• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: cpp-extensions"]
2
3import _codecs
4import os
5import shutil
6import sys
7import tempfile
8import types
9import unittest
10from typing import Union
11from unittest.mock import patch
12
13import numpy as np
14
15import torch
16import torch.testing._internal.common_utils as common
17import torch.utils.cpp_extension
18from torch.serialization import safe_globals
19from torch.testing._internal.common_utils import (
20    IS_ARM64,
21    skipIfTorchDynamo,
22    TemporaryFileName,
23    TEST_CUDA,
24    TEST_XPU,
25)
26from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
27
28
29TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
30TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
31
32
33def remove_build_path():
34    if sys.platform == "win32":
35        # Not wiping extensions build folder because Windows
36        return
37    default_build_root = torch.utils.cpp_extension.get_default_build_root()
38    if os.path.exists(default_build_root):
39        shutil.rmtree(default_build_root, ignore_errors=True)
40
41
42def generate_faked_module():
43    def device_count() -> int:
44        return 1
45
46    def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor:
47        # create a tensor using our custom device object.
48        return torch.empty(4, 4, device="foo")
49
50    def set_rng_state(
51        new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo"
52    ) -> None:
53        pass
54
55    def is_available():
56        return True
57
58    def current_device():
59        return 0
60
61    # create a new module to fake torch.foo dynamicaly
62    foo = types.ModuleType("foo")
63
64    foo.device_count = device_count
65    foo.get_rng_state = get_rng_state
66    foo.set_rng_state = set_rng_state
67    foo.is_available = is_available
68    foo.current_device = current_device
69    foo._lazy_init = lambda: None
70    foo.is_initialized = lambda: True
71
72    return foo
73
74
75@unittest.skipIf(IS_ARM64, "Does not work on arm")
76@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
77@torch.testing._internal.common_utils.markDynamoStrictTest
78class TestCppExtensionOpenRgistration(common.TestCase):
79    """Tests Open Device Registration with C++ extensions."""
80
81    module = None
82
83    def setUp(self):
84        super().setUp()
85
86        # cpp extensions use relative paths. Those paths are relative to
87        # this file, so we'll change the working directory temporarily
88        self.old_working_dir = os.getcwd()
89        os.chdir(os.path.dirname(os.path.abspath(__file__)))
90
91        assert self.module is not None
92
93    def tearDown(self):
94        super().tearDown()
95
96        # return the working directory (see setUp)
97        os.chdir(self.old_working_dir)
98
99    @classmethod
100    def setUpClass(cls):
101        remove_build_path()
102
103        cls.module = torch.utils.cpp_extension.load(
104            name="custom_device_extension",
105            sources=[
106                "cpp_extensions/open_registration_extension.cpp",
107            ],
108            extra_include_paths=["cpp_extensions"],
109            extra_cflags=["-g"],
110            verbose=True,
111        )
112
113        # register torch.foo module and foo device to torch
114        torch.utils.rename_privateuse1_backend("foo")
115        torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
116        torch._register_device_module("foo", generate_faked_module())
117
118    def test_base_device_registration(self):
119        self.assertFalse(self.module.custom_add_called())
120        # create a tensor using our custom device object
121        device = self.module.custom_device()
122        x = torch.empty(4, 4, device=device)
123        y = torch.empty(4, 4, device=device)
124        # Check that our device is correct.
125        self.assertTrue(x.device == device)
126        self.assertFalse(x.is_cpu)
127        self.assertFalse(self.module.custom_add_called())
128        # calls out custom add kernel, registered to the dispatcher
129        z = x + y
130        # check that it was called
131        self.assertTrue(self.module.custom_add_called())
132        z_cpu = z.to(device="cpu")
133        # Check that our cross-device copy correctly copied the data to cpu
134        self.assertTrue(z_cpu.is_cpu)
135        self.assertFalse(z.is_cpu)
136        self.assertTrue(z.device == device)
137        self.assertEqual(z, z_cpu)
138
139    def test_common_registration(self):
140        # check unsupported device and duplicated registration
141        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
142            torch._register_device_module("dev", generate_faked_module())
143        with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
144            torch._register_device_module("foo", generate_faked_module())
145
146        # backend name can be renamed to the same name multiple times
147        torch.utils.rename_privateuse1_backend("foo")
148
149        # backend name can't be renamed multiple times to different names.
150        with self.assertRaisesRegex(
151            RuntimeError, "torch.register_privateuse1_backend()"
152        ):
153            torch.utils.rename_privateuse1_backend("dev")
154
155        # generator tensor and module can be registered only once
156        with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
157            torch.utils.generate_methods_for_privateuse1_backend()
158
159        # check whether torch.foo have been registered correctly
160        self.assertTrue(
161            torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1
162        )
163        with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
164            torch.utils.backend_registration._get_custom_mod_func("func_name_")
165
166        # check attributes after registered
167        self.assertTrue(hasattr(torch.Tensor, "is_foo"))
168        self.assertTrue(hasattr(torch.Tensor, "foo"))
169        self.assertTrue(hasattr(torch.TypedStorage, "is_foo"))
170        self.assertTrue(hasattr(torch.TypedStorage, "foo"))
171        self.assertTrue(hasattr(torch.UntypedStorage, "is_foo"))
172        self.assertTrue(hasattr(torch.UntypedStorage, "foo"))
173        self.assertTrue(hasattr(torch.nn.Module, "foo"))
174        self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo"))
175        self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo"))
176
177    def test_open_device_generator_registration_and_hooks(self):
178        device = self.module.custom_device()
179        # None of our CPU operations should call the custom add function.
180        self.assertFalse(self.module.custom_add_called())
181
182        # check generator registered before using
183        with self.assertRaisesRegex(
184            RuntimeError,
185            "Please register a generator to the PrivateUse1 dispatch key",
186        ):
187            torch.Generator(device=device)
188
189        self.module.register_generator_first()
190        gen = torch.Generator(device=device)
191        self.assertTrue(gen.device == device)
192
193        # generator can be registered only once
194        with self.assertRaisesRegex(
195            RuntimeError,
196            "Only can register a generator to the PrivateUse1 dispatch key once",
197        ):
198            self.module.register_generator_second()
199
200        if self.module.is_register_hook() is False:
201            self.module.register_hook()
202        default_gen = self.module.default_generator(0)
203        self.assertTrue(
204            default_gen.device.type == torch._C._get_privateuse1_backend_name()
205        )
206
207    def test_open_device_dispatchstub(self):
208        # test kernels could be reused by privateuse1 backend through dispatchstub
209        input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
210        foo_input_data = input_data.to("foo")
211        output_data = torch.abs(input_data)
212        foo_output_data = torch.abs(foo_input_data)
213        self.assertEqual(output_data, foo_output_data.cpu())
214
215        output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
216        # output operand will resize flag is True in TensorIterator.
217        foo_input_data = input_data.to("foo")
218        foo_output_data = output_data.to("foo")
219        # output operand will resize flag is False in TensorIterator.
220        torch.abs(input_data, out=output_data[:, :, 0:6:2])
221        torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2])
222        self.assertEqual(output_data, foo_output_data.cpu())
223
224        # output operand will resize flag is True in TensorIterator.
225        # and convert output to contiguous tensor in TensorIterator.
226        output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
227        foo_input_data = input_data.to("foo")
228        foo_output_data = output_data.to("foo")
229        torch.abs(input_data, out=output_data[:, :, 0:6:3])
230        torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3])
231        self.assertEqual(output_data, foo_output_data.cpu())
232
233    def test_open_device_quantized(self):
234        input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
235        quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
236        self.assertEqual(quantized_tensor.device, torch.device("foo:0"))
237        self.assertEqual(quantized_tensor.dtype, torch.qint8)
238
239    def test_open_device_random(self):
240        # check if torch.foo have implemented get_rng_state
241        with torch.random.fork_rng(device_type="foo"):
242            pass
243
244    def test_open_device_tensor(self):
245        device = self.module.custom_device()
246
247        # check whether print tensor.type() meets the expectation
248        dtypes = {
249            torch.bool: "torch.foo.BoolTensor",
250            torch.double: "torch.foo.DoubleTensor",
251            torch.float32: "torch.foo.FloatTensor",
252            torch.half: "torch.foo.HalfTensor",
253            torch.int32: "torch.foo.IntTensor",
254            torch.int64: "torch.foo.LongTensor",
255            torch.int8: "torch.foo.CharTensor",
256            torch.short: "torch.foo.ShortTensor",
257            torch.uint8: "torch.foo.ByteTensor",
258        }
259        for tt, dt in dtypes.items():
260            test_tensor = torch.empty(4, 4, dtype=tt, device=device)
261            self.assertTrue(test_tensor.type() == dt)
262
263        # check whether the attributes and methods of the corresponding custom backend are generated correctly
264        x = torch.empty(4, 4)
265        self.assertFalse(x.is_foo)
266
267        x = x.foo(torch.device("foo"))
268        self.assertFalse(self.module.custom_add_called())
269        self.assertTrue(x.is_foo)
270
271        # test different device type input
272        y = torch.empty(4, 4)
273        self.assertFalse(y.is_foo)
274
275        y = y.foo(torch.device("foo:0"))
276        self.assertFalse(self.module.custom_add_called())
277        self.assertTrue(y.is_foo)
278
279        # test different device type input
280        z = torch.empty(4, 4)
281        self.assertFalse(z.is_foo)
282
283        z = z.foo(0)
284        self.assertFalse(self.module.custom_add_called())
285        self.assertTrue(z.is_foo)
286
287    def test_open_device_packed_sequence(self):
288        device = self.module.custom_device()
289        a = torch.rand(5, 3)
290        b = torch.tensor([1, 1, 1, 1, 1])
291        input = torch.nn.utils.rnn.PackedSequence(a, b)
292        self.assertFalse(input.is_foo)
293        input_foo = input.foo()
294        self.assertTrue(input_foo.is_foo)
295
296    def test_open_device_storage(self):
297        # check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
298        x = torch.empty(4, 4)
299        z1 = x.storage()
300        self.assertFalse(z1.is_foo)
301
302        z1 = z1.foo()
303        self.assertFalse(self.module.custom_add_called())
304        self.assertTrue(z1.is_foo)
305
306        with self.assertRaisesRegex(RuntimeError, "Invalid device"):
307            z1.foo(torch.device("cpu"))
308
309        z1 = z1.cpu()
310        self.assertFalse(self.module.custom_add_called())
311        self.assertFalse(z1.is_foo)
312
313        z1 = z1.foo(device="foo:0", non_blocking=False)
314        self.assertFalse(self.module.custom_add_called())
315        self.assertTrue(z1.is_foo)
316
317        with self.assertRaisesRegex(RuntimeError, "Invalid device"):
318            z1.foo(device="cuda:0", non_blocking=False)
319
320        # check UntypedStorage
321        y = torch.empty(4, 4)
322        z2 = y.untyped_storage()
323        self.assertFalse(z2.is_foo)
324
325        z2 = z2.foo()
326        self.assertFalse(self.module.custom_add_called())
327        self.assertTrue(z2.is_foo)
328
329        # check custom StorageImpl create
330        self.module.custom_storage_registry()
331
332        z3 = y.untyped_storage()
333        self.assertFalse(self.module.custom_storageImpl_called())
334
335        z3 = z3.foo()
336        self.assertTrue(self.module.custom_storageImpl_called())
337        self.assertFalse(self.module.custom_storageImpl_called())
338
339        z3 = z3[0:3]
340        self.assertTrue(self.module.custom_storageImpl_called())
341
342    @skipIfTorchDynamo("unsupported aten.is_pinned.default")
343    def test_open_device_storage_pin_memory(self):
344        # Check if the pin_memory is functioning properly on custom device
345        cpu_tensor = torch.empty(3)
346        self.assertFalse(cpu_tensor.is_foo)
347        self.assertFalse(cpu_tensor.is_pinned("foo"))
348
349        cpu_tensor_pin = cpu_tensor.pin_memory("foo")
350        self.assertTrue(cpu_tensor_pin.is_pinned("foo"))
351
352        # Test storage pin_memory and is_pin
353        cpu_storage = cpu_tensor.storage()
354        # We implement a dummy pin_memory of no practical significance
355        # for custom device. Once tensor.pin_memory() has been called,
356        # then tensor.is_pinned() will always return true no matter
357        # what tensor it's called on.
358        self.assertTrue(cpu_storage.is_pinned("foo"))
359
360        cpu_storage_pinned = cpu_storage.pin_memory("foo")
361        self.assertTrue(cpu_storage_pinned.is_pinned("foo"))
362
363        # Test untyped storage pin_memory and is_pin
364        cpu_tensor = torch.randn([3, 2, 1, 4])
365        cpu_untyped_storage = cpu_tensor.untyped_storage()
366        self.assertTrue(cpu_untyped_storage.is_pinned("foo"))
367
368        cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo")
369        self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
370
371    @unittest.skip(
372        "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
373    )
374    def test_open_device_serialization(self):
375        self.module.set_custom_device_index(-1)
376        storage = torch.UntypedStorage(4, device=torch.device("foo"))
377        self.assertEqual(torch.serialization.location_tag(storage), "foo")
378
379        self.module.set_custom_device_index(0)
380        storage = torch.UntypedStorage(4, device=torch.device("foo"))
381        self.assertEqual(torch.serialization.location_tag(storage), "foo:0")
382
383        cpu_storage = torch.empty(4, 4).storage()
384        foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0")
385        self.assertTrue(foo_storage.is_foo)
386
387        # test tensor MetaData serialization
388        x = torch.empty(4, 4).long()
389        y = x.foo()
390        self.assertFalse(self.module.check_backend_meta(y))
391        self.module.custom_set_backend_meta(y)
392        self.assertTrue(self.module.check_backend_meta(y))
393
394        self.module.custom_serialization_registry()
395        with tempfile.TemporaryDirectory() as tmpdir:
396            path = os.path.join(tmpdir, "data.pt")
397            torch.save(y, path)
398            z1 = torch.load(path)
399            # loads correctly onto the foo backend device
400            self.assertTrue(z1.is_foo)
401            # loads BackendMeta data correctly
402            self.assertTrue(self.module.check_backend_meta(z1))
403
404            # cross-backend
405            z2 = torch.load(path, map_location="cpu")
406            # loads correctly onto the cpu backend device
407            self.assertFalse(z2.is_foo)
408            # loads BackendMeta data correctly
409            self.assertFalse(self.module.check_backend_meta(z2))
410
411    def test_open_device_storage_resize(self):
412        cpu_tensor = torch.randn([8])
413        foo_tensor = cpu_tensor.foo()
414        foo_storage = foo_tensor.storage()
415        self.assertTrue(foo_storage.size() == 8)
416
417        # Only register tensor resize_ function.
418        foo_tensor.resize_(8)
419        self.assertTrue(foo_storage.size() == 8)
420
421        with self.assertRaisesRegex(TypeError, "Overflow"):
422            foo_tensor.resize_(8**29)
423
424    def test_open_device_storage_type(self):
425        # test cpu float storage
426        cpu_tensor = torch.randn([8]).float()
427        cpu_storage = cpu_tensor.storage()
428        self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
429
430        # test custom float storage before defining FloatStorage
431        foo_tensor = cpu_tensor.foo()
432        foo_storage = foo_tensor.storage()
433        self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
434
435        class CustomFloatStorage:
436            @property
437            def __module__(self):
438                return "torch." + torch._C._get_privateuse1_backend_name()
439
440            @property
441            def __name__(self):
442                return "FloatStorage"
443
444        # test custom float storage after defining FloatStorage
445        try:
446            torch.foo.FloatStorage = CustomFloatStorage()
447            self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
448
449            # test custom int storage after defining FloatStorage
450            foo_tensor2 = torch.randn([8]).int().foo()
451            foo_storage2 = foo_tensor2.storage()
452            self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
453        finally:
454            torch.foo.FloatStorage = None
455
456    def test_open_device_faketensor(self):
457        with torch._subclasses.fake_tensor.FakeTensorMode.push():
458            a = torch.empty(1, device="foo")
459            b = torch.empty(1, device="foo:0")
460            result = a + b
461
462    def test_open_device_named_tensor(self):
463        torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
464
465    # Not an open registration test - this file is just very convenient
466    # for testing torch.compile on custom C++ operators
467    def test_compile_autograd_function_returns_self(self):
468        x_ref = torch.randn(4, requires_grad=True)
469        out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
470        out_ref.sum().backward()
471
472        x_test = x_ref.clone().detach().requires_grad_(True)
473        f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
474        out_test = f_compiled(x_test)
475        out_test.sum().backward()
476
477        self.assertEqual(out_ref, out_test)
478        self.assertEqual(x_ref.grad, x_test.grad)
479
480    # Not an open registration test - this file is just very convenient
481    # for testing torch.compile on custom C++ operators
482    @skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
483    def test_compile_autograd_function_aliasing(self):
484        x_ref = torch.randn(4, requires_grad=True)
485        out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
486        out_ref.sum().backward()
487
488        x_test = x_ref.clone().detach().requires_grad_(True)
489        f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
490        out_test = f_compiled(x_test)
491        out_test.sum().backward()
492
493        self.assertEqual(out_ref, out_test)
494        self.assertEqual(x_ref.grad, x_test.grad)
495
496    def test_open_device_scalar_type_fallback(self):
497        z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
498        z = torch.triu_indices(3, 3, device="foo")
499        self.assertEqual(z_cpu, z)
500
501    def test_open_device_tensor_type_fallback(self):
502        # create tensors located in custom device
503        x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo")
504        y = torch.Tensor([1, 0, 2]).to("foo")
505        # create result tensor located in cpu
506        z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
507        # Check that our device is correct.
508        device = self.module.custom_device()
509        self.assertTrue(x.device == device)
510        self.assertFalse(x.is_cpu)
511
512        # call sub op, which will fallback to cpu
513        z = torch.sub(x, y)
514        self.assertEqual(z_cpu, z)
515
516        # call index op, which will fallback to cpu
517        z_cpu = torch.Tensor([3, 1])
518        y = torch.Tensor([1, 0]).long().to("foo")
519        z = x[y, y]
520        self.assertEqual(z_cpu, z)
521
522    def test_open_device_tensorlist_type_fallback(self):
523        # create tensors located in custom device
524        v_foo = torch.Tensor([1, 2, 3]).to("foo")
525        # create result tensor located in cpu
526        z_cpu = torch.Tensor([2, 4, 6])
527        # create tensorlist for foreach_add op
528        x = (v_foo, v_foo)
529        y = (v_foo, v_foo)
530        # Check that our device is correct.
531        device = self.module.custom_device()
532        self.assertTrue(v_foo.device == device)
533        self.assertFalse(v_foo.is_cpu)
534
535        # call _foreach_add op, which will fallback to cpu
536        z = torch._foreach_add(x, y)
537        self.assertEqual(z_cpu, z[0])
538        self.assertEqual(z_cpu, z[1])
539
540        # call _fused_adamw_ with undefined tensor.
541        self.module.fallback_with_undefined_tensor()
542
543    def test_open_device_numpy_serialization(self):
544        torch.utils.rename_privateuse1_backend("foo")
545        device = self.module.custom_device()
546        default_protocol = torch.serialization.DEFAULT_PROTOCOL
547        # This is a hack to test serialization through numpy
548        with patch.object(torch._C, "_has_storage", return_value=False):
549            x = torch.randn(2, 3)
550            x_foo = x.to(device)
551            sd = {"x": x_foo}
552            rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
553            self.assertTrue(
554                rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
555            )
556            # Test map_location
557            with TemporaryFileName() as f:
558                torch.save(sd, f)
559                with safe_globals(
560                    [
561                        np.core.multiarray._reconstruct,
562                        np.ndarray,
563                        np.dtype,
564                        _codecs.encode,
565                        type(np.dtype(np.float32))
566                        if np.__version__ < "1.25.0"
567                        else np.dtypes.Float32DType,
568                    ]
569                ):
570                    sd_loaded = torch.load(f, map_location="cpu")
571                self.assertTrue(sd_loaded["x"].is_cpu)
572
573            # Test metadata_only
574            with TemporaryFileName() as f:
575                with self.assertRaisesRegex(
576                    RuntimeError,
577                    "Cannot serialize tensors on backends with no storage under skip_data context manager",
578                ):
579                    with torch.serialization.skip_data():
580                        torch.save(sd, f)
581
582
583if __name__ == "__main__":
584    common.run_tests()
585