1from __future__ import annotations 2 3import itertools 4from typing import Sequence 5 6from torchgen.api import cpp 7from torchgen.api.types import DispatcherSignature 8from torchgen.code_template import CodeTemplate 9from torchgen.context import with_native_function 10from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments 11from torchgen.utils import FileManager 12 13 14# Note [Manual Backend kernels] 15# For these ops, we want to manually register to dispatch key Backend and 16# skip codegen-ed registeration to all keys before Backend. 17# For codegen this means: 18# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml 19# where we skip codegen backend kernels 20# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration 21# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration 22# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. 23# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp 24MANUAL_BACKEND = { 25 "options", 26 "data", 27 "set_data", 28 "is_leaf", 29 "output_nr", 30 "_version", 31 "retain_grad", 32 "_backward", 33 "requires_grad_", 34} 35 36# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. 37# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp 38MANUAL_AUTOGRAD_AND_TRACER = { 39 "resize_", 40 "resize_as_", 41 "detach", 42 "detach_", 43 "copy_", 44 "_fw_primal", 45 "_make_dual", 46} 47 48# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: 49# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) 50# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp 51MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER 52 53# These functions we don't want to record for tracing, because we always want 54# to trace their constituent parts. This is a temporary hack in lieue 55# of proper scopes, where subsequent compilation passes can ask for the unfolding 56# on demand. Only concrete ATen methods can be disabled this way; it will have 57# NO EFFECT otherwise. 58DONT_RECORD_TRACE = { 59 "convolution", 60 "conv1d", 61 "conv2d", 62 "conv3d", 63 "conv_transpose1d", 64 "conv_transpose2d", 65 "conv_transpose3d", 66 "lstm_cell", 67 "gru_cell", 68 "rnn_tanh_cell", 69 "rnn_relu_cell", 70 # FIXME: figure out a better way when we support sparse tensors in jit 71 "_coalesced", 72} 73 74 75def should_trace(f: NativeFunction) -> bool: 76 # Operations involving Storage or Type are not traceable at the moment 77 if any( 78 str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"} 79 for arg in f.func.schema_order_arguments() 80 ): 81 return False 82 # We can't trace functions which don't have any Tensor or TensorList returns 83 if not any(r.type.is_tensor_like() for r in f.func.returns): 84 return False 85 return f.func.name.name.base not in DONT_RECORD_TRACE 86 87 88SELECT = CodeTemplate( 89 """\ 90 91if (${cond}) { 92 ${true} 93} else { 94 ${false} 95} 96""" 97) 98 99OP_NAME = CodeTemplate( 100 """\ 101op_name = c10::Symbol::fromQualString("aten::${trace_name}"); 102""" 103) 104 105# These functions have their names recorded under trace renamed, 106RENAME_TRACE = { 107 "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like 108 "fill": "full_like", # replacing aten::fill_ with aten::full_like 109} 110 111 112def format_trace_op_name(f: NativeFunction) -> str: 113 # TODO: byte-for-byte compatible with old codegen behavior - should clean up 114 if ( 115 f.func.kind() in (SchemaKind.functional, SchemaKind.out) 116 or f.func.name.name.dunder_method 117 ): 118 # special case for *_out functions: the in-place and out-of-place ops 119 # are overloaded with the same name in the JIT 120 trace_name = str(f.func.name.name) 121 trace_name = RENAME_TRACE.get(trace_name, trace_name) 122 return OP_NAME.substitute(trace_name=trace_name) 123 124 # otherwise, this is an in-place op and we need to emit both in- and 125 # out-of-place versions 126 outplace_trace_name = f.func.name.name.base 127 inplace_trace_name = cpp.name(f.func) 128 outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) 129 inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) 130 131 return SELECT.substitute( 132 cond="tracer_state->force_outplace", 133 true=OP_NAME.substitute(trace_name=outplace_trace_name), 134 false=OP_NAME.substitute(trace_name=inplace_trace_name), 135 ) 136 137 138ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") 139 140 141def format_trace_inputs(f: NativeFunction) -> str: 142 def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]: 143 if isinstance(arg, TensorOptionsArguments): 144 name = "options" 145 return [ 146 ADD_TRACE_INPUT.substitute( 147 name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())" 148 ), 149 ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"), 150 ADD_TRACE_INPUT.substitute(name=name, input="options.device()"), 151 ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"), 152 ] 153 else: 154 name = arg.name 155 if str(arg.type) == "Tensor?[]": 156 return [f'jit::tracer::addInputs(node, "{name}", {name});'] 157 else: 158 return [ADD_TRACE_INPUT.substitute(name=name, input=name)] 159 160 args: list[Argument | TensorOptionsArguments] = list( 161 f.func.schema_order_arguments() 162 ) 163 164 if f.func.is_out_fn(): 165 # *_out functions take the result as a separate argument, but we don't want to 166 # trace that argument directly. Instead, we trace its TensorOptions. 167 # So first, we need to remove the out argument from the list of arguments to trace. 168 num_out_args = len(f.func.arguments.out) 169 args = args[:-num_out_args] 170 171 trace_inputs = itertools.chain.from_iterable( 172 dispatch_trace_input(arg) for arg in args 173 ) 174 175 if f.func.is_out_fn(): 176 # for *_out functions, handle the result argument differently for inplace/outplace. 177 # For inplace: just add the input to the end to confirm with the JIT schema 178 inplace = [ 179 ADD_TRACE_INPUT.substitute( 180 name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name 181 ) 182 for i in range(num_out_args) 183 ] 184 185 # for outplace: do nothing, except if the function is a factory. 186 # Factories are a bit special because their out-of-place overloads 187 # take an extra TensorOptions argument, which is missing in the _out function 188 has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) 189 has_tensor_input_arg = any( 190 a.type.is_tensor_like() for a in f.func.arguments.flat_non_out 191 ) 192 is_factory_method = f.category_override == "factory" or ( 193 has_tensor_return and not has_tensor_input_arg 194 ) 195 196 # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` 197 # flag for the whole family of ops with the same basename if any of them is a 198 # factory method. For most cases the whole family of ops are indeed all factory 199 # method - 'normal' is the only exception. So we handle it specially here to avoid 200 # cloning the old logic. 201 if f.func.name.name.base == "normal": 202 is_factory_method = True 203 204 if is_factory_method: 205 outplace = [ 206 ADD_TRACE_INPUT.substitute( 207 name="out", 208 input="c10::optTypeMetaToScalarType(out.options().dtype_opt())", 209 ), 210 ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"), 211 ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"), 212 ADD_TRACE_INPUT.substitute( 213 name="out", input="out.options().pinned_memory()" 214 ), 215 ] 216 else: 217 outplace = [] 218 219 trace_inputs = itertools.chain( 220 trace_inputs, 221 [ 222 SELECT.substitute( 223 cond="tracer_state->force_outplace", 224 true="\n".join(outplace), 225 false="\n".join(inplace), 226 ) 227 ], 228 ) 229 230 return "\n".join(trace_inputs) 231 232 233# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, 234# which force jit to replace functions with outplace variants (for 235# example `aten::add_` becomes `aten::add`). 236# 237# This replacement implemented in-place with minimum modifications of 238# arguments stack (as it assumes that outplace call has the same arguments 239# as inplace version). 240# 241# However there are no such substitutions available for `aten::fill_` 242# and `aten::zero_` operators, as we never implemented `aten::fill` 243# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with 244# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. 245# 246# But as they potentially can have different arguments, we also have 247# to hack into the stack and add missing ones. 248# 249# A possible alternative would be: 250# 251# - Add `aten::fill` and `aten::zero` 252# 253# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` 254# arguments (inside of the `native_functions.yaml`) 255RENAME_TRACE_ADD_ARGS = { 256 "fill": """\ 257 jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>()); 258 jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); 259 jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); 260 jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); 261 ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve; 262 jit::tracer::addInputs(node, "memory_format", memory_format); 263""", 264 "zero": """\ 265 jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>()); 266 jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); 267 jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); 268 jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); 269 ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve; 270 jit::tracer::addInputs(node, "memory_format", memory_format); 271""", 272} 273 274INPLACE_GUARD = CodeTemplate( 275 """\ 276jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); 277""" 278) 279 280PRE_RECORD_TRACE = CodeTemplate( 281 """\ 282torch::jit::Node* node = nullptr; 283std::shared_ptr<jit::tracer::TracingState> tracer_state; 284if (jit::tracer::isTracing()) { 285 tracer_state = jit::tracer::getTracingState(); 286 at::Symbol op_name; 287 ${set_op_name} 288 node = tracer_state->createNode(op_name, /*num_outputs=*/0); 289 jit::tracer::recordSourceLocation(node); 290 ${add_trace_inputs} 291 tracer_state->insertNode(node); 292 ${inplace_guard} 293 jit::tracer::setTracingState(nullptr); 294} 295""" 296) 297 298 299def format_prerecord_trace(f: NativeFunction) -> str: 300 if not should_trace(f): 301 return "" 302 303 # TODO: clean up old codegen behavior 304 is_inplace = ( 305 f.func.kind() in (SchemaKind.inplace, SchemaKind.out) 306 and not f.func.name.name.dunder_method 307 ) 308 add_args = ( 309 RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" 310 ) 311 additional_inputs = ( 312 SELECT.substitute( 313 cond="tracer_state->force_outplace", 314 true=add_args, 315 false="", 316 ) 317 if add_args 318 else "" 319 ) 320 321 return PRE_RECORD_TRACE.substitute( 322 set_op_name=format_trace_op_name(f), 323 add_trace_inputs=format_trace_inputs(f) + additional_inputs, 324 inplace_guard=INPLACE_GUARD.substitute( 325 name=cpp.name(f.func), 326 mutable_input=f.func.arguments.out[0].name 327 if f.func.arguments.out 328 else "self", 329 ) 330 if is_inplace 331 else "", 332 ) 333 334 335POST_RECORD_TRACE = CodeTemplate( 336 """\ 337if (tracer_state) { 338 jit::tracer::setTracingState(std::move(tracer_state)); 339 ${add_trace_outputs} 340} 341""" 342) 343 344 345def format_postrecord_trace(f: NativeFunction) -> str: 346 if not should_trace(f): 347 return "" 348 349 # For outplacing ops, *_out overloads require special handling to move the 350 # output *argument* to a return value 351 if f.func.is_out_fn(): 352 output_names_outplace = [arg.name for arg in f.func.arguments.out] 353 output_names_inplace = cpp.return_names(f) 354 355 # Code size optimization: the common case is that the return value is 356 # the same for both variants 357 if output_names_outplace == output_names_inplace: 358 outputs = [ 359 f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace 360 ] 361 return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) 362 363 selection = SELECT.substitute( 364 cond="force_outplace", 365 true="\n".join( 366 f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace 367 ), 368 false="\n".join( 369 f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace 370 ), 371 ) 372 return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) 373 else: 374 output_names = cpp.return_names(f) 375 outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] 376 return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) 377 378 379def tie_return_values(f: NativeFunction) -> str: 380 if len(f.func.returns) == 1: 381 return f'auto {f.func.returns[0].name or "result"}' 382 names = cpp.return_names(f) 383 return f'auto [{", ".join(names)}]' 384 385 386def get_return_value(f: NativeFunction) -> str: 387 names = cpp.return_names(f) 388 if len(f.func.returns) == 1: 389 return names[0] 390 if f.func.kind() == SchemaKind.out: 391 return f'std::forward_as_tuple({", ".join(names)})' 392 else: 393 moved = ", ".join(f"std::move({name})" for name in names) 394 return f"std::make_tuple({moved})" 395 396 397TRACE_DISPATCH = CodeTemplate( 398 """\ 399${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""" 400) 401 402 403def emit_trace_body(f: NativeFunction) -> list[str]: 404 trace_body: list[str] = [] 405 406 trace_body.append(format_prerecord_trace(f)) 407 408 dispatcher_sig = DispatcherSignature.from_schema(f.func) 409 dispatcher_exprs = dispatcher_sig.exprs() 410 411 # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. 412 # See Note [Plumbing Keys Through The Dispatcher] for details. 413 dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" 414 redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) 415 416 assign_return_values = ( 417 f"{tie_return_values(f)} = " 418 if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] 419 and f.func.returns 420 else "" 421 ) 422 423 # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. 424 # We could probably work harder to ensure that the fast variants are 425 # called instead, but the perf benefit would be minimal. 426 trace_body.append( 427 TRACE_DISPATCH.substitute( 428 assign_return_values=assign_return_values, 429 unambiguous_name=f.func.name.unambiguous_name(), 430 unpacked_args=redispatch_args, 431 ) 432 ) 433 434 trace_body.append(format_postrecord_trace(f)) 435 if f.func.returns: 436 trace_body.append(f"return {get_return_value(f)};") 437 return trace_body 438 439 440METHOD_DEFINITION = CodeTemplate( 441 """\ 442${return_type} ${type_wrapper_name}(${formals}) { 443 ${type_definition_body} 444} 445""" 446) 447 448 449def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: 450 if f.func.name.overload_name: 451 name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" 452 else: 453 name = cpp.name(f.func) 454 455 # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. 456 # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, 457 # the key argument should not be passed. 458 # We do not append key if it is Default so that generated functions from 459 # before per-dispatch-key derivatives were added retain the same names. 460 if key != "Default": 461 name = name + f"_{key}" 462 return name 463 464 465@with_native_function 466def method_definition(f: NativeFunction) -> str: 467 assert cpp.name(f.func) not in MANUAL_TRACER 468 469 formals = ", ".join( 470 # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. 471 # See Note [Plumbing Keys Through The Dispatcher] for details. 472 ["c10::DispatchKeySet ks"] 473 + [ 474 f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' 475 for a in f.func.schema_order_arguments() 476 ] 477 ) 478 479 return METHOD_DEFINITION.substitute( 480 return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), 481 type_wrapper_name=type_wrapper_name(f), 482 formals=formals, 483 type_definition_body=emit_trace_body(f), 484 ) 485 486 487WRAPPER_REGISTRATION = CodeTemplate( 488 """\ 489m.impl("${name}", 490 TORCH_FN(${class_type}::${type_wrapper_name}) 491); 492""" 493) 494 495 496@with_native_function 497def method_registration(f: NativeFunction) -> str: 498 assert cpp.name(f.func) not in MANUAL_TRACER 499 500 return WRAPPER_REGISTRATION.substitute( 501 name=f.func.name, 502 type_wrapper_name=type_wrapper_name(f), 503 class_type="TraceType", 504 ) 505 506 507def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]: 508 return { 509 "ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"], 510 "trace_method_definitions": [method_definition(fn)], 511 "trace_wrapper_registrations": [method_registration(fn)], 512 } 513 514 515def gen_trace_type( 516 out: str, native_functions: list[NativeFunction], template_path: str 517) -> None: 518 # NOTE: see Note [Sharded File] at the top of the VariableType.cpp 519 # template regarding sharding of the generated files. 520 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 521 fm.write_sharded( 522 "TraceType.cpp", 523 [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], 524 key_fn=lambda fn: fn.root_name, 525 base_env={ 526 "generated_comment": "@" 527 + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp", 528 }, 529 env_callable=gen_trace_type_func, 530 num_shards=5, 531 sharded_keys={ 532 "ops_headers", 533 "trace_method_definitions", 534 "trace_wrapper_registrations", 535 }, 536 ) 537