175175
176176# Install dependencies if needed
177177try :
178- from transformers import GPT2Config , GPT2LMHeadModel , GPT2Tokenizer
178+ from transformers import GPT2LMHeadModel , GPT2Tokenizer
179179 from transformers .modeling_outputs import CausalLMOutputWithCrossAttentions
180180except ImportError :
181181 subprocess .check_call (
182182 [sys .executable , "-m" , "pip" , "install" , "-q" , "transformers" ]
183183 )
184- from transformers import GPT2Config , GPT2LMHeadModel , GPT2Tokenizer
184+ from transformers import GPT2LMHeadModel , GPT2Tokenizer
185185 from transformers .modeling_outputs import CausalLMOutputWithCrossAttentions
186186
187187try :
@@ -481,10 +481,49 @@ def run_training_ac(
481481 print ("\n Note: Mosaic profile generation encountered issues." )
482482 print ("This may happen if running in an environment without full Mosaic support." )
483483
484+ ######################################################################
485+ # Download Generated Files (Google Colab)
486+ # ----------------------------------------
487+ #
488+ # If running in Google Colab, you can download the generated snapshot
489+ # and profile files using the following code:
490+ #
491+ # .. code-block:: python
492+ #
493+ # from google.colab import files
494+ #
495+ # print("Downloading memory snapshots and profiles...")
496+ # files.download('snapshot_baseline.pickle')
497+ # files.download('snapshot_with_ac.pickle')
498+ # files.download('profile_baseline.html')
499+ # files.download('profile_with_ac.html')
500+ #
501+
484502######################################################################
485503# Results Interpretation: Activation Checkpointing
486504# -------------------------------------------------
487505#
506+ # The generated HTML profiles visualize memory usage over time, with
507+ # allocations colored by category. Here's what the profiles look like:
508+ #
509+ # .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-without-ac.png
510+ # :alt: GPT-2 memory profile without activation checkpointing
511+ # :align: center
512+ # :width: 600px
513+ #
514+ # **Baseline (without activation checkpointing):** Notice the large
515+ # activation memory (shown in one color) that persists throughout
516+ # the forward pass.
517+ #
518+ # .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-with-ac.png
519+ # :alt: GPT-2 memory profile with activation checkpointing
520+ # :align: center
521+ # :width: 600px
522+ #
523+ # **With activation checkpointing:** Activation memory is significantly
524+ # reduced as intermediate activations are discarded and recomputed
525+ # during the backward pass.
526+ #
488527# What We Observed
489528# ~~~~~~~~~~~~~~~~
490529#
@@ -580,11 +619,17 @@ def run_training_ac(
580619# debugging, but forgot to remove them before training.
581620
582621
583- class GPT2WithDebugOverhead (GPT2LMHeadModel ):
584- """GPT2 with abandoned 'feature analysis' code that bloats peak memory."""
622+ class GPT2WithDebugOverhead (torch .nn .Module ):
623+ """GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
624+
625+ This wrapper adds extra projection layers that consume memory but serve no
626+ purpose - simulating abandoned debug code that was never cleaned up.
627+ """
585628
586- def __init__ (self , config ):
587- super ().__init__ (config )
629+ def __init__ (self , base_model ):
630+ super ().__init__ ()
631+ self .base_model = base_model
632+ config = base_model .config
588633
589634 # BUG: Large projection layers from an abandoned experiment
590635 self .debug_projections = torch .nn .ModuleList (
@@ -600,7 +645,7 @@ def __init__(self, config):
600645
601646 def forward (self , input_ids = None , labels = None , ** kwargs ):
602647 # Run normal GPT-2 forward with hidden states
603- outputs = super (). forward (
648+ outputs = self . base_model (
604649 input_ids = input_ids ,
605650 labels = labels ,
606651 output_hidden_states = True ,
@@ -680,14 +725,9 @@ def run_training_with_bug(snapshot_path, num_steps=3):
680725 device = torch .device ("cuda" )
681726
682727 print ("Loading buggy model with debug overhead..." )
683- config = GPT2Config .from_pretrained ("gpt2" )
684- model = GPT2WithDebugOverhead (config ).to (device )
685-
686- # Load pretrained weights
687- pretrained = GPT2LMHeadModel .from_pretrained ("gpt2" )
688- model .load_state_dict (pretrained .state_dict (), strict = False )
689- del pretrained
690- torch .cuda .empty_cache ()
728+ # Load pretrained GPT-2 and wrap it with the debug overhead
729+ base_model = GPT2LMHeadModel .from_pretrained ("gpt2" )
730+ model = GPT2WithDebugOverhead (base_model ).to (device )
691731
692732 model .train ()
693733
@@ -745,35 +785,50 @@ def run_training_with_bug(snapshot_path, num_steps=3):
745785 print ("Training with debug projection overhead (BUG)" )
746786 print ("=" * 60 )
747787
748- try :
749- buggy_memory = run_training_with_bug ("snapshot_with_bug.pickle" , num_steps = 3 )
750- except (AttributeError , ValueError ) as e :
751- # Handle transformers version compatibility issues
752- print (f"Note: Skipping buggy model demo due to transformers compatibility: { e } " )
753- buggy_memory = baseline_memory_debug
788+ buggy_memory = run_training_with_bug ("snapshot_with_bug.pickle" , num_steps = 3 )
754789
755790######################################################################
756791# Use Mosaic to Find the Problem
757792# -------------------------------
758793#
759794# Analyze both snapshots to identify the source of extra memory usage.
795+ # We'll run Mosaic's peak memory analysis on each snapshot separately.
796+
797+ ######################################################################
798+ # Analyze the Baseline (Clean) Snapshot
799+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
760800
761801if HAS_CUDA and HAS_MOSAIC_CLI :
762- print ("\n " + " =" * 60 )
802+ print ("=" * 60 )
763803 print ("MOSAIC: Analyzing the Baseline Snapshot" )
764804 print ("=" * 60 )
765805
766- subprocess .run (
806+ result = subprocess .run (
767807 ["mosaic_get_memory_usage_peak" , "--snapshot" , "snapshot_debug_baseline.pickle" ],
808+ capture_output = True ,
809+ text = True ,
768810 )
811+ print (result .stdout )
812+ if result .stderr :
813+ print (result .stderr )
769814
770- print ("\n " + "=" * 60 )
815+ ######################################################################
816+ # Analyze the Buggy Snapshot
817+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
818+
819+ if HAS_CUDA and HAS_MOSAIC_CLI :
820+ print ("=" * 60 )
771821 print ("MOSAIC: Analyzing the Buggy Snapshot" )
772822 print ("=" * 60 )
773823
774- subprocess .run (
824+ result = subprocess .run (
775825 ["mosaic_get_memory_usage_peak" , "--snapshot" , "snapshot_with_bug.pickle" ],
826+ capture_output = True ,
827+ text = True ,
776828 )
829+ print (result .stdout )
830+ if result .stderr :
831+ print (result .stderr )
777832
778833######################################################################
779834# Analyzing The Mosaic Output
@@ -783,8 +838,7 @@ def run_training_with_bug(snapshot_path, num_steps=3):
783838# memory allocation. Let's look at how to find abandoned or unnecessary code
784839# that's bloating the memory.
785840#
786- # 1. Optimizer State Allocations Delta
787- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
841+ # **1. Optimizer State Allocations Delta**
788842#
789843# In the buggy snapshot output, we can see that the first two stack traces
790844# represent the **optimizer state allocations** (like ``zeros_like`` for Adam
@@ -809,11 +863,10 @@ def run_training_with_bug(snapshot_path, num_steps=3):
809863# - 148 calls
810864# - 0.464 GB + 0.464 GB
811865#
812- # ** What this tells us:** The optimizer is tracking more tensors! This is your
866+ # What this tells us: The optimizer is tracking more tensors! This is your
813867# first clue that there are extra parameters or tensors in the computation graph.
814868#
815- # 2. Additional Activation Allocations
816- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
869+ # **2. Additional Activation Allocations**
817870#
818871# The buggy version shows **extra allocations** that don't appear in the
819872# baseline model. Scrolling down the Mosaic output of the buggy model we can
0 commit comments