# Owner(s): ["module: dynamo"] # flake8: noqa: B950 import copy import math from dataclasses import dataclass import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda if HAS_CUDA: import triton from torch.testing._internal.triton_utils import add_kernel class CustomFunc1(torch.autograd.Function): @staticmethod def forward(ctx, foo): return foo + foo @staticmethod def backward(ctx, grad_output): return grad_output class CustomFunc3(torch.autograd.Function): # Test there is graph break in forward function @staticmethod def forward(ctx, foo): result = foo + foo torch._dynamo.graph_break() result = result + foo ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): (result,) = ctx.saved_tensors return grad_output * math.sqrt(result.numel()) class Module1(torch.nn.Module): def forward(self, foo): return CustomFunc1().apply(foo) class Module2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fn = CustomFunc1.apply def forward(self, foo): return self.fn(foo) class Module3(torch.nn.Module): def forward(self, foo): return CustomFunc1().apply(foo) class Module4(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fn = CustomFunc1.apply def forward(self, foo): return self.fn(foo) class Module5(torch.nn.Module): def forward(self, foo): return CustomFunc3().apply(foo) class Module6(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fn = CustomFunc3.apply def forward(self, foo): return self.fn(foo) class LinearFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, bias): output = input.mm(weight.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) return output @staticmethod # inputs is a Tuple of all of the inputs passed to forward. # output is the output of the forward(). def setup_context(ctx, inputs, output): input, weight, bias = inputs ctx.save_for_backward(input, weight, bias) # This function has only a single output, so it gets only one gradient @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) return grad_input, grad_weight, grad_bias class ModuleLinear(torch.nn.Module): def forward(self, input, weight, bias=None): return LinearFunction.apply(input, weight, bias) class MaterializingGradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.set_materialize_grads(False) return x.clone(), x.clone() @staticmethod def backward(ctx, grad_out1, grad_out2): return grad_out1, grad_out2 class MaterializingGradModule(torch.nn.Module): def forward(self, x): return MaterializingGradFunction.apply(x) class CustomFuncBwdPrintGraphBreak(torch.autograd.Function): @staticmethod def forward(ctx, foo): return torch.add(foo, foo) @staticmethod def backward(ctx, grad_output): print("graph break!") return grad_output class CustomFuncBwdPrintModule(torch.nn.Module): def forward(self, x): return CustomFuncBwdPrintGraphBreak.apply(x) class CustomFuncStrideBwd(torch.autograd.Function): @staticmethod def forward(ctx, foo): return torch.add(foo, foo) @staticmethod def backward(ctx, grad_output): return grad_output.stride() class CustomFuncStrideModule(torch.nn.Module): def forward(self, x): return CustomFuncStrideBwd.apply(x) class CustomFuncSaveForBwd(torch.autograd.Function): @staticmethod def forward(ctx, foo): result = foo + foo result = result + foo ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): (result,) = ctx.saved_tensors return grad_output * math.sqrt(result.numel()) class SaveForBwdModule(torch.nn.Module): def forward(self, foo): return CustomFuncSaveForBwd().apply(foo) class ContextSaveAndMark(torch.autograd.Function): @staticmethod def forward(ctx, x): with torch.no_grad(): ctx.save_for_backward(x) ctx.mark_non_differentiable(x) return x @staticmethod def backward(ctx, grad_output): return grad_output class ContextMarkAndSave(torch.autograd.Function): @staticmethod def forward(ctx, x): with torch.no_grad(): ctx.mark_non_differentiable(x) ctx.save_for_backward(x) return x @staticmethod def backward(ctx, grad_output): return grad_output class ModuleWithGradFunc(torch.nn.Module): def __init__(self, func): super().__init__() self.f = func.apply def forward(self, x): return self.f(x) class AutogradFunctionTests(torch._dynamo.test_case.TestCase): # Sound behaviors, tested for working capture def test_autograd_function_equivalence(self): for grad in [True, False]: for i in range(1, 5): torch._dynamo.reset() model = globals()[f"Module{i}"]() opt_model = torch._dynamo.optimize("eager")(model) self.assertTrue( torch.allclose( opt_model(torch.ones(2, 3, requires_grad=grad)), torch.tensor([2.0], requires_grad=grad), ) ) def test_autograd_function_has_graph_break(self): for grad in [True, False]: x = torch.randn(10, requires_grad=grad) for model in [Module5(), Module6()]: torch._dynamo.reset() cnts = torch._dynamo.testing.CompileCounter() opt_model = torch._dynamo.optimize(cnts)(model) for _ in range(3): ref = model(x) res = opt_model(x) self.assertTrue(torch.allclose(ref, res)) self.assertEqual(cnts.frame_count, 2) def test_linear_setup_context(self): model = ModuleLinear() opt_model = torch._dynamo.optimize("eager", nopython=True)(model) input = torch.randn(2, 2, dtype=torch.double, requires_grad=True) weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True) eager_result = model(input, weight) optim_result = opt_model(input, weight) self.assertEqual(optim_result, eager_result) def test_materialize_grad(self): model = MaterializingGradModule() opt_model = torch._dynamo.optimize("eager")(model) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) optim_result = opt_model(x) eager_result = model(x) self.assertEqual(optim_result, eager_result) def test_print_in_bwd(self): model = CustomFuncBwdPrintModule() opt_model = torch._dynamo.optimize("eager", nopython=True)(model) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"): opt_model(x) def test_stride_in_bwd(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() model = CustomFuncStrideModule() opt_model = torch.compile(backend=cnt)(model) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) ref = model(x) res = opt_model(x) self.assertEqual(ref, res) self.assertEqual(cnt.frame_count, 1) # graph break: Illegal getattr invocation stride in strict mod. self.assertEqual( list(torch._dynamo.utils.counters["graph_break"].values()), [1] ) def test_enum_arg(self): from enum import Enum class SomeEnum(Enum): A = 0 B = 1 class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x, e): if e is SomeEnum.A: return x.sin() else: return x.cos() @staticmethod def backward(ctx, g): return g @torch.compile(backend="eager", fullgraph=True) def f(x, enum): output = Foo.apply( x, enum, ) return output x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True) y = f(x, SomeEnum.A) self.assertEqual(y, x.sin()) def test_save_for_bwd(self): model = SaveForBwdModule() opt_model = torch._dynamo.optimize("eager", nopython=True)(model) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) opt_model(x) def test_allow_in_graph(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() @torch._dynamo.allow_in_graph class AllowInGraphFunc(torch.autograd.Function): @staticmethod def forward(ctx, x): torch._dynamo.graph_break() ctx.x0 = x.size(0) return x * 2 @staticmethod def backward(ctx, grad_out): return grad_out * ctx.x0 @torch.compile(backend=cnt, fullgraph=True) def fn(x): return AllowInGraphFunc.apply(x) x = torch.rand(2, 3, requires_grad=True) result = fn(x) self.assertEqual(result, AllowInGraphFunc.apply(x)) self.assertEqual(cnt.frame_count, 1) def test_once_differentiable(self): from torch.autograd.function import once_differentiable torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() class ScaleGradient(torch.autograd.Function): @staticmethod def forward(ctx, x): return x @staticmethod @once_differentiable def backward(ctx, grad): return grad * 0.5 @torch.compile(backend=cnt, fullgraph=True) def fn(x): return ScaleGradient.apply(x) x = torch.randn(3, requires_grad=True) result = fn(x) self.assertEqual(result, ScaleGradient.apply(x)) self.assertEqual(cnt.frame_count, 1) def test_classmethod(self): class Shake(torch.autograd.Function): @classmethod def forward(cls, ctx, foo): return foo + foo @classmethod def backward(cls, ctx, grad_output): return grad_output def f(x): return Shake.apply(x) x = torch.randn(4, 4, 4, 4, requires_grad=True) opt_m = torch.compile(backend="eager")(f) opt_m(x) def test_function_context_save_and_mark(self): mod = ModuleWithGradFunc(ContextSaveAndMark) args, kwargs = ([torch.rand([1])], {}) before = mod(*args, **kwargs) torch._dynamo.reset() compiled_model = torch._dynamo.optimize("eager")(mod) after = compiled_model(*args, **kwargs) self.assertEqual(before, after) def test_function_context_mark_and_save(self): mod = ModuleWithGradFunc(ContextMarkAndSave) args, kwargs = ([torch.rand([1])], {}) before = mod(*args, **kwargs) torch._dynamo.reset() compiled_model = torch._dynamo.optimize("eager")(mod) after = compiled_model(*args, **kwargs) self.assertEqual(before, after) def test_multi_output(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone(), x.clone() @staticmethod def backward(ctx, grad1, grad2): return grad1 + grad2 @torch.compile(backend=cnt, fullgraph=True) def f(x): return Foo.apply(x) x = torch.randn(3, requires_grad=True) result = f(x) self.assertEqual(result, Foo.apply(x)) self.assertEqual(cnt.frame_count, 1) def test_amp_custom_fwd_bwd(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() class MyMM(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(device_type="cuda") def forward(ctx, a, b): ctx.save_for_backward(a, b) return a.mm(b) @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad): a, b = ctx.saved_tensors return grad.mm(b.t()), a.t().mm(grad) @torch.compile(backend=cnt, fullgraph=True) def fn(a, b): return MyMM.apply(a, b) a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True) grad = a.clone() res = fn(a, a) res.backward(grad) self.assertEqual(res, MyMM.apply(a, a)) self.assertEqual(cnt.frame_count, 1) def test_set_materialize_grads_no_graph_break(self): class MulY(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.set_materialize_grads(True) return x * 3 @staticmethod def backward(ctx, grad_out): return grad_out * 3 @torch.compile(backend="eager", fullgraph=True) def f(x): return MulY.apply(x) x = torch.tensor(2.0, requires_grad=True) result = f(x) result.sum().backward() self.assertEqual(result, MulY.apply(x)) self.assertEqual(x.grad, 3.0) def test_user_defined_object_as_input(self): cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") @dataclass class Weird: x: int b: torch.Tensor c: torch.Tensor class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, weird: Weird, z: torch.Tensor): ctx.save_for_backward(weird.b, weird.c) return weird.b * weird.c * x.clone() @staticmethod def backward(ctx, grad): b, c = ctx.saved_tensors return grad * b * c, None, grad * 2 @torch.compile(backend=cnt, fullgraph=True) def f(x, weird, z): return Foo.apply(x, weird, z) x = torch.tensor(2.0, requires_grad=True) weird = Weird(1.2, torch.tensor(2.5, requires_grad=True), torch.tensor(3.5)) z = torch.tensor(3.0, requires_grad=True) result = f(x, weird, z) result.sum().backward() self.assertEqual(result, Foo.apply(x, weird, z)) self.assertEqual(x.grad, 2.5 * 3.5) self.assertEqual(z.grad, 2.0) self.assertEqual(weird.b.grad, None) # check Dynamo captured graph is correct! actual_graph = torch._dynamo.testing.normalize_gm( cnt.graphs[0].print_readable(print_output=False) ) self.assertExpectedInline( actual_graph, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "f32[]"): l_x_ = L_x_ l_z_ = L_z_ l_weird_b = L_weird_b l_weird_c = L_weird_c function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None return (autograd_function_apply,) class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = x.clone(); x = None mul_1: "f32[]" = mul * clone; mul = clone = None return (mul_1, [l_weird_b, l_weird_c]) class bwd_body_0(torch.nn.Module): def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None mul: "f32[]" = grad * l_weird_b; l_weird_b = None mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None mul_2: "f32[]" = grad * 2; grad = None _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (mul_1, mul_2) """, ) def test_tensor_list_as_input(self): class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x, tl): ctx.save_for_backward(tl[0], tl[1]) return x.clone() * (tl[0] + tl[1]) @staticmethod def backward(ctx, grad): tl0, tl1 = ctx.saved_tensors return grad * (tl0 + tl1), None @torch.compile(backend="aot_eager", fullgraph=True) def f(x, tl): return Foo.apply(x, tl) x = torch.tensor(2.0, requires_grad=True) tl = [ torch.tensor(3.0, requires_grad=True), torch.tensor(4.0, requires_grad=True), ] result = f(x, tl) result.sum().backward() self.assertEqual(result, Foo.apply(x, tl)) self.assertEqual(x.grad, 7.0) self.assertEqual(tl[0].grad, None) self.assertEqual(tl[1].grad, None) def test_multiple_different_non_tensor_inputs(self): @dataclass class Weird: x: int b: torch.Tensor c: torch.Tensor class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x, weird, z, tl): ctx.save_for_backward(weird.b, weird.c, tl[0], tl[1]) return x.clone() * weird.b * weird.c * tl[0] @staticmethod def backward(ctx, grad): b, c, tl0, _ = ctx.saved_tensors return grad * b * c * tl0, None, grad * 2, None @torch.compile(backend="aot_eager", fullgraph=True) def f(x, weird, z, tl): return Foo.apply(x, weird, z, tl) x = torch.tensor(2.0, requires_grad=True) weird = Weird( 1.2, torch.tensor(2.5, requires_grad=True), torch.tensor(3.5, requires_grad=True), ) z = torch.tensor(3.0, requires_grad=True) tl = [ torch.tensor(0.5, requires_grad=True), torch.tensor(0.6, requires_grad=True), ] result = f(x, weird, z, tl) result.sum().backward() self.assertEqual(result, Foo.apply(x, weird, z, tl)) self.assertEqual(x.grad, 2.5 * 3.5 * 0.5) self.assertEqual(z.grad, 2.0) self.assertEqual(weird.b.grad, None) self.assertEqual(weird.c.grad, None) self.assertEqual(tl[0].grad, None) self.assertEqual(tl[1].grad, None) def test_backward_returns_none_for_tensor_input(self): class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(y) return x.clone() * y @staticmethod def backward(ctx, grad): (y,) = ctx.saved_tensors return grad * y, None @torch.compile(backend="aot_eager", fullgraph=True) def f(x, y): return Foo.apply(x, y) x = torch.tensor(2.0, requires_grad=True) y = torch.tensor(3.0, requires_grad=True) result = f(x, y) result.sum().backward() self.assertEqual(result, Foo.apply(x, y)) self.assertEqual(x.grad, 3.0) self.assertEqual(y.grad, None) def test_function_with_bound_free_variable(self): class LowerBound(torch.autograd.Function): @staticmethod def forward(ctx, inputs, bound): ctx.save_for_backward(inputs, inputs.new_ones(1) * bound) return inputs.clamp(min=bound) @staticmethod def backward(ctx, grad_output): inputs, bound = ctx.saved_tensors return (inputs >= bound) * grad_output, None class MyMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32])) def forward(self, x): gamma = LowerBound.apply(self.gamma, 1) return x + gamma mod = MyMod() args, kwargs = ([torch.rand([4, 128, 32, 32])], {}) before = mod(*args, **kwargs) compiled_model = torch._dynamo.optimize("eager")(mod) after = compiled_model(*args, **kwargs) self.assertEqual(before, after) # I pulled all of these test cases from test_autograd.py # In the future, we should make the Dynamo test suite actually # run on test_autograd.py (it's disabled right now) and delete these. def test_smoke_from_test_autograd(self): def mult1(x): return x.prod(dim=-1).prod(dim=-1) class Mult(torch.autograd.Function): @staticmethod def forward(ctx, x): y = mult1(x) ctx.save_for_backward(x, y) return y @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors return (grad_output * y)[:, None, None] / x mult2 = Mult.apply class Double(torch.autograd.Function): @staticmethod def forward(ctx, x): y = x**2 ctx.save_for_backward(x, y) return y @staticmethod def backward(ctx, grad_output): x, _ = ctx.saved_tensors return grad_output * 2 * x # this is equivalent, but uses the output of .forward() in .backward() class Double2(Double): @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors return grad_output * 2 * y / x double = Double.apply double2 = Double2.apply class Identity(torch.autograd.Function): @staticmethod def forward(ctx, a, b): return a, a + b @staticmethod def backward(ctx, grad_a, grad_b): return grad_a + grad_b, grad_b class MyFunc2(torch.autograd.Function): @staticmethod def forward(ctx, inp): return inp.clone() @staticmethod def backward(ctx, gO): return torch.tensor(float("nan")).expand(10, 10) def run_fn(a): out = MyFunc2.apply(a) return out.sum() class MyFn(torch.autograd.Function): @staticmethod def forward(ctx, inp): return inp.view_as(inp) @staticmethod def backward(ctx, grad): return grad class MyAdder(torch.autograd.Function): @staticmethod def forward(ctx, a, b): a.add_(b) ctx.mark_dirty(a) return a @staticmethod def backward(ctx, grad): return grad, grad class InplaceMul(torch.autograd.Function): @staticmethod def forward(ctx, x): result = x.mul_(2) ctx.mark_dirty(result) return result @staticmethod def backward(ctx, grad_output): pass @staticmethod def jvp(ctx, x_t): if jvp_err: # noqa: F821 return x_t else: return x_t.mul_(2) class MyFn2(torch.autograd.Function): @staticmethod def forward(ctx, x, y): return x + y, x @staticmethod def vjp(ctx, gO1, gO2): return gO1 + gO2, gO1 @staticmethod def jvp(ctx, x_t, y_t): return x_t + y_t, fn(x_t) # noqa: F821 class MyFn3(torch.autograd.Function): @staticmethod def forward(ctx, inp, inplace): view = inp.clone()[:3] if inplace: view += 2 return view @staticmethod def backward(ctx, grad): return grad, None def test(): x = torch.ones(2, 4, 4).requires_grad_() mult2(x) x = torch.tensor(2).double().requires_grad_() double(x) double2(x) x = torch.randn(5, 5, requires_grad=True) y = torch.randn(5, 5, requires_grad=True) q, p = Identity.apply(x, y) a = torch.rand(1, 2) b = torch.rand(1, requires_grad=True) view_a = MyFn.apply(a) a = torch.ones(2, requires_grad=True) b = torch.ones(2, requires_grad=True) c = MyAdder.apply(a.clone(), b) c.sum().backward() z = torch.tensor(1.0, requires_grad=True) x = z.clone() y = InplaceMul.apply(x) a = torch.tensor(1.0, dtype=torch.double, requires_grad=True) b = torch.tensor(1.0, dtype=torch.double, requires_grad=True) c = torch.tensor(1.0, dtype=torch.double) d = torch.tensor(1.0, dtype=torch.double) MyFn2.apply(a, b) MyFn2.apply(c, d) base = torch.rand(10, requires_grad=True) foo = MyFn3.apply(base, False) test() opt_test = torch._dynamo.optimize("eager")(test) opt_test() def test_tensor_subclass_intermediary_input(self): class FooTensor(torch.Tensor): @staticmethod def __new__(cls, data, config, scale): self = torch.Tensor._make_wrapper_subclass( cls, config[0], strides=config[1], storage_offset=config[2], dtype=config[3], layout=config[4], requires_grad=config[5], device=data.device, ) self._data = data self._config = config self._scale = scale return self def __repr__(self): return "FooTensor" def __tensor_flatten__(self): return ("_data",), ( self._config, self._scale, ) @staticmethod def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride): return FooTensor(tensors["_data"], metadatas[0], metadatas[1]) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # handling clone and view is so dynamo fakefication passes, it's not # intended to be handling user code if func == torch.ops.aten.clone.default: return FooTensor( args[0]._data.clone(), args[0]._config, args[0]._scale ) elif func == torch.ops.aten.view.default: new_data = args[0]._data.view(*args[1:]) return FooTensor(new_data, args[0]._config, args[0]._scale) raise NotImplementedError class foo_autograd_fn(torch.autograd.Function): @staticmethod def forward(ctx, x): # access some data from `x`, where `x` is a tensor subclass x2 = x._data + 1.0 # create and return a tensor subclass from within a torch.autograd.Function x3 = FooTensor(x2, x._config, x._scale) return x3._data @staticmethod def backward(ctx, g): return g x_ref = torch.randn(4, 4).requires_grad_(True) x = copy.deepcopy(x_ref) scale = torch.tensor(1.0) # Weird that this is needed, but not having this breaks a lot of things torch._dynamo.allow_in_graph(FooTensor) def foo(x, scale): config = ( x.size(), x.stride(), x.storage_offset(), x.dtype, x.layout, x.requires_grad, ) x = FooTensor(x, config, scale) x = foo_autograd_fn.apply(x) return x y_ref = foo(x_ref, scale) y_ref.sum().backward() foo_opt = torch.compile(foo, backend="eager") y = foo_opt(x, scale) y.sum().backward() self.assertEqual(y, y_ref) self.assertEqual(x.grad, x_ref.grad) def test_smuggle_symint_issue_111031(self): from torch.autograd import Function class Foo(Function): @staticmethod def forward(ctx, x): ctx.x0 = x.size(0) return x * 2 @staticmethod def backward(ctx, grad_out): return grad_out * ctx.x0 cnts = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnts, fullgraph=True, dynamic=True) def foo(x): return Foo.apply(x) foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) def test_needs_input_grad(self): cnt = torch._dynamo.testing.CompileCounter() class NeedsInputGradFunc(torch.autograd.Function): @staticmethod def forward(ctx, foo): result = foo + foo ctx.save_for_backward(result) return result @staticmethod @torch.compile(backend=cnt, fullgraph=True) def backward(ctx, grad_output): (result,) = ctx.saved_tensors if ctx.needs_input_grad[0]: return grad_output * result.sin() return None x = torch.randn(10, requires_grad=True) NeedsInputGradFunc.apply(x).sum().backward() self.assertEqual(x.grad.shape, x.shape) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 2) def test_repeated_save_for_backward_calls(self): from torch.autograd import Function class Foo(Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x) ctx.save_for_backward(x, y) return x * y @staticmethod def backward(ctx, grad_out): x, y = ctx.saved_tensors return grad_out * x, grad_out * y cnts = torch._dynamo.testing.CompileCounter() def foo(x, y): return Foo.apply(x, y) x_ref = torch.randn(2, requires_grad=True) y_ref = torch.randn(2, requires_grad=True) x_test = x_ref.clone().detach().requires_grad_() y_test = y_ref.clone().detach().requires_grad_() out_ref = foo(x_ref, y_ref) out_ref.sum().backward() out_test = torch.compile(foo, backend=cnts)(x_test, y_test) out_test.sum().backward() self.assertEqual(cnts.frame_count, 1) self.assertEqual(out_ref, out_test) self.assertEqual(x_ref.grad, x_test.grad) self.assertEqual(y_ref.grad, y_test.grad) def test_smuggle_tensor_and_complex_structures(self): from torch.autograd import Function class Foo(Function): @staticmethod def forward(ctx, x): ctx.x0 = x ctx.x1 = [1, 2, 3] return x * 2 @staticmethod def backward(ctx, grad_out): x0mul = grad_out * ctx.x0 for i in ctx.x1: x0mul = (x0mul * i) + x0mul return x0mul cnts = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnts, fullgraph=True, dynamic=True) def foo(x): return Foo.apply(x) foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) def test_mark_non_differentiable(self): cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") from torch.autograd import Function class MyFunction(Function): @staticmethod def forward(ctx, x, y): out1 = x.sin() out2 = y * 2 ctx.mark_non_differentiable(out2) return out1, out2 @staticmethod def backward(ctx, grad1, grad2): return grad1.cos(), grad2 * 0.0 @torch.compile(backend=cnt, fullgraph=True) def fn(x, y): return MyFunction.apply(x, y) x = torch.tensor(10.0, requires_grad=True) y = torch.tensor(20.0, requires_grad=True) ref1, ref2 = MyFunction.apply(x, y) res1, res2 = fn(x, y) self.assertEqual(ref1, res1) self.assertEqual(ref2, res2) # Ensure out1 requires gradients, out2 does not. self.assertTrue(ref1.requires_grad) self.assertTrue(res1.requires_grad) self.assertFalse(ref2.requires_grad) self.assertFalse(res2.requires_grad) res1.sum().backward() # check Dynamo captured graph is correct! actual_graph = torch._dynamo.testing.normalize_gm( cnt.graphs[0].print_readable(print_output=False) ) self.assertExpectedInline( actual_graph, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[]", L_y_: "f32[]"): l_x_ = L_x_ l_y_ = L_y_ function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None getitem: "f32[]" = autograd_function_apply[0] getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None return (getitem, getitem_1) class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", y: "f32[]"): out1: "f32[]" = x.sin(); x = None out2: "f32[]" = y * 2; y = None return ((out1, out2), []) class bwd_body_0(torch.nn.Module): def forward(self, ctx, grad1: "f32[]", grad2: "f32[]"): _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None cos: "f32[]" = grad1.cos(); grad1 = None mul: "f32[]" = grad2 * 0.0; grad2 = None _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (cos, mul) """, ) def test_mark_multi_output_non_differentiable(self): from torch.autograd import Function class MyFunction(Function): @staticmethod def forward(ctx, x, y, z): out1 = x.sin() out2 = y * 2 out3 = z + 3 ctx.mark_non_differentiable(out2, out3) return out1, out2, out3 @staticmethod def backward(ctx, grad1, grad2, grad3): return grad1.cos(), grad2, grad3 @torch.compile(backend="aot_eager", fullgraph=True) def fn(x, y, z): return MyFunction.apply(x, y, z) x = torch.tensor(10.0, requires_grad=True) y = torch.tensor(20.0, requires_grad=True) z = torch.tensor(30.0, requires_grad=True) ref1, ref2, ref3 = MyFunction.apply(x, y, z) res1, res2, res3 = fn(x, y, z) self.assertEqual(ref1, res1) self.assertEqual(ref2, res2) self.assertEqual(ref3, res3) # Ensure out1 requires gradients, out2 does not. self.assertTrue(ref1.requires_grad) self.assertTrue(res1.requires_grad) self.assertFalse(ref2.requires_grad) self.assertFalse(res2.requires_grad) self.assertFalse(ref3.requires_grad) self.assertFalse(res3.requires_grad) res1.sum().backward() def test_default_values(self): from torch.autograd import Function class Foo(Function): @staticmethod def forward(ctx, x, alpha=0.99): return x @staticmethod def backward(ctx, grad_out): return grad_out @torch.compile def foo(x): return Foo.apply(x) # Make sure guards for default values do not crash foo(torch.randn(2)) foo(torch.randn(2, requires_grad=True)) def test_tuple_arg(self): cnt = torch._dynamo.testing.CompileCounter() class TupleArgFunc(torch.autograd.Function): @staticmethod def forward(ctx, x, shape): ctx.save_for_backward(torch.randn(shape)) return x + 1 @staticmethod def backward(ctx, grad_output): (result,) = ctx.saved_tensors return result, None @torch.compile(backend=cnt, fullgraph=True) def fn(): return TupleArgFunc.apply(x, shape) shape = (10, 10) x = torch.randn(shape, requires_grad=True) out = fn() out.sum().backward() self.assertEqual(out, x + 1) self.assertEqual(x.grad.shape, shape) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 2) @requires_cuda def test_triton_kernel_basic(self): class Add(torch.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) output = torch.zeros_like(x) n_elements = output.numel() grid = lambda meta: ( # noqa: E731 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) return output @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors return x * grad_output, y * grad_output @torch.compile(fullgraph=True, backend="inductor") def f(x, y): z = Add.apply(x, y) return z x = torch.randn(10, device="cuda", requires_grad=True) y = torch.randn(10, device="cuda", requires_grad=True) z = f(x, y) loss = z.sum() loss.backward() self.assertEqual(x + y, z) @requires_cuda def test_triton_kernel_multiple_out(self): class Add(torch.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) ctx.t1 = x ctx.t2 = y output = torch.zeros_like(x) n_elements = output.numel() grid = lambda meta: ( # noqa: E731 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) return output, x @staticmethod def backward(ctx, grad_output, old_x): x, y = ctx.saved_tensors x1 = ctx.t1 y1 = ctx.t2 return old_x * x * x1 * grad_output, y * y1 * grad_output @torch.compile(fullgraph=True, backend="inductor") def f(x, y): z = Add.apply(x, y) return z x = torch.randn(10, device="cuda", requires_grad=True) y = torch.randn(10, device="cuda", requires_grad=True) z, _ = f(x, y) loss = z.sum() loss.backward() self.assertEqual(x + y, z) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()