from typing import Any import torch @torch.jit.script class MyScriptClass: """Intended to be scripted.""" def __init__(self, x): self.foo = x def set_foo(self, x): self.foo = x @torch.jit.script def uses_script_class(x): """Intended to be scripted.""" foo = MyScriptClass(x) return foo.foo class IdListFeature: def __init__(self) -> None: self.id_list = torch.ones(1, 1) def returns_self(self) -> "IdListFeature": return IdListFeature() class UsesIdListFeature(torch.nn.Module): def forward(self, feature: Any): if isinstance(feature, IdListFeature): return feature.id_list else: return feature