-
Notifications
You must be signed in to change notification settings - Fork 302
/
Copy pathindex.html
1915 lines (1735 loc) · 62.9 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html lang="en">
<head>
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-T8XT4PS');</script>
<!-- End Google Tag Manager -->
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico?">
<title>
Accelerating Generative AI with PyTorch: Segment Anything 2 - Fast and furious inference with low latency and fast cold starts | PyTorch
</title>
<meta name="robots" content="index, follow" />
<meta name="description" content="This post is a follow-up to our first entry in the multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.
" />
<meta property="og:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta name="twitter:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta property="og:locale" content="en_US" />
<meta property="og:type" content="website" />
<meta property="og:title" content="Accelerating Generative AI with PyTorch: Segment Anything 2 - Fast and furious inference with low latency and fast cold starts" />
<meta property="og:description" content="This post is a follow-up to our first entry in the multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.
" />
<meta property="og:site_name" content="PyTorch" />
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:title" content="Accelerating Generative AI with PyTorch: Segment Anything 2 - Fast and furious inference with low latency and fast cold starts" />
<meta name="twitter:description" content="This post is a follow-up to our first entry in the multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.
" />
<link rel="stylesheet" href="/assets/main.css">
<script src="/assets/vendor/jquery.min.js"></script>
<script src="/assets/vendor/popper.min.js"></script>
<script src="/assets/vendor/bootstrap.min.js"></script>
<script src="/assets/vendor/anchor.min.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
inlineMath: [['$','$']]
}
});
</script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js"></script>
<script>
!function(f,b,e,v,n,t,s)
{if(f.fbq)return;n=f.fbq=function(){n.callMethod?
n.callMethod.apply(n,arguments):n.queue.push(arguments)};
if(!f._fbq)f._fbq=n;n.push=n;n.loaded=!0;n.version='2.0';
n.queue=[];t=b.createElement(e);t.async=!0;
t.src=v;s=b.getElementsByTagName(e)[0];
s.parentNode.insertBefore(t,s)}(window,document,'script',
'https://connect.facebook.net/en_US/fbevents.js');
fbq('init', '243028289693773');
fbq('track', 'PageView');
</script>
<noscript>
<img height="1" width="1"
src="https://www.facebook.com/tr?id=243028289693773&ev=PageView
&noscript=1"/>
</noscript>
<!-- Twitter universal website tag code -->
<img height="1" width="1" style="display:none;" alt="" src="https://analytics.twitter.com/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__analytics.twitter.com_i_adsct-3Fp-5Fid-3DTwitter-26p-5Fuser-5Fid-3D0-26txn-5Fid-3Do2gi1-26events-3D-255B-255B-2522pageview-2522-252Cnull-255D-255D-26tw-5Fsale-5Famount-3D0-26tw-5Forder-5Fquantity-3D0&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=o6i4D0V0088WH2RnzIoqiF-vj45PL-2sTrsxQ0SNO3A&e=)" />
<img height="1" width="1" style="display:none;" alt="" src="//t.co/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__linkprotect.cudasvc.com_url-3Fa-3Dhttp-253a-252f-252ft.co-252fi-252fadsct-253fp-5Fid-253dTwitter-2526p-5Fuser-5Fid-253d0-2526txn-5Fid-253do2gi1-2526events-253d-25255B-25255B-252522pageview-252522-25252Cnull-25255D-25255D-2526tw-5Fsale-5Famount-253d0-2526tw-5Forder-5Fquantity-253d0-26c-3DE-2C1-2CC33dLwIhtuEcl5FhdztSnUwsioeej5k-2DWy0RYREBAq51kGji32A2Cw94YU9vQBpY5tPN3AukEw3C-5F-2DlbtndnLoR7-5FA-5FLoH0Rr7zLtP1ykptN-26typo-3D1&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=Abgc3XBkhESv8XBYtLchdDZyISGsK6v_BB6cLMJGyCw&e=)" />
<!-- End Twitter universal website tag code -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link href="/feed.xml" type="application/atom+xml" rel="alternate" title="Pythorch Blog Posts" />
</head>
<body class="blog">
<!-- Google Tag Manager (noscript) -->
<noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-T8XT4PS"
height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript>
<!-- End Google Tag Manager (noscript) -->
<div class="main-background blog-background blog-detail-background"></div>
<div class="hello-bar">
<div class="container">
Join us at PyTorch Conference in San Francisco, October 22-23. CFP open now! <a target="_blank" href="https://events.linuxfoundation.org/pytorch-conference/">Learn more</a>.
</div>
</div>
<div class="container-fluid header-holder blog-detail-header">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Learn
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/get-started">
<span class=dropdown-title>Get Started</span>
<p>Run PyTorch locally or get started quickly with one of the supported cloud platforms</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/">
<span class="dropdown-title">Tutorials</span>
<p>Whats new in PyTorch tutorials</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/basics/intro.html">
<span class="dropdown-title">Learn the Basics</span>
<p>Familiarize yourself with PyTorch concepts and modules</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/recipes/recipes_index.html">
<span class="dropdown-title">PyTorch Recipes</span>
<p>Bite-size, ready-to-deploy PyTorch code examples</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/introyt.html">
<span class="dropdown-title">Intro to PyTorch - YouTube Series</span>
<p>Master PyTorch basics with our engaging YouTube tutorial series</p>
</a>
<a class="nav-dropdown-item" href="/new">
<span class="dropdown-title">New to PyTorch Foundation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Ecosystem
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://landscape.pytorch.org/" target="_blank">
<span class="dropdown-title">Tools</span>
<p>Learn about the tools and frameworks in the PyTorch Ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/join-ecosystem">
<span class="dropdown-title">Join the Ecosystem</span>
</a>
<a class="nav-dropdown-item" href="/#community-module">
<span class=dropdown-title>Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org" target="_blank">
<span class=dropdown-title>Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="/resources">
<span class=dropdown-title>Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="/ecosystem/contributor-awards-2024">
<span class="dropdown-title">Contributor Awards - 2024</span>
<p>Award winners announced at this year's PyTorch Conference</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Edge
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/edge">
<span class="dropdown-title">About PyTorch Edge</span>
<p>Build innovative and privacy-aware AI experiences for edge devices</p>
</a>
<a class="nav-dropdown-item" href="/executorch-overview">
<span class="dropdown-title">ExecuTorch</span>
<p>End-to-end solution for enabling on-device inference capabilities across mobile and edge devices</p>
</a>
<a class="nav-dropdown-item" target="_blank" href="https://pytorch.org/executorch/stable/index.html">
<span class="dropdown-title">ExecuTorch Documentation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="docsDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/docs">
<span class="dropdown-title">PyTorch</span>
<p>Explore the documentation for comprehensive guidance on how to use PyTorch.</p>
</a>
<a class="nav-dropdown-item" href="/pytorch-domains">
<span class="dropdown-title">PyTorch Domains</span>
<p> Read the PyTorch Domains documentation to learn more about domain-specific libraries.</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Blog & News
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/blog">
<span class="dropdown-title">PyTorch Blog</span>
<p>Catch up on the latest technical news and happenings</p>
</a>
<a class="nav-dropdown-item" href="/community-blog">
<span class="dropdown-title">Community Blog</span>
<p>Stories from the PyTorch ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/videos">
<span class="dropdown-title">Videos</span>
<p>Learn about the latest PyTorch tutorials, new, and more </p>
</a>
<a class="nav-dropdown-item" href="/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch</p>
</a>
<a class="nav-dropdown-item" href="/events">
<span class=dropdown-title>Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
<a class="nav-dropdown-item" href="/newsletter">
<span class=dropdown-title>Newsletter</span>
<p>Stay up-to-date with the latest updates</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
About
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/foundation">
<span class=dropdown-title>PyTorch Foundation</span>
<p>Learn more about the PyTorch Foundation.</p>
</a>
<a class="nav-dropdown-item" href="/governing-board">
<span class=dropdown-title>Governing Board</span>
</a>
<a class="nav-dropdown-item" href="/credits">
<span class=dropdown-title>Cloud Credit Program</span>
</a>
<a class="nav-dropdown-item" href="/tac">
<span class=dropdown-title>Technical Advisory Council</span>
</a>
<a class="nav-dropdown-item" href="/staff">
<span class=dropdown-title>Staff</span>
</a>
<a class="nav-dropdown-item" href="/contact-us">
<span class=dropdown-title>Contact Us</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<a href="/join" data-cta="join">
Become a Member
</a>
</li>
<li class="main-menu-item" id="github-main-menu-link">
<a href="https://github.com/pytorch/pytorch" title="Go to PyTorch GitHub">
<div id="topnav-gh-icon"></div>
</a>
</li>
<li class="navSearchWrapper reactNavSearchWrapper" key="search">
<div class="search-border">
<div id="search-icon"></div>
<input
id="search-input"
type="text"
title="Search"
/>
<div id="close-search">X</div>
</div>
</li>
</ul>
</div>
<script src="/assets/main-menu-dropdown.js"></script>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<div class="jumbotron jumbotron-fluid blog-detail-jumbotron">
<div class="container blog-detail-container">
<p class="featured-post">February 26, 2025</p>
<h1>
<a class="blog-title">Accelerating Generative AI with PyTorch: Segment Anything 2 - Fast and furious inference with low latency and fast cold starts</a>
</h1>
</div>
</div>
<div class="main-content-wrapper blog-detail-wrapper">
<div class="main-content blog-detail-content">
<div class="container">
<img src="/assets/images/logo-icon.svg" class="img-fluid author-icon">
<article class="pytorch-article">
<p class="author">
by
Team PyTorch
</p>
<p>This post is a follow-up to our <a href="https://pytorch.org/blog/accelerating-generative-ai/">first entry in the multi-series blog focused on how to accelerate generative AI models</a> with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.</p>
<p>By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, reduced precision, batched prompts and GPU preprocessing we observe up to <strong>13x improvement in p90 execution latency</strong> and <strong>queue times compared to regular eager mode PyTorch</strong>.</p>
<p>We calculate our final results and demonstrate the improvement in a realistic deployment on auto-scaling cloud infrastructure from <a href="https://modal.com">Modal</a>.</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td colspan="2">p50 execution latency
<br />
(ms / improvement)
</td>
<td colspan="2">p90 execution latency
<br />
(ms / improvement)
</td>
</tr>
<tr>
<td>
</td>
<td>eager float32
</td>
<td>AOTI float16
</td>
<td>eager float32
</td>
<td>AOTI float16
</td>
</tr>
<tr>
<td>AMG
</td>
<td>741
</td>
<td>112 (6.6x)
</td>
<td>1140
</td>
<td>176 (6.5x)
</td>
</tr>
<tr>
<td>SPS
</td>
<td>98
</td>
<td>20 (4.9x)
</td>
<td>130
</td>
<td>28 (4.6x)
</td>
</tr>
<tr>
<td>MPS
</td>
<td>269
</td>
<td>38 (7.1x)
</td>
<td>714
</td>
<td>52 (13.7x)
</td>
</tr>
</table>
<table class="table table-bordered">
<tr>
<td>
</td>
<td colspan="2">p50 queue time (ms / improvement)
</td>
<td colspan="2">p90 queue time (ms / improvement)
</td>
</tr>
<tr>
<td>
</td>
<td>eager float32
</td>
<td>AOTI float16
</td>
<td>eager float32
</td>
<td>AOTI float16
</td>
</tr>
<tr>
<td>AMG
</td>
<td>201
</td>
<td>41 (4.9x)
</td>
<td>815
</td>
<td>327 (2.6x)
</td>
</tr>
<tr>
<td>SPS
</td>
<td>31
</td>
<td>33 (0.9x)
</td>
<td>441
</td>
<td>49 (9.0x)
</td>
</tr>
<tr>
<td>MPS
</td>
<td>40
</td>
<td>37 (1.1x)
</td>
<td>942
</td>
<td>75 (12.6x)
</td>
</tr>
</table>
<h2 id="the-tasks">The Tasks</h2>
<p>The first post focused on processing a small number of varying prompts (points of interest) per image. These points represented the center points of the ground truth masks. For this post, we’ll now focus on a broader set of tasks. Single prompt segmentation (SPS), multi prompt segmentation (MPS), automatic mask generation (AMG) which generates the full set of masks for the input image without a given set of prompts. The first post focused on MPS only.</p>
<p><img src="/assets/images/accelerating-generative-ai-2.jpg" alt="comparison of 3 images" style="width:100%" /></p>
<p>The little star in the image represents a user prompt. For AMG there are no prompts and masks are filtered down heuristically from a dense grid of initial candidate prompts (guesses). For SPS and MPS user prompts are derived from the center points of AMG masks. For SPS we choose the mask with the largest area.</p>
<p><strong>Note that SAM2 uses a different backbone than SAM1. In particular, we only consider the largest and most accurate sam2.1_hiera_large backbone for this blog.</strong></p>
<p>We aggregate the scripts needed to reproduce the results in <a href="https://github.com/pytorch/ao/tree/main/examples/sam2_amg_server">torchao’s example folder</a> and incrementally upstream the more stable parts of the <a href="https://github.com/pytorch/ao/tree/main/torchao/_models/sam2">changes to the SAM2 model in torchao</a> to the main <a href="https://github.com/facebookresearch/sam2">SAM2</a> repository. So if you are interested in taking a look at the cutting-edge variant or would like to contribute experimental features, please don’t hesitate to reach out to the torchao repository and team. For the more stable and latest model version, please head on over to SAM2 directly.</p>
<h2 id="overview">Overview</h2>
<p>We categorize the changes presented here into two. <strong>Fast</strong> changes constrain themselves to techniques that are not meant to affect model accuracy. <strong>Furious</strong> changes sacrifice some numerical accuracy for additional speed by making use of approximations such as low-precision data types.</p>
<p>Approximations may slightly lower precision metrics in favor of significantly improved performance while still passing an end-to-end check based on mean intersection over union (mIoU).</p>
<p>To measure the performance improvements we processed 1000 images, which were selected at random from the SAM2 validation dataset. We look at the p50 and p90 latency per image. To measure accuracy we consider the mIoU. Most notably for the AMG task we also define a fail count metric. We consider a comparison failed if the <strong>number of masks</strong> differs. This turns out to be a fairly unstable quantity and we can see that the other tasks are not as sensitive to small numeric changes as AMG.</p>
<h2 id="the-setup">The Setup</h2>
<p>We are running the offline experiments on a regular H100 devserver, which is a fairly beefy and performant machine.</p>
<p>However, we try to look at these tasks with realistic constraints. In particular, we would like to emulate a server-side inference environment. That means we don’t use DataLoader to hide the latency of image preprocessing or decoding routines.</p>
<p>For the latency calculations we include decoding, segmentation and conversion of masks to a dictionary of run-length encoded masks. Or put differently, we exclude loading the images into in-memory host bytearrays and storing the resulting dictionaries as json files on disk. This is meant to emulate a more realistic setting.</p>
<p>More concretely, consider the code below for the routines we include in our measurements. For any task <code class="language-plaintext highlighter-rouge">gen_masks</code> produces a batched bool Tensor bitmask that represents the corresponding object masks. We then compress this bitmask into a run length encoded (rle) format that can be used to transfer back the results from a remote server much more efficiently.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]
</code></pre></div></div>
<h2 id="optimizations">Optimizations</h2>
<h3 id="ao-eager-code-optimizations">ao: eager code optimizations</h3>
<p>The most effective tool for this work is the PyTorch autograd profiler combined with <code class="language-plaintext highlighter-rouge">record_function</code>. To build this software, we’ve used the profiler repeatedly to observe the program and confirm the effectiveness of any changes. It’s also important to keep in mind that the profiler itself has overhead. The more data you collect, such as stack traces, the more overhead you introduce, which might skew the collected trace. But it is excellent to find synchronization points, space between kernels and GPU kernels that take a long time.</p>
<p>GPU traces help you understand bottlenecks that are not necessarily easily addressed by compile. We found that AutomaticMaskGeneration in particular is dominated by the data structure used to store the masks and by the routine used to convert the masks to a run-length encoded compressed format. We also found a large part of AMG performance is dominated by the large number of masks created as a single batch. Sometimes candidate masks can be filtered down to fewer candidates earlier in the postprocessing stage by reordering operations. This in turn significantly speeds up the later operations.</p>
<p>In order to confirm the accuracy of our implementation we first compare without any changes in settings and using float32 precision. We see that mIoU is unchanged and the masks match perfectly when using the exact same settings. This means that these eager mode changes did not affect the accuracy of these tasks.</p>
<p>AMG</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU / fail count
</td>
</tr>
<tr>
<td>Baseline
</td>
<td>864
</td>
<td>1144
</td>
<td>4350
</td>
<td>reference
</td>
</tr>
<tr>
<td>AO
</td>
<td>693
</td>
<td>786
</td>
<td>4010
</td>
<td>1 / 0
</td>
</tr>
</table>
<h3 id="ao-batching-prompts">ao: batching prompts</h3>
<p>Another lossless performance optimization that we were able to apply is batching the user input prompt calculations. When optimizing for latency at batch size 1 on a server-grade GPU such as an H100 we are often left with a lot of spare memory. We can easily trade off that memory for more performance by processing more points of interest (also called user prompts) at once. Remember that SAM2 is split into two parts: First the backbone (image encoder), second the prediction and decoding of masks based on a set of user prompts / points of interest. It is the second part where we may expect a larger or even varying number of inputs and it is this second part where we apply batching.</p>
<p>This causes a large increase in memory, but also much better latency. The baseline generates one mask per prompt in a loop. For AMG the baseline processes 64 prompts at once and all that is needed is to change it to 1024, which is the number of candidate prompts generated. For SPS we process one prompt at a time, but it’s still included below for completeness.</p>
<p>AMG</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU / fail count
</td>
</tr>
<tr>
<td>Baseline
</td>
<td>864
</td>
<td>1144
</td>
<td>4350
</td>
<td>reference
</td>
</tr>
<tr>
<td>AO + batching
</td>
<td>613
</td>
<td>706
</td>
<td>33786
</td>
<td>0.9999995 / 0
</td>
</tr>
</table>
<p>SPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU
</td>
</tr>
<tr>
<td>Baseline
</td>
<td>116
</td>
<td>181
</td>
<td>1337
</td>
<td>reference
</td>
</tr>
<tr>
<td>AO
</td>
<td>110
</td>
<td>170
</td>
<td>1339
</td>
<td>1
</td>
</tr>
</table>
<p>MPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU
</td>
</tr>
<tr>
<td>Baseline
</td>
<td>276
</td>
<td>681
</td>
<td>1337
</td>
<td>reference
</td>
</tr>
<tr>
<td>AO + batching
</td>
<td>126
</td>
<td>225
</td>
<td>8021
</td>
<td>0.9999992
</td>
</tr>
</table>
<p>As a technical side note: Most notably to enable batching for MPS, and to avoid a significant manual rewrite of the code base to support multiple prompts at the same time, we used a Tensor subclass we call MapTensor. A MapTensor allows us to pass a batch of N prompts, but have it advertise a batch size of 1. Any operation is then automatically broadcast to the wrapped Tensor and propagated throughout the prediction part of the model. This works because individual prompt predictions are independent of one another. This is very similar to torch.vmap.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>center_points_torch = to_map_tensor(center_points_torch)
center_points_label_torch = to_map_tensor(center_points_label_torch)
masks, scores, _ = mask_generator.predictor.predict(
point_coords=center_points_torch,
point_labels=center_points_label_torch,
multimask_output=True,
return_logits=False,
return_type="torch",
)
# Unwrapping MapTensor
masks = masks.elems
scores = scores.elems
</code></pre></div></div>
<h3 id="fast-fullgraph-compilation">fast: fullgraph compilation</h3>
<p>Just as with our first post, we first remove GPU syncs and graph breaks to make use of fullgraph compiled model code with max-autotune kernels where appropriate. After some rewriting, we are able to compile the image encoder and the prediction of masks.</p>
<p>We run the experiments twice to get a sense of the overhead due to compilation. We run it once in an environment with an empty TORCHINDUCTOR_CACHE_DIR and then again while ingesting the artifacts from the previous run. In particular, auto-tuning can take a long time and happens on the first call in a pristine environment. We call the second run “warm”. The first iteration is typically expected to be slow due to various other related initialization processes, but compile increases it significantly, even if an existing cache is used and the same exact shapes are fed again. Having said that, an overhead of a few seconds in a warm environment is often still stomachable on the very first call.</p>
<p>Most of these drawbacks can be mitigated and compiling causes a significant improvement in latency and reduction in memory.</p>
<p>AMG</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU /
<br />
fail count
</td>
<td>first iteration
<br />
(ms)
</td>
</tr>
<tr>
<td>AO + batching
</td>
<td>613
</td>
<td>706
</td>
<td>33786
</td>
<td>0.9999995 / 0
</td>
<td>1125
</td>
</tr>
<tr>
<td>+ compile (cold)
</td>
<td>423
</td>
<td>513
</td>
<td>29349
</td>
<td>skipped
</td>
<td>404866
</td>
</tr>
<tr>
<td>+ compile (warm)
</td>
<td>439
</td>
<td>530
</td>
<td>29349
</td>
<td>0.994 / 190
</td>
<td>8544
</td>
</tr>
</table>
<p>The number of masks produced per mask can vary slightly when using automatic mask segmentation. There is ambiguity in the number of masks per object the model may produce. For example, a car may be subdivided into frames, windows and doors or treated as a whole. When a modification causes the number of masks to change, we consider the comparison failed and we only calculate the mIoU on masks with an exact match. This does not apply to the other tasks. We found that the number of masks generated is very sensitive to small numerical changes. The other tasks use the same code and MPS in particular can help us further verify correctness.</p>
<p>SPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU
</td>
<td>first iteration
<br />
(ms)
</td>
</tr>
<tr>
<td>AO
</td>
<td>110
</td>
<td>170
</td>
<td>1339
</td>
<td>1
</td>
<td>562
</td>
</tr>
<tr>
<td>+ compile (cold)
</td>
<td>102
</td>
<td>158
</td>
<td>1343
</td>
<td>skipped
</td>
<td>319954
</td>
</tr>
<tr>
<td>+ compile (warm)
</td>
<td>100
</td>
<td>160
</td>
<td>1302
</td>
<td>0.9999
</td>
<td>8947
</td>
</tr>
</table>
<p>MPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU
</td>
<td>first iteration
<br />
(ms)
</td>
</tr>
<tr>
<td>AO + batching
</td>
<td>126
</td>
<td>225
</td>
<td>8021
</td>
<td>0.9999992
</td>
<td>504
</td>
</tr>
<tr>
<td>+ compile (cold)
</td>
<td>129
</td>
<td>215
</td>
<td>8021
</td>
<td>skipped
</td>
<td>333308
</td>
</tr>
<tr>
<td>+ compile (warm)
</td>
<td>113
</td>
<td>213
</td>
<td>8021
</td>
<td>0.998
</td>
<td>8617
</td>
</tr>
</table>
<h3 id="furious-tf32-float16-and-gpu-preprocessing">furious: TF32, float16 and GPU preprocessing</h3>
<p>We found that using float16 is the right level of precision for a few significant subcomponents of the model. In particular, the image encoder and mask decoder weights can be converted entirely to float16. We can also use TensorFloat32 precision for the remaining float32 matrix operations. It should be possible to further reduce the precision and we may address this in a future post. We also move image preprocessing such as image normalization onto the GPU with the furious mode. We can’t use GPU decoding (nvJPEG) routines, because the differences are too significant and the model suffers from significant degradation in mIoU, so image decoding still happens on the CPU.</p>
<p>AMG</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU /
<br />
fail count
</td>
</tr>
<tr>
<td>AO
<br />
+ batching
<br />
+ compile (warm)
</td>
<td>439
</td>
<td>530
</td>
<td>29349
</td>
<td>0.994 / 190
</td>
</tr>
<tr>
<td>+ furious
</td>
<td>165
</td>
<td>240
</td>
<td>28335
</td>
<td>0.978 / 306
</td>
</tr>
</table>
<p>This causes a significant degradation in mIoU for the AMG task, but doesn’t affect the other tasks. After an in-depth investigation, we still chalk this up to numerical instability and reordering of operations. More work is needed to further investigate this and it may not be interesting to run the AMG task in lower precision. The other tasks, however, benefit drastically in latency with minimal changes in mIoU.</p>
<p>SPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>
<td>mIoU
</td>
</tr>
<tr>
<td>AO
<br />
+ compile (warm)
</td>
<td>100
</td>
<td>160
</td>
<td>1302
</td>
<td>0.9999
</td>
</tr>
<tr>
<td>+ furious
</td>
<td>32
</td>
<td>63
</td>
<td>861
</td>
<td>0.9997
</td>
</tr>
</table>
<p>MPS</p>
<table class="table table-bordered">
<tr>
<td>
</td>
<td>p50 latency (ms)
</td>
<td>p90 latency (ms)
</td>
<td>memory (MiB)
</td>