1# mypy: ignore-errors 2 3import logging 4 5from functorch.compile import make_boxed_func 6 7from ..backends.common import aot_autograd 8from .registry import register_backend, register_experimental_backend 9 10 11log = logging.getLogger(__name__) 12 13 14@register_experimental_backend 15def openxla_eval(model, fake_tensor_inputs): 16 return xla_backend_helper(model, fake_tensor_inputs, boxed=False) 17 18 19def openxla_eval_boxed(model, fake_tensor_inputs): 20 return xla_backend_helper(model, fake_tensor_inputs, boxed=True) 21 22 23def xla_backend_helper(model, fake_tensor_inputs, boxed=False): 24 try: 25 import torch_xla.core.dynamo_bridge as bridge 26 except ImportError as e: 27 raise ImportError( 28 "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla" 29 ) from e 30 31 compiled_graph = None 32 33 def fwd(*args): 34 nonlocal model 35 nonlocal compiled_graph 36 if compiled_graph is None: 37 compiled_graph = bridge.extract_compiled_graph(model, args) 38 del model 39 return compiled_graph(*args) 40 41 return make_boxed_func(fwd) if boxed else fwd 42 43 44openxla = aot_autograd( 45 fw_compiler=openxla_eval_boxed, 46) 47register_backend(name="openxla", compiler_fn=openxla) 48