-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathllama_model.cpp
More file actions
984 lines (829 loc) · 44.8 KB
/
llama_model.cpp
File metadata and controls
984 lines (829 loc) · 44.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
// Copyright (c) 2025, IST Austria, developed by Erik Schultheis
// SPDX-License-Identifier: Apache-2.0
//
#include "llama_model.h"
#include <cmath>
#include <sstream>
#include "kernels/kernels.h"
#include "llama_gradients.h"
#include "llama_optimizer.h"
#include "llama_run_state.h"
#include "llama_weights.h"
#include "utilities/comm.h"
LLamaModel::LLamaModel(TransformerConfig config, const LLamaOptions& options, int rank, int world, const std::shared_ptr<TensorAllocator>& alloc) :
Config(config), Options(options), Allocator(alloc ? alloc : std::make_shared<TensorAllocator>())
{
Parameters = LLamaWeightsManager::create(Config, options, rank, world, *Allocator);
}
LLamaModel::~LLamaModel() = default;
void forward_qmm(Tensor& out, QuantizableTensor& inp, Tensor& weight, const Tensor& bias,
cublasLtHandle_t handle, Tensor& workspace,
int B, int T, int C, int OC,
const cudaDeviceProp& dp, bool reuse_inp_quant,
cudaStream_t stream, EMatmulBackend backend) {
if (weight.DType == inp.Value.DType) {
matmul(out, weight, inp.Value, bias, nullptr, nullptr,
handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend);
return;
}
if (!reuse_inp_quant) {
quantize_with_abs_max(inp.Quant, inp.Quant.scale(), inp.Value, inp.Quant.abs_max(), B*T*C, dp, stream);
}
if (weight.DType == ETensorDType::BF16) {
matmul(out, weight, inp.Quant, bias, nullptr, nullptr,
handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend);
} else {
matmul(out, weight, inp.Quant, bias, weight.scale(), inp.Quant.scale(),
handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend);
}
}
template<typename Function>
void trace_or_execute_cuda_graph(Function&& function, cudaStream_t stream, cudaGraphExec_t& instance, bool enabled) {
if (enabled) {
cudaGraph_t graph;
CUDA_CHECK(cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal));
function();
CUDA_CHECK(cudaStreamEndCapture(stream, &graph));
if (instance == nullptr) {
CUDA_CHECK(cudaGraphInstantiate(&instance, graph, nullptr, nullptr, 0));
}
cudaGraphExecUpdateResultInfo result;
if(auto status = cudaGraphExecUpdate(instance, graph, &result); status != cudaSuccess)
{
fprintf(stderr, "Graph update failed: %d\n", result.result);
CUDA_CHECK(status);
}
CUDA_CHECK(cudaGraphDestroy(graph));
CUDA_CHECK(cudaGraphLaunch(instance, stream));
} else {
function();
}
}
void LLamaModel::forward(Tensor inputs, NCCLCommunicator& comm, int micro_step) {
NVTX_RANGE_FN();
if(Options.TriggerTimingEvents) {
RunState->setup_timing_events(micro_step);
CUDA_CHECK(cudaEventRecord(RunState->TimingForwardStart[micro_step], RunState->MainStream));
}
assert(inputs.DType == ETensorDType::INT32);
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
long B = inputs.Sizes[0];
long T = inputs.Sizes[1];
long V = Config.VocabSize;
long C = Config.HiddenSize;
// If this is the first micro-step, the parameters have just changed, and we can not
// re-use any cached values
if(micro_step == 0) {
Parameters->invalidate();
}
assert(rs->Inputs.Sizes[0] >= B);
assert(rs->Inputs.Sizes[1] >= T);
assert(inputs.Device == -1);
{
NvtxRange r{"copy-input"};
// no point running this copy on side stream: input is needed by embedding gradients, which is
// the last op in backward.
CUDA_CHECK(cudaMemcpyAsync(rs->Inputs.Data, inputs.Data, inputs.bytes(), cudaMemcpyHostToDevice, main_stream));
CUDA_CHECK(cudaEventRecord(rs->TransferDone, main_stream));
}
{
NvtxRange emb_range("embedding");
Parameters->gather_embeddings(comm);
encoder_forward(
rs->Encoded,
rs->Inputs,
Parameters->get_embeddings(main_stream),
Tensor{}, B, T, C, V, main_stream);
Parameters->release_embeddings(main_stream);
}
if(rs->AbsMaxes.has_value())
fill_zero(rs->AbsMaxes.value(), main_stream);
Parameters->gather_block(0, comm, *rs);
for (int l = 0; l < Config.NumLayers; l++) {
NvtxRange layer_range("Layer", l);
// prefetch
if (l != Config.NumLayers - 1) {
Parameters->gather_block(l + 1, comm, *rs);
}
auto& wgt = Parameters->get_block(l, main_stream);
Tensor residual = l == 0 ? rs->Encoded : rs->get_res_ffn(l-1, main_stream);
// fuse RMSNorm with residual, except in the first layer when no residual exists yet.
// mark_res_ffn_ready records an event, and we need to wait for that event outside the
// graph, so this block has to be separate.
if (l == 0) {
rmsnorm_forward(rs->Acts[0].LN1.Value, rs->Acts[0].LN1_Rstd, residual, wgt.LN1_w,
rs->Acts[0].LN1.Quant.abs_max(), Config.RmsNormEps, B, T, C, main_stream);
} else {
auto& prev = rs->Acts[l-1];
fused_residual_rmsnorm_forward(residual, rs->Acts[l].LN1.Value, rs->Acts[l].LN1_Rstd,
prev.ResidualAtt, prev.MlpDown, wgt.LN1_w,
rs->Acts[l].LN1.Quant.abs_max(),
Config.RmsNormEps, B * T, C, main_stream);
rs->mark_res_ffn_ready(l-1, main_stream);
}
rs->Acts[l].MlpUp = rs->acquire_mlp_up(l);
trace_or_execute_cuda_graph([&](){_forward_block(wgt, rs->Acts[l], residual);},
main_stream, rs->ForwardBlockGraph, rs->Options.UseCudaGraphs);
Parameters->release_block(l, main_stream);
rs->release_mlp_up(rs->Acts[l].MlpUp);
if(l > 0) {
rs->put_res_ffn(l-1, rs->SideStream);
}
}
{
NvtxRange r{"LNF"};
auto& acts = rs->Acts[Config.NumLayers-1];
Parameters->gather_lnf(comm);
fused_residual_rmsnorm_forward(rs->get_res_ffn(Config.NumLayers - 1, main_stream), rs->LNF, rs->LNF_Rstd, acts.ResidualAtt,
acts.MlpDown, Parameters->get_lnf(main_stream), nullptr, Config.RmsNormEps, B * T, C, main_stream);
Parameters->release_lnf(main_stream);
rs->mark_res_ffn_ready(Config.NumLayers-1, main_stream);
rs->put_res_ffn(Config.NumLayers-1, rs->SideStream);
}
// do not return before inputs can be accessed again.
CUDA_CHECK(cudaEventSynchronize(rs->TransferDone));
CUDA_CHECK(cudaEventRecord(rs->ForwardDone, main_stream));
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(RunState->TimingForwardEnd[micro_step], RunState->MainStream));
}
}
void LLamaModel::_forward_block(sLLamaBlockWeights<Tensor>& weights, sLLamaLayerActivations& acts, Tensor& residual)
{
auto& rs = RunState;
long B = rs->Inputs.Sizes[0];
long T = rs->Inputs.Sizes[1];
long C = Config.HiddenSize;
long D = Config.IntermediateSize;
long Hq = Config.NumQueryHeads;
long Hkv = Config.NumKeyValHeads;
long Hs = Config.head_size();
cudaStream_t main_stream = rs->MainStream;
// 1) projection to QKV vectors (note k,v may be fewer heads than q)
forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b,
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, Config.qkv_channels(),
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
// 2) apply RoPE to q,k (potentially in place)
rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream);
// 3) attention: att <- softmax(qk^T)v
attention_forward_cudnn(acts.Att.Value, acts.LSE, acts.QKV, rs->CuBlasWorkspace, rs->CudnnHandle, B, T, Hq, Hkv, Hs, main_stream);
// quantize attention if necessary
if(acts.Att.Quant) {
abs_max(acts.Att.Quant.abs_max(), acts.Att.Value, acts.Att.Value.nelem(), rs->DeviceProp, main_stream);
}
forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{},
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, C,
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
fused_residual_rmsnorm_forward(acts.ResidualAtt, acts.LN2.Value, acts.LN2_Rstd, residual, acts.AttO, weights.LN2_w,
acts.LN2.Quant.abs_max(), Config.RmsNormEps, B * T, C, main_stream);
forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{},
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, 2 * D,
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
swiglu_forward(acts.SwiGLu.Value, acts.MlpUp, acts.SwiGLu.Quant.abs_max(), B, T, D, main_stream);
forward_qmm(acts.MlpDown, acts.SwiGLu, weights.MLP_Down_w, Tensor{},
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, D, C,
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
}
std::pair<float, float> LLamaModel::validate(Tensor inputs, Tensor targets, NCCLCommunicator& comm, int micro_step) {
NVTX_RANGE_FN();
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
auto& rs = RunState;
const size_t V = Config.VocabSize;
const size_t Vp = Config.VocabSize;
long B = inputs.Sizes[0];
long T = inputs.Sizes[1];
long C = Config.HiddenSize;
cudaStream_t main_stream = rs->MainStream;
forward(inputs, comm, micro_step);
NvtxRange classifier_and_loss_range("classifier_and_loss");
// fused classifier: does the forward pass and first part of the backward pass
const float d_loss = 1.0f / float(B * T); // results in the uniform average loss over all elements
// note: we don't need to generate dlogits here
fill_zero(rs->Losses, main_stream);
if(targets.Device == -1) {
CUDA_CHECK(cudaMemcpy(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice));
} else {
CUDA_CHECK(cudaMemcpy(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyDeviceToDevice));
}
long nano_batches = Options.LMHeadChunks;
int nano_batch_size = div_exact(B * T, nano_batches);
Parameters->gather_head(comm);
rs->temp_acquire(rs->Output);
for(int nano_step = 0; nano_step < nano_batches; nano_step++) {
Tensor lnf_slice = rs->LNF;
lnf_slice.Data += nano_step * nano_batch_size * C * get_dtype_size(lnf_slice.DType);
Tensor tgt = rs->Targets;
tgt.Data += nano_step * nano_batch_size * get_dtype_size(tgt.DType);
Tensor losses = rs->Losses;
losses.Data += nano_step * nano_batch_size * get_dtype_size(losses.DType);
Tensor lse = rs->LSE;
lse.Data += nano_step * nano_batch_size * get_dtype_size(lse.DType);
matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice,
Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, false, main_stream, rs->MatmulBackend);
// accumulate the losses inside rs->losses, and kick off the backward pass inside the fused classifier
fused_classifier(rs->Output, losses, lse, d_loss, tgt, 0.f, nano_batch_size, V, Vp, false, main_stream);
}
rs->temp_free(rs->Output);
Parameters->release_head(main_stream);
_reduce_loss(*rs, comm, B, T);
CUDA_CHECK(cudaDeviceSynchronize());
float full_loss = rs->get_loss() / (B*T);
float loss_1k = rs->get_loss(1024) / (B*std::min(1024l, T));
return {full_loss, loss_1k};
}
void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias,
QuantizableTensor& dout, QuantizableTensor& inp, Tensor& weight, Tensor bias_buffer,
bool accumulate_gradient,
LLamaRunState& rs,
int B, int T, int C, int OC,
bool reuse_inp, cudaStream_t stream) {
if (weight.DType == inp.Value.DType) {
matmul(dinp, weight, dout.Value, Tensor{}, nullptr, nullptr,
rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend);
matmul(dweight, inp.Value, dout.Value, Tensor{}, nullptr, nullptr,
rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend);
if (dbias) {
backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream);
}
} else if (weight.DType == ETensorDType::BF16) {
quantize_with_abs_max(dout.Quant, dout.Quant.scale(), dout.Value, nullptr, B*T*OC, rs.DeviceProp, stream);
if(!reuse_inp) {
quantize_with_abs_max(inp.Quant, dout.Quant.scale(), inp.Value, nullptr, B*T*C, rs.DeviceProp, stream);
}
matmul(dinp, weight, dout.Quant, Tensor{}, nullptr, nullptr,
rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend);
matmul(dweight, inp.Quant, dout.Quant, Tensor{}, nullptr, nullptr,
rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend);
if (dbias) {
backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream);
}
} else {
quantize_with_abs_max(dout.Quant, dout.Quant.scale(), dout.Value, dout.Quant.abs_max(), B*T*OC, rs.DeviceProp, stream);
auto& inp_q = inp.Quant;
auto weight_tp = rs.temp_alloc(inp_q.DType, {C, OC});
transpose(weight_tp, weight, OC, C, stream);
matmul(dinp, weight_tp, dout.Quant, Tensor{}, weight.scale(), dout.Quant.scale(),
rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::TN, false, stream, rs.MatmulBackend);
rs.temp_free(weight_tp);
auto activation_tp = rs.temp_alloc(inp_q.DType, {C, B*T});
auto grad_tp = rs.temp_alloc(rs.Options.grad_dtype(), {OC, B*T});
if(reuse_inp) {
// inp is already quantized from the forward pass, so just transpose here
transpose(activation_tp, inp_q, B*T, C, stream);
} else {
// even though we're re-using (and overwriting) the main tensor, each tensor still has its own version
// of the absmax-scale, so we can reuse the existing scale from the forward pass
quantize_and_transpose_with_abs_max(activation_tp, activation_tp.scale(), inp.Value, inp.Quant.abs_max(), B*T, C, rs.DeviceProp, stream);
}
transpose(grad_tp, dout.Quant, B*T, OC, stream);
matmul(dweight, activation_tp, grad_tp, Tensor{}, inp_q.scale(), dout.Quant.scale(),
rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::TN, accumulate_gradient, stream, rs.MatmulBackend);
if (dbias) {
backward_bias(dbias, dout.Quant, inp_q.scale(), dout.Quant.scale(), bias_buffer, B, T, OC, rs.DeviceProp, stream);
}
rs.temp_free(grad_tp);
rs.temp_free(activation_tp);
}
}
void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, float z_loss, int grad_accum_steps, int micro_step) {
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
NVTX_RANGE_FN();
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingBackwardStart[micro_step], main_stream));
}
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
long B = inputs.Sizes[0];
long T = inputs.Sizes[1];
const size_t C = Config.HiddenSize;
const size_t L = Config.NumLayers;
{
NvtxRange r{"copy-targets"};
// make sure rs->Targets is no longer needed by the previous step.
CUDA_CHECK(cudaStreamWaitEvent(rs->SideStream, rs->BackwardDone, 0));
CUDA_CHECK(cudaMemcpyAsync(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice, rs->SideStream));
CUDA_CHECK(cudaEventRecord(rs->TransferDone, rs->SideStream));
// we will wait in _backward_lmhead for this transfer to be done.
}
bool last_step = micro_step == grad_accum_steps - 1;
// on the first micro-step zero the gradients, as we're about to += accumulate into them
if (micro_step == 0) {
NvtxRange classifier_and_loss_range("zero gradients");
// there are currently two state vars during the gradient accumulation inner loop:
// 1) the losses accumulate += into rs->losses, reset here
// 2) the gradients accumulate += into grads_memory, reset here
fill_zero(rs->Losses, main_stream);
Grads->start_micro_step(rs->SideStream, micro_step, grad_accum_steps);
CUDA_CHECK(cudaEventRecord(rs->SideStreamEvent, rs->SideStream));
} else {
Grads->start_micro_step(main_stream, micro_step, grad_accum_steps);
}
// reset residual stream gradients (put here to work with gradient accumulation)
fill_zero(rs->DLNF, main_stream);
fill_zero(rs->DActs[L-1].DResFFN.Value, main_stream);
_backward_lmhead(B, T, z_loss, micro_step, grad_accum_steps, comm);
// ok, now reduce the loss across all ranks
if (last_step) {
_reduce_loss(*rs, comm, B, T);
}
bool accumulate;
auto& d_lnf_w = Grads->get_non_block_full(LLamaWeightID::LNF_W, main_stream, comm, accumulate);
Parameters->gather_lnf(comm);
// backward the final layernorm
rmsnorm_backward(rs->DActs[L-1].DResFFN.Value, d_lnf_w, rs->RMSNormScratch, rs->DActs[L - 1].DResFFN.Value, rs->DLNF,
rs->get_res_ffn(L-1, main_stream), Parameters->get_lnf(main_stream), rs->LNF_Rstd,
rs->DActs[L-1].DResFFN.Quant.abs_max(), B, T, C, rs->DeviceProp, main_stream);
rs->release_res_ffn(L-1, main_stream);
Parameters->release_lnf(main_stream);
Grads->notify_non_block(LLamaWeightID::LNF_W, main_stream, comm);
rs->fetch_res_ffn(L-2, comm.stream());
Parameters->gather_block(L - 1, comm, *rs);
// now backward all the layers
for (int l = L-1; l >= 0; l--) {
NvtxRange layer_range("Layer", l);
auto& dw = Grads->get_block_full(l, main_stream, comm, accumulate);
// prefetch previous layer
if(l > 1) {
rs->fetch_res_ffn(l-2, comm.stream());
}
if(l > 0) {
Parameters->gather_block(l - 1, comm, *rs);
} else if (!last_step) {
Parameters->gather_embeddings(comm);
}
auto& weights = Parameters->get_block(l, main_stream);
auto& d_acts = rs->DActs.at(l);
Tensor residual = l == 0 ? rs->Encoded : rs->get_res_ffn(l - 1, main_stream);
rs->Acts[l].MlpUp = rs->acquire_mlp_up(l);
rs->DActs[l].DMlpUp.Value = rs->Acts[l].MlpUp;
trace_or_execute_cuda_graph([&]() {
_recompute_block(weights, rs->Acts[l], residual);
_backward_block(accumulate, weights, dw, rs->Acts[l], rs->DActs[l]);
}, main_stream, rs->BackwardBlockGraph, rs->Options.UseCudaGraphs);
rs->release_mlp_up(rs->Acts[l].MlpUp);
if(l > 0) {
auto& prev_dacts = rs->DActs.at(l - 1);
rmsnorm_backward(prev_dacts.DResFFN.Value, dw.get_tensor(LLamaWeightID::LN1_W), rs->RMSNormScratch, prev_dacts.DResAtt.Value, d_acts.DLN1,
rs->get_res_ffn(l-1, main_stream), weights.LN1_w, rs->Acts[l].LN1_Rstd, prev_dacts.DResFFN.Quant.abs_max(),
B, T, C, rs->DeviceProp, main_stream);
rs->release_res_ffn(l - 1, main_stream);
} else {
rmsnorm_backward(rs->DEmb, dw.get_tensor(LLamaWeightID::LN1_W), rs->RMSNormScratch, d_acts.DResAtt.Value, d_acts.DLN1,
rs->Encoded, weights.LN1_w, rs->Acts[l].LN1_Rstd, nullptr, B, T, C, rs->DeviceProp, main_stream);
}
Parameters->release_block(l, main_stream);
Grads->notify_block(l, main_stream, comm);
}
auto& d_emb = Grads->get_non_block_full(LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
encoder_backward(d_emb, rs->EncoderBwdScratch, rs->EncoderBwdIndices, rs->EncoderBwdInfo,
rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream, rs->SideStreamEvent, rs->SideStream);
Grads->notify_non_block(LLamaWeightID::EMBEDDING, main_stream, comm);
// make sure all gradients are communicated before we go to the update step.
Grads->end_micro_step(main_stream, comm);
CUDA_CHECK(cudaEventRecord(rs->BackwardDone, main_stream));
// do not return before inputs can be accessed again.
CUDA_CHECK(cudaEventSynchronize(rs->TransferDone));
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingBackwardEnd[micro_step], main_stream));
}
}
void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step, int grad_accum_steps, NCCLCommunicator& comm) {
auto& rs = RunState;
const size_t C = Config.HiddenSize;
const size_t V = Config.VocabSize;
const size_t Vp = Config.VocabSize;
cudaStream_t main_stream = rs->MainStream;
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingHeadStart[micro_step], main_stream));
}
long nano_batches = Options.LMHeadChunks;
int nano_batch_size = div_exact(B * T, nano_batches);
const float d_loss =
1.0f / (float) (B * T * grad_accum_steps); // results in the uniform average loss over all elements
NvtxRange classifier_and_loss_range("lm-head");
Parameters->gather_head(comm);
rs->temp_acquire(rs->Output);
for (int nano_step = 0; nano_step < nano_batches; nano_step++) {
Tensor lnf_slice = rs->LNF;
lnf_slice.Data += nano_step * nano_batch_size * C * get_dtype_size(lnf_slice.DType);
Tensor tgt = rs->Targets;
tgt.Data += nano_step * nano_batch_size * get_dtype_size(tgt.DType);
Tensor losses = rs->Losses;
losses.Data += nano_step * nano_batch_size * get_dtype_size(losses.DType);
Tensor lse = rs->LSE;
lse.Data += nano_step * nano_batch_size * get_dtype_size(lse.DType);
Tensor dlnf_slice = rs->DLNF;
dlnf_slice.Data += nano_step * nano_batch_size * C * get_dtype_size(dlnf_slice.DType);
matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice, Tensor{},
nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN,
false, main_stream, rs->MatmulBackend);
if(nano_step == 0) {
// make sure Targets have been copied
CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->TransferDone, 0));
}
// accumulate the losses inside rs->losses, and kick off the backward pass inside the fused classifier
fused_classifier(rs->Output, losses, lse, d_loss, tgt, z_loss / (B*T*grad_accum_steps), nano_batch_size, V, Vp, true, main_stream);
// if we reset model grads to zero, now is the time we need to wait
if (micro_step == 0 && nano_step == 0) {
CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->SideStreamEvent, 0));
}
// handle the LM-head. We run the d_lmhead matmul first, so that the gradient reduction can overlap with the DLNF matmul.
bool accumulate;
// get the correct matrix depending on whether we have tied embeddings
auto& d_lmhead = [&]() -> Tensor& {
if (Config.TiedWordEmbeddings) {
return Grads->get_non_block_full(LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
} else {
return Grads->get_non_block_full(LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
}
}();
// even if we overwrite for first micro-batch, we need to accumulate on non-first nano batch
accumulate |= nano_step != 0;
matmul(d_lmhead, lnf_slice, rs->Output, Tensor{}, nullptr, nullptr,
rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend);
if (nano_step == nano_batches - 1 && !Config.TiedWordEmbeddings) {
Grads->notify_non_block(LLamaWeightID::LM_HEAD, main_stream, comm);
}
matmul(dlnf_slice, Parameters->get_head(main_stream), rs->Output, Tensor{}, nullptr, nullptr,
rs->CublasLtHandle, rs->CuBlasWorkspace, C, nano_batch_size, V, EMMTranspose::NN, false, main_stream, rs->MatmulBackend);
}
rs->temp_free(rs->Output);
Parameters->release_head(main_stream);
reduce_lse_stats(rs->LSEHost, rs->LSE.get<float>(), rs->LSE.nelem(), micro_step == 0, rs->SideStream);
CUDA_CHECK(cudaEventRecord(rs->LSEDone, rs->SideStream));
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingHeadEnd[micro_step], main_stream));
}
}
void LLamaModel::_recompute_block(sLLamaBlockWeights<Tensor>& weights, sLLamaLayerActivations& acts, Tensor& residual) {
NvtxRange classifier_and_loss_range("recompute");
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
long B = rs->Inputs.Sizes[0];
long T = rs->Inputs.Sizes[1];
const size_t C = Config.HiddenSize;
long D = Config.IntermediateSize;
long Hq = Config.NumQueryHeads;
long Hkv = Config.NumKeyValHeads;
long Hs = Config.head_size();
auto& opt = rs->Options;
// Figure out which parts we need to recompute
bool recompute_ln1 = opt.RecomputeRMSNorm || opt.RecomputeAtt || opt.RecomputeBlock;
bool recompute_ln2 = opt.RecomputeRMSNorm || opt.RecomputeFFN || opt.RecomputeBlock;
bool recompute_qkv = opt.RecomputeQKV || opt.RecomputeAtt || opt.RecomputeBlock;
bool recompute_swiglu = opt.RecomputeSwiGLu || opt.RecomputeFFN || opt.RecomputeBlock;
bool recompute_att = opt.RecomputeAtt || opt.RecomputeBlock;
// Attention block
if(recompute_ln1) {
rmsnorm_forward(acts.LN1.Value, acts.LN1_Rstd, residual, weights.LN1_w, nullptr, Config.RmsNormEps, B, T, C, main_stream);
}
if (recompute_qkv) {
// two scenarios: 1) we do not recompute the RMSnorm; then, we _will_ overwrite the full-precision copy of acts.LN1,
// but _can_ reuse the quantized version
// 2) we recompute RMSNorm; then, acts.LN1 will be correct, but its quantized version will not, so
// we have to re-quantize
forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b,
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, Config.qkv_channels(),
rs->DeviceProp, !recompute_ln1, main_stream, rs->MatmulBackend);
rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream);
}
if (recompute_att) {
attention_forward_cudnn(acts.Att.Value, acts.LSE, acts.QKV, rs->CuBlasWorkspace, rs->CudnnHandle, B, T, Hq, Hkv, Hs, main_stream);
// AttO not needed in backward pass; but if we want to recompute the entire transformer block, we need its output
// to recompute the FFN part
if (opt.RecomputeBlock) {
forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{},
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, C,
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
}
}
// Feed-forward block
if(recompute_ln2) {
if (opt.RecomputeBlock) {
fused_residual_rmsnorm_forward(acts.ResidualAtt, acts.LN2.Value, acts.LN2_Rstd, residual, acts.AttO, weights.LN2_w,
nullptr, Config.RmsNormEps, B * T, C, main_stream);
} else {
rmsnorm_forward(acts.LN2.Value, acts.LN2_Rstd, acts.ResidualAtt, weights.LN2_w,
nullptr, Config.RmsNormEps, B, T, C, main_stream);
}
}
if(opt.RecomputeFFN) {
forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{},
rs->CublasLtHandle, rs->CuBlasWorkspace,
B, T, C, 2 * D,
rs->DeviceProp, false, main_stream, rs->MatmulBackend);
}
if(recompute_swiglu) {
if (acts.SwiGLu.Quant) {
swiglu_forward_quant(acts.SwiGLu.Quant, acts.SwiGLu.Quant.scale(), acts.MlpUp, acts.SwiGLu.Quant.abs_max(), B, T, D, main_stream);
} else {
swiglu_forward(acts.SwiGLu.Value, acts.MlpUp, nullptr, B, T, D, main_stream);
}
}
}
void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& weights, SimpleTensorContainer& d_weights,
sLLamaLayerActivations& acts, sLLamaLayerGradients& d_acts) {
using namespace LLamaWeightID;
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
long B = rs->Inputs.Sizes[0];
long T = rs->Inputs.Sizes[1];
const size_t C = Config.HiddenSize;
long D = Config.IntermediateSize;
long Hq = Config.NumQueryHeads;
long Hkv = Config.NumKeyValHeads;
long Hs = Config.head_size();
// backward the 2nd matmul of MLP
// note that _recompute_block guarantees that if SwiGLu is already quantized (if necessary)
rs->temp_acquire(d_acts.DSwiGLU);
backward_qmm(d_acts.DSwiGLU, d_weights.get_tensor(DOWN_W), Tensor{}, d_acts.DResFFN, acts.SwiGLu, weights.MLP_Down_w, Tensor{},
accumulate, *rs, B, T, D, C, true, main_stream);
swiglu_backward(d_acts.DMlpUp.Value, d_acts.DSwiGLU, acts.MlpUp, d_acts.DMlpUp.Quant.abs_max(), B, T, D, main_stream);
rs->temp_free(d_acts.DSwiGLU);
if(Options.grad_dtype() != d_acts.DMlpUp.Value.DType) {
rs->temp_acquire(d_acts.DMlpUp.Quant);
}
backward_qmm(d_acts.DLN2, d_weights.get_tensor(UP_W), Tensor{}, d_acts.DMlpUp, acts.LN2, weights.MLP_Up_w, Tensor{},
accumulate, *rs, B, T, C, 2 * D, !rs->Options.RecomputeRMSNorm, main_stream);
if(Options.grad_dtype() != d_acts.DMlpUp.Value.DType) {
rs->temp_free(d_acts.DMlpUp.Quant);
}
// rmsnorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
rmsnorm_backward(d_acts.DResAtt.Value, d_weights.get_tensor(LN2_W), rs->RMSNormScratch, d_acts.DResFFN.Value, d_acts.DLN2,
acts.ResidualAtt, weights.LN2_w, acts.LN2_Rstd, d_acts.DResAtt.Quant.abs_max(), B, T, C, rs->DeviceProp, main_stream);
bool recompute_ln1 = rs->Options.RecomputeRMSNorm || rs->Options.RecomputeAtt;
backward_qmm(d_acts.DAttY, d_weights.get_tensor(ATTO_W), Tensor{}, d_acts.DResAtt, acts.Att, weights.Attn_Out_w, Tensor{},
accumulate, *rs, B, T, C, C, false, main_stream);
rs->temp_acquire(d_acts.DQKV.Value);
rs->temp_acquire(rs->CuDNNWorkspace);
for (int i=0; i < Options.AttBwdChunks; ++i) {
long chunk_batch_size = div_exact(B, (long)Options.AttBwdChunks);
Tensor d_qkv = shard_view(d_acts.DQKV.Value, i, Options.AttBwdChunks);
Tensor lse = shard_view(acts.LSE, i, Options.AttBwdChunks);
Tensor att = shard_view(acts.Att.Value, i, Options.AttBwdChunks);
Tensor d_atty = shard_view(d_acts.DAttY, i, Options.AttBwdChunks);
Tensor qkv = shard_view(acts.QKV, i, Options.AttBwdChunks);
attention_backward_cudnn(d_qkv, lse, att, d_atty, qkv, rs->CuDNNWorkspace, rs->CudnnHandle,
chunk_batch_size, T, Hq, Hkv, Hs, main_stream);
}
rs->temp_free(rs->CuDNNWorkspace);
rope_backward(d_acts.DQKV.Value, d_acts.DQKV.Value, rs->FreqCis, d_acts.DQKV.Quant.abs_max(), B, T, Hq, Hkv, Hs, main_stream);
backward_qmm(d_acts.DLN1, d_weights.get_tensor(QKV_W), d_weights.get_tensor(QKV_B), d_acts.DQKV, acts.LN1, weights.Attn_QKV_w, rs->MatmulBiasScratch,
accumulate, *rs, B, T, C, Config.qkv_channels(), !recompute_ln1, main_stream);
rs->temp_free(d_acts.DQKV.Value);
}
void LLamaModel::_reduce_loss(LLamaRunState& acts, NCCLCommunicator& comm, int B, int T) {
NVTX_RANGE_FN();
// reduce all the losses within the current GPU (across all microsteps)
grouped_loss_sum(acts.GroupedLosses, acts.Losses, B, T, acts.MainStream);
// reduce loss across GPUs to a single, final float across all GPUs
comm.reduce_mean(acts.GroupedLosses.get<float>(), acts.GroupedLosses.nelem(), acts.MainStream);
CUDA_CHECK(cudaMemcpyAsync(acts.LossHost, acts.GroupedLosses.get<float>(), sizeof(float) * acts.GroupedLosses.nelem(), cudaMemcpyDeviceToHost, acts.MainStream));
}
void LLamaModel::calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip) {
NVTX_RANGE_FN();
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->BackwardDone));
if(rs->Options.UseCudaGraphs) {
if(!rs->GlobalNormGraph) {
cudaGraph_t graph;
CUDA_CHECK(cudaStreamBeginCapture(main_stream, cudaStreamCaptureModeThreadLocal));
_calculate_gradient_norm(comm, grad_clip, rs->MainStream);
CUDA_CHECK(cudaStreamEndCapture(main_stream, &graph));
CUDA_CHECK(cudaGraphInstantiate(&rs->GlobalNormGraph, graph, nullptr, nullptr, 0));
CUDA_CHECK(cudaGraphDestroy(graph));
}
CUDA_CHECK(cudaGraphLaunch(rs->GlobalNormGraph, main_stream));
} else {
_calculate_gradient_norm(comm, grad_clip, rs->MainStream);
}
CUDA_CHECK(cudaEventRecord(rs->NormDone, main_stream));
}
IRunState& LLamaModel::get_run_state() const {
return *RunState;
}
std::size_t LLamaModel::num_block_tensors() const {
return 7;
}
void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config,
ETensorDType matrix_dtype, ETensorDType other_dtype) const
{
long C = config.HiddenSize;
long H = config.IntermediateSize;
long HS = config.head_size();
auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) {
tgt.Rank = cols != 0 ? 2 : 1;
tgt.DType = dtype;
tgt.Sizes[0] = rows;
tgt.Sizes[1] = cols;
};
long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS;
create(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::ATTO_W), C, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::DOWN_W), C, H, matrix_dtype);
create(target.get_tensor(LLamaWeightID::LN1_W), C, 0, other_dtype);
create(target.get_tensor(LLamaWeightID::LN2_W), C, 0, other_dtype);
create(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0, 0, other_dtype);
}
std::size_t LLamaModel::num_non_block_tensors() const {
return 3;
}
void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config,
ETensorDType matrix_dtype, ETensorDType other_dtype) const
{
long V = config.VocabSize;
long C = config.HiddenSize;
auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) {
tgt.Rank = cols != 0 ? 2 : 1;
tgt.DType = dtype;
tgt.Sizes[0] = rows;
tgt.Sizes[1] = cols;
};
create(target.get_tensor(LLamaWeightID::EMBEDDING), V, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype);
if(!config.TiedWordEmbeddings) {
create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype);
} else {
create(target.get_tensor(LLamaWeightID::LM_HEAD), 0, 0, matrix_dtype);
}
}
void LLamaModel::_calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream) {
auto& rs = RunState;
fill_zero(rs->NormBuffer, stream);
auto norm_squared = [&](const TensorShard& grad){
global_norm_squared(rs->NormBuffer, grad, grad.nelem(), rs->DeviceProp, stream);
};
norm_squared(Grads->get_non_block_shard(LLamaWeightID::EMBEDDING, stream));
if(!Config.TiedWordEmbeddings) {
norm_squared(Grads->get_non_block_shard(LLamaWeightID::LM_HEAD, stream));
}
norm_squared(Grads->get_non_block_shard(LLamaWeightID::LNF_W, stream));
for(int i = 0; i < Config.NumLayers; i++) {
auto& block = Grads->get_block_shard(i, stream);
visit([&](Tensor& t){
norm_squared(t);
}, block);
}
// final reduction to a single norm-squared element
deterministic_sum(rs->NormBuffer.get<float>(), rs->NormBuffer.get<float>(), rs->NormBuffer.nelem(), stream);
// potential cross-gpu reduction
comm.reduce_norm(rs->NormBuffer.get<float>(), stream);
// tiny kernel (1 thread) that calculates norm, scale factor, and puts the result on the host for later display
global_norm_sqrt(rs->NormBuffer.get<float>(), rs->NormHost, grad_clip, rs->DeviceProp, stream);
}
void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_1, float beta_2, int t, float epsilon, float weight_decay, float grad_clip) {
NVTX_RANGE_FN();
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
if(!OptimizerState) {
throw std::logic_error("LLamaModel::update() but no optimizer available");
}
auto& rng = OptimizerRNG;
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingOptimizerStart, main_stream));
}
Parameters->begin_optimizer(rs->Stack, rs->MainStream);
OptimizerState->begin_optimizer(rs->Stack, rs->MainStream);
// grad_scale gets deposited into NormBuffer[1] and can be used on main_stream after this.
calculate_gradient_norm(comm, grad_clip);
float* grad_scale = rs->NormBuffer.get<float>() + 1;
auto run_update = [&](Tensor& val, Tensor& grad, Tensor& m, Tensor& v, Tensor& scales, float wd) {
adamw_update(val, grad, m, v, grad.nelem(),
learning_rate, beta_1, beta_2, t, epsilon, wd, grad_scale, scales, val.abs_max(), rng(), main_stream);
};
auto& nb_scales = OptimizerState->non_block_m_scales();
using namespace LLamaWeightID;
run_update(Parameters->get_master_embeddings(), Grads->get_non_block_shard(EMBEDDING, main_stream),
OptimizerState->non_block_m().get_tensor(EMBEDDING), OptimizerState->non_block_v().get_tensor(EMBEDDING),
nb_scales.get_tensor(EMBEDDING), weight_decay);
comm.reduce_max(Parameters->get_master_embeddings().abs_max());
run_update(Parameters->get_master_lnf_w(), Grads->get_non_block_shard(LNF_W, main_stream),
OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), nb_scales.get_tensor(LNF_W), 0.f);
comm.reduce_max(Parameters->get_master_lnf_w().abs_max());
CUDA_CHECK(cudaEventRecord(rs->OptEmbeddingsDone, main_stream));
for(int i = 0; i < Config.NumLayers; i++) {
NvtxRange layer_range("Layer", i);
Parameters->fetch_master_block(i, comm.stream());
OptimizerState->fetch_block(i, comm.stream());
auto& bw = Parameters->get_master_block(i, main_stream);
auto& bg = Grads->get_block_shard(i, main_stream);
auto& bm = OptimizerState->get_block_m(i, main_stream);
auto& bv = OptimizerState->get_block_v(i, main_stream);
auto& sm = OptimizerState->get_block_scales_m(i);
run_update(bw.get_tensor(LN1_W), bg.get_tensor(LN1_W), bm.get_tensor(LN1_W), bv.get_tensor(LN1_W), sm.get_tensor(LN1_W), 0.f);
run_update(bw.get_tensor(LN2_W), bg.get_tensor(LN2_W), bm.get_tensor(LN2_W), bv.get_tensor(LN2_W), sm.get_tensor(LN2_W), 0.f);
run_update(bw.get_tensor(QKV_W), bg.get_tensor(QKV_W), bm.get_tensor(QKV_W), bv.get_tensor(QKV_W),
sm.get_tensor(QKV_W), weight_decay);
if(Config.UseQKVBias) {
run_update(bw.get_tensor(QKV_B), bg.get_tensor(QKV_B), bm.get_tensor(QKV_B), bv.get_tensor(QKV_B), sm.get_tensor(QKV_B), 0.f);
}
run_update(bw.get_tensor(ATTO_W), bg.get_tensor(ATTO_W), bm.get_tensor(ATTO_W), bv.get_tensor(ATTO_W), sm.get_tensor(ATTO_W), weight_decay);
run_update(bw.get_tensor(UP_W), bg.get_tensor(UP_W), bm.get_tensor(UP_W), bv.get_tensor(UP_W), sm.get_tensor(UP_W), weight_decay);
run_update(bw.get_tensor(DOWN_W), bg.get_tensor(DOWN_W), bm.get_tensor(DOWN_W), bv.get_tensor(DOWN_W), sm.get_tensor(DOWN_W), weight_decay);
auto scales = Parameters->get_scales_for_block(i);
// yes, we run this on main stream. Yes, this isn't nice because it prevents kernels from running in parallel.
// the communication is tiny, though, so it doesn't matter, and this setup guarantees that the abs-maxes are
// ready once we try to quantize on the main stream (which happens in `release_master_block`), so in that case
// we'd have to wait anyway.
// TODO there's probably a way to schedule this so that we can avoid this idle time. If it turns out to actually
// matter (e.g., for small models), we can investigate more.
comm.reduce_max(scales.first, scales.second - scales.first, main_stream);
Parameters->release_master_block(i, main_stream, rs->SideStream, *rs);
OptimizerState->store_block(i, main_stream, rs->SideStream);
CUDA_CHECK(cudaEventRecord(rs->LayerUpdateDone[i], main_stream));
}
if(!Config.TiedWordEmbeddings) {
run_update(Parameters->get_master_lmhead(), Grads->get_non_block_shard(LM_HEAD, main_stream),
OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), nb_scales.get_tensor(LM_HEAD), weight_decay);
comm.reduce_max(Parameters->get_master_lmhead().abs_max());
}
comm.wait_on_comms(main_stream);
OptimizerState->end_optimizer(rs->Stack);
Parameters->end_optimizer(rs->Stack);
CUDA_CHECK(cudaEventRecord(rs->OptimizerDone, main_stream));
if(Options.TriggerTimingEvents) {
CUDA_CHECK(cudaEventRecord(rs->TimingOptimizerEnd, main_stream));
}
}
void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicator& comm, int B, int T) {
NVTX_RANGE_FN();
std::vector<std::pair<const char*, std::size_t>> stack_watermark;
// create a dummy stack and simulate the way we're going to use temporaries later, to determine how much we need to allocate
int dev;
CUDA_CHECK(cudaGetDevice(&dev));
DeviceMemoryStack stack(nullptr, 1024 * 1024 * 1024 * 1024ll, dev);
LLamaRunState acts;
{
auto ctx = Allocator->with_context("Activations");
acts = ::allocate_run_state(Config, options, B, T, stack, Allocator);
}
OptimizerState = std::make_unique<LLamaOptimizerStateManager>(Config, *this, options, comm);
OptimizerState->allocate_state(*this, acts.MainStream, options.offload_alloc(), *Allocator);
Parameters->begin_optimizer(stack, comm.stream());
OptimizerState->begin_optimizer(stack, comm.stream());
OptimizerState->end_optimizer(stack);
Parameters->end_optimizer(stack);
{
auto ctx = Allocator->with_context("Stack");
long required_size = stack.max_utilization();
acts.Stack = DeviceMemoryStack{Allocator->allocate(ETensorDType::BYTE, "stack", {required_size}).Data, (std::size_t)required_size, dev};
acts.Stack.set_high_mark(stack.get_high_mark());
}
{
auto ctx = Allocator->with_context("Gradients");
Grads = create_grads_manager(42, 0, *this, Config, options, comm.rank(), comm.world_size(), Allocator);
}
OptimizerRNG = std::minstd_rand{42};
RunState = std::make_unique<LLamaRunState>(std::move(acts));
if (options.UseCustomMatmul) {
RunState->MatmulBackend = EMatmulBackend::Custom;
} else {
RunState->MatmulBackend = EMatmulBackend::CuBLAS;
}
comm.barrier(); // make sure *all* GPUs have allocated the model before returning
}
ITensorContainer& LLamaModel::weights() {
return *Parameters;
}
AdamWStateManager& LLamaModel::optimizer() {
return *OptimizerState;
}
std::vector<std::byte> LLamaModel::rng_state() const {
std::stringstream tmp;
tmp << OptimizerRNG;
auto view = tmp.rdbuf()->view();
std::vector<std::byte> state;
state.reserve(view.size());
std::transform(view.begin(), view.end(), std::back_inserter(state), [](char c) { return static_cast<std::byte>(c); });
return state;
}
void LLamaModel::set_rng_state(const std::vector<std::byte>& state) {
std::stringstream tmp;
tmp.write(reinterpret_cast<const char*>(state.data()), state.size());
tmp >> OptimizerRNG;
}
void LLamaModel::import_weights(const std::string& file_name, bool allow_cast, NCCLCommunicator& comm) {
Parameters->import_from_file(file_name, allow_cast, comm);
}
void LLamaModel::init_weights(NCCLCommunicator& comm) {
Parameters->random_init(42, Options, comm);
}
void LLamaModel::export_weights(const std::string& file_name, NCCLCommunicator& comm) {
Parameters->export_to_file(file_name, comm);
}
void LLamaModel::on_restore_checkpoint(NCCLCommunicator& comm) {
Parameters->synchronize_absmax(comm);
}
std::string_view LLamaModel::model_type() const {
return Config.model_name();
}