• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2import torch
3from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper
4from torch._inductor.codegen.codegen_device_driver import cuda_kernel_driver
5from torch._inductor.test_case import run_tests, TestCase
6
7
8TEST_CODES = [
9    "CUresult code = EXPR;",
10    "CUfunction kernel = nullptr;",
11    "static CUfunction kernel = nullptr;",
12    "CUdeviceptr var = reinterpret_cast<CUdeviceptr>(arg.data_ptr());",
13    "at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal());",
14    # Hipification should be idempotent, hipifying should be a no-op for already hipified files
15    "at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
16]
17
18HIP_CODES = [
19    "hipError_t code = EXPR;",
20    "hipFunction_t kernel = nullptr;",
21    "static hipFunction_t kernel = nullptr;",
22    "hipDeviceptr_t var = reinterpret_cast<hipDeviceptr_t>(arg.data_ptr());",
23    "at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
24    "at::hip::HIPStreamGuardMasqueradingAsCUDA guard(at::hip::getStreamFromExternalMasqueradingAsCUDA());",
25]
26
27
28class TestCppWrapperHipify(TestCase):
29    def test_hipify_basic_declaration(self) -> None:
30        assert len(TEST_CODES) == len(HIP_CODES)
31        for i in range(len(TEST_CODES)):
32            result = maybe_hipify_code_wrapper(TEST_CODES[i], True)
33            expected = HIP_CODES[i]
34            self.assertEqual(result, expected)
35
36    def test_hipify_aoti_driver_header(self) -> None:
37        header = cuda_kernel_driver()
38        expected = """
39            #define CUDA_DRIVER_CHECK(EXPR)                    \\
40            do {                                               \\
41                hipError_t code = EXPR;                          \\
42                const char *msg;                               \\
43                hipDrvGetErrorString(code, &msg);                  \\
44                if (code != hipSuccess) {                    \\
45                    throw std::runtime_error(                  \\
46                        std::string("CUDA driver error: ") +   \\
47                        std::string(msg));                     \\
48                }                                              \\
49            } while (0);
50
51            namespace {
52
53            struct Grid {
54                Grid(uint32_t x, uint32_t y, uint32_t z)
55                  : grid_x(x), grid_y(y), grid_z(z) {}
56                uint32_t grid_x;
57                uint32_t grid_y;
58                uint32_t grid_z;
59
60                bool is_non_zero() {
61                    return grid_x > 0 && grid_y > 0 && grid_z > 0;
62                }
63            };
64
65            }  // anonymous namespace
66
67            static inline hipFunction_t loadKernel(
68                    std::string filePath,
69                    const std::string &funcName,
70                    uint32_t sharedMemBytes,
71                    const std::optional<std::string> &cubinDir = std::nullopt) {
72                if (cubinDir) {
73                    std::filesystem::path p1{*cubinDir};
74                    std::filesystem::path p2{filePath};
75                    filePath = (p1 / p2.filename()).string();
76                }
77
78                hipModule_t mod;
79                hipFunction_t func;
80                CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
81                CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str()));
82                if (sharedMemBytes > 0) {
83                    CUDA_DRIVER_CHECK(hipFuncSetAttribute(
84                        func,
85                        hipFuncAttributeMaxDynamicSharedMemorySize,
86                        sharedMemBytes
87                    ))
88                }
89                return func;
90            }
91
92            static inline void launchKernel(
93                    hipFunction_t func,
94                    uint32_t gridX,
95                    uint32_t gridY,
96                    uint32_t gridZ,
97                    uint32_t numWarps,
98                    uint32_t sharedMemBytes,
99                    void* args[],
100                    hipStream_t stream) {
101                CUDA_DRIVER_CHECK(hipModuleLaunchKernel(
102                    func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
103                ));
104            }
105        """
106        if torch.version.hip is not None:
107            expected = expected.replace("32*numWarps", "64*numWarps")
108        result = maybe_hipify_code_wrapper(header, True)
109        self.assertEqual(result.rstrip(), expected.rstrip())
110
111    def test_hipify_cross_platform(self) -> None:
112        assert len(TEST_CODES) == len(HIP_CODES)
113        for i in range(len(TEST_CODES)):
114            hip_result = maybe_hipify_code_wrapper(TEST_CODES[i], True)
115            result = maybe_hipify_code_wrapper(TEST_CODES[i])
116            if torch.version.hip is not None:
117                self.assertEqual(result, hip_result)
118            else:
119                self.assertEqual(result, TEST_CODES[i])
120
121
122if __name__ == "__main__":
123    run_tests()
124