Spaces:
Runtime error
Runtime error
feat: better multi-node support (#158)
Browse files* reproducible data loader
* custom sharding
* model parallel across multiple nodes
- src/dalle_mini/data.py +12 -3
- tools/train/config/mega/config.json +27 -8
- tools/train/config/mini/config.json +1 -1
- tools/train/train.py +50 -9
src/dalle_mini/data.py
CHANGED
|
@@ -43,6 +43,8 @@ class Dataset:
|
|
| 43 |
if self.seed_dataset is None:
|
| 44 |
# create a random seed
|
| 45 |
self.seed_dataset = random.randint(0, 2**32 - 1)
|
|
|
|
|
|
|
| 46 |
self.multi_hosts = jax.process_count() > 1
|
| 47 |
# feed blank captions only in streaming mode for now
|
| 48 |
# otherwise dataset could be cached with same blanked captions
|
|
@@ -173,6 +175,7 @@ class Dataset:
|
|
| 173 |
blank_caption_function,
|
| 174 |
text_column=self.text_column,
|
| 175 |
blank_caption_prob=self.blank_caption_prob,
|
|
|
|
| 176 |
)
|
| 177 |
if hasattr(self, "train_dataset"):
|
| 178 |
self.train_dataset = (
|
|
@@ -180,7 +183,9 @@ class Dataset:
|
|
| 180 |
if self.streaming
|
| 181 |
else self.train_dataset.map(
|
| 182 |
partial_blank_caption_function,
|
| 183 |
-
num_proc=
|
|
|
|
|
|
|
| 184 |
load_from_cache_file=False,
|
| 185 |
desc="Blanking some captions",
|
| 186 |
)
|
|
@@ -316,8 +321,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
|
| 316 |
return shifted_input_ids
|
| 317 |
|
| 318 |
|
| 319 |
-
def blank_caption_function(example, text_column, blank_caption_prob):
|
| 320 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
example[text_column] = ""
|
| 322 |
return example
|
| 323 |
|
|
|
|
| 43 |
if self.seed_dataset is None:
|
| 44 |
# create a random seed
|
| 45 |
self.seed_dataset = random.randint(0, 2**32 - 1)
|
| 46 |
+
# set numpy rng
|
| 47 |
+
self.np_rng = np.random.default_rng(self.seed_dataset)
|
| 48 |
self.multi_hosts = jax.process_count() > 1
|
| 49 |
# feed blank captions only in streaming mode for now
|
| 50 |
# otherwise dataset could be cached with same blanked captions
|
|
|
|
| 175 |
blank_caption_function,
|
| 176 |
text_column=self.text_column,
|
| 177 |
blank_caption_prob=self.blank_caption_prob,
|
| 178 |
+
rng=self.np_rng,
|
| 179 |
)
|
| 180 |
if hasattr(self, "train_dataset"):
|
| 181 |
self.train_dataset = (
|
|
|
|
| 183 |
if self.streaming
|
| 184 |
else self.train_dataset.map(
|
| 185 |
partial_blank_caption_function,
|
| 186 |
+
num_proc=None
|
| 187 |
+
if self.seed_dataset
|
| 188 |
+
else self.preprocessing_num_workers,
|
| 189 |
load_from_cache_file=False,
|
| 190 |
desc="Blanking some captions",
|
| 191 |
)
|
|
|
|
| 321 |
return shifted_input_ids
|
| 322 |
|
| 323 |
|
| 324 |
+
def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
|
| 325 |
+
if (
|
| 326 |
+
blank_caption_prob
|
| 327 |
+
and (rng.random() if rng is not None else np.random.random())
|
| 328 |
+
< blank_caption_prob
|
| 329 |
+
):
|
| 330 |
example[text_column] = ""
|
| 331 |
return example
|
| 332 |
|
tools/train/config/mega/config.json
CHANGED
|
@@ -1,30 +1,49 @@
|
|
| 1 |
{
|
| 2 |
"activation_dropout": 0.0,
|
| 3 |
-
"activation_function": "
|
| 4 |
"attention_dropout": 0.0,
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"d_model": 2048,
|
| 7 |
"decoder_attention_heads": 32,
|
| 8 |
-
"decoder_ffn_dim":
|
| 9 |
"decoder_layerdrop": 0.0,
|
| 10 |
-
"decoder_layers":
|
| 11 |
"decoder_start_token_id": 16384,
|
|
|
|
| 12 |
"dropout": 0.0,
|
| 13 |
"encoder_attention_heads": 32,
|
| 14 |
-
"encoder_ffn_dim":
|
| 15 |
"encoder_layerdrop": 0.0,
|
| 16 |
-
"encoder_layers":
|
| 17 |
-
"encoder_vocab_size":
|
| 18 |
"eos_token_id": 16385,
|
|
|
|
|
|
|
| 19 |
"image_length": 256,
|
| 20 |
-
"image_vocab_size":
|
| 21 |
"init_std": 0.01,
|
| 22 |
"is_encoder_decoder": true,
|
|
|
|
|
|
|
|
|
|
| 23 |
"max_text_length": 64,
|
|
|
|
| 24 |
"model_type": "dallebart",
|
| 25 |
"normalize_text": true,
|
| 26 |
"pad_token_id": 16385,
|
| 27 |
"scale_embedding": false,
|
|
|
|
|
|
|
| 28 |
"tie_word_embeddings": false,
|
| 29 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"activation_dropout": 0.0,
|
| 3 |
+
"activation_function": "swish",
|
| 4 |
"attention_dropout": 0.0,
|
| 5 |
"bos_token_id": 16385,
|
| 6 |
"d_model": 2048,
|
| 7 |
"decoder_attention_heads": 32,
|
| 8 |
+
"decoder_ffn_dim": 4096,
|
| 9 |
"decoder_layerdrop": 0.0,
|
| 10 |
+
"decoder_layers": 25,
|
| 11 |
"decoder_start_token_id": 16384,
|
| 12 |
+
"do_sample": true,
|
| 13 |
"dropout": 0.0,
|
| 14 |
"encoder_attention_heads": 32,
|
| 15 |
+
"encoder_ffn_dim": 4096,
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
+
"encoder_layers": 25,
|
| 18 |
+
"encoder_vocab_size": 50272,
|
| 19 |
"eos_token_id": 16385,
|
| 20 |
+
"force_ln_scale": false,
|
| 21 |
+
"gradient_checkpointing": false,
|
| 22 |
"image_length": 256,
|
| 23 |
+
"image_vocab_size": 16415,
|
| 24 |
"init_std": 0.01,
|
| 25 |
"is_encoder_decoder": true,
|
| 26 |
+
"ln_positions": "normformer",
|
| 27 |
+
"ln_type": "layernorm",
|
| 28 |
+
"max_length": 257,
|
| 29 |
"max_text_length": 64,
|
| 30 |
+
"min_length": 257,
|
| 31 |
"model_type": "dallebart",
|
| 32 |
"normalize_text": true,
|
| 33 |
"pad_token_id": 16385,
|
| 34 |
"scale_embedding": false,
|
| 35 |
+
"sinkhorn_iters": 1,
|
| 36 |
+
"tau_init": 0.05,
|
| 37 |
"tie_word_embeddings": false,
|
| 38 |
+
"use_absolute_position_embeddings": true,
|
| 39 |
+
"use_alibi": false,
|
| 40 |
+
"use_bias": false,
|
| 41 |
+
"use_cache": true,
|
| 42 |
+
"use_cosine_attention": false,
|
| 43 |
+
"use_deepnet_scaling": false,
|
| 44 |
+
"use_final_ln_decoder": true,
|
| 45 |
+
"use_final_ln_encoder": true,
|
| 46 |
+
"use_glu": true,
|
| 47 |
+
"use_head_scale": false,
|
| 48 |
+
"use_swin_position_embeddings": false
|
| 49 |
}
|
tools/train/config/mini/config.json
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
"eos_token_id": 16385,
|
| 17 |
"gradient_checkpointing": false,
|
| 18 |
"image_length": 256,
|
| 19 |
-
"image_vocab_size":
|
| 20 |
"init_std": 0.02,
|
| 21 |
"is_encoder_decoder": true,
|
| 22 |
"max_text_length": 64,
|
|
|
|
| 16 |
"eos_token_id": 16385,
|
| 17 |
"gradient_checkpointing": false,
|
| 18 |
"image_length": 256,
|
| 19 |
+
"image_vocab_size": 16391,
|
| 20 |
"init_std": 0.02,
|
| 21 |
"is_encoder_decoder": true,
|
| 22 |
"max_text_length": 64,
|
tools/train/train.py
CHANGED
|
@@ -368,6 +368,12 @@ class TrainingArguments:
|
|
| 368 |
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
| 369 |
},
|
| 370 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
num_train_epochs: int = field(
|
| 373 |
default=3, metadata={"help": "Total number of training epochs to perform."}
|
|
@@ -450,6 +456,11 @@ class TrainingArguments:
|
|
| 450 |
metadata={"help": "Verify that TPU is not in use."},
|
| 451 |
)
|
| 452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
mp_devices: Optional[int] = field(
|
| 454 |
default=1,
|
| 455 |
metadata={
|
|
@@ -500,6 +511,11 @@ class TrainingArguments:
|
|
| 500 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 501 |
"Use --overwrite_output_dir to overcome."
|
| 502 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
assert (
|
| 504 |
self.mp_devices > 0
|
| 505 |
), f"Number of devices for model parallelism must be > 0"
|
|
@@ -530,6 +546,12 @@ def main():
|
|
| 530 |
else:
|
| 531 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
# Make one log on every process with the configuration for debugging.
|
| 534 |
logging.basicConfig(
|
| 535 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
@@ -748,8 +770,20 @@ def main():
|
|
| 748 |
graft_type=graft_type,
|
| 749 |
nesterov=False,
|
| 750 |
exponent_override=0,
|
| 751 |
-
statistics_partition_spec=PartitionSpec(
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
num_devices_for_pjit=training_args.dp_devices,
|
| 754 |
shard_optimizer_states=True,
|
| 755 |
inverse_failure_threshold=0.1,
|
|
@@ -917,7 +951,7 @@ def main():
|
|
| 917 |
|
| 918 |
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
|
| 919 |
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
|
| 920 |
-
use_vmap_trick =
|
| 921 |
|
| 922 |
# make grad_param_spec for vmap
|
| 923 |
if use_vmap_trick:
|
|
@@ -1145,7 +1179,8 @@ def main():
|
|
| 1145 |
self.log_time("train_per_log", delta_time, offset=False)
|
| 1146 |
|
| 1147 |
def log_time(self, key, duration, offset=True):
|
| 1148 |
-
|
|
|
|
| 1149 |
if offset:
|
| 1150 |
self.offset_time += duration
|
| 1151 |
|
|
@@ -1191,7 +1226,11 @@ def main():
|
|
| 1191 |
# ======================== Evaluating ==============================
|
| 1192 |
if training_args.do_eval:
|
| 1193 |
start_eval_time = time.perf_counter()
|
| 1194 |
-
eval_loader = dataset.dataloader(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1195 |
eval_steps = (
|
| 1196 |
len_eval_dataset // eval_batch_size_per_step
|
| 1197 |
if len_eval_dataset is not None
|
|
@@ -1353,10 +1392,12 @@ def main():
|
|
| 1353 |
metrics_logger.update_state_metrics(local_state)
|
| 1354 |
metrics_logger.log({})
|
| 1355 |
|
| 1356 |
-
#
|
|
|
|
|
|
|
| 1357 |
train_loader = dataset.dataloader(
|
| 1358 |
"train",
|
| 1359 |
-
|
| 1360 |
epoch,
|
| 1361 |
)
|
| 1362 |
# train
|
|
@@ -1373,12 +1414,12 @@ def main():
|
|
| 1373 |
|
| 1374 |
# set correct shape to batch
|
| 1375 |
# - add grad_step dim if gradient_accumulation_steps > 1
|
| 1376 |
-
# - split per dp device if not multi-host for vmap trick (does not work in multi-host)
|
| 1377 |
bs_shape = (
|
| 1378 |
-
(batch_size_per_node_per_grad_step,)
|
| 1379 |
if not use_vmap_trick
|
| 1380 |
else (
|
| 1381 |
jax.local_device_count()
|
|
|
|
| 1382 |
// training_args.mp_devices, # local dp devices
|
| 1383 |
training_args.per_device_train_batch_size,
|
| 1384 |
)
|
|
|
|
| 368 |
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
| 369 |
},
|
| 370 |
)
|
| 371 |
+
shard_shampoo_across: str = field(
|
| 372 |
+
default="dp",
|
| 373 |
+
metadata={
|
| 374 |
+
"help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
|
| 375 |
+
},
|
| 376 |
+
)
|
| 377 |
|
| 378 |
num_train_epochs: int = field(
|
| 379 |
default=3, metadata={"help": "Total number of training epochs to perform."}
|
|
|
|
| 456 |
metadata={"help": "Verify that TPU is not in use."},
|
| 457 |
)
|
| 458 |
|
| 459 |
+
use_vmap_trick: bool = field(
|
| 460 |
+
default=True,
|
| 461 |
+
metadata={"help": "Verify that TPU is not in use."},
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
mp_devices: Optional[int] = field(
|
| 465 |
default=1,
|
| 466 |
metadata={
|
|
|
|
| 511 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 512 |
"Use --overwrite_output_dir to overcome."
|
| 513 |
)
|
| 514 |
+
assert self.shard_shampoo_across in [
|
| 515 |
+
"dp",
|
| 516 |
+
"mp",
|
| 517 |
+
"2d",
|
| 518 |
+
], f"Shard shampoo across {self.shard_shampoo_across} not supported."
|
| 519 |
assert (
|
| 520 |
self.mp_devices > 0
|
| 521 |
), f"Number of devices for model parallelism must be > 0"
|
|
|
|
| 546 |
else:
|
| 547 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 548 |
|
| 549 |
+
# check arguments
|
| 550 |
+
if training_args.mp_devices > jax.local_device_count():
|
| 551 |
+
assert (
|
| 552 |
+
data_args.seed_dataset is not None
|
| 553 |
+
), "Seed dataset must be provided when model is split over multiple hosts"
|
| 554 |
+
|
| 555 |
# Make one log on every process with the configuration for debugging.
|
| 556 |
logging.basicConfig(
|
| 557 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 770 |
graft_type=graft_type,
|
| 771 |
nesterov=False,
|
| 772 |
exponent_override=0,
|
| 773 |
+
statistics_partition_spec=PartitionSpec(
|
| 774 |
+
None, training_args.shard_shampoo_across, None
|
| 775 |
+
)
|
| 776 |
+
if training_args.shard_shampoo_across != "2d"
|
| 777 |
+
else PartitionSpec(None, "dp", "mp"),
|
| 778 |
+
preconditioner_partition_spec=PartitionSpec(
|
| 779 |
+
training_args.shard_shampoo_across, None, None
|
| 780 |
+
)
|
| 781 |
+
if training_args.shard_shampoo_across != "2d"
|
| 782 |
+
else PartitionSpec(
|
| 783 |
+
"mp" if training_args.mp_devices > training_args.dp_devices else "dp",
|
| 784 |
+
None,
|
| 785 |
+
None,
|
| 786 |
+
),
|
| 787 |
num_devices_for_pjit=training_args.dp_devices,
|
| 788 |
shard_optimizer_states=True,
|
| 789 |
inverse_failure_threshold=0.1,
|
|
|
|
| 951 |
|
| 952 |
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
|
| 953 |
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
|
| 954 |
+
use_vmap_trick = training_args.use_vmap_trick
|
| 955 |
|
| 956 |
# make grad_param_spec for vmap
|
| 957 |
if use_vmap_trick:
|
|
|
|
| 1179 |
self.log_time("train_per_log", delta_time, offset=False)
|
| 1180 |
|
| 1181 |
def log_time(self, key, duration, offset=True):
|
| 1182 |
+
if jax.process_index() == 0:
|
| 1183 |
+
wandb.log({f"time/{key}": duration, **self.state_dict})
|
| 1184 |
if offset:
|
| 1185 |
self.offset_time += duration
|
| 1186 |
|
|
|
|
| 1226 |
# ======================== Evaluating ==============================
|
| 1227 |
if training_args.do_eval:
|
| 1228 |
start_eval_time = time.perf_counter()
|
| 1229 |
+
eval_loader = dataset.dataloader(
|
| 1230 |
+
"eval",
|
| 1231 |
+
eval_batch_size_per_step
|
| 1232 |
+
* max(1, training_args.mp_devices // jax.local_device_count()),
|
| 1233 |
+
)
|
| 1234 |
eval_steps = (
|
| 1235 |
len_eval_dataset // eval_batch_size_per_step
|
| 1236 |
if len_eval_dataset is not None
|
|
|
|
| 1392 |
metrics_logger.update_state_metrics(local_state)
|
| 1393 |
metrics_logger.log({})
|
| 1394 |
|
| 1395 |
+
# load data - may be replicated on multiple nodes
|
| 1396 |
+
node_groups = max(1, training_args.mp_devices // jax.local_device_count())
|
| 1397 |
+
loader_bs = batch_size_per_node * node_groups
|
| 1398 |
train_loader = dataset.dataloader(
|
| 1399 |
"train",
|
| 1400 |
+
loader_bs,
|
| 1401 |
epoch,
|
| 1402 |
)
|
| 1403 |
# train
|
|
|
|
| 1414 |
|
| 1415 |
# set correct shape to batch
|
| 1416 |
# - add grad_step dim if gradient_accumulation_steps > 1
|
|
|
|
| 1417 |
bs_shape = (
|
| 1418 |
+
(batch_size_per_node_per_grad_step * node_groups,)
|
| 1419 |
if not use_vmap_trick
|
| 1420 |
else (
|
| 1421 |
jax.local_device_count()
|
| 1422 |
+
* node_groups
|
| 1423 |
// training_args.mp_devices, # local dp devices
|
| 1424 |
training_args.per_device_train_batch_size,
|
| 1425 |
)
|