• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import logging
5from typing import List, Optional
6from unittest.mock import patch
7
8import sympy
9
10import torch
11
12from ...autotune_process import CUDABenchmarkRequest, TensorMeta
13from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
14from ...utils import IndentedBuffer, unique
15from ...virtualized import V
16from ..common import KernelTemplate
17from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
18
19
20log = logging.getLogger(__name__)
21
22
23class CUDATemplate(KernelTemplate):
24    index_counter = itertools.count()
25
26    def __init__(
27        self,
28        name: str,
29        input_nodes: List[Buffer],
30        layout: Layout,
31        input_reorder: Optional[List[int]] = None,
32    ) -> None:
33        """
34
35        Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
36
37        Args:
38            name (str): The name of the CUDATemplate object.
39            input_nodes (List[IRNode]): A list of input IRNodes.
40            layout (Layout): The layout of the output buffer / tensor.
41            input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
42
43        """
44        super().__init__(name)
45        self.input_nodes = input_nodes
46        self.output_node: Buffer = Buffer("buf_out", layout)
47        self.input_reorder = input_reorder
48        self.layout = layout
49
50    def generate(  # type: ignore[override]
51        self,
52        **kwargs,
53    ) -> CUDATemplateCaller:
54        """
55        Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
56        may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
57
58        Args:
59            kwargs: Additional keyword arguments.
60
61        Returns:
62            A CUDATemplateCaller object representing the generated CUDA template caller.
63        """
64        kernel_name = f"cuda_{self.name}"
65        with patch.object(
66            V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
67        ), CUDATemplateKernel(
68            kernel_name=kernel_name,
69        ) as kernel:
70            code = self.render(kernel=kernel, **kwargs)
71            _, call_args, _, _ = kernel.args.python_argdefs()
72            log.debug("Generated Code:\n%s", code)
73            log.debug(
74                "Args: cpp_argdefs: %s, python_argdefs: %s",
75                kernel.args.cpp_argdefs(),
76                kernel.args.python_argdefs(),
77            )
78
79        input_reorder = (
80            self.input_reorder
81            if self.input_reorder is not None
82            else list(range(len(self.input_nodes)))
83        )
84        expected_args = list(
85            unique(self.input_nodes[idx].get_name() for idx in input_reorder)
86        )
87        expected_args.extend([self.output_node.get_name()])
88        assert list(call_args)[: len(expected_args)] == expected_args, (
89            call_args,
90            expected_args,
91        )
92        extra_args = V.graph.sizevars.size_hints(
93            map(sympy.expand, call_args[len(expected_args) :])
94        )
95
96        kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
97
98        # create the BenchmarkRequest
99        bmreq = CUDABenchmarkRequest(
100            kernel_name=kernel_name,
101            input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
102            output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
103            extra_args=extra_args,
104            source_code=code,
105        )
106
107        def make_kernel_render(
108            template_node: CUDATemplateBuffer,
109            epilogue_nodes: Optional[List[IRNode]] = None,
110        ):
111            kernel = CUDATemplateKernel(
112                kernel_name="KERNEL_NAME",
113            )
114            render = functools.partial(
115                self.render,
116                kernel=kernel,
117                template_buffer_node=template_node,
118                epilogue_nodes=epilogue_nodes,
119                **kwargs,  # includes "op" argument in case of CUTLASSGemmTemplate
120            )
121            return kernel, render
122
123        return CUDATemplateCaller(
124            kernel_hash_name,
125            self.name,
126            self.input_nodes,
127            self.output_node.get_layout(),
128            make_kernel_render,
129            bmreq,
130            self,
131            kwargs,
132        )
133
134    def header(self) -> IndentedBuffer:
135        res = IndentedBuffer()
136        res.splice(
137            """
138                #include <exception>
139                #include <iostream>
140                #include <memory>
141                #include <random>
142                #include <vector>
143            """
144        )
145        return res
146
147    def globals(self) -> IndentedBuffer:
148        res = IndentedBuffer()
149        res.splice(
150            """
151                // We compile all models with -fvisibility=hidden. Any symbols that need to be
152                // exposed in the final shared library must be declared with PT_EXPORT to make
153                // them visible.
154                #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
155                #define PT_EXPORT __attribute__((__visibility__("default")))
156                #else
157                #ifdef _WIN32
158                #define PT_EXPORT __declspec(dllexport)
159                #else
160                #define PT_EXPORT
161                #endif
162                #endif
163                using bfloat16 = nv_bfloat16;
164            """
165        )
166        return res
167
168    def render(self, **kwargs) -> str:
169        raise NotImplementedError
170
171
172class CUTLASSTemplate(CUDATemplate):
173    """
174    CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
175    CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
176    """
177
178    def header(self) -> IndentedBuffer:
179        res = super().header()
180        res.splice(
181            """
182                #include "cute/tensor.hpp"
183                #include "cutlass/cutlass.h"
184                #include "cutlass/numeric_types.h"
185                #include "cutlass/tensor_ref.h"
186                #include "cutlass/util/host_tensor.h"
187                #include "cutlass/util/reference/host/tensor_fill.h"
188                #include "cutlass/util/reference/device/tensor_fill.h"
189                #include "cutlass/util/device_memory.h"
190            """
191        )
192        return res
193
194    def globals(self) -> IndentedBuffer:
195        res = super().globals()
196        res.splice(
197            """
198                using namespace cute;
199                #define CUTLASS_CHECK(status)                                                      \\
200                {                                                                                  \\
201                  cutlass::Status error = status;                                                  \\
202                  if (error != cutlass::Status::kSuccess) {                                        \\
203                    auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " +             \\
204                        cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__);        \\
205                    throw std::runtime_error(msg);                                                 \\
206                  }                                                                                \\
207                }
208
209                // Used as pass-through functor in EVT just for type casting / rounding
210                template <typename T>
211                struct identity_op {
212                  CUTLASS_HOST_DEVICE
213                  T operator()(T val) const { return val; }
214                };
215
216            """
217        )
218        return res
219
220    def cute_int(self, int_str: str, var_name: str) -> str:
221        res = ""
222        if int_str in {"1", "1L"}:
223            res = "cute::Int<1>{}"
224        else:
225            res = int_str
226
227        return f"{res} /* {var_name} */"
228
229    _DTYPE_TO_CUTLASS = {
230        torch.float32: "float",
231        torch.float64: "double",
232        torch.float16: "cutlass::half_t",
233        torch.int32: "int32_t",
234        torch.int16: "int16_t",
235        torch.int8: "int8_t",
236        torch.uint8: "uint8_t",
237        torch.bool: "bool",
238        torch.bfloat16: "cutlass::bfloat16_t",
239    }
240
241    _DTYPE_TO_CUTLASS_SPARSE_META = {
242        torch.int32: "uint32_t",
243        torch.int16: "uint16_t",
244    }
245
246    def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
247        if node is None:
248            return ptr
249        else:
250            return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
251
252    def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str:
253        if node is None:
254            return ptr
255        else:
256            return (
257                f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
258            )
259