-
Notifications
You must be signed in to change notification settings - Fork 4k
Expand file tree
/
Copy pathgpt_model.py
More file actions
853 lines (773 loc) · 37 KB
/
gpt_model.py
File metadata and controls
853 lines (773 loc) · 37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from collections import OrderedDict
from typing import Any, Callable, Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.inference.utils import InferenceMode
from megatron.core.models.common.embeddings import YarnRotaryEmbedding
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
MultimodalRotaryEmbedding,
RotaryEmbedding,
)
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
process_mtp_loss,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
WrappedTensor,
deprecate_inference_params,
is_using_quantization_scales,
)
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
pg_collection (ProcessGroupCollection): Model communication process groups
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal[
'learned_absolute', 'rope', 'mrope', 'yarn', 'none'
] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_block_spec: Optional[ModuleSpec] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
) -> None:
super().__init__(config=config, pg_collection=pg_collection)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.vp_stage = vp_stage
self.disable_param_offloading = True
if hasattr(self.config, 'position_embedding_type'):
self.position_embedding_type = self.config.position_embedding_type
else:
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
if hasattr(self.config, 'rotary_base'):
self.rotary_base = self.config.rotary_base
else:
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.mtp_block_spec = mtp_block_spec
self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank(
self.config, ignore_virtual=False, vp_stage=vp_stage
)
if self.pre_process or self.mtp_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
tp_group=self.pg_collection.tp,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
cp_group=self.pg_collection.cp,
)
elif self.position_embedding_type == 'yarn':
self.rotary_pos_emb = YarnRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
scaling_factor=getattr(self.config, "yarn_rotary_scaling_factor"),
original_max_position_embeddings=getattr(
self.config, "yarn_original_max_position_embeddings"
),
beta_fast=getattr(self.config, "yarn_beta_fast"),
beta_slow=getattr(self.config, "yarn_beta_slow"),
mscale=getattr(self.config, "yarn_mscale"),
mscale_all_dim=getattr(self.config, "yarn_mscale_all_dim"),
correction_range_round_to_int=getattr(
self.config, "yarn_correction_range_round_to_int"
),
use_cpu_initialization=self.config.use_cpu_initialization,
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = MultimodalRotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.mrope_section = self.config.mrope_section
assert (
self.mrope_section is not None
), "mrope require mrope_section setting, but we got None from TransformerConfig"
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
pg_collection=self.pg_collection,
vp_stage=vp_stage,
)
if self.mtp_process:
self.mtp = MultiTokenPredictionBlock(
config=self.config,
spec=self.mtp_block_spec,
vp_stage=vp_stage,
pg_collection=self.pg_collection,
)
self._setup_mtp_cuda_graphs()
# Output
if self.post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=(
config.embedding_init_method
if config.use_mup and not self.share_embeddings_and_output_weights
else config.init_method
),
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
tp_group=self.pg_collection.tp,
)
if self.pre_process or self.post_process or self.mtp_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
for name, module in self.named_modules():
if hasattr(module, 'finish_init'):
quant_config = get_quant_config_or_none(name, self.config.quant_recipe)
module.finish_init(quant_config)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def _preprocess(
self,
input_ids: Tensor,
position_ids: Tensor,
decoder_input: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
padding_mask: Optional[Tensor] = None,
):
"""Preprocesses inputs for the transformer decoder.
Applies embeddings to input tokens, or uses `decoder_input` from a previous
pipeline stage. Also sets up rotary positional embeddings.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
in_inference_mode = InferenceMode.is_active()
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
if padding_mask is not None:
assert padding_mask.shape == input_ids.shape, (
f"padding_mask shape {padding_mask.shape} does not match "
f"input_ids shape {input_ids.shape}"
)
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
if padding_mask is not None and self.config.sequence_parallel:
padding_mask = (
tensor_parallel.scatter_to_sequence_parallel_region(
padding_mask.transpose(0, 1).contiguous()
)
.transpose(0, 1)
.contiguous()
)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
# this is used to store combined cos/sin embeddings, exclusively for flash infer rope
rotary_pos_cos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
use_flash_infer_fused_rope = (
hasattr(inference_context, 'use_flashinfer_fused_rope')
and inference_context.use_flashinfer_fused_rope
)
if (
in_inference_mode
and inference_context is not None
and (self.config.flash_decode or use_flash_infer_fused_rope)
):
assert (
not self.config.flash_decode
) or inference_context.is_static_batching(), (
"Flash decode is only applicable to static batching."
)
# Flash decoding uses precomputed cos and sin for RoPE
if self.config.flash_decode:
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
elif use_flash_infer_fused_rope:
assert not self.mtp_process, "MTP not tested with flashinfer_fused_rope"
rotary_pos_cos_sin = self.rotary_pos_emb_cache.setdefault(
inference_context.max_sequence_length,
torch.cat(
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
-1,
),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None,
)
elif self.position_embedding_type == 'yarn':
if not InferenceMode.is_active() or not self.config.flash_decode:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb, _ = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None,
)
else:
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implemented in "
"YarnRotaryEmbedding yet."
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if not InferenceMode.is_active() or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(
position_ids,
self.mrope_section,
cp_group=packed_seq_params.cp_group if packed_seq_params is not None else None,
)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implemented in "
"MultimodalRotaryEmbedding yet."
)
if (
in_inference_mode
and inference_context is not None
and (
(
self.config.cuda_graph_impl == "local"
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
and inference_context.is_static_batching()
):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
else:
sequence_len_offset = None
if in_inference_mode:
# Clear the outputs for padding tokens when using dynamic batching with
# quantization scales to avoid corrupting amax calculations
if (
inference_context is not None
and inference_context.is_dynamic_batching()
and is_using_quantization_scales(self.config)
):
decoder_input[inference_context.padding_slice] = 0.0
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)
preproc_output = (
decoder_input,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
padding_mask,
)
if rotary_pos_cos_sin is not None:
# only in the case of flashinfer fused rope will we
# return this extra tensor
# this is for backwards compatibility with
# legacy unit tests, which break if you
# return a 7 tuple instead of 6.
preproc_output += (rotary_pos_cos_sin,)
return preproc_output
def preprocess_for_fine_grained_offloading(self):
"""Preprocess for fine-grained activation offloading."""
off_interface.init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size,
vp_stage=self.vp_stage,
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
off_interface.mark_not_offloadable(param)
if self.mtp_process:
for param in self.mtp.parameters():
off_interface.mark_not_offloadable(param)
if self.post_process:
for param in self.output_layer.parameters():
off_interface.mark_not_offloadable(param)
self.disable_param_offloading = False
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
output_processor: Optional[Callable[..., Tensor]] = None,
output_processor_context: Optional[Any] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
padding_mask (Tensor, optional): Padding mask for MoE routing.
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
Only used for MoE layers to exclude padding tokens from routing computations.
output_processor (Callable, optional): Custom postprocess hook that receives
decoder hidden states and output-layer helpers, then returns the model output.
output_processor_context (Any, optional): User-defined context object forwarded to
`output_processor`.
"""
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
inference_context = deprecate_inference_params(inference_context, inference_params)
preproc_output = self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
padding_mask=padding_mask,
)
(
decoder_input,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
padding_mask,
) = preproc_output[:6]
rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
rotary_pos_cos_sin=rotary_pos_cos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
padding_mask=padding_mask,
**(extra_block_kwargs or {}),
)
return self._postprocess(
hidden_states=hidden_states,
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
mtp_in_postprocess=self.mtp_process,
loss_mask=loss_mask,
decoder_input=decoder_input,
attention_mask=attention_mask,
padding_mask=padding_mask,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
output_processor=output_processor,
output_processor_context=output_processor_context,
)
def _postprocess(
self,
hidden_states,
input_ids,
position_ids,
labels,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
mtp_in_postprocess=None,
loss_mask=None,
decoder_input=None,
attention_mask=None,
padding_mask=None,
inference_params=None,
packed_seq_params=None,
sequence_len_offset=None,
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
output_processor=None,
output_processor_context=None,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.
Applies Multi-Token Prediction if enabled, generates output logits through
the output layer, and computes language model loss when labels are provided.
"""
in_inference_mode = InferenceMode.is_active()
if in_inference_mode:
assert runtime_gather_output, "Inference must always gather TP logits"
# Check if speculative decoding is active. When it is, MTP must be
# computed *after* verification so that it is conditioned on verified
# tokens rather than stale speculative tokens from the previous step.
is_spec_decode = (
in_inference_mode
and inference_context is not None
and inference_context.is_dynamic_batching()
and inference_context.num_speculative_tokens > 0
)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if mtp_in_postprocess and not (in_inference_mode or is_spec_decode):
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=None, # MTP layers don't use KV cache
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
padding_mask=padding_mask,
embedding=self.embedding,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
if self.config.mtp_num_layers:
assert self.config.mtp_num_layers > 0
if in_inference_mode or is_spec_decode:
# Cache decoder hidden states for serial MTP computation
# after speculative token verification.
self._decoder_hidden_states_cache = hidden_states
else:
# In training/eval, use the utility function for processing MTP loss/scaling.
hidden_states = process_mtp_loss(
hidden_states=hidden_states,
labels=labels,
loss_mask=loss_mask,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
is_training=self.training,
compute_language_model_loss=self.compute_language_model_loss,
config=self.config,
cp_group=self.pg_collection.cp,
packed_seq_params=packed_seq_params,
scale_logits_fn=self._scale_logits if self.config.use_mup else None,
)
sequence_parallel_override = False
if output_processor is not None:
return output_processor(
hidden_states=hidden_states,
output_layer=self.output_layer,
output_weight=output_weight,
labels=labels,
loss_mask=loss_mask,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
runtime_gather_output=runtime_gather_output,
context=output_processor_context,
compute_language_model_loss=self.compute_language_model_loss,
scale_logits=self._scale_logits,
config=self.config,
)
if (
in_inference_mode
and inference_context is not None
and inference_context.config.materialize_only_last_token_logits
):
if inference_context.is_static_batching():
hidden_states = hidden_states[-1:, :, :]
else:
if self.output_layer.sequence_parallel:
# Perform the sequence parallel gather here instead of after the output layer
# because we need to slice the last token logits from the full view of the
# packed logits across all requests.
hidden_states = gather_from_sequence_parallel_region(
hidden_states, group=self.pg_collection.tp
)
self.output_layer.sequence_parallel = False
sequence_parallel_override = True
# Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction,
# then back to [S’, B, H] for the output layer.
reshaped = hidden_states.squeeze(1).unsqueeze(0)
hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
# Apply MuP output scaling to logits
logits = self._scale_logits(logits)
# Restore sequence parallel execution to the output layer if necessary.
if sequence_parallel_override:
assert (
in_inference_mode
and inference_context.is_dynamic_batching()
and inference_context.config.materialize_only_last_token_logits
)
self.output_layer.sequence_parallel = True
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def build_schedule_plan(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
*,
output_processor: Optional[Callable[..., Tensor]] = None,
output_processor_context: Optional[Any] = None,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_context (BaseInferenceContext, optional):
Inference context. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
output_processor (Callable, optional): Custom postprocess hook to run in the
schedule-plan postprocess node instead of the default logits/loss path.
output_processor_context (Any, optional): User-defined context object forwarded to
`output_processor`.
Returns:
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
"""
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan
return TransformerModelChunkSchedulePlan(
self,
input_ids,
position_ids,
attention_mask,
decoder_input,
labels,
packed_seq_params,
extra_block_kwargs,
runtime_gather_output,
loss_mask,
padding_mask,
output_processor=output_processor,
output_processor_context=output_processor_context,
)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility.
Removing extra state.
Tie word embeddings and output layer in mtp process stage.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict