import itertools import unittest from functools import partial import torch from torch.testing._internal.common_dtype import ( all_types_and_complex_and, floating_types, floating_types_and, ) from torch.testing._internal.common_methods_invocations import ( DecorateInfo, OpInfo, SampleInput, ) from torch.testing._internal.common_utils import make_tensor # List of OpInfos that aren't in PyTorch Core yet. # They are here because we wanted a fast way of writing OpInfos and may not be # 100% correct (w.r.t. to dtypes and other options). # TODO: Figure out how to upstream these, delete them when they're upstreamed additional_op_db = [] # https://github.com/pytorch/pytorch/pull/61068 def sample_inputs_conv2d( has_bias, self, device, dtype, requires_grad, extra_args=(), groups=1 ): in_ch, out_ch = 6, 4 inp = make_tensor( (2, in_ch * groups, 7, 5), device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1, ) weight = make_tensor( (out_ch * groups, in_ch, 3, 2), device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1, ) bias = None if has_bias: bias = make_tensor( (out_ch * groups,), device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1, ) return [SampleInput(inp, args=((weight, bias) + extra_args))] additional_op_db.extend( [ OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="no_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_conv2d, False), dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_conv2d, True), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2))), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_no_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, False, extra_args=((2, 2)) ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_padding_with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1)) ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_padding_no_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, False, extra_args=((2, 2), (1, 1)) ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="strided_padding_dilation_with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2)) ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="strided_padding_dilation_no_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2)) ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_groups_with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 2), groups=2 ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), OpInfo( "nn.functional.conv2d", aten_name="conv2d", variant_test_name="stride_depthwise_with_bias", supports_autograd=True, supports_forward_ad=True, sample_inputs_func=partial( sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 6), groups=6 ), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types(), supports_out=False, ), ] ) # TODO: PyTorch core has a check for if requires_grad=True or not. # We actually want to test more things for backward here which is why we have our own def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): def make_input(shape): return make_tensor( shape, device=device, dtype=dtype, requires_grad=requires_grad ) def make_long_input(shape, *, low, high): return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) M = 20 S = 5 def generator(): # 0-D index tensor idx = make_long_input((), low=0, high=M) yield SampleInput( make_input((M, S)), args=(idx,), ) # 1-D index tensor idx = make_long_input((S,), low=0, high=M) yield SampleInput( make_input((M, S)), args=(idx,), ) # 2-D index tensor idx = make_long_input((S, S), low=0, high=M) yield SampleInput( make_input((M, S)), args=(idx,), ) idx = make_long_input((2, 2), low=0, high=S) idx[0, 0] = 2 idx[1, 1] = 2 yield SampleInput( make_input((S, S)), args=(idx,), kwargs={"padding_idx": 2}, ) idx = make_long_input((2, 2), low=0, high=S) idx[0, 0] = 4 idx[1, 1] = 4 yield SampleInput( make_input((S, S)), args=(idx,), kwargs={"padding_idx": -1}, ) # Scale the gradient based on the inverse frequency of a particular index. idx = make_long_input((2, 2), low=0, high=S) idx[0, 0] = 1 idx[0, 1] = 1 weights = make_input((S, S)) yield SampleInput( weights, args=(idx,), kwargs={"scale_grad_by_freq": True}, ) return list(generator()) additional_op_db.append( OpInfo( "nn.functional.embedding", variant_test_name="functorch", # We use lambda to reshuffle the positional arguments. # This is because currently only the `input` field of SampleInput # is tested in gradient tests. op=lambda weight, idx, **kwargs: torch.nn.functional.embedding( idx, weight, **kwargs ), dtypes=floating_types_and(torch.bfloat16, torch.float16), sample_inputs_func=sample_inputs_embedding, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, ) ) def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs): def make_input(shape, requires_grad=requires_grad): return make_tensor( shape, device=device, dtype=dtype, requires_grad=requires_grad ) rhs_requires_grad = kwargs.get("rhs_requires_grad", requires_grad) S = 5 shapes = ((S, S), (S, S, S), (S, S, S, S)) reductions = ("none", "mean", "sum") for shape, reduction in itertools.product(shapes, reductions): yield SampleInput( make_input(shape), args=(make_input(shape, requires_grad=rhs_requires_grad),), kwargs={"reduction": reduction}, ) additional_op_db.append( OpInfo( "nn.functional.mse_loss", variant_test_name="functorch", sample_inputs_func=sample_inputs_mse_loss, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=floating_types_and(torch.float16), backward_dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), ) ) # TODO: upstream sample inputs to pytorch/pytorch. # We are more comprehensive. def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): # Short for "advanced index" adv_idx = torch.LongTensor([[0, 1], [2, 3]]) S = 5 # self_dim, indices test_args = [ (3, ([1, 2],)), (3, (slice(0, 3),)), (3, ([slice(0, 3), 1],)), (3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)), (3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)), (3, ([slice(None), slice(None), [0, 3]],)), (3, ([slice(None), [0, 3], slice(None)],)), (3, ([[0, 3], slice(None), slice(None)],)), (3, ([[0, 3], [1, 2], slice(None)],)), ( 3, ( [ [0, 3], ], ), ), (3, ([[0, 3], slice(None)],)), (3, ([[0, 3], Ellipsis],)), (3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)), (4, ([slice(None), adv_idx, adv_idx, slice(None)],)), (4, ([slice(None), adv_idx, slice(None), adv_idx],)), (4, ([adv_idx, slice(None), slice(None), adv_idx],)), (4, ([slice(None), slice(None), adv_idx, adv_idx],)), (4, ([Ellipsis, adv_idx, adv_idx],)), (5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)), (5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)), (5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)), (6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)), (6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)), (6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)), ] def get_shape(dim): return tuple(S + i for i in range(dim)) return tuple( SampleInput( make_tensor( get_shape(self_dim), device=device, dtype=dtype, low=None, high=None, requires_grad=requires_grad, ), args=args, ) for self_dim, args in test_args ) # TODO: split PyTorch's __getitem__. The problem is we don't support indexing # with masks with vmap. additional_op_db.append( OpInfo( "__getitem__", variant_test_name="functorch", dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, supports_inplace_autograd=False, supports_scripting=False, op=torch.Tensor.__getitem__, assert_jit_shape_analysis=False, # TODO: support index.Tensor() supports_forward_ad=True, sample_inputs_func=sample_inputs_getitem, ) ) # Turns out at::index_put is different from torch.index_put... # TODO: figure out how to upstream this def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial( make_tensor, dtype=dtype, device=device, requires_grad=requires_grad ) inputs = [] adv_idx = torch.LongTensor([[0, 1], [2, 3]]) # self_shape, indices additional = [ ((5, 6, 7, 8), [None, adv_idx, adv_idx, None]), ((5, 6, 7, 8), [None, adv_idx, None, adv_idx]), ((5, 6, 7, 8), [adv_idx, None, None, adv_idx]), ((5, 6, 7, 8), [None, None, adv_idx, adv_idx]), ((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]), ((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]), ((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]), ((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]), ] for self_shape, indices in additional: for broadcast_value in [False, True]: inp = make_arg(self_shape) tmp_indices = [slice(None) if idx is None else idx for idx in indices] values_shape = inp[tmp_indices].shape if broadcast_value: values_shape = values_shape[3:] values = make_arg(values_shape) inputs.append(SampleInput(inp, args=(tuple(indices), values))) return inputs def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial( make_tensor, dtype=dtype, device=device, requires_grad=requires_grad ) make_idx = partial( make_tensor, dtype=torch.long, device=device, requires_grad=False ) S = 5 inputs = [] for accumulate in [False, True]: # putting vectors at indexed locations inputs.append( SampleInput( make_arg((S, S)), args=((make_idx((2,), low=0, high=4),), make_arg((2, S))), kwargs=dict(accumulate=accumulate), ) ) # putting multi-dim tensors at indexed locations inputs.append( SampleInput( make_arg((S, S, 2)), args=((make_idx((3,), low=0, high=4),), make_arg((3, S, 2))), kwargs=dict(accumulate=accumulate), ) ) # value with size `0` dim inputs.append( SampleInput( make_arg((S, 0)), args=((make_idx((3,), low=0, high=4),), make_arg((3, 0))), kwargs=dict(accumulate=accumulate), ) ) # scalar value inputs.append( SampleInput( make_arg((S,)), args=((make_idx((), low=0, high=S),), make_arg(())), kwargs=dict(accumulate=accumulate), ) ) # cuda and accumulate don't work well # Reference: https://github.com/pytorch/pytorch/issues/72053 if not accumulate and device == "cuda": # Broadcast `values` inputs.append( SampleInput( make_arg((S, S)), args=((make_idx((2,), low=0, high=S),), make_arg((S,))), kwargs=dict(accumulate=accumulate), ) ) return inputs additional_op_db.append( OpInfo( "index_put", variant_test_name="functorch", dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_index_put, supports_forward_ad=True, ) ) additional_op_db.append( OpInfo( "ops.aten.index_put", variant_test_name="functorch", dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_aten_index_put, supports_forward_ad=True, ) ) def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): S = 3 make_arg = partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) yield SampleInput( make_arg((S,)), args=(torch.randn(S, S, device=device) > 0, 10), broadcasts_input=True, ) additional_op_db.append( OpInfo( "masked_fill", variant_test_name="functorch_Scalar_only", dtypes=all_types_and_complex_and( torch.bool, torch.half, torch.bfloat16, torch.chalf ), sample_inputs_func=sample_inputs_masked_fill, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, supports_out=False, ) ) def sample_inputs_new_zeros_with_same_feature_meta( op_info, device, dtype, requires_grad, **kwargs ): make_arg = partial( make_tensor, dtype=dtype, device=device, requires_grad=requires_grad ) matrix = [ # tangent, base, num_tangent_bdims ([5], [2, 3], 0), ([2, 3], [2, 3], 0), ([5], [2], 0), ([1, 0, 2], [1, 2], 0), ([], [1, 2], 0), ([8, 7, 5], [2, 3, 11], 1), ([6, 7, 5], [2, 3, 4], 2), ([6, 4], [3], 2), ] results = [] for tangent_shape, base_shape, num_tangent_bdims in matrix: tangent = make_arg(tangent_shape) base = make_arg(base_shape) results.append( SampleInput( tangent, args=(base,), kwargs=dict(self_num_batch_dims=num_tangent_bdims), ) ) return results additional_op_db.append( OpInfo( "ops.aten._new_zeros_with_same_feature_meta", variant_test_name="functorchonly", dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, supports_autograd=False, supports_forward_ad=False, sample_inputs_func=sample_inputs_new_zeros_with_same_feature_meta, ) ) def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial( make_tensor, dtype=dtype, device=device, requires_grad=requires_grad ) shapes = ((), (2, 3)) memory_format_options = [None, torch.contiguous_format] for shape, memory_format in itertools.product(shapes, memory_format_options): yield SampleInput( make_arg(shape), kwargs={"memory_format": memory_format} if memory_format else {}, ) additional_op_db.extend( [ OpInfo( "bfloat16", op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype DecorateInfo(unittest.expectedFailure, "TestFwdGradients"), DecorateInfo(unittest.expectedFailure, "TestBwdGradients"), DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), DecorateInfo( unittest.skip("Skipped!"), "TestNNCOpInfo", "test_nnc_correctness" ), ), ), OpInfo( "bool", op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "byte", op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, # The autograd test runner cannot handle functions that change dtype supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "char", op=lambda x, *args, **kwargs: x.char(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, # The autograd test runner cannot handle functions that change dtype supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "double", op=lambda x, *args, **kwargs: x.double(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "float", op=lambda x, *args, **kwargs: x.float(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype DecorateInfo(unittest.expectedFailure, "TestFwdGradients"), DecorateInfo(unittest.expectedFailure, "TestBwdGradients"), DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "half", op=lambda x, *args, **kwargs: x.half(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype DecorateInfo(unittest.expectedFailure, "TestFwdGradients"), DecorateInfo(unittest.expectedFailure, "TestBwdGradients"), DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "int", op=lambda x, *args, **kwargs: x.int(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "long", op=lambda x, *args, **kwargs: x.long(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), OpInfo( "short", op=lambda x, *args, **kwargs: x.short(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_out=False, variant_test_name="functorch_no_channels_last", sample_inputs_func=sample_inputs_conversion, supports_autograd=False, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), ), ] )