• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import itertools
4from typing import Sequence
5
6from torchgen.api import cpp
7from torchgen.api.types import DispatcherSignature
8from torchgen.code_template import CodeTemplate
9from torchgen.context import with_native_function
10from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
11from torchgen.utils import FileManager
12
13
14# Note [Manual Backend kernels]
15# For these ops, we want to manually register to dispatch key Backend and
16# skip codegen-ed registeration to all keys before Backend.
17# For codegen this means:
18#   - op set below must match ops with manual_kernel_registration=True in native_functions.yaml
19#     where we skip codegen backend kernels
20#   - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration
21#   - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
22# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
23# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
24MANUAL_BACKEND = {
25    "options",
26    "data",
27    "set_data",
28    "is_leaf",
29    "output_nr",
30    "_version",
31    "retain_grad",
32    "_backward",
33    "requires_grad_",
34}
35
36# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
37# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
38MANUAL_AUTOGRAD_AND_TRACER = {
39    "resize_",
40    "resize_as_",
41    "detach",
42    "detach_",
43    "copy_",
44    "_fw_primal",
45    "_make_dual",
46}
47
48# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
49#   union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)
50# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
51MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER
52
53# These functions we don't want to record for tracing, because we always want
54# to trace their constituent parts.  This is a temporary hack in lieue
55# of proper scopes, where subsequent compilation passes can ask for the unfolding
56# on demand.  Only concrete ATen methods can be disabled this way; it will have
57# NO EFFECT otherwise.
58DONT_RECORD_TRACE = {
59    "convolution",
60    "conv1d",
61    "conv2d",
62    "conv3d",
63    "conv_transpose1d",
64    "conv_transpose2d",
65    "conv_transpose3d",
66    "lstm_cell",
67    "gru_cell",
68    "rnn_tanh_cell",
69    "rnn_relu_cell",
70    # FIXME: figure out a better way when we support sparse tensors in jit
71    "_coalesced",
72}
73
74
75def should_trace(f: NativeFunction) -> bool:
76    # Operations involving Storage or Type are not traceable at the moment
77    if any(
78        str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"}
79        for arg in f.func.schema_order_arguments()
80    ):
81        return False
82    # We can't trace functions which don't have any Tensor or TensorList returns
83    if not any(r.type.is_tensor_like() for r in f.func.returns):
84        return False
85    return f.func.name.name.base not in DONT_RECORD_TRACE
86
87
88SELECT = CodeTemplate(
89    """\
90
91if (${cond}) {
92  ${true}
93} else {
94  ${false}
95}
96"""
97)
98
99OP_NAME = CodeTemplate(
100    """\
101op_name = c10::Symbol::fromQualString("aten::${trace_name}");
102"""
103)
104
105# These functions have their names recorded under trace renamed,
106RENAME_TRACE = {
107    "zero": "zeros_like",  # replacing aten::zero_ with aten::zeros_like
108    "fill": "full_like",  # replacing aten::fill_ with aten::full_like
109}
110
111
112def format_trace_op_name(f: NativeFunction) -> str:
113    # TODO: byte-for-byte compatible with old codegen behavior - should clean up
114    if (
115        f.func.kind() in (SchemaKind.functional, SchemaKind.out)
116        or f.func.name.name.dunder_method
117    ):
118        # special case for *_out functions: the in-place and out-of-place ops
119        # are overloaded with the same name in the JIT
120        trace_name = str(f.func.name.name)
121        trace_name = RENAME_TRACE.get(trace_name, trace_name)
122        return OP_NAME.substitute(trace_name=trace_name)
123
124    # otherwise, this is an in-place op and we need to emit both in- and
125    # out-of-place versions
126    outplace_trace_name = f.func.name.name.base
127    inplace_trace_name = cpp.name(f.func)
128    outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
129    inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
130
131    return SELECT.substitute(
132        cond="tracer_state->force_outplace",
133        true=OP_NAME.substitute(trace_name=outplace_trace_name),
134        false=OP_NAME.substitute(trace_name=inplace_trace_name),
135    )
136
137
138ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")
139
140
141def format_trace_inputs(f: NativeFunction) -> str:
142    def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
143        if isinstance(arg, TensorOptionsArguments):
144            name = "options"
145            return [
146                ADD_TRACE_INPUT.substitute(
147                    name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())"
148                ),
149                ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"),
150                ADD_TRACE_INPUT.substitute(name=name, input="options.device()"),
151                ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"),
152            ]
153        else:
154            name = arg.name
155            if str(arg.type) == "Tensor?[]":
156                return [f'jit::tracer::addInputs(node, "{name}", {name});']
157            else:
158                return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
159
160    args: list[Argument | TensorOptionsArguments] = list(
161        f.func.schema_order_arguments()
162    )
163
164    if f.func.is_out_fn():
165        # *_out functions take the result as a separate argument, but we don't want to
166        # trace that argument directly. Instead, we trace its TensorOptions.
167        # So first, we need to remove the out argument from the list of arguments to trace.
168        num_out_args = len(f.func.arguments.out)
169        args = args[:-num_out_args]
170
171    trace_inputs = itertools.chain.from_iterable(
172        dispatch_trace_input(arg) for arg in args
173    )
174
175    if f.func.is_out_fn():
176        # for *_out functions, handle the result argument differently for inplace/outplace.
177        # For inplace: just add the input to the end to confirm with the JIT schema
178        inplace = [
179            ADD_TRACE_INPUT.substitute(
180                name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
181            )
182            for i in range(num_out_args)
183        ]
184
185        # for outplace: do nothing, except if the function is a factory.
186        # Factories are a bit special because their out-of-place overloads
187        # take an extra TensorOptions argument, which is missing in the _out function
188        has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
189        has_tensor_input_arg = any(
190            a.type.is_tensor_like() for a in f.func.arguments.flat_non_out
191        )
192        is_factory_method = f.category_override == "factory" or (
193            has_tensor_return and not has_tensor_input_arg
194        )
195
196        # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method`
197        # flag for the whole family of ops with the same basename if any of them is a
198        # factory method. For most cases the whole family of ops are indeed all factory
199        # method - 'normal' is the only exception. So we handle it specially here to avoid
200        # cloning the old logic.
201        if f.func.name.name.base == "normal":
202            is_factory_method = True
203
204        if is_factory_method:
205            outplace = [
206                ADD_TRACE_INPUT.substitute(
207                    name="out",
208                    input="c10::optTypeMetaToScalarType(out.options().dtype_opt())",
209                ),
210                ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"),
211                ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"),
212                ADD_TRACE_INPUT.substitute(
213                    name="out", input="out.options().pinned_memory()"
214                ),
215            ]
216        else:
217            outplace = []
218
219        trace_inputs = itertools.chain(
220            trace_inputs,
221            [
222                SELECT.substitute(
223                    cond="tracer_state->force_outplace",
224                    true="\n".join(outplace),
225                    false="\n".join(inplace),
226                )
227            ],
228        )
229
230    return "\n".join(trace_inputs)
231
232
233# `torch.jit.trace` have undocumented keyword argument `_force_outplace`,
234# which force jit to replace functions with outplace variants (for
235# example `aten::add_` becomes `aten::add`).
236#
237# This replacement implemented in-place with minimum modifications of
238# arguments stack (as it assumes that outplace call has the same arguments
239# as inplace version).
240#
241# However there are no such substitutions available for `aten::fill_`
242# and `aten::zero_` operators, as we never implemented `aten::fill`
243# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with
244# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`.
245#
246# But as they potentially can have different arguments, we also have
247# to hack into the stack and add missing ones.
248#
249# A possible alternative would be:
250#
251#  - Add `aten::fill` and `aten::zero`
252#
253#  - Or keep `aten::zeros_like` arguments aligned with `aten::zero_`
254# arguments (inside of the `native_functions.yaml`)
255RENAME_TRACE_ADD_ARGS = {
256    "fill": """\
257    jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>());
258    jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
259    jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
260    jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
261    ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
262    jit::tracer::addInputs(node, "memory_format", memory_format);
263""",
264    "zero": """\
265    jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>());
266    jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
267    jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
268    jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
269    ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
270    jit::tracer::addInputs(node, "memory_format", memory_format);
271""",
272}
273
274INPLACE_GUARD = CodeTemplate(
275    """\
276jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input});
277"""
278)
279
280PRE_RECORD_TRACE = CodeTemplate(
281    """\
282torch::jit::Node* node = nullptr;
283std::shared_ptr<jit::tracer::TracingState> tracer_state;
284if (jit::tracer::isTracing()) {
285  tracer_state = jit::tracer::getTracingState();
286  at::Symbol op_name;
287  ${set_op_name}
288  node = tracer_state->createNode(op_name, /*num_outputs=*/0);
289  jit::tracer::recordSourceLocation(node);
290  ${add_trace_inputs}
291  tracer_state->insertNode(node);
292  ${inplace_guard}
293  jit::tracer::setTracingState(nullptr);
294}
295"""
296)
297
298
299def format_prerecord_trace(f: NativeFunction) -> str:
300    if not should_trace(f):
301        return ""
302
303    # TODO: clean up old codegen behavior
304    is_inplace = (
305        f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
306        and not f.func.name.name.dunder_method
307    )
308    add_args = (
309        RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else ""
310    )
311    additional_inputs = (
312        SELECT.substitute(
313            cond="tracer_state->force_outplace",
314            true=add_args,
315            false="",
316        )
317        if add_args
318        else ""
319    )
320
321    return PRE_RECORD_TRACE.substitute(
322        set_op_name=format_trace_op_name(f),
323        add_trace_inputs=format_trace_inputs(f) + additional_inputs,
324        inplace_guard=INPLACE_GUARD.substitute(
325            name=cpp.name(f.func),
326            mutable_input=f.func.arguments.out[0].name
327            if f.func.arguments.out
328            else "self",
329        )
330        if is_inplace
331        else "",
332    )
333
334
335POST_RECORD_TRACE = CodeTemplate(
336    """\
337if (tracer_state) {
338  jit::tracer::setTracingState(std::move(tracer_state));
339  ${add_trace_outputs}
340}
341"""
342)
343
344
345def format_postrecord_trace(f: NativeFunction) -> str:
346    if not should_trace(f):
347        return ""
348
349    # For outplacing ops, *_out overloads require special handling to move the
350    # output *argument* to a return value
351    if f.func.is_out_fn():
352        output_names_outplace = [arg.name for arg in f.func.arguments.out]
353        output_names_inplace = cpp.return_names(f)
354
355        # Code size optimization: the common case is that the return value is
356        # the same for both variants
357        if output_names_outplace == output_names_inplace:
358            outputs = [
359                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
360            ]
361            return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
362
363        selection = SELECT.substitute(
364            cond="force_outplace",
365            true="\n".join(
366                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
367            ),
368            false="\n".join(
369                f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace
370            ),
371        )
372        return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
373    else:
374        output_names = cpp.return_names(f)
375        outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names]
376        return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
377
378
379def tie_return_values(f: NativeFunction) -> str:
380    if len(f.func.returns) == 1:
381        return f'auto {f.func.returns[0].name or "result"}'
382    names = cpp.return_names(f)
383    return f'auto [{", ".join(names)}]'
384
385
386def get_return_value(f: NativeFunction) -> str:
387    names = cpp.return_names(f)
388    if len(f.func.returns) == 1:
389        return names[0]
390    if f.func.kind() == SchemaKind.out:
391        return f'std::forward_as_tuple({", ".join(names)})'
392    else:
393        moved = ", ".join(f"std::move({name})" for name in names)
394        return f"std::make_tuple({moved})"
395
396
397TRACE_DISPATCH = CodeTemplate(
398    """\
399${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});"""
400)
401
402
403def emit_trace_body(f: NativeFunction) -> list[str]:
404    trace_body: list[str] = []
405
406    trace_body.append(format_prerecord_trace(f))
407
408    dispatcher_sig = DispatcherSignature.from_schema(f.func)
409    dispatcher_exprs = dispatcher_sig.exprs()
410
411    # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
412    # See Note [Plumbing Keys Through The Dispatcher] for details.
413    dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)"
414    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
415
416    assign_return_values = (
417        f"{tie_return_values(f)} = "
418        if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
419        and f.func.returns
420        else ""
421    )
422
423    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
424    # We could probably work harder to ensure that the fast variants are
425    # called instead, but the perf benefit would be minimal.
426    trace_body.append(
427        TRACE_DISPATCH.substitute(
428            assign_return_values=assign_return_values,
429            unambiguous_name=f.func.name.unambiguous_name(),
430            unpacked_args=redispatch_args,
431        )
432    )
433
434    trace_body.append(format_postrecord_trace(f))
435    if f.func.returns:
436        trace_body.append(f"return {get_return_value(f)};")
437    return trace_body
438
439
440METHOD_DEFINITION = CodeTemplate(
441    """\
442${return_type} ${type_wrapper_name}(${formals}) {
443  ${type_definition_body}
444}
445"""
446)
447
448
449def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str:
450    if f.func.name.overload_name:
451        name = f"{cpp.name(f.func)}_{f.func.name.overload_name}"
452    else:
453        name = cpp.name(f.func)
454
455    # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key.
456    # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated,
457    # the key argument should not be passed.
458    # We do not append key if it is Default so that generated functions from
459    # before per-dispatch-key derivatives were added retain the same names.
460    if key != "Default":
461        name = name + f"_{key}"
462    return name
463
464
465@with_native_function
466def method_definition(f: NativeFunction) -> str:
467    assert cpp.name(f.func) not in MANUAL_TRACER
468
469    formals = ", ".join(
470        # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
471        # See Note [Plumbing Keys Through The Dispatcher] for details.
472        ["c10::DispatchKeySet ks"]
473        + [
474            f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
475            for a in f.func.schema_order_arguments()
476        ]
477    )
478
479    return METHOD_DEFINITION.substitute(
480        return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
481        type_wrapper_name=type_wrapper_name(f),
482        formals=formals,
483        type_definition_body=emit_trace_body(f),
484    )
485
486
487WRAPPER_REGISTRATION = CodeTemplate(
488    """\
489m.impl("${name}",
490       TORCH_FN(${class_type}::${type_wrapper_name})
491);
492"""
493)
494
495
496@with_native_function
497def method_registration(f: NativeFunction) -> str:
498    assert cpp.name(f.func) not in MANUAL_TRACER
499
500    return WRAPPER_REGISTRATION.substitute(
501        name=f.func.name,
502        type_wrapper_name=type_wrapper_name(f),
503        class_type="TraceType",
504    )
505
506
507def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
508    return {
509        "ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
510        "trace_method_definitions": [method_definition(fn)],
511        "trace_wrapper_registrations": [method_registration(fn)],
512    }
513
514
515def gen_trace_type(
516    out: str, native_functions: list[NativeFunction], template_path: str
517) -> None:
518    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
519    # template regarding sharding of the generated files.
520    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
521    fm.write_sharded(
522        "TraceType.cpp",
523        [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
524        key_fn=lambda fn: fn.root_name,
525        base_env={
526            "generated_comment": "@"
527            + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp",
528        },
529        env_callable=gen_trace_type_func,
530        num_shards=5,
531        sharded_keys={
532            "ops_headers",
533            "trace_method_definitions",
534            "trace_wrapper_registrations",
535        },
536    )
537