• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import itertools
2import unittest
3from functools import partial
4
5import torch
6from torch.testing._internal.common_dtype import (
7    all_types_and_complex_and,
8    floating_types,
9    floating_types_and,
10)
11from torch.testing._internal.common_methods_invocations import (
12    DecorateInfo,
13    OpInfo,
14    SampleInput,
15)
16from torch.testing._internal.common_utils import make_tensor
17
18
19# List of OpInfos that aren't in PyTorch Core yet.
20# They are here because we wanted a fast way of writing OpInfos and may not be
21# 100% correct (w.r.t. to dtypes and other options).
22# TODO: Figure out how to upstream these, delete them when they're upstreamed
23
24additional_op_db = []
25
26# https://github.com/pytorch/pytorch/pull/61068
27
28
29def sample_inputs_conv2d(
30    has_bias, self, device, dtype, requires_grad, extra_args=(), groups=1
31):
32    in_ch, out_ch = 6, 4
33    inp = make_tensor(
34        (2, in_ch * groups, 7, 5),
35        device=device,
36        dtype=dtype,
37        requires_grad=requires_grad,
38        low=-1,
39        high=1,
40    )
41    weight = make_tensor(
42        (out_ch * groups, in_ch, 3, 2),
43        device=device,
44        dtype=dtype,
45        requires_grad=requires_grad,
46        low=-1,
47        high=1,
48    )
49    bias = None
50    if has_bias:
51        bias = make_tensor(
52            (out_ch * groups,),
53            device=device,
54            dtype=dtype,
55            requires_grad=requires_grad,
56            low=-1,
57            high=1,
58        )
59    return [SampleInput(inp, args=((weight, bias) + extra_args))]
60
61
62additional_op_db.extend(
63    [
64        OpInfo(
65            "nn.functional.conv2d",
66            aten_name="conv2d",
67            variant_test_name="no_bias",
68            supports_autograd=True,
69            supports_forward_ad=True,
70            sample_inputs_func=partial(sample_inputs_conv2d, False),
71            dtypes=floating_types(),
72            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
73            supports_out=False,
74        ),
75        OpInfo(
76            "nn.functional.conv2d",
77            aten_name="conv2d",
78            variant_test_name="with_bias",
79            supports_autograd=True,
80            supports_forward_ad=True,
81            sample_inputs_func=partial(sample_inputs_conv2d, True),
82            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
83            dtypes=floating_types(),
84            supports_out=False,
85        ),
86        OpInfo(
87            "nn.functional.conv2d",
88            aten_name="conv2d",
89            variant_test_name="stride_with_bias",
90            supports_autograd=True,
91            supports_forward_ad=True,
92            sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2))),
93            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
94            dtypes=floating_types(),
95            supports_out=False,
96        ),
97        OpInfo(
98            "nn.functional.conv2d",
99            aten_name="conv2d",
100            variant_test_name="stride_no_bias",
101            supports_autograd=True,
102            supports_forward_ad=True,
103            sample_inputs_func=partial(
104                sample_inputs_conv2d, False, extra_args=((2, 2))
105            ),
106            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
107            dtypes=floating_types(),
108            supports_out=False,
109        ),
110        OpInfo(
111            "nn.functional.conv2d",
112            aten_name="conv2d",
113            variant_test_name="stride_padding_with_bias",
114            supports_autograd=True,
115            supports_forward_ad=True,
116            sample_inputs_func=partial(
117                sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1))
118            ),
119            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
120            dtypes=floating_types(),
121            supports_out=False,
122        ),
123        OpInfo(
124            "nn.functional.conv2d",
125            aten_name="conv2d",
126            variant_test_name="stride_padding_no_bias",
127            supports_autograd=True,
128            supports_forward_ad=True,
129            sample_inputs_func=partial(
130                sample_inputs_conv2d, False, extra_args=((2, 2), (1, 1))
131            ),
132            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
133            dtypes=floating_types(),
134            supports_out=False,
135        ),
136        OpInfo(
137            "nn.functional.conv2d",
138            aten_name="conv2d",
139            variant_test_name="strided_padding_dilation_with_bias",
140            supports_autograd=True,
141            supports_forward_ad=True,
142            sample_inputs_func=partial(
143                sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
144            ),
145            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
146            dtypes=floating_types(),
147            supports_out=False,
148        ),
149        OpInfo(
150            "nn.functional.conv2d",
151            aten_name="conv2d",
152            variant_test_name="strided_padding_dilation_no_bias",
153            supports_autograd=True,
154            supports_forward_ad=True,
155            sample_inputs_func=partial(
156                sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
157            ),
158            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
159            dtypes=floating_types(),
160            supports_out=False,
161        ),
162        OpInfo(
163            "nn.functional.conv2d",
164            aten_name="conv2d",
165            variant_test_name="stride_groups_with_bias",
166            supports_autograd=True,
167            supports_forward_ad=True,
168            sample_inputs_func=partial(
169                sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 2), groups=2
170            ),
171            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
172            dtypes=floating_types(),
173            supports_out=False,
174        ),
175        OpInfo(
176            "nn.functional.conv2d",
177            aten_name="conv2d",
178            variant_test_name="stride_depthwise_with_bias",
179            supports_autograd=True,
180            supports_forward_ad=True,
181            sample_inputs_func=partial(
182                sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 6), groups=6
183            ),
184            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
185            dtypes=floating_types(),
186            supports_out=False,
187        ),
188    ]
189)
190
191
192# TODO: PyTorch core has a check for if requires_grad=True or not.
193# We actually want to test more things for backward here which is why we have our own
194def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
195    def make_input(shape):
196        return make_tensor(
197            shape, device=device, dtype=dtype, requires_grad=requires_grad
198        )
199
200    def make_long_input(shape, *, low, high):
201        return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
202
203    M = 20
204    S = 5
205
206    def generator():
207        # 0-D index tensor
208        idx = make_long_input((), low=0, high=M)
209        yield SampleInput(
210            make_input((M, S)),
211            args=(idx,),
212        )
213
214        # 1-D index tensor
215        idx = make_long_input((S,), low=0, high=M)
216        yield SampleInput(
217            make_input((M, S)),
218            args=(idx,),
219        )
220
221        # 2-D index tensor
222        idx = make_long_input((S, S), low=0, high=M)
223        yield SampleInput(
224            make_input((M, S)),
225            args=(idx,),
226        )
227
228        idx = make_long_input((2, 2), low=0, high=S)
229        idx[0, 0] = 2
230        idx[1, 1] = 2
231        yield SampleInput(
232            make_input((S, S)),
233            args=(idx,),
234            kwargs={"padding_idx": 2},
235        )
236
237        idx = make_long_input((2, 2), low=0, high=S)
238        idx[0, 0] = 4
239        idx[1, 1] = 4
240        yield SampleInput(
241            make_input((S, S)),
242            args=(idx,),
243            kwargs={"padding_idx": -1},
244        )
245
246        # Scale the gradient based on the inverse frequency of a particular index.
247        idx = make_long_input((2, 2), low=0, high=S)
248        idx[0, 0] = 1
249        idx[0, 1] = 1
250        weights = make_input((S, S))
251        yield SampleInput(
252            weights,
253            args=(idx,),
254            kwargs={"scale_grad_by_freq": True},
255        )
256
257    return list(generator())
258
259
260additional_op_db.append(
261    OpInfo(
262        "nn.functional.embedding",
263        variant_test_name="functorch",
264        # We use lambda to reshuffle the positional arguments.
265        # This is because currently only the `input` field of SampleInput
266        # is tested in gradient tests.
267        op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(
268            idx, weight, **kwargs
269        ),
270        dtypes=floating_types_and(torch.bfloat16, torch.float16),
271        sample_inputs_func=sample_inputs_embedding,
272        supports_forward_ad=True,
273        supports_fwgrad_bwgrad=True,
274        supports_out=False,
275    )
276)
277
278
279def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
280    def make_input(shape, requires_grad=requires_grad):
281        return make_tensor(
282            shape, device=device, dtype=dtype, requires_grad=requires_grad
283        )
284
285    rhs_requires_grad = kwargs.get("rhs_requires_grad", requires_grad)
286    S = 5
287
288    shapes = ((S, S), (S, S, S), (S, S, S, S))
289    reductions = ("none", "mean", "sum")
290
291    for shape, reduction in itertools.product(shapes, reductions):
292        yield SampleInput(
293            make_input(shape),
294            args=(make_input(shape, requires_grad=rhs_requires_grad),),
295            kwargs={"reduction": reduction},
296        )
297
298
299additional_op_db.append(
300    OpInfo(
301        "nn.functional.mse_loss",
302        variant_test_name="functorch",
303        sample_inputs_func=sample_inputs_mse_loss,
304        supports_out=False,
305        supports_forward_ad=True,
306        supports_fwgrad_bwgrad=True,
307        dtypes=floating_types_and(torch.float16),
308        backward_dtypes=floating_types(),
309        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
310        backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
311    )
312)
313
314
315# TODO: upstream sample inputs to pytorch/pytorch.
316# We are more comprehensive.
317def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
318    # Short for "advanced index"
319    adv_idx = torch.LongTensor([[0, 1], [2, 3]])
320    S = 5
321    # self_dim, indices
322    test_args = [
323        (3, ([1, 2],)),
324        (3, (slice(0, 3),)),
325        (3, ([slice(0, 3), 1],)),
326        (3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)),
327        (3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)),
328        (3, ([slice(None), slice(None), [0, 3]],)),
329        (3, ([slice(None), [0, 3], slice(None)],)),
330        (3, ([[0, 3], slice(None), slice(None)],)),
331        (3, ([[0, 3], [1, 2], slice(None)],)),
332        (
333            3,
334            (
335                [
336                    [0, 3],
337                ],
338            ),
339        ),
340        (3, ([[0, 3], slice(None)],)),
341        (3, ([[0, 3], Ellipsis],)),
342        (3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)),
343        (4, ([slice(None), adv_idx, adv_idx, slice(None)],)),
344        (4, ([slice(None), adv_idx, slice(None), adv_idx],)),
345        (4, ([adv_idx, slice(None), slice(None), adv_idx],)),
346        (4, ([slice(None), slice(None), adv_idx, adv_idx],)),
347        (4, ([Ellipsis, adv_idx, adv_idx],)),
348        (5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)),
349        (5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)),
350        (5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)),
351        (6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)),
352        (6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)),
353        (6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)),
354    ]
355
356    def get_shape(dim):
357        return tuple(S + i for i in range(dim))
358
359    return tuple(
360        SampleInput(
361            make_tensor(
362                get_shape(self_dim),
363                device=device,
364                dtype=dtype,
365                low=None,
366                high=None,
367                requires_grad=requires_grad,
368            ),
369            args=args,
370        )
371        for self_dim, args in test_args
372    )
373
374
375# TODO: split PyTorch's __getitem__. The problem is we don't support indexing
376# with masks with vmap.
377additional_op_db.append(
378    OpInfo(
379        "__getitem__",
380        variant_test_name="functorch",
381        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
382        supports_out=False,
383        supports_inplace_autograd=False,
384        supports_scripting=False,
385        op=torch.Tensor.__getitem__,
386        assert_jit_shape_analysis=False,  # TODO: support index.Tensor()
387        supports_forward_ad=True,
388        sample_inputs_func=sample_inputs_getitem,
389    )
390)
391
392
393# Turns out at::index_put is different from torch.index_put...
394# TODO: figure out how to upstream this
395def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs):
396    make_arg = partial(
397        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
398    )
399    inputs = []
400    adv_idx = torch.LongTensor([[0, 1], [2, 3]])
401    # self_shape, indices
402    additional = [
403        ((5, 6, 7, 8), [None, adv_idx, adv_idx, None]),
404        ((5, 6, 7, 8), [None, adv_idx, None, adv_idx]),
405        ((5, 6, 7, 8), [adv_idx, None, None, adv_idx]),
406        ((5, 6, 7, 8), [None, None, adv_idx, adv_idx]),
407        ((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]),
408        ((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]),
409        ((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]),
410        ((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]),
411    ]
412    for self_shape, indices in additional:
413        for broadcast_value in [False, True]:
414            inp = make_arg(self_shape)
415
416            tmp_indices = [slice(None) if idx is None else idx for idx in indices]
417            values_shape = inp[tmp_indices].shape
418            if broadcast_value:
419                values_shape = values_shape[3:]
420            values = make_arg(values_shape)
421            inputs.append(SampleInput(inp, args=(tuple(indices), values)))
422    return inputs
423
424
425def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
426    make_arg = partial(
427        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
428    )
429    make_idx = partial(
430        make_tensor, dtype=torch.long, device=device, requires_grad=False
431    )
432    S = 5
433    inputs = []
434    for accumulate in [False, True]:
435        # putting vectors at indexed locations
436        inputs.append(
437            SampleInput(
438                make_arg((S, S)),
439                args=((make_idx((2,), low=0, high=4),), make_arg((2, S))),
440                kwargs=dict(accumulate=accumulate),
441            )
442        )
443
444        # putting multi-dim tensors at indexed locations
445        inputs.append(
446            SampleInput(
447                make_arg((S, S, 2)),
448                args=((make_idx((3,), low=0, high=4),), make_arg((3, S, 2))),
449                kwargs=dict(accumulate=accumulate),
450            )
451        )
452
453        # value with size `0` dim
454        inputs.append(
455            SampleInput(
456                make_arg((S, 0)),
457                args=((make_idx((3,), low=0, high=4),), make_arg((3, 0))),
458                kwargs=dict(accumulate=accumulate),
459            )
460        )
461
462        # scalar value
463        inputs.append(
464            SampleInput(
465                make_arg((S,)),
466                args=((make_idx((), low=0, high=S),), make_arg(())),
467                kwargs=dict(accumulate=accumulate),
468            )
469        )
470
471        # cuda and accumulate don't work well
472        # Reference: https://github.com/pytorch/pytorch/issues/72053
473        if not accumulate and device == "cuda":
474            # Broadcast `values`
475            inputs.append(
476                SampleInput(
477                    make_arg((S, S)),
478                    args=((make_idx((2,), low=0, high=S),), make_arg((S,))),
479                    kwargs=dict(accumulate=accumulate),
480                )
481            )
482
483    return inputs
484
485
486additional_op_db.append(
487    OpInfo(
488        "index_put",
489        variant_test_name="functorch",
490        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
491        supports_out=False,
492        sample_inputs_func=sample_inputs_index_put,
493        supports_forward_ad=True,
494    )
495)
496additional_op_db.append(
497    OpInfo(
498        "ops.aten.index_put",
499        variant_test_name="functorch",
500        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
501        supports_out=False,
502        sample_inputs_func=sample_inputs_aten_index_put,
503        supports_forward_ad=True,
504    )
505)
506
507
508def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
509    S = 3
510    make_arg = partial(
511        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
512    )
513
514    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
515    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
516    yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
517    yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
518    yield SampleInput(
519        make_arg((S,)),
520        args=(torch.randn(S, S, device=device) > 0, 10),
521        broadcasts_input=True,
522    )
523
524
525additional_op_db.append(
526    OpInfo(
527        "masked_fill",
528        variant_test_name="functorch_Scalar_only",
529        dtypes=all_types_and_complex_and(
530            torch.bool, torch.half, torch.bfloat16, torch.chalf
531        ),
532        sample_inputs_func=sample_inputs_masked_fill,
533        supports_forward_ad=True,
534        supports_fwgrad_bwgrad=True,
535        check_batched_forward_grad=False,
536        supports_out=False,
537    )
538)
539
540
541def sample_inputs_new_zeros_with_same_feature_meta(
542    op_info, device, dtype, requires_grad, **kwargs
543):
544    make_arg = partial(
545        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
546    )
547    matrix = [
548        # tangent, base, num_tangent_bdims
549        ([5], [2, 3], 0),
550        ([2, 3], [2, 3], 0),
551        ([5], [2], 0),
552        ([1, 0, 2], [1, 2], 0),
553        ([], [1, 2], 0),
554        ([8, 7, 5], [2, 3, 11], 1),
555        ([6, 7, 5], [2, 3, 4], 2),
556        ([6, 4], [3], 2),
557    ]
558    results = []
559    for tangent_shape, base_shape, num_tangent_bdims in matrix:
560        tangent = make_arg(tangent_shape)
561        base = make_arg(base_shape)
562        results.append(
563            SampleInput(
564                tangent,
565                args=(base,),
566                kwargs=dict(self_num_batch_dims=num_tangent_bdims),
567            )
568        )
569    return results
570
571
572additional_op_db.append(
573    OpInfo(
574        "ops.aten._new_zeros_with_same_feature_meta",
575        variant_test_name="functorchonly",
576        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
577        supports_out=False,
578        supports_autograd=False,
579        supports_forward_ad=False,
580        sample_inputs_func=sample_inputs_new_zeros_with_same_feature_meta,
581    )
582)
583
584
585def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
586    make_arg = partial(
587        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
588    )
589    shapes = ((), (2, 3))
590    memory_format_options = [None, torch.contiguous_format]
591    for shape, memory_format in itertools.product(shapes, memory_format_options):
592        yield SampleInput(
593            make_arg(shape),
594            kwargs={"memory_format": memory_format} if memory_format else {},
595        )
596
597
598additional_op_db.extend(
599    [
600        OpInfo(
601            "bfloat16",
602            op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
603            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
604            supports_out=False,
605            variant_test_name="functorch_no_channels_last",
606            sample_inputs_func=sample_inputs_conversion,
607            skips=(
608                # autograd tests don't handle operators that change dtype
609                DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
610                DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
611                DecorateInfo(
612                    unittest.expectedFailure,
613                    "TestNormalizeOperators",
614                    "test_normalize_operator_exhaustive",
615                ),
616                # RuntimeError: attribute lookup is not defined on builtin
617                DecorateInfo(
618                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
619                ),
620                DecorateInfo(
621                    unittest.skip("Skipped!"), "TestNNCOpInfo", "test_nnc_correctness"
622                ),
623            ),
624        ),
625        OpInfo(
626            "bool",
627            op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
628            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
629            supports_out=False,
630            variant_test_name="functorch_no_channels_last",
631            sample_inputs_func=sample_inputs_conversion,
632            supports_autograd=False,
633            skips=(
634                DecorateInfo(
635                    unittest.expectedFailure,
636                    "TestNormalizeOperators",
637                    "test_normalize_operator_exhaustive",
638                ),
639                # RuntimeError: attribute lookup is not defined on builtin
640                DecorateInfo(
641                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
642                ),
643            ),
644        ),
645        OpInfo(
646            "byte",
647            op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
648            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
649            supports_out=False,
650            variant_test_name="functorch_no_channels_last",
651            sample_inputs_func=sample_inputs_conversion,
652            # The autograd test runner cannot handle functions that change dtype
653            supports_autograd=False,
654            skips=(
655                DecorateInfo(
656                    unittest.expectedFailure,
657                    "TestNormalizeOperators",
658                    "test_normalize_operator_exhaustive",
659                ),
660                # RuntimeError: attribute lookup is not defined on builtin
661                DecorateInfo(
662                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
663                ),
664            ),
665        ),
666        OpInfo(
667            "char",
668            op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
669            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
670            supports_out=False,
671            variant_test_name="functorch_no_channels_last",
672            sample_inputs_func=sample_inputs_conversion,
673            # The autograd test runner cannot handle functions that change dtype
674            supports_autograd=False,
675            skips=(
676                DecorateInfo(
677                    unittest.expectedFailure,
678                    "TestNormalizeOperators",
679                    "test_normalize_operator_exhaustive",
680                ),
681                # RuntimeError: attribute lookup is not defined on builtin
682                DecorateInfo(
683                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
684                ),
685            ),
686        ),
687        OpInfo(
688            "double",
689            op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
690            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
691            supports_out=False,
692            variant_test_name="functorch_no_channels_last",
693            sample_inputs_func=sample_inputs_conversion,
694            supports_forward_ad=True,
695            supports_fwgrad_bwgrad=True,
696            skips=(
697                DecorateInfo(
698                    unittest.expectedFailure,
699                    "TestNormalizeOperators",
700                    "test_normalize_operator_exhaustive",
701                ),
702                # RuntimeError: attribute lookup is not defined on builtin
703                DecorateInfo(
704                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
705                ),
706            ),
707        ),
708        OpInfo(
709            "float",
710            op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
711            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
712            supports_out=False,
713            variant_test_name="functorch_no_channels_last",
714            sample_inputs_func=sample_inputs_conversion,
715            skips=(
716                # autograd tests don't handle operators that change dtype
717                DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
718                DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
719                DecorateInfo(
720                    unittest.expectedFailure,
721                    "TestNormalizeOperators",
722                    "test_normalize_operator_exhaustive",
723                ),
724                # RuntimeError: attribute lookup is not defined on builtin
725                DecorateInfo(
726                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
727                ),
728            ),
729        ),
730        OpInfo(
731            "half",
732            op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
733            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
734            supports_out=False,
735            variant_test_name="functorch_no_channels_last",
736            sample_inputs_func=sample_inputs_conversion,
737            skips=(
738                # autograd tests don't handle operators that change dtype
739                DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
740                DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
741                DecorateInfo(
742                    unittest.expectedFailure,
743                    "TestNormalizeOperators",
744                    "test_normalize_operator_exhaustive",
745                ),
746                # RuntimeError: attribute lookup is not defined on builtin
747                DecorateInfo(
748                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
749                ),
750            ),
751        ),
752        OpInfo(
753            "int",
754            op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
755            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
756            supports_out=False,
757            variant_test_name="functorch_no_channels_last",
758            sample_inputs_func=sample_inputs_conversion,
759            supports_autograd=False,
760            skips=(
761                DecorateInfo(
762                    unittest.expectedFailure,
763                    "TestNormalizeOperators",
764                    "test_normalize_operator_exhaustive",
765                ),
766                # RuntimeError: attribute lookup is not defined on builtin
767                DecorateInfo(
768                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
769                ),
770            ),
771        ),
772        OpInfo(
773            "long",
774            op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
775            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
776            supports_out=False,
777            variant_test_name="functorch_no_channels_last",
778            sample_inputs_func=sample_inputs_conversion,
779            supports_autograd=False,
780            skips=(
781                DecorateInfo(
782                    unittest.expectedFailure,
783                    "TestNormalizeOperators",
784                    "test_normalize_operator_exhaustive",
785                ),
786                # RuntimeError: attribute lookup is not defined on builtin
787                DecorateInfo(
788                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
789                ),
790            ),
791        ),
792        OpInfo(
793            "short",
794            op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
795            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
796            supports_out=False,
797            variant_test_name="functorch_no_channels_last",
798            sample_inputs_func=sample_inputs_conversion,
799            supports_autograd=False,
800            skips=(
801                DecorateInfo(
802                    unittest.expectedFailure,
803                    "TestNormalizeOperators",
804                    "test_normalize_operator_exhaustive",
805                ),
806                # RuntimeError: attribute lookup is not defined on builtin
807                DecorateInfo(
808                    unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
809                ),
810            ),
811        ),
812    ]
813)
814