-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathmodeling_llada.py
More file actions
1937 lines (1677 loc) · 78 KB
/
modeling_llada.py
File metadata and controls
1937 lines (1677 loc) · 78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA
from __future__ import annotations
import logging
import math
import sys
from abc import abstractmethod
from collections import defaultdict
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
cast,
)
from dataclasses import fields
from typing import List, Optional, Tuple, Union
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
FLEX_ATTN_AVAILABLE = True
except:
FLEX_ATTN_AVAILABLE = False
import torch
import torch.backends.cuda
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import einsum, values_copy
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModel
from transformers.cache_utils import Cache
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .configuration_llada import (
LLaDAConfig,
StrEnum,
InitFnType,
ActivationType,
BlockType,
LayerNormType,
ModelConfig,
ActivationCheckpointingStrategy,
)
from einops import rearrange
if sys.version_info.minor > 8:
from collections.abc import MutableMapping
elif sys.version_info.minor == 8:
from typing import MutableMapping
else:
raise SystemExit("This script supports Python 3.8 or higher")
import re
from .tp_linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
__all__ = [
"LayerNormBase",
"LayerNorm",
"RMSLayerNorm",
"GemmaRMSLayerNorm",
"RotaryEmbedding",
"Activation",
"GELU",
"ReLU",
"SwiGLU",
"LLaDABlock",
"LLaDASequentialBlock",
"LLaDAModel",
"LLaDAOutput",
"LLaDAGenerateOutput",
]
def _all_gather_cat(
tensor: torch.Tensor,
dim: int = 1,
group: Optional[dist.ProcessGroup] = None,
normal_len: int = 0,
last_len: int = 0,
) -> torch.Tensor:
"""
Gather tensors along `dim` from all ranks and concatenate them.
Only the last chunk may be shorter than `normal_len`; all others are exactly `normal_len`.
Args:
tensor: local tensor on current rank
dim: dimension along which to concatenate
normal_len: length of the first (world_size-1) ranks along `dim`
last_len: length of the last rank along `dim`
Returns:
Concatenated tensor of shape [total_len, ...] along `dim`
"""
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
if world_size == 1:
return tensor
# 1. Move the concatenation dimension to 0 for easier all_gather
tensor = tensor.movedim(dim, 0) # [L_local, ...]
L_local = tensor.size(0)
# 2. Compute global length across all ranks
total_len = normal_len * (world_size - 1) + last_len
# 3. Pre-allocate receive buffers (same shape for all ranks, sized for the largest chunk)
max_len = max(normal_len, last_len)
gather_list = [
torch.empty([max_len] + list(tensor.shape[1:]),
dtype=tensor.dtype,
device=tensor.device)
for _ in range(world_size)
]
# 4. Copy local data into the corresponding buffer (only first L_local rows are valid)
gather_list[rank][:L_local] = tensor
# 5. All-gather (communicate only valid parts)
dist.all_gather(gather_list, gather_list[rank], group=group)
# 6. Trim padding and concatenate
gathered = torch.cat(gather_list, dim=0)[:total_len]
# 7. Move dimension back to original position
return gathered.movedim(0, dim)
class H2Embed:
def __init__(self, embedding: nn.Embedding, tau: float = 1.0):
"""
W_e : token embedding weights [V, d]
tau : temperature; lower values yield sharper distributions
"""
self.embedding = embedding
self.W_e = embedding.weight
self.tau = tau
self.sp_size = 1 # no sequence parallel by default
def __call__(
self,
x: torch.Tensor,
mask_index: Optional[torch.Tensor] = None,
logits: Optional[torch.Tensor] = None,
iter_cont_weight: float = 0.0
) -> torch.Tensor:
"""
Args:
x: [B, L] token ids
mask_index: [B, L] bool tensor, True where continuous embedding should be used
logits: [B, L, V] logits used to produce continuous embeddings
iter_cont_weight: blending weight between continuous and discrete embeddings
Returns:
Embedded representations [B, L, d]
"""
if torch.distributed.is_initialized():
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
else:
rank = 0
world_size = 1
seq_len = x.shape[1]
# If sequence parallel is enabled, each rank handles a slice of the sequence
if self.sp_size > 1:
normal_seq_len = (seq_len + self.sp_size - 1) // self.sp_size
last_seq_len = seq_len - normal_seq_len * (self.sp_size - 1)
part_start = normal_seq_len * rank
part_end = min(normal_seq_len * (rank + 1), seq_len)
x_part = x[:, part_start:part_end]
if mask_index is not None:
mask_part = mask_index[:, part_start:part_end]
logits_part = logits[:, part_start:part_end] if logits is not None else None
else:
mask_part = None
logits_part = None
else:
x_part = x
mask_part = mask_index
logits_part = logits
# Base discrete embedding
result_part = self.embedding(x_part)
# Replace selected positions with continuous embeddings
if mask_part is not None and logits_part is not None:
prob = torch.softmax(logits_part / self.tau, dim=-1) # [B, L_part, V]
input_embeds_h = prob @ self.W_e # [B, L_part, d]
# Blend continuous and discrete embeddings
result_part = torch.where(
mask_part.unsqueeze(-1),
iter_cont_weight * input_embeds_h + 1 * result_part,
result_part
)
# 4. Gather and concatenate sequence slices across ranks
if self.sp_size > 1:
out = _all_gather_cat(
result_part,
dim=1,
group=None,
normal_len=normal_seq_len,
last_len=last_seq_len
)
else:
out = result_part
return out
def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise"],
rank:int=0, world_size:int=1
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
new_module = vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
tp_rank=rank,
tp_size=world_size,
return_bias=False,
)
new_module = new_module.to(dtype=linear.weight.dtype, device=linear.weight.device)
return new_module
log = logging.getLogger(__name__)
@torch.compile()
def scaled_dot_product_attention(q, k, v, mask=None, attn_mask=None, dropout_p=0.0, is_causal=False):
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
class ModuleType(StrEnum):
in_module = "in"
out_module = "out"
emb = "emb"
final_out = "final_out"
def init_weights(
config: ModelConfig,
module: Union[nn.Linear, nn.Embedding],
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
type_of_module: Optional[ModuleType] = None,
) -> None:
"""
Initialize weights of a linear or embedding module.
:param config: The model config.
:param module: The linear or embedding submodule to initialize.
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
for fused layers.
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
``1 / sqrt(2 * (layer_id + 1))``.
"""
d = d if d is not None else config.d_model
if config.init_fn == InitFnType.normal:
std = config.init_std * std_factor
if config.init_cutoff_factor is not None:
cutoff_value = config.init_cutoff_factor * std
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif config.init_fn == InitFnType.mitchell:
std = std_factor / math.sqrt(d)
if layer_id is not None:
std = std / math.sqrt(2 * (layer_id + 1))
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
elif config.init_fn == InitFnType.kaiming_normal:
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
elif config.init_fn == InitFnType.fan_in:
std = std_factor / math.sqrt(d)
nn.init.normal_(module.weight, mean=0.0, std=std)
elif config.init_fn == InitFnType.full_megatron:
if type_of_module is None:
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
cutoff_factor = config.init_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
if type_of_module == ModuleType.in_module:
# for att_proj (same as QKV), ff_proj
std = config.init_std
elif type_of_module == ModuleType.out_module:
# for attn_out, ff_out
std = config.init_std / math.sqrt(2.0 * config.n_layers)
elif type_of_module == ModuleType.emb:
# positional embeddings (wpe)
# token embeddings (wte)
std = config.init_std
elif type_of_module == ModuleType.final_out:
# final output (ff_out)
std = config.d_model**-0.5
else:
raise RuntimeError(f"Unknown module type '{type_of_module}'")
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
else:
raise NotImplementedError(config.init_fn)
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
with torch.no_grad():
module.weight.div_(math.sqrt(2 * config.n_layers))
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
"""
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
"""
if check_neg_inf:
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
if check_pos_inf:
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
def activation_checkpoint_function(cfg: ModelConfig):
preserve_rng_state = (
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
)
from torch.utils.checkpoint import checkpoint
return partial(
checkpoint,
preserve_rng_state=preserve_rng_state,
use_reentrant=False,
)
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
"""
Cache for attention biases and other things that would normally be stored as buffers.
We avoid using buffers because we've run into various issues doing so with FSDP.
In general it appears the way FSDP handles buffers is not well-defined.
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
NaNs when they're synchronized due to casting or some other issue.
"""
def _non_meta_init_device(config: ModelConfig) -> torch.device:
if config.init_device is not None and config.init_device != "meta":
return torch.device(config.init_device)
else:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Dropout(nn.Dropout):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.p == 0.0:
return input
else:
return F.dropout(input, self.p, self.training, self.inplace)
class LayerNormBase(nn.Module):
def __init__(
self,
config: ModelConfig,
*,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
eps: float = 1e-05,
):
super().__init__()
self.config = config
self.eps = eps
self.normalized_shape = (size or config.d_model,)
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
use_bias = self.config.include_bias
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("bias", None)
self.register_parameter("weight", None)
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@classmethod
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
if config.layer_norm_type == LayerNormType.default:
return LayerNorm(config, size=size, low_precision=False, **kwargs)
elif config.layer_norm_type == LayerNormType.low_precision:
return LayerNorm(config, size=size, low_precision=True, **kwargs)
elif config.layer_norm_type == LayerNormType.rms:
return RMSLayerNorm(config, size=size, **kwargs)
elif config.layer_norm_type == LayerNormType.gemma_rms:
return GemmaRMSLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
# `is_autocast_cpu_enabled()` for CPU autocast.
# See https://github.com/pytorch/pytorch/issues/110966.
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
else:
return tensor
def reset_parameters(self):
if self.weight is not None:
torch.nn.init.ones_(self.weight) # type: ignore
if self.bias is not None:
torch.nn.init.zeros_(self.bias) # type: ignore
class LayerNorm(LayerNormBase):
"""
The default :class:`LayerNorm` implementation which can optionally run in low precision.
"""
def __init__(
self,
config: ModelConfig,
size: Optional[int] = None,
low_precision: bool = False,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-05,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
self.low_precision = low_precision
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.low_precision:
module_device = x.device
downcast_x = self._cast_if_autocast_enabled(x)
downcast_weight = (
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
)
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
with torch.autocast(enabled=False, device_type=module_device.type):
return F.layer_norm(
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
)
else:
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
class RMSLayerNorm(LayerNormBase):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation
"""
def __init__(
self,
config: ModelConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-5,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
og_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
x = x.to(og_dtype)
if self.weight is not None:
if self.bias is not None:
return self.weight * x + self.bias
else:
return self.weight * x
else:
return x
class GemmaRMSLayerNorm(LayerNormBase):
"""
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
"""
def __init__(
self,
config: ModelConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-5,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
og_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
x = x.to(og_dtype)
if self.weight is not None:
if self.bias is not None:
return x * (1 + self.weight) + self.bias
else:
return x * (1 + self.weight)
else:
return x
class RotaryEmbedding(nn.Module):
"""
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
"""
def __init__(self, config: ModelConfig, cache: BufferCache, tp_size: int = 1):
super().__init__()
self.config = config
self.__cache = cache
self.tp_size = tp_size
# Warm up cache.
self.rope_theta = config.rope_theta
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
if (
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
and pos_sin.shape[-2] >= seq_len
and pos_cos.shape[-2] >= seq_len
):
if pos_sin.device != device:
pos_sin = pos_sin.clone().to(device)
self.__cache["rope_pos_sin"] = pos_sin
if pos_cos.device != device:
pos_cos = pos_cos.clone().to(device)
self.__cache["rope_pos_cos"] = pos_cos
return pos_sin[:, :, :seq_len, :].clone(), pos_cos[:, :, :seq_len, :].clone()
with torch.autocast(device.type, enabled=False):
dim = self.config.d_model // self.config.n_heads
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = einsum("i , j -> i j", seq, inv_freq)
positions = torch.cat((freqs, freqs), dim=-1)
pos_sin = positions.sin()[None, None, :, :].clone()
pos_cos = positions.cos()[None, None, :, :].clone()
self.__cache["rope_pos_sin"] = pos_sin.clone()
self.__cache["rope_pos_cos"] = pos_cos.clone()
return pos_sin, pos_cos
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
B, nh, T, hs = x.size()
x = x.view(B, nh, T, 2, hs // 2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, block_end_index: Optional[torch.Tensor] = None, start_pos: int=0) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_full_precision:
q_, k_ = q.float(), k.float()
else:
q_, k_ = q, k
with torch.autocast(q.device.type, enabled=False):
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
pos_sin, pos_cos = self.get_rotary_embedding(key_len+start_pos, q_.device)
pos_sin = pos_sin.type_as(q_)
pos_cos = pos_cos.type_as(q_)
if block_end_index is None:
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, key_len - query_len+start_pos : key_len+start_pos, :],
pos_cos[:, :, key_len - query_len+start_pos : key_len+start_pos, :],
q_,
)
else:
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, block_end_index - query_len+start_pos : block_end_index+start_pos, :],
pos_cos[:, :, block_end_index - query_len+start_pos : block_end_index+start_pos, :],
q_,
)
k_ = self.apply_rotary_pos_emb(pos_sin[:, :, start_pos:], pos_cos[:, :, start_pos:], k_)
return q_.type_as(q), k_.type_as(k)
class Activation(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@property
@abstractmethod
def output_multiplier(self) -> float:
raise NotImplementedError
@classmethod
def build(cls, config: ModelConfig) -> Activation:
if config.activation_type == ActivationType.gelu:
return cast(Activation, GELU(approximate="none"))
elif config.activation_type == ActivationType.relu:
return cast(Activation, ReLU(inplace=False))
elif config.activation_type == ActivationType.silu:
return cast(Activation, SiLU(inplace=False))
elif config.activation_type == ActivationType.swiglu:
return SwiGLU(config)
else:
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
class GELU(nn.GELU):
@property
def output_multiplier(self) -> float:
return 1.0
class ReLU(nn.ReLU):
@property
def output_multiplier(self) -> float:
return 1.0
class SiLU(nn.SiLU):
@property
def output_multiplier(self) -> float:
return 1.0
class SwiGLU(Activation):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.clone().to(device)
cache["causal_attention_bias"] = causal_bias.clone()
return causal_bias.clone()
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
cache["causal_attention_bias"] = causal_bias.clone()
return causal_bias
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
# shape: (1, 1, seq_len, seq_len)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(-1)
# shape: (n_heads,)
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
m.mul_(config.alibi_bias_max / config.n_heads)
# shape: (1, n_heads, seq_len, seq_len)
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
class LLaDABlock(nn.Module):
"""
A base class for transformer block implementations.
"""
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__()
self.layer_id = layer_id
self.config = config
self.hidden_size = (
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
)
self.__cache = cache
assert config.d_model % config.n_heads == 0
self._activation_checkpoint_fn = None
try:
self.tp_size = get_tensor_model_parallel_world_size()
except AssertionError:
self.tp_size = 1
# Dropout.
self.dropout = Dropout(config.residual_dropout)
# Layer norms.
self.k_norm: Optional[LayerNormBase] = None
self.q_norm: Optional[LayerNormBase] = None
if config.attention_layer_norm:
self.k_norm = LayerNormBase.build(
config,
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
elementwise_affine=config.attention_layer_norm_with_affine,
)
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
# Activation function.
self.act = Activation.build(config)
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Attention output projection.
self.attn_out = nn.Linear(
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
)
# Feed-forward output projection.
self.ff_out = nn.Linear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
device=config.init_device,
)
self.ff_out._is_residual = True # type: ignore
# Rotary embeddings.
if self.config.rope:
self.rotary_emb = RotaryEmbedding(config, self.__cache, tp_size=self.tp_size)
self.flash_attn_func = None
if config.flash_attention:
try:
from flash_attn import flash_attn_func # type: ignore
self.flash_attn_func = flash_attn_func
except:
pass
def reset_parameters(self):
if self.k_norm is not None:
self.k_norm.reset_parameters()
if self.q_norm is not None:
self.q_norm.reset_parameters()
init_weights(
self.config,
self.attn_out,
d=self.config.d_model,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
init_weights(
self.config,
self.ff_out,
d=self.ff_out.in_features,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = None
@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
target_dtype = input_dtype
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
# `is_autocast_cpu_enabled()` for CPU autocast.
# See https://github.com/pytorch/pytorch/issues/110966.
if bias.device.type == "cuda" and torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
target_dtype = torch.get_autocast_cpu_dtype()
if bias.dtype != target_dtype:
bias = bias.to(target_dtype)
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
return bias
def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
"""
Computes scaled dot product attention on query, key and value tensors, using an optional
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
"""
if self.flash_attn_func is not None and attn_mask is None:
r = self.flash_attn_func(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
)
return r.transpose(1, 2)
else:
# torch's sdpa doesn't support GQA, so we're doing this
assert k.size(1) == v.size(1)
num_kv_heads = k.size(1)
num_q_heads = q.size(1)
if num_q_heads != num_kv_heads:
assert num_q_heads % num_kv_heads == 0
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
# Modify: MDM set causal to False, and with no attn_mask.
return scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
)
def attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
kv_cache: Optional[Cache] = None,
use_cache: bool = False,
replace_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = q.size() # batch size, sequence length, d_model
dtype = k.dtype
tp_size = getattr(self, "tp_size", 1)
# Optionally apply layer norm to keys and queries.
if self.q_norm is not None and self.k_norm is not None: #self.q_norm: None, self.k_norm: None
q = self.q_norm(q).to(dtype=dtype)
k = self.k_norm(k).to(dtype=dtype)
# Move head forward to be next to the batch dim.
# shape: (B, nh, T, hs)
# self.config.n_heads: 32
actual_nheads = self.config.n_heads//tp_size
actual_kv_heads = self.config.effective_n_kv_heads//tp_size
q = q.view(B, T, actual_nheads, C // actual_nheads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
k = k.view(B, T, actual_kv_heads, C // actual_nheads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
v = v.view(B, T, actual_kv_heads, C // actual_nheads).transpose(1, 2)
# if layer_past is not None:
if kv_cache is not None:
k, v = kv_cache.update(k, v, self.layer_id, replace_position)
elif layer_past is not None:
past_key, past_value = layer_past
if replace_position is None:
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
else:
# k shape is [B, n_kv_h, selected_length, hs]
# replace_position shape is [B, L], where L contains 0s and 1s, 0 means no replacement, 1 means replace, with selected_length number of 1s
# past_key shape is [B, n_kv_h, L, hs]
# Replace selected_length number of 1s in past_key with k
# Get the indices that need to be replaced
start, end = replace_position
k = past_key.slice_scatter(k, dim=2, start=start, end=end)
v = past_value.slice_scatter(v, dim=2, start=start, end=end)
present = (k, v) if use_cache else None # present: None
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
if self.config.rope:
# Apply rotary embeddings.
if replace_position is None:
q, k = self.rotary_emb(q, k)
else:
q, k = self.rotary_emb(q, k, replace_position[1])
if attention_bias is not None:
# Resize and cast attention bias.
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
# as down-casting the attention bias to the autocast precision will result in -infs, which will
# cause the SDP attn function to produce NaNs.
attention_bias = self._cast_attn_bias(
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
)
# Get the attention scores.
# shape: (B, nh, T, hs)
att = self._scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0 if not self.training else self.config.attention_dropout,
is_causal=False,
)
# Re-assemble all head outputs side-by-side.
att = att.transpose(1, 2).contiguous().view(B, T, C)
# Apply output projection.
return self.attn_out(att), present
@abstractmethod
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
kv_cache = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
raise NotImplementedError
@classmethod
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
if config.block_type == BlockType.sequential:
return LLaDASequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return LLaDALlamaBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
class LLaDASequentialBlock(LLaDABlock):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``