Lines Matching full:torch
4 import torch
5 import torch._dynamo.test_case
6 import torch._dynamo.testing
7 import torch.onnx.operators
8 from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
9 from torch.nn import functional as F
10 from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
11 from torch.testing._internal.common_utils import TEST_WITH_ROCM
16 self.prev = torch.is_grad_enabled()
20 torch._C._set_grad_enabled(self.mode)
23 torch._C._set_grad_enabled(self.prev)
28 torch._dynamo.graph_break()
32 class CtxManagerTests(torch._dynamo.test_case.TestCase):
37 with torch.no_grad():
44 with torch.set_grad_enabled(False):
51 with torch.enable_grad():
58 with torch.set_grad_enabled(True):
59 if torch.is_grad_enabled():
64 with torch.no_grad():
65 torch._dynamo.testing.standard_test(
68 torch._dynamo.testing.standard_test(
71 torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
72 torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
73 with torch.enable_grad():
74 torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
75 torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
76 torch._dynamo.testing.standard_test(
79 torch._dynamo.testing.standard_test(
85 prev_grad = torch.is_grad_enabled()
86 torch.set_grad_enabled(False)
90 torch.set_grad_enabled(prev_grad)
93 a = torch.randn([3, 4])
94 b = torch.randn([3, 4])
95 cnts = torch._dynamo.testing.CompileCounter()
96 opt_fn = torch._dynamo.optimize(cnts)(fn)
103 before = torch.is_grad_enabled()
104 with torch.set_grad_enabled(False):
105 torch._dynamo.graph_break()
106 with torch.set_grad_enabled(True):
107 x = torch.mul(x, 5)
108 torch._dynamo.graph_break()
109 x = torch.sqrt(x)
110 assert torch.is_grad_enabled()
111 assert not torch.is_grad_enabled()
112 assert torch.is_grad_enabled() == before
115 a = torch.randn([3, 4])
116 cnts = torch._dynamo.testing.CompileCounter()
117 opt_fn = torch._dynamo.optimize(cnts)(fn)
124 # wrap torch.profiler.* as NullContextVariable and do nothing
127 with torch.profiler.profile():
129 with torch.profiler.record_function("my_function"):
135 x = torch.randn((2, 2), requires_grad=True)
137 cnts = torch._dynamo.testing.CompileCounter()
138 opt_fn = torch._dynamo.optimize(cnts)(fn)
144 # wrap torch.autograd.profiler.* as NullContextVariable and do nothing
147 with torch.autograd.profiler.profile():
149 with torch.autograd.profiler.record_function("my_function"):
155 x = torch.randn((2, 2), requires_grad=True)
157 cnts = torch._dynamo.testing.CompileCounter()
158 opt_fn = torch._dynamo.optimize(cnts)(fn)
163 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
166 s = torch.cuda.Stream()
167 x = torch.mul(x, 5)
168 x = torch.add(x, 2)
169 current_stream = torch.cuda.current_stream()
171 with torch.cuda.stream(s):
172 x = torch.relu(x)
174 x = torch.add(x, 1)
175 x = torch.cos(x)
178 x = torch.randn((2, 2), device="cuda")
180 cnts = torch._dynamo.testing.CompileCounter()
181 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
188 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
191 s = torch.cuda.Stream()
192 x = torch.mul(x, 5)
193 x = torch.add(x, 2)
197 tcs = torch.cuda.stream(s)
198 current_stream = torch.cuda.current_stream()
202 x = torch.relu(x)
205 x = torch.add(x, 1)
206 x = torch.cos(x)
209 x = torch.randn((2, 2), device="cuda")
211 cnts = torch._dynamo.testing.CompileCounter()
212 opt_fn = torch._dynamo.optimize(cnts)(fn)
219 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
222 x = torch.mul(x, 5)
223 x = torch.add(x, 2)
225 current_stream = torch.cuda.current_stream()
228 with torch.cuda.stream(s):
229 x = torch.relu(x)
232 with torch.cuda.stream(current_stream):
233 x = torch.relu(x)
235 s2 = torch.cuda.Stream()
237 with torch.cuda.stream(s2):
238 x = torch.relu(x)
241 x = torch.add(x, 1)
242 x = torch.cos(x)
245 x = torch.randn((2, 2), device="cuda")
246 s = torch.cuda.Stream()
248 cnts = torch._dynamo.testing.CompileCounter()
249 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
255 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
258 x = torch.mul(x, 1)
259 x = torch.add(x, 2)
261 new_stream = torch.cuda.Stream()
262 cur_stream = torch.cuda.current_stream()
265 with torch.cuda.stream(new_stream):
266 x = torch.sin(x)
267 x = torch.add(x, 3)
271 x = torch.add(x, 4)
275 with torch.cuda.stream(new_stream):
276 x = torch.add(x, 5)
281 x = torch.relu(x)
282 x = torch.cos(x)
285 x = torch.randn((2, 2), device="cuda")
287 cnts = torch._dynamo.testing.CompileCounter()
288 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
294 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
297 x = torch.mul(x, 1)
298 x = torch.add(x, 2)
300 cur_stream = torch.cuda.current_stream()
306 x = torch.mul(x, 1)
307 x = torch.add(x, 2)
309 cur_stream = torch.cuda.current_stream()
314 x = torch.randn((2, 2), device="cuda")
316 cnts = torch._dynamo.testing.CompileCounter()
317 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
318 opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
324 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
332 s0 = torch.cuda.Stream()
333 s1 = torch.cuda.Stream()
334 x = torch.randn(2, 2)
335 cnts = torch._dynamo.testing.CompileCounter()
336 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
349 torch._dynamo.reset()
350 cnts = torch._dynamo.testing.CompileCounter()
351 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
364 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
367 e = torch.cuda.Event()
368 x = torch.mul(x, 5)
369 x = torch.add(x, 2)
372 x = torch.randn((2, 2), device="cuda")
374 cnts = torch._dynamo.testing.CompileCounter()
375 opt_fn = torch._dynamo.optimize(cnts)(fn)
381 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
384 e = torch.cuda.Event()
386 x = torch.mul(x, 5)
387 x = torch.add(x, 2)
391 torch.cuda.current_stream().wait_event(e)
392 x = torch.add(x, 1)
393 x = torch.cos(x)
396 x = torch.randn((2, 2), device="cuda")
398 cnts = torch._dynamo.testing.CompileCounter()
399 opt_fn = torch._dynamo.optimize(cnts)(fn)
405 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
407 user_stream = torch.cuda.Stream()
408 event = torch.cuda.Event()
409 foo = torch.empty((2, 2), device="cuda")
415 x = torch.randn((1024, 1024), device="cuda")
416 cnts = torch._dynamo.testing.CompileCounter()
420 fn = torch._dynamo.optimize(cnts)(fn)
422 with torch.cuda.stream(user_stream):
423 torch.mm(x, x, out=foo)
434 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
437 x = torch.mul(x, 1)
438 x = torch.add(x, 2)
440 x = torch.add(x, 3)
446 with torch.cuda.stream(new_stream):
447 x = torch.add(x, 4)
449 new_event = torch.cuda.Event()
453 x = torch.add(x, 5)
458 x = torch.relu(x)
459 x = torch.cos(x)
462 x = torch.randn((2, 2), device="cuda")
463 cur_stream = torch.cuda.current_stream()
464 new_stream = torch.cuda.Stream()
466 cnts = torch._dynamo.testing.CompileCounter()
467 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
473 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
476 x = torch.mul(x, 1)
477 x = torch.add(x, 2)
479 cur_stream = torch.cuda.current_stream()
480 new_stream = torch.cuda.Stream()
482 x = torch.add(x, 3)
488 with torch.cuda.stream(new_stream):
489 x = torch.add(x, 4)
491 new_event = torch.cuda.Event()
495 x = torch.add(x, 5)
500 x = torch.relu(x)
501 x = torch.cos(x)
504 x = torch.randn((2, 2), device="cuda")
506 cnts = torch._dynamo.testing.CompileCounter()
507 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
513 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
516 with torch.cuda.device(x.device.index - 1):
517 x = torch.sin(x + 1)
520 x = torch.randn((2, 2), device="cuda")
522 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
528 if torch.autograd._profiler_enabled():
533 x = torch.randn((2, 2), requires_grad=True)
534 cnts = torch._dynamo.testing.CompileCounter()
535 opt_fn = torch._dynamo.optimize(cnts)(fn)
537 if torch.autograd._profiler_enabled():
538 torch.autograd._disable_profiler()
539 assert not torch.autograd._profiler_enabled()
544 with torch.autograd.profiler.profile():
545 assert torch.autograd._profiler_enabled()
550 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
552 if not torch.cuda.is_bf16_supported():
555 class MyModule(torch.nn.Module):
557 a_float32 = torch.rand((8, 8), device="cuda")
558 b_float32 = torch.rand((8, 8), device="cuda")
559 d_float32 = torch.rand((8, 8), device="cuda")
561 with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
562 e_float16 = torch.mm(a_float32, b_float32)
563 f_float16 = torch.mm(d_float32, e_float16)
567 real = module(torch.tensor([0.5]))
571 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
572 exported = graph(torch.tensor([0.5]))
578 self.assertEqual(exported.dtype, torch.bfloat16)
580 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
582 class MyModule(torch.nn.Module):
584 a_float32 = torch.rand((8, 8), device="cuda")
585 b_float32 = torch.rand((8, 8), device="cuda")
587 with torch.cuda.amp.autocast(dtype=torch.float64):
588 c_float64 = torch.mm(a_float32, b_float32)
592 real = module(torch.tensor([0.5]))
596 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
597 exported = graph(torch.tensor([0.5]))
603 self.assertEqual(exported.dtype, torch.float64)
607 with torch.cpu.amp.autocast(dtype=torch.bfloat16):
608 c_float16 = torch.mm(a_float32, b_float32)
609 if torch.is_autocast_cpu_enabled():
613 a = torch.rand((8, 8))
614 b = torch.rand((8, 8))
616 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
625 class MyModule(torch.nn.Module):
627 with torch.autocast("cpu"):
628 with torch.autocast("cuda", dtype=torch.float32):
634 dtype = torch.float32
638 query = torch.ones(
641 key = torch.ones(
644 value = torch.ones(
653 opt_mod = torch._dynamo.optimize("inductor")(module)
661 self.assertEqual(compiled.dtype, torch.float32)
664 class MyModule(torch.nn.Module):
666 a_float32 = torch.rand((8, 8), device="cpu")
667 b_float32 = torch.rand((8, 8), device="cpu")
668 d_float32 = torch.rand((8, 8), device="cpu")
670 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
671 e_float16 = torch.mm(a_float32, b_float32)
672 f_float16 = torch.mm(d_float32, e_float16)
676 real = module(torch.tensor([0.5]))
680 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
681 exported = graph(torch.tensor([0.5]))
686 self.assertEqual(exported.dtype, torch.bfloat16)
689 class MyModule(torch.nn.Module):
691 a_float32 = torch.rand((8, 8), device="cpu")
692 b_float32 = torch.rand((8, 8), device="cpu")
693 torch._dynamo.graph_break()
694 d_float32 = torch.rand((8, 8), device="cpu")
696 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
697 e_float16 = torch.mm(a_float32, b_float32)
698 torch._dynamo.graph_break()
699 f_float16 = torch.mm(d_float32, e_float16)
703 real = module(torch.tensor([0.5]))
707 opt = torch._dynamo.optimize("eager")(module)
708 res = opt(torch.tensor([0.5]))
713 self.assertEqual(res.dtype, torch.bfloat16)
718 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
719 x = torch.mm(x, x)
720 torch._dynamo.graph_break()
721 x = torch.relu(x)
724 x = torch.rand([4, 4])
725 self.assertEqual(x.dtype, torch.float32)
727 opt_fn = torch._dynamo.optimize("eager")(fn)
729 self.assertTrue(torch.allclose(res, opt_res))
730 self.assertEqual(res.dtype, torch.bfloat16)
731 self.assertEqual(opt_res.dtype, torch.bfloat16)
734 class MyModule(torch.nn.Module):
737 torch._dynamo.graph_break()
738 return torch.mm(x, y)
741 a_float32 = torch.rand((8, 8), device="cpu")
742 b_float32 = torch.rand((8, 8), device="cpu")
744 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
745 torch._dynamo.graph_break()
746 with torch.autocast(
747 device_type="cpu", dtype=torch.bfloat16, enabled=False
749 torch._dynamo.graph_break()
750 g_float32 = torch.mm(a_float32, b_float32)
751 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
753 torch._dynamo.graph_break()
762 real_16, real_32 = module(torch.tensor([0.5]))
768 graph = torch._dynamo.optimize("eager")(module)
769 out_16, out_32 = graph(torch.tensor([0.5]))
776 self.assertEqual(out_16.dtype, torch.bfloat16)
778 self.assertEqual(out_32.dtype, torch.float32)
781 class MyModule(torch.nn.Module):
787 return torch.mm(x, y) + self.bias
790 torch._dynamo.graph_break()
791 return torch.mm(x, y) + self.bias
794 a_float32 = torch.rand((8, 8), device="cpu")
795 b_float32 = torch.rand((8, 8), device="cpu")
797 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
798 with torch.autocast(
799 device_type="cpu", dtype=torch.bfloat16, enabled=False
801 g_float32 = torch.mm(a_float32, b_float32)
809 module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16))
811 with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
813 res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand(
814 (8, 8), device="cpu", dtype=torch.bfloat16
816 self.assertEqual(res.dtype, torch.float32)
818 real_16, real_32 = module(torch.tensor([0.5]))
824 graph = torch._dynamo.optimize("eager")(module)
825 out_16, out_32 = graph(torch.tensor([0.5]))
832 self.assertEqual(out_16.dtype, torch.bfloat16)
834 self.assertEqual(out_32.dtype, torch.float32)
836 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
838 class MyModule(torch.nn.Module):
840 a_float32 = torch.rand((8, 8), device="cuda")
841 b_float32 = torch.rand((8, 8), device="cuda")
842 d_float32 = torch.rand((8, 8), device="cuda")
844 with torch.autocast(device_type="cuda", dtype=torch.float64):
845 e_float64 = torch.mm(a_float32, b_float32)
846 f_float64 = torch.mm(d_float32, e_float64)
850 real = module(torch.tensor([0.5]))
854 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
855 exported = graph(torch.tensor([0.5]))
860 self.assertEqual(exported.dtype, torch.float64)
862 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
864 class MyModule(torch.nn.Module):
866 a_float32 = torch.rand((8, 8), device="cuda")
867 b_float32 = torch.rand((8, 8), device="cuda")
868 d_float32 = torch.rand((8, 8), device="cuda")
870 with torch.autocast("cuda"):
871 e_float64 = torch.mm(a_float32, b_float32)
872 f_float64 = torch.mm(d_float32, e_float64)
876 real = module(torch.tensor([0.5]))
880 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
881 exported = graph(torch.tensor([0.5]))
886 self.assertEqual(exported.dtype, torch.float16)
888 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
891 with torch.cuda.amp.autocast(False):
892 x = torch.sin(x + 1)
896 with torch.cpu.amp.autocast(False):
897 x = torch.cos(x + 1)
900 x = torch.rand([2, 3])
903 opt_f1 = torch.compile(backend="eager")(f1)
904 opt_f2 = torch.compile(backend="eager")(f2)
910 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
913 @torch.amp.autocast(device_type="cuda", dtype=torch.float16)
920 @torch.cuda.amp.autocast(dtype=torch.float16)
927 @torch.cpu.amp.autocast(dtype=torch.float16)
934 return torch.mm(a, b)
943 a_float32 = torch.rand((8, 8), device="cuda")
944 b_float32 = torch.rand((8, 8), device="cuda")
947 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
950 self.assertTrue(res[0].dtype == torch.float16)
951 self.assertTrue(res[1].dtype == torch.float16)
961 x = torch.randn(2, 3)
962 opt_fn = torch.compile(backend="eager", fullgraph=False)(fn)
966 @torch.compile(backend="eager", fullgraph=True)
973 x = torch.randn(2, 3)
978 @torch.compile(backend="eager", fullgraph=False)
981 torch._dynamo.graph_break()
986 x = torch.randn(2, 3)
994 if torch.is_grad_enabled():
996 x = torch.relu(x)
999 x = torch.rand(2, 3)
1000 cnts = torch._dynamo.testing.CompileCounter()
1001 opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
1003 with torch.no_grad():
1010 with torch.enable_grad():
1021 if torch.is_grad_enabled():
1024 if torch.is_grad_enabled():
1027 x = torch.relu(x)
1030 x = torch.rand(2, 3)
1031 cnts = torch._dynamo.testing.CompileCounter()
1032 opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
1034 with torch.no_grad():
1041 with torch.enable_grad():
1052 if torch.is_grad_enabled():
1054 torch._dynamo.graph_break()
1055 x = torch.relu(x)
1058 x = torch.rand(2, 3)
1059 cnts = torch._dynamo.testing.CompileCounter()
1060 opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
1062 with torch.no_grad():
1069 with torch.enable_grad():
1080 if torch.is_grad_enabled():
1083 if torch.is_grad_enabled():
1085 torch._dynamo.graph_break()
1087 x = torch.relu(x)
1090 x = torch.rand(2, 3)
1091 cnts = torch._dynamo.testing.CompileCounter()
1092 opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
1094 with torch.no_grad():
1101 torch._dynamo.reset()
1102 cnts = torch._dynamo.testing.CompileCounter()
1103 opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
1105 with torch.enable_grad():
1114 with torch.no_grad():
1115 torch._dynamo.graph_break()
1116 return torch.sin(z)
1119 a = torch.mm(x, y)
1123 torch._dynamo.reset()
1124 cnts = torch._dynamo.testing.CompileCounter()
1125 opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
1126 x = torch.randn(4, 4, requires_grad=True)
1127 y = torch.randn(4, 4, requires_grad=True)
1128 z = torch.randn(4)
1135 with torch.autocast(device_type=device, dtype=torch.bfloat16):
1136 z = torch.mm(x, y)
1137 torch._dynamo.graph_break()
1138 return torch.sin(z)
1141 z = torch.mm(x, y)
1145 x = torch.rand(3, 3).to(device)
1146 y = torch.rand(3, 3).to(device)
1147 opt_fn = torch.compile(backend="eager")(fn)
1155 torch.cuda.is_available() and torch.cuda.is_bf16_supported()
1162 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1166 x, y = torch.ones(
1168 ), torch.zeros(
1174 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1182 class GraphModule(torch.nn.Module):
1184 …_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not suppor…
1186 x: "f32[1]" = torch.ones(1)
1188 y: "f32[1]" = torch.zeros(1)
1192 …_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_ho…
1199 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1203 x, y = torch.ones(
1205 ), torch.zeros(
1211 with torch.autograd.graph.disable_saved_tensors_hooks(
1214 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1222 class GraphModule(torch.nn.Module):
1224 …_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not suppor…
1226 x: "f32[1]" = torch.ones(1)
1228 y: "f32[1]" = torch.zeros(1)
1232 …_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disab…
1239 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1241 @torch.autograd.graph.disable_saved_tensors_hooks(
1249 x, y = torch.ones(
1251 ), torch.zeros(
1257 with torch.autograd.graph.disable_saved_tensors_hooks(
1260 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1268 class GraphModule(torch.nn.Module):
1270 …_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not suppor…
1272 x: "f32[1]" = torch.ones(1)
1274 y: "f32[1]" = torch.zeros(1)
1276 …_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supp…
1280 …_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supp…
1284 …_saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disab…
1291 with torch.autograd.graph.disable_saved_tensors_hooks(
1295 torch._dynamo.graph_break()
1299 torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
1309 class GraphModule(torch.nn.Module):
1313 …_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not suppor…
1317 …_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_ho…
1327 class GraphModule(torch.nn.Module):
1331 …_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not suppor…
1335 …_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_ho…
1341 ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)]
1344 torch._dynamo.reset()
1360 assert torch.is_grad_enabled() == mode_inverse
1365 x = torch.zeros(10, requires_grad=True)
1366 opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1371 ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)]
1375 torch._dynamo.reset()
1395 assert torch.is_grad_enabled() == mode_inverse
1400 x = torch.zeros(10, requires_grad=True)
1401 opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1409 torch._dynamo.reset()
1415 with torch.set_grad_enabled(mode_inverse):
1418 @torch.set_grad_enabled(mode)
1427 inner_func = torch.set_grad_enabled(mode)(inner_func)
1431 assert torch.is_grad_enabled() == mode_inverse
1433 with torch.set_grad_enabled(mode_inverse):
1436 x = torch.zeros(10, requires_grad=True)
1437 opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1444 ctx = torch.set_grad_enabled(True)
1445 torch._dynamo.graph_break()
1450 x = torch.zeros(10, requires_grad=False)
1451 cnts = torch._dynamo.testing.CompileCounter()
1452 opt_fn = torch.compile(fn, backend=cnts)
1464 torch._dynamo.graph_break()
1469 x = torch.zeros(10, requires_grad=False)
1470 cnts = torch._dynamo.testing.CompileCounter()
1471 opt_fn = torch.compile(fn, backend=cnts)
1482 torch._dynamo.graph_break()
1493 x = torch.zeros(10, requires_grad=False)
1494 cnts = torch._dynamo.testing.CompileCounter()
1495 opt_fn = torch.compile(fn, backend=cnts)
1502 torch._dynamo.graph_break()
1507 ctx = gn(torch.set_grad_enabled(True))
1513 x = torch.zeros(10, requires_grad=False)
1514 cnts = torch._dynamo.testing.CompileCounter()
1515 opt_fn = torch.compile(fn, backend=cnts)
1526 x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break())
1529 x = torch.zeros(10, requires_grad=False)
1530 cnts = torch._dynamo.testing.CompileCounter()
1531 opt_fn = torch.compile(fn, backend=cnts)
1538 from torch._dynamo.test_case import run_tests