-
Notifications
You must be signed in to change notification settings - Fork 302
/
Copy pathindex.html
902 lines (716 loc) · 57.9 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
<!DOCTYPE html>
<html lang="en">
<head>
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-T8XT4PS');</script>
<!-- End Google Tag Manager -->
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico?">
<title>
Accelerating Generative AI with PyTorch: Segment Anything, Fast | PyTorch
</title>
<meta name="robots" content="index, follow" />
<meta name="description" content="This post is the first part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples of how these features can be combined to see how far we can push PyTorch native performance.
" />
<meta property="og:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta name="twitter:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta property="og:locale" content="en_US" />
<meta property="og:type" content="website" />
<meta property="og:title" content="Accelerating Generative AI with PyTorch: Segment Anything, Fast" />
<meta property="og:description" content="This post is the first part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples of how these features can be combined to see how far we can push PyTorch native performance.
" />
<meta property="og:site_name" content="PyTorch" />
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:title" content="Accelerating Generative AI with PyTorch: Segment Anything, Fast" />
<meta name="twitter:description" content="This post is the first part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples of how these features can be combined to see how far we can push PyTorch native performance.
" />
<link rel="stylesheet" href="/assets/main.css">
<script src="/assets/vendor/jquery.min.js"></script>
<script src="/assets/vendor/popper.min.js"></script>
<script src="/assets/vendor/bootstrap.min.js"></script>
<script src="/assets/vendor/anchor.min.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
inlineMath: [['$','$']]
}
});
</script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js"></script>
<script>
!function(f,b,e,v,n,t,s)
{if(f.fbq)return;n=f.fbq=function(){n.callMethod?
n.callMethod.apply(n,arguments):n.queue.push(arguments)};
if(!f._fbq)f._fbq=n;n.push=n;n.loaded=!0;n.version='2.0';
n.queue=[];t=b.createElement(e);t.async=!0;
t.src=v;s=b.getElementsByTagName(e)[0];
s.parentNode.insertBefore(t,s)}(window,document,'script',
'https://connect.facebook.net/en_US/fbevents.js');
fbq('init', '243028289693773');
fbq('track', 'PageView');
</script>
<noscript>
<img height="1" width="1"
src="https://www.facebook.com/tr?id=243028289693773&ev=PageView
&noscript=1"/>
</noscript>
<!-- Twitter universal website tag code -->
<img height="1" width="1" style="display:none;" alt="" src="https://analytics.twitter.com/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__analytics.twitter.com_i_adsct-3Fp-5Fid-3DTwitter-26p-5Fuser-5Fid-3D0-26txn-5Fid-3Do2gi1-26events-3D-255B-255B-2522pageview-2522-252Cnull-255D-255D-26tw-5Fsale-5Famount-3D0-26tw-5Forder-5Fquantity-3D0&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=o6i4D0V0088WH2RnzIoqiF-vj45PL-2sTrsxQ0SNO3A&e=)" />
<img height="1" width="1" style="display:none;" alt="" src="//t.co/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__linkprotect.cudasvc.com_url-3Fa-3Dhttp-253a-252f-252ft.co-252fi-252fadsct-253fp-5Fid-253dTwitter-2526p-5Fuser-5Fid-253d0-2526txn-5Fid-253do2gi1-2526events-253d-25255B-25255B-252522pageview-252522-25252Cnull-25255D-25255D-2526tw-5Fsale-5Famount-253d0-2526tw-5Forder-5Fquantity-253d0-26c-3DE-2C1-2CC33dLwIhtuEcl5FhdztSnUwsioeej5k-2DWy0RYREBAq51kGji32A2Cw94YU9vQBpY5tPN3AukEw3C-5F-2DlbtndnLoR7-5FA-5FLoH0Rr7zLtP1ykptN-26typo-3D1&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=Abgc3XBkhESv8XBYtLchdDZyISGsK6v_BB6cLMJGyCw&e=)" />
<!-- End Twitter universal website tag code -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link href="/feed.xml" type="application/atom+xml" rel="alternate" title="Pythorch Blog Posts" />
</head>
<body class="blog">
<!-- Google Tag Manager (noscript) -->
<noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-T8XT4PS"
height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript>
<!-- End Google Tag Manager (noscript) -->
<div class="main-background blog-background blog-detail-background"></div>
<div class="hello-bar">
<div class="container">
Join us at PyTorch Conference in San Francisco, October 22-23. CFP open now! <a target="_blank" href="https://events.linuxfoundation.org/pytorch-conference/">Learn more</a>.
</div>
</div>
<div class="container-fluid header-holder blog-detail-header">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Learn
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/get-started">
<span class=dropdown-title>Get Started</span>
<p>Run PyTorch locally or get started quickly with one of the supported cloud platforms</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/">
<span class="dropdown-title">Tutorials</span>
<p>Whats new in PyTorch tutorials</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/basics/intro.html">
<span class="dropdown-title">Learn the Basics</span>
<p>Familiarize yourself with PyTorch concepts and modules</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/recipes/recipes_index.html">
<span class="dropdown-title">PyTorch Recipes</span>
<p>Bite-size, ready-to-deploy PyTorch code examples</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/introyt.html">
<span class="dropdown-title">Intro to PyTorch - YouTube Series</span>
<p>Master PyTorch basics with our engaging YouTube tutorial series</p>
</a>
<a class="nav-dropdown-item" href="/new">
<span class="dropdown-title">New to PyTorch Foundation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Ecosystem
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://landscape.pytorch.org/" target="_blank">
<span class="dropdown-title">Tools</span>
<p>Learn about the tools and frameworks in the PyTorch Ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/join-ecosystem">
<span class="dropdown-title">Join the Ecosystem</span>
</a>
<a class="nav-dropdown-item" href="/#community-module">
<span class=dropdown-title>Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org" target="_blank">
<span class=dropdown-title>Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="/resources">
<span class=dropdown-title>Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="/ecosystem/contributor-awards-2024">
<span class="dropdown-title">Contributor Awards - 2024</span>
<p>Award winners announced at this year's PyTorch Conference</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Edge
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/edge">
<span class="dropdown-title">About PyTorch Edge</span>
<p>Build innovative and privacy-aware AI experiences for edge devices</p>
</a>
<a class="nav-dropdown-item" href="/executorch-overview">
<span class="dropdown-title">ExecuTorch</span>
<p>End-to-end solution for enabling on-device inference capabilities across mobile and edge devices</p>
</a>
<a class="nav-dropdown-item" target="_blank" href="https://pytorch.org/executorch/stable/index.html">
<span class="dropdown-title">ExecuTorch Documentation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="docsDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/docs">
<span class="dropdown-title">PyTorch</span>
<p>Explore the documentation for comprehensive guidance on how to use PyTorch.</p>
</a>
<a class="nav-dropdown-item" href="/pytorch-domains">
<span class="dropdown-title">PyTorch Domains</span>
<p> Read the PyTorch Domains documentation to learn more about domain-specific libraries.</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Blog & News
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/blog">
<span class="dropdown-title">PyTorch Blog</span>
<p>Catch up on the latest technical news and happenings</p>
</a>
<a class="nav-dropdown-item" href="/community-blog">
<span class="dropdown-title">Community Blog</span>
<p>Stories from the PyTorch ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/videos">
<span class="dropdown-title">Videos</span>
<p>Learn about the latest PyTorch tutorials, new, and more </p>
</a>
<a class="nav-dropdown-item" href="/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch</p>
</a>
<a class="nav-dropdown-item" href="/events">
<span class=dropdown-title>Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
<a class="nav-dropdown-item" href="/newsletter">
<span class=dropdown-title>Newsletter</span>
<p>Stay up-to-date with the latest updates</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
About
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/foundation">
<span class=dropdown-title>PyTorch Foundation</span>
<p>Learn more about the PyTorch Foundation.</p>
</a>
<a class="nav-dropdown-item" href="/governing-board">
<span class=dropdown-title>Governing Board</span>
</a>
<a class="nav-dropdown-item" href="/credits">
<span class=dropdown-title>Cloud Credit Program</span>
</a>
<a class="nav-dropdown-item" href="/tac">
<span class=dropdown-title>Technical Advisory Council</span>
</a>
<a class="nav-dropdown-item" href="/staff">
<span class=dropdown-title>Staff</span>
</a>
<a class="nav-dropdown-item" href="/contact-us">
<span class=dropdown-title>Contact Us</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<a href="/join" data-cta="join">
Become a Member
</a>
</li>
<li class="main-menu-item" id="github-main-menu-link">
<a href="https://github.com/pytorch/pytorch" title="Go to PyTorch GitHub">
<div id="topnav-gh-icon"></div>
</a>
</li>
<li class="navSearchWrapper reactNavSearchWrapper" key="search">
<div class="search-border">
<div id="search-icon"></div>
<input
id="search-input"
type="text"
title="Search"
/>
<div id="close-search">X</div>
</div>
</li>
</ul>
</div>
<script src="/assets/main-menu-dropdown.js"></script>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<div class="jumbotron jumbotron-fluid blog-detail-jumbotron">
<div class="container blog-detail-container">
<p class="featured-post">November 16, 2023</p>
<h1>
<a class="blog-title">Accelerating Generative AI with PyTorch: Segment Anything, Fast</a>
</h1>
</div>
</div>
<div class="main-content-wrapper blog-detail-wrapper">
<div class="main-content blog-detail-content">
<div class="container">
<img src="/assets/images/logo-icon.svg" class="img-fluid author-icon">
<article class="pytorch-article">
<p class="author">
by
Team PyTorch
</p>
<p>This post is the first part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples of how these features can be combined to see how far we can push PyTorch native performance.</p>
<p>As announced during the <a href="https://www.youtube.com/watch?v=IWpM_9AsC-U">PyTorch Developer Conference 2023</a>, the PyTorch team <a href="https://github.com/facebookresearch/segment-anything">rewrote Meta’s Segment Anything (“SAM”) Model</a> <strong>resulting in 8x faster code</strong> than <a href="https://github.com/facebookresearch/segment-anything">the original implementation</a>, with no loss of accuracy, all using native PyTorch optimizations. We leverage a breadth of new PyTorch features:</p>
<ul>
<li><a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html">Torch.compile</a>: A compiler for PyTorch models</li>
<li><a href="https://github.com/pytorch-labs/ao/tree/main#torchao">GPU quantization</a>: Accelerate models with reduced precision operations</li>
<li><a href="https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html">Scaled Dot Product Attention (SDPA)</a>: Memory efficient attention implementations</li>
<li><a href="https://pytorch.org/tutorials/prototype/semi_structured_sparse.html">Semi-Structured (2:4) Sparsity:</a> A GPU optimized sparse memory format</li>
<li><a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">Nested Tensor:</a> Batch together non-uniformly sized data into a single Tensor, such as images of different sizes.</li>
<li><strong>Custom operators with Triton:</strong> Write GPU operations using Triton Python DSL and easily integrate it into PyTorch’s various components with custom operator registration.</li>
</ul>
<p>We encourage readers to copy-paste code from <a href="https://github.com/pytorch-labs/segment-anything-fast">our implementation of SAM on Github</a> and <a href="https://github.com/pytorch-labs/segment-anything-fast/issues">ask us questions</a> on Github.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_7.png" alt="A quick glimpse of increasing throughput and decreasing memory overhead" style="width:100%;" /></p>
<p><em>A quick glimpse of increasing throughput and decreasing memory overhead with our newly released, PyTorch native, features. Benchmarks run on p4d.24xlarge instance (8x A100s).</em></p>
<h2 id="segmentanything-model">SegmentAnything Model</h2>
<p><a href="https://github.com/facebookresearch/segment-anything">SAM</a> is a zero-shot vision model for generating promptable image masks.</p>
<p><img src="/assets/images/accelerating-generative-ai/intro_image.jpg" alt="sam image masks" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p>The SAM architecture [described<a href="https://arxiv.org/abs/2304.02643"> in its paper</a>] includes multiple prompt and image encoders based on the Transformer architecture. Of this, we measured performance across the smallest and largest vision transformer backbones: <a href="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth">ViT-B</a> and <a href="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth">ViT-H</a>. And for simplicity, we only show traces for the ViT-B model.</p>
<h2 id="optimizations">Optimizations</h2>
<p>Below we tell the story of optimizing SAM: profiling, identifying bottlenecks, and building new features into PyTorch that solve these problems. Throughout, we showcase our new PyTorch features: <strong>torch.compile, SDPA, Triton kernels, Nested Tensor and semi-structured sparsity.</strong> The following sections are progressively built upon each other, ending with our SAM-fast, now <a href="https://github.com/pytorch-labs/segment-anything-fast">available on Github</a>. We motivate each feature using real kernel and memory traces, using fully PyTorch native tooling, and visualize these traces with <a href="https://perfetto.dev/">Perfetto UI</a>.</p>
<h3 id="baseline">Baseline</h3>
<p>Our SAM baseline is Facebook Research’s <a href="https://github.com/facebookresearch/segment-anything">unmodified model</a>, using float32 dtype and a batch size of 1. After some initial warmup, we can look at a kernel trace using the <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">PyTorch Profiler</a>:</p>
<p><img src="/assets/images/accelerating-generative-ai/baseline_trace.jpg" alt="kernel trace" style="width:100%;" /></p>
<p>We notice two areas ripe for optimization.</p>
<p>The first is long calls to aten::index, the underlying call resulting from a Tensor index operation (e.g., []). While the actual GPU time spent on aten::index is relatively low. aten::index is launching two kernels, and a blocking cudaStreamSynchronize is happening in between. This means the CPU is waiting for the GPU to finish processing until it launches the second kernel. To optimize SAM, we should aim to remove blocking GPU syncs causing idle time.</p>
<p>The second is significant time spent on GPU in matrix multiplication (dark green on stream 7 7 above). This is common in Transformers. We can significantly speed up SAM if we can reduce the amount of GPU time spent on matrix multiplication.</p>
<p>We can measure the throughput (img/s) and memory overhead (GiB) from out of the box SAM to establish a baseline:</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_0.png" alt="throughput (img/s) and memory overhead (GiB) from out of the box SAM" style="width:100%;" /></p>
<h3 id="bfloat16-half-precision-gpu-syncs-and-batching">Bfloat16 Half precision (+GPU syncs and batching)</h3>
<p>To address the first issue of less time spent in matrix multiplication, we can turn to <a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format">bfloat16</a>. Bfloat16 is a commonly used half-precision type. Through less precision per parameter and activations, we can save significant time and memory in computation. With reducing precision of parameters, it’s critical to validate end to end model accuracy.</p>
<p><img src="/assets/images/accelerating-generative-ai/bfloat16_snippet.jpg" alt="replacing padding dtypes with half precision, bfloat16" style="width:100%;" /></p>
<p><em>Shown here is an example of replacing padding dtypes with half precision, bfloat16. <a href="https://github.com/pytorch-labs/segment-anything-fast/blame/main/segment_anything_fast/modeling/prompt_encoder.py#L86">Code is here</a>.</em></p>
<p>Next to simply setting <code class="language-plaintext highlighter-rouge">model.to(torch.bfloat16)</code> we have to change a few small places that assume the default dtype.</p>
<p>Now, in order to remove GPU syncs we need to audit operations that cause them. We can find these pieces of code by searching the GPU traces for calls to <code class="language-plaintext highlighter-rouge">cudaStreamSynchronize</code>. In fact, we found two locations that we were able to rewrite to be sync-free.</p>
<p><img src="/assets/images/accelerating-generative-ai/code1.jpg" alt="code sample 1" style="width:100%;" /></p>
<p><img src="/assets/images/accelerating-generative-ai/bfloat16_snippet2.jpg" alt="replacing padding dtypes with half precision, bfloat16" style="width:100%;" /></p>
<p>Specifically, we see that within SAM’s image encoder, there are variables acting as coordinate scalers, q_coords and k_coords. These are both allocated and processed on the CPU. However, once these variables are used to index in rel_pos_resized, the index operation automatically moves these variables to the GPU. This copy over causes the GPU sync we’ve observed above. We notice a second call to index in SAM’s prompt encoder: We can use torch.where to rewrite this as shown above.</p>
<p><strong>Kernel trace</strong></p>
<p>After applying these changes, we begin to see significant time between individual kernel calls. This is typically observed with small batch sizes (1 here) due to the GPU overhead of launching kernels. To get a closer look at practical areas for optimization, we can start to profile SAM inference with batch size 8:</p>
<p><img src="/assets/images/accelerating-generative-ai/bfloat16_trace.jpg" alt="profile SAM inference with batch size 8" style="width:100%;" /></p>
<p>Looking at the time spent per-kernel, we obverse most of SAM’s GPU time spent on elementwise kernels and softmax operation. With this we now see that matrix multiplications have become a much smaller relative overhead.</p>
<p><img src="/assets/images/accelerating-generative-ai/bfloat16_kernels.jpg" alt="matrix multiplications have become a much smaller relative overhead" style="width:100%;" /></p>
<p>Taken the GPU sync and bfloat16 optimizations together, we have now pushed SAM performance by up to 3x</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_1.png" alt="SAM performance by up to 3x" style="width:100%;" /></p>
<h3 id="torchcompile-graph-breaks-and-cuda-graphs">Torch.compile (+graph breaks and CUDA graphs)</h3>
<p>When observing a large number of small operations, such as the elementwise kernels profiled above, turning to a compiler to fuse operations can have strong benefits. PyTorch’s recently released <strong>torch.compile</strong> does a great job optimizing by:</p>
<ol>
<li>Fusing together sequences of operations such as nn.LayerNorm or nn.GELU into a single GPU kernel that is called and</li>
<li>Epilogues: fusing operations that immediately follow matrix multiplication kernels to reduce the number of GPU kernel calls.</li>
</ol>
<p>Through these optimizations, we reduce the number of GPU global memory roundtrips, thus speeding up inference. We can now try torch.compile on SAM’s <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/3bd74614fe7285de4de3d763d8ec2e951c4c589c/experiments/eval_combo.py#L196-L201">image encoder</a>. To maximize performance we use a few advanced compile techniques such as:</p>
<ul>
<li>using torch.compile’s max-autotune mode enables <a href="https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/">CUDA graphs</a> and shape-specific kernels with custom epilogues</li>
<li>By setting TORCH_LOGS=”graph_breaks,recompiles” we can manually verify that we are not running into <a href="https://pytorch.org/docs/main/torch.compiler_faq.html#graph-breaks">graph breaks</a> or recompiles.</li>
<li>Padding the batch of images input to the encoder with zeros ensures compile accepts static shapes thus being able to always use shape-specific optimized kernels with custom epilogues without recompilations.</li>
</ul>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>predictor.model.image_encoder = \
torch.compile(predictor.model.image_encoder, mode=use_compile)
</code></pre></div></div>
<p><strong>Kernel trace</strong></p>
<p><img src="/assets/images/accelerating-generative-ai/compile_trace.jpg" alt="Kernel trace" style="width:100%;" /></p>
<p>torch.compile is working beautifully. We launch a single CUDA graph, which makes up a significant portion of GPU time within the timed region. Let’s run our profile again and look at the percentage of GPU time spent in specific kernels:</p>
<p><img src="/assets/images/accelerating-generative-ai/compile_kernels.jpg" alt="the percentage of GPU time spent in specific kernels" style="width:100%;" /></p>
<p>We now see softmax makes up a significant portion of the time followed by various GEMM variants. In summary we observe the following measurements for batch size 8 and above changes.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_2.png" alt="measurements for batch size 8 and above" style="width:100%;" /></p>
<h3 id="sdpa-scaled_dot_product_attention">SDPA: scaled_dot_product_attention</h3>
<p>Next up, we can tackle one of the most common areas for transformer performance overhead: the attention mechanism. Naive attention implementations scale quadratically in time and memory with sequence length. PyTorch’s <a href="https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention">scaled_dot_product_attention</a> operation built upon the principles of <a href="https://arxiv.org/pdf/2205.14135.pdf">Flash Attention</a>, <a href="https://github.com/Dao-AILab/flash-attention">FlashAttentionV2</a> and <a href="https://github.com/facebookresearch/xformers">xFormer’s memory efficient attention</a> can significantly speed up GPU attention. Combined with torch.compile, this operation allows us to express and fuse a common pattern within variants of MultiheadAttention. After <a href="https://github.com/facebookresearch/segment-anything/compare/50cb459d080bcd783a4b481d3bde4150d35ac497...7dc75fdf283693f73606f2fe7fdcb693afcb16b9">a small set of changes</a> we can adapt the model to use scaled_dot_product_attention.</p>
<p><img src="/assets/images/accelerating-generative-ai/sdpa_snippet.jpg" alt="PyTorch native attention implementation" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p><em>PyTorch native attention implementation, <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L236">see code here</a>.</em></p>
<p><strong>Kernel trace</strong></p>
<p>We can now see that in particular the memory efficient attention kernel is taking up a large amount of computational time on the GPU:</p>
<p><img src="/assets/images/accelerating-generative-ai/sdpa_kernels.jpg" alt="memory efficient attention kernel is taking up a large amount of computational time on the GPU" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p>Using PyTorch’s native scaled_dot_product_attention, we can significantly increase the batch size. We now observe the following measurements for batch size 32 and above changes.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_3.png" alt="batch size 32 and above" style="width:100%;" /></p>
<h3 id="triton-custom-sdpa-for-fused-relative-positional-encoding">Triton: Custom SDPA for fused relative positional encoding</h3>
<p>Transitioning away from inference throughput for a moment, we started profiling overall SAM memory. Within the image encoder, we saw significant spikes in memory allocation:</p>
<p><img src="/assets/images/accelerating-generative-ai/triton_trace.png" alt="spikes in memory allocation" style="width:100%;" /></p>
<p>Zooming in, we see this allocation happens within add_decomposed_rel_pos, <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L373">on the following line:</a></p>
<p><img src="/assets/images/accelerating-generative-ai/triton_snippet.jpg" alt="we see this allocation happens within add_decomposed_rel_pos" style="width:100%;display: block;max-width:500px; margin-left:auto; margin-right:auto;" /></p>
<p>The attn variable here is the addition of two smaller tensors: rel_h of shape (B, q_h, q_w, k_h, 1) and rel_w of shape (B, q_h, q_w, 1, k_w).</p>
<p>It’s not surprising that the memory efficient attention kernel (used via SDPA) is taking a long time with an attention bias size over 3.0GiB. If instead of allocating this large attn tensor, we thread into SDPA the two smaller rel_h and rel_w tensors, and only construct attn as needed, we’d anticipate significant performance gain.</p>
<p>Unfortunately this is not a trivial modification; SDPA kernels are highly optimized and written in CUDA. We can turn to Triton, with their easy to understand and use <a href="https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html">tutorial on a FlashAttention implementation</a>. After some significant digging and in close collaboration with xFormer’s Daniel Haziza we found one case of input shapes where it is relatively straightforward to implement a fused version of the kernel. The <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/flash_4.py">details have been added to the repository</a>. Surprisingly this can be done in under 350 lines of code for the inference case.</p>
<p>This is a great example of extending PyTorch with a new kernel, straightforwardly built with Triton code.</p>
<p><strong>Kernel trace</strong></p>
<p><img src="/assets/images/accelerating-generative-ai/triton_kernels.jpg" alt="kernel trace" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p>With our custom positional Triton kernel we observe the following measurements for batch size 32.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_4.png" alt="we observe the following measurements for batch size 32" style="width:100%;" /></p>
<h3 id="nt-nestedtensor-and-batching-predict_torch">NT: NestedTensor and batching predict_torch</h3>
<p>We have spent a lot of time on the image encoder. This makes sense, since it takes up the most amount of computational time. At this point however it is fairly well optimized and the operator that takes the most time would require significant additional investment to be improved.</p>
<p>We discovered an interesting observation with the <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/7cd6ba3cea451602acb7d36d176da06c70ac68f1/experiments/eval_combo.py#L137-L157">mask prediction pipeline</a>: for each image we have there is an associated size, coords, and fg_labels Tensor. Each of these tensors are of different batch sizes. Each image itself is also of a different size. This representation of data looks like <a href="https://en.wikipedia.org/wiki/Jagged_array">Jagged Data</a>. With PyTorch’s recently released <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">NestedTensor</a>, we can modify our data pipeline batch coords and fg_labels Tensors into a single NestedTensor. This can have significant performance benefits for the prompt encoder and mask decoder that follow the image encoder. Invoking:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)
</code></pre></div></div>
<p><strong>Kernel trace</strong></p>
<p><img src="/assets/images/accelerating-generative-ai/trace1.jpg" alt="Kernel trace" style="width:100%;" /></p>
<p><img src="/assets/images/accelerating-generative-ai/nt_kernel.jpg" alt="we can launch kernels much faster from the CPU than the GPU can process" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p>We can see now that we can launch kernels much faster from the CPU than the GPU can process and that it spends a long time waiting at the end of our timed region for the GPU to finish (cudaDeviceSynchronize). We also don’t see any more idle time (white space) between kernels on the GPU.</p>
<p>With Nested Tensor, we observe the following measurements for batch size 32 and above changes.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_5.png" alt="batch size 32 and above changes" style="width:100%;" /></p>
<h3 id="int8-quantization-and-approximating-matmul">int8: quantization and approximating matmul</h3>
<p>We notice in the above trace, that significant time is now spent in GEMM kernels. We’ve optimized enough that we now see matrix multiplication account for more time in inference than scaled dot product attention.</p>
<p>Building on earlier learnings going from fp32 to bfloat16, let’s go a step further, emulating even lower precision with int8 quantization. Looking at quantization methods, we focus on <a href="https://pytorch.org/tutorials/recipes/quantization.html">Dynamic quantization</a> wherein our model observes the range of possible inputs and weights of a layer, and subdivides the expressible int8 range to uniformly “spread out” observed values. Ultimately each float input will be mapped to a single integer in the range [-128, 127]. For more information see PyTorch’s <a href="https://pytorch.org/tutorials/recipes/quantization.html">tutorial on quantization</a></p>
<p>Reducing precision can immediately lead to peak memory savings, but to realize inference speedups, we have to make full use of int8 through SAM’s operations. This requires building an efficient int8@int8 matrix multiplication kernel, as well as casting logic to translate from high to low precision (quantization) as well as reversing back from low to high (dequantization). Utilizing the power of torch.compile, we can compile and fuse together these quantization and dequantization routines into efficient single kernels and epilogues of our matrix multiplication. The resulting implementation is <a href="https://github.com/pytorch-labs/segment-anything-fast/blob/21b0208ae46eefc5659f7f200a2bf447add8765b/segment_anything_fast/dynamic_quant.py">fairly short and less than 250 lines of code</a>. For more information on the APIs and usage, see <a href="https://github.com/pytorch-labs/ao/tree/main#torchao">pytorch-labs/ao</a>.</p>
<p>While it’s common to see some accuracy regression when quantizing models at inference time, SAM has been particularly robust to lower precision inference with minimal loss of accuracy. With quantization added, we now observe the following measurements for <strong>batch size 32</strong> and above changes.</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_6.png" alt="batch size 32 and above changes" style="width:100%;" /></p>
<h3 id="sparse-semi-structured-24-sparsity">sparse: Semi-structured (2:4) sparsity</h3>
<p>Matrix multiplications are still our bottleneck. We can turn to the model acceleration playbook with another classic method to approximate matrix multiplication: sparsification. By sparsifying our matrices (i.e., zeroing out values), we could theoretically use fewer bits to store weight and activation tensors. The process by which we decide which weights in the tensor to set to zero is called pruning. The idea behind pruning is that small weights in a weight tensor contribute little to the net output of a layer, typically the product of weights with activations. Pruning away small weights can potentially reduce model size without significant loss of accuracy.</p>
<p>Methods for pruning are varied, from completely unstructured, wherein weights are greedily pruned to highly structured, wherein large sub-components of a tensor are pruned a time. Choice of method is not trivial. While unstructured pruning may have the theoretically least impact on accuracy, GPUs are also highly efficient with multiplying large, dense matrices and may suffer significant performance degradation in sparse regimes. One recent pruning method supported in PyTorch seeks to strike a balance, called semi-structured (or 2:4) sparsity. This sparse storage reduces the original tensor by a significant 50%, while simultaneously resulting in a dense tensor output that can leverage highly performant, 2:4 GPU kernels. See the following picture for an illustration.</p>
<p><img src="/assets/images/accelerating-generative-ai/sparse_image.png" alt="dense tensor output that can leverage highly performant, 2:4 GPU kernels" style="width:100%;display: block;max-width:600px; margin-left:auto; margin-right:auto;" /></p>
<p>From <a href="https://developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt">developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt</a></p>
<p>In order to use this sparse storage format and the associated fast kernels we need to prune our weights such that they adhere to the constraints for the format. We pick the two smallest weights to prune in a 1 by 4 region, measuring the performance vs accuracy tradeoff. It is easy to change a weight from its default PyTorch (“strided”) layout to this new, semi-structured sparse layout. To implement <code class="language-plaintext highlighter-rouge">apply_sparse(model)</code> we only require 32 lines of Python code:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
# Sparsity helper functions
def apply_fake_sparsity(model):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
# torch.ao.pruning flow
from torch.ao.pruning import WeightNormSparsifier
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2)
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.step()
sparsifier.squash_mask()
def apply_sparse(model):
apply_fake_sparsity(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
</code></pre></div></div>
<p>With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32:</p>
<p><img src="/assets/images/accelerating-generative-ai/bar_chart_7.png" alt="With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32" style="width:100%;" /></p>
<h3 id="conclusion">Conclusion</h3>
<p>Wrapping up, we are excited to have<a href="https://www.youtube.com/watch?v=IWpM_9AsC-U"> announced</a> our fastest implementation of <a href="https://github.com/facebookresearch/segment-anything">Segment Anything</a> to date. We rewrote Meta’s original SAM in pure PyTorch with no loss of accuracy using a breadth of newly released features:</p>
<ul>
<li><strong>Torch.compile</strong> PyTorch’s native JIT compiler, providing fast, automated fusion of PyTorch operations [<a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html">tutorial</a>]</li>
<li><strong>GPU quantization</strong> accelerate models with reduced precision operations [<a href="https://github.com/pytorch-labs/ao/tree/main#torchao">api</a>]</li>
<li><strong>Scaled Dot Product Attention (SDPA)</strong> a new, memory efficient implementation of Attention [<a href="https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html">tutorial</a>]</li>
<li><strong>Semi-Structured (2:4) Sparsity</strong> accelerate models with fewer bits to store weights and activations [<a href="https://pytorch.org/tutorials/prototype/semi_structured_sparse.html">tutorial</a>]</li>
<li><strong>Nested Tensor</strong> Highly optimized, ragged array handling for non-uniform batch and image sizes [<a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">tutorial</a>]</li>
<li><strong>Triton kernels.</strong> Custom GPU operations, easily built and optimized via Triton</li>
</ul>
<p>For more details on how to reproduce the data presented in this blog post, check out <a href="https://github.com/pytorch-labs/segment-anything-fast/tree/main/experiments">the experiments folder of segment-anything-fast</a>. Please don’t hesitate to contact us or <a href="https://github.com/pytorch-labs/segment-anything-fast/issues/new">open an issue</a> if you run into any technical issues.</p>
<p>In our next post, we are excited to share similar performance gains with our PyTorch natively authored LLM!</p>
<h3 id="acknowledgements">Acknowledgements</h3>
<p>We would like to thank Meta’s <a href="https://github.com/facebookresearch/xformers">xFormers</a> team including Daniel Haziza and Francisco Massa for authoring SDPA kernels and helping us design our custom one-off Triton kernel.</p>
</article>
</div>
</div>
</div>
<!--
-->
<div class="container-fluid docs-tutorials-resources">
<div class="container">
<div class="row">
<div class="col-md-4 text-center">
<h2>Docs</h2>
<p>Access comprehensive developer documentation for PyTorch</p>
<a class="with-right-arrow" href="/docs">View Docs</a>
</div>
<div class="col-md-4 text-center">
<h2>Tutorials</h2>
<p>Get in-depth tutorials for beginners and advanced developers</p>
<a class="with-right-arrow" href="https://pytorch.org/tutorials">View Tutorials</a>
</div>
<div class="col-md-4 text-center">
<h2>Resources</h2>
<p>Find development resources and get your questions answered</p>
<a class="with-right-arrow" href="/resources">View Resources</a>
</div>
</div>
</div>
</div>
<footer class="site-footer">
<div class="container footer-container">
<div class="newsletter" id="newsletter">
<p
class="newsletter__title is-style-max-width-800"><strong>Stay in touch</strong> for updates, event info, and the latest news</p>
<script charset="utf-8" type="text/javascript" src="//js.hsforms.net/forms/embed/v2.js"></script>
<script>
hbspt.forms.create({
region: "na1",
portalId: "8112310",
formId: "2fb2231c-000b-4ec5-88a0-1ab242549c9e"
});
</script>
<p
class="newsletter__privacy">By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. <a href="https://www.linuxfoundation.org/privacy/">Privacy Policy</a>.</p>
</div>
<div class="lf-grid">
<div class="footer-logo-wrapper">
<a href="https://pytorch.org" class="footer-logo">
<img src="/assets/images/logo-icon.svg" alt="PyTorch logo" width="40">
</a>
</div>
<ul class="social-links">
<li><a href="https://www.facebook.com/pytorch" target="_blank" title="PyTorch on Facebook">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="-0.51 -0.26 26.45 26.45" aria-label="Facebook"><path fill="currentColor" d="M25.497 13.075c0-2.45-.698-4.848-2.011-6.911a12.765 12.765 0 0 0-5.398-4.73A12.671 12.671 0 0 0 11.008.38a12.705 12.705 0 0 0-6.529 2.95A12.827 12.827 0 0 0 .563 9.358a12.896 12.896 0 0 0-.07 7.201 12.831 12.831 0 0 0 3.801 6.103 12.709 12.709 0 0 0 6.471 3.078v-8.957H7.53v-3.708h3.235v-2.824c0-3.213 1.903-4.988 4.813-4.988.956.014 1.909.097 2.852.25V8.67h-1.607a1.83 1.83 0 0 0-1.518.497 1.854 1.854 0 0 0-.561 1.505v2.404h3.535l-.563 3.708h-2.97v8.957a12.725 12.725 0 0 0 7.697-4.337 12.87 12.87 0 0 0 3.054-8.328z"/></svg>
</a></li>
<li><a href="https://twitter.com/pytorch" target="_blank" title="PyTorch on X">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="0 0 300 300" aria-label="X"><path fill="currentColor" d="M178.57 127.15 290.27 0h-26.46l-97.03 110.38L89.34 0H0l117.13 166.93L0 300.25h26.46l102.4-116.59 81.8 116.59h89.34M36.01 19.54H76.66l187.13 262.13h-40.66"/></svg>
</a></li>
<li><a href="https://www.youtube.com/pytorch" target="_blank" title="PyTorch on YouTube">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="0.21 0.27 34.45 25.07" aria-label="YouTube"><path fill="currentColor" d="M33.729 6.084s-.327-2.33-1.317-3.356a4.691 4.691 0 0 0-3.32-1.432c-4.634-.34-11.589-.34-11.589-.34h-.014s-6.954 0-11.59.342a4.692 4.692 0 0 0-3.32 1.432c-.993 1.025-1.315 3.354-1.315 3.354a52.189 52.189 0 0 0-.331 5.473v2.566c.014 1.829.125 3.656.331 5.472 0 0 .322 2.33 1.316 3.36 1.26 1.345 2.916 1.3 3.653 1.445 2.65.26 11.263.34 11.263.34s6.96-.01 11.597-.353a4.691 4.691 0 0 0 3.32-1.432c.993-1.026 1.316-3.356 1.316-3.356.206-1.817.316-3.644.33-5.473v-2.57a52.26 52.26 0 0 0-.33-5.472zM14.076 17.232V7.729l8.951 4.768-8.95 4.735z"/></svg>
</a></li>
<li><a href="https://www.linkedin.com/company/pytorch" target="_blank" title="PyTorch on LinkedIn">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="-10.23 -10.23 531.96 531.96" aria-label="LinkedIn"><rect width="512" height="512" rx="0" fill="currentColor"/><circle fill="#000" cx="142" cy="138" r="37"/><path stroke="#000" stroke-width="66" d="M244 194v198M142 194v198"/><path fill="#000" d="M276 282c0-20 13-40 36-40 24 0 33 18 33 45v105h66V279c0-61-32-89-76-89-34 0-51 19-59 32"/></svg>
</a></li>
<li><a href="https://join.slack.com/t/pytorch/shared_invite/zt-2j2la612p-miUinTTaxXczKOJw48poHA" target="_blank" title="PyTorch Slack">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0.16 -0.03 21.19 21.19" aria-label="Slack"><path fill="currentColor" d="M4.896 13.27a2.147 2.147 0 0 1-2.141 2.142A2.147 2.147 0 0 1 .613 13.27c0-1.178.963-2.141 2.142-2.141h2.141v2.141zm1.08 0c0-1.178.962-2.141 2.141-2.141s2.142.963 2.142 2.141v5.363a2.147 2.147 0 0 1-2.142 2.141 2.147 2.147 0 0 1-2.141-2.142V13.27zm2.141-8.6a2.147 2.147 0 0 1-2.141-2.14c0-1.18.962-2.142 2.141-2.142s2.142.963 2.142 2.141v2.142H8.117zm0 1.08c1.179 0 2.141.962 2.141 2.141a2.147 2.147 0 0 1-2.141 2.142H2.755A2.147 2.147 0 0 1 .613 7.89c0-1.179.963-2.141 2.142-2.141h5.362zm8.599 2.141c0-1.179.963-2.141 2.141-2.141 1.179 0 2.143.962 2.143 2.14a2.147 2.147 0 0 1-2.142 2.142h-2.141V7.89zm-1.08 0a2.147 2.147 0 0 1-2.141 2.142 2.147 2.147 0 0 1-2.141-2.142V2.53c0-1.178.962-2.141 2.141-2.141s2.142.963 2.142 2.141v5.362zm-2.141 8.6c1.179 0 2.142.962 2.142 2.14a2.147 2.147 0 0 1-2.142 2.142 2.147 2.147 0 0 1-2.141-2.141V16.49h2.141zm0-1.08a2.147 2.147 0 0 1-2.141-2.141c0-1.179.962-2.142 2.141-2.142h5.362c1.179 0 2.142.963 2.142 2.142a2.147 2.147 0 0 1-2.142 2.142h-5.362z"></path></svg>
</a></li>
<li><a href="/wechat" title="PyTorch on WeChat">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0.14 -0.17 38.02 33.02" aria-label="WeChat"><path fill="currentColor" d="M26.289 10.976a12.972 12.972 0 0 0-8.742 3.53 10.386 10.386 0 0 0-3.224 8.795c-1.326-.164-2.535-.345-3.75-.448a2.332 2.332 0 0 0-1.273.216c-1.18.666-2.311 1.418-3.652 2.255.246-1.112.405-2.087.687-3.024a1.15 1.15 0 0 0-.523-1.52C1.737 17.902.02 13.601 1.307 9.165c1.189-4.1 4.11-6.587 8.077-7.884A13.54 13.54 0 0 1 24.18 5.617a10.135 10.135 0 0 1 2.109 5.359zM10.668 9.594a1.564 1.564 0 0 0-2.095-1.472 1.52 1.52 0 0 0-.895 1.964 1.502 1.502 0 0 0 1.391.966 1.545 1.545 0 0 0 1.598-1.46v.002zm8.15-1.566a1.567 1.567 0 0 0-1.528 1.543 1.528 1.528 0 0 0 1.571 1.492 1.52 1.52 0 0 0 1.375-2.117 1.518 1.518 0 0 0-1.415-.919l-.003.001z"></path><path fill="currentColor" d="M33.914 32.137c-1.075-.478-2.062-1.196-3.11-1.306-1.049-.11-2.145.494-3.24.605a10.821 10.821 0 0 1-8.781-2.864c-4.682-4.33-4.013-10.97 1.403-14.518 4.811-3.154 11.874-2.102 15.268 2.273a8.671 8.671 0 0 1-1.002 12.095c-1.046.929-1.422 1.693-.751 2.917.102.257.174.525.213.798zM21.68 20.292a1.264 1.264 0 1 0 .01-2.528 1.264 1.264 0 0 0-.01 2.528zm7.887-2.526a1.266 1.266 0 0 0-1.256 1.21 1.247 1.247 0 1 0 1.256-1.21z"></path></svg>
</a></li>
</ul>
</div>
<div class="privacy-policy">
<div class="copyright">
<p>© Copyright The Linux Foundation. The PyTorch Foundation is a project of The Linux Foundation.
For web site terms of use, trademark policy and other policies applicable to The PyTorch Foundation please see
<a href="https://www.linuxfoundation.org/legal/policies/">Linux Foundation Policies</a>. The PyTorch Foundation supports the PyTorch open source
project, which has been established as PyTorch Project a Series of LF Projects, LLC. For policies applicable to the PyTorch Project a Series of LF Projects, LLC,
please see <a href="https://www.lfprojects.org/policies/">LF Projects, LLC Policies</a>. <a href="https://www.linuxfoundation.org/privacy">Privacy Policy</a> and <a href="https://www.linuxfoundation.org/terms">Terms of Use</a>.</p>
</div>
</div>
</div>
</footer>
<div class="mobile-main-menu">
<div class="container-fluid">
<div class="container">
<div class="mobile-main-menu-header-container">
<a class="header-logo" href="https://pytorch.org" aria-label="PyTorch"></a>
<a class="main-menu-close-button" href="#" data-behavior="close-mobile-menu"></a>
</div>
</div>
</div>
<div class="mobile-main-menu-links-container">
<div class="main-menu">
<ul>
<li class="navSearchWrapper reactNavSearchWrapper tabletSearchWrapper" key="search">
<div class="mobile-search-border">
<input
id="mobile-search-input"
type="text"
title="Search"
/>
<div id="mobile-search-icon"></div>
</div>
</li>
<li class="resources-mobile-menu-title">
<a>Learn</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/basics/intro.html">Learn the Basics</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/recipes/recipes_index.html">PyTorch Recipes</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/introyt.html">Introduction to PyTorch - YouTube Series</a>
</li>
<li>
<a href="/new">New to PyTorch Foundation</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Ecosystem</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://landscape.pytorch.org/">Tools</a>
</li>
<li>
<a href="/join-ecosystem">Join the Ecosystem</a>
</li>
<li>
<a href="/#community-module">Community</a>
</li>
<li>
<a href="https://discuss.pytorch.org">Forums</a>
</li>
<li>
<a href="/resources">Developer Resources</a>
</li>
<li>
<a href="/ecosystem/contributor-awards-2024">Contributor Awards - 2024</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Edge</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/edge">About PyTorch Edge</a>
</li>
<li>
<a href="/executorch-overview">ExecuTorch</a>
</li>
<li>
<a href="https://pytorch.org/executorch/stable/index.html">ExecuTorch Documentation</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Docs</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/docs">PyTorch</a>
</li>
<li>
<a href="/pytorch-domains">PyTorch Domains</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Blog & News</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/blog">PyTorch Blog</a>
</li>
<li>
<a href="/community-blog">Community Blog</a>
</li>
<li>
<a href="/videos">Videos</a>
</li>
<li>
<a href="/community-stories">Community Stories</a>
</li>
<li>
<a href="/events">Events</a>
</li>
<li>
<a href="/newsletter">Newsletter</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>About</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/foundation">PyTorch Foundation</a>
</li>
<li>
<a href="/governing-board">Governing Board</a>
</li>
<li>
<a href="/credits">Cloud Credit Program</a>
</li>
<li>
<a href="/tac">Technical Advisory Council</a>
</li>
<li>
<a href="/staff">Staff</a>
</li>
<li>
<a href="/contact-us">Contact Us</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a href="/join">Become a Member</a>
</li>
<li class="resources-mobile-menu-title">
<a href="https://github.com/pytorch/pytorch" title="Go to PyTorch GitHub"><div id="topnav-gh-icon"></div></a>
</li>
</ul>
</div>
</div>
</div>
<script src="/assets/mobile-menu.js"></script>
<script src="/assets/scroll-to-anchor.js"></script>
<script src="/assets/external-links-new-tab.js"></script>
<script src="/assets/search-bar.js"></script>
<script src="/assets/cookie-banner.js"></script>
<script type="text/javascript">
mobileMenu.bind();
anchors.add('.pytorch-article h2, .pytorch-article h3, .pytorch-article h4, .pytorch-article h5');
// Add class to links that have code blocks, since we cannot create links in code blocks
$("a code.highlighter-rouge").each(function(e) {
$(this).closest("a").addClass("has-code");
});
scrollToAnchor.bind();
var hasStaticHeader = $(".blog-header, .blog-detail-header, .resources-header, .get-started-header, .features-header, .ecosystem-header, .hub-header, .mobile-header, .announcement-header, .comm-stories-header").length > 0;
if (!hasStaticHeader) {
$(window).on("scroll", function() {
var top = $(this).scrollTop();
var fullPosition = $(".main-background").height() - $(".header-holder").height();
if (top <= 40) {
$(".header-holder").css({"backgroundColor": "rgba(0, 0, 0, 0.165)"});
} else if (top >= fullPosition) {
$(".header-holder").css({"backgroundColor": "#000000"});
} else {
var bgColor = "rgba(0, 0, 0, " + top / fullPosition + ")";
$(".header-holder").css({"backgroundColor": bgColor});
}
});
}
</script>
<script src="/assets/track-events.js"></script>
<script>trackEvents.bind();</script>
<div class="cookie-banner-wrapper">
<div class="container">
<p class="gdpr-notice">To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: <a href="https://www.facebook.com/policies/cookies/">Cookies Policy</a>.</p>
<img class="close-button" src="/assets/images/pytorch-x.svg">
</div>
</div>
</body>
</html>