Skip to content

Commit 1b64c96

Browse files
author
Ricardo Decal
committed
another pass
1 parent 9a038e5 commit 1b64c96

1 file changed

Lines changed: 18 additions & 11 deletions

File tree

beginner_source/simple_distributed_training_tutorial.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@
126126
# ``<|endoftext|>`` separator token before each title line so the model
127127
# learns to reset context at article boundaries.
128128

129-
# Limit dataset size for fast iteration during smoke tests.
129+
# Limit dataset size for fast iteration during smoke tests.=
130130
if SMOKE_TEST:
131-
train_ds = train_ds.limit(1000)
132-
val_ds = val_ds.limit(1000)
131+
train_ds = train_ds.limit(2500)
132+
val_ds = val_ds.limit(2500)
133133

134134
###############################################################################
135135
# Tokenize and chunk the data
@@ -275,7 +275,7 @@ def create_model():
275275

276276
model = create_model()
277277
num_params = sum(p.numel() for p in model.parameters())
278-
print(f"Model parameters: {num_params:,} ({num_params / 1e6:.1f}M)")
278+
print(f"Model parameters: {num_params / 1e6:.1f}M")
279279

280280
del model # Free memory before training
281281

@@ -333,7 +333,7 @@ def train_func_per_worker(config: dict):
333333
model.train()
334334
train_loss_sum = 0.0
335335
train_batches = 0
336-
train_items = 0
336+
train_tokens = 0
337337
epoch_start = time.perf_counter()
338338

339339
# iter_torch_batches returns dicts of tensors already on the GPU.
@@ -356,7 +356,7 @@ def train_func_per_worker(config: dict):
356356

357357
train_loss_sum += loss.item()
358358
train_batches += 1
359-
train_items += batch_size
359+
train_tokens += input_ids.numel()
360360

361361
if max_steps_per_epoch and train_batches >= max_steps_per_epoch:
362362
break
@@ -385,14 +385,16 @@ def train_func_per_worker(config: dict):
385385
break
386386

387387
avg_val_loss = val_loss_sum / max(val_batches, 1)
388+
epoch_elapsed = time.perf_counter() - epoch_start
388389

389390
# --- Report metrics -------------------------------------------------
390391
metrics = {
391392
"train_loss": round(avg_train_loss, 4),
392393
"val_loss": round(avg_val_loss, 4),
393394
"epoch": epoch,
394-
"batches_per_sec": round(train_batches / max(train_elapsed, 1e-6), 2),
395-
"items_per_sec": round(train_items / max(train_elapsed, 1e-6), 2),
395+
"epoch_time_sec": round(epoch_elapsed, 2),
396+
"epoch_tokens": train_tokens,
397+
"tokens_per_sec": round(train_tokens / max(train_elapsed, 1e-6), 2),
396398
}
397399
ray.train.report(
398400
metrics=metrics,
@@ -427,9 +429,14 @@ def train_func_per_worker(config: dict):
427429
# ...
428430
# Moving model to device: cuda:0
429431
# Wrapping provided model in DistributedDataParallel.
432+
#
433+
# ``batch_size_per_worker`` is the number of sequences each worker
434+
# processes per gradient step. With 8 workers and a per-worker batch size
435+
# of 16, the **effective global batch size** is 8 × 16 = 128 sequences,
436+
# or 128 × 256 = 32,768 tokens per optimizer step.
430437

431438
NUM_WORKERS = 8 # One worker per GPU on this machine
432-
NUM_EPOCHS = 20
439+
NUM_EPOCHS = 5
433440
BATCH_SIZE_PER_WORKER = 16
434441
LR = 3e-4
435442
WEIGHT_DECAY = 0.1
@@ -471,8 +478,8 @@ def train_func_per_worker(config: dict):
471478
#
472479
# .. code-block:: text
473480
#
474-
# {'train_loss': 10.95, 'val_loss': 10.02, 'epoch': 0,
475-
# 'batches_per_sec': 1.23, 'items_per_sec': 19.68}
481+
# {'train_loss': 7.0646, 'val_loss': 7.6051, 'epoch': 4,
482+
# 'epoch_time_sec': 12.34, 'epoch_tokens': 20480, 'tokens_per_sec': 1759.8}
476483
#
477484
# The per-worker logs show training loss, validation loss, and throughput
478485
# metrics for each epoch. With random weights and only a few steps, expect

0 commit comments

Comments
 (0)