w11wo commited on
Commit
f5e3080
1 Parent(s): f1dd461

config and tokenizer

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. config.json +36 -0
  3. create_config.py +9 -0
  4. run.sh +23 -0
  5. run_clm_flax.py +750 -0
  6. tokenizer.json +0 -0
  7. train_tokenizer.py +49 -0
.gitattributes CHANGED
@@ -15,3 +15,4 @@
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
+ nohup.out filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.0,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.0,
9
+ "eos_token_id": 50256,
10
+ "gradient_checkpointing": false,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "resid_pdrop": 0.0,
21
+ "scale_attn_weights": true,
22
+ "summary_activation": null,
23
+ "summary_first_dropout": 0.1,
24
+ "summary_proj_to_labels": true,
25
+ "summary_type": "cls_index",
26
+ "summary_use_proj": true,
27
+ "task_specific_params": {
28
+ "text-generation": {
29
+ "do_sample": true,
30
+ "max_length": 50
31
+ }
32
+ },
33
+ "transformers_version": "4.9.0.dev0",
34
+ "use_cache": true,
35
+ "vocab_size": 50257
36
+ }
create_config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+ model_dir = "./"
4
+
5
+ config = GPT2Config.from_pretrained(
6
+ "gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
7
+ )
8
+ config.save_pretrained(model_dir)
9
+
run.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python3 run_clm_flax.py \
3
+ --model_name_or_path="flax_model.msgpack" \
4
+ --output_dir="./" \
5
+ --model_type="gpt2" \
6
+ --config_name="./" \
7
+ --tokenizer_name="./" \
8
+ --dataset_language="su" \
9
+ --do_train \
10
+ --do_eval \
11
+ --block_size="512" \
12
+ --preprocessing_num_workers="64" \
13
+ --weight_decay="0.0" \
14
+ --per_device_train_batch_size="64" \
15
+ --per_device_eval_batch_size="64" \
16
+ --learning_rate="2e-4" \
17
+ --warmup_steps="1000" \
18
+ --overwrite_output_dir \
19
+ --num_train_epochs="50" \
20
+ --adam_beta1="0.9" \
21
+ --adam_beta2="0.999" \
22
+ --adam_epsilon="1e-8" \
23
+ --push_to_hub
run_clm_flax.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset, concatenate_datasets
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ from flax import jax_utils, traverse_util
42
+ from flax.jax_utils import unreplicate
43
+ from flax.training import train_state
44
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
+ from transformers import (
46
+ CONFIG_MAPPING,
47
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
48
+ AutoConfig,
49
+ AutoTokenizer,
50
+ FlaxAutoModelForCausalLM,
51
+ HfArgumentParser,
52
+ TrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.testing_utils import CaptureLogger
56
+
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ # Cache the result
61
+ has_tensorboard = is_tensorboard_available()
62
+ if has_tensorboard:
63
+ try:
64
+ from flax.metrics.tensorboard import SummaryWriter
65
+ except ImportError as ie:
66
+ has_tensorboard = False
67
+ print(
68
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
69
+ )
70
+
71
+ else:
72
+ print(
73
+ "Unable to display metrics through TensorBoard because the package is not installed: "
74
+ "Please run pip install tensorboard to enable."
75
+ )
76
+
77
+
78
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
79
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
80
+
81
+
82
+ @dataclass
83
+ class ModelArguments:
84
+ """
85
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
86
+ """
87
+
88
+ model_name_or_path: Optional[str] = field(
89
+ default=None,
90
+ metadata={
91
+ "help": "The model checkpoint for weights initialization."
92
+ "Don't set if you want to train a model from scratch."
93
+ },
94
+ )
95
+ model_type: Optional[str] = field(
96
+ default=None,
97
+ metadata={
98
+ "help": "If training from scratch, pass a model type from the list: "
99
+ + ", ".join(MODEL_TYPES)
100
+ },
101
+ )
102
+ config_name: Optional[str] = field(
103
+ default=None,
104
+ metadata={
105
+ "help": "Pretrained config name or path if not the same as model_name"
106
+ },
107
+ )
108
+ tokenizer_name: Optional[str] = field(
109
+ default=None,
110
+ metadata={
111
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
112
+ },
113
+ )
114
+ cache_dir: Optional[str] = field(
115
+ default=None,
116
+ metadata={
117
+ "help": "Where do you want to store the pretrained models downloaded from s3"
118
+ },
119
+ )
120
+ use_fast_tokenizer: bool = field(
121
+ default=True,
122
+ metadata={
123
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
124
+ },
125
+ )
126
+ dtype: Optional[str] = field(
127
+ default="float32",
128
+ metadata={
129
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
130
+ },
131
+ )
132
+
133
+
134
+ @dataclass
135
+ class DataTrainingArguments:
136
+ """
137
+ Arguments pertaining to what data we are going to input our model for training and eval.
138
+ """
139
+
140
+ dataset_language: Optional[str] = field(
141
+ default=None,
142
+ metadata={
143
+ "help": "The language of the OSCAR, MC4, CC100 dataset to use (via the datasets library)."
144
+ },
145
+ )
146
+ # dataset_name: Optional[str] = field(
147
+ # default=None,
148
+ # metadata={"help": "The name of the dataset to use (via the datasets library)."},
149
+ # )
150
+ # dataset_config_name: Optional[str] = field(
151
+ # default=None,
152
+ # metadata={
153
+ # "help": "The configuration name of the dataset to use (via the datasets library)."
154
+ # },
155
+ # )
156
+ train_file: Optional[str] = field(
157
+ default=None, metadata={"help": "The input training data file (a text file)."}
158
+ )
159
+ validation_file: Optional[str] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
163
+ },
164
+ )
165
+ max_train_samples: Optional[int] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
169
+ "value if set."
170
+ },
171
+ )
172
+ max_eval_samples: Optional[int] = field(
173
+ default=None,
174
+ metadata={
175
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
176
+ "value if set."
177
+ },
178
+ )
179
+ overwrite_cache: bool = field(
180
+ default=False,
181
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
182
+ )
183
+ validation_split_percentage: Optional[int] = field(
184
+ default=5,
185
+ metadata={
186
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
187
+ },
188
+ )
189
+ block_size: Optional[int] = field(
190
+ default=None,
191
+ metadata={
192
+ "help": "Optional input sequence length after tokenization. "
193
+ "The training dataset will be truncated in block of this size for training. "
194
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
195
+ },
196
+ )
197
+ overwrite_cache: bool = field(
198
+ default=False,
199
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
200
+ )
201
+ preprocessing_num_workers: Optional[int] = field(
202
+ default=None,
203
+ metadata={"help": "The number of processes to use for the preprocessing."},
204
+ )
205
+
206
+ def __post_init__(self):
207
+ if (
208
+ self.dataset_language is None
209
+ and self.train_file is None
210
+ and self.validation_file is None
211
+ ):
212
+ raise ValueError(
213
+ "Need either a dataset name or a training/validation file."
214
+ )
215
+ else:
216
+ if self.train_file is not None:
217
+ extension = self.train_file.split(".")[-1]
218
+ assert extension in [
219
+ "csv",
220
+ "json",
221
+ "txt",
222
+ ], "`train_file` should be a csv, a json or a txt file."
223
+ if self.validation_file is not None:
224
+ extension = self.validation_file.split(".")[-1]
225
+ assert extension in [
226
+ "csv",
227
+ "json",
228
+ "txt",
229
+ ], "`validation_file` should be a csv, a json or a txt file."
230
+
231
+
232
+ class TrainState(train_state.TrainState):
233
+ dropout_rng: jnp.ndarray
234
+
235
+ def replicate(self):
236
+ return jax_utils.replicate(self).replace(
237
+ dropout_rng=shard_prng_key(self.dropout_rng)
238
+ )
239
+
240
+
241
+ def data_loader(
242
+ rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
243
+ ):
244
+ """
245
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
246
+ Shuffle batches if `shuffle` is `True`.
247
+ """
248
+ steps_per_epoch = len(dataset) // batch_size
249
+
250
+ if shuffle:
251
+ batch_idx = jax.random.permutation(rng, len(dataset))
252
+ else:
253
+ batch_idx = jnp.arange(len(dataset))
254
+
255
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
256
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
257
+
258
+ for idx in batch_idx:
259
+ batch = dataset[idx]
260
+ batch = {k: jnp.array(v) for k, v in batch.items()}
261
+
262
+ batch = shard(batch)
263
+
264
+ yield batch
265
+
266
+
267
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
268
+ summary_writer.scalar("train_time", train_time, step)
269
+
270
+ train_metrics = get_metrics(train_metrics)
271
+ for key, vals in train_metrics.items():
272
+ tag = f"train_{key}"
273
+ for i, val in enumerate(vals):
274
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
275
+
276
+ for metric_name, value in eval_metrics.items():
277
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
278
+
279
+
280
+ def create_learning_rate_fn(
281
+ train_ds_size: int,
282
+ train_batch_size: int,
283
+ num_train_epochs: int,
284
+ num_warmup_steps: int,
285
+ learning_rate: float,
286
+ ) -> Callable[[int], jnp.array]:
287
+ """Returns a linear warmup, linear_decay learning rate function."""
288
+ steps_per_epoch = train_ds_size // train_batch_size
289
+ num_train_steps = steps_per_epoch * num_train_epochs
290
+ warmup_fn = optax.linear_schedule(
291
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
292
+ )
293
+ decay_fn = optax.linear_schedule(
294
+ init_value=learning_rate,
295
+ end_value=0,
296
+ transition_steps=num_train_steps - num_warmup_steps,
297
+ )
298
+ schedule_fn = optax.join_schedules(
299
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
300
+ )
301
+ return schedule_fn
302
+
303
+
304
+ def main():
305
+ # See all possible arguments in src/transformers/training_args.py
306
+ # or by passing the --help flag to this script.
307
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
308
+
309
+ parser = HfArgumentParser(
310
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
311
+ )
312
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
313
+ # If we pass only one argument to the script and it's the path to a json file,
314
+ # let's parse it to get our arguments.
315
+ model_args, data_args, training_args = parser.parse_json_file(
316
+ json_file=os.path.abspath(sys.argv[1])
317
+ )
318
+ else:
319
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
320
+
321
+ if (
322
+ os.path.exists(training_args.output_dir)
323
+ and os.listdir(training_args.output_dir)
324
+ and training_args.do_train
325
+ and not training_args.overwrite_output_dir
326
+ ):
327
+ raise ValueError(
328
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
329
+ "Use --overwrite_output_dir to overcome."
330
+ )
331
+
332
+ # Make one log on every process with the configuration for debugging.
333
+ logging.basicConfig(
334
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
335
+ datefmt="%m/%d/%Y %H:%M:%S",
336
+ level=logging.INFO,
337
+ )
338
+ # Setup logging, we only want one process per machine to log things on the screen.
339
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
340
+ if jax.process_index() == 0:
341
+ datasets.utils.logging.set_verbosity_warning()
342
+ transformers.utils.logging.set_verbosity_info()
343
+ else:
344
+ datasets.utils.logging.set_verbosity_error()
345
+ transformers.utils.logging.set_verbosity_error()
346
+
347
+ # Set the verbosity to info of the Transformers logger (on main process only):
348
+ logger.info(f"Training/evaluation parameters {training_args}")
349
+
350
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
351
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
352
+ # (the dataset will be downloaded automatically from the datasets Hub).
353
+ #
354
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
355
+ # 'text' is found. You can easily tweak this behavior (see below).
356
+ #
357
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
358
+ # download the dataset.
359
+ if data_args.dataset_language is not None:
360
+ # Downloading and loading a dataset from the hub.
361
+ oscar = load_dataset(
362
+ "oscar",
363
+ f"unshuffled_deduplicated_{data_args.dataset_language}",
364
+ split="train",
365
+ cache_dir=model_args.cache_dir,
366
+ )
367
+
368
+ cc100 = load_dataset(
369
+ "cc100",
370
+ lang=data_args.dataset_language,
371
+ split="train",
372
+ cache_dir=model_args.cache_dir,
373
+ )
374
+
375
+ mc4 = load_dataset(
376
+ "mc4",
377
+ data_args.dataset_language,
378
+ split="train",
379
+ cache_dir=model_args.cache_dir,
380
+ )
381
+
382
+ wiki_files = [str(x) for x in Path("../docs").glob("*.txt")]
383
+ wiki = load_dataset("text", data_files=wiki_files)
384
+
385
+ # want: text column only!
386
+ oscar = oscar.remove_columns("id")
387
+ mc4 = mc4.remove_columns(["url", "timestamp"])
388
+ cc100 = cc100.remove_columns("id")
389
+
390
+ # combine datasets
391
+ datasets = concatenate_datasets([oscar, mc4, cc100, wiki["train"]])
392
+ # split train and validation
393
+ # note: renamed `validation` key to `test` everywhere else in the script
394
+ datasets = datasets.train_test_split(
395
+ test_size=data_args.validation_split_percentage / 100, seed=42
396
+ )
397
+
398
+ else:
399
+ data_files = {}
400
+ if data_args.train_file is not None:
401
+ data_files["train"] = data_args.train_file
402
+ if data_args.validation_file is not None:
403
+ data_files["test"] = data_args.validation_file
404
+ extension = data_args.train_file.split(".")[-1]
405
+ if extension == "txt":
406
+ extension = "text"
407
+ dataset = load_dataset(
408
+ extension, data_files=data_files, cache_dir=model_args.cache_dir
409
+ )
410
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
411
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
412
+
413
+ # Load pretrained model and tokenizer
414
+
415
+ # Distributed training:
416
+ # The .from_pretrained methods guarantee that only one local process can concurrently
417
+ # download model & vocab.
418
+ if model_args.config_name:
419
+ config = AutoConfig.from_pretrained(
420
+ model_args.config_name, cache_dir=model_args.cache_dir
421
+ )
422
+ elif model_args.model_name_or_path:
423
+ config = AutoConfig.from_pretrained(
424
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
425
+ )
426
+ else:
427
+ config = CONFIG_MAPPING[model_args.model_type]()
428
+ logger.warning("You are instantiating a new config instance from scratch.")
429
+
430
+ if model_args.tokenizer_name:
431
+ tokenizer = AutoTokenizer.from_pretrained(
432
+ model_args.tokenizer_name,
433
+ cache_dir=model_args.cache_dir,
434
+ use_fast=model_args.use_fast_tokenizer,
435
+ )
436
+ elif model_args.model_name_or_path:
437
+ tokenizer = AutoTokenizer.from_pretrained(
438
+ model_args.model_name_or_path,
439
+ cache_dir=model_args.cache_dir,
440
+ use_fast=model_args.use_fast_tokenizer,
441
+ )
442
+ else:
443
+ raise ValueError(
444
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
445
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
446
+ )
447
+
448
+ if model_args.model_name_or_path:
449
+ model = FlaxAutoModelForCausalLM.from_pretrained(
450
+ model_args.model_name_or_path,
451
+ config=config,
452
+ seed=training_args.seed,
453
+ dtype=getattr(jnp, model_args.dtype),
454
+ )
455
+ else:
456
+ model = FlaxAutoModelForCausalLM.from_config(
457
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
458
+ )
459
+
460
+ # Preprocessing the datasets.
461
+ # First we tokenize all the texts.
462
+ if training_args.do_train:
463
+ column_names = dataset["train"].column_names
464
+ else:
465
+ column_names = dataset["test"].column_names
466
+ text_column_name = "text" if "text" in column_names else column_names[0]
467
+
468
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
469
+ tok_logger = transformers.utils.logging.get_logger(
470
+ "transformers.tokenization_utils_base"
471
+ )
472
+
473
+ def tokenize_function(examples):
474
+ with CaptureLogger(tok_logger) as cl:
475
+ output = tokenizer(examples[text_column_name])
476
+ # clm input could be much much longer than block_size
477
+ if "Token indices sequence length is longer than the" in cl.out:
478
+ tok_logger.warning(
479
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
480
+ )
481
+ return output
482
+
483
+ tokenized_datasets = dataset.map(
484
+ tokenize_function,
485
+ batched=True,
486
+ num_proc=data_args.preprocessing_num_workers,
487
+ remove_columns=column_names,
488
+ load_from_cache_file=not data_args.overwrite_cache,
489
+ )
490
+
491
+ if data_args.block_size is None:
492
+ block_size = tokenizer.model_max_length
493
+ if block_size > config.max_position_embeddings:
494
+ logger.warning(
495
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
496
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
497
+ )
498
+ block_size = 1024
499
+ else:
500
+ if data_args.block_size > tokenizer.model_max_length:
501
+ logger.warning(
502
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
503
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
504
+ )
505
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
506
+
507
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
508
+ def group_texts(examples):
509
+ # Concatenate all texts.
510
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
511
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
512
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
513
+ # customize this part to your needs.
514
+ total_length = (total_length // block_size) * block_size
515
+ # Split by chunks of max_len.
516
+ result = {
517
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
518
+ for k, t in concatenated_examples.items()
519
+ }
520
+ result["labels"] = result["input_ids"].copy()
521
+ return result
522
+
523
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
524
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
525
+ # to preprocess.
526
+ #
527
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
528
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
529
+
530
+ lm_datasets = tokenized_datasets.map(
531
+ group_texts,
532
+ batched=True,
533
+ num_proc=data_args.preprocessing_num_workers,
534
+ load_from_cache_file=not data_args.overwrite_cache,
535
+ )
536
+
537
+ if training_args.do_train:
538
+ if "train" not in tokenized_datasets:
539
+ raise ValueError("--do_train requires a train dataset")
540
+ train_dataset = lm_datasets["train"]
541
+ if data_args.max_train_samples is not None:
542
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
543
+
544
+ if training_args.do_eval:
545
+ if "test" not in tokenized_datasets:
546
+ raise ValueError("--do_eval requires a validation dataset")
547
+ eval_dataset = lm_datasets["test"]
548
+ if data_args.max_eval_samples is not None:
549
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
550
+
551
+ # Enable tensorboard only on the master node
552
+ if has_tensorboard and jax.process_index() == 0:
553
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
554
+
555
+ # Initialize our training
556
+ rng = jax.random.PRNGKey(training_args.seed)
557
+ rng, dropout_rng = jax.random.split(rng)
558
+
559
+ # Store some constant
560
+ num_epochs = int(training_args.num_train_epochs)
561
+ train_batch_size = (
562
+ int(training_args.per_device_train_batch_size) * jax.device_count()
563
+ )
564
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
565
+ steps_per_epoch = len(train_dataset) // train_batch_size
566
+ total_train_steps = steps_per_epoch * num_epochs
567
+
568
+ # Create learning rate schedule
569
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
570
+ len(train_dataset),
571
+ train_batch_size,
572
+ training_args.num_train_epochs,
573
+ training_args.warmup_steps,
574
+ training_args.learning_rate,
575
+ )
576
+
577
+ # We use Optax's "masking" functionality to not apply weight decay
578
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
579
+ # mask boolean with the same structure as the parameters.
580
+ # The mask is True for parameters that should be decayed.
581
+ # Note that this mask is specifically adapted for FlaxGPT2.
582
+ # For other models, one should correct the layer norm parameter naming
583
+ # accordingly.
584
+ def decay_mask_fn(params):
585
+ flat_params = traverse_util.flatten_dict(params)
586
+ flat_mask = {
587
+ path: (
588
+ path[-1] != "bias"
589
+ and path[-2:]
590
+ not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]
591
+ )
592
+ for path in flat_params
593
+ }
594
+ return traverse_util.unflatten_dict(flat_mask)
595
+
596
+ # create adam optimizer
597
+ adamw = optax.adamw(
598
+ learning_rate=linear_decay_lr_schedule_fn,
599
+ b1=training_args.adam_beta1,
600
+ b2=training_args.adam_beta2,
601
+ eps=training_args.adam_epsilon,
602
+ weight_decay=training_args.weight_decay,
603
+ mask=decay_mask_fn,
604
+ )
605
+
606
+ # Setup train state
607
+ state = TrainState.create(
608
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
609
+ )
610
+
611
+ def loss_fn(logits, labels):
612
+ shift_logits = logits[..., :-1, :]
613
+ shift_labels = labels[..., 1:]
614
+ loss = optax.softmax_cross_entropy(
615
+ shift_logits, onehot(shift_labels, shift_logits.shape[-1])
616
+ )
617
+ return loss.mean()
618
+
619
+ # Define gradient update step fn
620
+ def train_step(state, batch):
621
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
622
+
623
+ def compute_loss(params):
624
+ labels = batch.pop("labels")
625
+ logits = state.apply_fn(
626
+ **batch, params=params, dropout_rng=dropout_rng, train=True
627
+ )[0]
628
+ loss = loss_fn(logits, labels)
629
+ return loss
630
+
631
+ grad_fn = jax.value_and_grad(compute_loss)
632
+ loss, grad = grad_fn(state.params)
633
+ grad = jax.lax.pmean(grad, "batch")
634
+
635
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
636
+
637
+ metrics = {
638
+ "loss": loss,
639
+ "learning_rate": linear_decay_lr_schedule_fn(state.step),
640
+ }
641
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
642
+
643
+ return new_state, metrics
644
+
645
+ # Define eval fn
646
+ def eval_step(params, batch):
647
+ labels = batch.pop("labels")
648
+ logits = model(**batch, params=params, train=False)[0]
649
+ loss = loss_fn(logits, labels)
650
+
651
+ # summarize metrics
652
+ metrics = {"loss": loss}
653
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
654
+ return metrics
655
+
656
+ # Create parallel version of the train and eval step
657
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
658
+ p_eval_step = jax.pmap(eval_step, "batch")
659
+
660
+ # Replicate the train state on each device
661
+ state = state.replicate()
662
+
663
+ logger.info("***** Running training *****")
664
+ logger.info(f" Num examples = {len(train_dataset)}")
665
+ logger.info(f" Num Epochs = {num_epochs}")
666
+ logger.info(
667
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
668
+ )
669
+ logger.info(
670
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
671
+ )
672
+ logger.info(f" Total optimization steps = {total_train_steps}")
673
+
674
+ train_time = 0
675
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
676
+ for epoch in epochs:
677
+ # ======================== Training ================================
678
+ train_start = time.time()
679
+
680
+ # Create sampling rng
681
+ rng, input_rng = jax.random.split(rng)
682
+ train_metrics = []
683
+
684
+ # Generate an epoch by shuffling sampling indices from the train dataset
685
+ train_loader = data_loader(
686
+ input_rng, train_dataset, train_batch_size, shuffle=True
687
+ )
688
+ steps_per_epoch = len(train_dataset) // train_batch_size
689
+ # train
690
+ for _ in tqdm(
691
+ range(steps_per_epoch), desc="Training...", position=1, leave=False
692
+ ):
693
+ batch = next(train_loader)
694
+ state, train_metric = p_train_step(state, batch)
695
+ train_metrics.append(train_metric)
696
+
697
+ train_time += time.time() - train_start
698
+
699
+ train_metric = unreplicate(train_metric)
700
+
701
+ epochs.write(
702
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
703
+ )
704
+
705
+ # ======================== Evaluating ==============================
706
+ eval_metrics = []
707
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
708
+ eval_steps = len(eval_dataset) // eval_batch_size
709
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
710
+ # Model forward
711
+ batch = next(eval_loader)
712
+ metrics = p_eval_step(state.params, batch)
713
+ eval_metrics.append(metrics)
714
+
715
+ # normalize eval metrics
716
+ eval_metrics = get_metrics(eval_metrics)
717
+
718
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
719
+
720
+ try:
721
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
722
+ except OverflowError:
723
+ eval_metrics["perplexity"] = float("inf")
724
+
725
+ # Print metrics and update progress bar
726
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
727
+ epochs.write(desc)
728
+ epochs.desc = desc
729
+
730
+ # Save metrics
731
+ if has_tensorboard and jax.process_index() == 0:
732
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
733
+ write_metric(
734
+ summary_writer, train_metrics, eval_metrics, train_time, cur_step
735
+ )
736
+
737
+ # save checkpoint after each epoch and push checkpoint to the hub
738
+ if jax.process_index() == 0:
739
+ params = jax.device_get(unreplicate(state.params))
740
+ model.save_pretrained(
741
+ training_args.output_dir,
742
+ params=params,
743
+ push_to_hub=training_args.push_to_hub,
744
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
745
+ )
746
+
747
+
748
+ if __name__ == "__main__":
749
+ main()
750
+
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
train_tokenizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, concatenate_datasets
2
+ from tokenizers import ByteLevelBPETokenizer
3
+ from pathlib import Path
4
+
5
+ dataset_language = "su"
6
+ validation_split_percentage = 10
7
+
8
+ # load dataset
9
+ # only the train subset for tokenizing purposes
10
+ oscar = load_dataset(
11
+ "oscar", f"unshuffled_deduplicated_{dataset_language}", split="train",
12
+ )
13
+
14
+ cc100 = load_dataset("cc100", lang=dataset_language, split="train")
15
+
16
+ mc4 = load_dataset("mc4", dataset_language, split="train")
17
+
18
+ wiki_files = [str(x) for x in Path("../docs").glob("*.txt")]
19
+ wiki = load_dataset("text", data_files=wiki_files)
20
+
21
+ # want: text column only!
22
+ oscar = oscar.remove_columns("id")
23
+ mc4 = mc4.remove_columns(["url", "timestamp"])
24
+ cc100 = cc100.remove_columns("id")
25
+
26
+ dataset = concatenate_datasets([oscar, mc4, cc100, wiki["train"]])
27
+ dataset = dataset.train_test_split(test_size=validation_split_percentage / 100, seed=42)
28
+
29
+ # Instantiate tokenizer
30
+ tokenizer = ByteLevelBPETokenizer()
31
+
32
+
33
+ def batch_iterator(batch_size=10000):
34
+ for i in range(0, len(dataset), batch_size):
35
+ yield dataset["train"][i : i + batch_size]["text"]
36
+
37
+
38
+ # Customized training
39
+ tokenizer.train_from_iterator(
40
+ batch_iterator(),
41
+ vocab_size=50265,
42
+ min_frequency=2,
43
+ special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>",],
44
+ )
45
+
46
+ # Save files to disk
47
+ model_dir = "."
48
+ tokenizer.save(f"{model_dir}/tokenizer.json")
49
+