# Owner(s): ["module: cpp-extensions"] import _codecs import os import shutil import sys import tempfile import types import unittest from typing import Union from unittest.mock import patch import numpy as np import torch import torch.testing._internal.common_utils as common import torch.utils.cpp_extension from torch.serialization import safe_globals from torch.testing._internal.common_utils import ( IS_ARM64, skipIfTorchDynamo, TemporaryFileName, TEST_CUDA, TEST_XPU, ) from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME TEST_CUDA = TEST_CUDA and CUDA_HOME is not None TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None def remove_build_path(): if sys.platform == "win32": # Not wiping extensions build folder because Windows return default_build_root = torch.utils.cpp_extension.get_default_build_root() if os.path.exists(default_build_root): shutil.rmtree(default_build_root, ignore_errors=True) def generate_faked_module(): def device_count() -> int: return 1 def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor: # create a tensor using our custom device object. return torch.empty(4, 4, device="foo") def set_rng_state( new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo" ) -> None: pass def is_available(): return True def current_device(): return 0 # create a new module to fake torch.foo dynamicaly foo = types.ModuleType("foo") foo.device_count = device_count foo.get_rng_state = get_rng_state foo.set_rng_state = set_rng_state foo.is_available = is_available foo.current_device = current_device foo._lazy_init = lambda: None foo.is_initialized = lambda: True return foo @unittest.skipIf(IS_ARM64, "Does not work on arm") @unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently") @torch.testing._internal.common_utils.markDynamoStrictTest class TestCppExtensionOpenRgistration(common.TestCase): """Tests Open Device Registration with C++ extensions.""" module = None def setUp(self): super().setUp() # cpp extensions use relative paths. Those paths are relative to # this file, so we'll change the working directory temporarily self.old_working_dir = os.getcwd() os.chdir(os.path.dirname(os.path.abspath(__file__))) assert self.module is not None def tearDown(self): super().tearDown() # return the working directory (see setUp) os.chdir(self.old_working_dir) @classmethod def setUpClass(cls): remove_build_path() cls.module = torch.utils.cpp_extension.load( name="custom_device_extension", sources=[ "cpp_extensions/open_registration_extension.cpp", ], extra_include_paths=["cpp_extensions"], extra_cflags=["-g"], verbose=True, ) # register torch.foo module and foo device to torch torch.utils.rename_privateuse1_backend("foo") torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) torch._register_device_module("foo", generate_faked_module()) def test_base_device_registration(self): self.assertFalse(self.module.custom_add_called()) # create a tensor using our custom device object device = self.module.custom_device() x = torch.empty(4, 4, device=device) y = torch.empty(4, 4, device=device) # Check that our device is correct. self.assertTrue(x.device == device) self.assertFalse(x.is_cpu) self.assertFalse(self.module.custom_add_called()) # calls out custom add kernel, registered to the dispatcher z = x + y # check that it was called self.assertTrue(self.module.custom_add_called()) z_cpu = z.to(device="cpu") # Check that our cross-device copy correctly copied the data to cpu self.assertTrue(z_cpu.is_cpu) self.assertFalse(z.is_cpu) self.assertTrue(z.device == device) self.assertEqual(z, z_cpu) def test_common_registration(self): # check unsupported device and duplicated registration with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): torch._register_device_module("dev", generate_faked_module()) with self.assertRaisesRegex(RuntimeError, "The runtime module of"): torch._register_device_module("foo", generate_faked_module()) # backend name can be renamed to the same name multiple times torch.utils.rename_privateuse1_backend("foo") # backend name can't be renamed multiple times to different names. with self.assertRaisesRegex( RuntimeError, "torch.register_privateuse1_backend()" ): torch.utils.rename_privateuse1_backend("dev") # generator tensor and module can be registered only once with self.assertRaisesRegex(RuntimeError, "The custom device module of"): torch.utils.generate_methods_for_privateuse1_backend() # check whether torch.foo have been registered correctly self.assertTrue( torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1 ) with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"): torch.utils.backend_registration._get_custom_mod_func("func_name_") # check attributes after registered self.assertTrue(hasattr(torch.Tensor, "is_foo")) self.assertTrue(hasattr(torch.Tensor, "foo")) self.assertTrue(hasattr(torch.TypedStorage, "is_foo")) self.assertTrue(hasattr(torch.TypedStorage, "foo")) self.assertTrue(hasattr(torch.UntypedStorage, "is_foo")) self.assertTrue(hasattr(torch.UntypedStorage, "foo")) self.assertTrue(hasattr(torch.nn.Module, "foo")) self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo")) self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo")) def test_open_device_generator_registration_and_hooks(self): device = self.module.custom_device() # None of our CPU operations should call the custom add function. self.assertFalse(self.module.custom_add_called()) # check generator registered before using with self.assertRaisesRegex( RuntimeError, "Please register a generator to the PrivateUse1 dispatch key", ): torch.Generator(device=device) self.module.register_generator_first() gen = torch.Generator(device=device) self.assertTrue(gen.device == device) # generator can be registered only once with self.assertRaisesRegex( RuntimeError, "Only can register a generator to the PrivateUse1 dispatch key once", ): self.module.register_generator_second() if self.module.is_register_hook() is False: self.module.register_hook() default_gen = self.module.default_generator(0) self.assertTrue( default_gen.device.type == torch._C._get_privateuse1_backend_name() ) def test_open_device_dispatchstub(self): # test kernels could be reused by privateuse1 backend through dispatchstub input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu") foo_input_data = input_data.to("foo") output_data = torch.abs(input_data) foo_output_data = torch.abs(foo_input_data) self.assertEqual(output_data, foo_output_data.cpu()) output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") # output operand will resize flag is True in TensorIterator. foo_input_data = input_data.to("foo") foo_output_data = output_data.to("foo") # output operand will resize flag is False in TensorIterator. torch.abs(input_data, out=output_data[:, :, 0:6:2]) torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2]) self.assertEqual(output_data, foo_output_data.cpu()) # output operand will resize flag is True in TensorIterator. # and convert output to contiguous tensor in TensorIterator. output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") foo_input_data = input_data.to("foo") foo_output_data = output_data.to("foo") torch.abs(input_data, out=output_data[:, :, 0:6:3]) torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3]) self.assertEqual(output_data, foo_output_data.cpu()) def test_open_device_quantized(self): input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo") quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8) self.assertEqual(quantized_tensor.device, torch.device("foo:0")) self.assertEqual(quantized_tensor.dtype, torch.qint8) def test_open_device_random(self): # check if torch.foo have implemented get_rng_state with torch.random.fork_rng(device_type="foo"): pass def test_open_device_tensor(self): device = self.module.custom_device() # check whether print tensor.type() meets the expectation dtypes = { torch.bool: "torch.foo.BoolTensor", torch.double: "torch.foo.DoubleTensor", torch.float32: "torch.foo.FloatTensor", torch.half: "torch.foo.HalfTensor", torch.int32: "torch.foo.IntTensor", torch.int64: "torch.foo.LongTensor", torch.int8: "torch.foo.CharTensor", torch.short: "torch.foo.ShortTensor", torch.uint8: "torch.foo.ByteTensor", } for tt, dt in dtypes.items(): test_tensor = torch.empty(4, 4, dtype=tt, device=device) self.assertTrue(test_tensor.type() == dt) # check whether the attributes and methods of the corresponding custom backend are generated correctly x = torch.empty(4, 4) self.assertFalse(x.is_foo) x = x.foo(torch.device("foo")) self.assertFalse(self.module.custom_add_called()) self.assertTrue(x.is_foo) # test different device type input y = torch.empty(4, 4) self.assertFalse(y.is_foo) y = y.foo(torch.device("foo:0")) self.assertFalse(self.module.custom_add_called()) self.assertTrue(y.is_foo) # test different device type input z = torch.empty(4, 4) self.assertFalse(z.is_foo) z = z.foo(0) self.assertFalse(self.module.custom_add_called()) self.assertTrue(z.is_foo) def test_open_device_packed_sequence(self): device = self.module.custom_device() a = torch.rand(5, 3) b = torch.tensor([1, 1, 1, 1, 1]) input = torch.nn.utils.rnn.PackedSequence(a, b) self.assertFalse(input.is_foo) input_foo = input.foo() self.assertTrue(input_foo.is_foo) def test_open_device_storage(self): # check whether the attributes and methods for storage of the corresponding custom backend are generated correctly x = torch.empty(4, 4) z1 = x.storage() self.assertFalse(z1.is_foo) z1 = z1.foo() self.assertFalse(self.module.custom_add_called()) self.assertTrue(z1.is_foo) with self.assertRaisesRegex(RuntimeError, "Invalid device"): z1.foo(torch.device("cpu")) z1 = z1.cpu() self.assertFalse(self.module.custom_add_called()) self.assertFalse(z1.is_foo) z1 = z1.foo(device="foo:0", non_blocking=False) self.assertFalse(self.module.custom_add_called()) self.assertTrue(z1.is_foo) with self.assertRaisesRegex(RuntimeError, "Invalid device"): z1.foo(device="cuda:0", non_blocking=False) # check UntypedStorage y = torch.empty(4, 4) z2 = y.untyped_storage() self.assertFalse(z2.is_foo) z2 = z2.foo() self.assertFalse(self.module.custom_add_called()) self.assertTrue(z2.is_foo) # check custom StorageImpl create self.module.custom_storage_registry() z3 = y.untyped_storage() self.assertFalse(self.module.custom_storageImpl_called()) z3 = z3.foo() self.assertTrue(self.module.custom_storageImpl_called()) self.assertFalse(self.module.custom_storageImpl_called()) z3 = z3[0:3] self.assertTrue(self.module.custom_storageImpl_called()) @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_open_device_storage_pin_memory(self): # Check if the pin_memory is functioning properly on custom device cpu_tensor = torch.empty(3) self.assertFalse(cpu_tensor.is_foo) self.assertFalse(cpu_tensor.is_pinned("foo")) cpu_tensor_pin = cpu_tensor.pin_memory("foo") self.assertTrue(cpu_tensor_pin.is_pinned("foo")) # Test storage pin_memory and is_pin cpu_storage = cpu_tensor.storage() # We implement a dummy pin_memory of no practical significance # for custom device. Once tensor.pin_memory() has been called, # then tensor.is_pinned() will always return true no matter # what tensor it's called on. self.assertTrue(cpu_storage.is_pinned("foo")) cpu_storage_pinned = cpu_storage.pin_memory("foo") self.assertTrue(cpu_storage_pinned.is_pinned("foo")) # Test untyped storage pin_memory and is_pin cpu_tensor = torch.randn([3, 2, 1, 4]) cpu_untyped_storage = cpu_tensor.untyped_storage() self.assertTrue(cpu_untyped_storage.is_pinned("foo")) cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo") self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo")) @unittest.skip( "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" ) def test_open_device_serialization(self): self.module.set_custom_device_index(-1) storage = torch.UntypedStorage(4, device=torch.device("foo")) self.assertEqual(torch.serialization.location_tag(storage), "foo") self.module.set_custom_device_index(0) storage = torch.UntypedStorage(4, device=torch.device("foo")) self.assertEqual(torch.serialization.location_tag(storage), "foo:0") cpu_storage = torch.empty(4, 4).storage() foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0") self.assertTrue(foo_storage.is_foo) # test tensor MetaData serialization x = torch.empty(4, 4).long() y = x.foo() self.assertFalse(self.module.check_backend_meta(y)) self.module.custom_set_backend_meta(y) self.assertTrue(self.module.check_backend_meta(y)) self.module.custom_serialization_registry() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "data.pt") torch.save(y, path) z1 = torch.load(path) # loads correctly onto the foo backend device self.assertTrue(z1.is_foo) # loads BackendMeta data correctly self.assertTrue(self.module.check_backend_meta(z1)) # cross-backend z2 = torch.load(path, map_location="cpu") # loads correctly onto the cpu backend device self.assertFalse(z2.is_foo) # loads BackendMeta data correctly self.assertFalse(self.module.check_backend_meta(z2)) def test_open_device_storage_resize(self): cpu_tensor = torch.randn([8]) foo_tensor = cpu_tensor.foo() foo_storage = foo_tensor.storage() self.assertTrue(foo_storage.size() == 8) # Only register tensor resize_ function. foo_tensor.resize_(8) self.assertTrue(foo_storage.size() == 8) with self.assertRaisesRegex(TypeError, "Overflow"): foo_tensor.resize_(8**29) def test_open_device_storage_type(self): # test cpu float storage cpu_tensor = torch.randn([8]).float() cpu_storage = cpu_tensor.storage() self.assertEqual(cpu_storage.type(), "torch.FloatStorage") # test custom float storage before defining FloatStorage foo_tensor = cpu_tensor.foo() foo_storage = foo_tensor.storage() self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage") class CustomFloatStorage: @property def __module__(self): return "torch." + torch._C._get_privateuse1_backend_name() @property def __name__(self): return "FloatStorage" # test custom float storage after defining FloatStorage try: torch.foo.FloatStorage = CustomFloatStorage() self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage") # test custom int storage after defining FloatStorage foo_tensor2 = torch.randn([8]).int().foo() foo_storage2 = foo_tensor2.storage() self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage") finally: torch.foo.FloatStorage = None def test_open_device_faketensor(self): with torch._subclasses.fake_tensor.FakeTensorMode.push(): a = torch.empty(1, device="foo") b = torch.empty(1, device="foo:0") result = a + b def test_open_device_named_tensor(self): torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"]) # Not an open registration test - this file is just very convenient # for testing torch.compile on custom C++ operators def test_compile_autograd_function_returns_self(self): x_ref = torch.randn(4, requires_grad=True) out_ref = self.module.custom_autograd_fn_returns_self(x_ref) out_ref.sum().backward() x_test = x_ref.clone().detach().requires_grad_(True) f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self) out_test = f_compiled(x_test) out_test.sum().backward() self.assertEqual(out_ref, out_test) self.assertEqual(x_ref.grad, x_test.grad) # Not an open registration test - this file is just very convenient # for testing torch.compile on custom C++ operators @skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket") def test_compile_autograd_function_aliasing(self): x_ref = torch.randn(4, requires_grad=True) out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref) out_ref.sum().backward() x_test = x_ref.clone().detach().requires_grad_(True) f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing) out_test = f_compiled(x_test) out_test.sum().backward() self.assertEqual(out_ref, out_test) self.assertEqual(x_ref.grad, x_test.grad) def test_open_device_scalar_type_fallback(self): z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64) z = torch.triu_indices(3, 3, device="foo") self.assertEqual(z_cpu, z) def test_open_device_tensor_type_fallback(self): # create tensors located in custom device x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo") y = torch.Tensor([1, 0, 2]).to("foo") # create result tensor located in cpu z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]]) # Check that our device is correct. device = self.module.custom_device() self.assertTrue(x.device == device) self.assertFalse(x.is_cpu) # call sub op, which will fallback to cpu z = torch.sub(x, y) self.assertEqual(z_cpu, z) # call index op, which will fallback to cpu z_cpu = torch.Tensor([3, 1]) y = torch.Tensor([1, 0]).long().to("foo") z = x[y, y] self.assertEqual(z_cpu, z) def test_open_device_tensorlist_type_fallback(self): # create tensors located in custom device v_foo = torch.Tensor([1, 2, 3]).to("foo") # create result tensor located in cpu z_cpu = torch.Tensor([2, 4, 6]) # create tensorlist for foreach_add op x = (v_foo, v_foo) y = (v_foo, v_foo) # Check that our device is correct. device = self.module.custom_device() self.assertTrue(v_foo.device == device) self.assertFalse(v_foo.is_cpu) # call _foreach_add op, which will fallback to cpu z = torch._foreach_add(x, y) self.assertEqual(z_cpu, z[0]) self.assertEqual(z_cpu, z[1]) # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL # This is a hack to test serialization through numpy with patch.object(torch._C, "_has_storage", return_value=False): x = torch.randn(2, 3) x_foo = x.to(device) sd = {"x": x_foo} rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( [ np.core.multiarray._reconstruct, np.ndarray, np.dtype, _codecs.encode, type(np.dtype(np.float32)) if np.__version__ < "1.25.0" else np.dtypes.Float32DType, ] ): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) # Test metadata_only with TemporaryFileName() as f: with self.assertRaisesRegex( RuntimeError, "Cannot serialize tensors on backends with no storage under skip_data context manager", ): with torch.serialization.skip_data(): torch.save(sd, f) if __name__ == "__main__": common.run_tests()