Skip to content

Commit d745ba2

Browse files
committed
Improve Mosaic tutorial: add images, fix subprocess output, refactor buggy model
- Add GPT-2 memory profiling images (with/without activation checkpointing) - Add Google Colab download instructions for generated files - Fix subprocess.run to capture and print Mosaic CLI output - Split Mosaic analysis into separate code blocks for readability - Refactor GPT2WithDebugOverhead to use wrapper pattern instead of subclassing, fixing transformers version compatibility issues - Remove try/except workaround that was bypassing the tutorial's purpose - Update section formatting (bold headers instead of RST underlines)
1 parent afcd23a commit d745ba2

3 files changed

Lines changed: 83 additions & 30 deletions

File tree

164 KB
Loading
278 KB
Loading

beginner_source/mosaic_memory_profiling_tutorial.py

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@
175175

176176
# Install dependencies if needed
177177
try:
178-
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
178+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
179179
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
180180
except ImportError:
181181
subprocess.check_call(
182182
[sys.executable, "-m", "pip", "install", "-q", "transformers"]
183183
)
184-
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
184+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
185185
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
186186

187187
try:
@@ -481,10 +481,49 @@ def run_training_ac(
481481
print("\nNote: Mosaic profile generation encountered issues.")
482482
print("This may happen if running in an environment without full Mosaic support.")
483483

484+
######################################################################
485+
# Download Generated Files (Google Colab)
486+
# ----------------------------------------
487+
#
488+
# If running in Google Colab, you can download the generated snapshot
489+
# and profile files using the following code:
490+
#
491+
# .. code-block:: python
492+
#
493+
# from google.colab import files
494+
#
495+
# print("Downloading memory snapshots and profiles...")
496+
# files.download('snapshot_baseline.pickle')
497+
# files.download('snapshot_with_ac.pickle')
498+
# files.download('profile_baseline.html')
499+
# files.download('profile_with_ac.html')
500+
#
501+
484502
######################################################################
485503
# Results Interpretation: Activation Checkpointing
486504
# -------------------------------------------------
487505
#
506+
# The generated HTML profiles visualize memory usage over time, with
507+
# allocations colored by category. Here's what the profiles look like:
508+
#
509+
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-without-ac.png
510+
# :alt: GPT-2 memory profile without activation checkpointing
511+
# :align: center
512+
# :width: 600px
513+
#
514+
# **Baseline (without activation checkpointing):** Notice the large
515+
# activation memory (shown in one color) that persists throughout
516+
# the forward pass.
517+
#
518+
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-with-ac.png
519+
# :alt: GPT-2 memory profile with activation checkpointing
520+
# :align: center
521+
# :width: 600px
522+
#
523+
# **With activation checkpointing:** Activation memory is significantly
524+
# reduced as intermediate activations are discarded and recomputed
525+
# during the backward pass.
526+
#
488527
# What We Observed
489528
# ~~~~~~~~~~~~~~~~
490529
#
@@ -580,11 +619,17 @@ def run_training_ac(
580619
# debugging, but forgot to remove them before training.
581620

582621

583-
class GPT2WithDebugOverhead(GPT2LMHeadModel):
584-
"""GPT2 with abandoned 'feature analysis' code that bloats peak memory."""
622+
class GPT2WithDebugOverhead(torch.nn.Module):
623+
"""GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
624+
625+
This wrapper adds extra projection layers that consume memory but serve no
626+
purpose - simulating abandoned debug code that was never cleaned up.
627+
"""
585628

586-
def __init__(self, config):
587-
super().__init__(config)
629+
def __init__(self, base_model):
630+
super().__init__()
631+
self.base_model = base_model
632+
config = base_model.config
588633

589634
# BUG: Large projection layers from an abandoned experiment
590635
self.debug_projections = torch.nn.ModuleList(
@@ -600,7 +645,7 @@ def __init__(self, config):
600645

601646
def forward(self, input_ids=None, labels=None, **kwargs):
602647
# Run normal GPT-2 forward with hidden states
603-
outputs = super().forward(
648+
outputs = self.base_model(
604649
input_ids=input_ids,
605650
labels=labels,
606651
output_hidden_states=True,
@@ -680,14 +725,9 @@ def run_training_with_bug(snapshot_path, num_steps=3):
680725
device = torch.device("cuda")
681726

682727
print("Loading buggy model with debug overhead...")
683-
config = GPT2Config.from_pretrained("gpt2")
684-
model = GPT2WithDebugOverhead(config).to(device)
685-
686-
# Load pretrained weights
687-
pretrained = GPT2LMHeadModel.from_pretrained("gpt2")
688-
model.load_state_dict(pretrained.state_dict(), strict=False)
689-
del pretrained
690-
torch.cuda.empty_cache()
728+
# Load pretrained GPT-2 and wrap it with the debug overhead
729+
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
730+
model = GPT2WithDebugOverhead(base_model).to(device)
691731

692732
model.train()
693733

@@ -745,35 +785,50 @@ def run_training_with_bug(snapshot_path, num_steps=3):
745785
print("Training with debug projection overhead (BUG)")
746786
print("=" * 60)
747787

748-
try:
749-
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
750-
except (AttributeError, ValueError) as e:
751-
# Handle transformers version compatibility issues
752-
print(f"Note: Skipping buggy model demo due to transformers compatibility: {e}")
753-
buggy_memory = baseline_memory_debug
788+
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
754789

755790
######################################################################
756791
# Use Mosaic to Find the Problem
757792
# -------------------------------
758793
#
759794
# Analyze both snapshots to identify the source of extra memory usage.
795+
# We'll run Mosaic's peak memory analysis on each snapshot separately.
796+
797+
######################################################################
798+
# Analyze the Baseline (Clean) Snapshot
799+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
760800

761801
if HAS_CUDA and HAS_MOSAIC_CLI:
762-
print("\n" + "=" * 60)
802+
print("=" * 60)
763803
print("MOSAIC: Analyzing the Baseline Snapshot")
764804
print("=" * 60)
765805

766-
subprocess.run(
806+
result = subprocess.run(
767807
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_debug_baseline.pickle"],
808+
capture_output=True,
809+
text=True,
768810
)
811+
print(result.stdout)
812+
if result.stderr:
813+
print(result.stderr)
769814

770-
print("\n" + "=" * 60)
815+
######################################################################
816+
# Analyze the Buggy Snapshot
817+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
818+
819+
if HAS_CUDA and HAS_MOSAIC_CLI:
820+
print("=" * 60)
771821
print("MOSAIC: Analyzing the Buggy Snapshot")
772822
print("=" * 60)
773823

774-
subprocess.run(
824+
result = subprocess.run(
775825
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_with_bug.pickle"],
826+
capture_output=True,
827+
text=True,
776828
)
829+
print(result.stdout)
830+
if result.stderr:
831+
print(result.stderr)
777832

778833
######################################################################
779834
# Analyzing The Mosaic Output
@@ -783,8 +838,7 @@ def run_training_with_bug(snapshot_path, num_steps=3):
783838
# memory allocation. Let's look at how to find abandoned or unnecessary code
784839
# that's bloating the memory.
785840
#
786-
# 1. Optimizer State Allocations Delta
787-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
841+
# **1. Optimizer State Allocations Delta**
788842
#
789843
# In the buggy snapshot output, we can see that the first two stack traces
790844
# represent the **optimizer state allocations** (like ``zeros_like`` for Adam
@@ -809,11 +863,10 @@ def run_training_with_bug(snapshot_path, num_steps=3):
809863
# - 148 calls
810864
# - 0.464 GB + 0.464 GB
811865
#
812-
# **What this tells us:** The optimizer is tracking more tensors! This is your
866+
# What this tells us: The optimizer is tracking more tensors! This is your
813867
# first clue that there are extra parameters or tensors in the computation graph.
814868
#
815-
# 2. Additional Activation Allocations
816-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
869+
# **2. Additional Activation Allocations**
817870
#
818871
# The buggy version shows **extra allocations** that don't appear in the
819872
# baseline model. Scrolling down the Mosaic output of the buggy model we can

0 commit comments

Comments
 (0)