• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import sys
3import types
4
5import torch
6
7
8class _XNNPACKEnabled:
9    def __get__(self, obj, objtype):
10        return torch._C._is_xnnpack_enabled()
11
12    def __set__(self, obj, val):
13        raise RuntimeError("Assignment not supported")
14
15
16class XNNPACKEngine(types.ModuleType):
17    def __init__(self, m, name):
18        super().__init__(name)
19        self.m = m
20
21    def __getattr__(self, attr):
22        return self.m.__getattribute__(attr)
23
24    enabled = _XNNPACKEnabled()
25
26
27# This is the sys.modules replacement trick, see
28# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
29sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__)
30