NotoriousH2 commited on
Commit
1c4102a
ยท
verified ยท
1 Parent(s): 8e96ac1

Add train_sft.py

Browse files
Files changed (1) hide show
  1. train_sft.py +67 -0
train_sft.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """C17d: ๋ชจ๋“  ํ’€์ด + ๊ธธ์ด ํ•„ํ„ฐ (1500์ž ์ดํ•˜๋งŒ) + NaN ๋ฐฉ์ง€"""
2
+ import json, re, random, torch, numpy as np, os
3
+ from collections import defaultdict
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from trl import SFTTrainer, SFTConfig
6
+ from transformers import EarlyStoppingCallback
7
+ from datasets import Dataset
8
+
9
+ SEED = 42
10
+ random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
11
+ if torch.cuda.is_available():
12
+ torch.cuda.manual_seed_all(SEED)
13
+ if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision('high')
14
+
15
+ SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"
16
+
17
+ print("=== C17d: All solutions, length-filtered (โ‰ค1500 chars) ===")
18
+
19
+ with open("data/GSM8K_full_qwen3_30b.json") as f:
20
+ data = json.load(f)
21
+
22
+ # ๊ธธ์ด ํ•„ํ„ฐ: 1500์ž ์ดํ•˜๋งŒ
23
+ filtered = [d for d in data if len(d['answer']) <= 1500]
24
+ print(f"์›๋ณธ: {len(data)}๊ฐœ โ†’ ํ•„ํ„ฐ ํ›„: {len(filtered)}๊ฐœ (์ œ๊ฑฐ: {len(data)-len(filtered)})")
25
+
26
+ random.shuffle(filtered)
27
+ uq = len(set(d["question"] for d in filtered))
28
+ print(f"Unique: {uq}, avg {len(filtered)/uq:.1f}/q")
29
+
30
+ split = int(len(filtered) * 0.95)
31
+ train, test = filtered[:split], filtered[split:]
32
+ def to_sft(ex):
33
+ return {"prompt": [{"role":"user","content":SP+"\n\n"+ex["question"]}],
34
+ "completion": [{"role":"assistant","content":ex["answer"]}]}
35
+
36
+ cols = [c for c in Dataset.from_list(train[:1]).column_names if c not in ["prompt","completion"]]
37
+ train_ds = Dataset.from_list(train).map(to_sft, remove_columns=cols)
38
+ test_ds = Dataset.from_list(test).map(to_sft, remove_columns=cols)
39
+ print(f"ํ•™์Šต: {len(train_ds)} / ๊ฒ€์ฆ: {len(test_ds)}")
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained("outputs/models/gemma-3-1b-it")
42
+ model = AutoModelForCausalLM.from_pretrained("outputs/models/gemma-3-1b-it", dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+ model.gradient_checkpointing_enable(); model.config.use_cache = False
45
+
46
+ cfg = SFTConfig(
47
+ report_to='none', seed=SEED, eval_strategy="steps", eval_steps=200,
48
+ save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="eval_loss",
49
+ save_steps=200, num_train_epochs=3, warmup_ratio=0.05, weight_decay=0.01, max_grad_norm=1.0,
50
+ neftune_noise_alpha=5, per_device_train_batch_size=8, gradient_accumulation_steps=4,
51
+ per_device_eval_batch_size=2, max_length=2048, lr_scheduler_type='cosine',
52
+ learning_rate=2e-5, bf16=True, optim="paged_adamw_8bit",
53
+ output_dir="outputs/c17d_checkpoints", logging_steps=50, save_strategy="steps",
54
+ )
55
+
56
+ trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=train_ds, eval_dataset=test_ds, args=cfg,
57
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)])
58
+ print("ํ•™์Šต ์‹œ์ž‘ (3 epochs, ๋ชจ๋“  ํ’€์ด, โ‰ค1500์ž)")
59
+ r = trainer.train()
60
+ print(f"์™„๋ฃŒ! Loss: {r.training_loss:.4f}")
61
+
62
+ SAVE = "outputs/models/c17d-gemma-3-1b-it-Math"
63
+ os.makedirs(SAVE, exist_ok=True)
64
+ model.eval(); model.save_pretrained(SAVE, safe_serialization=False); tokenizer.save_pretrained(SAVE)
65
+ print(f"์ €์žฅ: {SAVE}")
66
+ del model, trainer; torch.cuda.empty_cache()
67
+ print("GPU ํ•ด์ œ")