-
Notifications
You must be signed in to change notification settings - Fork 302
/
Copy pathindex.html
913 lines (784 loc) · 46.7 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
<!DOCTYPE html>
<html lang="en">
<head>
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-T8XT4PS');</script>
<!-- End Google Tag Manager -->
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico?">
<title>
Compiling NumPy code into C++ or CUDA via torch.compile | PyTorch
</title>
<meta name="robots" content="index, follow" />
<meta name="description" content="Quansight engineers have implemented support for tracing through NumPy code via
torch.compile in PyTorch 2.1. This feature leverages PyTorch’s compiler to
generate efficient fused vectorized code without having to modify your original
NumPy code. Even more, it also allows for executing NumPy code on CUDA
just by running it through torch.compile under torch.device("cuda")!
" />
<meta property="og:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta name="twitter:image" content="https://pytorch.org/assets/images/social-share.jpg" />
<meta property="og:locale" content="en_US" />
<meta property="og:type" content="website" />
<meta property="og:title" content="Compiling NumPy code into C++ or CUDA via torch.compile" />
<meta property="og:description" content="Quansight engineers have implemented support for tracing through NumPy code via
torch.compile in PyTorch 2.1. This feature leverages PyTorch’s compiler to
generate efficient fused vectorized code without having to modify your original
NumPy code. Even more, it also allows for executing NumPy code on CUDA
just by running it through torch.compile under torch.device("cuda")!
" />
<meta property="og:site_name" content="PyTorch" />
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:title" content="Compiling NumPy code into C++ or CUDA via torch.compile" />
<meta name="twitter:description" content="Quansight engineers have implemented support for tracing through NumPy code via
torch.compile in PyTorch 2.1. This feature leverages PyTorch’s compiler to
generate efficient fused vectorized code without having to modify your original
NumPy code. Even more, it also allows for executing NumPy code on CUDA
just by running it through torch.compile under torch.device("cuda")!
" />
<link rel="stylesheet" href="/assets/main.css">
<script src="/assets/vendor/jquery.min.js"></script>
<script src="/assets/vendor/popper.min.js"></script>
<script src="/assets/vendor/bootstrap.min.js"></script>
<script src="/assets/vendor/anchor.min.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
inlineMath: [['$','$']]
}
});
</script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js"></script>
<script>
!function(f,b,e,v,n,t,s)
{if(f.fbq)return;n=f.fbq=function(){n.callMethod?
n.callMethod.apply(n,arguments):n.queue.push(arguments)};
if(!f._fbq)f._fbq=n;n.push=n;n.loaded=!0;n.version='2.0';
n.queue=[];t=b.createElement(e);t.async=!0;
t.src=v;s=b.getElementsByTagName(e)[0];
s.parentNode.insertBefore(t,s)}(window,document,'script',
'https://connect.facebook.net/en_US/fbevents.js');
fbq('init', '243028289693773');
fbq('track', 'PageView');
</script>
<noscript>
<img height="1" width="1"
src="https://www.facebook.com/tr?id=243028289693773&ev=PageView
&noscript=1"/>
</noscript>
<!-- Twitter universal website tag code -->
<img height="1" width="1" style="display:none;" alt="" src="https://analytics.twitter.com/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__analytics.twitter.com_i_adsct-3Fp-5Fid-3DTwitter-26p-5Fuser-5Fid-3D0-26txn-5Fid-3Do2gi1-26events-3D-255B-255B-2522pageview-2522-252Cnull-255D-255D-26tw-5Fsale-5Famount-3D0-26tw-5Forder-5Fquantity-3D0&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=o6i4D0V0088WH2RnzIoqiF-vj45PL-2sTrsxQ0SNO3A&e=)" />
<img height="1" width="1" style="display:none;" alt="" src="//t.co/i/adsct?p_id=Twitter&p_user_id=0&txn_id=o2gi1&events=%5B%5B%22pageview%22%2Cnull%5D%5D&tw_sale_amount=0&tw_order_quantity=0 (https://urldefense.proofpoint.com/v2/url?u=https-3A__linkprotect.cudasvc.com_url-3Fa-3Dhttp-253a-252f-252ft.co-252fi-252fadsct-253fp-5Fid-253dTwitter-2526p-5Fuser-5Fid-253d0-2526txn-5Fid-253do2gi1-2526events-253d-25255B-25255B-252522pageview-252522-25252Cnull-25255D-25255D-2526tw-5Fsale-5Famount-253d0-2526tw-5Forder-5Fquantity-253d0-26c-3DE-2C1-2CC33dLwIhtuEcl5FhdztSnUwsioeej5k-2DWy0RYREBAq51kGji32A2Cw94YU9vQBpY5tPN3AukEw3C-5F-2DlbtndnLoR7-5FA-5FLoH0Rr7zLtP1ykptN-26typo-3D1&d=DwMGaQ&c=5VD0RTtNlTh3ycd41b3MUw&r=GMr8XYCDyeQQZuD3noL91A&m=dAJyokk16UvYy-vMrGn_JwYiGfp_eEgo25B9iGDCG-A&s=Abgc3XBkhESv8XBYtLchdDZyISGsK6v_BB6cLMJGyCw&e=)" />
<!-- End Twitter universal website tag code -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link href="/feed.xml" type="application/atom+xml" rel="alternate" title="Pythorch Blog Posts" />
</head>
<body class="blog">
<!-- Google Tag Manager (noscript) -->
<noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-T8XT4PS"
height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript>
<!-- End Google Tag Manager (noscript) -->
<div class="main-background blog-background blog-detail-background"></div>
<div class="hello-bar">
<div class="container">
Join us at PyTorch Conference in San Francisco, October 22-23. CFP open now! <a target="_blank" href="https://events.linuxfoundation.org/pytorch-conference/">Learn more</a>.
</div>
</div>
<div class="container-fluid header-holder blog-detail-header">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Learn
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/get-started">
<span class=dropdown-title>Get Started</span>
<p>Run PyTorch locally or get started quickly with one of the supported cloud platforms</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/">
<span class="dropdown-title">Tutorials</span>
<p>Whats new in PyTorch tutorials</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/basics/intro.html">
<span class="dropdown-title">Learn the Basics</span>
<p>Familiarize yourself with PyTorch concepts and modules</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/recipes/recipes_index.html">
<span class="dropdown-title">PyTorch Recipes</span>
<p>Bite-size, ready-to-deploy PyTorch code examples</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/introyt.html">
<span class="dropdown-title">Intro to PyTorch - YouTube Series</span>
<p>Master PyTorch basics with our engaging YouTube tutorial series</p>
</a>
<a class="nav-dropdown-item" href="/new">
<span class="dropdown-title">New to PyTorch Foundation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Ecosystem
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://landscape.pytorch.org/" target="_blank">
<span class="dropdown-title">Tools</span>
<p>Learn about the tools and frameworks in the PyTorch Ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/join-ecosystem">
<span class="dropdown-title">Join the Ecosystem</span>
</a>
<a class="nav-dropdown-item" href="/#community-module">
<span class=dropdown-title>Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org" target="_blank">
<span class=dropdown-title>Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="/resources">
<span class=dropdown-title>Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="/ecosystem/contributor-awards-2024">
<span class="dropdown-title">Contributor Awards - 2024</span>
<p>Award winners announced at this year's PyTorch Conference</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Edge
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/edge">
<span class="dropdown-title">About PyTorch Edge</span>
<p>Build innovative and privacy-aware AI experiences for edge devices</p>
</a>
<a class="nav-dropdown-item" href="/executorch-overview">
<span class="dropdown-title">ExecuTorch</span>
<p>End-to-end solution for enabling on-device inference capabilities across mobile and edge devices</p>
</a>
<a class="nav-dropdown-item" target="_blank" href="https://pytorch.org/executorch/stable/index.html">
<span class="dropdown-title">ExecuTorch Documentation</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="docsDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/docs">
<span class="dropdown-title">PyTorch</span>
<p>Explore the documentation for comprehensive guidance on how to use PyTorch.</p>
</a>
<a class="nav-dropdown-item" href="/pytorch-domains">
<span class="dropdown-title">PyTorch Domains</span>
<p> Read the PyTorch Domains documentation to learn more about domain-specific libraries.</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="dropdownMenuButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Blog & News
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/blog">
<span class="dropdown-title">PyTorch Blog</span>
<p>Catch up on the latest technical news and happenings</p>
</a>
<a class="nav-dropdown-item" href="/community-blog">
<span class="dropdown-title">Community Blog</span>
<p>Stories from the PyTorch ecosystem</p>
</a>
<a class="nav-dropdown-item" href="/videos">
<span class="dropdown-title">Videos</span>
<p>Learn about the latest PyTorch tutorials, new, and more </p>
</a>
<a class="nav-dropdown-item" href="/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch</p>
</a>
<a class="nav-dropdown-item" href="/events">
<span class=dropdown-title>Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
<a class="nav-dropdown-item" href="/newsletter">
<span class=dropdown-title>Newsletter</span>
<p>Stay up-to-date with the latest updates</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
About
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="/foundation">
<span class=dropdown-title>PyTorch Foundation</span>
<p>Learn more about the PyTorch Foundation.</p>
</a>
<a class="nav-dropdown-item" href="/governing-board">
<span class=dropdown-title>Governing Board</span>
</a>
<a class="nav-dropdown-item" href="/credits">
<span class=dropdown-title>Cloud Credit Program</span>
</a>
<a class="nav-dropdown-item" href="/tac">
<span class=dropdown-title>Technical Advisory Council</span>
</a>
<a class="nav-dropdown-item" href="/staff">
<span class=dropdown-title>Staff</span>
</a>
<a class="nav-dropdown-item" href="/contact-us">
<span class=dropdown-title>Contact Us</span>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<a href="/join" data-cta="join">
Become a Member
</a>
</li>
<li class="main-menu-item" id="github-main-menu-link">
<a href="https://github.com/pytorch/pytorch" title="Go to PyTorch GitHub">
<div id="topnav-gh-icon"></div>
</a>
</li>
<li class="navSearchWrapper reactNavSearchWrapper" key="search">
<div class="search-border">
<div id="search-icon"></div>
<input
id="search-input"
type="text"
title="Search"
/>
<div id="close-search">X</div>
</div>
</li>
</ul>
</div>
<script src="/assets/main-menu-dropdown.js"></script>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<div class="jumbotron jumbotron-fluid blog-detail-jumbotron">
<div class="container blog-detail-container">
<p class="featured-post">October 17, 2023</p>
<h1>
<a class="blog-title">Compiling NumPy code into C++ or CUDA via torch.compile</a>
</h1>
</div>
</div>
<div class="main-content-wrapper blog-detail-wrapper">
<div class="main-content blog-detail-content">
<div class="container">
<img src="/assets/images/logo-icon.svg" class="img-fluid author-icon">
<article class="pytorch-article">
<p class="author">
by
Evgeni Burovski, Ralf Gommers and Mario Lezcano
</p>
<p>Quansight engineers have implemented support for tracing through NumPy code via
<code class="language-plaintext highlighter-rouge">torch.compile</code> in PyTorch 2.1. This feature leverages PyTorch’s compiler to
generate efficient fused vectorized code without having to modify your original
NumPy code. Even more, it also allows for executing NumPy code on CUDA
just by running it through <code class="language-plaintext highlighter-rouge">torch.compile</code> under <code class="language-plaintext highlighter-rouge">torch.device("cuda")</code>!</p>
<p>In this post, we go over how to use this feature and give a few tips and tricks
to make the most out of it.</p>
<h2 id="compiling-numpy-code-into-parallel-c">Compiling NumPy code into Parallel C++</h2>
<p>We take as our running example one step in a K-Means algorithm.
This piece of code is borrowed from this <a href="https://realpython.com/numpy-array-programming/#clustering-algorithms">NumPy book</a></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import numpy as np
def kmeans(X, means):
return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)
</code></pre></div></div>
<p>We create a synthetic dataset with 20M random 2-D points. We can see that,
given that the means are chosen appropriately, the function returns the correct
cluster for all of them</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape) # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)
</code></pre></div></div>
<p>Benchmarking this function gives us a baseline of <strong>1.26s</strong> on an AMD 3970X CPU.</p>
<p>Compiling this function is now as easy as wrapping it with <code class="language-plaintext highlighter-rouge">torch.compile</code> and
executing it with the example inputs</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import torch
compiled_fn = torch.compile(kmeans)
compiled_pred = compiled_fn(X, means)
assert np.allclose(np_pred, compiled_pred)
</code></pre></div></div>
<p>The compiled function yields a 9x speed-up when running it on 1 core. Even
better, as opposed to NumPy, our generated code does take advantage of all the
cores in a processor. As such, when we run it on 32 cores, we get a <strong>57x
speed-up</strong>. Note that PyTorch always uses all the available cores unless
explicitly restricted, so this is the default behavior you get when using
<code class="language-plaintext highlighter-rouge">torch.compile</code>.</p>
<p>We may inspect the generated C++ code by running the script with the
environment variable <code class="language-plaintext highlighter-rouge">TORCH_LOGS=output_code</code>. When doing so, we can see that
<code class="language-plaintext highlighter-rouge">torch.compile</code> was able to compile the broadcasting and the two reductions
into just one for-loop, and parallelize it using OpenMP</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
#pragma omp parallel num_threads(32)
#pragma omp for
for(long i0=0L; i0<20000000L; i0+=1L) {
auto tmp0 = in_ptr0[2L*i0];
auto tmp1 = in_ptr1[0L];
auto tmp5 = in_ptr0[1L + (2L*i0)];
auto tmp6 = in_ptr1[1L];
// Rest of the kernel omitted for brevity
</code></pre></div></div>
<h2 id="compiling-numpy-code-into-cuda">Compiling NumPy code into CUDA</h2>
<p>Compiling our code so that it runs on CUDA is as simple as setting the
default device to be CUDA</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>with torch.device("cuda"):
cuda_pred = compiled_fn(X, means)
assert np.allclose(np_pred, cuda_pred)
</code></pre></div></div>
<p>By inspecting the generated code via <code class="language-plaintext highlighter-rouge">TORCH_LOGS=output_code</code>, we see that,
rather than generating CUDA code directly, <code class="language-plaintext highlighter-rouge">torch.compile</code> generates rather
readable <a href="https://triton-lang.org/main/index.html">triton</a> code</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
xnumel = 20000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
tmp1 = tl.load(in_ptr1 + (0))
// Rest of the kernel omitted for brevity
</code></pre></div></div>
<p>Running this small snippet on an RTX 2060 gives an <strong>8x speed-up</strong> over the
original NumPy code. This is something, but it is not particularly impressive,
given the speed-ups we have seen on CPU. Let’s have a look into how to squeeze
the most out of our GPU via a couple minor changes.</p>
<p><code class="language-plaintext highlighter-rouge">float64</code> vs <code class="language-plaintext highlighter-rouge">float32</code>. Many GPUs, in particular consumer-grade ones, are
rather sluggish when running operations on <code class="language-plaintext highlighter-rouge">float64</code>. For this reason, changing
the data generation to <code class="language-plaintext highlighter-rouge">float32</code>, the original NumPy code just gets a bit
faster, about a 9%, but our CUDA code gets <strong>40% faster</strong>, yielding a <strong>11x
speed-up</strong> over the plain NumPy code.</p>
<p><code class="language-plaintext highlighter-rouge">torch.compile</code>, by default, respects the NumPy semantics, and as such, it uses
<code class="language-plaintext highlighter-rouge">np.float64</code> as its default dtype for all its creation ops. As discussed, this
can hinder performance, so it is possible to change this default by setting</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>from torch._dynamo import config
config.numpy_default_float = "float32"
</code></pre></div></div>
<p><strong>CPU <> CUDA copies</strong>. An 11x speed-up is good, but it is not even close to
the CPU numbers. This is caused by a small transformation that <code class="language-plaintext highlighter-rouge">torch.compile
</code>does behind the scenes. The code above takes NumPy arrays and returns NumPy
arrays. All of these arrays are on CPU, but the computations are performed on
the GPU. This means that every time the function is called, <code class="language-plaintext highlighter-rouge">torch.compile</code> has
to copy all these arrays from CPU to the GPU, and then copy the result back to
CPU to preserve the original semantics. There is no native solution to this
issue in NumPy, as NumPy does not have the notion of a <code class="language-plaintext highlighter-rouge">device</code>. That being
said, we can work around it by creating a wrapper to this function so that it
accepts PyTorch tensors and returns PyTorch tensors.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@torch.compile
def tensor_fn(X, means):
X, means = X.numpy(), means.numpy()
ret = kmeans(X, means)
return torch.from_numpy(ret)
def cuda_fn(X, means):
with torch.device("cuda"):
return tensor_fn(X, means)
</code></pre></div></div>
<p>This function now takes tensors in CUDA memory and returns tensors in CUDA
memory, but the function itself is written in NumPy! <code class="language-plaintext highlighter-rouge">torch.compile</code> uses the
<code class="language-plaintext highlighter-rouge">numpy()</code> and the <code class="language-plaintext highlighter-rouge">from_numpy()</code> calls as hints, and optimizes them away, and
internally it simply works with PyTorch tensors without moving the memory at
all. When we keep the tensors in CUDA and perform the computations in
<code class="language-plaintext highlighter-rouge">float32</code>, we see a <strong>200x speed-up</strong> over the initial NumPy implementation on
<code class="language-plaintext highlighter-rouge">float32</code> arrays.</p>
<p><strong>Mixing NumPy and PyTorch</strong>. In this example, we had to write a small adaptor
to convert tensors to ndarrays and then back to tensors. In programs that mix
PyTorch and NumPy converting a tensor into an ndarray is often implemented as
<code class="language-plaintext highlighter-rouge">x.detach().cpu().numpy()</code>, or simply <code class="language-plaintext highlighter-rouge">x.numpy(force=True)</code>. Since when running
under <code class="language-plaintext highlighter-rouge">torch.compile</code> we can run NumPy code in CUDA, we can implement this
conversion pattern as call to <code class="language-plaintext highlighter-rouge">x.numpy()</code>, as we did above. Doing so and
running the resulting code under <code class="language-plaintext highlighter-rouge">device("cuda")</code> will generate efficient CUDA
code from original NumPy calls without copying the data from CUDA to CPU at
all. Note that the resulting code does not run without <code class="language-plaintext highlighter-rouge">torch.compile</code>. For it
to run in eager mode one would need to rollback to <code class="language-plaintext highlighter-rouge">x.numpy(force=True)</code>.</p>
<h2 id="further-speed-up-tricks">Further Speed-up tricks</h2>
<p><strong>General advice</strong>. The CUDA code we have shown is already quite efficient, but
it is true that the running example is rather short. When dealing with larger
programs, we may need to tweak parts of it to make it more efficient. A good
place to start is the multiple <a href="https://pytorch.org/docs/main/torch.compiler.html#read-more">tutorials and FAQs for torch.compile</a>.
This showcases a number of ways to inspect the tracing process, and how to
identify problematic code that may cause slowdowns.</p>
<p><strong>Advice when compiling NumPy code</strong>. NumPy, even if rather similar to PyTorch,
is often used very differently. It is rather common to perform computations in
NumPy and then do an if/else depending on values within the array, or perform
operations in-place, perhaps via boolean masks. These constructions, while
supported by <code class="language-plaintext highlighter-rouge">torch.compile</code>, hamper its performance. Changes like writing the
code in a branchless way to avoid graph breaks, or avoiding in-place ops can go
a long way.</p>
<p>To write fast NumPy code, it is best to avoid loops, but sometimes they are
unavoidable. When tracing through a loop, <code class="language-plaintext highlighter-rouge">torch.compile</code> will try to fully
unroll it. This is sometimes desirable, but sometimes it may not even be
possible, like when we have a dynamic stopping condition, like in a while loop.
In these cases, it may be best to just compile the body of the loop, perhaps a
few iterations at a time (loop unrolling).</p>
<p><strong>Debugging NumPy code</strong>. Debugging is rather tricky when a compiler is
involved. To figure out whether an error you are hitting is a <code class="language-plaintext highlighter-rouge">torch.compile
</code>error, or an error from the program, you can execute your NumPy program without
<code class="language-plaintext highlighter-rouge">torch.compile</code> by replacing the NumPy import by <code class="language-plaintext highlighter-rouge">import torch._numpy as np</code>.
This is should just be used for <strong>debugging purposes</strong> and is in no way a
replacement for the PyTorch API, as it is <strong>much slower</strong> and, as a private API,
<strong>may change without notice</strong>. See also <a href="https://pytorch.org/docs/stable/torch.compiler_faq.html#does-numpy-work-with-torch-compile">this FAQ</a> for other tricks.</p>
<h2 id="differences-between-numpy-and-torchcompile-numpy">Differences between NumPy and <code class="language-plaintext highlighter-rouge">torch.compile</code> NumPy</h2>
<p><strong>NumPy scalars</strong>. NumPy returns NumPy scalars in almost any case where PyTorch
would return a 0-D tensor (e.g. from <code class="language-plaintext highlighter-rouge">np.sum</code>). Under <code class="language-plaintext highlighter-rouge">torch.compile</code>, NumPy
scalars are treated as 0-D arrays. This is just fine in most cases. The only
case when their behavior diverges is when NumPy scalars are implicitly used as
Python scalars. For example,</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>>>> np.asarray(2) * [1, 2, 3] # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3] # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6]) # acts as a 0-D array, not as a scalar ?!?!
</code></pre></div></div>
<p>If we compile the first two lines, we see that <code class="language-plaintext highlighter-rouge">torch.compile</code> treats <code class="language-plaintext highlighter-rouge">u</code> as a
0-D array. To recover the eager semantics, we just need to make the casting
explicit</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]
</code></pre></div></div>
<p><strong>Type promotion and versioning</strong>. NumPy’s type promotion rules may be, at
times, a bit surprising</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>>>> np.zeros(1, dtype=np.int8) + 127
array([127], dtype=int8)
>>> np.zeros(1, dtype=np.int8) + 128
array([128], dtype=int16)
</code></pre></div></div>
<p>NumPy 2.0 is changing these rules to follow others that are closer to those
PyTorch. The relevant technical document is <a href="https://numpy.org/neps/nep-0050-scalar-promotion.html">NEP 50</a>.
<code class="language-plaintext highlighter-rouge">torch.compile</code> went ahead and implemented NEP 50 rather than the about-to-be-deprecated rules.</p>
<p>In general, NumPy within torch.compile follows NumPy 2.0 pre-release.</p>
<h2 id="beyond-numpy-scipy-and-scikit-learn">Beyond NumPy: SciPy and scikit-learn</h2>
<p>In parallel to this effort of making <code class="language-plaintext highlighter-rouge">torch.compile</code> understand NumPy code,
other Quansight engineers have designed and proposed a way to support PyTorch
tensors within scikit-learn and SciPy. This was received enthusiastically by
other maintainers from these libraries, as it was shown that using PyTorch as a
backend would often yield considerable speed-ups. Both projects have now merged
initial support for PyTorch tensors across a number of APIs and submodules.</p>
<p>This sets the stepping stone to move towards a future where PyTorch tensors can
be used within other libraries in the Python data ecosystem. Even more, this
will enable running these other libraries on GPUs and even compiling code
mixing these libraries and PyTorch, similar to what we have been discussed in
this post.</p>
<p>If you want to learn more about this effort, how to use it, or how to help
moving it forward, see <a href="https://labs.quansight.org/blog/array-api-support-scikit-learn">this other blogpost</a>.</p>
<h2 id="conclusion">Conclusion</h2>
<p>PyTorch has committed since its inception to be a framework compatible with the
rest of the Python ecosystem. Enabling compiling NumPy programs, and
establishing the tools necessary to do the same for other prominent libraries
are two more steps in this direction. Quansight and Meta continue working hand
on hand, improving the compatibility between PyTorch and the rest of the
ecosystem.</p>
<p>From Quansight, we would like to thank Mengwei, Voz, and Ed for their
invaluable help in integrating our work with <code class="language-plaintext highlighter-rouge">torch.compile</code>. We would also
like to thank Meta for funding this project as well as previous work on
improving NumPy compatibility within PyTorch, and the project that led to
supporting PyTorch within scikit-learn and SciPy. These are giant leaps towards
consolidating PyTorch as the framework of choice within the open source Python
data ecosystem.</p>
</article>
</div>
</div>
</div>
<!--
-->
<div class="container-fluid docs-tutorials-resources">
<div class="container">
<div class="row">
<div class="col-md-4 text-center">
<h2>Docs</h2>
<p>Access comprehensive developer documentation for PyTorch</p>
<a class="with-right-arrow" href="/docs">View Docs</a>
</div>
<div class="col-md-4 text-center">
<h2>Tutorials</h2>
<p>Get in-depth tutorials for beginners and advanced developers</p>
<a class="with-right-arrow" href="https://pytorch.org/tutorials">View Tutorials</a>
</div>
<div class="col-md-4 text-center">
<h2>Resources</h2>
<p>Find development resources and get your questions answered</p>
<a class="with-right-arrow" href="/resources">View Resources</a>
</div>
</div>
</div>
</div>
<footer class="site-footer">
<div class="container footer-container">
<div class="newsletter" id="newsletter">
<p
class="newsletter__title is-style-max-width-800"><strong>Stay in touch</strong> for updates, event info, and the latest news</p>
<script charset="utf-8" type="text/javascript" src="//js.hsforms.net/forms/embed/v2.js"></script>
<script>
hbspt.forms.create({
region: "na1",
portalId: "8112310",
formId: "2fb2231c-000b-4ec5-88a0-1ab242549c9e"
});
</script>
<p
class="newsletter__privacy">By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. <a href="https://www.linuxfoundation.org/privacy/">Privacy Policy</a>.</p>
</div>
<div class="lf-grid">
<div class="footer-logo-wrapper">
<a href="https://pytorch.org" class="footer-logo">
<img src="/assets/images/logo-icon.svg" alt="PyTorch logo" width="40">
</a>
</div>
<ul class="social-links">
<li><a href="https://www.facebook.com/pytorch" target="_blank" title="PyTorch on Facebook">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="-0.51 -0.26 26.45 26.45" aria-label="Facebook"><path fill="currentColor" d="M25.497 13.075c0-2.45-.698-4.848-2.011-6.911a12.765 12.765 0 0 0-5.398-4.73A12.671 12.671 0 0 0 11.008.38a12.705 12.705 0 0 0-6.529 2.95A12.827 12.827 0 0 0 .563 9.358a12.896 12.896 0 0 0-.07 7.201 12.831 12.831 0 0 0 3.801 6.103 12.709 12.709 0 0 0 6.471 3.078v-8.957H7.53v-3.708h3.235v-2.824c0-3.213 1.903-4.988 4.813-4.988.956.014 1.909.097 2.852.25V8.67h-1.607a1.83 1.83 0 0 0-1.518.497 1.854 1.854 0 0 0-.561 1.505v2.404h3.535l-.563 3.708h-2.97v8.957a12.725 12.725 0 0 0 7.697-4.337 12.87 12.87 0 0 0 3.054-8.328z"/></svg>
</a></li>
<li><a href="https://twitter.com/pytorch" target="_blank" title="PyTorch on X">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="0 0 300 300" aria-label="X"><path fill="currentColor" d="M178.57 127.15 290.27 0h-26.46l-97.03 110.38L89.34 0H0l117.13 166.93L0 300.25h26.46l102.4-116.59 81.8 116.59h89.34M36.01 19.54H76.66l187.13 262.13h-40.66"/></svg>
</a></li>
<li><a href="https://www.youtube.com/pytorch" target="_blank" title="PyTorch on YouTube">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="0.21 0.27 34.45 25.07" aria-label="YouTube"><path fill="currentColor" d="M33.729 6.084s-.327-2.33-1.317-3.356a4.691 4.691 0 0 0-3.32-1.432c-4.634-.34-11.589-.34-11.589-.34h-.014s-6.954 0-11.59.342a4.692 4.692 0 0 0-3.32 1.432c-.993 1.025-1.315 3.354-1.315 3.354a52.189 52.189 0 0 0-.331 5.473v2.566c.014 1.829.125 3.656.331 5.472 0 0 .322 2.33 1.316 3.36 1.26 1.345 2.916 1.3 3.653 1.445 2.65.26 11.263.34 11.263.34s6.96-.01 11.597-.353a4.691 4.691 0 0 0 3.32-1.432c.993-1.026 1.316-3.356 1.316-3.356.206-1.817.316-3.644.33-5.473v-2.57a52.26 52.26 0 0 0-.33-5.472zM14.076 17.232V7.729l8.951 4.768-8.95 4.735z"/></svg>
</a></li>
<li><a href="https://www.linkedin.com/company/pytorch" target="_blank" title="PyTorch on LinkedIn">
<svg xmlns="http://www.w3.org/2000/svg" viewbox="-10.23 -10.23 531.96 531.96" aria-label="LinkedIn"><rect width="512" height="512" rx="0" fill="currentColor"/><circle fill="#000" cx="142" cy="138" r="37"/><path stroke="#000" stroke-width="66" d="M244 194v198M142 194v198"/><path fill="#000" d="M276 282c0-20 13-40 36-40 24 0 33 18 33 45v105h66V279c0-61-32-89-76-89-34 0-51 19-59 32"/></svg>
</a></li>
<li><a href="https://join.slack.com/t/pytorch/shared_invite/zt-2j2la612p-miUinTTaxXczKOJw48poHA" target="_blank" title="PyTorch Slack">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0.16 -0.03 21.19 21.19" aria-label="Slack"><path fill="currentColor" d="M4.896 13.27a2.147 2.147 0 0 1-2.141 2.142A2.147 2.147 0 0 1 .613 13.27c0-1.178.963-2.141 2.142-2.141h2.141v2.141zm1.08 0c0-1.178.962-2.141 2.141-2.141s2.142.963 2.142 2.141v5.363a2.147 2.147 0 0 1-2.142 2.141 2.147 2.147 0 0 1-2.141-2.142V13.27zm2.141-8.6a2.147 2.147 0 0 1-2.141-2.14c0-1.18.962-2.142 2.141-2.142s2.142.963 2.142 2.141v2.142H8.117zm0 1.08c1.179 0 2.141.962 2.141 2.141a2.147 2.147 0 0 1-2.141 2.142H2.755A2.147 2.147 0 0 1 .613 7.89c0-1.179.963-2.141 2.142-2.141h5.362zm8.599 2.141c0-1.179.963-2.141 2.141-2.141 1.179 0 2.143.962 2.143 2.14a2.147 2.147 0 0 1-2.142 2.142h-2.141V7.89zm-1.08 0a2.147 2.147 0 0 1-2.141 2.142 2.147 2.147 0 0 1-2.141-2.142V2.53c0-1.178.962-2.141 2.141-2.141s2.142.963 2.142 2.141v5.362zm-2.141 8.6c1.179 0 2.142.962 2.142 2.14a2.147 2.147 0 0 1-2.142 2.142 2.147 2.147 0 0 1-2.141-2.141V16.49h2.141zm0-1.08a2.147 2.147 0 0 1-2.141-2.141c0-1.179.962-2.142 2.141-2.142h5.362c1.179 0 2.142.963 2.142 2.142a2.147 2.147 0 0 1-2.142 2.142h-5.362z"></path></svg>
</a></li>
<li><a href="/wechat" title="PyTorch on WeChat">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0.14 -0.17 38.02 33.02" aria-label="WeChat"><path fill="currentColor" d="M26.289 10.976a12.972 12.972 0 0 0-8.742 3.53 10.386 10.386 0 0 0-3.224 8.795c-1.326-.164-2.535-.345-3.75-.448a2.332 2.332 0 0 0-1.273.216c-1.18.666-2.311 1.418-3.652 2.255.246-1.112.405-2.087.687-3.024a1.15 1.15 0 0 0-.523-1.52C1.737 17.902.02 13.601 1.307 9.165c1.189-4.1 4.11-6.587 8.077-7.884A13.54 13.54 0 0 1 24.18 5.617a10.135 10.135 0 0 1 2.109 5.359zM10.668 9.594a1.564 1.564 0 0 0-2.095-1.472 1.52 1.52 0 0 0-.895 1.964 1.502 1.502 0 0 0 1.391.966 1.545 1.545 0 0 0 1.598-1.46v.002zm8.15-1.566a1.567 1.567 0 0 0-1.528 1.543 1.528 1.528 0 0 0 1.571 1.492 1.52 1.52 0 0 0 1.375-2.117 1.518 1.518 0 0 0-1.415-.919l-.003.001z"></path><path fill="currentColor" d="M33.914 32.137c-1.075-.478-2.062-1.196-3.11-1.306-1.049-.11-2.145.494-3.24.605a10.821 10.821 0 0 1-8.781-2.864c-4.682-4.33-4.013-10.97 1.403-14.518 4.811-3.154 11.874-2.102 15.268 2.273a8.671 8.671 0 0 1-1.002 12.095c-1.046.929-1.422 1.693-.751 2.917.102.257.174.525.213.798zM21.68 20.292a1.264 1.264 0 1 0 .01-2.528 1.264 1.264 0 0 0-.01 2.528zm7.887-2.526a1.266 1.266 0 0 0-1.256 1.21 1.247 1.247 0 1 0 1.256-1.21z"></path></svg>
</a></li>
</ul>
</div>
<div class="privacy-policy">
<div class="copyright">
<p>© Copyright The Linux Foundation. The PyTorch Foundation is a project of The Linux Foundation.
For web site terms of use, trademark policy and other policies applicable to The PyTorch Foundation please see
<a href="https://www.linuxfoundation.org/legal/policies/">Linux Foundation Policies</a>. The PyTorch Foundation supports the PyTorch open source
project, which has been established as PyTorch Project a Series of LF Projects, LLC. For policies applicable to the PyTorch Project a Series of LF Projects, LLC,
please see <a href="https://www.lfprojects.org/policies/">LF Projects, LLC Policies</a>. <a href="https://www.linuxfoundation.org/privacy">Privacy Policy</a> and <a href="https://www.linuxfoundation.org/terms">Terms of Use</a>.</p>
</div>
</div>
</div>
</footer>
<div class="mobile-main-menu">
<div class="container-fluid">
<div class="container">
<div class="mobile-main-menu-header-container">
<a class="header-logo" href="https://pytorch.org" aria-label="PyTorch"></a>
<a class="main-menu-close-button" href="#" data-behavior="close-mobile-menu"></a>
</div>
</div>
</div>
<div class="mobile-main-menu-links-container">
<div class="main-menu">
<ul>
<li class="navSearchWrapper reactNavSearchWrapper tabletSearchWrapper" key="search">
<div class="mobile-search-border">
<input
id="mobile-search-input"
type="text"
title="Search"
/>
<div id="mobile-search-icon"></div>
</div>
</li>
<li class="resources-mobile-menu-title">
<a>Learn</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/basics/intro.html">Learn the Basics</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/recipes/recipes_index.html">PyTorch Recipes</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/introyt.html">Introduction to PyTorch - YouTube Series</a>
</li>
<li>
<a href="/new">New to PyTorch Foundation</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Ecosystem</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://landscape.pytorch.org/">Tools</a>
</li>
<li>
<a href="/join-ecosystem">Join the Ecosystem</a>
</li>
<li>
<a href="/#community-module">Community</a>
</li>
<li>
<a href="https://discuss.pytorch.org">Forums</a>
</li>
<li>
<a href="/resources">Developer Resources</a>
</li>
<li>
<a href="/ecosystem/contributor-awards-2024">Contributor Awards - 2024</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Edge</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/edge">About PyTorch Edge</a>
</li>
<li>
<a href="/executorch-overview">ExecuTorch</a>
</li>
<li>
<a href="https://pytorch.org/executorch/stable/index.html">ExecuTorch Documentation</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Docs</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/docs">PyTorch</a>
</li>
<li>
<a href="/pytorch-domains">PyTorch Domains</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Blog & News</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/blog">PyTorch Blog</a>
</li>
<li>
<a href="/community-blog">Community Blog</a>
</li>
<li>
<a href="/videos">Videos</a>
</li>
<li>
<a href="/community-stories">Community Stories</a>
</li>
<li>
<a href="/events">Events</a>
</li>
<li>
<a href="/newsletter">Newsletter</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>About</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="/foundation">PyTorch Foundation</a>
</li>
<li>
<a href="/governing-board">Governing Board</a>
</li>
<li>
<a href="/credits">Cloud Credit Program</a>
</li>
<li>
<a href="/tac">Technical Advisory Council</a>
</li>
<li>
<a href="/staff">Staff</a>
</li>
<li>
<a href="/contact-us">Contact Us</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a href="/join">Become a Member</a>
</li>
<li class="resources-mobile-menu-title">
<a href="https://github.com/pytorch/pytorch" title="Go to PyTorch GitHub"><div id="topnav-gh-icon"></div></a>
</li>
</ul>
</div>
</div>
</div>
<script src="/assets/mobile-menu.js"></script>
<script src="/assets/scroll-to-anchor.js"></script>
<script src="/assets/external-links-new-tab.js"></script>
<script src="/assets/search-bar.js"></script>
<script src="/assets/cookie-banner.js"></script>
<script type="text/javascript">
mobileMenu.bind();
anchors.add('.pytorch-article h2, .pytorch-article h3, .pytorch-article h4, .pytorch-article h5');
// Add class to links that have code blocks, since we cannot create links in code blocks
$("a code.highlighter-rouge").each(function(e) {
$(this).closest("a").addClass("has-code");
});
scrollToAnchor.bind();
var hasStaticHeader = $(".blog-header, .blog-detail-header, .resources-header, .get-started-header, .features-header, .ecosystem-header, .hub-header, .mobile-header, .announcement-header, .comm-stories-header").length > 0;
if (!hasStaticHeader) {
$(window).on("scroll", function() {
var top = $(this).scrollTop();
var fullPosition = $(".main-background").height() - $(".header-holder").height();
if (top <= 40) {
$(".header-holder").css({"backgroundColor": "rgba(0, 0, 0, 0.165)"});
} else if (top >= fullPosition) {
$(".header-holder").css({"backgroundColor": "#000000"});
} else {
var bgColor = "rgba(0, 0, 0, " + top / fullPosition + ")";
$(".header-holder").css({"backgroundColor": bgColor});
}
});
}
</script>
<script src="/assets/track-events.js"></script>
<script>trackEvents.bind();</script>
<div class="cookie-banner-wrapper">
<div class="container">
<p class="gdpr-notice">To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: <a href="https://www.facebook.com/policies/cookies/">Cookies Policy</a>.</p>
<img class="close-button" src="/assets/images/pytorch-x.svg">
</div>
</div>
</body>
</html>