# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import itertools import os import unittest from collections import namedtuple from functorch_additional_op_db import additional_op_db import torch import torch.utils._pytree as pytree from functorch import vmap from torch.testing._internal.autograd_function_db import autograd_function_db from torch.testing._internal.common_device_type import toleranceOverride from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db from torch.testing._internal.common_modules import module_db from torch.testing._internal.custom_op_db import custom_op_db IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1" def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values): outs = [] out_spec = None for idx in range(batch_size): flat_args, args_spec = pytree.tree_flatten(batched_args) flat_dims, dims_spec = pytree.tree_flatten(in_dims) assert args_spec == dims_spec new_args = [ a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims) ] out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) flat_out, out_spec = pytree.tree_flatten(out) outs.append(flat_out) # use the same out_dim for all outputs if isinstance(out_dim, int): flat_out_dim = [out_dim for _ in flat_out] else: flat_out_dim, _ = pytree.tree_flatten(out_dim) outs = zip(*outs) result = [] for i, out_lst in enumerate(outs): if flat_out_dim[i] is not None: if not all(isinstance(x, torch.Tensor) for x in out_lst): raise ValueError( f"vmap `{op}` must only return " "Tensors. Did you mean to set out_dims= to None for output?" ) result.append(torch.stack(out_lst)) else: # not batched over, result should be the same for all batches result.append(out_lst[0]) return pytree.tree_unflatten(result, out_spec) # Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible # to generalize the loops function but it seemed too complicated for this def loop2( op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, *batched_args, **kwarg_values, ): outs = [] flat_args, args_spec = pytree.tree_flatten(batched_args) flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1) flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2) assert args_spec == dims_spec1 assert args_spec == dims_spec2 assert len(flat_dims1) == len(flat_dims2) for idx1 in range(batch_size1): out_split = [] arg_split = [ a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1) ] for idx2 in range(batch_size2): new_args = [ a.select(in_dim, idx2) if in_dim is not None else a for a, in_dim in zip(arg_split, flat_dims2) ] out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) out_split.append(out) outs.append(out_split) loop_out = [] for out_split in outs: if isinstance(out_split[0], torch.Tensor): loop_out.append(torch.stack(out_split, out_dim1)) else: new_out = [] for idx in range(len(out_split[0])): new_out.append(torch.stack([i[idx] for i in out_split], out_dim1)) loop_out.append(new_out) new_out = [] if isinstance(loop_out, torch.Tensor): new_out = torch.stack(loop_out, out_dim2) else: for idx in range(len(loop_out[0])): new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2)) return new_out def is_valid_inplace_sample_input(sample_input, op, inplace_variant): if inplace_variant is None: return False if sample_input.broadcasts_input: return False if not isinstance(sample_input.input, torch.Tensor): return False # Check if input's dtype matches the output's dtype args = (sample_input.input,) + sample_input.args kwargs = sample_input.kwargs output_dtype = op(*args, **kwargs).dtype return sample_input.input.dtype == output_dtype # This is kind of dangerous, please think carefully before using it. # Known risks: # - the return better not be mutated so it's best to return immutable types # (e.g. prefer tuples to list) # - Don't hash tensors in a global context, that'll keep them around forever def memoize(fn): memo = {} def wrapped(*args): if args not in memo: memo[args] = fn(*args) return memo[args] return wrapped # NB: This is O(2 ** num_tensors). # num_tensors ranges from 1 to 10, with 2-4 being most common. # Try not to extravagate it if you're modifying it. @memoize def get_bdim_choices(num_tensors): choices = [] # full of zeros choices.append((0,) * num_tensors) # All permutations of (-1, None) options = (-1, None) choices.extend(itertools.product(options, repeat=num_tensors)) assert choices[-1] == (None,) * num_tensors return tuple(choices[:-1]) # NB: This is O(2 ** num_tensors). # num_tensors ranges from 1 to 10, with 2-4 being most common. # Try not to extravagate it if you're modifying it. def get_bdim_choices_batch_norm( num_tensors, _, running_mean=None, running_var=None, *args ): choices = [] options = (-1, None) # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified if running_mean is None or running_var is None: choices.append((None,) + (0,) * (num_tensors - 1)) for choice in itertools.product(options, repeat=num_tensors - 1): choices.append((None,) + choice) else: # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but # running_mean/var are unbatched, so this tests all other cases choices.append((0,) * num_tensors) for choice in itertools.product(options, repeat=num_tensors): input_bdim = choice[0] running_mean_bdim = choice[1] running_var_bdim = choice[2] if input_bdim and (not running_mean_bdim or not running_var_bdim): continue choices.append(choice) assert choices[-1] == (None,) * num_tensors return tuple(choices[:-1]) def add_batch_dim(arg, bdim, batch_size=3): assert bdim == 0 or bdim == -1 assert isinstance(arg, torch.Tensor) if bdim == 0: shape = [1] * len(arg.shape) shape.insert(bdim, batch_size) return (arg.repeat(shape), bdim) if bdim == -1: arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous() return (arg, bdim) def construct_in_dims(bdim_choice_for_tensors, is_tensors): result = [] bdim = iter(bdim_choice_for_tensors) for is_tensor in is_tensors: if not is_tensor: result.append(None) continue result.append(next(bdim)) return tuple(result) def is_batch_norm_training(op_name, kwarg_values): batch_norm_fns = ( "nn.functional.batch_norm", "nn.functional.instance_norm", ) # instance norm calls batch norm if op_name not in batch_norm_fns: return False # batch norm and instance norm require the value to be a plain bool default_training = ( op_name == "nn.functional.instance_norm" ) # instance norm defaults to training, batch norm doesn't is_training = tuple( arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool) ) if len(is_training) == 0: return default_training else: assert len(is_training) == 1 return is_training[0] def generate_vmap_inputs( arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2 ): flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] num_tensors = sum(is_tensors) # For Batch Norm, if there's only an input, we can't # batch it since running_mean/var will be seen as unbatched tensors if num_tensors == 1 and is_batch_norm_and_training: return bdim_choices = ( get_bdim_choices_batch_norm(num_tensors, *arg_values) if is_batch_norm_and_training else get_bdim_choices(num_tensors) ) @memoize def get_batched_arg(arg, bdim): assert isinstance(arg, torch.Tensor) assert bdim is not None result, _ = add_batch_dim(arg, bdim, batch_size) return result for bdim_choice in bdim_choices: flat_in_dims = construct_in_dims(bdim_choice, is_tensors) flat_batched_args = tuple( arg if in_dim is None else get_batched_arg(arg, in_dim) for arg, in_dim in zip(flat_args, flat_in_dims) ) batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) yield batched_args, in_dims, kwarg_values def clone_if_tensor(x): if isinstance(x, torch.Tensor): return x.clone() return x # Helper function to compare output of `vmap` against the # `for-loop` version. def _compute_quantities_for_vmap_test( op, orig_batched_args, orig_kwarg_values, in_dims, out_dim, batch_size, compute_loop_out=True, clone_inputs=False, ): def maybe_clone_inputs(): if clone_inputs: batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args) kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values) return batched_args, kwarg_values return orig_batched_args, orig_kwarg_values batched_args, kwarg_values = maybe_clone_inputs() if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None # Used for debugging the resulting operations # from functorch import make_fx # def f(a): # return op(a) # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_args, kwarg_values = maybe_clone_inputs() batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)( *batched_args, **kwarg_values ) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support # added via a manual plumbing you may need to handle this specially. def add_bdim_if_tensor(x): if isinstance(x, torch.Tensor): return x.unsqueeze(1) return x def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) vmapvmap_expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0,) + in_dims batched_args, kwarg_values = maybe_clone_inputs() vmapvmap_output = vmap( vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim )(dummy, *batched_args, **kwarg_values) yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected) # Function with more friendly return types # compared to `_compute_quantities_for_vmap_test` def compute_quantities_for_vmap_test( op, orig_batched_args, orig_kwarg_values, in_dims, out_dim=0, batch_size=2, compute_loop_out=True, clone_inputs=False, ): for quantities in _compute_quantities_for_vmap_test( op, orig_batched_args, orig_kwarg_values, in_dims, out_dim, batch_size, compute_loop_out, clone_inputs, ): yield (quantities[0], quantities[1]) yield (quantities[2], quantities[3]) def get_fallback_and_vmap_exhaustive( op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True, ): out_dim = 0 batch_size = 2 def make_batched(t): if isinstance(t, torch.Tensor): shape = list(t.shape) shape.insert(out_dim, batch_size) return t.expand(*shape) return t # Inputs generated by `generate_vmap_inputs` just copy/expand the unbatched inputs # over the batched dimension. Thus we can compute the expected value once and just # expand it based on the `out_dim` and `batch_size`. expected_unbatched = op(*arg_values, **kwarg_values) expected_batched = pytree.tree_map(make_batched, expected_unbatched) generator = generate_vmap_inputs( arg_values, kwarg_values, is_batch_norm_and_training ) for batched_args, in_dims, kwarg_values in generator: for quantities in _compute_quantities_for_vmap_test( op, batched_args, kwarg_values, in_dims, out_dim, batch_size, compute_loop_out=False, ): assert quantities[1] is None yield (quantities[0], expected_batched) yield (quantities[2], quantities[3]) def opinfo_in_dict(opinfo, d): return (opinfo.name in d) or (f"{opinfo.name}.{opinfo.variant_test_name}" in d) DecorateMeta = namedtuple( "DecorateMeta", [ "op_name", "variant_name", "decorator", "device_type", "dtypes", ], ) def decorate( op_name, variant_name="", *, decorator=None, device_type=None, dtypes=None ): assert decorator is not None return DecorateMeta( op_name=op_name, variant_name=variant_name, decorator=decorator, device_type=device_type, dtypes=dtypes, ) def xfail(op_name, variant_name="", *, device_type=None, dtypes=None): return decorate( op_name=op_name, variant_name=variant_name, decorator=unittest.expectedFailure, device_type=device_type, dtypes=dtypes, ) def skip(op_name, variant_name="", *, device_type=None, dtypes=None): return decorate( op_name=op_name, variant_name=variant_name, decorator=unittest.skip("Skipped!"), device_type=device_type, dtypes=dtypes, ) def skipOps(test_case_name, base_test_name, to_skip): all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db for decorate_meta in to_skip: matching_opinfos = [ o for o in all_opinfos if o.name == decorate_meta.op_name and o.variant_test_name == decorate_meta.variant_name ] assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}" assert len(matching_opinfos) == 1, ( "OpInfos should be uniquely determined by their (name, variant_name). " f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})" ) opinfo = matching_opinfos[0] decorators = list(opinfo.decorators) new_decorator = DecorateInfo( decorate_meta.decorator, test_case_name, base_test_name, device_type=decorate_meta.device_type, dtypes=decorate_meta.dtypes, ) decorators.append(new_decorator) opinfo.decorators = tuple(decorators) # This decorator doesn't modify fn in any way def wrapped(fn): return fn return wrapped def decorateForModules(decorator, module_classes, device_type=None, dtypes=None): # This decorator doesn't modify fn in any way def wrapped( fn, module_classes=module_classes, decorator=decorator, device_type=device_type, dtypes=dtypes, ): name_parts = fn.__qualname__.split(".") assert ( len(name_parts) == 2 ), "Decorator only applies to a test function of a test class" test_case_name, base_test_name = name_parts for module_cls in module_classes: matching_module_infos = [m for m in module_db if m.module_cls == module_cls] assert ( len(matching_module_infos) == 1 ), f"Couldn't find single ModuleInfo for {module_cls}" module_info = matching_module_infos[0] decorators = list(module_info.decorators) new_decorator = DecorateInfo( decorator, test_case_name, base_test_name, device_type=device_type, dtypes=dtypes, ) decorators.append(new_decorator) module_info.decorators = tuple(decorators) return fn return wrapped def expectedFailureIf(condition): def decorator(fn): if condition: return unittest.expectedFailure(fn) return fn return decorator def tol2(op_name, variant_name, override_dct, *, device_type=None): return (op_name, variant_name, override_dct, device_type) def tol1(op_name, override_dct, *, device_type=None): return tol2(op_name, "", override_dct, device_type=device_type) def opsToleranceOverride(test_case_name, base_test_name, overrides): all_opinfos = op_db + additional_op_db for override in overrides: op_name, variant_name, override, device_type = override matching_opinfos = [ o for o in all_opinfos if o.name == op_name and o.variant_test_name == variant_name ] assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}" opinfo = matching_opinfos[0] decorators = list(opinfo.decorators) decorators.append( DecorateInfo( toleranceOverride(override), test_case_name, base_test_name, device_type=device_type, ) ) opinfo.decorators = tuple(decorators) # This decorator doesn't modify fn in any way def wrapped(fn): return fn return wrapped class DisableVmapFallback: def __enter__(self): self.prev_state = torch._C._functorch._is_vmap_fallback_enabled() torch._C._functorch._set_vmap_fallback_enabled(False) def __exit__(self, *ignored): torch._C._functorch._set_vmap_fallback_enabled(self.prev_state) def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False): try: with DisableVmapFallback(): thunk() except Exception: if not dry_run: raise if opinfo.variant_test_name: print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") else: print(f"xfail('{opinfo.name}'),")