-
Notifications
You must be signed in to change notification settings - Fork 365
Expand file tree
/
Copy pathtransformers.py
More file actions
1159 lines (1015 loc) · 46.1 KB
/
transformers.py
File metadata and controls
1159 lines (1015 loc) · 46.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Adapted from: https://github.com/ctlllll/axolotl/blob/f86767e/src/axolotl/monkeypatch/medusa_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support speculative decoding for huggingface models."""
import contextlib
import copy
from dataclasses import dataclass
from typing import Any
import torch
import transformers
from packaging.version import Version
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput
from transformers.utils.quantization_config import CompressedTensorsConfig
from ...export.plugins.hf_spec_export import (
EagleExporter,
EagleMedusaExporter,
SpeculativeDecodingExporter,
)
from ..eagle.conversion import EagleDMRegistry
from ..eagle.eagle_model import EagleModel
from ..eagle.utils import expand_mask, make_causal_mask
from ..medusa.conversion import MedusaDMRegistry
from ..medusa.medusa_model import MedusaModel
from ..utils import (
AcceptanceRateValidation,
ResBlock,
_setup_kimi_k2_decoder,
enable_cp_ttt_patch,
get_ttt_msk_func,
temporary_set_config_value,
)
__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"]
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
ENABLE_CP_TTT_PATCH = False
# module variable to cache attention mask for cp ttt
CACHED_SHARD_TTT_MASKS = {}
def _get_empty_cache(config):
"""Return an empty cache. Handle different versions of transformers for unit tests."""
if Version(transformers.__version__) >= Version("4.54"):
return DynamicCache(config=config)
else:
return DynamicCache()
@MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFMedusaModel(MedusaModel):
"""Medusa Model Class for huggingface models."""
def modify(self, medusa_num_heads=0, medusa_num_layers=0):
"""Constructor.
Args:
medusa_num_heads: number of medusa heads.
medusa_num_layers: number of ResBlock layers in each head.
"""
super().modify(medusa_num_heads=medusa_num_heads, medusa_num_layers=medusa_num_layers)
self.config.medusa = {
"num_medusa_heads": medusa_num_heads,
"num_medusa_layers": medusa_num_layers,
}
hidden_size = self.lm_head.weight.shape[-1]
vocab_size = self.lm_head.weight.shape[0]
# Create a list of Medusa heads
self.medusa_heads = nn.ModuleList(
[
nn.Sequential(
*([ResBlock(hidden_size) for _ in range(self.medusa_num_layers)]),
nn.Linear(hidden_size, vocab_size, bias=False),
)
for _ in range(self.medusa_num_heads)
]
)
# Ensure medusa_head's dtype and device align with the base_model
self.medusa_heads.to(self.lm_head.weight.dtype).to(self.lm_head.weight.device)
self.medusa_heads.device = self.lm_head.weight.device
if hasattr(self, "hf_device_map") and "lm_head" in self.hf_device_map:
self.hf_device_map["medusa_heads"] = self.hf_device_map["lm_head"]
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
freeze_base_model: bool = True,
medusa_heads_coefficient: float | None = 0.2,
medusa_decay_coefficient: float | None = 0.8,
**kwargs,
) -> Any:
"""Forward pass of the MedusaModel.
Returns:
torch.Tensor: A tensor containing predictions from all Medusa heads.
"""
# Pass input through the base model
with torch.no_grad() if freeze_base_model else contextlib.nullcontext():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
rcache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
medusa_logits = [
self.medusa_heads[i](hidden_states[:, slice_indices, :])
for i in range(self.medusa_num_heads)
]
if labels is not None:
loss = 0
loss_fct = CrossEntropyLoss()
# Base model loss
if not freeze_base_model:
loss_logits = logits.view(-1, logits.shape[-1])
loss_labels = labels.view(-1)
base_model_loss = loss_fct(loss_logits, loss_labels)
loss += base_model_loss
# Medusa loss
for i in range(self.medusa_num_heads):
labels = labels[..., 1:].contiguous()
loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
loss_logits = loss_logits.view(-1, loss_logits.shape[-1])
loss_labels = labels.view(-1)
loss += (
loss_fct(loss_logits, loss_labels)
* medusa_decay_coefficient**i
* medusa_heads_coefficient
)
else:
loss = None
return ModelOutput(
loss=loss,
logits=logits,
medusa_logits=medusa_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class ParallelDraft(nn.Module):
"""ParallelDraft module with multiple Medusa heads and a shared lm head."""
def __init__(self, hidden_size: int, vocab_size: int, num_heads: int = 1, num_layers: int = 1):
"""Init function for ParallelDraft."""
super().__init__()
self.medusa_heads = torch.nn.ModuleList(
[
nn.Sequential(
*([ResBlock(hidden_size) for _ in range(num_layers)]),
)
for _ in range(num_heads)
]
)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, x):
"""Forward function."""
output = []
for head in self.medusa_heads:
x_head = head(x)
output.append(self.lm_head(x_head))
return output
class EagleModule(nn.Module):
"""Eagle module used in EAGLE model."""
def __init__(self, config, decoder_layer_cls, bias=False):
"""Init function for EagleModule."""
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
if config.use_last_layernorm:
self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps)
# Optionally, we use a smaller vocab table for eagle module
if config.draft_vocab_size != config.vocab_size or config.has_lm_head:
# Need an extra lm_head for eagle module since vocab size is reduced.
assert config.draft_vocab_size <= config.vocab_size, (
"EAGLE module's vocab size should be <= base model vocab size!"
)
# Initialize the buffers to zero.
# Their values depend on specific tokenzier and calibrate dataset, and should be set in training script.
if config.draft_vocab_size < config.vocab_size:
self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64))
self.lm_head = nn.Linear(
config.hidden_size,
config.draft_vocab_size,
bias=False,
)
if config.use_aux_hidden_state:
# In EAGLE-3, the FC concentrate hidden states from multiple base model layers
self.fc = nn.Linear(
len(config.eagle_aux_hidden_state_layer_ids) * config.hidden_size,
config.hidden_size,
bias=bias,
)
first_layer_attn = self.layers[0].self_attn
# Expand first attn input dim since it accepts cat(input_embeds, hidden_states)
self._expand_first_attn_in_dim(first_layer_attn)
# EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states]
first_layer_attn.register_forward_pre_hook(
self._eagle3_attention_forward_pre_hook, with_kwargs=True
)
# In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation.
self.layers[0].input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layers[0].hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.config.parallel_draft_step > 1:
self.parallel_draft_heads = ParallelDraft(
config.hidden_size,
config.draft_vocab_size,
num_heads=self.config.parallel_draft_step - 1,
num_layers=self.config.parallel_draft_heads_num_layers,
)
def _maybe_init_rope(self):
if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _expand_first_attn_in_dim(self, first_layer_attn):
"""Modify qkv projection in first layer to accept 2h hidden size."""
# Find Linear modules to expand
eagle_attn_type = type(first_layer_attn)
if eagle_attn_type.__name__ == "LlamaAttention":
expand_modules = ["q_proj", "k_proj", "v_proj"]
elif eagle_attn_type.__name__ == "DeepseekV3Attention":
if first_layer_attn.q_lora_rank is None:
expand_modules = ["q_proj", "kv_a_proj_with_mqa"]
else:
expand_modules = ["q_a_proj", "kv_a_proj_with_mqa"]
else:
raise ValueError(f"Unsupported attention type: {eagle_attn_type}")
# Replace Linear with 2x input dim
for module in expand_modules:
original_linear = getattr(first_layer_attn, module)
assert isinstance(original_linear, nn.Linear), f"Module {module} is not a Linear"
setattr(
first_layer_attn,
module,
nn.Linear(
original_linear.in_features * 2,
original_linear.out_features,
bias=first_layer_attn.config.attention_bias,
),
)
def _eagle3_attention_forward_pre_hook(self, module, args, kwargs):
"""Concat input_embeds and hidden_states for EAGLE-3's first attention layer."""
if "hidden_states" not in kwargs:
raise ValueError("hidden_states not found in kwargs")
if self._input_embeds is None:
raise ValueError("self._input_embeds is None")
input_embeds = self._input_embeds
self._input_embeds = None
kwargs["hidden_states"] = torch.cat(
(input_embeds, self.layers[0].hidden_norm(kwargs["hidden_states"])), dim=-1
)
return args, kwargs
def forward(
self,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = False,
):
"""Forward function for EagleModule."""
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = hidden_states.device if hidden_states is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device)
# In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
# Also, we normalize input embeddings and hidden states before concatenating them.
# The default input norm in first layer attn will be disabled.
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
if self.config.eagle_decoder_type == "llama":
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
)
# For HF>= 4.54.0, the layer_outputs is a tensor, for older, it is a tuple.
if isinstance(layer_outputs, tuple):
hidden_states = layer_outputs[0]
else:
hidden_states = layer_outputs
pre_norm_h = hidden_states
post_norm_h = self.norm(hidden_states) if hasattr(self, "norm") else hidden_states
return post_norm_h, pre_norm_h, past_key_values
@dataclass
class EagleBaseModelOutput:
out_hiddens: torch.Tensor
aux_hiddens: torch.Tensor | None = None
logits: torch.Tensor | None = None
input_embeds: torch.Tensor | None = None
loss: torch.Tensor | None = None
@classmethod
def from_offline_dict(cls, d: dict):
return cls(
out_hiddens=d.get("base_model_hidden_states"),
aux_hiddens=d.get("aux_hidden_states"),
logits=d.get("base_model_logits"),
input_embeds=d.get("base_model_input_embeds"),
loss=None,
)
@EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFEagleModel(EagleModel):
"""Eagle Model Class for huggingface models."""
# Use functions to get base model parts without creating tied modules.
@property
def _base_model(self):
return self.get_submodule(self.base_model_path)
@property
def _base_model_embeddings(self):
return self.get_submodule(self.base_model_embeddings_path)
@property
def _base_model_lm_head(self):
return self.get_submodule(self.base_model_lm_head_path)
@property
def _base_llm_config(self):
"""Return the llm config for the base model, from LLM or VLM."""
return (
getattr(self.config, "text_config", None)
or getattr(self.config, "llm_config", None)
or self.config
)
@property
def _draft_model_config(self):
"""Return the llm config for the draft model."""
return self.eagle_config
def _enable_cp_ttt(self):
if self.training and not self.eagle_mix_hidden_states:
return enable_cp_ttt_patch()
return contextlib.nullcontext()
def _nvtx_range(self, name):
"""Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
if not self.eagle_enable_nvtx:
return contextlib.nullcontext()
try:
import torch.cuda.nvtx as nvtx
return nvtx.range(name)
except Exception as e:
print(f"Failed to create NVTX range {name}: {e}")
return contextlib.nullcontext()
def get_exporter(self) -> SpeculativeDecodingExporter:
"""Get the exporter for the draft model."""
exporter_cls = (
EagleExporter if self.eagle_config.parallel_draft_step <= 1 else EagleMedusaExporter
)
return exporter_cls(self)
def _find_base_model_parts(self):
"""Find model parts from different models and set base_{part}_path attributes."""
base_model_parts_mapping = {
"base_model_path": [
"model.language_model",
"model",
"backbone",
"language_model.backbone",
],
"base_model_embeddings_path": [
"model.embed_tokens",
"backbone.embeddings",
"language_model.backbone.embeddings",
"model.language_model.embed_tokens",
],
"base_model_lm_head_path": ["lm_head", "language_model.lm_head"],
}
for name, paths in base_model_parts_mapping.items():
found_submodule = False
for path in paths:
try:
submodule = self.get_submodule(path)
assert isinstance(submodule, torch.nn.Module)
print(f"Found {name} at {path}")
found_submodule = True
setattr(self, name, path)
break
except Exception:
continue
if not found_submodule:
raise ValueError(f"Part {name} not found in model")
def _set_default_aux_hidden_state_layers(self):
# Read a custom config attribute since we override num_hidden_layers for offline training
num_layers = self._base_llm_config.num_hidden_layers
if self.eagle_offline and (num_layers is None or num_layers <= 0):
num_layers = getattr(self.config, "num_orig_hidden_layers", 0)
self.eagle_config.eagle_aux_hidden_state_layer_ids = [
1,
max(0, num_layers // 2 - 1),
max(0, num_layers - 4),
]
self.eagle_config.eagle_aux_hidden_state_layer_ids = list(
set(self.eagle_config.eagle_aux_hidden_state_layer_ids)
)
def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None:
"""Collect auxiliary hidden states from base model intermediate layers, save them in attribute."""
hidden_states = (
output.clone().detach()
if isinstance(output, torch.Tensor)
else output[0].clone().detach()
)
self._aux_hidden_states.append(hidden_states)
def pop_and_gather_aux_hiddens(self):
"""Pop auxiliary hidden states from base model and gather them on the draft model device."""
if not self.eagle_config.use_aux_hidden_state:
return None
# In PTQ, forward method will be called with try and except to find max batch size.
# This leads to uncleared aux hidden states in the front of the list.
# To fix it, we only return the last num_aux_h items in the list.
num_aux_h = len(self.eagle_config.eagle_aux_hidden_state_layer_ids)
aux_h_list = self._aux_hidden_states[-num_aux_h:]
self._aux_hidden_states.clear()
# Gather aux hidden states on the draft model device
aux_hiddens = torch.cat(
[h.to(self.eagle_module.fc.weight.device) for h in aux_h_list], dim=-1
)
return aux_hiddens
def _get_eagle_device(self):
"""Return the device where we should place eagle module."""
if self.eagle_offline:
# For offline training, the base model has no layers.
# Read the device from the base model lm_head instead.
return self._base_model_lm_head.weight.device
else:
# When there is a base model, put eagle on the last layer's device.
base_model_last_layer = self._base_model.layers[-1]
return next(base_model_last_layer.parameters()).device
def modify(
self,
config,
):
"""Constructor.
Args:
config: The config for eagle decoder layers.
"""
super().modify(config)
if self.eagle_decoder_type == "llama":
# Use default eagle config
decoder_cls = LlamaDecoderLayer
elif self.eagle_decoder_type == "kimik2":
decoder_cls = _setup_kimi_k2_decoder()
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
self.eagle_config.eagle_decoder_type = self.eagle_decoder_type
# Hidden size and vocab size must match base model
self.eagle_config.hidden_size = self._base_llm_config.hidden_size
self.eagle_config.vocab_size = self._base_llm_config.vocab_size
self.eagle_config.max_position_embeddings = self._base_llm_config.max_position_embeddings
self.eagle_config.draft_vocab_size = getattr(
self.eagle_config, "draft_vocab_size", self.eagle_config.vocab_size
)
if self.eagle_config._attn_implementation is None:
self.eagle_config._attn_implementation = "sdpa"
# Patch for Kimi-K2-Thinking, avoid quantizing drafter
quant_config = getattr(self.config, "quantization_config", None)
if isinstance(quant_config, CompressedTensorsConfig):
quant_config.ignore.append("re:.*eagle_module.*")
# Set default aux_hidden_state layers
if (
self.eagle_config.use_aux_hidden_state
and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0
):
self._set_default_aux_hidden_state_layers()
# Freeze all parameters
if self.eagle_freeze_base_model:
for name, param in self.named_parameters():
param.requires_grad = False
self.eagle_module = EagleModule(
self.eagle_config,
decoder_cls,
)
# find base model, lm head, and embeddings paths
self._find_base_model_parts()
self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device())
# EAGLE-3 auxiliary hidden_states
if (not self.eagle_offline) and self.eagle_config.use_aux_hidden_state:
self._aux_hidden_states = []
for layer_idx, layer in enumerate(self._base_model.layers):
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
# delete base model layers for offline training
if self.eagle_offline:
self._base_model._modules.pop("layers")
# NOTE: this is a temporary hack to bypass hf trainer check:
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
self.is_quantized = False
if self.eagle_use_torch_compile:
self._activate_torch_compile()
self._cached_attn_blk_masks = {}
def _activate_torch_compile(self):
import torch._dynamo
torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
compile_targets = [
("_prepare_eagle_inputs", {}),
("_eagle_forward", {"mode": "max-autotune"}),
("_eagle_loss", {"fullgraph": True}),
]
for name, kwargs in compile_targets:
try:
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
except Exception: # noqa: PERF203
print(f"Disabling torch.compile for {name} due to compilation error.")
def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
self._cached_attn_blk_masks.update(
{ttt_step: self._compute_ttt_attention_mask(batch_size, seq_length, ttt_step)}
)
return self._cached_attn_blk_masks[ttt_step]
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, past_key_values_length, device, dtype
):
"""Expand the 2-D attention mask to 4-D and apply causal mask."""
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
# construct causal mask
if input_shape[-1] > 1:
combined_attention_mask = make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
)
# merge causal mask with padding mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to(
device
)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def _prepare_eagle_inputs(
self,
input_ids,
attention_mask,
position_ids,
eagle_cache,
base_outputs,
):
"""Helper function to prepare eagle inputs for the 0th eagle forward pass."""
b, seq_length = input_ids.shape
past_kv_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0
seq_len_with_past = seq_length + past_kv_len
# Prepare eagle_input_embeds: Shift left 1 token
with torch.no_grad():
if base_outputs.input_embeds is None:
eagle_input_embeds = self._base_model_embeddings(input_ids.roll(-1, 1))
else:
eagle_input_embeds = base_outputs.input_embeds.roll(-1, 1)
# Prepare eagle_input_hiddens
if self.eagle_config.use_aux_hidden_state:
# concat base model intermediate (pre-norm) hiddens
eagle_input_hiddens = self.eagle_module.fc(base_outputs.aux_hiddens)
else:
# use base model output (post-norm)hiddens
eagle_input_hiddens = base_outputs.out_hiddens
# Prepare attention_mask
if attention_mask is None:
eagle_attention_mask = torch.ones( # default: all tokens are valid
(b, seq_len_with_past), dtype=torch.bool, device=eagle_input_hiddens.device
)
else:
eagle_attention_mask = attention_mask.roll(-1, 1) # Shift left 1 token
# Expand the 2-D attention mask to 4-D and apply causal mask.
eagle_attention_mask = self._prepare_decoder_attention_mask(
eagle_attention_mask,
(b, seq_length),
past_kv_len,
eagle_input_hiddens.device,
eagle_input_hiddens.dtype,
)
# Prepare position_ids
if position_ids is None:
eagle_position_ids = (
torch.arange(
past_kv_len,
seq_len_with_past,
dtype=torch.long,
device=eagle_input_hiddens.device,
)
.unsqueeze(0)
.view(-1, seq_length)
)
else:
eagle_position_ids = position_ids.view(-1, seq_length).long()
base_model_logits = base_outputs.logits
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach()
return (
eagle_input_embeds,
eagle_input_hiddens,
eagle_attention_mask,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
)
def _compute_ttt_attention_mask(
self, batch_size, seq_length, ttt_step
) -> BlockMask | torch.Tensor:
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
msk_func = get_ttt_msk_func(seq_length, ttt_step)
dtypemin = torch.finfo(self._base_llm_config.dtype).min
q_len = seq_length
kv_len = seq_length * (1 + ttt_step)
if self.eagle_config._attn_implementation == "flex_attention":
# Return block mask for flex attention
block_mask = create_block_mask(msk_func, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len)
return block_mask
else:
# Return tensor mask for non-flex attention
tensor_mask = msk_func(
None,
None,
torch.arange(q_len).view(1, 1, q_len, 1),
torch.arange(kv_len).view(1, 1, 1, kv_len),
).to(self.device)
tensor_mask = torch.full_like(
tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device
).masked_fill(~tensor_mask, dtypemin)
# Note: (hg) repeat mask for kimi-k2 compatibility
tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1)
return tensor_mask
def _base_model_forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values,
freeze_base_model,
labels,
**kwargs,
):
with torch.no_grad() if freeze_base_model else contextlib.nullcontext():
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_hidden_states=True,
**kwargs,
)
past_key_values = getattr(outputs, "past_key_values", None)
base_input_embeds = outputs.hidden_states[0]
base_model_hidden_states = outputs.hidden_states[-1]
base_model_logits = outputs.logits
# Optionally, compute base model loss when we want to tune the base model.
base_model_loss = None
if not freeze_base_model and labels is not None: # Base model loss
loss_fct = CrossEntropyLoss()
loss_logits = base_model_logits.view(-1, base_model_logits.shape[-1])
labels = labels.view(-1)
base_model_loss = loss_fct(loss_logits, labels)
return EagleBaseModelOutput(
input_embeds=base_input_embeds,
aux_hiddens=self.pop_and_gather_aux_hiddens(),
out_hiddens=base_model_hidden_states,
logits=base_model_logits,
loss=base_model_loss,
), past_key_values
def _map_logits_to_draft_vocab(self, full_logits):
assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized"
reverse_mapping = (
torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
+ self.eagle_module.d2t
)
return full_logits[:, :, reverse_mapping]
def _eagle_forward(
self,
eagle_input_hidden_states,
inputs_embeds,
attention_mask,
position_ids,
eagle_cache=None,
):
eagle_postnorm_h, eagle_prenorm_h, eagle_cache = self.eagle_module(
eagle_input_hidden_states,
inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
past_key_values=eagle_cache,
)
eagle_lm_head = (
self.eagle_module.lm_head
if hasattr(self.eagle_module, "lm_head")
else self._base_model_lm_head
)
eagle_logits = eagle_lm_head(eagle_postnorm_h)
draft_logits_list = [eagle_logits]
if self.eagle_config.parallel_draft_step > 1:
# Get additional draft logits from parallel draft heads
draft_logits = self.eagle_module.parallel_draft_heads(eagle_postnorm_h)
draft_logits_list += draft_logits
return eagle_postnorm_h, eagle_prenorm_h, draft_logits_list, eagle_cache
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int = 0,
loss_mask: torch.Tensor | None = None,
**kwargs,
) -> Any:
"""Forward pass of the EagleModel.
Returns:
loss: Loss of base model or eagle model.
logits: Base model logits.
past_key_values: Base model past key values with eagle cache attached.
hidden_states: Base model hidden states.
train_acc: Drafter training accuracies.
"""
eagle_cache = getattr(past_key_values, "eagle_cache", None)
if self.training:
assert past_key_values is None, "past_key_values should be None in training"
if loss_mask is None:
# By default, mask out padding tokens in loss computation
loss_mask = (
attention_mask.clone().detach()
if attention_mask is not None
else torch.ones_like(input_ids, dtype=torch.bool)
)
# ====First, run base model forward====
if self.eagle_offline:
# Parse base model outputs forwarded from teacher
assert "base_model_outputs" in kwargs
base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"])
if base_outputs.logits is None:
base_outputs.logits = self.lm_head(base_outputs.out_hiddens)
past_key_values = None
else:
with self._nvtx_range("base_model_forward"):
base_outputs, past_key_values = self._base_model_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
self.eagle_freeze_base_model,
labels,
**kwargs,
)
if not isinstance(past_key_values, Cache):
past_key_values = _get_empty_cache(self._base_llm_config)
if not isinstance(eagle_cache, Cache):
eagle_cache = _get_empty_cache(self.eagle_module.config)
past_key_values.eagle_cache = eagle_cache
# ====Prepare inputs for the first eagle forward pass====
eagle_loss = None
num_parallel = self.eagle_config.parallel_draft_step
num_ttt = self.eagle_ttt_steps
train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
b, seq_length, _ = base_outputs.out_hiddens.shape
with self._nvtx_range("prepare_eagle_inputs"):
(
eagle_input_embeds,
eagle_input_hiddens,
eagle_attn_mask_0,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
) = self._prepare_eagle_inputs(
input_ids,
attention_mask,
position_ids,
eagle_cache,
base_outputs,
)
self.eagle_module._maybe_init_rope()
# ====Run eagle forward with extra training-time-test steps====
for ttt_step in range(self.eagle_ttt_steps):
# TODO: (hg) during cp training, this mask is not used. Maybe turn it off then.
eagle_attention_mask = (
eagle_attn_mask_0
if self.eagle_mix_hidden_states or ttt_step == 0
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
)
with self._enable_cp_ttt(), self._nvtx_range("eagle_forward"):
_, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hiddens,
eagle_input_embeds,
eagle_attention_mask,
eagle_position_ids,
None if self.eagle_mix_hidden_states else eagle_cache,
)
eagle_output_hiddens = eagle_output_hiddens.roll(1, 1)
if self.eagle_mix_hidden_states:
batch_size, seq_len_s, _ = eagle_input_hiddens.shape
num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))
# Randomly select positions for each batch to replace
rand_indices = torch.rand(
batch_size, seq_len_s, device=eagle_input_hiddens.device
).argsort(dim=1)[:, :num_to_replace]
# Clone to avoid inplace modification that breaks autograd
eagle_input_hiddens = eagle_input_hiddens.clone()
batch_indices = torch.arange(batch_size)[:, None]
eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
batch_indices, rand_indices
]
else:
eagle_input_hiddens = eagle_output_hiddens
for i in range(self.eagle_config.parallel_draft_step):