Skip to content

Commit 9ba4e41

Browse files
author
Ricardo Decal
committed
another pass
1 parent 1b64c96 commit 9ba4e41

1 file changed

Lines changed: 116 additions & 51 deletions

File tree

beginner_source/simple_distributed_training_tutorial.py

Lines changed: 116 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
1515
* Pre-train a GPT-2 (~124M-parameter) language model using PyTorch
1616
and Hugging Face Transformers.
17-
* Distribute training across multiple GPUs with Ray Train.
18-
* Stream training data from Hugging Face datasets with Ray Data.
17+
* Distribute training across multiple GPUs with Ray Train with minimal code changes.
18+
* Stream training data from Hugging Face datasets with Ray Data's distributed workers.
1919
* Save and load distributed checkpoints.
2020
* Scale from a single node to a multi-node cluster with minimal code changes.
21+
* Optimize cost and performance with heterogeneous clusters.
2122
* Monitor training with the Ray dashboard.
2223
2324
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
@@ -50,8 +51,6 @@
5051
Then, import the required libraries:
5152
"""
5253

53-
###############################################################################
54-
5554
import time
5655

5756
import numpy as np
@@ -87,6 +86,11 @@
8786
train_ds = ray.data.from_huggingface(hf_ds["train"])
8887
val_ds = ray.data.from_huggingface(hf_ds["validation"])
8988

89+
# Limit dataset size for fast iteration during smoke tests.=
90+
if SMOKE_TEST:
91+
train_ds = train_ds.limit(2500)
92+
val_ds = val_ds.limit(2500)
93+
9094
print(f"Dataset schema:\n{train_ds.schema()}")
9195

9296
###############################################################################
@@ -101,10 +105,12 @@
101105
# This means that the dataset has one column called ``text`` and it is a string.
102106
#
103107
# Inspect raw data
108+
#
104109
# ~~~~~~~~~~~~~~~~
105110
#
106111
# Use ``take(n)`` to fetch a small number of rows for inspection.
107112
# Each row is a dictionary with the column names as keys.
113+
108114
print("--- Raw data sample ---")
109115
sample = train_ds.take(2)
110116
for i, row in enumerate(sample):
@@ -117,21 +123,15 @@
117123
# .. code-block:: text
118124
#
119125
# Row 0: ''
120-
# Row 1: ' = Valkyria Chronicles III = \n'
126+
# Row 1: ' = Valkyria Chronicles III = '
121127
#
122128
# Each row in Wikitext-103 is a single line from a Wikipedia article.
123129
# Consecutive rows belong to the same article, with empty rows separating
124130
# paragraphs. New articles begin with a title line like
125131
# ``= Article Title =``. The tokenization step below inserts an
126132
# ``<|endoftext|>`` separator token before each title line so the model
127133
# learns to reset context at article boundaries.
128-
129-
# Limit dataset size for fast iteration during smoke tests.=
130-
if SMOKE_TEST:
131-
train_ds = train_ds.limit(2500)
132-
val_ds = val_ds.limit(2500)
133-
134-
###############################################################################
134+
#
135135
# Tokenize and chunk the data
136136
# ----------------------------
137137
#
@@ -195,12 +195,14 @@ def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
195195
}
196196

197197

198+
198199
###############################################################################
199200
# Apply the tokenization with ``map_batches()``. This operation is **lazy**,
200201
# meaning that Ray Data defers execution until a downstream consumer requests the
201202
# results. Lazy execution lets Ray optimize the entire pipeline before any
202203
# work begins.
203204

205+
# These do not trigger execution.
204206
train_ds = train_ds.map_batches(tokenize_and_chunk, batch_format="numpy")
205207
val_ds = val_ds.map_batches(tokenize_and_chunk, batch_format="numpy")
206208

@@ -236,29 +238,23 @@ def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
236238
# .. code-block:: text
237239
#
238240
# Execution plan: InputDataBuffer[Input]
239-
# -> TaskPoolMapOperator[Filter]
240241
# -> TaskPoolMapOperator[MapBatches(tokenize_and_chunk)]
241242
# -> OutputSplitter[split(8, equal=True)]
242243
#
243-
# This tells you exactly how Ray Data will stream through filter, tokenize,
244-
# and split the data across 8 workers.
244+
# This tells you exactly how Ray Data will stream through tokenization
245+
# and split the data across 8 trainer workers.
245246

246247
###############################################################################
247248
# Define the transformer model
248249
# ----------------------------
249250
#
250251
# The model is a decoder-only transformer language model using Hugging Face's
251-
# ``GPT2LMHeadModel``.
252+
# ``GPT2LMHeadModel``. The hyperparameters below are the standard GPT-2 "small" architecture.
252253
#
253-
# The GPT-2 "small" architecture:
254-
#
255-
# * 12 transformer layers, 12 attention heads, 768 hidden size
256-
# * ~124M parameters
257-
# * Built-in causal attention masking and weight tying
258254

259255

260256
def create_model():
261-
"""Create a fresh GPT-2 model from config (random weights)."""
257+
"""Create a GPT-2 small model with random weights."""
262258
model = GPT2LMHeadModel(GPT2Config(
263259
vocab_size=VOCAB_SIZE,
264260
n_positions=BLOCK_SIZE,
@@ -270,6 +266,7 @@ def create_model():
270266
return model
271267

272268

269+
273270
###############################################################################
274271
# Verify the model size:
275272

@@ -280,8 +277,7 @@ def create_model():
280277
del model # Free memory before training
281278

282279
###############################################################################
283-
# You should see approximately **123.8M parameters** — the standard GPT-2
284-
# "small" size.
280+
# You should see approximately 123.8M parameters.
285281

286282
###############################################################################
287283
# Define the distributed training function
@@ -294,15 +290,15 @@ def create_model():
294290
#
295291
# The key Ray Train integration points are:
296292
#
297-
# 1. **``ray.train.get_dataset_shard("train")``** retrieves the
293+
# 1. ``ray.train.get_dataset_shard("train")`` retrieves the
298294
# worker's portion of the data. Ray Data automatically splits the
299295
# dataset across all workers.
300-
# 2. **``ray.train.torch.prepare_model(model)``** wraps the model in
296+
# 2. ``ray.train.torch.prepare_model(model)`` wraps the model in
301297
# ``DistributedDataParallel`` and moves it to the correct GPU.
302-
# 3. **``shard.iter_torch_batches(batch_size=...)``** returns an iterator
298+
# 3. ``shard.iter_torch_batches(batch_size=...)`` returns an iterator
303299
# of ``dict[str, torch.Tensor]`` batches, with tensors automatically
304-
# placed on the worker's GPU.
305-
# 4. **``ray.train.report(metrics, checkpoint=...)``** reports metrics
300+
# placed on the worker's GPU. Setting ``prefetch_batches=2`` opportunistically fetches 2 batches ahead of the current batch.
301+
# 4. ``ray.train.report(metrics, checkpoint=...)`` reports metrics
306302
# to the driver and optionally saves a checkpoint.
307303

308304

@@ -338,7 +334,7 @@ def train_func_per_worker(config: dict):
338334

339335
# iter_torch_batches returns dicts of tensors already on the GPU.
340336
for batch in train_data_shard.iter_torch_batches(
341-
batch_size=batch_size, dtypes=torch.long
337+
batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
342338
):
343339
input_ids = batch["input_ids"]
344340
labels = batch["labels"]
@@ -371,7 +367,7 @@ def train_func_per_worker(config: dict):
371367

372368
with torch.no_grad():
373369
for batch in val_data_shard.iter_torch_batches(
374-
batch_size=batch_size, dtypes=torch.long
370+
batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
375371
):
376372
input_ids = batch["input_ids"]
377373
labels = batch["labels"]
@@ -402,18 +398,20 @@ def train_func_per_worker(config: dict):
402398
)
403399

404400

401+
405402
###############################################################################
406403
# Configure and launch distributed training
407404
# ------------------------------------------
408405
#
409-
# The ``TorchTrainer`` brings everything together. It accepts:
406+
# The ``TorchTrainer`` brings everything together. Running ``trainer.fit()`` finally
407+
# triggers the execution of the full data pipeline and training loop. The Trainer accepts:
410408
#
411-
# * **``train_func_per_worker``**: the function each worker executes.
412-
# * **``train_loop_config``**: a dictionary of hyperparameters forwarded
409+
# * ``train_func_per_worker``: the function each worker executes.
410+
# * ``train_loop_config``: a dictionary of hyperparameters forwarded
413411
# to the training function.
414-
# * **``datasets``**: a dictionary of Ray Datasets. Ray Train automatically
412+
# * ``datasets``: a dictionary of Ray Datasets. Ray Train automatically
415413
# splits each dataset across workers.
416-
# * **``scaling_config``**: specifies the number of workers and whether to
414+
# * ``scaling_config``: specifies the number of workers and whether to
417415
# use GPUs.
418416
#
419417
# Setting ``num_workers=8`` launches 8 parallel workers, one per GPU. Ray
@@ -424,8 +422,9 @@ def train_func_per_worker(config: dict):
424422
# .. code-block:: text
425423
#
426424
# Started training worker group of size 8:
427-
# - (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
428-
# - (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
425+
#
426+
# * (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
427+
# * (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
429428
# ...
430429
# Moving model to device: cuda:0
431430
# Wrapping provided model in DistributedDataParallel.
@@ -452,6 +451,7 @@ def train_func_per_worker(config: dict):
452451
"batch_size_per_worker": BATCH_SIZE_PER_WORKER,
453452
"max_steps_per_epoch": 5 if SMOKE_TEST else None,
454453
},
454+
# Register the datasets,
455455
datasets={"train": train_ds, "validation": val_ds},
456456
scaling_config=ScalingConfig(
457457
num_workers=NUM_WORKERS,
@@ -483,12 +483,11 @@ def train_func_per_worker(config: dict):
483483
#
484484
# The per-worker logs show training loss, validation loss, and throughput
485485
# metrics for each epoch. With random weights and only a few steps, expect
486-
# a high loss (~10–11) — this is normal. In a real training run with more
487-
# epochs and the full dataset, you would see loss steadily decrease and
488-
# throughput stabilize.
486+
# a high loss (~10–11).
489487

490488
###############################################################################
491489
# Checkpointing
490+
#
492491
# ~~~~~~~~~~~~~
493492
#
494493
# In a production training run you would enable checkpointing so that
@@ -507,23 +506,70 @@ def train_func_per_worker(config: dict):
507506
# )
508507
#
509508
# Inside the training function, save a checkpoint with
510-
# ``ray.train.report()``:
509+
# ``ray.train.report()``. Every worker must still call ``ray.train.report()``:
511510
#
512511
# .. code-block:: python
513512
#
514513
# with tempfile.TemporaryDirectory() as tmp_dir:
515-
# model.module.save_pretrained(tmp_dir) # .module unwraps DDP
516-
# checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
514+
# checkpoint = None
515+
# if ray.train.get_context().get_world_rank() == 0:
516+
# torch.save(model.module.state_dict(),
517+
# os.path.join(tmp_dir, "model.pt"))
518+
# torch.save(optimizer.state_dict(),
519+
# os.path.join(tmp_dir, "optimizer.pt"))
520+
# torch.save({"epoch": epoch},
521+
# os.path.join(tmp_dir, "extra_state.pt"))
522+
# checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
517523
# ray.train.report(metrics={...}, checkpoint=checkpoint)
518524
#
525+
# Note that ``.module`` unwraps the ``DistributedDataParallel`` wrapper so
526+
# you save the underlying model weights rather than the DDP wrapper.
527+
#
528+
# To **resume training from a checkpoint**, call
529+
# ``ray.train.get_checkpoint()`` at the top of your training function.
530+
# When Ray Train restarts workers (for example, after a failure), it
531+
# automatically provides the latest checkpoint. If no checkpoint exists
532+
# (i.e. this is a fresh run), the function returns ``None``:
533+
#
534+
# .. code-block:: python
535+
#
536+
# def train_func_per_worker(config: dict):
537+
# model = create_model()
538+
# model = ray.train.torch.prepare_model(model)
539+
# optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])
540+
#
541+
# # Resume from the latest checkpoint if one exists.
542+
# start_epoch = 0
543+
# checkpoint = ray.train.get_checkpoint()
544+
# if checkpoint:
545+
# with checkpoint.as_directory() as ckpt_dir:
546+
# model.module.load_state_dict(
547+
# torch.load(os.path.join(ckpt_dir, "model.pt"))
548+
# )
549+
# optimizer.load_state_dict(
550+
# torch.load(os.path.join(ckpt_dir, "optimizer.pt"))
551+
# )
552+
# start_epoch = torch.load(
553+
# os.path.join(ckpt_dir, "extra_state.pt")
554+
# )["epoch"] + 1
555+
#
556+
# for epoch in range(start_epoch, config["epochs"]):
557+
# # ... training loop ...
558+
#
559+
# You can also call ``TorchTrainer.restore(path, datasets=...)`` to
560+
# restore an entire interrupted experiment from its results directory
561+
# without re-specifying the full trainer configuration. See the `Ray Train
562+
# checkpointing guide
563+
# <https://docs.ray.io/en/latest/train/user-guides/checkpoints.html>`__
564+
# for more details.
565+
#
519566
# Scaling to a multi-node cluster
520567
# -------------------------------
521568
#
522569
# The code above runs on a single 8-GPU machine. Scaling to a multi-node
523570
# cluster requires only two changes:
524571
#
525-
# 1. **Increase ``num_workers``** to match the total number of GPUs across
526-
# all nodes.
572+
# 1. **Increase ``num_workers``** to match the total number of GPUs in the cluster.
527573
# 2. **Set a shared storage path** so that all nodes can access checkpoints.
528574
#
529575
# For example, to train on a cluster of 4 nodes with 8 GPUs each
@@ -549,8 +595,6 @@ def train_func_per_worker(config: dict):
549595
# Ray Train automatically:
550596
#
551597
# * Launches workers across all available nodes.
552-
# * Initializes ``torch.distributed`` with the NCCL backend.
553-
# * Configures ``DistributedDataParallel`` across nodes.
554598
# * Shards data across all workers.
555599
#
556600
# No changes to the training function are needed. The same
@@ -564,6 +608,24 @@ def train_func_per_worker(config: dict):
564608
# `FullyShardedDataParallel <https://pytorch.org/docs/stable/fsdp.html>`__
565609
# (FSDP) to shard parameters, gradients, and optimizer states across
566610
# workers by setting ``prepare_model(parallel_strategy="fsdp")``.
611+
#
612+
# Heterogeneous clusters: separate data and training resources
613+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
614+
#
615+
# Because Ray Data and Ray Train are separate systems, they don't have to
616+
# share the same machines. By default, Ray Data preprocessing and training
617+
# workers all run on the same nodes. However, you can optionally add
618+
# **CPU-only nodes** to your cluster and Ray Data will automatically
619+
# schedule preprocessing tasks on them, keeping your expensive GPU nodes
620+
# free for training.
621+
#
622+
# This is useful when data preprocessing is a bottleneck. If you notice
623+
# low GPU utilization because workers are waiting on data, you can add
624+
# cheaper CPU-only nodes to the cluster and Ray Data scales out
625+
# preprocessing to them.
626+
#
627+
# For more details, see `Configuring data ingest
628+
# <https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html>`__.
567629

568630
###############################################################################
569631
# Fault tolerance
@@ -582,8 +644,7 @@ def train_func_per_worker(config: dict):
582644
# storage. When recovering from a failure, training resumes from the
583645
# latest checkpoint rather than starting over.
584646
# * **Node failure handling**: If an entire node goes down, Ray
585-
# redistributes work to surviving nodes and replaces the failed node
586-
# when new resources become available.
647+
# replaces the failed node and resumes training.
587648
#
588649
# To enable automatic failure recovery, configure ``FailureConfig`` in your ``RunConfig``:
589650
#
@@ -637,9 +698,13 @@ def train_func_per_worker(config: dict):
637698
# * Ran distributed training across 8 GPUs using Ray Train's
638699
# ``TorchTrainer`` with only minimal changes to a standard PyTorch
639700
# training loop.
640-
# * Learned how to save distributed checkpoints for model recovery.
701+
# * Learned how to save and load distributed checkpoints for model
702+
# recovery.
641703
# * Learned how to scale to multi-node clusters by changing
642704
# ``ScalingConfig`` and ``RunConfig``.
705+
# * Learned how heterogeneous clusters let you run data preprocessing
706+
# on CPU nodes and training on GPU nodes for cost and performance
707+
# optimization.
643708
# * Learned about Ray Train's **fault tolerance** mechanisms for
644709
# production training jobs.
645710
# * Monitored training with the Ray dashboard.

0 commit comments

Comments
 (0)