# Owner(s): ["NNC"] import contextlib import math import operator import os import unittest import warnings from typing import List import torch import torch.nn.functional as F from torch.testing import FileCheck # these needs to be set before `common_utils` # infers `GRAPH_EXECUTOR`. # this file **requires** these settings # and setting them after `GRAPH_EXECUTOR` is # inferred erroneously runs or skips # some tests torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) from itertools import combinations, permutations, product from textwrap import dedent from jit.test_fuser_common import TestFuserCommon # noqa: F401 from test_jit import ( backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell, ) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, OpDTypes, ops, ) from torch.testing._internal.common_jit import JitCommonTestCase from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, IS_FBCODE, ProfilingMode, run_tests, skipIfTorchDynamo, slowTest, TEST_WITH_ASAN, TEST_WITH_ROCM, ) from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing._internal.jit_utils import ( clone_inputs, get_traced_sample_variant_pairs, JitTestCase, NoTracerWarnContextManager, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, set_fusion_group_inlining, TensorExprTestOptions, warmup_backward, ) FUSION_GROUP = "prim::TensorExprGroup" LLVM_ENABLED = torch._C._llvm_enabled() autograd_check_set = { "aten::__is__", "prim::AutogradAllNonZero", "prim::AutogradAllZero", "prim::ListConstruct", } def strip_profiling_nodes(nodes): profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"} return [n for n in nodes if n.kind() not in profiling_opcodes] def warmup_forward(f, *args, profiling_count=2): for i in range(profiling_count): results = f(*args) return results @contextlib.contextmanager def texpr_reductions_enabled(): old = torch._C._jit_set_texpr_reductions_enabled(True) try: yield finally: torch._C._jit_set_texpr_reductions_enabled(old) @contextlib.contextmanager def texpr_enable_strategy(strategy): old = torch._C._jit_set_fusion_strategy(strategy) try: yield finally: torch._C._jit_set_fusion_strategy(old) @contextlib.contextmanager def inline_fusion_groups(): old_inlining = torch._C._debug_get_fusion_group_inlining() torch._C._debug_set_fusion_group_inlining(True) try: yield finally: torch._C._debug_set_fusion_group_inlining(old_inlining) class TestTEFuser(JitTestCase): def setUp(self): super().setUp() self.tensorexpr_options = TensorExprTestOptions() # note: `self.dynamic_shapes` instatiated in specialization of class # defined below fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy) self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] self.int_dtypes = [ torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, ] self.fp_dtypes = [ torch.float16, torch.float32, torch.float64, torch.bfloat16, ] self.dtypes = self.int_dtypes + self.fp_dtypes def tearDown(self): self.tensorexpr_options.restore() torch._C._jit_set_fusion_strategy(self.old_fusion_strategy) super().tearDown() def assertAllFused(self, graph, except_for=None): except_for = except_for if except_for is not None else set() # TODO - upstream guards = ( "prim::TypeCheck", "prim::RequiresGradCheck", "prim::TensorExprDynamicGuard", ) guard_found = False def autodiff_guard(node): if node.kind() != "aten::all": return False inps = list(node.inputs()) if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct": return False li_inps = list(inps[0].node().inputs()) for li_inp in li_inps: if li_inp.node().kind() in ( "prim::AutogradAllNonZero", "prim::AutogradAllZero", ): return True return False def is_guard(node): return node.kind() in guards or autodiff_guard(node) for node in graph.block().nodes(): if node.kind() == "prim::Constant": continue if is_guard(node): self.assertFalse(guard_found) guard_found = True continue if node.kind() in except_for: continue if node.kind() == "prim::If": self.assertTrue(is_guard(node.prev())) continue self.assertTrue(False, "Found unexpected node:" + node.kind()) self.assertTrue(guard_found) def assertLastGraphAllFused(self): self.assertAllFused(torch.jit.last_executed_optimized_graph()) def findFusionGroups(self, graph): result = [] for n in graph.nodes(): if n.kind() == FUSION_GROUP: result.append(n.g("Subgraph")) continue for block in n.blocks(): result += self.findFusionGroups(block) return result def test_typecheck(self): a = torch.ones(1) def fused_kernel(a, b): return (a + b) * 2.0 scripted = self.checkScript(fused_kernel, (a, a)) graph = scripted.graph_for(a, a) # double check we fused fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) # we use a bigger tensor now (size 2) # if we won't trigger a recompilation # we will still create a tensor up to (size 1) # if the type check fails a = torch.ones(2) # shape changed if we don't trigger recompilation # we would compute the wrong result silently self.assertEqual(scripted(a, a), fused_kernel(a, a)) def test_sum_simple(self): def func(x): x2 = x * x return x2.sum() with texpr_reductions_enabled(): a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() def test_nop(self): pass def test_sum_dim(self): def func(x): return x.sum((0,)) * 2 def func_neg(x): return x.sum((-2,)) * 2 with texpr_reductions_enabled(): a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() scripted = self.checkScript(func_neg, (a,)) self.assertLastGraphAllFused() def test_sum_keepdim_cast(self): def func(x): return x.sum((0,), keepdim=True, dtype=torch.double) * 2 with texpr_reductions_enabled(): a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) self.checkScript(func, (a,)) self.assertLastGraphAllFused() def test_abs(self): for device in self.devices: def func(x): return x.abs() * 2 a = torch.randn(5, device=device) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() def test_unsqueeze_size_calculation(self): for device in self.devices: def foo(b, d): x = d.unsqueeze(1) y = x * 42.0 z = b + y r = z / 42.0 return r inputs = ( torch.rand(20, 28, device=device, requires_grad=True), torch.rand(20, device=device), ) scripted = self.checkScript(foo, inputs) self.assertAllFused(scripted.graph_for(*inputs)) def test_zero_element_tensors(self): for device in self.devices: def decode(sin_t, cos_t): theta = torch.atan2(sin_t.float(), cos_t.float()) return theta sin = torch.zeros(0, device=device) cos = torch.zeros(0, device=device) inputs = [sin, cos] ge = self.checkScript(decode, inputs) def test_arg_configurations_smoke(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") # A smoke test to make sure we won't use the same kernel for contiguous # and non-contiguous arguments. # TODO: add optionally enabled debug counters to the fuser to verify # that we really can tell the difference between configurations for device in self.devices: def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) traced_f = torch.jit.trace(f, (x, y)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) def test_broadcast(self): for device in self.devices: def scaleshift(x, scale, shift): return x * scale + shift inputs = [ torch.randn(4, 4, dtype=torch.float, device=device), torch.randn(4, dtype=torch.float, device=device), torch.randn(4, dtype=torch.float, device=device), ] self.checkScript(scaleshift, inputs) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @unittest.skipIf( GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on" ) def test_cuda_half(self): x = torch.randn(4, 4, dtype=torch.half, device="cuda") y = torch.randn(4, 4, dtype=torch.half, device="cuda") funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp] # Note: Non fused inputs must be float to prevent loss of precision inputs = (x.float(), y.float()) fusion_inputs = (x, y) for fn in funcs: local_inputs = [t.clone().requires_grad_() for t in inputs] local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs] # Verifies outputs fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False) outputs = fn(*local_inputs) fusion_outputs = fusion(*local_fusion_inputs) outputs_half = [t.half() for t in outputs] self.assertEqual(outputs_half, fusion_outputs) # Verifies gradients for output, fusion_output in zip(outputs_half, fusion_outputs): grads = torch.autograd.grad( output.float().sum(), local_inputs, allow_unused=True, retain_graph=True, ) fusion_grads = torch.autograd.grad( fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True, ) grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) def test_checks_cat_inputs(self): # single fusion node causes error with set_fusion_group_inlining(True): for device in self.devices: # We shouldn't treat cat nodes as broadcasting. All their inputs # need to be checked for having the same map size, before we can # run the kernel. def f(x, y): return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0) # NOTE: y is broadcastable to x, but output of f(x, y) should have # shape 3x4, and not 4x4. x = torch.randn(2, 4, dtype=torch.float, device=device) y = torch.randn(1, 4, dtype=torch.float, device=device) scripted = self.checkScript(f, (x, y)) self.assertEqual(scripted(x, y).shape, (3, 4)) self.assertAllFused(scripted.graph_for(x, y)) def test_chunk(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def fn(x): a, b, c = x.chunk(3, 1) return a * b + c inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] self.checkScript(fn, inputs) self.assertLastGraphAllFused() def test_chunk_correctness(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def chunk_4_0(x): x0, x1, x2, x3 = x.chunk(4, 0) return x0 + x1 + x2 + x3 def chunk_4_1(x): x0, x1, x2, x3 = x.chunk(4, 1) return x0 + x1 + x2 + x3 def chunk_4_last(x): x0, x1, x2, x3 = x.chunk(4, 2) return x0 + x1 + x2 + x3 fns = [chunk_4_0, chunk_4_1, chunk_4_last] tensors = [ # splitSize = 1 torch.randn(4, 4, 4, dtype=torch.float, device=device), # contiguous case torch.randn(12, 8, 16, dtype=torch.float, device=device), # non-contiguous case torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose( 1, 2 ), ] for tensor in tensors: for fn in fns: self.checkScript(fn, [tensor]) self.assertLastGraphAllFused() def test_chunk_distributes(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(f, (x, y)) graph = ge.graph_for(x, y) # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) FileCheck().check("with " + FUSION_GROUP + "_").check_count( "ConstantChunk", 1, exactly=True ).run(str(graph)) def test_chunk_motion_deduplicates_inputs(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def func1(x): z = x * x z0, z1 = z.chunk(2) return z0 * z1 def func2(x): z = x * x * x z0, z1 = z.chunk(2) return z0 * z1 inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)] for func in [func1, func2]: self.checkScript(func, inputs) self.assertLastGraphAllFused() def test_chunk_multiple(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: # The arguments are intentionally used out of order as a test to see # if the fusion compiler adds extra args in the correct order def fn(s, x, y, z): z1, z2 = z.chunk(2, 2) x1, x2, x3 = x.chunk(3, 1) y1, y2 = y.chunk(2, 0) return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 inputs = [ torch.randn(5, 2, 3, dtype=torch.float, device=device), torch.randn(5, 6, 3, dtype=torch.float, device=device), torch.randn(10, 2, 3, dtype=torch.float, device=device), torch.randn(5, 2, 6, dtype=torch.float, device=device), ] ge = self.checkScript(fn, inputs) self.assertAllFused(ge.graph_for(*inputs)) def test_minmax(self): for device in self.devices: def tmax(a, b): return torch.max(2 * a, b) def tmin(a, b): return torch.min(2 * a, b) a = torch.randn(4, 4, dtype=torch.float) b = torch.randn(4, 4, dtype=torch.float) nan = torch.tensor(float("nan"), dtype=torch.float) for f, inputs, device in product( (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices ): inputs = [t.to(device) for t in inputs] s = self.checkScript(f, inputs) self.assertAllFused(s.graph_for(*inputs)) def test_clamp(self): for device in self.devices: def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float("inf")) def funcNegInf(a, b): return torch.clamp(a + b, min=float("-inf"), max=0) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device=device) nan = torch.tensor(float("nan"), dtype=torch.float, device=device) funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused( s.graph_for(inp1, inp2), except_for={"aten::size", "aten::_size_if_not_equal"}, ) c = s(inp1, inp2) with enable_profiling_mode_for_profiling_tests(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused( graph, except_for={"aten::Float", "aten::_grad_sum_to_size"}.union( autograd_check_set ), ) def test_clamp_double(self): for device in self.devices: def clamp_double(x, eta: float): return 1 - x.clamp(eta, 1 - eta) x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device) eta = 1e-9 s = self.checkScript( clamp_double, (x, eta), profiling=ProfilingMode.PROFILING, atol=1e-10, rtol=1e-5, ) self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"}) def test_clamp_int(self): for device in self.devices: def clamp_int(x, eta: int): return x.clamp(0, eta) x = torch.tensor([1, 1], device=device) eta = 1 << 32 s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING) self.assertAllFused(s.graph_for(x, eta)) def test_add_bool(self): sizes = [(1,), (2,), (4, 4)] for device, size in product(self.devices, sizes): def f(x, y, z): return x + y + z x = torch.randint(0, 2, size, dtype=torch.bool, device=device) y = torch.randint(0, 2, size, dtype=torch.bool, device=device) z = torch.randint(0, 2, size, dtype=torch.bool, device=device) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) def test_mul_bool(self): for device in self.devices: def f(x, y, z): return x * y * z x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) def test_div_bool(self): for device in self.devices: def f(x, y, z): return (x + y) / z x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) z = torch.ones_like(x, dtype=torch.bool, device=device) ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) binary_ops = [ operator.__and__, operator.__or__, operator.__xor__, operator.__lshift__, operator.__rshift__, ] devices = self.devices for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) z = self.data_for(dtype, device) fn = apply(op) ref = fn(x, y, z) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y, z)) self.assertEqual(ref, t(x, y, z)) self.assertAllFused(t.graph_for(x, y, z)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) binary_ops = [torch.min, torch.max] devices = self.devices for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) z = self.data_for(dtype, device) fn = apply(op) ref = fn(x, y, z) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y, z)) self.assertEqual(ref, t(x, y, z)) self.assertAllFused(t.graph_for(x, y, z)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_comparison_eq_ne(self): for device in self.devices: def f(x, y): mask = (x == 0).type_as(x) z = x * mask + y mask = (x != 0).type_as(x) z = z * mask + y return z x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(f, (x, y)) self.assertAllFused(ge.graph_for(x, y)) @staticmethod def fn_test_comparison_gt_lt(x, y): mask = (x > 0).type_as(x) z = x * mask + y mask = (x < 0).type_as(x) z = z * mask + y return z def test_comparison_gt_lt(self): for device in self.devices: x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) self.assertAllFused(ge.graph_for(x, y)) def test_comparison_ge_le(self): for device in self.devices: def f(x, y): mask = (x >= 0).type_as(x) z = x * mask + y mask = (x <= 0).type_as(x) z = z * mask + y return z x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(f, (x, y)) self.assertAllFused(ge.graph_for(x, y)) x.requires_grad_(True) y.requires_grad_(True) self.assertAllFused( ge.graph_for(x, y), except_for=( "aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal", ), ) def test_addcmul(self): for device in self.devices: t = torch.randn(1, 4, dtype=torch.float, device=device) t1 = torch.randn(4, 1, dtype=torch.float, device=device) t2 = torch.randn(1, 4, dtype=torch.float, device=device) def foo(t, t1, t2): return t.addcmul(t + 1, t2, value=0.1) ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) graph = ge.graph_for(t, t1, t2) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) FileCheck().check("aten::add(").check("aten::addcmul(").run( str(fusion_groups[0]) ) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive # until the end of the process, the memory is effectively leaked. # Removed `_cuda` suffix from this test which disables leak-checking. # If this is a real problem, we'll need to revisit Torchscript Function # lifetimes in Python. def test_lerp(self): for device in self.devices: start = torch.randn(4, 1, dtype=torch.float, device=device) end = torch.randn(1, 4, dtype=torch.float, device=device) weight = torch.tensor(0.5, dtype=torch.float, device=device) # scalar weight overload def foo_weight_scalar(start, end): return torch.lerp(start + 1, end, 0.5) # tensor weight overload def foo_weight_tensor(start, end): return torch.lerp(start + 1, end, weight) ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) graph = ge_weight_scalar.graph_for(start, end) self.assertAllFused(graph) # TODO: uncomment when TE enables support for scalar tensors # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) # graph = ge_weight_tensor.graph_for(start, end) # self.assertAllFused(graph) def test_concat(self): # disabling concat causes error with single concat node with set_fusion_group_inlining(True): for device in self.devices: hx = torch.randn(3, 20, dtype=torch.float, device=device) cx = torch.randn(3, 20, dtype=torch.float, device=device) def foo(hx, cx): return torch.cat((hx + cx, hx * cx)) ge = self.checkTrace(foo, (hx, cx)) graph = ge.graph_for(hx, cx) self.assertAllFused(graph) # XXX: TE fuser can handle concats in a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) def test_remove_output_used_only_in_size(self): for device in self.devices: def test_fuse(a, b): c = a + b d = c + b return d scripted_f = torch.jit.script(test_fuse) x = torch.ones(1, requires_grad=True, device=device) y = torch.ones(1, requires_grad=True, device=device) warmup_forward(scripted_f, x, y, profiling_count=3) g = scripted_f.graph_for(x, y) diff_nodes = g.findAllNodes("prim::DifferentiableGraph") self.assertEqual(len(diff_nodes), 1) g = diff_nodes[0].g("Subgraph") if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"] self.assertEqual(len(if_nodes), 1) # the if node and the fusion group inside it should only have one output self.assertEqual(len(list(if_nodes[0].outputs())), 1) def test_concat_invariant(self): for device in self.devices: # Invariant: the output of prim::FusedConcat may # not be an input to any node inside the FusionGroup. def fn(x, y, z): x1 = x + y y1 = x - y w = torch.cat([x1, y1]) return w + z x = torch.randn(2, 2, dtype=torch.float, device=device) y = torch.randn(2, 2, dtype=torch.float, device=device) z = torch.randn(4, 2, dtype=torch.float, device=device) ge = self.checkTrace(fn, (x, y, z)) graph = ge.graph_for(x, y, z) self.assertAllFused(graph, except_for={"aten::add"}) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @staticmethod def fn_test_exp(x, y): return (x + 0.5 * y).exp() def test_exp(self): for device in self.devices: x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(self.fn_test_exp, (x, y)) self.assertAllFused(ge.graph_for(x, y)) def test_threshold(self): for device in self.devices: def f(x): return torch.threshold(x, 0, -10) + x + x + x x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device) scripted = self.checkScript(f, (x,)) self.assertAllFused(scripted.graph_for(x)) def test_scalar_arg(self): for device in self.devices: def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: return p * (x * x + x) x = torch.randn(4, 4, dtype=torch.float, device=device) p = 3 scripted = self.checkScript(fn_test_scalar_arg, (x, p)) self.assertAllFused(scripted.graph_for(x, p)) x.requires_grad_(True) # use another function otherwise we will bailout # and won't be able to do fused checks def fn_test_scalar_arg_requires_grad( x: torch.Tensor, p: float ) -> torch.Tensor: return p * (x * x + x) scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) out = scripted(x, p) out = scripted(x, p) out = scripted(x, p) self.assertAllFused( scripted.graph_for(x, p), except_for=( "aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal", ), ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") def test_fusion_reuse_multi_gpu(self): def fn(x, y): return x * y * x * y inputs_cpu = [ torch.randn(4, 4, dtype=torch.float), torch.randn(4, 4, dtype=torch.float), ] inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] # Should not crash; these should compile different kernels. ge = self.checkScript(fn, inputs_cpu) self.assertAllFused(ge.graph_for(*inputs_cpu)) ge(*inputs_cuda0) ge(*inputs_cuda1) # TODO: we're currently not checking 'device' in the type info when pulling # nodes into a fusion group. We should fix that and re-enable this test. @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") def test_kernel_cache_multi_gpu(self): def not_fusible(x): return x def fn(x, y, z): x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x y_out = y * y * y * y * y z_out = z * z * z * z * z return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) inputs = [ torch.randn(4, 4, dtype=torch.float), torch.randn(4, 4, dtype=torch.float, device="cuda:0"), torch.randn(4, 4, dtype=torch.float, device="cuda:1"), ] prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() # There are 3 FusionGroups. Because they have the same graph, they # should reuse the same KernelSpec in the KernelSpec cache. ge = self.checkScript(fn, inputs) self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True) new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() # XXX: This assumes that the same kernel isn't already used by another test # FIXME: Use the TE fuser's way of querying the cache. # self.assertEqual(new_cache_size - prev_cache_size, 1) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") def test_nonzero_device_cuda(self): device = "cuda:" + str(1) x = torch.tensor([0.4], dtype=torch.float, device=device) y = torch.tensor([0.7], dtype=torch.float, device=device) def doit(x, y): return torch.sigmoid(torch.tanh(x * (x + y) + x)) ge = self.checkTrace(doit, (x, y)) self.assertAllFused(ge.graph_for(x, y)) def test_lstm(self): for device in self.devices: inputs = get_lstm_inputs(device, training=True) module = self.checkScript(LSTMCellS, inputs) self.assertAllFused( module.graph_for(inputs), except_for={"prim::TupleConstruct"} ) def test_lstm_concat(self): # single fusion node causes error with set_fusion_group_inlining(True): for device in self.devices: inputs = get_lstm_inputs(device) ge = self.checkTrace(LSTMCellC, inputs) graph = ge.graph_for(*inputs) except_nodes = {"prim::TupleConstruct", "aten::linear"} # TODO... Chunk if self.dynamic_shapes: except_nodes = except_nodes.union( {"aten::add", "prim::ConstantChunk"} ) self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) def test_lstm_gates_permutations(self): for device in self.devices: # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. # Test that any permutation of this will still result in one FusionGroup. choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"] template = dedent( """ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): gates = {} + {} + {} + {} ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) return ingate * forgetgate * cellgate * outgate """ ) for permutation in permutations(choices, len(choices)): code = template.format(*permutation) scope = {} exec(code, globals(), scope) cu = torch.jit.CompilationUnit(code) fusion_group_len = 2 if self.dynamic_shapes else 1 inputs = get_lstm_inputs(device, training=False) self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs)) forward_graph = cu.cell.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, fusion_group_len ) # TODO: Fuser doesn't work at all when inputs require grad. Fix that def test_lstm_traced(self): for device in self.devices: inputs = get_lstm_inputs(device) ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) fusion_groups = self.findFusionGroups(graph) # TODO: chunk fusion_group_len = 2 if self.dynamic_shapes else 1 self.assertEqual(len(fusion_groups), fusion_group_len) f = FileCheck() if not self.dynamic_shapes: f.check("Chunk") f.check("aten::sigmoid").check("aten::tanh").run( str(fusion_groups[0 if not self.dynamic_shapes else 1]) ) def test_milstm(self): if self.dynamic_shapes: self.skipTest("don't run conv with dynamic shapes") for device in self.devices: inputs = get_milstm_inputs(device, training=True) module = self.checkScript(MiLSTMCell, inputs) forward_graph = module.graph_for(*inputs) # TODO: chunk fusion_group_len = 2 if self.dynamic_shapes else 1 self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True ) FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next( "return" ).check(FUSION_GROUP).run(str(forward_graph)) hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") def test_rand_cuda(self): class M(torch.jit.ScriptModule): __constants__ = ["d"] def __init__(self) -> None: super().__init__() self.d = torch.device("cuda") @torch.jit.script_method def create(self, x): return x * x + x + torch.rand_like(x) x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda") m = M() out1 = m.create(x) out2 = m.create(x) self.assertNotEqual(out1, out2) self.assertTrue(torch.all(out1 >= 0)) self.assertTrue(torch.all(out1 < 1)) self.assertTrue(torch.all(out2 >= 0)) self.assertTrue(torch.all(out2 < 1)) self.assertAllFused(m.create.graph_for(x)) @staticmethod def fn_test_relu(x, y): return F.relu(x + 0.5 * y) def test_relu(self): for device in self.devices: x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(self.fn_test_relu, (x, y)) self.assertAllFused(ge.graph_for(x, y)) def test_erf(self): for device in self.devices: # only enabled on gpu if device == "cpu": continue def fn_test_erf(x): return F.relu(torch.erf(x) - torch.erfc(x)) x = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) self.assertAllFused(ge.graph_for(x)) x.requires_grad_(True) ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) self.assertAllFused( ge.graph_for(x), except_for=( "aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal", ), ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") def test_rand_broadcast_cuda(self): def fn_test_rand(x, y): r = torch.rand_like(y) return r * x + x # If using profiling, a different function is needed to test different # shapes, or we'll use a cached script. def fn_test_rand2(x, y): r = torch.rand_like(y) return r * x * x x = torch.randn(4, 4, dtype=torch.float, device="cuda") y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand) warmup_forward(script_f, x, y) out = script_f(x, y) self.assertAllFused(script_f.graph_for(x, y)) x.requires_grad_(True) out = script_f(x, y) self.assertAllFused( script_f.graph_for(x, y), except_for=( "aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal", ), ) # test that broadcasting random produces correct results x = torch.ones(4, 4, dtype=torch.float, device="cuda") y = torch.ones(4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand2) warmup_forward(script_f, x, y) out = script_f(x, y) self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") def test_rand_diamond(self): def fn_test_diamond(x, y): r = torch.rand_like(y) a = x + r b = y - r return a + b x = torch.randn(4, 4, dtype=torch.float, device="cuda") y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_diamond) warmup_forward(script_f, x, y) out = script_f(x, y) self.assertEqual(out, x + y) def test_scalar(self): def fn(x, y): return 2 * x + y x = torch.tensor(0.1, dtype=torch.float, device="cpu") y = torch.tensor(1, dtype=torch.float, device="cpu") ge = self.checkScript(fn, (x, y)) self.assertAllFused(ge.graph_for(x, y)) def test_inlined_optimized_graph(self): @torch.jit.script def foo(x): return torch.relu(x + x) for _ in range(3): foo(torch.rand([4, 4])) for _ in range(3): foo(torch.rand([10])) for _ in range(3): foo(torch.rand([2, 2, 2])) g = torch.jit.last_executed_optimized_graph() FileCheck().check_count("prim::If", 1, exactly=True).check( "prim::TensorExpr" ).run(g) torch._C._jit_pass_inline(g) f = FileCheck() for _ in range(3): f.check("prim::If").check("prim::TensorExpr") f.run(g) def test_small_constant(self): for device in self.devices: def fn_test_small_constant(x, y): return (1e-8 * x + 5e-9 * y) * 1e8 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(fn_test_small_constant, (x, y)) self.assertAllFused(ge.graph_for(x, y)) # Currently we don't pull constants into fusion groups, because in some # cases it could remove the constant from the original graph and now our # fusion group needs to return that constant for its other users. # Instead of never pulling constants into the fusion group, we should just # be more careful at how we rewrite its users. # TODO: fix that and reenable the test. def test_tensor_scalar_ops(self): for device in self.devices: def should_fuse(x): z = 3.0 y = x + z return x * y def should_fuse_scalar(x, z): y = x + int(z) return x * y inputs = [torch.randn(2, 2, dtype=torch.float, device=device)] ge = self.checkScript(should_fuse, inputs) graph = ge.graph_for(*inputs) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), torch.tensor(3.0, dtype=torch.float, device=device), ] ge = self.checkScript(should_fuse_scalar, inputs) # Check that the fused graph computes correct results when the scalar # input changes. inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), torch.tensor(7.0, dtype=torch.float, device=device), ] self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) # The TE fuser supports fusion of non-constant scalars self.assertGraphContainsExactly( ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True ) def test_where_and_typing(self): for device in self.devices: def f(x, y): mask = x > y res = torch.where(mask, x, y) return mask, res x = torch.randn(4, 4, dtype=torch.double, device=device) y = torch.randn(4, 4, dtype=torch.double, device=device) script_f = self.checkScript(f, (x, y)) self.assertAllFused( script_f.graph_for(x, y), except_for={"prim::TupleConstruct"} ) def test_disabled(self): old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() torch._C._jit_override_can_fuse_on_cpu(False) def fn(a): return a**2 + a x = torch.randn(4, dtype=torch.float, device="cpu") s = self.checkScript(fn, (x,)) g = s.graph_for(x) self.assertEqual(len(self.findFusionGroups(g)), 0) torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) def data_for(self, dtype, device="cuda", size=None): if size is None: v = torch.arange(1, 3, dtype=torch.float, device=device) else: v = torch.rand(*size, device=device) if dtype == torch.bool: return v > 2 elif dtype in [torch.qint8, torch.quint8, torch.qint32]: return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype) else: return v.to(dtype) def test_torch_to(self): # test no op @torch.jit.script def foo(x): return x.to(torch.float) foo(torch.tensor([3.0], dtype=torch.float)) foo(torch.tensor([3.0], dtype=torch.float)) FileCheck().check_not("TensorExpr").run( torch.jit.last_executed_optimized_graph() ) # test not fusing non-const inputs @torch.jit.script def foo(x, dtype: int): return x.to(dtype) foo(torch.tensor([3.0], dtype=torch.float), torch.int) foo(torch.tensor([3.0], dtype=torch.float), torch.int) FileCheck().check_not("TensorExpr").run( torch.jit.last_executed_optimized_graph() ) # test not fusing to_pinned inputs @torch.jit.script def foo(x, dtype: int): return x.to(pin_memory=True) foo(torch.tensor([3.0], dtype=torch.float), torch.int) foo(torch.tensor([3.0], dtype=torch.float), torch.int) FileCheck().check_not("TensorExpr").run( torch.jit.last_executed_optimized_graph() ) # test across-device not supported if torch.cuda.is_available(): @torch.jit.script def foo(x): return x.to(device="cuda") foo(torch.tensor([3.0], dtype=torch.float)) foo(torch.tensor([3.0], dtype=torch.float)) FileCheck().check_not("TensorExpr").run( torch.jit.last_executed_optimized_graph() ) sizes = [(1, 4), (4, 4)] # reuses cast impl, smaller dtype set for faster test dtypes = [ torch.bool, torch.int, torch.float16, torch.float32, torch.float64, ] class MyMod(torch.nn.Module): def __init__(self, dtype): super().__init__() self.dtype = dtype def forward(self, x): return x.to(self.dtype) bad_dtypes = [] for dtype, output_dtype, device, size in product( dtypes, dtypes, self.devices, sizes ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue if dtype == output_dtype: continue x = self.data_for(dtype, device, size=size) mod = MyMod(output_dtype) ref = mod.forward(x) # use freezing to make non-Tensor args to `to` constant mod = torch.jit.freeze(torch.jit.script(mod.eval())) warmup_forward(mod.forward, x) self.assertEqual(ref, mod.forward(x)) self.assertLastGraphAllFused() @unittest.skip("Temporarily disabled") def test_masked_fill(self): dtypes = [ torch.int8, torch.int16, torch.int32, torch.int64, # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed # torch.float16, torch.float32, torch.float64, torch.bool, ] sizes = [(2,), (4, 4)] for self_dtype, device, scalar_val, size in product( dtypes, self.devices, [0.4, 3], sizes ): input_v = self.data_for(self_dtype, device, size=size) mask = self.data_for(torch.bool, device, size=size) def fn(input_v, mask): return torch.masked_fill(input_v, mask, scalar_val) ref = fn(input_v, mask) try: t = torch.jit.trace(fn, (input_v, mask)) torch.testing.assert_close(ref, t(input_v, mask)) self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( " ".join( [ "Failed:", str(self_dtype), op.__name__, # noqa: F821 device, str(size), ] ) ) from e def test_isnan(self): x = torch.rand([4]) x[0] = float("nan") inputs = [x, torch.tensor([float("nan"), 0.5])] dtypes = [ torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.bool, ] for inp, device, dtype in product(inputs, self.devices, dtypes): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue inp = inp.to(device=device, dtype=dtype) try: f = torch.jit.trace(lambda x: x.isnan(), (inp,)) warmup_forward(f, inp) self.assertEqual(f(inp), inp.isnan()) self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), "isnan", device]) ) from e def test_gelu(self): def apply(fn): return lambda x, approximate: fn(x, approximate) unary_ops = [ F.gelu, ] sizes = [(1,), (2,), (4, 4)] for dtype, op, device, size in product( self.dtypes, unary_ops, self.devices, sizes ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device, size=size) cond = self.data_for(torch.bool, device) fn = apply(op) ref = fn(x, cond) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, cond)) torch.testing.assert_close(ref, t(x, cond)) self.assertAllFused(t.graph_for(x, cond)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) ) from e def test_unary_ops(self): with torch._jit_internal._disable_emit_hooks(): def apply(fn): return lambda x: fn(x) unary_ops = [ torch.lgamma, torch.sigmoid, torch.reciprocal, torch.neg, torch.relu, F.relu6, torch.log, torch.log10, torch.log1p, torch.log2, torch.exp, torch.expm1, torch.erf, torch.erfc, torch.cos, torch.sin, torch.tan, torch.acos, torch.asin, torch.cosh, torch.sinh, torch.atan, torch.tanh, F.hardtanh, F.hardsigmoid, F.hardswish, F.softplus, F.silu, F.mish, F.elu, torch.sqrt, torch.rsqrt, torch.abs, # TODO broken on int8 since # https://github.com/pytorch/pytorch/pull/85144 # RuntimeError: Invalid integral op_type: 23 # torch.ceil, # torch.floor, # torch.round, # torch.trunc, torch.frac, # TODO: broken on ROCm? # F.hardshrink, F.leaky_relu, lambda x: torch.threshold(x, 0, -10), # TODO: broken since type promotion was added # lambda x: torch.clamp(x, -10, 10), ] gpu_only = {torch.erf, torch.erfc} sizes = [(1,), (2,), (4, 4)] for dtype, op, device, size in product( self.dtypes, unary_ops, self.devices, sizes ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue # todo - re-enable. fails with .500 if dtype == torch.bfloat16 and op == torch.round: continue if op in gpu_only and device == "cpu": continue try: x = self.data_for(dtype, device, size=size) fn = apply(op) ref = fn(x) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x,)) torch.testing.assert_close(ref, t(x)) self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( " ".join( ["Failed:", str(dtype), op.__name__, device, str(size)] ) ) from e def test_binary_ops(self): def apply(fn): return lambda x, y: fn(x, y) binary_ops = [ operator.__and__, operator.__or__, operator.__xor__, torch.add, torch.sub, torch.mul, torch.min, torch.max, lambda x, y: torch.lerp(x, y, 0.5), torch.atan2, torch.div, torch.eq, torch.ne, torch.ge, torch.gt, torch.lt, torch.fmod, torch.remainder, lambda x, y: y.type_as(x), ] fp_only = [ torch.fmod, torch.remainder, ] devices = self.devices for dtype, op, device in product(self.dtypes, binary_ops, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) fn = apply(op) ref = fn(x, y) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y)) self.assertEqual(ref, t(x, y)) if op not in fp_only or dtype.is_floating_point: self.assertAllFused(t.graph_for(x, y)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_binary_scalar_ops(self): def apply(fn): return lambda x, y: fn(x, y) ir_template = """ graph(%x : {dtype_x}, %y : {dtype_y}): %z = {op}(%x, %y) return (%z)""" binary_ops = [ "aten::mul", "aten::add", "aten::sub", "aten::div", "aten::lt", "aten::le", "aten::eq", "aten::ne", "aten::gt", "aten::ge", "aten::__or__", "aten::__xor__", "aten::__and__", "aten::__lshift__", "aten::__rshift__", ] dtypes = ["int", "float", "bool"] values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} devices = self.devices for dtype_x, dtype_y, op, device in product( dtypes, dtypes, binary_ops, devices ): code = ir_template.format(**locals()) # Interpret the graph try: graph = torch._C.parse_ir(code) for x, y in product(values[dtype_x], values[dtype_y]): ref = torch._C._jit_interpret_graph(graph, (x, y)) except Exception: # If we can't interpret this IR, don't bother checking NNC. continue # Compile the graph try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: raise RuntimeError( " ".join(["Compilation failed:", device, str(code)]) ) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): ref = torch._C._jit_interpret_graph(graph, (x, y)) try: res = k.run((x, y)) self.assertEqual(ref, res) except Exception as e: raise RuntimeError( " ".join( ["Failed at runtime:", device, str(x), str(y), str(code)] ) ) from e def test_matmul(self): if self.dynamic_shapes: self.skipTest("don't run conv with dynamic shapes") def fn(x, y): return torch.matmul(x, y) devices = ["cpu"] # No cuda support for ext calls yet sizes = [ [[128, 128], [128, 128]], [[10, 10], [10, 10]], [[1, 16], [16, 128]], [[128], [128]], [[128], [128, 128]], [[3], [3]], [[3, 4], [4]], [[10, 3, 4], [4]], [[10, 3, 4], [10, 4, 5]], [[10, 3, 4], [4, 5]], ] # Only 2D x 2D matrix multiply is supported. For non-supported sizes we # still want to run results verification to test that we didn't # accidentally fuse it, but we skip the 'is-fused' check. # TODO: add support for other shape combinations and make this set empty: skip_is_fused_check_sizes = [ "[[128], [128]]", "[[128], [128, 128]]", "[[3], [3]]", "[[3, 4], [4]]", "[[10, 3, 4], [4]]", "[[10, 3, 4], [10, 4, 5]]", "[[10, 3, 4], [4, 5]]", ] for dtype, size, device in product(self.dtypes, sizes, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: size_x, size_y = size x = self.data_for(dtype, device, size=size_x) y = self.data_for(dtype, device, size=size_y) ref = fn(x, y) except Exception as e: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y)) t(x, y) self.assertEqual(ref, t(x, y)) if str(size) not in skip_is_fused_check_sizes: self.assertAllFused(t.graph_for(x, y)) except Exception as e: raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e def test_binary_tensor_scalar_ops(self): with torch._jit_internal._disable_emit_hooks(): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) # FIXME: Fails in IR Eval: torch.int64 and_ cpu binary_ops = [ operator.__and__, operator.__or__, operator.__xor__, torch.add, torch.sub, torch.mul, torch.eq, torch.ne, torch.ge, torch.lt, torch.gt, ] devices = self.devices # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product( self.dtypes, binary_ops, devices, scalars ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) ref = fn(x) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x)) self.assertEqual(ref, t(x)) self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_binary_div_ops(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) binary_ops = [ torch.div, torch.remainder, torch.fmod, ] devices = self.devices # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 for dtype, op, device, scalar in product( self.dtypes, binary_ops, devices, scalars ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) ref = fn(x) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x)) self.assertEqual(ref, t(x)) except Exception as e: raise RuntimeError( f"Failed: {dtype} {op.__name__} {device} {scalar}" ) from e def test_binary_pow(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) dtypes = [ # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0 # torch.float16, torch.float32, torch.float64, # torch.bool intentionally not included ] binary_ops = [ torch.pow, ] # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product( dtypes, binary_ops, self.devices, scalars ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) ref = fn(x) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x)) self.assertEqual(ref, t(x)) self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_ternary_ops(self): def apply(fn): return lambda x, y, z: fn(x, y, z) ternary_ops = [ torch.lerp, torch.addcmul, ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) z = self.data_for(dtype, device) fn = apply(op) ref = fn(x, y, z) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y, z)) self.assertEqual(ref, t(x, y, z)) self.assertAllFused(t.graph_for(x, y, z)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_ternary_norm_ops(self): def apply(fn): return lambda x, y, z: fn(x, y, z) ternary_ops = [ F.batch_norm, ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device, size=[5, 3, 128, 128]) y = self.data_for(dtype, device, size=[3]) z = self.data_for(dtype, device, size=[3]) fn = apply(op) ref = fn(x, y, z) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y, z)) self.assertEqual(ref, t(x, y, z)) self.assertAllFused(t.graph_for(x, y, z)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e @unittest.skip( "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure" ) def test_list_ops(self): def apply(fn): return lambda x, y, z: fn([x * x, y * y, z * z]) devices = self.devices list_ops = [ torch.cat, ] for dtype, op, device in product(self.dtypes, list_ops, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: x = self.data_for(dtype, device, size=[5, 4, 1, 7]) y = self.data_for(dtype, device, size=[5, 4, 1, 7]) z = self.data_for(dtype, device, size=[5, 4, 1, 7]) fn = apply(op) ref = fn(x, y, z) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (x, y, z)) self.assertEqual(ref, t(x, y, z)) self.assertAllFused(t.graph_for(x, y, z)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_where_ops(self): def apply(fn): return lambda cond, x, y: fn(cond, x, y) ops = [ torch.where, lambda cond, x, y: torch.where(cond, x, 3.1415), lambda cond, x, y: torch.where(cond, 42, y), ] devices = self.devices for dtype, op, device in product(self.dtypes, ops, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: cond = self.data_for(torch.bool, device) x = self.data_for(dtype, device) y = self.data_for(dtype, device) fn = apply(op) ref = fn(cond, x, y) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue try: t = torch.jit.trace(fn, (cond, x, y)) self.assertEqual(ref, t(cond, x, y)) self.assertAllFused(t.graph_for(cond, x, y)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e def test_unsupported_dtypes(self): for device in self.devices: def fn(x): return x * x + x unsupported_dtypes = [ torch.uint8, torch.complex32, torch.complex64, torch.complex128, torch.qint8, torch.quint8, torch.qint32, ] for dtype in unsupported_dtypes: try: x = self.data_for(dtype, device) ref = fn(x) except Exception: # If eager mode doesn't support a dtype/op/device combo, # neither does the fuser. Catch everything to avoid needing to # guess what errors might be thrown by eager. continue t = torch.jit.trace(fn, (x,)) self.assertEqual(ref, t(x)) self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0) def test_superslomo(self): devices = self.devices.copy() if not LLVM_ENABLED: devices.remove("cpu") for device in devices: # Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo # A few interesting things happen here: strided inputs of mixed size, # plus outputs of mixed shapes. The latter characteristic happened to # expose a memory corruption bug due to not properly guarding the # outputs. def eager(t0, t1, t2, t3, t4): t5 = torch.mul(t0, t4) t6 = torch.mul(t2, t3) t7 = torch.mul(t6, t1) t9 = torch.add(t5, t7) t11 = torch.add(t0, t6) ft_p = torch.div(t9, t11) return (ft_p, t11, t9, t6) t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1) t1 = torch.rand(6, 3, 352, 352, device=device) t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2) t3 = torch.rand(6, 1, 352, 352, device=device) t4 = torch.rand(6, 3, 352, 352, device=device) inputs = [t0, t1, t2, t3, t4] script = torch.jit.script(eager) for _ in range(4): for pair in zip(script(*inputs), eager(*inputs)): test, ref = pair torch.testing.assert_close(test, ref) self.assertAllFused( script.graph_for(*inputs), except_for={"prim::TupleConstruct"} ) def test_sub_gt_and(self): for device in self.devices: def eager(t1, t2, t3, t4, t: float): w = t1 - t2 h = t3 - t4 k = (w > t) & (h > t) assert k.dtype == torch.bool if t > 0.5: # Putting a use of k in a never-executed conditional prevents # profiling its type, which leaves it as "Tensor". If we # propagate Tensor back to the definition of k, we have to be # careful not to create a fusion group containing it. return k + 1 return w t = torch.rand(8, dtype=torch.float, device=device) scripted = self.checkScript(eager, (t, t, t, t, 0.1)) @skipIfTorchDynamo("too slow") def test_chunk_mul_one(self): if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def eager(x): z, y, w = torch.chunk(x, 3, -1) return z * 3, y, w x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) z, y, w = eager(x) script = self.checkScript(eager, (x,)) def test_eq_unsqueeze_type_as(self): for device in self.devices: def eager(a, b): mask = b == 1 mask = torch.unsqueeze(mask, -1) x = mask.type_as(a) return x, mask a = torch.rand(1, 64, 1024, device=device, dtype=torch.float) b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long) script = self.checkScript(eager, (a, b)) def test_neg_pow(self): def eager_tt(a: torch.Tensor, b: torch.Tensor): return torch.neg(torch.pow(a, b)) def eager_ts(a: torch.Tensor, b: float): return torch.neg(torch.pow(a, b)) def eager_st(a: float, b: torch.Tensor): return torch.neg(torch.pow(a, b)) a = torch.rand(1, dtype=torch.float) b = torch.rand(1, dtype=torch.float) s = b.item() script = self.checkScript(eager_tt, (a, b)) # TODO: re-enable fusion, which doesn't work right now. just test correctness for now # self.assertAllFused(script.graph_for(a, b)) script = self.checkScript(eager_ts, (a, s)) # self.assertAllFused(script.graph_for(a, s)) script = self.checkScript(eager_st, (s, b)) # self.assertAllFused(script.graph_for(s, b)) @unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter") def test_conv2d_depthwise(self): if self.dynamic_shapes: self.skipTest("don't run conv with dynamic shapes") def eager(input, weight, bias): return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72) input = torch.rand((1, 72, 56, 56), dtype=torch.float) weight = torch.rand((72, 1, 3, 3), dtype=torch.float) bias = torch.rand((72), dtype=torch.float) script = self.checkScript(eager, (input, weight, bias)) self.assertAllFused(script.graph_for(input, weight, bias)) def test_conv2d(self): if self.dynamic_shapes: self.skipTest("don't run conv with dynamic shapes") def eager(input, weight, bias): return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1) input = torch.rand((1, 64, 56, 56), dtype=torch.float) weight = torch.rand((64, 64, 3, 3), dtype=torch.float) bias = torch.rand((64), dtype=torch.float) script = self.checkScript(eager, (input, weight, bias)) FileCheck().check_not("TensorExpr").run( torch.jit.last_executed_optimized_graph() ) def test_type_as_cat(self): with inline_fusion_groups(): def eager(x, y): return torch.cat((x, y.type_as(x)), dim=1) dtypes = self.dtypes.copy() # CPU fuser doesn't support float16. dtypes.remove(torch.float16) dtypes.remove(torch.bfloat16) for dtype1, dtype2 in product(dtypes, dtypes): x = torch.randint(2, (1, 13)).to(dtype1) zero = torch.tensor([[0]]).to(dtype2) one = torch.tensor([[1]]).to(dtype2) script = torch.jit.trace(eager, (x, zero)) for _ in range(3): torch.testing.assert_close(script(x, zero), eager(x, zero)) torch.testing.assert_close(script(x, one), eager(x, one)) self.assertAllFused(script.graph_for(x, one)) def test_to_device(self): def eager(x): return x.to(device="cpu").relu() x = torch.rand(8) script = self.checkScript(eager, (x,)) self.assertAllFused(script.graph_for(x)) def test_dims(self): def eager(x, y): return x / (y + 0.0001) x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided( (1, 1, 768), (768, 1, 1) ) y = torch.tensor([[[2.0]]], dtype=torch.float32) script = self.checkScript(eager, (x, y)) self.assertAllFused(script.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_channels_last_dims_dynamic(self): def eager(x, y): return x + (y + 0.0001) indices = [0, 1, 2, 3] sets = [] for i in range(0, len(indices) + 1): for subset in combinations(indices, i): sets.append(subset) # noqa: PERF402 for set in sets: size = [2, 3, 4, 5] for index in set: size[index] = 1 inp = torch.rand(size).to(memory_format=torch.channels_last).cuda() with texpr_enable_strategy([("DYNAMIC", 20)]): foo_s = torch.jit.trace(eager, (inp, inp)) for _ in range(3): out = foo_s(inp, inp) out_eager = eager(inp, inp) self.assertEqual(out_eager, out) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) g = torch.jit.last_executed_optimized_graph() FileCheck().check("TensorExpr").run(g) def test_exhaust_specializations(self): with texpr_enable_strategy([("STATIC", 1)]): @torch.jit.script def foo(x): return x + x + x for _ in range(3): foo(torch.rand([2, 2])) for _ in range(3): foo(torch.rand([4, 4, 4])) g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) FileCheck().check_count("TensorExpr", 2, exactly=True).run(g) def test_unsqueeze_var_dim(self): def eager(x, y, z: int): return x * torch.unsqueeze(y, dim=z) x = torch.rand(4, 4, 64).permute(1, 0, 2) y = torch.rand(4, 4) z = 2 script = self.checkScript(eager, (x, y, z)) def _test_fwd_bwd(self, fn): x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) script = torch.jit.script(fn) for i in range(11): y = fn(x) g0 = torch.rand_like(y) y.backward(g0) ys = script(xs) ys.backward(g0) with torch.no_grad(): x -= 0.1 * x.grad xs -= 0.1 * xs.grad x.grad = None xs.grad = None torch.testing.assert_close(y, ys) def test_relu_fwd_bwd(self): def eager(x): return torch.relu(x * 1.01) self._test_fwd_bwd(eager) def test_hardswish_fwd_bwd(self): def eager(x): return F.hardswish(x) * 1.01 self._test_fwd_bwd(eager) def test_hardsigmoid_fwd_bwd(self): def eager(x): return F.hardsigmoid(x) * 1.01 self._test_fwd_bwd(eager) def test_cat_graph_opt(self): def foo(x, y, z): return torch.log(torch.cat([x, y, z])) self.checkScript( foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])) ) # TODO: not sure why not updated graph isn't reflected in last_optimized_graph self.assertLastGraphAllFused() def test_dynamic_cat(self): with inline_fusion_groups(): @torch.jit.script def repro( xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor] ): return [ torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) for x, y, z in zip(xs, ys, zs) ] for _ in range(3): N = 3 xs = [torch.ones(21) for _ in range(N)] # Note: concat of ys and zs will have the same size for each # pair, even though the individual ys and zs do not. ys = [torch.ones(N - i) for i in range(N)] zs = [torch.ones(i) for i in range(N)] repro(xs, ys, zs) def test_scalar_only_inputs(self): def eager(b: float): a = torch.ones(1) return a * b script = self.checkScript(eager, (1.0,)) def test_cat_2k_args(self): with inline_fusion_groups(): def eager(x): return torch.relu(torch.cat([x for _ in range(2000)])) x = torch.randn(1) trace = self.checkTrace(eager, (x,)) fusion_groups = self.findFusionGroups(trace.graph_for(x)) self.assertEqual(len(fusion_groups), 0) def test_adaptive_avg_pool2d(self): # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this # test should be moved there with inline_fusion_groups(): def foo1(x): return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2)) def foo2(x): return torch.nn.functional.adaptive_avg_pool2d(x, (2)) x = torch.randn(4, 4, 4) for foo in [foo1, foo2]: f = torch.jit.trace(foo, (x,)) kernel = torch._C._te.TensorExprKernel(f.graph) correct_val = f(x) self.assertEqual(kernel.run((x,)), correct_val) def test_unrolled_cat(self): with inline_fusion_groups(): def eager(x): ret = torch.empty(0) for i in range(x.shape[0]): ret = torch.cat([ret, x[i].relu()]) return ret script = torch.jit.script(eager) # Warm up with size=1 tensor; since the loop iterates once the # profile data will be "burned in" assuming size=1, and then # unrolled. x = torch.ones(1, 1) for _ in range(3): script(x) torch.testing.assert_close(eager(x), script(x)) # Now when an input hits the unrolled path, it will produce an # incorrectly-sized tensor, since size=1 has been burned in. x = torch.ones((8, 1)) torch.testing.assert_close(eager(x), script(x)) @skipIfTorchDynamo("too slow") @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans") def test_batch_norm(self): def test(fn, args): trace = torch.jit.trace(fn, args) self.assertAllFused(trace.graph_for(*args)) # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the # default? torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True) def bn(i, x): return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() def bn_no_weight(i, x): return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() def bn_no_bias(i, x): return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() def bn_neither(i, x): return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() for device in self.devices: i = torch.randn(4, 16, 32, 40, device=device) x = torch.randn(16, device=device) for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: test(fn, (i, x)) def test_profiler(self): @torch.jit.script def test(x, y, z): return x * y + z args = [torch.randn(4) for _ in range(3)] with torch.autograd.profiler.profile() as prof: for _ in range(3): test(*args) self.assertIn("fused_mul_add", prof.table()) def test_skip_grad_in_check(self): @torch.jit.script def foo(x): return (x + 2) / 2 inp = torch.rand([4, 4]) for _ in range(3): foo(inp) inp.requires_grad_(True) with torch.inference_mode(): for _ in range(3): foo(inp) g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) torch._C._jit_pass_inline(g) FileCheck().check_count("prim::If", 1, exactly=True).run(g) def test_dynamic_shapes(self): from functools import partial n = 10 gen_tensor = ( lambda n: R(1, n), lambda n: R(n, n), lambda n: R(n, n).transpose(0, 1), lambda n: R(n + 1, n + 1, 2)[:n, n, 0], lambda n: R(n, n, 2)[:, :, 0], lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last), ) with texpr_enable_strategy([("DYNAMIC", 20)]): def foo(x, y, z): return torch.sigmoid(torch.tanh(x)) foo.__disable_jit_function_caching__ = True def fi(x, y, z): return torch.tanh(x + y) fi.__disable_jit_function_caching__ = True def fum(x, y, z): return torch.tanh(x + y) + z fum.__disable_jit_function_caching__ = True funcs = [foo, fi, fum] with inline_fusion_groups(): for device in self.devices: I = partial(torch.randint, 0, 100, device=device) R = partial(torch.randn, device=device) for i, func in enumerate(funcs): num_args = i + 1 for j, gen in enumerate(gen_tensor): inps = (gen(n), gen(n), gen(n)) func_s = torch.jit.trace(func, inps, check_trace=False) torch._C._jit_pass_erase_shape_information(func_s.graph) for _ in range(2): x, y, z = gen(n), gen(n), gen(n) func_s(x, y, z) for incr in range(3): func_s(*[gen(n + 1) for _ in range(3)]) g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) torch._C._jit_pass_dce(g) # We should see only one optimized kernel FileCheck().check_count( "TensorExprDynamicGuard", 1, exactly=True ).run(g) self.assertEqual(func(*inps), func_s(*inps)) gen = gen_tensor[0] inps = (gen(n), gen(n), gen(n)) foo_s = torch.jit.trace(foo, inps) torch._C._jit_pass_erase_shape_information(foo_s.graph) g_prev = None for gen in gen_tensor: for i in range(3): foo_s(*[gen(n + i) for _ in range(3)]) inps = (gen(n), gen(n), gen(n)) self.assertEqual(foo_s(*inps), foo(*inps)) g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) torch._C._jit_pass_dce(g) FileCheck().check_count( "TensorExprDynamicGuard", len(gen_tensor), exactly=True ).run(g) @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") def test_autocast_up(self): def f(x): y = x._autocast_to_full_precision(True, True) z = torch.exp(y) return z x = torch.rand((2, 2), dtype=torch.half, device="cuda") scr = torch.jit.script(f) scr(x) scr(x) self.assertLastGraphAllFused() @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") def test_autocast_down(self): def f(x): y = torch.sigmoid(x) z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half) return z x = torch.rand((2, 2), dtype=torch.float, device="cuda") scr = torch.jit.script(f) scr(x) scr(x) self.assertLastGraphAllFused() @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") def test_to_dtype(self): def f(x): y = torch.sigmoid(x) z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16) h = z._autocast_to_full_precision(True, True) i = h.to(dtype=torch.bfloat16) j = i.to(dtype=torch.float32) return j x = torch.rand((2, 2), dtype=torch.float32) scr = torch.jit.trace(f, x) scr(x) scr(x) self.assertLastGraphAllFused() self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3) bf_x = torch.rand((2, 2), dtype=torch.bfloat16) bf_scr = torch.jit.trace(f, bf_x) bf_scr(bf_x) bf_scr(bf_x) graph = bf_scr.graph_for(bf_x) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 2) self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3) def test_with_strict_fusion(self): def success(x): with torch.jit.strict_fusion(): return x + x + x scripted = self.checkScript(success, (torch.rand([4]),)) g = torch.jit.last_executed_optimized_graph() FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g) def foo(x): with torch.jit.strict_fusion(): return x + x + torch.rand([4]) + 3 with self.assertRaises(Exception) as error_out: foo_s = torch.jit.script(foo) foo_s(torch.rand([4])) foo_s(torch.rand([4])) print(torch.jit.last_executed_optimized_graph()) fc = FileCheck().check("Found unfused operators") fc.check("aten::rand(SymInt[] size") fc.check("torch.rand([4]").run(str(error_out.exception)) with warnings.catch_warnings(record=True) as warns: foo(torch.rand([4])) FileCheck().check("Only works in script mode").run(str(warns[0])) def test_autodiff(x): with torch.jit.strict_fusion(): return torch.rand([4]) + x + x + x foo_s = torch.jit.script(test_autodiff) inp = torch.rand([4], requires_grad=True) with self.assertRaises(Exception) as error_out: for _ in range(3): foo_s(inp) f = FileCheck().check("unfused operators").check("aten::rand") f.run(str(error_out.exception)) def test_separate_fusions(x, y): with torch.jit.strict_fusion(): return x + x + x, y + y + y inp = torch.rand([4], requires_grad=True) with self.assertRaises(Exception) as error_out: for _ in range(3): foo_s = torch.jit.script(test_separate_fusions) foo_s(inp, inp) f = FileCheck().check("Found multiple fusions") f.run(str(error_out.exception)) def test_constant_chunk_shapes(self): # We had an issue where buildShapeExpressions would fail as show below: # # %1 : Tensor = Constant[..] # not supported, we don't build this shape # %2 : Tensor = Constant[..] # not supported # %3 : Tensor = aten::add(%1, %2) # inputs not supported, we don't build shape # ... = prim::ConstantChunk[..](%3) # it forgets to check whether input shapes exist, and fails if self.dynamic_shapes: self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: def f(x, y): r = torch.tensor(4) z1, z2 = (x + y + r).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) ge = self.checkTrace(f, (x, y)) graph = ge.graph_for(x, y) # make sure that we are actually testing the right scenario FileCheck().check("with " + FUSION_GROUP + "_").check_count( "ConstantChunk", 1, exactly=True ).run(str(graph)) f_traced = torch.jit.trace(f, (x, y)) for i in range(4): # make sure this doesn't error out res = f_traced(x, y) self.assertEqual(res, f(x, y)) @unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA") def test_pow_multiple_dtype(self): # https://github.com/pytorch/pytorch/issues/75476 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: p = torch.sigmoid(p) result = p**gamma return result x = torch.rand((2, 2), dtype=torch.half, device="cuda") ref = fn(x) script_fn = torch.jit.script(fn) for i in range(4): res = script_fn(x) self.assertEqual(ref, res) class TestTEFuserStatic(TestTEFuser): dynamic_shapes = False class TestTEFuserDynamic(TestTEFuser): dynamic_shapes = True del TestTEFuser works_list = [ "__radd__", "__rdiv__", "__rmul__", "__rmod__", "abs", "acos", "add", "addcmul", "addmm.decomposed", "asin", "atan", "atan2", "ceil", "clamp", "clamp.scalar", "contiguous", "cos", "cosh", "div.no_rounding_mode", "div.true_rounding", "div.floor_rounding", "div.trunc_rounding", "eq", "erf", "erfc", "exp", "expand", "expand_as", "expm1", "floor", "fmod", "fmod.autodiffed", "ge", "gt", "isnan", "le", "lerp", "lgamma", "log", "log10", "log1p", "log2", "lt", "masked_fill", "max.binary", "mean", "min.binary", "mm", "mul", "ne", "neg", "nn.functional.hardshrink", "nn.functional.hardsigmoid", "nn.functional.hardswish", "nn.functional.softplus", "nn.functional.hardtanh", "nn.functional.leaky_relu", "nn.functional.relu", "nn.functional.relu6", "nn.functional.softsign", "nn.functional.tanhshrink", "nn.functional.threshold", "permute", "pow", "reciprocal", "remainder", "remainder.autodiffed", "reshape", "reshape_as", "round", "rsub", "rsub.rsub_tensor", "rsqrt", "sigmoid", "sign", "sin", "sinh", "sqrt", "sub", "sum", "t", "tan", "tanh", "transpose", "true_divide", "trunc", "unsqueeze", "view", "view_as", "where", "bool", "byte", "char", "double", "float", "half", "int", "long", "short", "bool.channels_last", "byte.channels_last", "char.channels_last", "double.channels_last", "float.channels_last", "half.channels_last", "int.channels_last", "long.channels_last", "short.channels_last", ] known_failures = [ "__rmatmul__", "frac", "matmul", ] # If your OpInfo test causes this test to fail, add it here skip_ops = ["conj"] def get_name(op): l = [op.name] if op.variant_test_name != "": l.append(op.variant_test_name) return ".".join(l) # Purpose of this class is to allow super() calls. # super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works. # super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope. # super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation class TestNNCOpInfoParent(JitCommonTestCase): pass class TestNNCOpInfo(TestNNCOpInfoParent): def setUp(self): super(TestNNCOpInfoParent, self).setUp() self.tensorexpr_options = TensorExprTestOptions() def tearDown(self): self.tensorexpr_options.restore() super(TestNNCOpInfoParent, self).tearDown() def te_compile(self, device, dtype, op): if op.name in skip_ops: return sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) for sample_input in sample_inputs_itr: arg_values = [sample_input.input] + list(sample_input.args) kwarg_values = sample_input.kwargs param_names = [] param_values = [] fx_args = [] for idx, v in enumerate(arg_values): if isinstance(v, torch.Tensor): param_names.append(f"arg_{idx}") param_values.append(v) fx_args.append(param_names[-1]) else: fx_args.append(f"{repr(v)}") for k, v in kwarg_values.items(): if isinstance(v, torch.Tensor): param_names.append(k) param_values.append(v) fx_args.append(f"{k} = {k}") else: fx_args.append(f"{k} = {repr(v)}") code = f""" def f({', '.join(param_names)}): return op.op({', '.join(fx_args)})""" g = {"torch": torch, "inf": math.inf, "op": op} exec(code, g) f = g["f"] f.__module__ = "test" out = f(*param_values) ts_g = torch.jit.trace(f, param_values) kernel = torch._C._te.TensorExprKernel(ts_g.graph) correct_val = f(*param_values) self.assertEqual(kernel.run(tuple(param_values)), correct_val) self.assertEqual(kernel.fallback(tuple(param_values)), correct_val) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") @ops( [op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,), ) def test_working(self, device, dtype, op): self.te_compile(device, dtype, op) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") @ops( [op for op in op_db if get_name(op) in known_failures], allowed_dtypes=(torch.float,), ) def test_failures(self, device, dtype, op): try: self.te_compile(device, dtype, op) except Exception as e: pass else: raise RuntimeError( "Expected test to fail. If it now works, move op into works_list" ) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") @ops( [op for op in op_db if get_name(op) not in works_list + known_failures], allowed_dtypes=(torch.float,), ) def test_unsupported(self, device, dtype, op): if get_name(op) in skip_ops: return try: with warnings.catch_warnings(): warnings.simplefilter("ignore", TracerWarning) # noqa: F821 self.te_compile(device, dtype, op) except Exception as e: pass else: raise RuntimeError( "Expected test to fail. If it now works, move op into works_list" ) @slowTest @onlyCPU @ops(op_db, dtypes=OpDTypes.supported) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") with NoTracerWarnContextManager() as no_warn: variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) for variant, sample in variant_sample_pairs: trace = create_traced_fn(self, variant, cache_traced_fn=True) ref = variant( *clone_inputs((sample.input, *sample.args)), **sample.kwargs ) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) val = trace( *clone_inputs((sample.input, *sample.args)), **sample.kwargs ) atol = 2e-1 if dtype == torch.bfloat16 else 1e-5 rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5 self.assertEqual(ref, val, atol=atol, rtol=rtol) # https://github.com/pytorch/pytorch/issues/35600 # each torch.jit.trace adds state to the _python_cu compilation unit # since this test traces a lot of functions, out-of-memory can occur # if the CU is not cleared. torch.jit._state._python_cu.drop_all_functions() # CPU fuser not currently used in fbcode only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda") instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) # Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent) class TestLoopnestRandomizationParent(JitTestCase): pass class TestLoopnestRandomization(TestLoopnestRandomizationParent): def setUp(self): super(TestLoopnestRandomizationParent, self).setUp() self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu() self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() torch._C._jit_override_can_fuse_on_cpu(True) # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle # torch._C._jit_set_te_must_use_llvm_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) self.old_profiling_mode = torch._C._get_graph_executor_optimize(True) self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() torch._C._debug_set_fusion_group_inlining(False) self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() torch._C._jit_set_te_must_use_llvm_cpu(False) # Set the seed to 1. This tests the codepath through random # transformation. os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1" def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) torch._C._get_graph_executor_optimize(self.old_profiling_mode) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) # Set it back to 0. os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0" super(TestLoopnestRandomizationParent, self).tearDown() @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") def test_relu(self, device): def fn_test_relu(x, y): return F.relu(x + 0.5 * y) x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) fn = fn_test_relu traced_fn = torch.jit.trace(fn, (x, y)) ref = fn(x, y) res = traced_fn(x, y) assert torch.allclose(ref, res) instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu")) if __name__ == "__main__": run_tests()