-
Notifications
You must be signed in to change notification settings - Fork 302
/
Copy pathindex.html
1962 lines (1346 loc) · 80 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html lang="en">
<head>
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-T8XT4PS');</script>
<!-- End Google Tag Manager -->
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico?">
<title>
CUDA-Free Inference for LLMs | PyTorch
</title>
<meta name="robots" content="index, follow" />
<meta name="description" content="In this blog, we discuss the methods we used to achieve FP16 inference with popular LLM models such as Meta’s Llama3-8B and IBM’s Granite-8B Code, where 100% of the computation is performed using OpenAI’s Triton Language.
For single token generation times using our Triton kernel based models, we were able to approach 0.76-0.78x performance relative to the CUDA kernel dominant workflows for both Llama and Granite on Nvidia H100 GPUs, and 0.62-0.82x on Nvidia A100 GPUs.
Why explore using 100% Triton? Triton provides a path for enabling LLMs to run on different types of GPUs - NVIDIA, AMD, and in the future Intel and other GPU based accelerators. It also provides a higher layer of abstraction in Python for programming GPUs and has allowed us to write performant kernels faster than authoring them using vendor specific APIs. In the rest of this blog, we will share how we achieve CUDA-free compute, micro-benchmark individual kernels for comparison, and discuss how we can further improve future Triton kernels to close the gaps.
Figure 1. Inference throughput benchmarks with Triton and CUDA variants of Llama3-8B and Granite-8B, on NVIDIA H100 and A100
Settings: batch size = 2, input sequence length = 512, output sequence length = 256
2.0 Composition of a Transformer Block
We start with a breakdown of the computations that happen in Transformer-based models. The figure below shows the “kernels” of a typical Transformer block.
Figure 2. Transformer Block by core kernels
The core operations for a Llama3 architecture are summarized in this list:
RMSNorm
Matrix multiplication: Fused QKV
RoPE
Attention
Matrix multiplication: Output Projection
RMSNorm
Matrix multiplication: Fused Gate + Up Projection
Activation function: SiLU
Element Wise Multiplication
Matrix multiplication: Down Projection
Each of these operations is computed on the GPU through the execution of one (or multiple) kernels. While the specifics of each of these kernels can vary across different transformer models, the core operations remain the same. For example, IBM’s Granite 8B Code model uses bias in the MLP layer, different from Llama3. Such changes do require modifications to the kernels. A typical model is a stack of these transformer blocks wired together with embedding layers.
3.0 Model Inference
Typical model architecture code is shared with a python model.py file that is launched by PyTorch. In the default PyTorch eager execution mode, these kernels are all executed with CUDA. To achieve 100% Triton for end-to-end Llama3-8B and Granite-8B inference we need to write and integrate handwritten Triton kernels as well as leverage torch.compile (to generate Triton ops). First, we replace smaller ops with compiler generated Triton kernels, and second, we replace more expensive and complex computations (e.g. matrix multiplication and flash attention) with handwritten Triton kernels.
Torch.compile generates Triton kernels automatically for RMSNorm, RoPE, SiLU and Element Wise Multiplication. Using tools like Nsight Systems we can observe these generated kernels; they appear as tiny dark green kernels in-between the matrix multiplications and attention.
Figure 3. Trace of Llama3-8B with torch.compile, showing CUDA kernels being used for matrix multiplications and flash attention
For the above trace, we note that the two major ops that make up 80% of the E2E latency in a Llama3-8B style model are matrix multiplication and attention kernels and both remain CUDA kernels. Thus to close the remaining gap, we replace both matmul and attention kernels with handwritten Triton kernels.
4.0 Triton SplitK GEMM Kernel
For the matrix multiplications in the linear layers, we wrote a custom FP16 Triton GEMM (General Matrix-Matrix Multiply) kernel that leverages a SplitK work decomposition. We have previously discussed this parallelization in other blogs as a way to accelerate the decoding portion of LLM inference.
5.0 GEMM Kernel Tuning
To achieve optimal performance we used the exhaustive search approach to tune our SplitK GEMM kernel. Granite-8B and Llama3-8B have linear layers with the following shapes:
Linear Layer
Shape (in_features, out_features)
Fused QKV Projection
(4096, 6144)
Output Projection
(4096, 4096)
Fused Gate + Up Projection
(4096, 28672)
Down Projection
(14336, 4096)
Figure 4. Granite-8B and Llama3-8B Linear Layer Weight Matrix Shapes
Each of these linear layers have different weight matrix shapes. Thus, for optimal performance the Triton kernel must be tuned for each of these shape profiles. After tuning for each linear layer we were able to achieve 1.20x E2E speedup on Llama3-8B and Granite-8B over the untuned Triton kernel.
6.0 Flash Attention Kernel
We evaluated a suite of existing Triton flash attention kernels with different configurations, namely:
AMD Flash
OpenAI Flash
Dao AI Lab Flash
XFormers Flash
PyTorch FlexAttention
We evaluated the text generation quality of each of these kernels, first, in eager mode and then (if we were able to torch.compile the kernel with standard methods) compile mode. For kernels 2-5, we noted the following:
Kernel
Text Generation Quality
Torch.compile
Support for Arbitrary Sequence Length
AMD Flash
Coherent
Yes
Yes
OpenAI Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
No
Dao AI Lab Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
Yes
Xformers FlashDecoding
Hit a compilation error before we were able to evaluate text quality
WIP
No (This kernel is optimized for decoding)
PyTorch FlexAttention
Coherent
WIP
WIP
Figure 5. Table of combinations we tried with different Flash Attention Kernels
The above table summarizes what we observed out-of-the box. With some effort we expect that kernels 2-5 can be modified to meet the above criteria. However, this also shows that having a kernel that works for benchmarking is often only the start of having it usable as an end to end production kernel.
We chose to use the AMD flash attention kernel in our subsequent tests as it can be compiled via torch.compile and produces legible output in both eager and compiled mode.
To satisfy torch.compile compatibility with the AMD flash attention kernel, we had to define it as a torch custom operator. This process is explained in detail here. The tutorial link discusses how to wrap a simple image crop operation. However, we note that wrapping a more complex flash attention kernel follows a similar process. The two step approach is as follows:
Wrap the function into a PyTorch Custom Operator
Add a FakeTensor Kernel to the operator, which given the shapes of the input tensors of flash (q, k and v) provides a way to compute the output shape of the flash kernel
After defining the Triton flash kernel as a custom op, we were able to successfully compile it for our E2E runs.
Figure 6. Trace of Llama3-8B with torch.compile, after swapping in Triton matmul and Triton flash attention kernels
From Figure 5, we note that now, after integrating both the SplitK matrix multiplication kernel, the torch op wrapped flash attention kernel, and then running torch.compile, we are able to achieve a forward pass that uses 100% Triton computation kernels.
7.0 End-to-End Benchmarks
We performed end-to-end measurements on NVIDIA H100s and A100s (single GPU) with Granite-8B and Llama3-8B models. We performed our benchmarks with two different configurations.
The Triton kernel configuration uses:
Triton SplitK GEMM
AMD Triton Flash Attention
The CUDA Kernel configuration uses:
cuBLAS GEMM
cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)
We found the following throughput and inter-token latencies for both eager and torch compiled modes, with typical inference settings:
GPU
Model
Kernel Config
Median Latency (Eager) [ms/tok]
Median Latency (Compiled) [ms/tok]
H100
Granite-8B
Triton
27.42
11.59
CUDA
18.84
9.50
Llama3-8B
Triton
20.36
10.61
CUDA
16.59
8.59
A100
Granite-8B
Triton
53.44
16.88
CUDA
37.13
14.25
Llama3-8B
Triton
44.44
17.94
CUDA
32.45
12.96
Figure 7. Granite-8B and Llama3-8B Single Token Generation Latency on H100 and A100,
(batch size = 2, input sequence length = 512, output sequence length = 256)
To summarize, the Triton models can get up to 78% of the performance of the CUDA models on the H100 and up to 82% on the A100.
The performance gap can be explained by the kernel latencies we observe for matmul and flash attention, which are discussed in the next section.
8.0 Microbenchmarks
Kernel
Triton [us]
CUDA [us]
QKV Projection Matmul
25
21
Flash Attention
13
8
Output Projection Matmul
21
17
Gate + Up Projection Matmul
84
83
Down Projection Matmul
58
42
Figure 8. Triton and CUDA Kernel Latency Comparison (Llama3-8B on NVIDIA H100)
Input was an arbitrary prompt (bs=1, prompt = 44 seq length), decoding latency time
From the above, we note the following:
Triton matmul kernels are 1.2-1.4x slower than CUDA
AMDs Triton Flash Attention kernel is 1.6x slower than CUDA SDPA
These results highlight the need to further improve the performance of kernels that are core primitives like GEMM and Flash Attention. We leave this as future research, as recent works (e.g. FlashAttention-3, FlexAttention) provide ways to leverage the underlying hardware better as well as Triton pathways that we hope to be able to build on to produce greater speedups. To illustrate this, we compared FlexAttention with SDPA and AMD’s Triton Flash kernel.
We are working to verify E2E performance with FlexAttention. For now, initial microbenchmarks with Flex show promise for longer context lengths and decoding problem shapes, where the query vector is small:
Figure 9. FlexAttention Kernel Benchmarks on NVIDIA H100 SXM5 80GB
(batch=1, num_heads=32, seq_len=seq_len, head_dim=128)
9.0 Future Work
For future work we plan to explore ways to further optimize our matmuls that leverage the hardware better, such as this blog we published on utilizing TMA for H100, as well as different work decompositions (persistent kernel techniques like StreamK etc.) to get greater speedups for our Triton-based approach. For flash attention, we plan to explore FlexAttention and FlashAttention-3 as the techniques used in these kernels can be leveraged to help further close the gap between Triton and CUDA.
We also note that our prior work has shown promising results for FP8 Triton GEMM kernel performance versus cuBLAS FP8 GEMM, thus in a future post we will explore E2E FP8 LLM inference.
" />
<meta property="og:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta name="twitter:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta property="og:locale" content="en_US" />
<meta property="og:type" content="website" />
<meta property="og:title" content="CUDA-Free Inference for LLMs" />
<meta property="og:description" content="In this blog, we discuss the methods we used to achieve FP16 inference with popular LLM models such as Meta’s Llama3-8B and IBM’s Granite-8B Code, where 100% of the computation is performed using OpenAI’s Triton Language.
For single token generation times using our Triton kernel based models, we were able to approach 0.76-0.78x performance relative to the CUDA kernel dominant workflows for both Llama and Granite on Nvidia H100 GPUs, and 0.62-0.82x on Nvidia A100 GPUs.
Why explore using 100% Triton? Triton provides a path for enabling LLMs to run on different types of GPUs - NVIDIA, AMD, and in the future Intel and other GPU based accelerators. It also provides a higher layer of abstraction in Python for programming GPUs and has allowed us to write performant kernels faster than authoring them using vendor specific APIs. In the rest of this blog, we will share how we achieve CUDA-free compute, micro-benchmark individual kernels for comparison, and discuss how we can further improve future Triton kernels to close the gaps.
Figure 1. Inference throughput benchmarks with Triton and CUDA variants of Llama3-8B and Granite-8B, on NVIDIA H100 and A100
Settings: batch size = 2, input sequence length = 512, output sequence length = 256
2.0 Composition of a Transformer Block
We start with a breakdown of the computations that happen in Transformer-based models. The figure below shows the “kernels” of a typical Transformer block.
Figure 2. Transformer Block by core kernels
The core operations for a Llama3 architecture are summarized in this list:
RMSNorm
Matrix multiplication: Fused QKV
RoPE
Attention
Matrix multiplication: Output Projection
RMSNorm
Matrix multiplication: Fused Gate + Up Projection
Activation function: SiLU
Element Wise Multiplication
Matrix multiplication: Down Projection
Each of these operations is computed on the GPU through the execution of one (or multiple) kernels. While the specifics of each of these kernels can vary across different transformer models, the core operations remain the same. For example, IBM’s Granite 8B Code model uses bias in the MLP layer, different from Llama3. Such changes do require modifications to the kernels. A typical model is a stack of these transformer blocks wired together with embedding layers.
3.0 Model Inference
Typical model architecture code is shared with a python model.py file that is launched by PyTorch. In the default PyTorch eager execution mode, these kernels are all executed with CUDA. To achieve 100% Triton for end-to-end Llama3-8B and Granite-8B inference we need to write and integrate handwritten Triton kernels as well as leverage torch.compile (to generate Triton ops). First, we replace smaller ops with compiler generated Triton kernels, and second, we replace more expensive and complex computations (e.g. matrix multiplication and flash attention) with handwritten Triton kernels.
Torch.compile generates Triton kernels automatically for RMSNorm, RoPE, SiLU and Element Wise Multiplication. Using tools like Nsight Systems we can observe these generated kernels; they appear as tiny dark green kernels in-between the matrix multiplications and attention.
Figure 3. Trace of Llama3-8B with torch.compile, showing CUDA kernels being used for matrix multiplications and flash attention
For the above trace, we note that the two major ops that make up 80% of the E2E latency in a Llama3-8B style model are matrix multiplication and attention kernels and both remain CUDA kernels. Thus to close the remaining gap, we replace both matmul and attention kernels with handwritten Triton kernels.
4.0 Triton SplitK GEMM Kernel
For the matrix multiplications in the linear layers, we wrote a custom FP16 Triton GEMM (General Matrix-Matrix Multiply) kernel that leverages a SplitK work decomposition. We have previously discussed this parallelization in other blogs as a way to accelerate the decoding portion of LLM inference.
5.0 GEMM Kernel Tuning
To achieve optimal performance we used the exhaustive search approach to tune our SplitK GEMM kernel. Granite-8B and Llama3-8B have linear layers with the following shapes:
Linear Layer
Shape (in_features, out_features)
Fused QKV Projection
(4096, 6144)
Output Projection
(4096, 4096)
Fused Gate + Up Projection
(4096, 28672)
Down Projection
(14336, 4096)
Figure 4. Granite-8B and Llama3-8B Linear Layer Weight Matrix Shapes
Each of these linear layers have different weight matrix shapes. Thus, for optimal performance the Triton kernel must be tuned for each of these shape profiles. After tuning for each linear layer we were able to achieve 1.20x E2E speedup on Llama3-8B and Granite-8B over the untuned Triton kernel.
6.0 Flash Attention Kernel
We evaluated a suite of existing Triton flash attention kernels with different configurations, namely:
AMD Flash
OpenAI Flash
Dao AI Lab Flash
XFormers Flash
PyTorch FlexAttention
We evaluated the text generation quality of each of these kernels, first, in eager mode and then (if we were able to torch.compile the kernel with standard methods) compile mode. For kernels 2-5, we noted the following:
Kernel
Text Generation Quality
Torch.compile
Support for Arbitrary Sequence Length
AMD Flash
Coherent
Yes
Yes
OpenAI Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
No
Dao AI Lab Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
Yes
Xformers FlashDecoding
Hit a compilation error before we were able to evaluate text quality
WIP
No (This kernel is optimized for decoding)
PyTorch FlexAttention
Coherent
WIP
WIP
Figure 5. Table of combinations we tried with different Flash Attention Kernels
The above table summarizes what we observed out-of-the box. With some effort we expect that kernels 2-5 can be modified to meet the above criteria. However, this also shows that having a kernel that works for benchmarking is often only the start of having it usable as an end to end production kernel.
We chose to use the AMD flash attention kernel in our subsequent tests as it can be compiled via torch.compile and produces legible output in both eager and compiled mode.
To satisfy torch.compile compatibility with the AMD flash attention kernel, we had to define it as a torch custom operator. This process is explained in detail here. The tutorial link discusses how to wrap a simple image crop operation. However, we note that wrapping a more complex flash attention kernel follows a similar process. The two step approach is as follows:
Wrap the function into a PyTorch Custom Operator
Add a FakeTensor Kernel to the operator, which given the shapes of the input tensors of flash (q, k and v) provides a way to compute the output shape of the flash kernel
After defining the Triton flash kernel as a custom op, we were able to successfully compile it for our E2E runs.
Figure 6. Trace of Llama3-8B with torch.compile, after swapping in Triton matmul and Triton flash attention kernels
From Figure 5, we note that now, after integrating both the SplitK matrix multiplication kernel, the torch op wrapped flash attention kernel, and then running torch.compile, we are able to achieve a forward pass that uses 100% Triton computation kernels.
7.0 End-to-End Benchmarks
We performed end-to-end measurements on NVIDIA H100s and A100s (single GPU) with Granite-8B and Llama3-8B models. We performed our benchmarks with two different configurations.
The Triton kernel configuration uses:
Triton SplitK GEMM
AMD Triton Flash Attention
The CUDA Kernel configuration uses:
cuBLAS GEMM
cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)
We found the following throughput and inter-token latencies for both eager and torch compiled modes, with typical inference settings:
GPU
Model
Kernel Config
Median Latency (Eager) [ms/tok]
Median Latency (Compiled) [ms/tok]
H100
Granite-8B
Triton
27.42
11.59
CUDA
18.84
9.50
Llama3-8B
Triton
20.36
10.61
CUDA
16.59
8.59
A100
Granite-8B
Triton
53.44
16.88
CUDA
37.13
14.25
Llama3-8B
Triton
44.44
17.94
CUDA
32.45
12.96
Figure 7. Granite-8B and Llama3-8B Single Token Generation Latency on H100 and A100,
(batch size = 2, input sequence length = 512, output sequence length = 256)
To summarize, the Triton models can get up to 78% of the performance of the CUDA models on the H100 and up to 82% on the A100.
The performance gap can be explained by the kernel latencies we observe for matmul and flash attention, which are discussed in the next section.
8.0 Microbenchmarks
Kernel
Triton [us]
CUDA [us]
QKV Projection Matmul
25
21
Flash Attention
13
8
Output Projection Matmul
21
17
Gate + Up Projection Matmul
84
83
Down Projection Matmul
58
42
Figure 8. Triton and CUDA Kernel Latency Comparison (Llama3-8B on NVIDIA H100)
Input was an arbitrary prompt (bs=1, prompt = 44 seq length), decoding latency time
From the above, we note the following:
Triton matmul kernels are 1.2-1.4x slower than CUDA
AMDs Triton Flash Attention kernel is 1.6x slower than CUDA SDPA
These results highlight the need to further improve the performance of kernels that are core primitives like GEMM and Flash Attention. We leave this as future research, as recent works (e.g. FlashAttention-3, FlexAttention) provide ways to leverage the underlying hardware better as well as Triton pathways that we hope to be able to build on to produce greater speedups. To illustrate this, we compared FlexAttention with SDPA and AMD’s Triton Flash kernel.
We are working to verify E2E performance with FlexAttention. For now, initial microbenchmarks with Flex show promise for longer context lengths and decoding problem shapes, where the query vector is small:
Figure 9. FlexAttention Kernel Benchmarks on NVIDIA H100 SXM5 80GB
(batch=1, num_heads=32, seq_len=seq_len, head_dim=128)
9.0 Future Work
For future work we plan to explore ways to further optimize our matmuls that leverage the hardware better, such as this blog we published on utilizing TMA for H100, as well as different work decompositions (persistent kernel techniques like StreamK etc.) to get greater speedups for our Triton-based approach. For flash attention, we plan to explore FlexAttention and FlashAttention-3 as the techniques used in these kernels can be leveraged to help further close the gap between Triton and CUDA.
We also note that our prior work has shown promising results for FP8 Triton GEMM kernel performance versus cuBLAS FP8 GEMM, thus in a future post we will explore E2E FP8 LLM inference.
" />
<meta property="og:site_name" content="PyTorch" />
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:title" content="CUDA-Free Inference for LLMs" />
<meta name="twitter:description" content="In this blog, we discuss the methods we used to achieve FP16 inference with popular LLM models such as Meta’s Llama3-8B and IBM’s Granite-8B Code, where 100% of the computation is performed using OpenAI’s Triton Language.
For single token generation times using our Triton kernel based models, we were able to approach 0.76-0.78x performance relative to the CUDA kernel dominant workflows for both Llama and Granite on Nvidia H100 GPUs, and 0.62-0.82x on Nvidia A100 GPUs.
Why explore using 100% Triton? Triton provides a path for enabling LLMs to run on different types of GPUs - NVIDIA, AMD, and in the future Intel and other GPU based accelerators. It also provides a higher layer of abstraction in Python for programming GPUs and has allowed us to write performant kernels faster than authoring them using vendor specific APIs. In the rest of this blog, we will share how we achieve CUDA-free compute, micro-benchmark individual kernels for comparison, and discuss how we can further improve future Triton kernels to close the gaps.
Figure 1. Inference throughput benchmarks with Triton and CUDA variants of Llama3-8B and Granite-8B, on NVIDIA H100 and A100
Settings: batch size = 2, input sequence length = 512, output sequence length = 256
2.0 Composition of a Transformer Block
We start with a breakdown of the computations that happen in Transformer-based models. The figure below shows the “kernels” of a typical Transformer block.
Figure 2. Transformer Block by core kernels
The core operations for a Llama3 architecture are summarized in this list:
RMSNorm
Matrix multiplication: Fused QKV
RoPE
Attention
Matrix multiplication: Output Projection
RMSNorm
Matrix multiplication: Fused Gate + Up Projection
Activation function: SiLU
Element Wise Multiplication
Matrix multiplication: Down Projection
Each of these operations is computed on the GPU through the execution of one (or multiple) kernels. While the specifics of each of these kernels can vary across different transformer models, the core operations remain the same. For example, IBM’s Granite 8B Code model uses bias in the MLP layer, different from Llama3. Such changes do require modifications to the kernels. A typical model is a stack of these transformer blocks wired together with embedding layers.
3.0 Model Inference
Typical model architecture code is shared with a python model.py file that is launched by PyTorch. In the default PyTorch eager execution mode, these kernels are all executed with CUDA. To achieve 100% Triton for end-to-end Llama3-8B and Granite-8B inference we need to write and integrate handwritten Triton kernels as well as leverage torch.compile (to generate Triton ops). First, we replace smaller ops with compiler generated Triton kernels, and second, we replace more expensive and complex computations (e.g. matrix multiplication and flash attention) with handwritten Triton kernels.
Torch.compile generates Triton kernels automatically for RMSNorm, RoPE, SiLU and Element Wise Multiplication. Using tools like Nsight Systems we can observe these generated kernels; they appear as tiny dark green kernels in-between the matrix multiplications and attention.
Figure 3. Trace of Llama3-8B with torch.compile, showing CUDA kernels being used for matrix multiplications and flash attention
For the above trace, we note that the two major ops that make up 80% of the E2E latency in a Llama3-8B style model are matrix multiplication and attention kernels and both remain CUDA kernels. Thus to close the remaining gap, we replace both matmul and attention kernels with handwritten Triton kernels.
4.0 Triton SplitK GEMM Kernel
For the matrix multiplications in the linear layers, we wrote a custom FP16 Triton GEMM (General Matrix-Matrix Multiply) kernel that leverages a SplitK work decomposition. We have previously discussed this parallelization in other blogs as a way to accelerate the decoding portion of LLM inference.
5.0 GEMM Kernel Tuning
To achieve optimal performance we used the exhaustive search approach to tune our SplitK GEMM kernel. Granite-8B and Llama3-8B have linear layers with the following shapes:
Linear Layer
Shape (in_features, out_features)
Fused QKV Projection
(4096, 6144)
Output Projection
(4096, 4096)
Fused Gate + Up Projection
(4096, 28672)
Down Projection
(14336, 4096)
Figure 4. Granite-8B and Llama3-8B Linear Layer Weight Matrix Shapes
Each of these linear layers have different weight matrix shapes. Thus, for optimal performance the Triton kernel must be tuned for each of these shape profiles. After tuning for each linear layer we were able to achieve 1.20x E2E speedup on Llama3-8B and Granite-8B over the untuned Triton kernel.
6.0 Flash Attention Kernel
We evaluated a suite of existing Triton flash attention kernels with different configurations, namely:
AMD Flash
OpenAI Flash
Dao AI Lab Flash
XFormers Flash
PyTorch FlexAttention
We evaluated the text generation quality of each of these kernels, first, in eager mode and then (if we were able to torch.compile the kernel with standard methods) compile mode. For kernels 2-5, we noted the following:
Kernel
Text Generation Quality
Torch.compile
Support for Arbitrary Sequence Length
AMD Flash
Coherent
Yes
Yes
OpenAI Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
No
Dao AI Lab Flash
Incoherent
Did not evaluate. WIP to debug precision in eager mode first
Yes
Xformers FlashDecoding
Hit a compilation error before we were able to evaluate text quality
WIP
No (This kernel is optimized for decoding)
PyTorch FlexAttention
Coherent
WIP
WIP
Figure 5. Table of combinations we tried with different Flash Attention Kernels
The above table summarizes what we observed out-of-the box. With some effort we expect that kernels 2-5 can be modified to meet the above criteria. However, this also shows that having a kernel that works for benchmarking is often only the start of having it usable as an end to end production kernel.
We chose to use the AMD flash attention kernel in our subsequent tests as it can be compiled via torch.compile and produces legible output in both eager and compiled mode.
To satisfy torch.compile compatibility with the AMD flash attention kernel, we had to define it as a torch custom operator. This process is explained in detail here. The tutorial link discusses how to wrap a simple image crop operation. However, we note that wrapping a more complex flash attention kernel follows a similar process. The two step approach is as follows:
Wrap the function into a PyTorch Custom Operator
Add a FakeTensor Kernel to the operator, which given the shapes of the input tensors of flash (q, k and v) provides a way to compute the output shape of the flash kernel
After defining the Triton flash kernel as a custom op, we were able to successfully compile it for our E2E runs.
Figure 6. Trace of Llama3-8B with torch.compile, after swapping in Triton matmul and Triton flash attention kernels
From Figure 5, we note that now, after integrating both the SplitK matrix multiplication kernel, the torch op wrapped flash attention kernel, and then running torch.compile, we are able to achieve a forward pass that uses 100% Triton computation kernels.
7.0 End-to-End Benchmarks
We performed end-to-end measurements on NVIDIA H100s and A100s (single GPU) with Granite-8B and Llama3-8B models. We performed our benchmarks with two different configurations.
The Triton kernel configuration uses:
Triton SplitK GEMM
AMD Triton Flash Attention
The CUDA Kernel configuration uses:
cuBLAS GEMM
cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)
We found the following throughput and inter-token latencies for both eager and torch compiled modes, with typical inference settings:
GPU
Model
Kernel Config
Median Latency (Eager) [ms/tok]
Median Latency (Compiled) [ms/tok]
H100
Granite-8B
Triton
27.42
11.59
CUDA
18.84
9.50
Llama3-8B
Triton
20.36
10.61
CUDA
16.59
8.59
A100
Granite-8B
Triton
53.44
16.88
CUDA
37.13
14.25
Llama3-8B
Triton
44.44
17.94
CUDA
32.45
12.96
Figure 7. Granite-8B and Llama3-8B Single Token Generation Latency on H100 and A100,
(batch size = 2, input sequence length = 512, output sequence length = 256)
To summarize, the Triton models can get up to 78% of the performance of the CUDA models on the H100 and up to 82% on the A100.
The performance gap can be explained by the kernel latencies we observe for matmul and flash attention, which are discussed in the next section.
8.0 Microbenchmarks
Kernel
Triton [us]
CUDA [us]
QKV Projection Matmul
25
21
Flash Attention
13
8
Output Projection Matmul
21
17
Gate + Up Projection Matmul
84
83
Down Projection Matmul
58
42
Figure 8. Triton and CUDA Kernel Latency Comparison (Llama3-8B on NVIDIA H100)