• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import logging
3from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
4
5from ...autotune_process import CUDABenchmarkRequest
6from ...ir import (
7    Buffer,
8    ChoiceCaller,
9    CUDATemplateBuffer,
10    IRNode,
11    Layout,
12    PrimitiveInfoType,
13    TensorBox,
14)
15from ...utils import sympy_product
16from ...virtualized import V
17from ..common import IndentedBuffer, Kernel, OpOverrides
18from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
19
20
21if TYPE_CHECKING:
22    from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
23
24log = logging.getLogger(__name__)
25
26cexpr = CppPrinter().doprint
27
28
29def _normalize_idx(index: int, total_length: int) -> int:
30    return index if index >= 0 else index + total_length
31
32
33class CUDAKernel(Kernel):
34    """
35    Baseclass for CUDA / Cutlass based Kernels
36    """
37
38    overrides = OpOverrides  # type: ignore[assignment]
39
40
41class CUDATemplateKernel(CUDAKernel):
42    """
43    Template kernels defined by CUDA / Cutlass in C++.
44    """
45
46    _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
47
48    def __init__(self, kernel_name) -> None:
49        """
50        Initializes a new instance of the CUDATemplateKernel class.
51
52        Args:
53            kernel_name (str): The name of the kernel.
54        """
55        super().__init__()
56        self.kernel_name = kernel_name
57        # Mapping from arg name to IRNode.
58        self.named_nodes: Dict[str, IRNode] = {}
59
60    def arg_name(self, node: IRNode) -> Optional[str]:
61        """
62        Returns arg name of a given input or output node.
63        """
64        if node is None:
65            return None
66        return {**self.args.input_buffers, **self.args.output_buffers}.get(
67            node.get_name(), None
68        )
69
70    def check_not_null(self, node: IRNode) -> str:
71        """
72        Generates code to check that a node is not null.
73        """
74
75        if node is None:
76            return ""
77
78        size_str = self.size(node, 0, -1)
79        name_str = self.arg_name(node)
80        if name_str is None:
81            return ""
82
83        res = IndentedBuffer(initial_indent=2)
84        res.tabwidth = 1
85        res.splice(
86            f"""
87            {{
88              if (!{name_str}) {{
89                int64_t {name_str}_size = {size_str};
90                if ({name_str}_size > 0) {{
91                  throw std::runtime_error("input {name_str} is null but size is not 0!");
92                }}
93              }}
94            }}
95            """
96        )
97        return res.getvalue()
98
99    def def_kernel(
100        self,
101        inputs: List[IRNode],
102        outputs: List[IRNode],
103        names_str: str = "",
104        input_reorder: Optional[List[int]] = None,
105    ) -> str:
106        """
107        Hook called from template code to generate function definition and
108        needed args.
109
110        Args:
111            inputs: List of input IRNodes
112            outputs: List of output IRNodes
113            names_str: Comma separated list of input + output argument names.
114            input_reorder: The actual order of input nodes.
115                           e.g. The template might have input argument defined as [X, W, Bias],
116                           and the actual input passed into this template could be [Bias, X, W].
117                           In this case, the `input_reorder` would be [2, 0, 1].
118        """
119
120        names = [x.strip() for x in names_str.strip().split(",")]
121        if len(inputs) + len(outputs) != len(names):
122            raise RuntimeError(
123                f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
124            )
125
126        if input_reorder is not None:
127            assert len(inputs) == len(input_reorder)
128        else:
129            input_reorder = list(range(len(inputs)))
130
131        for idx in input_reorder:
132            name = names[idx]
133            node = inputs[idx]
134            if node is not None:
135                self.named_nodes[name] = node
136                self.args.input_buffers[node.get_name()] = name
137
138        for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
139            if node is not None:
140                self.named_nodes[name] = node
141                self.args.output_buffers[node.get_name()] = name
142
143        arg_defs, *_ = self.args.cpp_argdefs()
144        return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
145
146    def call_kernel(
147        self,
148        name: str,
149        node: "CUDATemplateBuffer",  # type: ignore[name-defined]
150    ) -> None:
151        """
152        Generates code to call the kernel through V.graph.wrapper_code.
153        used from within torch._inductor.wrapper.WrapperCodeGen
154
155        name: Name of kernel function.
156        node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
157        as well as all required inputs and outputs.
158        """
159        wrapper = V.graph.wrapper_code
160        _, call_args, _, arg_types = self.args.python_argdefs()
161        # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
162        for i in range(len(call_args)):
163            if V.graph.is_unspec_arg(call_args[i]):
164                call_args[i] = call_args[i] + ".item()"
165            else:
166                call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
167
168        # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
169        # workspace_size should have already been retrieved prior to this call.
170        call_args.append("None")
171
172        if node.get_workspace_size() > 0:
173            wrapper.generate_workspace_allocation(
174                node.get_workspace_size(), V.graph.scheduler.current_device, False
175            )
176            call_args.append("c_void_p(workspace.data_ptr())")
177        else:
178            call_args.append("None")
179
180        wrapper.generate_kernel_call(
181            name,
182            call_args,
183            cuda=True,
184            triton=False,
185            arg_types=arg_types,
186        )
187        if node.get_workspace_size() > 0:
188            wrapper.writeline(wrapper.make_free_by_names(["workspace"]))
189
190    def dtype(self, node: IRNode) -> Optional[str]:
191        """
192        Generates code which represents dtype of a given node.
193        """
194
195        if node is None:
196            return "void"
197        return DTYPE_TO_CPP.get(node.get_layout().dtype)
198
199    def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
200        # Helper method, called into from CUTLASSGemmTemplate
201        if node is None:
202            return default_dtype
203        from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
204
205        return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
206
207    def max_valid_index(self, node: IRNode, default=-1):
208        # Helper method, called into from CUTLASSGemmTemplate
209        if node is None:
210            return default
211        max_valid_offset = 0
212        for i in range(len(node.get_size())):
213            max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
214        return max_valid_offset
215
216    def offset(self, node: IRNode) -> str:
217        """
218        Generates code which represents offset of a given node.
219        """
220
221        if node is None:
222            return "0"
223        return str(node.get_layout().offset)
224
225    def ptr(self, node: IRNode) -> str:
226        """
227        Generates code which represents pointer of a given node.
228        """
229
230        if node is None:
231            return "nullptr"
232        arg_name = self.arg_name(node)
233        if arg_name is None:
234            return "nullptr"
235        offset = self.offset(node)
236        return arg_name if offset == "0" else f"{arg_name} + {offset}"
237
238    def size(
239        self,
240        node: IRNode,
241        start_index: int,
242        end_index: Optional[int] = None,
243        default_value: int = 0,
244    ) -> str:
245        """
246        Hook called from template code to get the size of an arg.
247        Generates code which represents size of a given node in [start_index, end_index).
248        If node is None, returns default_value.
249
250        TODO: Will add needed args to pass it in if it is dynamic.
251        """
252
253        if node is None:
254            return str(default_value)
255
256        start_index = _normalize_idx(start_index, len(node.get_size()))
257        if end_index is None:
258            end_index = start_index
259        end_index = _normalize_idx(end_index, len(node.get_size()))
260
261        sizes = node.get_size()[start_index : end_index + 1]
262        if len(sizes) == 0:
263            return str(default_value)
264
265        val = sympy_product(sizes)
266        return cexpr(self.rename_indexing(val))
267
268    def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
269        """
270        Hook called from template code to get the stride of an arg.
271        Generates code which represents stride of a given node at index.
272        If node is None, returns default_value.
273
274        TODO: Will add needed args to pass it in if it is dynamic.
275        """
276
277        if node is None:
278            return str(default_value)
279
280        index = _normalize_idx(index, len(node.get_size()))
281        if index < 0:
282            return str(default_value)
283
284        stride = node.get_stride()[index]
285        return cexpr(self.rename_indexing(stride))
286
287    def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
288        """
289        Hook called from template code to get the row or column stride of an arg.
290        This is required by some CUTLASS 2.X APIs.
291        If the node is in row_major, it returns stride[-2].
292        If the node is in column_major, it returns stride[-1].
293
294        TODO: Will add needed args to pass it in if it is dynamic.
295        """
296
297        if node is None or len(node.get_stride()) < 2:
298            return str(default_value)
299
300        stride0 = node.get_stride()[-1]
301        stride1 = node.get_stride()[-2]
302        if stride0 == 1:
303            return cexpr(self.rename_indexing(stride1))
304        elif stride1 == 1:
305            return cexpr(self.rename_indexing(stride0))
306        else:
307            raise RuntimeError(
308                f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
309            )
310
311
312class CUDATemplateCaller(ChoiceCaller):
313    """
314    CUDATemplateCaller
315
316    This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
317    Attributes:
318        name (str): The name of the caller.
319        category (str): The category of the caller.
320        bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
321        template_buffer (CUDATemplateBuffer): The template buffer for the caller.
322    """
323
324    def __init__(
325        self,
326        name: str,
327        category: str,
328        input_nodes: List[Buffer],
329        layout: Layout,
330        make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str],
331        bmreq: CUDABenchmarkRequest,
332        template: "CUDATemplate",  # type: ignore[name-defined]
333        info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]],  # type: ignore[type-arg]
334    ) -> None:
335        super().__init__(name, input_nodes, layout)
336        self.category = category
337        self.make_kernel_render = make_kernel_render
338        self.bmreq = bmreq
339        self.template = template
340        self.info_kwargs = info_kwargs
341
342    def precompile(self) -> None:
343        assert self.bmreq is not None
344        self.bmreq.precompile()
345
346    def benchmark(self, *args, out) -> float:
347        assert self.bmreq is not None
348        return self.bmreq.benchmark(
349            *args, output_tensor=out
350        )  # @TODO: Hack for ensuring that Cutlass Kernel is preferred
351
352    def __str__(self) -> str:
353        return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
354
355    def call_name(self) -> str:
356        return f"cuda_template_kernels.{self.name}"
357
358    def hash_key(self) -> str:
359        return "-".join(
360            [
361                self.category,
362                self.bmreq.hash_key,
363            ]
364        )
365
366    def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
367        """Information returned here is logged to the autotune log file when that is enabled."""
368        if self.info_kwargs is not None and "op" in self.info_kwargs:
369            op: Any = self.info_kwargs["op"]
370            return {
371                "backend": "CUDA",
372                "op_type": type(op).__name__,
373                "op_conf_name": str(op.configuration_name()),
374                "op_arch": str(op.arch),
375                "tile_shape": str(op.tile_description.tile_shape),
376                "epilogue_schedule": str(op.epilogue_schedule),
377                "kernel_schedule": str(op.kernel_schedule),
378                "element_accumulator": str(op.accumulator_type()),
379                "op_name": str(op.procedural_name()),
380                "instruction_shape": str(
381                    op.tile_description.math_instruction.instruction_shape
382                ),
383            }
384        else:
385            return {"backend": "CUDA", "op_type": "unknown"}
386
387    def output_node(self) -> TensorBox:
388        self.bmreq.update_workspace_size()
389        return TensorBox.create(
390            CUDATemplateBuffer(
391                layout=self.layout,
392                inputs=self.input_nodes,
393                make_kernel_render=self.make_kernel_render,
394                workspace_size=self.bmreq.workspace_size,
395                template=self.template,
396            )
397        )
398