# Owner(s): ["oncall: jit"] import os import sys from collections import OrderedDict from typing import Any, List, Tuple import torch import torch.nn as nn from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestModuleContainers(JitTestCase): def test_sequential_intermediary_types(self): class A(torch.nn.Module): def forward(self, x): return x + 3 class B(torch.nn.Module): def forward(self, x): return {"1": x} class C(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Sequential(A(), B()) def forward(self, x): return self.foo(x) self.checkModule(C(), (torch.tensor(1),)) def test_moduledict(self): class Inner(torch.nn.Module): def forward(self, x): return x + 10 class Inner2(torch.nn.Module): def forward(self, x): return x * 2 class Inner3(torch.nn.Module): def forward(self, x): return (x - 4) * 3 class M(torch.nn.Module): def __init__(self) -> None: super().__init__() modules = OrderedDict( [ ("one", Inner()), ("two", Inner2()), ("three", Inner3()), ] ) self.moduledict = nn.ModuleDict(modules) def forward(self, x, skip_name): # type: (Tensor, str) names = torch.jit.annotate(List[str], []) values = [] for name in self.moduledict: names.append(name) for name, mod in self.moduledict.items(): if name != skip_name: names.append(name) x = mod(x) values.append(x) for mod in self.moduledict.values(): x = mod(x) values.append(x) for key in self.moduledict.keys(): names.append(key) return x, names class M2(M): def forward(self, x, skip_name): # type: (Tensor, str) names = torch.jit.annotate(List[str], []) values = [] x2 = x iter = 0 for name in self.moduledict: names.append(name) for i, (name, mod) in enumerate(self.moduledict.items()): iter += i if name != skip_name: names.append(name) x = mod(x) values.append(x) for i, mod in enumerate(self.moduledict.values()): iter += i x = mod(x) values.append(x) for i, key in enumerate(self.moduledict.keys()): iter += i names.append(key) for mod, mod in zip(self.moduledict.values(), self.moduledict.values()): iter += i x2 = mod(mod(x2)) return x, x2, names, iter for name in ["", "one", "two", "three"]: inp = torch.tensor(1) self.checkModule(M(), (inp, name)) self.checkModule(M2(), (inp, name)) def test_custom_container_forward(self): class Inner(torch.nn.Module): def forward(self, x): return x + 10 class CustomSequential(nn.Sequential): def __init__(self) -> None: super().__init__(nn.ReLU(), Inner()) def forward(self, x): x = x + 3 for mod in self: x = mod(x) return x - 5 self.checkModule(CustomSequential(), (torch.tensor(0.5),)) class CustomModuleList(nn.ModuleList): def __init__(self) -> None: super().__init__([nn.ReLU(), Inner()]) def forward(self, x): x = x + 3 for mod in self: x = mod(x) return x - 5 self.checkModule(CustomModuleList(), (torch.tensor(0.5),)) class CustomModuleDict(nn.ModuleDict): def __init__(self) -> None: super().__init__( OrderedDict( [ ("one", Inner()), ("two", nn.ReLU()), ("three", Inner()), ] ) ) def forward(self, x): x = x + 3 names = torch.jit.annotate(List[str], []) for name, mod in self.items(): x = mod(x) names.append(name) return names, x - 5 self.checkModule(CustomModuleDict(), (torch.tensor(0.5),)) def test_script_module_list_sequential(self): class M(torch.jit.ScriptModule): def __init__(self, mod_list): super().__init__() self.mods = mod_list @torch.jit.script_method def forward(self, v): for m in self.mods: v = m(v) return v with torch.jit.optimized_execution(False): m = M(nn.Sequential(nn.ReLU())) self.assertExportImportModule(m, (torch.randn(2, 2),)) def test_script_modulelist_index(self): class Sub(torch.nn.Module): def __init__(self, i): super().__init__() self.i = i def forward(self, thing): return thing - self.i class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods = nn.ModuleList([Sub(i) for i in range(10)]) def forward(self, v): v = self.mods[4].forward(v) v = self.mods[-1].forward(v) v = self.mods[-9].forward(v) return v x = torch.tensor(1) self.checkModule(M(), (x,)) class MForward(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods = nn.ModuleList([Sub(i) for i in range(10)]) def forward(self, v): v = self.mods[4](v) v = self.mods[-1](v) v = self.mods[-9](v) return v self.checkModule(MForward(), (torch.tensor(1),)) class M2(M): def forward(self, v): return self.mods[-11].forward(v) with self.assertRaisesRegexWithHighlight( Exception, "Index -11 out of range", "self.mods[-11]" ): torch.jit.script(M2()) class M3(M): def forward(self, v): i = 3 return self.mods[i].forward(v) with self.assertRaisesRegexWithHighlight( Exception, "Enumeration is supported", "self.mods[i]" ): torch.jit.script(M3()) class M4(M): def forward(self, v): i = 3 return self.mods[i].forward(v) with self.assertRaisesRegex(Exception, "will fail because i is not a literal"): torch.jit.script(M4()) def test_module_interface_special_methods(self): class CustomModuleInterface(torch.nn.Module): pass class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList): def __init__(self, modules=None): CustomModuleInterface.__init__(self) torch.nn.ModuleList.__init__(self, modules) class CustomSequential(CustomModuleInterface, torch.nn.Sequential): def __init__(self, modules=None): CustomModuleInterface.__init__(self) torch.nn.Sequential.__init__(self, modules) class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict): def __init__(self, modules=None): CustomModuleInterface.__init__(self) torch.nn.ModuleDict.__init__(self, modules) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() # work around aliasing issue for 'is' operator by scripting ReLU up front self.submod = torch.jit.script(torch.nn.ReLU()) self.modulelist = CustomModuleList([self.submod]) self.sequential = CustomSequential(self.submod) self.moduledict = CustomModuleDict({"submod": self.submod}) def forward(self, inputs): assert ( self.modulelist[0] is self.submod ), "__getitem__ failing for ModuleList" assert len(self.modulelist) == 1, "__len__ failing for ModuleList" for module in self.modulelist: assert module is self.submod, "__iter__ failing for ModuleList" assert ( self.sequential[0] is self.submod ), "__getitem__ failing for Sequential" assert len(self.sequential) == 1, "__len__ failing for Sequential" for module in self.sequential: assert module is self.submod, "__iter__ failing for Sequential" assert ( self.moduledict["submod"] is self.submod ), "__getitem__ failing for ModuleDict" assert len(self.moduledict) == 1, "__len__ failing for ModuleDict" # note: unable to index moduledict with a string variable currently i = 0 for key in self.moduledict: i += 1 assert i == len(self.moduledict), "iteration failing for ModuleDict" assert "submod" in self.moduledict, "__contains__ fails for ModuleDict" for key in self.moduledict.keys(): assert key == "submod", "keys() fails for ModuleDict" for item in self.moduledict.items(): assert item[0] == "submod", "items() fails for ModuleDict" assert item[1] is self.submod, "items() fails for ModuleDict" for value in self.moduledict.values(): assert value is self.submod, "values() fails for ModuleDict" return inputs m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) def test_special_method_with_override(self): class CustomModuleInterface(torch.nn.Module): pass class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList): def __init__(self, modules=None): CustomModuleInterface.__init__(self) torch.nn.ModuleList.__init__(self, modules) def __len__(self): # this is arbitrary, just to check that the overridden py __len__ from # CustomModuleList takes precedence over the automatically generated # __len__ added by the jit compiler return 2 class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() # work around aliasing issue for 'is' operator by scripting ReLU up front self.submod = torch.jit.script(torch.nn.ReLU()) self.modulelist = CustomModuleList([self.submod]) def forward(self, inputs): assert len(self.modulelist) == 2, "__len__ failing for ModuleList" return inputs m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) mm = torch.jit.script(m) def test_moduledict_getitem(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.relu = torch.jit.script(torch.nn.ReLU()) self.tanh = torch.jit.script(torch.nn.Tanh()) self.moduledict = torch.nn.ModuleDict( {"relu": self.relu, "tanh": self.tanh} ) def forward(self, input): assert self.moduledict["relu"] is self.relu assert self.moduledict["tanh"] is self.tanh return input m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) def test_moduledict_keyerror(self): class BadModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None}) def forward(self, input): assert self.moduledict["blah"] == "blah", "this is a keyerror" with self.assertRaisesRegexWithHighlight( RuntimeError, "Key Error, blah", 'self.moduledict["blah"' ): b = BadModule() torch.jit.script(b) class AnotherBadModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None}) def forward(self, input): idx = "blah" assert self.moduledict[idx] == "blah", "this is a string literal error" with self.assertRaisesRegexWithHighlight( RuntimeError, "Unable to extract string literal index. " "ModuleDict indexing is only supported with string literals. " "For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail " "because i is not a literal.", "self.moduledict[idx]", ): b = AnotherBadModule() torch.jit.script(b) def test_normal_list_attribute_with_modules_error(self): """ Test that an attempt to script a module with a regular list attribute containing other modules fails with a relevant error message. """ class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = [torch.nn.ReLU(), torch.nn.ReLU()] def forward(self): return len(self.a) error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module" with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"): torch.jit.script(Mod()) def test_empty_dict_override_contains(self): class CustomModuleInterface(torch.nn.Module): pass class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict): def __init__(self, modules=None): CustomModuleInterface.__init__(self) torch.nn.ModuleDict.__init__(self, modules) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() # work around aliasing issue for 'is' operator by scripting ReLU up front self.submod = torch.jit.script(torch.nn.ReLU()) self.moduledict = CustomModuleDict() def forward(self, inputs): assert ( "submod" not in self.moduledict ), "__contains__ fails for ModuleDict" return inputs m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) def test_typed_module_dict(self): """ Test that a type annotation can be provided for a ModuleDict that allows non-static indexing. """ @torch.jit.interface class ModuleInterface(torch.nn.Module): def forward(self, inp: Any) -> Any: pass class ImplementsInterface(torch.nn.Module): def forward(self, inp: Any) -> Any: if isinstance(inp, torch.Tensor): return torch.max(inp, dim=0) return inp class DoesNotImplementInterface(torch.nn.Module): def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.max(inp, dim=0) # Test annotation of submodule. class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) def forward(self, x: torch.Tensor, key: str) -> Any: value: ModuleInterface = self.d[key] return value.forward(x) m = Mod() self.checkModule(m, (torch.randn(2, 2), "module")) # Test annotation of self. class ModDict(torch.nn.ModuleDict): def __init__(self) -> None: super().__init__({"module": ImplementsInterface()}) def forward(self, x: torch.Tensor, key: str) -> Any: submodule: ModuleInterface = self[key] return submodule.forward(x) m = ModDict() self.checkModule(m, (torch.randn(2, 2), "module")) # Test error message thrown when annotated attribute does not comply with the # annotation. class ModWithWrongAnnotation(torch.nn.ModuleDict): def __init__(self) -> None: super().__init__() self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()}) def forward(self, x: torch.Tensor, key: str) -> Any: submodule: ModuleInterface = self.d[key] return submodule.forward(x) with self.assertRaisesRegexWithHighlight( RuntimeError, r"Attribute module is not of annotated type", "self.d[key]" ): torch.jit.script(ModWithWrongAnnotation()) def test_typed_module_list(self): """ Test that a type annotation can be provided for a ModuleList that allows non-static indexing. """ @torch.jit.interface class ModuleInterface(torch.nn.Module): def forward(self, inp: Any) -> Any: pass class ImplementsInterface(torch.nn.Module): def forward(self, inp: Any) -> Any: if isinstance(inp, torch.Tensor): return torch.max(inp, dim=0) return inp class DoesNotImplementInterface(torch.nn.Module): def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.max(inp, dim=0) # Test annotation of submodule. class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.l = torch.nn.ModuleList([ImplementsInterface()]) def forward(self, x: torch.Tensor, idx: int) -> Any: value: ModuleInterface = self.l[idx] return value.forward(x) m = Mod() self.checkModule(m, (torch.randn(2, 2), 0)) # Test annotation of self. class ModList(torch.nn.ModuleList): def __init__(self) -> None: super().__init__([ImplementsInterface()]) def forward(self, x: torch.Tensor, idx: int) -> Any: submodule: ModuleInterface = self[idx] return submodule.forward(x) m = ModList() self.checkModule(m, (torch.randn(2, 2), 0)) # Test error message thrown when annotated attribute does not comply with the # annotation. class ModWithWrongAnnotation(torch.nn.ModuleList): def __init__(self) -> None: super().__init__() self.l = torch.nn.ModuleList([DoesNotImplementInterface()]) def forward(self, x: torch.Tensor, idx: int) -> Any: submodule: ModuleInterface = self.l[idx] return submodule.forward(x) with self.assertRaisesRegexWithHighlight( RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]" ): torch.jit.script(ModWithWrongAnnotation()) def test_module_properties(self): class ModuleWithProperties(torch.nn.Module): __jit_unused_properties__ = ["ignored_attr"] def __init__(self, a: int): super().__init__() self.a = a def forward(self, a: int, b: int): self.attr = a + b return self.attr @property def attr(self): return self.a @property def ignored_attr(self): return sum([self.a]) @torch.jit.unused @property def ignored_attr_2(self): return sum([self.a]) @ignored_attr_2.setter def ignored_attr_2(self, value): self.a = sum([self.a]) @attr.setter def attr(self, a: int): if a > 0: self.a = a else: self.a = 0 class ModuleWithNoSetter(torch.nn.Module): def __init__(self, a: int): super().__init__() self.a = a def forward(self, a: int, b: int): self.attr + a + b @property def attr(self): return self.a + 1 self.checkModule( ModuleWithProperties(5), ( 5, 6, ), ) self.checkModule( ModuleWithProperties(5), ( -5, -6, ), ) self.checkModule( ModuleWithNoSetter(5), ( 5, 6, ), ) self.checkModule( ModuleWithNoSetter(5), ( -5, -6, ), ) mod = ModuleWithProperties(3) scripted_mod = torch.jit.script(mod) with self.assertRaisesRegex(AttributeError, "has no attribute"): scripted_mod.ignored_attr def test_module_inplace_construct(self): class M(nn.Module): def __init__(self, start: int): super().__init__() self.linear = nn.Linear(3, 3) self.attribute = start self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float)) def method(self) -> int: return self.attribute @torch.jit.unused def unused_method(self): return self.attribute + self.attribute def forward(self, x): return self.linear(self.linear(x)) class N(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(4, 4) @torch.jit.ignore def ignored_method(self, x): return x def forward(self, x): return self.linear(x) m = torch.jit.script(M(3)) n = torch.jit.script(N()) n._reconstruct(m._c) inp = torch.rand((3)) # Check that both modules produce the same output. with torch.no_grad(): m_out = m(inp) n_out = n(inp) self.assertEqual(m_out, n_out) # Check that ignored method is still intact. self.assertEqual(inp, n.ignored_method(inp)) def test_parameterlist_script_getitem(self): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) self.parameter_list = nn.ParameterList( [nn.Parameter(torch.zeros(1)) for _ in range(10)] ) def forward(self, x): self.module_list[0] self.parameter_list[0] return x self.checkModule(MyModule(), (torch.zeros(1))) def test_parameterlist_script_iter(self): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) self.parameter_list = nn.ParameterList( [nn.Parameter(torch.zeros(1)) for _ in range(10)] ) def forward(self, x): r = x for i, p in enumerate(self.parameter_list): r = r + p + i return r self.checkModule(MyModule(), (torch.zeros(1),)) def test_parameterdict_script_getitem(self): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.parameter_dict = nn.ParameterDict( {k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]} ) def forward(self, x): return ( self.parameter_dict["a"] * x + self.parameter_dict["b"] * self.parameter_dict["c"] ) self.checkModule(MyModule(), (torch.ones(1),))