• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import torch
3from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
4from torch.overrides import TorchFunctionMode
5
6
7class AutogradStateOpsFailSafeguard(TorchFunctionMode):
8    """
9    Detect grad state ops during exporting the graph and fail the process by
10    raising an error, to avoid unexpected behavior. Those grad mode ops could be:
11    `torch.no_grad`
12    `torch.enable_grad`
13    `torch.set_grad_enabled`
14
15    Export with predispatch mode is exempted.
16    """
17
18    def __torch_function__(self, func, types, args=(), kwargs=None):
19        kwargs = kwargs or {}
20        unsupported_grad_mode_ops = [
21            torch._C._set_grad_enabled,
22        ]
23        # It's only enabled while tracing, by confirming the torch dispatch mode is
24        # any active PROXY. This is to allow the autograd ops out of tracing.
25        current_state = torch._C.is_grad_enabled()
26        if func in unsupported_grad_mode_ops:
27            assert len(args) == 1
28            changed_state = args[0]
29            mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
30            # Intend to check if it's not the pre_dispatch mode. It's allowed to use
31            # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
32            if (
33                mode
34                and isinstance(mode, ProxyTorchDispatchMode)
35                and not mode.pre_dispatch
36                and changed_state != current_state
37            ):
38                raise RuntimeError(
39                    f"Encountered autograd state manager op {func} trying to change global autograd state "
40                    "while exporting. This is unsafe because we don't capture this op in torch.export "
41                    "today, hence we can't reflect the user intention soundly. You can fix this by "
42                    "adding a torch.no_grad() context around the export call."
43                )
44        return func(*args, **kwargs)
45