• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import contextlib
2import threading
3from typing import Callable, Generator, Iterable, Optional, Union
4
5from .custom_ops import custom_op
6from .infer_schema import infer_schema
7
8
9def triton_op(
10    name: str,
11    fn: Optional[Callable] = None,
12    /,
13    *,
14    mutates_args: Union[str, Iterable[str]],
15    schema: Optional[str] = None,
16) -> Callable:
17    """Create a custom operator whose implementation is backed by 1+ triton kernels.
18
19    Use this instead of :func:`torch.library.custom_op` when the implementation
20    consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
21    custom operators as opaque (:func:`torch.compile` and
22    :func:`torch.export.export` will never trace into them), but ``triton_op``
23    makes the implementation visible to these subsystems, allowing them
24    to optimize the triton kernel(s).
25
26    Note that ``fn`` must only consist of calls to PyTorch-understood
27    operators and triton kernels. Any triton kernels called inside ``fn``
28    must be wrapped in a call to :func:`torch._library.capture_triton``.
29
30    Args:
31        name (str): A name for the custom op that looks like "{namespace}::{name}",
32            e.g. "mylib::my_linear". The name is used as the op's stable identifier
33            in PyTorch subsystems (e.g. torch.export, FX graphs).
34            To avoid name collisions, please use your project name as the namespace;
35            e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
36        mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
37            This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
38            it pessimistically assumes that all inputs to the operator are being mutated.
39        schema (None | str): A schema string for the operator. If None
40            (recommended) we'll infer a schema for the operator from its type
41            annotations. We recommend letting us infer a schema unless you
42            have a specific reason not to.
43            Example: "(Tensor x, int y) -> (Tensor, Tensor)".
44
45    Example::
46
47        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
48        >>> import torch
49        >>> from torch._library import triton_op, capture_triton
50        >>>
51        >>> import triton
52        >>> from triton import language as tl
53        >>>
54        >>> @triton.jit
55        >>> def add_kernel(
56        >>>     in_ptr0,
57        >>>     in_ptr1,
58        >>>     out_ptr,
59        >>>     n_elements,
60        >>>     BLOCK_SIZE: "tl.constexpr",
61        >>> ):
62        >>>     pid = tl.program_id(axis=0)
63        >>>     block_start = pid * BLOCK_SIZE
64        >>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
65        >>>     mask = offsets < n_elements
66        >>>     x = tl.load(in_ptr0 + offsets, mask=mask)
67        >>>     y = tl.load(in_ptr1 + offsets, mask=mask)
68        >>>     output = x + y
69        >>>     tl.store(out_ptr + offsets, output, mask=mask)
70        >>>
71        >>> @triton_op("mylib::add", mutates_args={})
72        >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
73        >>>     output = torch.empty_like(x)
74        >>>     n_elements = output.numel()
75        >>>
76        >>>     def grid(meta):
77        >>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
78        >>>
79        >>>     # NB: we need to wrap the triton kernel in a call to capture_triton
80        >>>     capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
81        >>>     return output
82        >>>
83        >>> @torch.compile
84        >>> def f(x, y):
85        >>>     return add(x, y)
86        >>>
87        >>> x = torch.randn(3, device="cuda")
88        >>> y = torch.randn(3, device="cuda")
89        >>>
90        >>> z = f(x, y)
91        >>> assert torch.allclose(z, x + y)
92
93    """
94
95    def dec(fn: Callable) -> Callable:
96        def backend_fn(*args, **kwargs):  # type: ignore[no-untyped-def]
97            # Optimization: we're passing regular Tensors into the triton kernel, so
98            # no need to go through HOP dispatch
99            with set_capture_triton_enabled(False):
100                return fn(*args, **kwargs)
101
102        result = custom_op(
103            name,
104            backend_fn,
105            mutates_args=mutates_args,
106            schema=infer_schema(fn, mutates_args=mutates_args),
107        )
108        from .._subclasses.functional_tensor import FunctionalTensorMode
109
110        # We require that the user pass us a function that is make_fx traceable,
111        # so we can just register it as the Fake/meta kernel.
112        result.register_fake(fn)
113
114        # We decompose the operator when FunctionalTensorMode is active.
115        # The goal is to decompose the operator in AOTDispatcher.
116        # - With torch.compile, this means that the backend (usually Inductor)
117        #   can see a call to the triton kernel(s) and so it can directly optimize
118        #   them by inlining them into the lowering process.
119        # - With post-dispatch torch.export, this means that there will
120        #   be a call(s) to the triton_kernel_wrapper_functional HOP in the
121        #   graph (that we have yet to figure out how to serialize).
122        def functional_decomp(  # type: ignore[no-untyped-def]
123            mode, _, types, args, kwargs
124        ):
125            with mode:
126                return fn(*args, **kwargs)
127
128        result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
129        return result
130
131    if fn is None:
132        return dec
133    else:
134        return dec(fn)
135
136
137capture_triton_enabled = threading.local()
138capture_triton_enabled_default = True
139
140
141@contextlib.contextmanager
142def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
143    """If triton kernels annotated with @capture_triton should dispatch via HOP
144    or go straight to the triton kernel execution.
145
146    We have this switch because eager-mode performance of HOP dispatch is slow
147    enough to matter (~1ms) and we know that capture_triton isn't necessary in
148    some situations (eager-mode with regular Tensors)
149    """
150    try:
151        prev = is_capture_triton_enabled()
152        capture_triton_enabled.value = enabled
153        yield
154    finally:
155        capture_triton_enabled.value = prev
156
157
158def is_capture_triton_enabled() -> bool:
159    return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
160
161
162def capture_triton(triton_kernel: Callable, /) -> Callable:
163    """Allows capture of a triton kernel into a graph via make_fx or
164    non-strict export (coming soon).
165
166    These technologies perform Dispatcher-based tracing (via
167    ``__torch_dispatch__``) and cannot see calls to raw triton kernels.
168    The ``capture_triton`` API returns a new callable that can actually
169    be traced into a graph.
170
171    Examples:
172
173        >>> # xdoctest: +SKIP
174        >>> import torch
175        >>> import triton
176        >>> from triton import language as tl
177        >>> from torch.fx.experimental.proxy_tensor import make_fx
178        >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
179        >>>
180        >>> @triton.jit
181        >>> def add_kernel(
182        >>>     in_ptr0,
183        >>>     in_ptr1,
184        >>>     out_ptr,
185        >>>     n_elements,
186        >>>     BLOCK_SIZE: "tl.constexpr",
187        >>> ):
188        >>>     pid = tl.program_id(axis=0)
189        >>>     block_start = pid * BLOCK_SIZE
190        >>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
191        >>>     mask = offsets < n_elements
192        >>>     x = tl.load(in_ptr0 + offsets, mask=mask)
193        >>>     y = tl.load(in_ptr1 + offsets, mask=mask)
194        >>>     output = x + y
195        >>>     tl.store(out_ptr + offsets, output, mask=mask)
196        >>>
197        >>> def add(x, y):
198        >>>     output = torch.empty_like(x)
199        >>>     n_elements = output.numel()
200        >>>
201        >>>     def grid_fn(meta):
202        >>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
203        >>>
204        >>>     capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
205        >>>     return output
206        >>>
207        >>> x = torch.randn(3, device="cuda")
208        >>> y = torch.randn(3, device="cuda")
209        >>> gm = make_fx(add)(x, y)
210        >>> print(gm.code)
211        >>> # def forward(self, x_1, y_1):
212        >>> #     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
213        >>> #     triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
214        >>> #         kernel_idx = 0, constant_args_idx = 0,
215        >>> #         grid = [(1, 1, 1)], kwargs = {
216        >>> #             'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
217        >>> #             'n_elements': 3, 'BLOCK_SIZE': 16
218        >>> #         })
219        >>> #     return empty_like
220
221    """
222    from triton.runtime.autotuner import Autotuner
223    from triton.runtime.jit import JITFunction
224
225    from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
226
227    if not isinstance(triton_kernel, (JITFunction, Autotuner)):
228        raise RuntimeError(
229            "capture_triton only works on functions annotated with triton.jit or triton.autotune"
230        )
231    if not is_capture_triton_enabled():
232        return triton_kernel
233    return TraceableTritonKernelWrapper(triton_kernel, None, None)
234