# Owner(s): ["oncall: jit"] import os import sys from typing import List import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) # Tests that Python slice class is supported in TorchScript class TestSlice(JitTestCase): def test_slice_kwarg(self): def slice_kwarg(x: List[int]): return x[slice(1, stop=2)] with self.assertRaisesRegex( RuntimeError, "Slice does not accept any keyword arguments" ): torch.jit.script(slice_kwarg) def test_slice_three_nones(self): def three_nones(x: List[int]): return x[slice(None, None, None)] self.checkScript(three_nones, (range(10),)) def test_slice_two_nones(self): def two_nones(x: List[int]): return x[slice(None, None)] self.checkScript(two_nones, (range(10),)) def test_slice_one_none(self): def one_none(x: List[int]): return x[slice(None)] self.checkScript(one_none, (range(10),)) def test_slice_stop_only(self): def fn(x: List[int]): return x[slice(5)] self.checkScript(fn, (range(10),)) def test_slice_stop_only_with_nones(self): def fn(x: List[int]): return x[slice(None, 5, None)] self.checkScript(fn, (range(10),)) def test_slice_start_stop(self): def fn(x: List[int]): return x[slice(1, 5)] self.checkScript(fn, (range(10),)) def test_slice_start_stop_with_none(self): def fn(x: List[int]): return x[slice(1, 5, None)] self.checkScript(fn, (range(10),)) def test_slice_start_stop_step(self): def fn(x: List[int]): return x[slice(0, 6, 2)] self.checkScript(fn, (range(10),)) def test_slice_string(self): def fn(x: str): return x[slice(None, 3, 1)] self.checkScript(fn, ("foo_bar",)) def test_slice_tensor(self): def fn(x: torch.Tensor): return x[slice(None, 3, 1)] self.checkScript(fn, (torch.ones(10),)) def test_slice_tensor_multidim(self): def fn(x: torch.Tensor): return x[slice(None, 3, 1), 0] self.checkScript(fn, (torch.ones((10, 10)),)) def test_slice_tensor_multidim_with_dots(self): def fn(x: torch.Tensor): return x[slice(None, 3, 1), ...] self.checkScript(fn, (torch.ones((10, 10)),)) def test_slice_as_variable(self): def fn(x: List[int]): a = slice(1) return x[a] self.checkScript(fn, (range(10),)) def test_slice_stop_clipped(self): def fn(x: List[int]): return x[slice(1000)] self.checkScript(fn, (range(10),)) def test_slice_dynamic_index(self): def t(x): slice1 = x[0:1] zero = 0 one = zero + 1 slice2 = x[zero:one] return slice1 + slice2 self.checkScript(t, (torch.zeros(3, 2, 3),)) def test_tuple_slicing(self): def tuple_slice(a): if bool(a): b = (1, 2, 3, 4) else: b = (4, 3, 2, 1) c = b[-4:4] e = c[1:-1] return e self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True) scripted_fn = torch.jit.script(tuple_slice) self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) tuple_graph = scripted_fn.graph slices = tuple_graph.findAllNodes("prim::TupleConstruct") num_outputs = {len(x.output().type().elements()) for x in slices} # there should be only one tupleSlice with length of 2 self.assertTrue(num_outputs == {2}) self.run_pass("lower_all_tuples", tuple_graph) self.assertTrue("Tuple" not in str(tuple_graph)) def test_module_list_slicing(self): class Bar(torch.nn.Module): def __init__(self, identifier: str): super().__init__() self.identifier = identifier def forward(self): return 0 class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")] self.test = torch.nn.ModuleList(module_list) def forward(self): return self.test[::-2], self.test[1:4:] scripted_foo = torch.jit.script(Foo()) result1, result2 = scripted_foo() self.assertEqual(len(result1), 3) self.assertEqual(result1[0].identifier, "E") self.assertEqual(result1[1].identifier, "C") self.assertEqual(result1[2].identifier, "A") self.assertEqual(len(result2), 3) self.assertEqual(result2[0].identifier, "B") self.assertEqual(result2[1].identifier, "C") self.assertEqual(result2[2].identifier, "D")