• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: ignore-errors
2
3import logging
4
5from functorch.compile import make_boxed_func
6
7from ..backends.common import aot_autograd
8from .registry import register_backend, register_experimental_backend
9
10
11log = logging.getLogger(__name__)
12
13
14@register_experimental_backend
15def openxla_eval(model, fake_tensor_inputs):
16    return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
17
18
19def openxla_eval_boxed(model, fake_tensor_inputs):
20    return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
21
22
23def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
24    try:
25        import torch_xla.core.dynamo_bridge as bridge
26    except ImportError as e:
27        raise ImportError(
28            "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
29        ) from e
30
31    compiled_graph = None
32
33    def fwd(*args):
34        nonlocal model
35        nonlocal compiled_graph
36        if compiled_graph is None:
37            compiled_graph = bridge.extract_compiled_graph(model, args)
38            del model
39        return compiled_graph(*args)
40
41    return make_boxed_func(fwd) if boxed else fwd
42
43
44openxla = aot_autograd(
45    fw_compiler=openxla_eval_boxed,
46)
47register_backend(name="openxla", compiler_fn=openxla)
48