-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathfunctional.py
898 lines (765 loc) · 43.3 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
import torch
from typing import Tuple, List
from torch._vmap_internals import _vmap
# Utility functions
def _as_tuple(inp, arg_name, fn_name):
# Ensures that inp is a tuple of Tensors
# Returns whether or not the original inp was a tuple and the tupled version of the input
is_inp_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
is_inp_tuple = False
for i, el in enumerate(inp):
if not isinstance(el, torch.Tensor):
if is_inp_tuple:
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
" value at index {} has type {}.".format(arg_name, fn_name, i, type(el)))
else:
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
" given {} has type {}.".format(arg_name, fn_name, arg_name, type(el)))
return is_inp_tuple, inp
def _tuple_postprocess(res, to_unpack):
# Unpacks a potentially nested tuple of Tensors
# to_unpack should be a single boolean or a tuple of two booleans.
# It is used to:
# - invert _as_tuple when res should match the inp given to _as_tuple
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
if isinstance(to_unpack, tuple):
assert len(to_unpack) == 2
if not to_unpack[1]:
res = tuple(el[0] for el in res)
if not to_unpack[0]:
res = res[0]
else:
if not to_unpack:
res = res[0]
return res
def _grad_preprocess(inputs, create_graph, need_graph):
# Preprocess the inputs to make sure they require gradient
# inputs is a tuple of Tensors to preprocess
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
# need_graph specifies if we internally want gradients to flow back to the Tensors in res
# Note that we *always* create a new Tensor object to be able to see the difference between
# inputs given as arguments and the same Tensors automatically captured by the user function.
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
res = []
for inp in inputs:
if create_graph and inp.requires_grad:
# Create at least a new Tensor object in a differentiable way
if not inp.is_sparse:
# Use .view_as() to get a shallow copy
res.append(inp.view_as(inp))
else:
# We cannot use view for sparse Tensors so we clone
res.append(inp.clone())
else:
res.append(inp.detach().requires_grad_(need_graph))
return tuple(res)
def _grad_postprocess(inputs, create_graph):
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
# request it.
if isinstance(inputs[0], torch.Tensor):
if not create_graph:
return tuple(inp.detach() for inp in inputs)
else:
return inputs
else:
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
def _validate_v(v, other, is_other_tuple):
# This assumes that other is the correct shape, and v should match
# Both are assumed to be tuples of Tensors
if len(other) != len(v):
if is_other_tuple:
raise RuntimeError("v is a tuple of invalid length: should be {} but got {}.".format(len(other), len(v)))
else:
raise RuntimeError("The given v should contain a single Tensor.")
for idx, (el_v, el_other) in enumerate(zip(v, other)):
if el_v.size() != el_other.size():
prepend = ""
if is_other_tuple:
prepend = "Entry {} in ".format(idx)
raise RuntimeError("{}v has invalid size: should be {} but got {}.".format(
prepend, el_other.size(), el_v.size()))
def _check_requires_grad(inputs, input_type, strict):
# Used to make all the necessary checks to raise nice errors in strict mode.
if not strict:
return
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
raise RuntimeError("Invalid input_type to _check_requires_grad")
for i, inp in enumerate(inputs):
if inp is None:
# This can only be reached for grad_inputs.
raise RuntimeError("The output of the user-provided function is independent of input {}."
" This is not allowed in strict mode.".format(i))
if not inp.requires_grad:
if input_type == "hessian":
raise RuntimeError("The hessian of the user-provided function with respect to input {}"
" is independent of the input. This is not allowed in strict mode."
" You should ensure that your function is thrice differentiable and that"
" the hessian depends on the inputs.".format(i))
elif input_type == "jacobian":
raise RuntimeError("While computing the hessian, found that the jacobian of the user-provided"
" function with respect to input {} is independent of the input. This is not"
" allowed in strict mode. You should ensure that your function is twice"
" differentiable and that the jacobian depends on the inputs (this would be"
" violated by a linear function for example).".format(i))
elif input_type == "grad_inputs":
raise RuntimeError("The gradient with respect to input {} is independent of the inputs of the"
" user-provided function. This is not allowed in strict mode.".format(i))
else:
raise RuntimeError("Output {} of the user-provided function does not require gradients."
" The outputs must be computed in a differentiable manner from the input"
" when running in strict mode.".format(i))
def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None):
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
# This has the extra constraint that inputs has to be a tuple
assert isinstance(outputs, tuple)
if grad_outputs is None:
grad_outputs = (None,) * len(outputs)
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
new_outputs: Tuple[torch.Tensor, ...] = tuple()
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
new_grad_outputs += (grad_out,)
if len(new_outputs) == 0:
# No differentiable output, we don't need to call the autograd engine
return (None,) * len(inputs)
else:
return torch.autograd.grad(new_outputs, inputs, new_grad_outputs, allow_unused=True,
create_graph=create_graph, retain_graph=retain_graph)
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
# Used to detect None in the grads and depending on the flags, either replace them
# with Tensors full of 0s of the appropriate size based on the refs or raise an error.
# strict and create graph allow us to detect when it is appropriate to raise an error
# stage gives us information of which backward call we consider to give good error message
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage))
res: Tuple[torch.Tensor, ...] = tuple()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
if stage == "back":
raise RuntimeError("The output of the user-provided function is independent of "
"input {}. This is not allowed in strict mode.".format(i))
elif stage == "back_trick":
raise RuntimeError("The gradient with respect to the input is independent of entry {}"
" in the grad_outputs when using the double backward trick to compute"
" forward mode gradients. This is not allowed in strict mode.".format(i))
elif stage == "double_back":
raise RuntimeError("The jacobian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode.".format(i))
else:
raise RuntimeError("The hessian of the user-provided function is independent of "
"entry {} in the grad_jacobian. This is not allowed in strict "
"mode as it prevents from using the double backward trick to "
"replace forward mode AD.".format(i))
grads_i = torch.zeros_like(refs[i])
else:
if strict and create_graph and not grads_i.requires_grad:
if "double" not in stage:
raise RuntimeError("The jacobian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode when create_graph=True.".format(i))
else:
raise RuntimeError("The hessian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode when create_graph=True.".format(i))
res += (grads_i,)
return res
# Public API
def vjp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between a vector ``v`` and the
Jacobian of the given function at the point given by the inputs.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the vector
Jacobian product is computed. Must be the same size as the output
of ``func``. This argument is optional when the output of ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
vjp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
vjp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the inputs.
Example:
>>> def exp_reducer(x):
... return x.exp().sum(dim=1)
>>> inputs = torch.rand(4, 4)
>>> v = torch.ones(4)
>>> vjp(exp_reducer, inputs, v)
(tensor([5.7817, 7.2458, 5.7830, 6.7782]),
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
[2.1288, 1.0652, 1.5483, 2.5035],
[2.2046, 1.1292, 1.1432, 1.3059],
[1.3225, 1.6652, 1.7753, 2.0152]]))
>>> vjp(exp_reducer, inputs, v, create_graph=True)
(tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
[2.1288, 1.0652, 1.5483, 2.5035],
[2.2046, 1.1292, 1.1432, 1.3059],
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
>>> def adder(x, y):
... return 2 * x + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2))
>>> v = torch.ones(2)
>>> vjp(adder, inputs, v)
(tensor([2.4225, 2.3340]),
(tensor([2., 2.]), tensor([3., 3.])))
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vjp")
_check_requires_grad(outputs, "outputs", strict=strict)
if v is not None:
_, v = _as_tuple(v, "v", "vjp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, outputs, is_outputs_tuple)
else:
if len(outputs) != 1 or outputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the "
"user-provided function returns "
"a single Tensor with a single element.")
enable_grad = True if create_graph else torch.is_grad_enabled()
with torch.set_grad_enabled(enable_grad):
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
vjp = _grad_postprocess(vjp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple)
def jvp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between the Jacobian of
the given function at the point given by the inputs and a vector ``v``.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the Jacobian
vector product is computed. Must be the same size as the input of
``func``. This argument is optional when the input to ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
jvp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
jvp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the output.
Example:
>>> def exp_reducer(x):
... return x.exp().sum(dim=1)
>>> inputs = torch.rand(4, 4)
>>> v = torch.ones(4, 4)
>>> jvp(exp_reducer, inputs, v)
(tensor([6.3090, 4.6742, 7.9114, 8.2106]),
tensor([6.3090, 4.6742, 7.9114, 8.2106]))
>>> jvp(exp_reducer, inputs, v, create_graph=True)
(tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
>>> def adder(x, y):
... return 2 * x + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.ones(2), torch.ones(2))
>>> jvp(adder, inputs, v)
(tensor([2.2399, 2.5005]),
tensor([5., 5.]))
Note:
The jvp is currently computed by using the backward of the backward
(sometimes called the double backwards trick) as we don't have support
for forward mode AD in PyTorch at the moment.
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
if v is not None:
_, v = _as_tuple(v, "v", "jvp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the input to "
"the user-provided function is a single Tensor "
"with a single element.")
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp")
_check_requires_grad(outputs, "outputs", strict=strict)
# The backward is linear so the value of grad_outputs is not important as
# it won't appear in the double backward graph. We only need to ensure that
# it does not contain inf or nan.
grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs)
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
if create_graph:
with torch.enable_grad():
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
else:
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
jvp = _grad_postprocess(jvp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple)
def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
# This function:
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
# - Each chunk corresponds to one tensor. The chunk has the same dtype and
# device as the tensor
#
# For example, with tensor_numels = [1, 2, 1], this function returns:
# ( tensor([[1], tensor([[0, 0], tensor([[0],
# [0], [1, 0], [0],
# [0], [0, 1], [0],
# [0]]) , [0, 0]]) , [1]]) )
#
# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
# Precondition: tensors always has at least one element.
#
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
# for context behind this function. All the pre-conditions are guarded for
# in torch.autograd.functional.jacobian.
assert len(tensors) == len(tensor_numels)
assert len(tensors) > 0
total_numel = sum(tensor_numels)
diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
for tensor, tensor_numel in zip(tensors, tensor_numels))
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
chunk.diagonal(diag_start_idx).fill_(1)
return chunks
def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False):
r"""Function that computes the Jacobian of a given function.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
create_graph (bool, optional): If ``True``, the Jacobian will be
computed in a differentiable manner. Note that when ``strict`` is
``False``, the result can not require gradients or be disconnected
from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
jacobian for said inputs, which is the expected mathematical value.
Defaults to ``False``.
vectorize (bool, optional): This feature is experimental, please use at
your own risk. When computing the jacobian, usually we invoke
``autograd.grad`` once per row of the jacobian. If this flag is
``True``, we use the vmap prototype feature as the backend to
vectorize calls to ``autograd.grad`` so we only invoke it once
instead of once per row. This should lead to performance
improvements in many use cases, however, due to this feature
being incomplete, there may be performance cliffs. Please
use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
to show any performance warnings and file us issues if
warnings exist for your use case. Defaults to ``False``.
Returns:
Jacobian (Tensor or nested tuple of Tensors): if there is a single
input and output, this will be a single Tensor containing the
Jacobian for the linearized inputs and output. If one of the two is
a tuple, then the Jacobian will be a tuple of Tensors. If both of
them are tuples, then the Jacobian will be a tuple of tuple of
Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
``i``\th output and ``j``\th input and will have as size the
concatenation of the sizes of the corresponding output and the
corresponding input and will have same dtype and device as the
corresponding input.
Example:
>>> def exp_reducer(x):
... return x.exp().sum(dim=1)
>>> inputs = torch.rand(2, 2)
>>> jacobian(exp_reducer, inputs)
tensor([[[1.4917, 2.4352],
[0.0000, 0.0000]],
[[0.0000, 0.0000],
[2.4369, 2.3799]]])
>>> jacobian(exp_reducer, inputs, create_graph=True)
tensor([[[1.4917, 2.4352],
[0.0000, 0.0000]],
[[0.0000, 0.0000],
[2.4369, 2.3799]]], grad_fn=<ViewBackward>)
>>> def exp_adder(x, y):
... return 2 * x.exp() + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2))
>>> jacobian(exp_adder, inputs)
(tensor([[2.8052, 0.0000],
[0.0000, 3.3963]]),
tensor([[3., 0.],
[0., 3.]]))
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs,
"outputs of the user-provided function",
"jacobian")
_check_requires_grad(outputs, "outputs", strict=strict)
if vectorize:
if strict:
raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` '
'and `vectorized=True` are not supported together. '
'Please either set `strict=False` or '
'`vectorize=False`.')
# NOTE: [Computing jacobian with vmap and grad for multiple outputs]
#
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
# It turns out we can compute the jacobian of this function with a single
# call to autograd.grad by using vmap over the correct grad_outputs.
#
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
#
# To get the first row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
# To get the 2nd row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
# and so on.
#
# Using vmap, we can vectorize all 4 of these computations into one by
# passing the standard basis for R^4 as the grad_output.
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
#
# Now, how do we compute the jacobian *without stacking the output*?
# We can just split the standard basis across the outputs. So to
# compute the jacobian of f(x), we'd use
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
# The grad_outputs looks like the following:
# ( torch.tensor([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1],
# [0, 0, 0]]),
# torch.tensor([[0],
# [0],
# [0],
# [1]]) )
#
# But we're not done yet!
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
# returns a Tensor of shape [4, 3]. We have to remember to split the
# jacobian of shape [4, 3] into two:
# - one of shape [3, 3] for the first output
# - one of shape [ 3] for the second output
# Step 1: Construct grad_outputs by splitting the standard basis
output_numels = tuple(output.numel() for output in outputs)
grad_outputs = _construct_standard_basis_for(outputs, output_numels)
flat_outputs = tuple(output.reshape(-1) for output in outputs)
# Step 2: Call vmap + autograd.grad
def vjp(grad_output):
vj = list(_autograd_grad(flat_outputs, inputs, grad_output, create_graph=create_graph))
for el_idx, vj_el in enumerate(vj):
if vj_el is not None:
continue
vj[el_idx] = torch.zeros_like(inputs[el_idx])
return tuple(vj)
jacobians_of_flat_output = _vmap(vjp)(grad_outputs)
# Step 3: The returned jacobian is one big tensor per input. In this step,
# we split each Tensor by output.
jacobian_input_output = []
for jac, input_i in zip(jacobians_of_flat_output, inputs):
jacobian_input_i_output = []
for jac, output_j in zip(jac.split(output_numels, dim=0), outputs):
jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
jacobian_input_i_output.append(jacobian_input_i_output_j)
jacobian_input_output.append(jacobian_input_i_output)
# Step 4: Right now, `jacobian` is a List[List[Tensor]].
# The outer List corresponds to the number of inputs,
# the inner List corresponds to the number of outputs.
# We need to exchange the order of these and convert to tuples
# before returning.
jacobian_output_input = tuple(zip(*jacobian_input_output))
jacobian_output_input = _grad_postprocess(jacobian_output_input, create_graph)
return _tuple_postprocess(jacobian_output_input, (is_outputs_tuple, is_inputs_tuple))
jacobian: Tuple[torch.Tensor, ...] = tuple()
for i, out in enumerate(outputs):
# mypy complains that expression and variable have different types due to the empty list
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
for j in range(out.nelement()):
vj = _autograd_grad((out.reshape(-1)[j],), inputs,
retain_graph=True, create_graph=create_graph)
for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(zip(jac_i, vj, inputs)):
if vj_el is not None:
if strict and create_graph and not vj_el.requires_grad:
msg = ("The jacobian of the user-provided function is "
"independent of input {}. This is not allowed in "
"strict mode when create_graph=True.".format(i))
raise RuntimeError(msg)
jac_i_el.append(vj_el)
else:
if strict:
msg = ("Output {} of the user-provided function is "
"independent of input {}. This is not allowed in "
"strict mode.".format(i, el_idx))
raise RuntimeError(msg)
jac_i_el.append(torch.zeros_like(inp_el))
jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size()
+ inputs[el_idx].size()) for (el_idx, jac_i_el) in enumerate(jac_i)), )
jacobian = _grad_postprocess(jacobian, create_graph)
return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
def hessian(func, inputs, create_graph=False, strict=False, vectorize=False):
r"""Function that computes the Hessian of a given scalar function.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor with a single element.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
create_graph (bool, optional): If ``True``, the Hessian will be computed in
a differentiable manner. Note that when ``strict`` is ``False``, the result can not
require gradients or be disconnected from the inputs.
Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
hessian for said inputs, which is the expected mathematical value.
Defaults to ``False``.
vectorize (bool, optional): This feature is experimental, please use at
your own risk. When computing the hessian, usually we invoke
``autograd.grad`` once per row of the hessian. If this flag is
``True``, we use the vmap prototype feature as the backend to
vectorize calls to ``autograd.grad`` so we only invoke it once
instead of once per row. This should lead to performance
improvements in many use cases, however, due to this feature
being incomplete, there may be performance cliffs. Please
use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
to show any performance warnings and file us issues if
warnings exist for your use case. Defaults to ``False``.
Returns:
Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
this will be a single Tensor containing the Hessian for the input.
If it is a tuple, then the Hessian will be a tuple of tuples where
``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
and ``j``\th input with size the sum of the size of the ``i``\th input plus
the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
dtype and device as the corresponding ``i``\th input.
Example:
>>> def pow_reducer(x):
... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2)
>>> hessian(pow_reducer, inputs)
tensor([[[[5.2265, 0.0000],
[0.0000, 0.0000]],
[[0.0000, 4.8221],
[0.0000, 0.0000]]],
[[[0.0000, 0.0000],
[1.9456, 0.0000]],
[[0.0000, 0.0000],
[0.0000, 3.2550]]]])
>>> hessian(pow_reducer, inputs, create_graph=True)
tensor([[[[5.2265, 0.0000],
[0.0000, 0.0000]],
[[0.0000, 4.8221],
[0.0000, 0.0000]]],
[[[0.0000, 0.0000],
[1.9456, 0.0000]],
[[0.0000, 0.0000],
[0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
>>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2))
>>> hessian(pow_adder_reducer, inputs)
((tensor([[4., 0.],
[0., 4.]]),
tensor([[0., 0.],
[0., 0.]])),
(tensor([[0., 0.],
[0., 0.]]),
tensor([[6., 0.],
[0., 6.]])))
"""
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
def ensure_single_output_function(*inp):
out = func(*inp)
is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian")
_check_requires_grad(t_out, "outputs", strict=strict)
if is_out_tuple or not isinstance(out, torch.Tensor):
raise RuntimeError("The function given to hessian should return a single Tensor")
if out.nelement() != 1:
raise RuntimeError("The Tensor returned by the function given to hessian should contain a single element")
return out.squeeze()
def jac_func(*inp):
jac = jacobian(ensure_single_output_function, inp, create_graph=True)
_check_requires_grad(jac, "jacobian", strict=strict)
return jac
res = jacobian(jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize)
return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
def vhp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between a vector ``v`` and the
Hessian of a given scalar function at the point given by the inputs.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor with a single element.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the vector Hessian
product is computed. Must be the same size as the input of
``func``. This argument is optional when ``func``'s input contains
a single element and (if it is not provided) will be set as a
Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs.
Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
vhp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
vhp (tuple of Tensors or Tensor): result of the dot product with the
same shape as the inputs.
Example:
>>> def pow_reducer(x):
... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2)
>>> v = torch.ones(2, 2)
>>> vhp(pow_reducer, inputs, v)
(tensor(0.5591),
tensor([[1.0689, 1.2431],
[3.0989, 4.4456]]))
>>> vhp(pow_reducer, inputs, v, create_graph=True)
(tensor(0.5591, grad_fn=<SumBackward0>),
tensor([[1.0689, 1.2431],
[3.0989, 4.4456]], grad_fn=<MulBackward0>))
>>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.zeros(2), torch.ones(2))
>>> vhp(pow_adder_reducer, inputs, v)
(tensor(4.8053),
(tensor([0., 0.]),
tensor([6., 6.])))
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
if v is not None:
_, v = _as_tuple(v, "v", "vhp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the input to the user-provided function "
"is a single Tensor with a single element.")
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vhp")
_check_requires_grad(outputs, "outputs", strict=strict)
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
raise RuntimeError("The function given to vhp should return a single Tensor")
if outputs[0].nelement() != 1:
raise RuntimeError("The Tensor returned by the function given to vhp should contain a single element")
jac = _autograd_grad(outputs, inputs, create_graph=True)
_check_requires_grad(jac, "jacobian", strict=strict)
enable_grad = True if create_graph else torch.is_grad_enabled()
with torch.set_grad_enabled(enable_grad):
grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
outputs = _grad_postprocess(outputs, create_graph)
vhp = _grad_postprocess(vhp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vhp, is_inputs_tuple)
def hvp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between the Hessian of a given scalar
function and a vector ``v`` at the point given by the inputs.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor with a single element.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the Hessian vector
product is computed. Must be the same size as the input of
``func``. This argument is optional when ``func``'s input contains
a single element and (if it is not provided) will be set as a
Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result will be
computed in a differentiable way. Note that when ``strict`` is
``False``, the result can not require gradients or be disconnected
from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
hvp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
hvp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the inputs.
Example:
>>> def pow_reducer(x):
... return x.pow(3).sum()
>>> inputs = torch.rand(2, 2)
>>> v = torch.ones(2, 2)
>>> hvp(pow_reducer, inputs, v)
(tensor(0.1448),
tensor([[2.0239, 1.6456],
[2.4988, 1.4310]]))
>>> hvp(pow_reducer, inputs, v, create_graph=True)
(tensor(0.1448, grad_fn=<SumBackward0>),
tensor([[2.0239, 1.6456],
[2.4988, 1.4310]], grad_fn=<MulBackward0>))
>>> def pow_adder_reducer(x, y):
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
>>> inputs = (torch.rand(2), torch.rand(2))
>>> v = (torch.zeros(2), torch.ones(2))
>>> hvp(pow_adder_reducer, inputs, v)
(tensor(2.3030),
(tensor([0., 0.]),
tensor([6., 6.])))
Note:
This function is significantly slower than `vhp` due to backward mode AD constraints.
If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
know that your function satisfies this condition, you should use vhp instead that is
much faster with the current implementation.
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
if v is not None:
_, v = _as_tuple(v, "v", "hvp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the input to the user-provided function "
"is a single Tensor with a single element.")
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "hvp")
_check_requires_grad(outputs, "outputs", strict=strict)
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
raise RuntimeError("The function given to hvp should return a single Tensor")
if outputs[0].nelement() != 1:
raise RuntimeError("The Tensor returned by the function given to hvp should contain a single element")
jac = _autograd_grad(outputs, inputs, create_graph=True)
_check_requires_grad(jac, "jacobian", strict=strict)
grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
_check_requires_grad(jac, "hessian", strict=strict)
enable_grad = True if create_graph else torch.is_grad_enabled()
with torch.set_grad_enabled(enable_grad):
grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
hvp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back_trick")
outputs = _grad_postprocess(outputs, create_graph)
hvp = _grad_postprocess(hvp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(hvp, is_inputs_tuple)