• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: ignore-errors
2
3from copy import copy
4from functools import partial
5
6import torch
7from torch.testing._internal.common_methods_invocations import op_db
8from torch.testing._internal.opinfo.core import (
9    BinaryUfuncInfo,
10    ReductionOpInfo,
11    SampleInput,
12    UnaryUfuncInfo,
13)
14from torch.utils._pytree import tree_map
15
16
17# random integer used for sizes
18def _rnd():
19    return torch.randint(3, 8, ()).item()
20
21
22def _raggedness_matches(nt1, nt2):
23    return (
24        nt1.is_nested
25        and nt2.is_nested
26        and nt1._ragged_idx == nt2._ragged_idx
27        and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx]
28    )
29
30
31# Generates a random NT.
32# dims should be something like [5, None, 10], with None indicating that a
33# random ragged structure should be used
34def random_nt_from_dims(
35    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
36):
37    sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])]
38    return torch.nested.nested_tensor(
39        [torch.randn(*size) for size in sizes],
40        device=device,
41        dtype=dtype,
42        layout=layout,
43        requires_grad=requires_grad,
44    )
45
46
47# Helper function for generating a comprehensive set of NJT sample inputs.
48def _sample_njts(device, dtype, requires_grad=False, dims=None):
49    if dims is None:
50        dims = [2, 3, 4]
51    if not isinstance(dims, (list, tuple)):
52        dims = [dims]
53
54    # contiguous NJTs
55    for dim in dims:
56        # with min / max seqlen cached
57        shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)])
58        nt = random_nt_from_dims(
59            shape,
60            device=device,
61            dtype=dtype,
62            requires_grad=requires_grad,
63            layout=torch.jagged,
64        )
65        yield nt
66
67        # without min / max seqlen cached
68        values = nt.values().clone().detach()
69        offsets = nt.offsets().clone().detach()
70        yield torch.nested.nested_tensor_from_jagged(values, offsets)
71
72    # TODO: add non-contiguous NJTs
73
74
75# Computes an unbind-based reference for a given OpInfo on a given SampleInput.
76# This reference unbinds the input NJT and invokes the op on each of the components,
77# optionally wrapping the result in an NJT.
78def unbind_reference(op, sample, wrap_output_as_njt=True):
79    assert sample.input.is_nested
80    out_ref_components = []
81    for i, component in enumerate(sample.input.unbind(dim=0)):
82
83        def _slice_njts(t, i=i, inp=sample.input):
84            # any NJT with the same ragged structure as the input should
85            # also be sliced to pass to the reference
86            if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp):
87                return t[i]
88            else:
89                return t
90
91        args = tree_map(_slice_njts, sample.args)
92        kwargs = tree_map(_slice_njts, sample.kwargs)
93
94        from torch._prims_common import canonicalize_dims
95
96        # Need to adjust dim to apply on NJT component
97        if "dim" in kwargs:
98            kwargs["dim"] = canonicalize_dims(sample.input.dim(), kwargs["dim"]) - 1
99            assert kwargs["dim"] >= 0
100
101        # TODO: handle this
102        assert "dims" not in kwargs
103
104        out_ref_component = op.op(component, *args, **kwargs)
105
106        # TODO: handle list / tuple / non-NJT outputs
107        assert not isinstance(out_ref_component, (list, tuple))
108        out_ref_components.append(out_ref_component)
109
110    if wrap_output_as_njt:
111        return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged)
112
113    return out_ref_components
114
115
116# Computes the reference value for a reduction op.
117def reduction_reference(op, sample):
118    assert sample.input.is_nested
119    dim = sample.kwargs.get("dim", None)
120    keepdim = sample.kwargs.get("keepdim", False)
121    assert dim != 0, "reductions over the batch dim are not supported"
122    assert "dims" not in sample.kwargs
123    assert sample.input._ragged_idx == 1
124
125    if dim is None:
126        # calculate reference value by running reduction on values buffer
127        return op.op(sample.input.values(), *sample.args, **sample.kwargs)
128
129    if dim == sample.input._ragged_idx:
130        # calculate reference value by running an unbind reference and stacking
131        out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False)
132        return torch.stack(out_ref_components, dim=0)
133
134    # unbind reference works for other reductions
135    return unbind_reference(op, sample)
136
137
138def sample_inputs_elementwise_njt_unary(
139    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
140):
141    if not op_kwargs:
142        op_kwargs = {}
143
144    for njt in _sample_njts(
145        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
146    ):
147        yield SampleInput(njt, kwargs=dict(op_kwargs))
148
149
150def sample_inputs_elementwise_njt_binary(
151    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
152):
153    if not op_kwargs:
154        op_kwargs = {}
155
156    for njt1 in _sample_njts(
157        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
158    ):
159        # TODO: account for non-contiguous NJTs here
160        # TODO: provide sample inputs for broadcasting cases and mixed (NT, T), (T, NT) inputs
161        njt2 = torch.randn_like(njt1)
162        yield SampleInput(njt1, args=(njt2,), kwargs=dict(op_kwargs))
163
164
165def sample_inputs_njt_reduction(
166    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
167):
168    if not op_kwargs:
169        op_kwargs = {}
170
171    for njt in _sample_njts(
172        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
173    ):
174        # dim-wise reduction; includes reduction over the ragged dim
175        # NB: reduction over the batch dim is not supported!
176        # TODO: Cover this in the set of error inputs
177        for dim in range(1, njt.dim()):
178            for keepdim in [False, True]:
179                yield SampleInput(
180                    njt, kwargs={**op_kwargs, "dim": dim, "keepdim": keepdim}
181                )
182
183        # full reduction
184        yield SampleInput(njt, kwargs=dict(op_kwargs))
185
186
187def unsupported_sample_inputs_func(op_name):
188    def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs):
189        raise RuntimeError(
190            f"OpInfo for {op_name} does not support NJT. Support can be added by modifying "
191            "torch/testing/_internal/opinfo/definitions/nested.py."
192        )
193
194    return _f
195
196
197def unsupported_reference(op_name):
198    def _f(op, sample):
199        raise RuntimeError(
200            f"OpInfo for {op_name} does not define a ref() function. Support can be added by "
201            "modifying torch/testing/_internal/opinfo/definitions/nested.py."
202        )
203
204    return _f
205
206
207# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS ===
208def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
209    # non-contiguous NJTs
210    for njt in _sample_njts(
211        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
212    ):
213        yield SampleInput(njt)
214
215    for memory_format in (torch.contiguous_format, torch.preserve_format):
216        # construct a "non-contiguous with holes" NJT
217        values = torch.randn(
218            10, 5, device=device, dtype=dtype, requires_grad=requires_grad
219        )
220        offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64)
221        lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64)
222        njt = torch.nested.nested_tensor_from_jagged(
223            values, offsets=offsets, lengths=lengths
224        )
225
226        yield SampleInput(njt, kwargs={"memory_format": memory_format})
227
228
229def sample_inputs_mvl_gamma(p):
230    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p})
231
232
233def sample_inputs_polygamma_n(n):
234    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
235
236
237def sample_inputs_special_polygamma_n(n):
238    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
239
240
241def sample_inputs_masked_select(
242    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
243):
244    for njt in _sample_njts(
245        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2]
246    ):
247        yield SampleInput(
248            njt, kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)}
249        )
250
251
252sample_inputs_nn_functional_threshold = partial(
253    sample_inputs_elementwise_njt_unary,
254    op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9},
255)
256# === END OP-SPECIFIC SAMPLE INPUTS FUNCS ===
257
258
259# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
260# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name
261# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary
262# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
263# in alphabetical order!
264njt_sample_inputs = {
265    "clone": sample_inputs_clone,
266    **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
267    "nn.functional.threshold": sample_inputs_nn_functional_threshold,
268    **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},
269    "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0),
270    "masked_select": sample_inputs_masked_select,
271}
272
273
274# Translates an OpInfo entry to one that operates on NJTs.
275def translate_opinfo(op):
276    new_op = copy(op)
277    new_op.supports_njt = True
278
279    if op.full_name in njt_sample_inputs:
280        new_op.sample_inputs_func = njt_sample_inputs[op.full_name]
281        # TODO: make the reference customizeable
282        new_op.ref = unbind_reference
283    elif isinstance(op, UnaryUfuncInfo):
284        new_op.sample_inputs_func = partial(
285            sample_inputs_elementwise_njt_unary, op_kwargs=None
286        )
287        new_op.ref = unbind_reference
288    elif isinstance(op, BinaryUfuncInfo):
289        new_op.sample_inputs_func = partial(
290            sample_inputs_elementwise_njt_binary, op_kwargs=None
291        )
292        new_op.ref = unbind_reference
293    elif isinstance(op, ReductionOpInfo):
294        new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None)
295        new_op.ref = reduction_reference
296    # TODO: Translate the rest of the OpInfos
297    else:
298        new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name)
299        new_op.ref = unsupported_reference(op.full_name)
300        new_op.supports_njt = False
301
302    return new_op
303
304
305njt_op_db = [translate_opinfo(op) for op in op_db]
306