1# mypy: ignore-errors 2 3from copy import copy 4from functools import partial 5 6import torch 7from torch.testing._internal.common_methods_invocations import op_db 8from torch.testing._internal.opinfo.core import ( 9 BinaryUfuncInfo, 10 ReductionOpInfo, 11 SampleInput, 12 UnaryUfuncInfo, 13) 14from torch.utils._pytree import tree_map 15 16 17# random integer used for sizes 18def _rnd(): 19 return torch.randint(3, 8, ()).item() 20 21 22def _raggedness_matches(nt1, nt2): 23 return ( 24 nt1.is_nested 25 and nt2.is_nested 26 and nt1._ragged_idx == nt2._ragged_idx 27 and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx] 28 ) 29 30 31# Generates a random NT. 32# dims should be something like [5, None, 10], with None indicating that a 33# random ragged structure should be used 34def random_nt_from_dims( 35 dims, device=None, dtype=None, layout=torch.strided, requires_grad=False 36): 37 sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])] 38 return torch.nested.nested_tensor( 39 [torch.randn(*size) for size in sizes], 40 device=device, 41 dtype=dtype, 42 layout=layout, 43 requires_grad=requires_grad, 44 ) 45 46 47# Helper function for generating a comprehensive set of NJT sample inputs. 48def _sample_njts(device, dtype, requires_grad=False, dims=None): 49 if dims is None: 50 dims = [2, 3, 4] 51 if not isinstance(dims, (list, tuple)): 52 dims = [dims] 53 54 # contiguous NJTs 55 for dim in dims: 56 # with min / max seqlen cached 57 shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)]) 58 nt = random_nt_from_dims( 59 shape, 60 device=device, 61 dtype=dtype, 62 requires_grad=requires_grad, 63 layout=torch.jagged, 64 ) 65 yield nt 66 67 # without min / max seqlen cached 68 values = nt.values().clone().detach() 69 offsets = nt.offsets().clone().detach() 70 yield torch.nested.nested_tensor_from_jagged(values, offsets) 71 72 # TODO: add non-contiguous NJTs 73 74 75# Computes an unbind-based reference for a given OpInfo on a given SampleInput. 76# This reference unbinds the input NJT and invokes the op on each of the components, 77# optionally wrapping the result in an NJT. 78def unbind_reference(op, sample, wrap_output_as_njt=True): 79 assert sample.input.is_nested 80 out_ref_components = [] 81 for i, component in enumerate(sample.input.unbind(dim=0)): 82 83 def _slice_njts(t, i=i, inp=sample.input): 84 # any NJT with the same ragged structure as the input should 85 # also be sliced to pass to the reference 86 if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp): 87 return t[i] 88 else: 89 return t 90 91 args = tree_map(_slice_njts, sample.args) 92 kwargs = tree_map(_slice_njts, sample.kwargs) 93 94 from torch._prims_common import canonicalize_dims 95 96 # Need to adjust dim to apply on NJT component 97 if "dim" in kwargs: 98 kwargs["dim"] = canonicalize_dims(sample.input.dim(), kwargs["dim"]) - 1 99 assert kwargs["dim"] >= 0 100 101 # TODO: handle this 102 assert "dims" not in kwargs 103 104 out_ref_component = op.op(component, *args, **kwargs) 105 106 # TODO: handle list / tuple / non-NJT outputs 107 assert not isinstance(out_ref_component, (list, tuple)) 108 out_ref_components.append(out_ref_component) 109 110 if wrap_output_as_njt: 111 return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged) 112 113 return out_ref_components 114 115 116# Computes the reference value for a reduction op. 117def reduction_reference(op, sample): 118 assert sample.input.is_nested 119 dim = sample.kwargs.get("dim", None) 120 keepdim = sample.kwargs.get("keepdim", False) 121 assert dim != 0, "reductions over the batch dim are not supported" 122 assert "dims" not in sample.kwargs 123 assert sample.input._ragged_idx == 1 124 125 if dim is None: 126 # calculate reference value by running reduction on values buffer 127 return op.op(sample.input.values(), *sample.args, **sample.kwargs) 128 129 if dim == sample.input._ragged_idx: 130 # calculate reference value by running an unbind reference and stacking 131 out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False) 132 return torch.stack(out_ref_components, dim=0) 133 134 # unbind reference works for other reductions 135 return unbind_reference(op, sample) 136 137 138def sample_inputs_elementwise_njt_unary( 139 op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs 140): 141 if not op_kwargs: 142 op_kwargs = {} 143 144 for njt in _sample_njts( 145 device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] 146 ): 147 yield SampleInput(njt, kwargs=dict(op_kwargs)) 148 149 150def sample_inputs_elementwise_njt_binary( 151 op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs 152): 153 if not op_kwargs: 154 op_kwargs = {} 155 156 for njt1 in _sample_njts( 157 device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] 158 ): 159 # TODO: account for non-contiguous NJTs here 160 # TODO: provide sample inputs for broadcasting cases and mixed (NT, T), (T, NT) inputs 161 njt2 = torch.randn_like(njt1) 162 yield SampleInput(njt1, args=(njt2,), kwargs=dict(op_kwargs)) 163 164 165def sample_inputs_njt_reduction( 166 op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs 167): 168 if not op_kwargs: 169 op_kwargs = {} 170 171 for njt in _sample_njts( 172 device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] 173 ): 174 # dim-wise reduction; includes reduction over the ragged dim 175 # NB: reduction over the batch dim is not supported! 176 # TODO: Cover this in the set of error inputs 177 for dim in range(1, njt.dim()): 178 for keepdim in [False, True]: 179 yield SampleInput( 180 njt, kwargs={**op_kwargs, "dim": dim, "keepdim": keepdim} 181 ) 182 183 # full reduction 184 yield SampleInput(njt, kwargs=dict(op_kwargs)) 185 186 187def unsupported_sample_inputs_func(op_name): 188 def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs): 189 raise RuntimeError( 190 f"OpInfo for {op_name} does not support NJT. Support can be added by modifying " 191 "torch/testing/_internal/opinfo/definitions/nested.py." 192 ) 193 194 return _f 195 196 197def unsupported_reference(op_name): 198 def _f(op, sample): 199 raise RuntimeError( 200 f"OpInfo for {op_name} does not define a ref() function. Support can be added by " 201 "modifying torch/testing/_internal/opinfo/definitions/nested.py." 202 ) 203 204 return _f 205 206 207# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS === 208def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs): 209 # non-contiguous NJTs 210 for njt in _sample_njts( 211 device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] 212 ): 213 yield SampleInput(njt) 214 215 for memory_format in (torch.contiguous_format, torch.preserve_format): 216 # construct a "non-contiguous with holes" NJT 217 values = torch.randn( 218 10, 5, device=device, dtype=dtype, requires_grad=requires_grad 219 ) 220 offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64) 221 lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64) 222 njt = torch.nested.nested_tensor_from_jagged( 223 values, offsets=offsets, lengths=lengths 224 ) 225 226 yield SampleInput(njt, kwargs={"memory_format": memory_format}) 227 228 229def sample_inputs_mvl_gamma(p): 230 return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p}) 231 232 233def sample_inputs_polygamma_n(n): 234 return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) 235 236 237def sample_inputs_special_polygamma_n(n): 238 return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) 239 240 241def sample_inputs_masked_select( 242 op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs 243): 244 for njt in _sample_njts( 245 device=device, dtype=dtype, requires_grad=requires_grad, dims=[2] 246 ): 247 yield SampleInput( 248 njt, kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)} 249 ) 250 251 252sample_inputs_nn_functional_threshold = partial( 253 sample_inputs_elementwise_njt_unary, 254 op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9}, 255) 256# === END OP-SPECIFIC SAMPLE INPUTS FUNCS === 257 258 259# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs 260# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name 261# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary 262# to specify if they cannot be auto-generated for some reason. Try to keep these sorted 263# in alphabetical order! 264njt_sample_inputs = { 265 "clone": sample_inputs_clone, 266 **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, 267 "nn.functional.threshold": sample_inputs_nn_functional_threshold, 268 **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, 269 "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), 270 "masked_select": sample_inputs_masked_select, 271} 272 273 274# Translates an OpInfo entry to one that operates on NJTs. 275def translate_opinfo(op): 276 new_op = copy(op) 277 new_op.supports_njt = True 278 279 if op.full_name in njt_sample_inputs: 280 new_op.sample_inputs_func = njt_sample_inputs[op.full_name] 281 # TODO: make the reference customizeable 282 new_op.ref = unbind_reference 283 elif isinstance(op, UnaryUfuncInfo): 284 new_op.sample_inputs_func = partial( 285 sample_inputs_elementwise_njt_unary, op_kwargs=None 286 ) 287 new_op.ref = unbind_reference 288 elif isinstance(op, BinaryUfuncInfo): 289 new_op.sample_inputs_func = partial( 290 sample_inputs_elementwise_njt_binary, op_kwargs=None 291 ) 292 new_op.ref = unbind_reference 293 elif isinstance(op, ReductionOpInfo): 294 new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None) 295 new_op.ref = reduction_reference 296 # TODO: Translate the rest of the OpInfos 297 else: 298 new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name) 299 new_op.ref = unsupported_reference(op.full_name) 300 new_op.supports_njt = False 301 302 return new_op 303 304 305njt_op_db = [translate_opinfo(op) for op in op_db] 306