NotoriousH2 commited on
Commit
10c8d20
ยท
verified ยท
1 Parent(s): 1dcce72

Add train_rs_sft.py

Browse files
Files changed (1) hide show
  1. train_rs_sft.py +143 -0
train_rs_sft.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """C20: Variants of C18-2 (the 48.5% recipe) with different replay ratios"""
2
+ import json, re, random, torch, numpy as np, os
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from trl import SFTTrainer, SFTConfig
5
+ from datasets import Dataset
6
+
7
+ SEED = 42
8
+ random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
9
+ if torch.cuda.is_available():
10
+ torch.cuda.manual_seed_all(SEED)
11
+ if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision('high')
12
+
13
+ SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"
14
+ BASE = "outputs/models/c17d-gemma-3-1b-it-Math"
15
+
16
+ # Load RS1+RS2 (the winning combo)
17
+ with open("outputs/c17d_rs/sft_dataset.json") as f:
18
+ rs1 = json.load(f)
19
+ with open("outputs/c17d_rs2/sft_dataset.json") as f:
20
+ rs2 = json.load(f)
21
+
22
+ seen = set()
23
+ rs_combined = []
24
+ for d in rs1 + rs2:
25
+ if d["n_correct"] < 4: continue
26
+ key = (d["question"], d["answer"])
27
+ if key not in seen:
28
+ seen.add(key)
29
+ rs_combined.append({"question": d["question"], "answer": d["answer"], "source": "gsm8k"})
30
+ print(f"RS1+RS2 combined: {len(rs_combined)}")
31
+
32
+ with open("data/GSM8K_full_qwen3_30b.json") as f:
33
+ orig_data = json.load(f)
34
+ orig_filtered = [d for d in orig_data if len(d["answer"]) <= 1500]
35
+
36
+ def to_sft(ex):
37
+ return {"prompt": [{"role":"user","content":SP+"\n\n"+ex["question"]}],
38
+ "completion": [{"role":"assistant","content":ex["answer"]}]}
39
+
40
+ # === Condition 1: RS1+RS2 + 2x replay (more aggressive RS) ===
41
+ print("\n=== C20-1: RS1+RS2 + 2x replay ===")
42
+ random.seed(SEED)
43
+ rs_qs = set(d["question"] for d in rs_combined)
44
+ replay = [d for d in orig_filtered if d["question"] not in rs_qs]
45
+ random.shuffle(replay)
46
+ replay1 = replay[:int(len(rs_combined) * 2)]
47
+ mixed1 = rs_combined + replay1
48
+ random.shuffle(mixed1)
49
+ print(f" RS: {len(rs_combined)} + replay: {len(replay1)} = {len(mixed1)}")
50
+
51
+ ds1 = Dataset.from_list(mixed1)
52
+ cols = [c for c in ds1.column_names if c not in ["prompt","completion"]]
53
+ ds1 = ds1.map(to_sft, remove_columns=cols)
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(BASE)
56
+ model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+ model.gradient_checkpointing_enable(); model.config.use_cache = False
59
+
60
+ cfg1 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
61
+ weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
62
+ gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
63
+ learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit",
64
+ output_dir="outputs/c20_1_ckpt", logging_steps=25, save_strategy="no")
65
+ trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds1, args=cfg1)
66
+ r = trainer.train()
67
+ print(f" Loss: {r.training_loss:.4f}")
68
+
69
+ SAVE1 = "outputs/models/c20-1-2x-replay"
70
+ os.makedirs(SAVE1, exist_ok=True)
71
+ model.eval(); model.save_pretrained(SAVE1, safe_serialization=False)
72
+ tokenizer.save_pretrained(SAVE1)
73
+ del model, trainer; torch.cuda.empty_cache()
74
+
75
+ # === Condition 2: RS1+RS2 + 5x replay (more teacher data) ===
76
+ print("\n=== C20-2: RS1+RS2 + 5x replay ===")
77
+ random.seed(SEED)
78
+ replay = [d for d in orig_filtered if d["question"] not in rs_qs]
79
+ random.shuffle(replay)
80
+ replay2 = replay[:int(len(rs_combined) * 5)]
81
+ mixed2 = rs_combined + replay2
82
+ random.shuffle(mixed2)
83
+ print(f" RS: {len(rs_combined)} + replay: {len(replay2)} = {len(mixed2)}")
84
+
85
+ ds2 = Dataset.from_list(mixed2)
86
+ cols = [c for c in ds2.column_names if c not in ["prompt","completion"]]
87
+ ds2 = ds2.map(to_sft, remove_columns=cols)
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained(BASE)
90
+ model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
91
+ tokenizer.pad_token = tokenizer.eos_token
92
+ model.gradient_checkpointing_enable(); model.config.use_cache = False
93
+
94
+ cfg2 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
95
+ weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
96
+ gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
97
+ learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit",
98
+ output_dir="outputs/c20_2_ckpt", logging_steps=25, save_strategy="no")
99
+ trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds2, args=cfg2)
100
+ r = trainer.train()
101
+ print(f" Loss: {r.training_loss:.4f}")
102
+
103
+ SAVE2 = "outputs/models/c20-2-5x-replay"
104
+ os.makedirs(SAVE2, exist_ok=True)
105
+ model.eval(); model.save_pretrained(SAVE2, safe_serialization=False)
106
+ tokenizer.save_pretrained(SAVE2)
107
+ del model, trainer; torch.cuda.empty_cache()
108
+
109
+ # === Condition 3: RS1+RS2 + 3x replay + lr=3e-6 (higher lr) ===
110
+ print("\n=== C20-3: RS1+RS2 + 3x replay + lr=3e-6 ===")
111
+ random.seed(SEED)
112
+ replay = [d for d in orig_filtered if d["question"] not in rs_qs]
113
+ random.shuffle(replay)
114
+ replay3 = replay[:int(len(rs_combined) * 3)]
115
+ mixed3 = rs_combined + replay3
116
+ random.shuffle(mixed3)
117
+ print(f" RS: {len(rs_combined)} + replay: {len(replay3)} = {len(mixed3)}")
118
+
119
+ ds3 = Dataset.from_list(mixed3)
120
+ cols = [c for c in ds3.column_names if c not in ["prompt","completion"]]
121
+ ds3 = ds3.map(to_sft, remove_columns=cols)
122
+
123
+ tokenizer = AutoTokenizer.from_pretrained(BASE)
124
+ model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
125
+ tokenizer.pad_token = tokenizer.eos_token
126
+ model.gradient_checkpointing_enable(); model.config.use_cache = False
127
+
128
+ cfg3 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05,
129
+ weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8,
130
+ gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine',
131
+ learning_rate=3e-6, bf16=True, optim="paged_adamw_8bit",
132
+ output_dir="outputs/c20_3_ckpt", logging_steps=25, save_strategy="no")
133
+ trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds3, args=cfg3)
134
+ r = trainer.train()
135
+ print(f" Loss: {r.training_loss:.4f}")
136
+
137
+ SAVE3 = "outputs/models/c20-3-lr3e-6"
138
+ os.makedirs(SAVE3, exist_ok=True)
139
+ model.eval(); model.save_pretrained(SAVE3, safe_serialization=False)
140
+ tokenizer.save_pretrained(SAVE3)
141
+ del model, trainer; torch.cuda.empty_cache()
142
+
143
+ print("\n=== All conditions complete ===")