1# Owner(s): ["module: cpp-extensions"] 2 3import _codecs 4import os 5import shutil 6import sys 7import tempfile 8import types 9import unittest 10from typing import Union 11from unittest.mock import patch 12 13import numpy as np 14 15import torch 16import torch.testing._internal.common_utils as common 17import torch.utils.cpp_extension 18from torch.serialization import safe_globals 19from torch.testing._internal.common_utils import ( 20 IS_ARM64, 21 skipIfTorchDynamo, 22 TemporaryFileName, 23 TEST_CUDA, 24 TEST_XPU, 25) 26from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME 27 28 29TEST_CUDA = TEST_CUDA and CUDA_HOME is not None 30TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None 31 32 33def remove_build_path(): 34 if sys.platform == "win32": 35 # Not wiping extensions build folder because Windows 36 return 37 default_build_root = torch.utils.cpp_extension.get_default_build_root() 38 if os.path.exists(default_build_root): 39 shutil.rmtree(default_build_root, ignore_errors=True) 40 41 42def generate_faked_module(): 43 def device_count() -> int: 44 return 1 45 46 def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor: 47 # create a tensor using our custom device object. 48 return torch.empty(4, 4, device="foo") 49 50 def set_rng_state( 51 new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo" 52 ) -> None: 53 pass 54 55 def is_available(): 56 return True 57 58 def current_device(): 59 return 0 60 61 # create a new module to fake torch.foo dynamicaly 62 foo = types.ModuleType("foo") 63 64 foo.device_count = device_count 65 foo.get_rng_state = get_rng_state 66 foo.set_rng_state = set_rng_state 67 foo.is_available = is_available 68 foo.current_device = current_device 69 foo._lazy_init = lambda: None 70 foo.is_initialized = lambda: True 71 72 return foo 73 74 75@unittest.skipIf(IS_ARM64, "Does not work on arm") 76@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently") 77@torch.testing._internal.common_utils.markDynamoStrictTest 78class TestCppExtensionOpenRgistration(common.TestCase): 79 """Tests Open Device Registration with C++ extensions.""" 80 81 module = None 82 83 def setUp(self): 84 super().setUp() 85 86 # cpp extensions use relative paths. Those paths are relative to 87 # this file, so we'll change the working directory temporarily 88 self.old_working_dir = os.getcwd() 89 os.chdir(os.path.dirname(os.path.abspath(__file__))) 90 91 assert self.module is not None 92 93 def tearDown(self): 94 super().tearDown() 95 96 # return the working directory (see setUp) 97 os.chdir(self.old_working_dir) 98 99 @classmethod 100 def setUpClass(cls): 101 remove_build_path() 102 103 cls.module = torch.utils.cpp_extension.load( 104 name="custom_device_extension", 105 sources=[ 106 "cpp_extensions/open_registration_extension.cpp", 107 ], 108 extra_include_paths=["cpp_extensions"], 109 extra_cflags=["-g"], 110 verbose=True, 111 ) 112 113 # register torch.foo module and foo device to torch 114 torch.utils.rename_privateuse1_backend("foo") 115 torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) 116 torch._register_device_module("foo", generate_faked_module()) 117 118 def test_base_device_registration(self): 119 self.assertFalse(self.module.custom_add_called()) 120 # create a tensor using our custom device object 121 device = self.module.custom_device() 122 x = torch.empty(4, 4, device=device) 123 y = torch.empty(4, 4, device=device) 124 # Check that our device is correct. 125 self.assertTrue(x.device == device) 126 self.assertFalse(x.is_cpu) 127 self.assertFalse(self.module.custom_add_called()) 128 # calls out custom add kernel, registered to the dispatcher 129 z = x + y 130 # check that it was called 131 self.assertTrue(self.module.custom_add_called()) 132 z_cpu = z.to(device="cpu") 133 # Check that our cross-device copy correctly copied the data to cpu 134 self.assertTrue(z_cpu.is_cpu) 135 self.assertFalse(z.is_cpu) 136 self.assertTrue(z.device == device) 137 self.assertEqual(z, z_cpu) 138 139 def test_common_registration(self): 140 # check unsupported device and duplicated registration 141 with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): 142 torch._register_device_module("dev", generate_faked_module()) 143 with self.assertRaisesRegex(RuntimeError, "The runtime module of"): 144 torch._register_device_module("foo", generate_faked_module()) 145 146 # backend name can be renamed to the same name multiple times 147 torch.utils.rename_privateuse1_backend("foo") 148 149 # backend name can't be renamed multiple times to different names. 150 with self.assertRaisesRegex( 151 RuntimeError, "torch.register_privateuse1_backend()" 152 ): 153 torch.utils.rename_privateuse1_backend("dev") 154 155 # generator tensor and module can be registered only once 156 with self.assertRaisesRegex(RuntimeError, "The custom device module of"): 157 torch.utils.generate_methods_for_privateuse1_backend() 158 159 # check whether torch.foo have been registered correctly 160 self.assertTrue( 161 torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1 162 ) 163 with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"): 164 torch.utils.backend_registration._get_custom_mod_func("func_name_") 165 166 # check attributes after registered 167 self.assertTrue(hasattr(torch.Tensor, "is_foo")) 168 self.assertTrue(hasattr(torch.Tensor, "foo")) 169 self.assertTrue(hasattr(torch.TypedStorage, "is_foo")) 170 self.assertTrue(hasattr(torch.TypedStorage, "foo")) 171 self.assertTrue(hasattr(torch.UntypedStorage, "is_foo")) 172 self.assertTrue(hasattr(torch.UntypedStorage, "foo")) 173 self.assertTrue(hasattr(torch.nn.Module, "foo")) 174 self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo")) 175 self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo")) 176 177 def test_open_device_generator_registration_and_hooks(self): 178 device = self.module.custom_device() 179 # None of our CPU operations should call the custom add function. 180 self.assertFalse(self.module.custom_add_called()) 181 182 # check generator registered before using 183 with self.assertRaisesRegex( 184 RuntimeError, 185 "Please register a generator to the PrivateUse1 dispatch key", 186 ): 187 torch.Generator(device=device) 188 189 self.module.register_generator_first() 190 gen = torch.Generator(device=device) 191 self.assertTrue(gen.device == device) 192 193 # generator can be registered only once 194 with self.assertRaisesRegex( 195 RuntimeError, 196 "Only can register a generator to the PrivateUse1 dispatch key once", 197 ): 198 self.module.register_generator_second() 199 200 if self.module.is_register_hook() is False: 201 self.module.register_hook() 202 default_gen = self.module.default_generator(0) 203 self.assertTrue( 204 default_gen.device.type == torch._C._get_privateuse1_backend_name() 205 ) 206 207 def test_open_device_dispatchstub(self): 208 # test kernels could be reused by privateuse1 backend through dispatchstub 209 input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu") 210 foo_input_data = input_data.to("foo") 211 output_data = torch.abs(input_data) 212 foo_output_data = torch.abs(foo_input_data) 213 self.assertEqual(output_data, foo_output_data.cpu()) 214 215 output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") 216 # output operand will resize flag is True in TensorIterator. 217 foo_input_data = input_data.to("foo") 218 foo_output_data = output_data.to("foo") 219 # output operand will resize flag is False in TensorIterator. 220 torch.abs(input_data, out=output_data[:, :, 0:6:2]) 221 torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2]) 222 self.assertEqual(output_data, foo_output_data.cpu()) 223 224 # output operand will resize flag is True in TensorIterator. 225 # and convert output to contiguous tensor in TensorIterator. 226 output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") 227 foo_input_data = input_data.to("foo") 228 foo_output_data = output_data.to("foo") 229 torch.abs(input_data, out=output_data[:, :, 0:6:3]) 230 torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3]) 231 self.assertEqual(output_data, foo_output_data.cpu()) 232 233 def test_open_device_quantized(self): 234 input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo") 235 quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8) 236 self.assertEqual(quantized_tensor.device, torch.device("foo:0")) 237 self.assertEqual(quantized_tensor.dtype, torch.qint8) 238 239 def test_open_device_random(self): 240 # check if torch.foo have implemented get_rng_state 241 with torch.random.fork_rng(device_type="foo"): 242 pass 243 244 def test_open_device_tensor(self): 245 device = self.module.custom_device() 246 247 # check whether print tensor.type() meets the expectation 248 dtypes = { 249 torch.bool: "torch.foo.BoolTensor", 250 torch.double: "torch.foo.DoubleTensor", 251 torch.float32: "torch.foo.FloatTensor", 252 torch.half: "torch.foo.HalfTensor", 253 torch.int32: "torch.foo.IntTensor", 254 torch.int64: "torch.foo.LongTensor", 255 torch.int8: "torch.foo.CharTensor", 256 torch.short: "torch.foo.ShortTensor", 257 torch.uint8: "torch.foo.ByteTensor", 258 } 259 for tt, dt in dtypes.items(): 260 test_tensor = torch.empty(4, 4, dtype=tt, device=device) 261 self.assertTrue(test_tensor.type() == dt) 262 263 # check whether the attributes and methods of the corresponding custom backend are generated correctly 264 x = torch.empty(4, 4) 265 self.assertFalse(x.is_foo) 266 267 x = x.foo(torch.device("foo")) 268 self.assertFalse(self.module.custom_add_called()) 269 self.assertTrue(x.is_foo) 270 271 # test different device type input 272 y = torch.empty(4, 4) 273 self.assertFalse(y.is_foo) 274 275 y = y.foo(torch.device("foo:0")) 276 self.assertFalse(self.module.custom_add_called()) 277 self.assertTrue(y.is_foo) 278 279 # test different device type input 280 z = torch.empty(4, 4) 281 self.assertFalse(z.is_foo) 282 283 z = z.foo(0) 284 self.assertFalse(self.module.custom_add_called()) 285 self.assertTrue(z.is_foo) 286 287 def test_open_device_packed_sequence(self): 288 device = self.module.custom_device() 289 a = torch.rand(5, 3) 290 b = torch.tensor([1, 1, 1, 1, 1]) 291 input = torch.nn.utils.rnn.PackedSequence(a, b) 292 self.assertFalse(input.is_foo) 293 input_foo = input.foo() 294 self.assertTrue(input_foo.is_foo) 295 296 def test_open_device_storage(self): 297 # check whether the attributes and methods for storage of the corresponding custom backend are generated correctly 298 x = torch.empty(4, 4) 299 z1 = x.storage() 300 self.assertFalse(z1.is_foo) 301 302 z1 = z1.foo() 303 self.assertFalse(self.module.custom_add_called()) 304 self.assertTrue(z1.is_foo) 305 306 with self.assertRaisesRegex(RuntimeError, "Invalid device"): 307 z1.foo(torch.device("cpu")) 308 309 z1 = z1.cpu() 310 self.assertFalse(self.module.custom_add_called()) 311 self.assertFalse(z1.is_foo) 312 313 z1 = z1.foo(device="foo:0", non_blocking=False) 314 self.assertFalse(self.module.custom_add_called()) 315 self.assertTrue(z1.is_foo) 316 317 with self.assertRaisesRegex(RuntimeError, "Invalid device"): 318 z1.foo(device="cuda:0", non_blocking=False) 319 320 # check UntypedStorage 321 y = torch.empty(4, 4) 322 z2 = y.untyped_storage() 323 self.assertFalse(z2.is_foo) 324 325 z2 = z2.foo() 326 self.assertFalse(self.module.custom_add_called()) 327 self.assertTrue(z2.is_foo) 328 329 # check custom StorageImpl create 330 self.module.custom_storage_registry() 331 332 z3 = y.untyped_storage() 333 self.assertFalse(self.module.custom_storageImpl_called()) 334 335 z3 = z3.foo() 336 self.assertTrue(self.module.custom_storageImpl_called()) 337 self.assertFalse(self.module.custom_storageImpl_called()) 338 339 z3 = z3[0:3] 340 self.assertTrue(self.module.custom_storageImpl_called()) 341 342 @skipIfTorchDynamo("unsupported aten.is_pinned.default") 343 def test_open_device_storage_pin_memory(self): 344 # Check if the pin_memory is functioning properly on custom device 345 cpu_tensor = torch.empty(3) 346 self.assertFalse(cpu_tensor.is_foo) 347 self.assertFalse(cpu_tensor.is_pinned("foo")) 348 349 cpu_tensor_pin = cpu_tensor.pin_memory("foo") 350 self.assertTrue(cpu_tensor_pin.is_pinned("foo")) 351 352 # Test storage pin_memory and is_pin 353 cpu_storage = cpu_tensor.storage() 354 # We implement a dummy pin_memory of no practical significance 355 # for custom device. Once tensor.pin_memory() has been called, 356 # then tensor.is_pinned() will always return true no matter 357 # what tensor it's called on. 358 self.assertTrue(cpu_storage.is_pinned("foo")) 359 360 cpu_storage_pinned = cpu_storage.pin_memory("foo") 361 self.assertTrue(cpu_storage_pinned.is_pinned("foo")) 362 363 # Test untyped storage pin_memory and is_pin 364 cpu_tensor = torch.randn([3, 2, 1, 4]) 365 cpu_untyped_storage = cpu_tensor.untyped_storage() 366 self.assertTrue(cpu_untyped_storage.is_pinned("foo")) 367 368 cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo") 369 self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo")) 370 371 @unittest.skip( 372 "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" 373 ) 374 def test_open_device_serialization(self): 375 self.module.set_custom_device_index(-1) 376 storage = torch.UntypedStorage(4, device=torch.device("foo")) 377 self.assertEqual(torch.serialization.location_tag(storage), "foo") 378 379 self.module.set_custom_device_index(0) 380 storage = torch.UntypedStorage(4, device=torch.device("foo")) 381 self.assertEqual(torch.serialization.location_tag(storage), "foo:0") 382 383 cpu_storage = torch.empty(4, 4).storage() 384 foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0") 385 self.assertTrue(foo_storage.is_foo) 386 387 # test tensor MetaData serialization 388 x = torch.empty(4, 4).long() 389 y = x.foo() 390 self.assertFalse(self.module.check_backend_meta(y)) 391 self.module.custom_set_backend_meta(y) 392 self.assertTrue(self.module.check_backend_meta(y)) 393 394 self.module.custom_serialization_registry() 395 with tempfile.TemporaryDirectory() as tmpdir: 396 path = os.path.join(tmpdir, "data.pt") 397 torch.save(y, path) 398 z1 = torch.load(path) 399 # loads correctly onto the foo backend device 400 self.assertTrue(z1.is_foo) 401 # loads BackendMeta data correctly 402 self.assertTrue(self.module.check_backend_meta(z1)) 403 404 # cross-backend 405 z2 = torch.load(path, map_location="cpu") 406 # loads correctly onto the cpu backend device 407 self.assertFalse(z2.is_foo) 408 # loads BackendMeta data correctly 409 self.assertFalse(self.module.check_backend_meta(z2)) 410 411 def test_open_device_storage_resize(self): 412 cpu_tensor = torch.randn([8]) 413 foo_tensor = cpu_tensor.foo() 414 foo_storage = foo_tensor.storage() 415 self.assertTrue(foo_storage.size() == 8) 416 417 # Only register tensor resize_ function. 418 foo_tensor.resize_(8) 419 self.assertTrue(foo_storage.size() == 8) 420 421 with self.assertRaisesRegex(TypeError, "Overflow"): 422 foo_tensor.resize_(8**29) 423 424 def test_open_device_storage_type(self): 425 # test cpu float storage 426 cpu_tensor = torch.randn([8]).float() 427 cpu_storage = cpu_tensor.storage() 428 self.assertEqual(cpu_storage.type(), "torch.FloatStorage") 429 430 # test custom float storage before defining FloatStorage 431 foo_tensor = cpu_tensor.foo() 432 foo_storage = foo_tensor.storage() 433 self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage") 434 435 class CustomFloatStorage: 436 @property 437 def __module__(self): 438 return "torch." + torch._C._get_privateuse1_backend_name() 439 440 @property 441 def __name__(self): 442 return "FloatStorage" 443 444 # test custom float storage after defining FloatStorage 445 try: 446 torch.foo.FloatStorage = CustomFloatStorage() 447 self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage") 448 449 # test custom int storage after defining FloatStorage 450 foo_tensor2 = torch.randn([8]).int().foo() 451 foo_storage2 = foo_tensor2.storage() 452 self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage") 453 finally: 454 torch.foo.FloatStorage = None 455 456 def test_open_device_faketensor(self): 457 with torch._subclasses.fake_tensor.FakeTensorMode.push(): 458 a = torch.empty(1, device="foo") 459 b = torch.empty(1, device="foo:0") 460 result = a + b 461 462 def test_open_device_named_tensor(self): 463 torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"]) 464 465 # Not an open registration test - this file is just very convenient 466 # for testing torch.compile on custom C++ operators 467 def test_compile_autograd_function_returns_self(self): 468 x_ref = torch.randn(4, requires_grad=True) 469 out_ref = self.module.custom_autograd_fn_returns_self(x_ref) 470 out_ref.sum().backward() 471 472 x_test = x_ref.clone().detach().requires_grad_(True) 473 f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self) 474 out_test = f_compiled(x_test) 475 out_test.sum().backward() 476 477 self.assertEqual(out_ref, out_test) 478 self.assertEqual(x_ref.grad, x_test.grad) 479 480 # Not an open registration test - this file is just very convenient 481 # for testing torch.compile on custom C++ operators 482 @skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket") 483 def test_compile_autograd_function_aliasing(self): 484 x_ref = torch.randn(4, requires_grad=True) 485 out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref) 486 out_ref.sum().backward() 487 488 x_test = x_ref.clone().detach().requires_grad_(True) 489 f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing) 490 out_test = f_compiled(x_test) 491 out_test.sum().backward() 492 493 self.assertEqual(out_ref, out_test) 494 self.assertEqual(x_ref.grad, x_test.grad) 495 496 def test_open_device_scalar_type_fallback(self): 497 z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64) 498 z = torch.triu_indices(3, 3, device="foo") 499 self.assertEqual(z_cpu, z) 500 501 def test_open_device_tensor_type_fallback(self): 502 # create tensors located in custom device 503 x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo") 504 y = torch.Tensor([1, 0, 2]).to("foo") 505 # create result tensor located in cpu 506 z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]]) 507 # Check that our device is correct. 508 device = self.module.custom_device() 509 self.assertTrue(x.device == device) 510 self.assertFalse(x.is_cpu) 511 512 # call sub op, which will fallback to cpu 513 z = torch.sub(x, y) 514 self.assertEqual(z_cpu, z) 515 516 # call index op, which will fallback to cpu 517 z_cpu = torch.Tensor([3, 1]) 518 y = torch.Tensor([1, 0]).long().to("foo") 519 z = x[y, y] 520 self.assertEqual(z_cpu, z) 521 522 def test_open_device_tensorlist_type_fallback(self): 523 # create tensors located in custom device 524 v_foo = torch.Tensor([1, 2, 3]).to("foo") 525 # create result tensor located in cpu 526 z_cpu = torch.Tensor([2, 4, 6]) 527 # create tensorlist for foreach_add op 528 x = (v_foo, v_foo) 529 y = (v_foo, v_foo) 530 # Check that our device is correct. 531 device = self.module.custom_device() 532 self.assertTrue(v_foo.device == device) 533 self.assertFalse(v_foo.is_cpu) 534 535 # call _foreach_add op, which will fallback to cpu 536 z = torch._foreach_add(x, y) 537 self.assertEqual(z_cpu, z[0]) 538 self.assertEqual(z_cpu, z[1]) 539 540 # call _fused_adamw_ with undefined tensor. 541 self.module.fallback_with_undefined_tensor() 542 543 def test_open_device_numpy_serialization(self): 544 torch.utils.rename_privateuse1_backend("foo") 545 device = self.module.custom_device() 546 default_protocol = torch.serialization.DEFAULT_PROTOCOL 547 # This is a hack to test serialization through numpy 548 with patch.object(torch._C, "_has_storage", return_value=False): 549 x = torch.randn(2, 3) 550 x_foo = x.to(device) 551 sd = {"x": x_foo} 552 rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] 553 self.assertTrue( 554 rebuild_func is torch._utils._rebuild_device_tensor_from_numpy 555 ) 556 # Test map_location 557 with TemporaryFileName() as f: 558 torch.save(sd, f) 559 with safe_globals( 560 [ 561 np.core.multiarray._reconstruct, 562 np.ndarray, 563 np.dtype, 564 _codecs.encode, 565 type(np.dtype(np.float32)) 566 if np.__version__ < "1.25.0" 567 else np.dtypes.Float32DType, 568 ] 569 ): 570 sd_loaded = torch.load(f, map_location="cpu") 571 self.assertTrue(sd_loaded["x"].is_cpu) 572 573 # Test metadata_only 574 with TemporaryFileName() as f: 575 with self.assertRaisesRegex( 576 RuntimeError, 577 "Cannot serialize tensors on backends with no storage under skip_data context manager", 578 ): 579 with torch.serialization.skip_data(): 580 torch.save(sd, f) 581 582 583if __name__ == "__main__": 584 common.run_tests() 585