1# mypy: allow-untyped-defs 2import logging 3from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union 4 5from ...autotune_process import CUDABenchmarkRequest 6from ...ir import ( 7 Buffer, 8 ChoiceCaller, 9 CUDATemplateBuffer, 10 IRNode, 11 Layout, 12 PrimitiveInfoType, 13 TensorBox, 14) 15from ...utils import sympy_product 16from ...virtualized import V 17from ..common import IndentedBuffer, Kernel, OpOverrides 18from ..cpp_utils import CppPrinter, DTYPE_TO_CPP 19 20 21if TYPE_CHECKING: 22 from torch._inductor.codegen.cuda.cuda_template import CUDATemplate 23 24log = logging.getLogger(__name__) 25 26cexpr = CppPrinter().doprint 27 28 29def _normalize_idx(index: int, total_length: int) -> int: 30 return index if index >= 0 else index + total_length 31 32 33class CUDAKernel(Kernel): 34 """ 35 Baseclass for CUDA / Cutlass based Kernels 36 """ 37 38 overrides = OpOverrides # type: ignore[assignment] 39 40 41class CUDATemplateKernel(CUDAKernel): 42 """ 43 Template kernels defined by CUDA / Cutlass in C++. 44 """ 45 46 _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" 47 48 def __init__(self, kernel_name) -> None: 49 """ 50 Initializes a new instance of the CUDATemplateKernel class. 51 52 Args: 53 kernel_name (str): The name of the kernel. 54 """ 55 super().__init__() 56 self.kernel_name = kernel_name 57 # Mapping from arg name to IRNode. 58 self.named_nodes: Dict[str, IRNode] = {} 59 60 def arg_name(self, node: IRNode) -> Optional[str]: 61 """ 62 Returns arg name of a given input or output node. 63 """ 64 if node is None: 65 return None 66 return {**self.args.input_buffers, **self.args.output_buffers}.get( 67 node.get_name(), None 68 ) 69 70 def check_not_null(self, node: IRNode) -> str: 71 """ 72 Generates code to check that a node is not null. 73 """ 74 75 if node is None: 76 return "" 77 78 size_str = self.size(node, 0, -1) 79 name_str = self.arg_name(node) 80 if name_str is None: 81 return "" 82 83 res = IndentedBuffer(initial_indent=2) 84 res.tabwidth = 1 85 res.splice( 86 f""" 87 {{ 88 if (!{name_str}) {{ 89 int64_t {name_str}_size = {size_str}; 90 if ({name_str}_size > 0) {{ 91 throw std::runtime_error("input {name_str} is null but size is not 0!"); 92 }} 93 }} 94 }} 95 """ 96 ) 97 return res.getvalue() 98 99 def def_kernel( 100 self, 101 inputs: List[IRNode], 102 outputs: List[IRNode], 103 names_str: str = "", 104 input_reorder: Optional[List[int]] = None, 105 ) -> str: 106 """ 107 Hook called from template code to generate function definition and 108 needed args. 109 110 Args: 111 inputs: List of input IRNodes 112 outputs: List of output IRNodes 113 names_str: Comma separated list of input + output argument names. 114 input_reorder: The actual order of input nodes. 115 e.g. The template might have input argument defined as [X, W, Bias], 116 and the actual input passed into this template could be [Bias, X, W]. 117 In this case, the `input_reorder` would be [2, 0, 1]. 118 """ 119 120 names = [x.strip() for x in names_str.strip().split(",")] 121 if len(inputs) + len(outputs) != len(names): 122 raise RuntimeError( 123 f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" 124 ) 125 126 if input_reorder is not None: 127 assert len(inputs) == len(input_reorder) 128 else: 129 input_reorder = list(range(len(inputs))) 130 131 for idx in input_reorder: 132 name = names[idx] 133 node = inputs[idx] 134 if node is not None: 135 self.named_nodes[name] = node 136 self.args.input_buffers[node.get_name()] = name 137 138 for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): 139 if node is not None: 140 self.named_nodes[name] = node 141 self.args.output_buffers[node.get_name()] = name 142 143 arg_defs, *_ = self.args.cpp_argdefs() 144 return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" 145 146 def call_kernel( 147 self, 148 name: str, 149 node: "CUDATemplateBuffer", # type: ignore[name-defined] 150 ) -> None: 151 """ 152 Generates code to call the kernel through V.graph.wrapper_code. 153 used from within torch._inductor.wrapper.WrapperCodeGen 154 155 name: Name of kernel function. 156 node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes 157 as well as all required inputs and outputs. 158 """ 159 wrapper = V.graph.wrapper_code 160 _, call_args, _, arg_types = self.args.python_argdefs() 161 # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar 162 for i in range(len(call_args)): 163 if V.graph.is_unspec_arg(call_args[i]): 164 call_args[i] = call_args[i] + ".item()" 165 else: 166 call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" 167 168 # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. 169 # workspace_size should have already been retrieved prior to this call. 170 call_args.append("None") 171 172 if node.get_workspace_size() > 0: 173 wrapper.generate_workspace_allocation( 174 node.get_workspace_size(), V.graph.scheduler.current_device, False 175 ) 176 call_args.append("c_void_p(workspace.data_ptr())") 177 else: 178 call_args.append("None") 179 180 wrapper.generate_kernel_call( 181 name, 182 call_args, 183 cuda=True, 184 triton=False, 185 arg_types=arg_types, 186 ) 187 if node.get_workspace_size() > 0: 188 wrapper.writeline(wrapper.make_free_by_names(["workspace"])) 189 190 def dtype(self, node: IRNode) -> Optional[str]: 191 """ 192 Generates code which represents dtype of a given node. 193 """ 194 195 if node is None: 196 return "void" 197 return DTYPE_TO_CPP.get(node.get_layout().dtype) 198 199 def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: 200 # Helper method, called into from CUTLASSGemmTemplate 201 if node is None: 202 return default_dtype 203 from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate 204 205 return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] 206 207 def max_valid_index(self, node: IRNode, default=-1): 208 # Helper method, called into from CUTLASSGemmTemplate 209 if node is None: 210 return default 211 max_valid_offset = 0 212 for i in range(len(node.get_size())): 213 max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] 214 return max_valid_offset 215 216 def offset(self, node: IRNode) -> str: 217 """ 218 Generates code which represents offset of a given node. 219 """ 220 221 if node is None: 222 return "0" 223 return str(node.get_layout().offset) 224 225 def ptr(self, node: IRNode) -> str: 226 """ 227 Generates code which represents pointer of a given node. 228 """ 229 230 if node is None: 231 return "nullptr" 232 arg_name = self.arg_name(node) 233 if arg_name is None: 234 return "nullptr" 235 offset = self.offset(node) 236 return arg_name if offset == "0" else f"{arg_name} + {offset}" 237 238 def size( 239 self, 240 node: IRNode, 241 start_index: int, 242 end_index: Optional[int] = None, 243 default_value: int = 0, 244 ) -> str: 245 """ 246 Hook called from template code to get the size of an arg. 247 Generates code which represents size of a given node in [start_index, end_index). 248 If node is None, returns default_value. 249 250 TODO: Will add needed args to pass it in if it is dynamic. 251 """ 252 253 if node is None: 254 return str(default_value) 255 256 start_index = _normalize_idx(start_index, len(node.get_size())) 257 if end_index is None: 258 end_index = start_index 259 end_index = _normalize_idx(end_index, len(node.get_size())) 260 261 sizes = node.get_size()[start_index : end_index + 1] 262 if len(sizes) == 0: 263 return str(default_value) 264 265 val = sympy_product(sizes) 266 return cexpr(self.rename_indexing(val)) 267 268 def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: 269 """ 270 Hook called from template code to get the stride of an arg. 271 Generates code which represents stride of a given node at index. 272 If node is None, returns default_value. 273 274 TODO: Will add needed args to pass it in if it is dynamic. 275 """ 276 277 if node is None: 278 return str(default_value) 279 280 index = _normalize_idx(index, len(node.get_size())) 281 if index < 0: 282 return str(default_value) 283 284 stride = node.get_stride()[index] 285 return cexpr(self.rename_indexing(stride)) 286 287 def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: 288 """ 289 Hook called from template code to get the row or column stride of an arg. 290 This is required by some CUTLASS 2.X APIs. 291 If the node is in row_major, it returns stride[-2]. 292 If the node is in column_major, it returns stride[-1]. 293 294 TODO: Will add needed args to pass it in if it is dynamic. 295 """ 296 297 if node is None or len(node.get_stride()) < 2: 298 return str(default_value) 299 300 stride0 = node.get_stride()[-1] 301 stride1 = node.get_stride()[-2] 302 if stride0 == 1: 303 return cexpr(self.rename_indexing(stride1)) 304 elif stride1 == 1: 305 return cexpr(self.rename_indexing(stride0)) 306 else: 307 raise RuntimeError( 308 f"At least 1 stride should be 1. Strides: {node.get_stride()=}" 309 ) 310 311 312class CUDATemplateCaller(ChoiceCaller): 313 """ 314 CUDATemplateCaller 315 316 This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. 317 Attributes: 318 name (str): The name of the caller. 319 category (str): The category of the caller. 320 bmreq (CUDABenchmarkRequest): The benchmark request for the caller. 321 template_buffer (CUDATemplateBuffer): The template buffer for the caller. 322 """ 323 324 def __init__( 325 self, 326 name: str, 327 category: str, 328 input_nodes: List[Buffer], 329 layout: Layout, 330 make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], 331 bmreq: CUDABenchmarkRequest, 332 template: "CUDATemplate", # type: ignore[name-defined] 333 info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] 334 ) -> None: 335 super().__init__(name, input_nodes, layout) 336 self.category = category 337 self.make_kernel_render = make_kernel_render 338 self.bmreq = bmreq 339 self.template = template 340 self.info_kwargs = info_kwargs 341 342 def precompile(self) -> None: 343 assert self.bmreq is not None 344 self.bmreq.precompile() 345 346 def benchmark(self, *args, out) -> float: 347 assert self.bmreq is not None 348 return self.bmreq.benchmark( 349 *args, output_tensor=out 350 ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred 351 352 def __str__(self) -> str: 353 return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" 354 355 def call_name(self) -> str: 356 return f"cuda_template_kernels.{self.name}" 357 358 def hash_key(self) -> str: 359 return "-".join( 360 [ 361 self.category, 362 self.bmreq.hash_key, 363 ] 364 ) 365 366 def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: 367 """Information returned here is logged to the autotune log file when that is enabled.""" 368 if self.info_kwargs is not None and "op" in self.info_kwargs: 369 op: Any = self.info_kwargs["op"] 370 return { 371 "backend": "CUDA", 372 "op_type": type(op).__name__, 373 "op_conf_name": str(op.configuration_name()), 374 "op_arch": str(op.arch), 375 "tile_shape": str(op.tile_description.tile_shape), 376 "epilogue_schedule": str(op.epilogue_schedule), 377 "kernel_schedule": str(op.kernel_schedule), 378 "element_accumulator": str(op.accumulator_type()), 379 "op_name": str(op.procedural_name()), 380 "instruction_shape": str( 381 op.tile_description.math_instruction.instruction_shape 382 ), 383 } 384 else: 385 return {"backend": "CUDA", "op_type": "unknown"} 386 387 def output_node(self) -> TensorBox: 388 self.bmreq.update_workspace_size() 389 return TensorBox.create( 390 CUDATemplateBuffer( 391 layout=self.layout, 392 inputs=self.input_nodes, 393 make_kernel_render=self.make_kernel_render, 394 workspace_size=self.bmreq.workspace_size, 395 template=self.template, 396 ) 397 ) 398