# mypy: ignore-errors import torch from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._utils import wrapper_set_seed import torch.utils._pytree as pytree def make_fx_check( func, args, kwargs, tracing_mode, assert_close=torch.testing.assert_close, randomize_data=False, ): f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs) def run(f, *args, **kwargs): return wrapper_set_seed(f, *args, **kwargs) traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args) msg = ( "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different " "values. This could mean that your abstract impls (meta/FakeTensor impls) " "are incorrect, that your operator is not completely traceable (e.g., " "it relies on some global state), or that there is a bug in make_fx. " "Note that if you passed a python function (and not an operator) to " "make_fx_check, it is still possible that the python function will still " "work with torch.compile because it handles capturing pieces of " "your python code to compile." ) # Randomize the data and run the traced graph with it, to catch bugs # where we may have baked in Tensor data into the trace. # This is not guaranteed to succeed, because `f` might have preconditions # on the values of the inputs, so we just ignore if we used # random data and it fails. if randomize_data: new_args = randomize(new_args) try: expected = run(f, *new_args) except Exception: if randomize_data: return raise result = run(traced_f, *new_args) assert_close(result, expected, msg=msg) # Arguably we should make make_fx promote torch.Size() objects to symbolic shapes. # Absent that, here is our strategy: # # If any argument is a torch.Size(), maybe get dynamic shapes for it by: # - Create a temporary Tensor whose size is the torch.Size() we want. Note that # we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. # - Pass it to make_fx such that it is is converted to a proxy Tensor # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in # symbolic mode, a no-op otherwise) def handle_sizes_for_dynamic_shapes(func, args, kwargs): def f(args, kwargs, extra_args, extra_kwargs): if extra_args: for i, t in extra_args: args[i] = t.size() if extra_kwargs: for k, t in extra_kwargs.items(): kwargs[k] = t.size() return func(*args, **kwargs) extra_args = [] extra_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, torch.Size): extra_args.append((i, torch.empty(arg, device="cpu"))) for key, value in kwargs.items(): if isinstance(value, torch.Size): extra_kwargs[key] = torch.empty(value, device="cpu") return f, args, kwargs, extra_args, extra_kwargs def randomize(args): def transform(x): if not x.dtype.is_floating_point: return x return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad) return pytree.tree_map_only(torch.Tensor, transform, args)