# Owner(s): ["oncall: jit"] import unittest from itertools import product import torch from torch.jit._passes._property_propagation import apply_input_props_using_example from torch.testing._internal.common_utils import TEST_CUDA from torch.testing._internal.jit_utils import JitTestCase try: from torchvision import models except ImportError: models = None if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestDeviceAnalysis(JitTestCase): @classmethod def setUpClass(cls): cls.cpu = torch.device("cpu") cls.cuda = torch.device("cuda") cls.vulkan = torch.device("vulkan") cls.mkldnn = torch.device( "mkldnn" ) # MKLDNN can't mix with other device types at all cls.device_types = [cls.cpu, cls.cuda, cls.vulkan] @staticmethod def node_output_device(graph): graph_out = list(graph.outputs()) assert len(graph_out) == 1 return graph_out[0].type().device() def prop_device_on_graph(self, graph, example_devices, in_shapes=None): graph_inputs = list(graph.inputs()) torch._C._jit_pass_erase_shape_information(graph) self.assertEqual(len(graph_inputs), len(example_devices)) for graph_i, device_i in zip(graph_inputs, example_devices): if device_i is not None: graph_i.setType(graph_i.type().with_device(device_i)) if in_shapes: for graph_i, shapes_i in zip(graph_inputs, in_shapes): if shapes_i is not None: graph_i.setType(graph_i.type().with_sizes(shapes_i)) torch._C._jit_pass_propagate_shapes_on_graph(graph) torch._C._jit_pass_propagate_device(graph) def assert_device_equal( self, fn, in_devices, expected_device, in_shapes=None, subtest_str="" ): with self.subTest( f"In device: {in_devices}, expected: {expected_device}, \n {subtest_str}" ): graph = torch.jit.script(fn).graph self.prop_device_on_graph(graph, in_devices, in_shapes) actual_device = self.node_output_device(graph) if expected_device is None or actual_device is None: self.assertEqual(actual_device, expected_device) else: self.assertEqual( actual_device.type, expected_device.type, "Failed Verification" ) def test_device_apply(self): # Test if the device is properly applied to the input def add_self(x): return x + x graph = torch.jit.script(add_self).graph graph_input = next(graph.inputs()) graph_input.setType(graph_input.type().with_device(self.cpu)) # self.prop_device_on_graph(graph, [self.cpu]) self.assertEqual(graph_input.type().device(), self.cpu) @unittest.skipIf(models is None, "Requires torchvision") def test_mobilenet(self): in_cpu = torch.randn(1, 3, 224, 224, device=self.cpu) in_example = in_cpu expected_device = self.cpu m = torch.jit.script(models.mobilenet_v3_small()) m.eval() graph = torch.jit.freeze(m).graph # torch._C._jit_pass_erase_shape_information(graph) apply_input_props_using_example(graph, in_example) torch._C._jit_pass_propagate_shapes_on_graph(graph) torch._C._jit_pass_propagate_device(graph) actual_device = self.node_output_device(graph) if expected_device is None or actual_device is None: self.assertEqual(actual_device, expected_device) else: self.assertEqual( actual_device.type, expected_device.type, "Failed Verification" ) def test_simple(self): def add_self(x): return x + x def relu_(x): return torch.nn.functional.relu_(x) functions = [add_self, relu_] for in_device, fn in product(self.device_types, functions): self.assert_device_equal(fn, [in_device], in_device) def test_set_dtype(self): def set_device(x): return x.to("cpu") for in_device in self.device_types: self.assert_device_equal(set_device, [in_device], self.cpu) def test_device_arg(self): # Test that no device gets propagated when arg is passed in def set_device(x, device_name: torch.device): return x.to(device=device_name) for in_device in self.device_types: self.assert_device_equal(set_device, [in_device, None], None) def test_tensor_as_fns(self): def view_as_fn(x, y): return x.view_as(y) def expand_as_fn(x, y): return x.expand_as(y) def reshape_as_fn(x, y): return x.reshape_as(y) for test_fn in [view_as_fn, expand_as_fn, reshape_as_fn]: self.assert_device_equal(test_fn, [self.cpu, self.cpu], self.cpu) self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) self.assert_device_equal(test_fn, [None, self.mkldnn], None) def type_as_fn(x, y): return x.type_as(y) self.assert_device_equal(type_as_fn, [self.cpu, self.cpu], self.cpu) self.assert_device_equal(type_as_fn, [self.cuda, None], None) self.assert_device_equal(type_as_fn, [None, self.mkldnn], self.mkldnn) def zerodim_test_core(self, device_pairs): # Test the support of zerodim tensors with non-zerodim tensors def mul(x, y): return x * y def add(x, y): return x + y fns = [mul, add] input_shapes = [ ((1, 2, 2), (2, 2)), # Different dim, non-zerodim ((1, 2, 2), ()), # one zerodim ((), ()), # both zerodim ] for fn, shapes, devices in product(fns, input_shapes, device_pairs): subtest_str = f"{fn.__name__} \n shapes: {shapes}, \n devices: {devices}" in0 = torch.rand(shapes[0], device=devices[0]) in1 = torch.rand(shapes[1], device=devices[1]) try: out = fn(in0, in1) except Exception as e: # Don't expect eager failures for CPU zerodim tensors for i in range(len(devices)): if shapes[i] == () and devices[i] == self.cpu: raise e # only expect eager failures on different devices if devices[0] == devices[1]: raise e # Expect result device to be None for the failure cases. self.assert_device_equal(fn, devices, None, shapes, subtest_str) continue self.assert_device_equal(fn, devices, out.device, shapes, subtest_str) # Test that without shapes, we either get the same device or None for the device # Aka that the code is convservative for tensor shapes. graph = torch.jit.script(fn).graph self.prop_device_on_graph(graph, devices) actual_device = self.node_output_device(graph) self.assertTrue( (actual_device is None) or (actual_device.type == out.device.type) ) def test_zerodim_cpu(self): # Allow for minimal testing locally self.zerodim_test_core([(self.cpu, self.cpu)]) def test_zerodim_no_device(self): # If device is missing, you should never be able to infer device type. def mul(x, y): return x * y def add(x, y): return x + y fns = [mul, add] device_pairs = [ (self.cpu, None), (None, self.cpu), (None, None), ] input_shapes = [ ((1, 2, 2), (2, 2)), # Different dim, non-zerodim ((1, 2, 2), ()), # one zerodim ((), ()), # both zerodim ] for fn, shapes, devices in product(fns, input_shapes, device_pairs): self.assert_device_equal(fn, devices, None, shapes) @unittest.skipIf(not TEST_CUDA, "No CUDA") def test_zerodim_gpu(self): device_pairs = [ (self.cpu, self.cuda), (self.cuda, self.cpu), (self.cuda, self.cuda), ] self.zerodim_test_core(device_pairs) def test_custom_device_op(self): # Test both of the custom functions and check that the devicetype is # correctly applied def set_cuda(x): return x.cuda() def set_cpu(x): return x.cpu() def set_mkldnn(x): return x.to_mkldnn() device_pairs = ( (set_cuda, self.cuda), (set_cpu, self.cpu), (set_mkldnn, self.mkldnn), ) for fn, out_device in device_pairs: for in_device in self.device_types: self.assert_device_equal(fn, [in_device], out_device) def test_device_if_propagation(self): def test_fn(x, y, z: bool): if z: return x + 3 else: return y * 2 self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu) self.assert_device_equal(test_fn, [self.mkldnn, self.mkldnn, None], self.mkldnn) self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None) def test_loop_simple(self): def test_fn(x, y, z: int): for _ in range(z): y = x return y self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu) self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None) self.assert_device_equal(test_fn, [self.cpu, None, None], None) def test_loop_device_change(self): def test_fn(x, z: int): for _ in range(z): x = x.cuda() return x self.assert_device_equal(test_fn, [self.cpu, None], None) self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) self.assert_device_equal(test_fn, [None, None], None) def test_while_change(self): def test_fn(x, z: int): while z > 0: x = x.cuda() z = 0 return x self.assert_device_equal(test_fn, [self.cpu, None], None) self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) self.assert_device_equal(test_fn, [None, None], None) def test_nested_loops(self): def test_fn(x, z: int): for i in range(z): x = x.cpu() for _ in range(i): x = x + 1 return x self.assert_device_equal(test_fn, [self.cpu, None], self.cpu) self.assert_device_equal(test_fn, [self.cuda, None], None) self.assert_device_equal(test_fn, [None, None], None) def test_if_loop_mix(self): def test_fn(x, y, z: bool, a: bool): c = x while a: if z: c = x + 3 else: c = y * 2 a = False return c self.assert_device_equal(test_fn, [self.cpu, self.cpu, None, None], self.cpu) self.assert_device_equal( test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn ) self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)