AGiottonini commited on
Commit
6569534
·
verified ·
1 Parent(s): eea6f25

Upload bert-distilled-pretrain.py

Browse files
Files changed (1) hide show
  1. bert-distilled-pretrain.py +295 -0
bert-distilled-pretrain.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import os
4
+ import tqdm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+
10
+ from transformers.models.bert import (
11
+ BertForPreTraining,
12
+ BertTokenizer,
13
+ BertConfig
14
+ )
15
+
16
+ from muon import SingleDeviceMuonWithAuxAdam
17
+
18
+ class Dataset():
19
+ def __init__(self, file_path, tokenizer, min_length=32, max_length=512):
20
+ self.sequences = []
21
+
22
+ self.tokenizer = tokenizer
23
+
24
+ self.min_length = min_length
25
+ self.max_length = max_length
26
+
27
+ self._load_data(file_path)
28
+
29
+ def _load_data(self, file_path):
30
+ with open(file_path, "rb") as f:
31
+ n_lines = (sum(1 for _ in f))
32
+
33
+ with open(file_path, "r") as f:
34
+ all_headers = []
35
+ all_sequences = []
36
+
37
+ with tqdm.tqdm(total=n_lines) as pbar:
38
+ for line in f:
39
+ line = line.rstrip("\n")
40
+
41
+ if line.startswith(">"):
42
+ all_headers.append(line.lstrip(">"))
43
+ all_sequences.append("")
44
+
45
+ else:
46
+ all_sequences[-1] += line
47
+
48
+ pbar.update(1)
49
+
50
+ self.sequences = [s for s in all_sequences if self.min_length - 2 <= len(s) <= self.max_length - 2]
51
+
52
+ def __len__(self):
53
+ return len(self.sequences)
54
+
55
+ def __getitem__(self, idx):
56
+ sequence = self.sequences[idx]
57
+
58
+ tokens = self.tokenizer(
59
+ " ".join(list(sequence)),
60
+ max_length=self.max_length,
61
+ padding="max_length",
62
+ return_tensors="pt"
63
+ )
64
+
65
+ return tokens
66
+
67
+
68
+ class DistillationLoss(nn.Module):
69
+ def __init__(self, alpha: float, temperature: float, num_labels: int, ignore_index: int):
70
+ super().__init__()
71
+ self.alpha = alpha
72
+ self.temperature = temperature
73
+ self.num_labels = num_labels
74
+ self.ignore_index = ignore_index
75
+
76
+ self.soft_loss_fn = nn.KLDivLoss(reduction="batchmean", log_target=True)
77
+ self.hard_loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)
78
+
79
+ def forward(self, student_logits, teacher_logits, labels):
80
+ """
81
+ Compute the distillation loss.
82
+
83
+ Args:
84
+ student_logits (torch.Tensor): Logits from the student model.
85
+ teacher_logits (torch.Tensor): Logits from the teacher model.
86
+ labels (torch.Tensor): Ground truth labels.
87
+
88
+ Returns:
89
+ torch.Tensor: The computed distillation loss.
90
+ """
91
+ # Soft loss
92
+ soft_loss = self.soft_loss_fn(
93
+ nn.LogSoftmax(dim=-1)(student_logits / self.temperature),
94
+ nn.LogSoftmax(dim=-1)(teacher_logits / self.temperature)
95
+ )
96
+
97
+ # Hard loss
98
+ hard_loss = self.hard_loss_fn(
99
+ student_logits.view(-1, self.num_labels), labels.view(-1)
100
+ )
101
+
102
+ return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
103
+
104
+
105
+ def main(
106
+ model_name: str,
107
+ student_hidden_size: int,
108
+ student_intermediate_size: int,
109
+ student_num_attention_heads: int,
110
+ student_num_hidden_layers: int,
111
+ train_data_path: str,
112
+ batch_size: int,
113
+ epochs: int,
114
+ lr: float,
115
+ default_lr: float,
116
+ teacher_use_bf16: bool=True,
117
+ teacher_use_sdpa: bool=True,
118
+ min_length: int=32,
119
+ max_length: int=512,
120
+ alpha: float=0.1,
121
+ temperature: float=10.0,
122
+ use_muon: bool=True,
123
+ weight_decay: float=0.01,
124
+ betas: Tuple[float, float]=(0.9,0.95),
125
+ default_weight_decay: float=0.01,
126
+ default_betas: Tuple[float, float]=(0.9, 0.95),
127
+ device: torch.device=torch.device("cpu"),
128
+ num_workers: int=256,
129
+ wandb_entity: str="giottonini-axel-unibe"
130
+ ):
131
+ import wandb
132
+ wandb_project = f"{model_name.replace('/', "_")}-distilled"
133
+ wandb_run = wandb.init(
134
+ entity=wandb_entity,
135
+ project=wandb_project,
136
+ config=dict(
137
+ hidden_size=student_hidden_size,
138
+ intermediate_size=student_intermediate_size,
139
+ num_attention_heads=student_num_attention_heads,
140
+ num_hidden_layers=student_num_hidden_layers,
141
+ alpha=alpha,
142
+ temperature=temperature,
143
+ use_muon=use_muon,
144
+ lr=lr,
145
+ weight_decay=weight_decay,
146
+ betas=betas,
147
+ default_lr=default_lr,
148
+ default_weight_decay=default_weight_decay,
149
+ default_betas=default_betas
150
+ )
151
+ )
152
+
153
+
154
+ # Initialize tokenizer, teacher model and student model
155
+ tokenizer = BertTokenizer.from_pretrained(model_name)
156
+
157
+ teacher_model_kwargs = dict()
158
+ if teacher_use_bf16:
159
+ teacher_model_kwargs["torch_dtype"] = torch.bfloat16
160
+ if teacher_use_sdpa:
161
+ teacher_model_kwargs["attn_implementation"] = "sdpa"
162
+ teacher_model = BertForPreTraining.from_pretrained(model_name, **teacher_model_kwargs)
163
+ teacher_model_compiled = torch.compile(teacher_model, mode="max-autotune", fullgraph=True)
164
+
165
+ student_config = BertConfig.from_pretrained(
166
+ "Rostlab/prot_bert",
167
+ hidden_size=student_hidden_size,
168
+ intermediate_size=student_intermediate_size,
169
+ num_attention_heads=student_num_attention_heads,
170
+ num_hidden_layers=student_num_hidden_layers
171
+ )
172
+ student_model = BertForPreTraining(student_config)
173
+
174
+ teacher_model_compiled.to(device) # type: ignore
175
+ student_model.to(device) # type: ignore
176
+
177
+
178
+ # Load dataset
179
+ dataset = Dataset(train_data_path, tokenizer, min_length=min_length, max_length=max_length)
180
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) # type: ignore
181
+
182
+
183
+ # Loss function
184
+ loss_fn = DistillationLoss(alpha, temperature, tokenizer.vocab_size, tokenizer.pad_token_type_id)
185
+
186
+
187
+ # Initialize optimizer
188
+ hidden_weights = [p for p in student_model.bert.encoder.parameters() if p.ndim >= 2]
189
+ hidden_biases = [p for p in student_model.bert.encoder.parameters() if p.ndim < 2]
190
+ nonhidden_params = [
191
+ *student_model.bert.embeddings.parameters(),
192
+ *student_model.cls.parameters()
193
+ ]
194
+
195
+ hidden_param_group = dict(
196
+ params=hidden_weights,
197
+ use_muon=use_muon,
198
+ lr=lr,
199
+ weight_decay=weight_decay
200
+ )
201
+ if not use_muon:
202
+ hidden_param_group["betas"] = betas # type: ignore
203
+
204
+ default_param_group = dict(
205
+ params=hidden_biases + nonhidden_params,
206
+ use_muon=False,
207
+ lr=default_lr,
208
+ betas=default_betas,
209
+ weight_decay=default_weight_decay
210
+ )
211
+
212
+ optimizer = SingleDeviceMuonWithAuxAdam([hidden_param_group, default_param_group])
213
+
214
+
215
+ # Training loop
216
+ wandb_run.watch(student_model)
217
+ step = 0
218
+ for epoch in range(epochs):
219
+ with tqdm.tqdm(dataloader) as pbar:
220
+ for batch in pbar:
221
+ # Clear optimizer and model gradients
222
+ optimizer.zero_grad()
223
+ student_model.zero_grad()
224
+
225
+ # Send the data to the device
226
+ batch = {k : v.squeeze(1).to(device) for k, v in batch.items()}
227
+
228
+ # Compute teacher logits
229
+ with torch.no_grad():
230
+ teacher_logits = teacher_model(
231
+ input_ids=batch["input_ids"],
232
+ attention_mask=batch["attention_mask"]
233
+ ).prediction_logits
234
+
235
+ # Compute student logits
236
+ student_logits = student_model(
237
+ input_ids=batch["input_ids"],
238
+ attention_mask=batch["attention_mask"]
239
+ ).prediction_logits
240
+
241
+ # Loss backpropagation and optimization step
242
+ loss = loss_fn(student_logits, teacher_logits, batch["input_ids"])
243
+ loss.backward()
244
+ optimizer.step()
245
+
246
+ step += 1
247
+ pbar.set_description(f"Epoch {epoch} | Loss: {loss.item():.4f}")
248
+ wandb_run.log(dict(loss=loss.item()))
249
+
250
+ # Save checkpoint
251
+ if step % 1000 == 0:
252
+ checkpoint = dict(
253
+ state_dict=student_model.state_dict(),
254
+ optimizer=optimizer.state_dict(),
255
+ )
256
+
257
+ os.makedirs(os.path.join(wandb_run.project, wandb_run.name, str(step))) # type: ignore
258
+ torch.save(
259
+ checkpoint,
260
+ os.path.join(wandb_run.project, wandb_run.name, str(step), "checkpoint.pt") # type: ignore
261
+ )
262
+
263
+
264
+ if __name__ == "__main__":
265
+ import argparse
266
+
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--model_name", type=str, default="Rostlab/prot_bert", help="Name of the teacher model")
269
+ parser.add_argument("--student_hidden_size", type=int, default=16, help="Hidden size of the student model")
270
+ parser.add_argument("--student_intermediate_size", type=int, default=64, help="Intermediate size of the studen model")
271
+ parser.add_argument("--student_num_attention_heads", type=int, default=4, help="Number of attention heads of the student model")
272
+ parser.add_argument("--student_num_hidden_layers", type=int, default=12, help="Number of hidden layers of the student model")
273
+ parser.add_argument("--train_data_path", type=str, help="Path to the training data (fasta file)")
274
+ parser.add_argument("--batch_size", type=int, default=1024, help="Batch size for training")
275
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs for training")
276
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for the hidden parameters")
277
+ parser.add_argument("--default_lr", type=float, default=1e-4, help="Learning rate for the non-hidden parameters and biases")
278
+ parser.add_argument("--device", type=int, default=-1, help="GPU device to use (-1 for CPU)")
279
+ args = vars(parser.parse_args())
280
+
281
+ device = torch.device(f"cuda:{int(args['device'])}" if torch.cuda.is_available() and int(args['device']) >= 0 else "cpu")
282
+
283
+ main(
284
+ model_name=args["model_name"],
285
+ student_hidden_size=args["student_hidden_size"],
286
+ student_intermediate_size=args["student_intermediate_size"],
287
+ student_num_attention_heads=args["student_num_attention_heads"],
288
+ student_num_hidden_layers=args["student_num_hidden_layers"],
289
+ train_data_path=args["train_data_path"],
290
+ batch_size=args["batch_size"],
291
+ epochs=args["epochs"],
292
+ lr=args["lr"],
293
+ default_lr=args["default_lr"],
294
+ device=device
295
+ )