1414
1515 * Pre-train a GPT-2 (~124M-parameter) language model using PyTorch
1616 and Hugging Face Transformers.
17- * Distribute training across multiple GPUs with Ray Train.
18- * Stream training data from Hugging Face datasets with Ray Data.
17+ * Distribute training across multiple GPUs with Ray Train with minimal code changes .
18+ * Stream training data from Hugging Face datasets with Ray Data's distributed workers .
1919 * Save and load distributed checkpoints.
2020 * Scale from a single node to a multi-node cluster with minimal code changes.
21+ * Optimize cost and performance with heterogeneous clusters.
2122 * Monitor training with the Ray dashboard.
2223
2324 .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
5051Then, import the required libraries:
5152"""
5253
53- ###############################################################################
54-
5554import time
5655
5756import numpy as np
8786train_ds = ray .data .from_huggingface (hf_ds ["train" ])
8887val_ds = ray .data .from_huggingface (hf_ds ["validation" ])
8988
89+ # Limit dataset size for fast iteration during smoke tests.=
90+ if SMOKE_TEST :
91+ train_ds = train_ds .limit (2500 )
92+ val_ds = val_ds .limit (2500 )
93+
9094print (f"Dataset schema:\n { train_ds .schema ()} " )
9195
9296###############################################################################
101105# This means that the dataset has one column called ``text`` and it is a string.
102106#
103107# Inspect raw data
108+ #
104109# ~~~~~~~~~~~~~~~~
105110#
106111# Use ``take(n)`` to fetch a small number of rows for inspection.
107112# Each row is a dictionary with the column names as keys.
113+
108114print ("--- Raw data sample ---" )
109115sample = train_ds .take (2 )
110116for i , row in enumerate (sample ):
117123# .. code-block:: text
118124#
119125# Row 0: ''
120- # Row 1: ' = Valkyria Chronicles III = \n '
126+ # Row 1: ' = Valkyria Chronicles III = '
121127#
122128# Each row in Wikitext-103 is a single line from a Wikipedia article.
123129# Consecutive rows belong to the same article, with empty rows separating
124130# paragraphs. New articles begin with a title line like
125131# ``= Article Title =``. The tokenization step below inserts an
126132# ``<|endoftext|>`` separator token before each title line so the model
127133# learns to reset context at article boundaries.
128-
129- # Limit dataset size for fast iteration during smoke tests.=
130- if SMOKE_TEST :
131- train_ds = train_ds .limit (2500 )
132- val_ds = val_ds .limit (2500 )
133-
134- ###############################################################################
134+ #
135135# Tokenize and chunk the data
136136# ----------------------------
137137#
@@ -195,12 +195,14 @@ def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
195195 }
196196
197197
198+
198199###############################################################################
199200# Apply the tokenization with ``map_batches()``. This operation is **lazy**,
200201# meaning that Ray Data defers execution until a downstream consumer requests the
201202# results. Lazy execution lets Ray optimize the entire pipeline before any
202203# work begins.
203204
205+ # These do not trigger execution.
204206train_ds = train_ds .map_batches (tokenize_and_chunk , batch_format = "numpy" )
205207val_ds = val_ds .map_batches (tokenize_and_chunk , batch_format = "numpy" )
206208
@@ -236,29 +238,23 @@ def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
236238# .. code-block:: text
237239#
238240# Execution plan: InputDataBuffer[Input]
239- # -> TaskPoolMapOperator[Filter]
240241# -> TaskPoolMapOperator[MapBatches(tokenize_and_chunk)]
241242# -> OutputSplitter[split(8, equal=True)]
242243#
243- # This tells you exactly how Ray Data will stream through filter, tokenize,
244- # and split the data across 8 workers.
244+ # This tells you exactly how Ray Data will stream through tokenization
245+ # and split the data across 8 trainer workers.
245246
246247###############################################################################
247248# Define the transformer model
248249# ----------------------------
249250#
250251# The model is a decoder-only transformer language model using Hugging Face's
251- # ``GPT2LMHeadModel``.
252+ # ``GPT2LMHeadModel``. The hyperparameters below are the standard GPT-2 "small" architecture.
252253#
253- # The GPT-2 "small" architecture:
254- #
255- # * 12 transformer layers, 12 attention heads, 768 hidden size
256- # * ~124M parameters
257- # * Built-in causal attention masking and weight tying
258254
259255
260256def create_model ():
261- """Create a fresh GPT-2 model from config ( random weights) ."""
257+ """Create a GPT-2 small model with random weights."""
262258 model = GPT2LMHeadModel (GPT2Config (
263259 vocab_size = VOCAB_SIZE ,
264260 n_positions = BLOCK_SIZE ,
@@ -270,6 +266,7 @@ def create_model():
270266 return model
271267
272268
269+
273270###############################################################################
274271# Verify the model size:
275272
@@ -280,8 +277,7 @@ def create_model():
280277del model # Free memory before training
281278
282279###############################################################################
283- # You should see approximately **123.8M parameters** — the standard GPT-2
284- # "small" size.
280+ # You should see approximately 123.8M parameters.
285281
286282###############################################################################
287283# Define the distributed training function
@@ -294,15 +290,15 @@ def create_model():
294290#
295291# The key Ray Train integration points are:
296292#
297- # 1. ** ``ray.train.get_dataset_shard("train")``** retrieves the
293+ # 1. ``ray.train.get_dataset_shard("train")`` retrieves the
298294# worker's portion of the data. Ray Data automatically splits the
299295# dataset across all workers.
300- # 2. ** ``ray.train.torch.prepare_model(model)``** wraps the model in
296+ # 2. ``ray.train.torch.prepare_model(model)`` wraps the model in
301297# ``DistributedDataParallel`` and moves it to the correct GPU.
302- # 3. ** ``shard.iter_torch_batches(batch_size=...)``** returns an iterator
298+ # 3. ``shard.iter_torch_batches(batch_size=...)`` returns an iterator
303299# of ``dict[str, torch.Tensor]`` batches, with tensors automatically
304- # placed on the worker's GPU.
305- # 4. ** ``ray.train.report(metrics, checkpoint=...)``** reports metrics
300+ # placed on the worker's GPU. Setting ``prefetch_batches=2`` opportunistically fetches 2 batches ahead of the current batch.
301+ # 4. ``ray.train.report(metrics, checkpoint=...)`` reports metrics
306302# to the driver and optionally saves a checkpoint.
307303
308304
@@ -338,7 +334,7 @@ def train_func_per_worker(config: dict):
338334
339335 # iter_torch_batches returns dicts of tensors already on the GPU.
340336 for batch in train_data_shard .iter_torch_batches (
341- batch_size = batch_size , dtypes = torch .long
337+ batch_size = batch_size , dtypes = torch .long , prefetch_batches = 2
342338 ):
343339 input_ids = batch ["input_ids" ]
344340 labels = batch ["labels" ]
@@ -371,7 +367,7 @@ def train_func_per_worker(config: dict):
371367
372368 with torch .no_grad ():
373369 for batch in val_data_shard .iter_torch_batches (
374- batch_size = batch_size , dtypes = torch .long
370+ batch_size = batch_size , dtypes = torch .long , prefetch_batches = 2
375371 ):
376372 input_ids = batch ["input_ids" ]
377373 labels = batch ["labels" ]
@@ -402,18 +398,20 @@ def train_func_per_worker(config: dict):
402398 )
403399
404400
401+
405402###############################################################################
406403# Configure and launch distributed training
407404# ------------------------------------------
408405#
409- # The ``TorchTrainer`` brings everything together. It accepts:
406+ # The ``TorchTrainer`` brings everything together. Running ``trainer.fit()`` finally
407+ # triggers the execution of the full data pipeline and training loop. The Trainer accepts:
410408#
411- # * ** ``train_func_per_worker``** : the function each worker executes.
412- # * ** ``train_loop_config``** : a dictionary of hyperparameters forwarded
409+ # * ``train_func_per_worker``: the function each worker executes.
410+ # * ``train_loop_config``: a dictionary of hyperparameters forwarded
413411# to the training function.
414- # * ** ``datasets``** : a dictionary of Ray Datasets. Ray Train automatically
412+ # * ``datasets``: a dictionary of Ray Datasets. Ray Train automatically
415413# splits each dataset across workers.
416- # * ** ``scaling_config``** : specifies the number of workers and whether to
414+ # * ``scaling_config``: specifies the number of workers and whether to
417415# use GPUs.
418416#
419417# Setting ``num_workers=8`` launches 8 parallel workers, one per GPU. Ray
@@ -424,8 +422,9 @@ def train_func_per_worker(config: dict):
424422# .. code-block:: text
425423#
426424# Started training worker group of size 8:
427- # - (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
428- # - (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
425+ #
426+ # * (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
427+ # * (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
429428# ...
430429# Moving model to device: cuda:0
431430# Wrapping provided model in DistributedDataParallel.
@@ -452,6 +451,7 @@ def train_func_per_worker(config: dict):
452451 "batch_size_per_worker" : BATCH_SIZE_PER_WORKER ,
453452 "max_steps_per_epoch" : 5 if SMOKE_TEST else None ,
454453 },
454+ # Register the datasets,
455455 datasets = {"train" : train_ds , "validation" : val_ds },
456456 scaling_config = ScalingConfig (
457457 num_workers = NUM_WORKERS ,
@@ -483,12 +483,11 @@ def train_func_per_worker(config: dict):
483483#
484484# The per-worker logs show training loss, validation loss, and throughput
485485# metrics for each epoch. With random weights and only a few steps, expect
486- # a high loss (~10–11) — this is normal. In a real training run with more
487- # epochs and the full dataset, you would see loss steadily decrease and
488- # throughput stabilize.
486+ # a high loss (~10–11).
489487
490488###############################################################################
491489# Checkpointing
490+ #
492491# ~~~~~~~~~~~~~
493492#
494493# In a production training run you would enable checkpointing so that
@@ -507,23 +506,70 @@ def train_func_per_worker(config: dict):
507506# )
508507#
509508# Inside the training function, save a checkpoint with
510- # ``ray.train.report()``:
509+ # ``ray.train.report()``. Every worker must still call ``ray.train.report()`` :
511510#
512511# .. code-block:: python
513512#
514513# with tempfile.TemporaryDirectory() as tmp_dir:
515- # model.module.save_pretrained(tmp_dir) # .module unwraps DDP
516- # checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
514+ # checkpoint = None
515+ # if ray.train.get_context().get_world_rank() == 0:
516+ # torch.save(model.module.state_dict(),
517+ # os.path.join(tmp_dir, "model.pt"))
518+ # torch.save(optimizer.state_dict(),
519+ # os.path.join(tmp_dir, "optimizer.pt"))
520+ # torch.save({"epoch": epoch},
521+ # os.path.join(tmp_dir, "extra_state.pt"))
522+ # checkpoint = ray.train.Checkpoint.from_directory(tmp_dir)
517523# ray.train.report(metrics={...}, checkpoint=checkpoint)
518524#
525+ # Note that ``.module`` unwraps the ``DistributedDataParallel`` wrapper so
526+ # you save the underlying model weights rather than the DDP wrapper.
527+ #
528+ # To **resume training from a checkpoint**, call
529+ # ``ray.train.get_checkpoint()`` at the top of your training function.
530+ # When Ray Train restarts workers (for example, after a failure), it
531+ # automatically provides the latest checkpoint. If no checkpoint exists
532+ # (i.e. this is a fresh run), the function returns ``None``:
533+ #
534+ # .. code-block:: python
535+ #
536+ # def train_func_per_worker(config: dict):
537+ # model = create_model()
538+ # model = ray.train.torch.prepare_model(model)
539+ # optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])
540+ #
541+ # # Resume from the latest checkpoint if one exists.
542+ # start_epoch = 0
543+ # checkpoint = ray.train.get_checkpoint()
544+ # if checkpoint:
545+ # with checkpoint.as_directory() as ckpt_dir:
546+ # model.module.load_state_dict(
547+ # torch.load(os.path.join(ckpt_dir, "model.pt"))
548+ # )
549+ # optimizer.load_state_dict(
550+ # torch.load(os.path.join(ckpt_dir, "optimizer.pt"))
551+ # )
552+ # start_epoch = torch.load(
553+ # os.path.join(ckpt_dir, "extra_state.pt")
554+ # )["epoch"] + 1
555+ #
556+ # for epoch in range(start_epoch, config["epochs"]):
557+ # # ... training loop ...
558+ #
559+ # You can also call ``TorchTrainer.restore(path, datasets=...)`` to
560+ # restore an entire interrupted experiment from its results directory
561+ # without re-specifying the full trainer configuration. See the `Ray Train
562+ # checkpointing guide
563+ # <https://docs.ray.io/en/latest/train/user-guides/checkpoints.html>`__
564+ # for more details.
565+ #
519566# Scaling to a multi-node cluster
520567# -------------------------------
521568#
522569# The code above runs on a single 8-GPU machine. Scaling to a multi-node
523570# cluster requires only two changes:
524571#
525- # 1. **Increase ``num_workers``** to match the total number of GPUs across
526- # all nodes.
572+ # 1. **Increase ``num_workers``** to match the total number of GPUs in the cluster.
527573# 2. **Set a shared storage path** so that all nodes can access checkpoints.
528574#
529575# For example, to train on a cluster of 4 nodes with 8 GPUs each
@@ -549,8 +595,6 @@ def train_func_per_worker(config: dict):
549595# Ray Train automatically:
550596#
551597# * Launches workers across all available nodes.
552- # * Initializes ``torch.distributed`` with the NCCL backend.
553- # * Configures ``DistributedDataParallel`` across nodes.
554598# * Shards data across all workers.
555599#
556600# No changes to the training function are needed. The same
@@ -564,6 +608,24 @@ def train_func_per_worker(config: dict):
564608# `FullyShardedDataParallel <https://pytorch.org/docs/stable/fsdp.html>`__
565609# (FSDP) to shard parameters, gradients, and optimizer states across
566610# workers by setting ``prepare_model(parallel_strategy="fsdp")``.
611+ #
612+ # Heterogeneous clusters: separate data and training resources
613+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
614+ #
615+ # Because Ray Data and Ray Train are separate systems, they don't have to
616+ # share the same machines. By default, Ray Data preprocessing and training
617+ # workers all run on the same nodes. However, you can optionally add
618+ # **CPU-only nodes** to your cluster and Ray Data will automatically
619+ # schedule preprocessing tasks on them, keeping your expensive GPU nodes
620+ # free for training.
621+ #
622+ # This is useful when data preprocessing is a bottleneck. If you notice
623+ # low GPU utilization because workers are waiting on data, you can add
624+ # cheaper CPU-only nodes to the cluster and Ray Data scales out
625+ # preprocessing to them.
626+ #
627+ # For more details, see `Configuring data ingest
628+ # <https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html>`__.
567629
568630###############################################################################
569631# Fault tolerance
@@ -582,8 +644,7 @@ def train_func_per_worker(config: dict):
582644# storage. When recovering from a failure, training resumes from the
583645# latest checkpoint rather than starting over.
584646# * **Node failure handling**: If an entire node goes down, Ray
585- # redistributes work to surviving nodes and replaces the failed node
586- # when new resources become available.
647+ # replaces the failed node and resumes training.
587648#
588649# To enable automatic failure recovery, configure ``FailureConfig`` in your ``RunConfig``:
589650#
@@ -637,9 +698,13 @@ def train_func_per_worker(config: dict):
637698# * Ran distributed training across 8 GPUs using Ray Train's
638699# ``TorchTrainer`` with only minimal changes to a standard PyTorch
639700# training loop.
640- # * Learned how to save distributed checkpoints for model recovery.
701+ # * Learned how to save and load distributed checkpoints for model
702+ # recovery.
641703# * Learned how to scale to multi-node clusters by changing
642704# ``ScalingConfig`` and ``RunConfig``.
705+ # * Learned how heterogeneous clusters let you run data preprocessing
706+ # on CPU nodes and training on GPU nodes for cost and performance
707+ # optimization.
643708# * Learned about Ray Train's **fault tolerance** mechanisms for
644709# production training jobs.
645710# * Monitored training with the Ray dashboard.
0 commit comments