# Owner(s): ["oncall: pt2"] import functools import sys import unittest from unittest.mock import patch import torch import torch.utils.checkpoint from functorch.compile import aot_function, min_cut_rematerialization_partition, nop from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, ) from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, run_tests, TestCase if IS_WINDOWS and IS_CI: sys.stderr.write("torch.compile not supported on windows") if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("torch.compile not supported on windows") def count_philox_rand(gm, args, freq): assert [node.target for node in gm.graph.nodes].count( torch.ops.rngprims.philox_rand.default ) == freq return gm class TestFunctionalizationRngOps(TestCase): @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_rand_like(self, dtype, device): def fn(x): a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a x = torch.rand(10, device=device, dtype=dtype) for seed in range(10): torch.cuda.manual_seed(seed) ref = fn(x) torch.cuda.manual_seed(seed) aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) res = aot_fn(x) self.assertEqual(ref, res) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_rand_like_dynamic(self, dtype, device): def fn(x): a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a for seed in range(1, 10): shape = (seed, seed) x = torch.rand(shape, device=device, dtype=dtype) torch.cuda.manual_seed(seed) ref = fn(x) torch.cuda.manual_seed(seed) opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) res = opt_fn(x) self.assertEqual(ref, res) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_rand_like_dynamic_bwd(self, dtype, device): def fn(x): a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a for seed in range(1, 10): shape = (seed, seed) x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) torch.cuda.manual_seed(seed) ref = fn(x) ref.sum().backward() torch.cuda.manual_seed(seed) opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) res = opt_fn(x) res.sum().backward() self.assertEqual(ref, res) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_rand(self, dtype, device): shape = (10,) def fn(x): a = torch.rand(*shape, device=device, dtype=dtype) * x a = torch.rand(*shape, device=device, dtype=dtype) * a return a x = torch.rand(*shape, device=device, dtype=dtype) for seed in range(10): torch.cuda.manual_seed(seed) ref = fn(x) torch.cuda.manual_seed(seed) aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) res = aot_fn(x) self.assertEqual(ref, res) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_autograd_function(self, dtype, device): shape = (16, 16) class Custom(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a @staticmethod def backward(ctx, grad_out): (x,) = ctx.saved_tensors return grad_out * torch.rand_like(grad_out) * torch.cos(x) custom = Custom.apply x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) x_clone = x.clone().detach().requires_grad_(True) torch.cuda.manual_seed(123) ref = custom(x) ref.sum().backward() torch.cuda.manual_seed(123) fwd_compiler = functools.partial(count_philox_rand, freq=2) bwd_compiler = functools.partial(count_philox_rand, freq=1) aot_custom = aot_function(custom, fwd_compiler, bwd_compiler) res = aot_custom(x_clone) res.sum().backward() self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_multiple_subgraphs(self, dtype, device): # Checks that rng state is maintained when there are multiple aot traced # graphs. shape = (16, 16) class CustomOp1(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a @staticmethod def backward(ctx, grad_out): (x,) = ctx.saved_tensors return grad_out * torch.rand_like(grad_out) * torch.cos(x) class CustomOp2(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) a = torch.rand_like(x) * x return a @staticmethod def backward(ctx, grad_out): (x,) = ctx.saved_tensors return grad_out * torch.rand_like(grad_out) * torch.rand_like(x) custom_op1 = CustomOp1.apply custom_op2 = CustomOp2.apply def fn(x): a = custom_op1(x) b = a.sin() return custom_op2(b) fwd_compiler = functools.partial(count_philox_rand, freq=2) bwd_compiler = functools.partial(count_philox_rand, freq=1) aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler) fwd_compiler = functools.partial(count_philox_rand, freq=1) bwd_compiler = functools.partial(count_philox_rand, freq=2) aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler) def aot_fn(x): a = aot_custom_op1(x) b = a.sin() return aot_custom_op2(b) for seed in range(10): torch.cuda.manual_seed(seed) x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) x_clone = x.clone().detach().requires_grad_(True) torch.cuda.manual_seed(seed) ref = fn(x) ref.sum().backward() torch.cuda.manual_seed(seed) res = aot_fn(x_clone) res.sum().backward() self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_set_get_rng_state(self, dtype, device): def fn(x): a = torch.rand_like(x) * x state = torch.cuda.get_rng_state() a = torch.rand_like(x) * a torch.cuda.set_rng_state(state) a = torch.rand_like(x) * a return a x = torch.rand(10, device=device, dtype=dtype) for seed in range(10): torch.cuda.manual_seed(seed) ref = fn(x) torch.cuda.manual_seed(seed) fwd_compiler = functools.partial(count_philox_rand, freq=3) aot_fn = aot_function(fn, fwd_compiler) res = aot_fn(x) self.assertEqual(ref, res) @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_min_cut_partitioner(self, dtype, device): # Checks that the calling convention is maintained shape = (16, 16) def fn(x): a = torch.rand_like(x) * x a = torch.rand_like(x) * a a = torch.sin(a) a = torch.sin(a) a = torch.sin(a) return a x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) x_clone = x.clone().detach().requires_grad_(True) torch.cuda.manual_seed(123) ref = fn(x) ref.sum().backward() torch.cuda.manual_seed(123) fwd_compiler = functools.partial(count_philox_rand, freq=2) bwd_compiler = functools.partial(count_philox_rand, freq=0) aot_custom = aot_function( fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition, ) # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler) res = aot_custom(x_clone) res.sum().backward() self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) # TODO - Dropout needs more work because of offset calculation @patch.object(torch._functorch.config, "functionalize_rng_ops", True) @dtypes(torch.float32) def test_checkpoint(self, dtype, device): def g(x, y): return torch.nn.functional.dropout(x, 0.6) def fn(x, y): return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False) # x = torch.rand(2, 2, device="cuda", requires_grad=True) x = torch.ones(2, 2, device="cuda", requires_grad=True) y = torch.rand(2, 2, device="cuda", requires_grad=True) torch.cuda.manual_seed(123) ref = fn(x, y) # With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd fwd_compiler = functools.partial(count_philox_rand, freq=1) bwd_compiler = functools.partial(count_philox_rand, freq=0) aot_fn = aot_function(fn, fwd_compiler, bwd_compiler) # We cant check accuracy here because rand_like generated different rand numbers than dropout res = aot_fn(x, y) res.sum().backward() @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_dropout_decomp(self, dtype, device): def fn(x): return torch.nn.functional.dropout(x, 0.6) * x x = torch.rand(10, device=device, dtype=dtype) # Ensure the decomp is happening aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1)) # We cant check accuracy here because rand_like generated different rand numbers than dropout aot_fn(x) only_for = ("cuda",) instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for) class NegativeTest(TestCase): @dtypes(torch.float32) @patch.object(torch._functorch.config, "functionalize_rng_ops", True) def test_on_cpu(self, dtype, device): def fn(x): a = torch.rand_like(x) * x a = torch.rand_like(x) * a return a x = torch.rand(10, device=device, dtype=dtype) aot_fn = aot_function(fn, nop) with self.assertRaises(RuntimeError): aot_fn(x) only_for = ("cpu",) instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for) if __name__ == "__main__": run_tests()