1# mypy: allow-untyped-defs 2import functools 3import itertools 4import logging 5from typing import List, Optional 6from unittest.mock import patch 7 8import sympy 9 10import torch 11 12from ...autotune_process import CUDABenchmarkRequest, TensorMeta 13from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout 14from ...utils import IndentedBuffer, unique 15from ...virtualized import V 16from ..common import KernelTemplate 17from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel 18 19 20log = logging.getLogger(__name__) 21 22 23class CUDATemplate(KernelTemplate): 24 index_counter = itertools.count() 25 26 def __init__( 27 self, 28 name: str, 29 input_nodes: List[Buffer], 30 layout: Layout, 31 input_reorder: Optional[List[int]] = None, 32 ) -> None: 33 """ 34 35 Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. 36 37 Args: 38 name (str): The name of the CUDATemplate object. 39 input_nodes (List[IRNode]): A list of input IRNodes. 40 layout (Layout): The layout of the output buffer / tensor. 41 input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. 42 43 """ 44 super().__init__(name) 45 self.input_nodes = input_nodes 46 self.output_node: Buffer = Buffer("buf_out", layout) 47 self.input_reorder = input_reorder 48 self.layout = layout 49 50 def generate( # type: ignore[override] 51 self, 52 **kwargs, 53 ) -> CUDATemplateCaller: 54 """ 55 Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller 56 may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. 57 58 Args: 59 kwargs: Additional keyword arguments. 60 61 Returns: 62 A CUDATemplateCaller object representing the generated CUDA template caller. 63 """ 64 kernel_name = f"cuda_{self.name}" 65 with patch.object( 66 V.graph, "get_dtype", self._fake_get_dtype(self.output_node) 67 ), CUDATemplateKernel( 68 kernel_name=kernel_name, 69 ) as kernel: 70 code = self.render(kernel=kernel, **kwargs) 71 _, call_args, _, _ = kernel.args.python_argdefs() 72 log.debug("Generated Code:\n%s", code) 73 log.debug( 74 "Args: cpp_argdefs: %s, python_argdefs: %s", 75 kernel.args.cpp_argdefs(), 76 kernel.args.python_argdefs(), 77 ) 78 79 input_reorder = ( 80 self.input_reorder 81 if self.input_reorder is not None 82 else list(range(len(self.input_nodes))) 83 ) 84 expected_args = list( 85 unique(self.input_nodes[idx].get_name() for idx in input_reorder) 86 ) 87 expected_args.extend([self.output_node.get_name()]) 88 assert list(call_args)[: len(expected_args)] == expected_args, ( 89 call_args, 90 expected_args, 91 ) 92 extra_args = V.graph.sizevars.size_hints( 93 map(sympy.expand, call_args[len(expected_args) :]) 94 ) 95 96 kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}" 97 98 # create the BenchmarkRequest 99 bmreq = CUDABenchmarkRequest( 100 kernel_name=kernel_name, 101 input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), 102 output_tensor_meta=TensorMeta.from_irnodes(self.output_node), 103 extra_args=extra_args, 104 source_code=code, 105 ) 106 107 def make_kernel_render( 108 template_node: CUDATemplateBuffer, 109 epilogue_nodes: Optional[List[IRNode]] = None, 110 ): 111 kernel = CUDATemplateKernel( 112 kernel_name="KERNEL_NAME", 113 ) 114 render = functools.partial( 115 self.render, 116 kernel=kernel, 117 template_buffer_node=template_node, 118 epilogue_nodes=epilogue_nodes, 119 **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate 120 ) 121 return kernel, render 122 123 return CUDATemplateCaller( 124 kernel_hash_name, 125 self.name, 126 self.input_nodes, 127 self.output_node.get_layout(), 128 make_kernel_render, 129 bmreq, 130 self, 131 kwargs, 132 ) 133 134 def header(self) -> IndentedBuffer: 135 res = IndentedBuffer() 136 res.splice( 137 """ 138 #include <exception> 139 #include <iostream> 140 #include <memory> 141 #include <random> 142 #include <vector> 143 """ 144 ) 145 return res 146 147 def globals(self) -> IndentedBuffer: 148 res = IndentedBuffer() 149 res.splice( 150 """ 151 // We compile all models with -fvisibility=hidden. Any symbols that need to be 152 // exposed in the final shared library must be declared with PT_EXPORT to make 153 // them visible. 154 #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) 155 #define PT_EXPORT __attribute__((__visibility__("default"))) 156 #else 157 #ifdef _WIN32 158 #define PT_EXPORT __declspec(dllexport) 159 #else 160 #define PT_EXPORT 161 #endif 162 #endif 163 using bfloat16 = nv_bfloat16; 164 """ 165 ) 166 return res 167 168 def render(self, **kwargs) -> str: 169 raise NotImplementedError 170 171 172class CUTLASSTemplate(CUDATemplate): 173 """ 174 CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the 175 CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. 176 """ 177 178 def header(self) -> IndentedBuffer: 179 res = super().header() 180 res.splice( 181 """ 182 #include "cute/tensor.hpp" 183 #include "cutlass/cutlass.h" 184 #include "cutlass/numeric_types.h" 185 #include "cutlass/tensor_ref.h" 186 #include "cutlass/util/host_tensor.h" 187 #include "cutlass/util/reference/host/tensor_fill.h" 188 #include "cutlass/util/reference/device/tensor_fill.h" 189 #include "cutlass/util/device_memory.h" 190 """ 191 ) 192 return res 193 194 def globals(self) -> IndentedBuffer: 195 res = super().globals() 196 res.splice( 197 """ 198 using namespace cute; 199 #define CUTLASS_CHECK(status) \\ 200 { \\ 201 cutlass::Status error = status; \\ 202 if (error != cutlass::Status::kSuccess) { \\ 203 auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ 204 cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ 205 throw std::runtime_error(msg); \\ 206 } \\ 207 } 208 209 // Used as pass-through functor in EVT just for type casting / rounding 210 template <typename T> 211 struct identity_op { 212 CUTLASS_HOST_DEVICE 213 T operator()(T val) const { return val; } 214 }; 215 216 """ 217 ) 218 return res 219 220 def cute_int(self, int_str: str, var_name: str) -> str: 221 res = "" 222 if int_str in {"1", "1L"}: 223 res = "cute::Int<1>{}" 224 else: 225 res = int_str 226 227 return f"{res} /* {var_name} */" 228 229 _DTYPE_TO_CUTLASS = { 230 torch.float32: "float", 231 torch.float64: "double", 232 torch.float16: "cutlass::half_t", 233 torch.int32: "int32_t", 234 torch.int16: "int16_t", 235 torch.int8: "int8_t", 236 torch.uint8: "uint8_t", 237 torch.bool: "bool", 238 torch.bfloat16: "cutlass::bfloat16_t", 239 } 240 241 _DTYPE_TO_CUTLASS_SPARSE_META = { 242 torch.int32: "uint32_t", 243 torch.int16: "uint16_t", 244 } 245 246 def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: 247 if node is None: 248 return ptr 249 else: 250 return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" 251 252 def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: 253 if node is None: 254 return ptr 255 else: 256 return ( 257 f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" 258 ) 259