lvwerra HF staff commited on
Commit
27393ce
1 Parent(s): 5263940

Create codeparrot_training.py

Browse files
Files changed (1) hide show
  1. codeparrot_training.py +200 -0
codeparrot_training.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, AutoTokenizer
2
+ from transformers import AdamW, get_scheduler, set_seed
3
+ from datasets import load_dataset
4
+ from accelerate import Accelerator
5
+ import datasets, transformers
6
+ from huggingface_hub import Repository
7
+
8
+ from torch.utils.data import IterableDataset
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from argparse import Namespace
12
+ import torch
13
+ import logging
14
+ import wandb
15
+
16
+ class ConstantLengthDataset(IterableDataset):
17
+
18
+ def __init__(self, tokenizer, dataset, seq_length=1024,
19
+ num_of_sequences=1024, chars_per_token=3.6):
20
+ self.tokenizer = tokenizer
21
+ self.concat_token_id = tokenizer.bos_token_id
22
+ self.dataset = dataset
23
+ self.seq_length = seq_length
24
+ self.input_characters = seq_length * chars_per_token * num_of_sequences
25
+
26
+ def __iter__(self):
27
+ iterator = iter(self.dataset)
28
+ more_examples = True
29
+ while more_examples:
30
+ buffer, buffer_len = [], 0
31
+ while True:
32
+ if buffer_len >= self.input_characters:
33
+ break
34
+ try:
35
+ buffer.append(next(iterator)['content'])
36
+ buffer_len += len(buffer[-1])
37
+ except StopIteration:
38
+ iterator = iter(self.dataset)
39
+ tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
40
+ all_token_ids = []
41
+ for tokenized_input in tokenized_inputs:
42
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
43
+ for i in range(0, len(all_token_ids), self.seq_length):
44
+ input_ids = all_token_ids[i : i + self.seq_length]
45
+ if len(input_ids) == self.seq_length:
46
+ yield torch.tensor(input_ids)
47
+
48
+ def setup_logging(project_name):
49
+ logger = logging.getLogger(__name__)
50
+ logging.basicConfig(
51
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
52
+ datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[
53
+ logging.FileHandler(f"log/debug_{accelerator.process_index}.log"),
54
+ logging.StreamHandler()])
55
+ if accelerator.is_main_process: # we only want to setup logging once
56
+ wandb.init(project=project_name, config=args)
57
+ run_name = wandb.run.name
58
+ tb_writer = SummaryWriter()
59
+ tb_writer.add_hparams(vars(args), {'0': 0})
60
+ logger.setLevel(logging.INFO)
61
+ datasets.utils.logging.set_verbosity_info()
62
+ transformers.utils.logging.set_verbosity_info()
63
+ else:
64
+ tb_writer = None
65
+ run_name = ''
66
+ logger.setLevel(logging.ERROR)
67
+ datasets.utils.logging.set_verbosity_error()
68
+ transformers.utils.logging.set_verbosity_error()
69
+ return logger, tb_writer, run_name
70
+
71
+ def create_dataloaders(dataset_name, args):
72
+ ds_kwargs = {"streaming":True}
73
+ train_data = load_dataset(dataset_name+'-train', split='train', **ds_kwargs)
74
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer,
75
+ seed=args.seed)
76
+ valid_data = load_dataset(dataset_name+'-valid', split="train", **ds_kwargs)
77
+ train_dataset = ConstantLengthDataset(tokenizer, train_data,
78
+ seq_length=args.seq_length)
79
+ valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
80
+ seq_length=args.seq_length)
81
+ train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
82
+ eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
83
+ return train_dataloader, eval_dataloader
84
+
85
+ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
86
+ params_with_wd, params_without_wd = [], []
87
+ for n, p in model.named_parameters():
88
+ if any(nd in n for nd in no_decay): params_without_wd.append(p)
89
+ else: params_with_wd.append(p)
90
+ return [{'params': params_with_wd, 'weight_decay': args.weight_decay},
91
+ {'params': params_without_wd, 'weight_decay': 0.0}]
92
+
93
+ def log_metrics(step, metrics):
94
+ logger.info(f"Step {step}: {metrics}")
95
+ if accelerator.is_main_process:
96
+ wandb.log(metrics)
97
+ [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
98
+
99
+ def evaluate(args):
100
+ model.eval()
101
+ losses = []
102
+ for step, batch in enumerate(eval_dataloader):
103
+ with torch.no_grad():
104
+ outputs = model(batch, labels=batch)
105
+ loss = outputs.loss.repeat(args.valid_batch_size)
106
+ losses.append(accelerator.gather(loss))
107
+ if args.max_eval_steps > 0 and step >= args.max_eval_steps: break
108
+ loss = torch.mean(torch.cat(losses))
109
+ try: perplexity = torch.exp(loss)
110
+ except OverflowError: perplexity = float("inf")
111
+ return loss.item(), perplexity.item()
112
+
113
+ # Accelerator
114
+ accelerator = Accelerator()
115
+ acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
116
+
117
+ # Hyperparameters
118
+ project_name = 'lvwerra/codeparrot-small'
119
+ dataset_name = '../codeparrot-clean'
120
+ config = {"train_batch_size": 12,
121
+ "valid_batch_size": 12,
122
+ "weight_decay": 0.1,
123
+ "shuffle_buffer": 1_000,
124
+ "learning_rate": 5e-4,
125
+ "lr_scheduler_type": "cosine",
126
+ "num_warmup_steps": 2_000,
127
+ "gradient_accumulation_steps": 1,
128
+ "gradient_checkpointing": False,
129
+ "max_train_steps": 150_000,
130
+ "max_eval_steps": -1,
131
+ "seq_length": 1024,
132
+ "seed": 1,
133
+ "save_checkpoint_steps": 15_000}
134
+ args = Namespace(**config, **acc_state)
135
+ samples_per_step = accelerator.state.num_processes * args.train_batch_size
136
+ set_seed(args.seed)
137
+
138
+ # Logging
139
+ logger, tb_writer, run_name = setup_logging(project_name.split("/")[1])
140
+ logger.info(accelerator.state)
141
+
142
+ # Load model and tokenizer
143
+ if accelerator.is_main_process:
144
+ hf_repo = Repository("./", clone_from=project_name, revision=run_name)
145
+ model = GPT2LMHeadModel.from_pretrained("./")
146
+ if args.gradient_checkpointing:
147
+ model.gradient_checkpointing_enable()
148
+ tokenizer = AutoTokenizer.from_pretrained("./")
149
+
150
+ # Load dataset and dataloader
151
+ train_dataloader, eval_dataloader = create_dataloaders(dataset_name, args)
152
+
153
+ # Prepare the optimizer and learning rate scheduler
154
+ optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
155
+ lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer,
156
+ num_warmup_steps=args.num_warmup_steps,
157
+ num_training_steps=args.max_train_steps,)
158
+ def get_lr(): return optimizer.param_groups[0]['lr']
159
+
160
+ # Prepare everything with our `accelerator`.
161
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
162
+ model, optimizer, train_dataloader, eval_dataloader)
163
+
164
+ # Train model
165
+ model.train()
166
+ completed_steps = 0
167
+ for step, batch in enumerate(train_dataloader, start=1):
168
+ loss = model(batch, labels=batch, use_cache=False).loss
169
+ log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step,
170
+ 'steps': completed_steps, 'loss/train': loss.item()})
171
+ loss = loss / args.gradient_accumulation_steps
172
+ accelerator.backward(loss)
173
+ if step % args.gradient_accumulation_steps == 0:
174
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
175
+ optimizer.step()
176
+ lr_scheduler.step()
177
+ optimizer.zero_grad()
178
+ completed_steps += 1
179
+ if step % args.save_checkpoint_steps == 0:
180
+ logger.info('Evaluating and saving model checkpoint')
181
+ eval_loss, perplexity = evaluate(args)
182
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
183
+ accelerator.wait_for_everyone()
184
+ unwrapped_model = accelerator.unwrap_model(model)
185
+ unwrapped_model.save_pretrained("./", save_function=accelerator.save)
186
+ if accelerator.is_main_process:
187
+ hf_repo.push_to_hub(commit_message=f'step {step}')
188
+ model.train()
189
+ if completed_steps >= args.max_train_steps:
190
+ break
191
+
192
+ # Evaluate and save the last checkpoint
193
+ logger.info('Evaluating and saving model after training')
194
+ eval_loss, perplexity = evaluate(args)
195
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
196
+ accelerator.wait_for_everyone()
197
+ unwrapped_model = accelerator.unwrap_model(model)
198
+ unwrapped_model.save_pretrained("./", save_function=accelerator.save)
199
+ if accelerator.is_main_process:
200
+ hf_repo.push_to_hub(commit_message=f'final model')