• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: pt2-dispatcher"]
2from __future__ import annotations
3
4import typing
5from typing import List, Optional, Sequence, Union  # noqa: F401
6
7import torch
8from torch import Tensor, types
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12mutates_args = {}
13
14
15class TestInferSchemaWithAnnotation(TestCase):
16    def test_tensor(self):
17        def foo_op(x: torch.Tensor) -> torch.Tensor:
18            return x.clone()
19
20        result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
21        self.assertEqual(result, "(Tensor x) -> Tensor")
22
23        def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
24            return x.clone() + y
25
26        result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
27        self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor")
28
29    def test_native_types(self):
30        def foo_op(x: int) -> int:
31            return x
32
33        result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
34        self.assertEqual(result, "(SymInt x) -> SymInt")
35
36        def foo_op_2(x: bool) -> bool:
37            return x
38
39        result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
40        self.assertEqual(result, "(bool x) -> bool")
41
42        def foo_op_3(x: str) -> int:
43            return 1
44
45        result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
46        self.assertEqual(result, "(str x) -> SymInt")
47
48        def foo_op_4(x: float) -> float:
49            return x
50
51        result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
52        self.assertEqual(result, "(float x) -> float")
53
54    def test_torch_types(self):
55        def foo_op_1(x: torch.types.Number) -> torch.types.Number:
56            return x
57
58        result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
59        self.assertEqual(result, "(Scalar x) -> Scalar")
60
61        def foo_op_2(x: torch.dtype) -> int:
62            return 1
63
64        result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
65        self.assertEqual(result, "(ScalarType x) -> SymInt")
66
67        def foo_op_3(x: torch.device) -> int:
68            return 1
69
70        result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
71        self.assertEqual(result, "(Device x) -> SymInt")
72
73    def test_type_variants(self):
74        def foo_op_1(x: typing.Optional[int]) -> int:
75            return 1
76
77        result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
78        self.assertEqual(result, "(SymInt? x) -> SymInt")
79
80        def foo_op_2(x: typing.Sequence[int]) -> int:
81            return 1
82
83        result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
84        self.assertEqual(result, "(SymInt[] x) -> SymInt")
85
86        def foo_op_3(x: typing.List[int]) -> int:
87            return 1
88
89        result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
90        self.assertEqual(result, "(SymInt[] x) -> SymInt")
91
92        def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int:
93            return 1
94
95        result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
96        self.assertEqual(result, "(SymInt[]? x) -> SymInt")
97
98        def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
99            return 1
100
101        result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
102        self.assertEqual(result, "(SymInt[]? x) -> SymInt")
103
104        def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number:
105            return x
106
107        result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
108        self.assertEqual(result, "(Scalar x) -> Scalar")
109
110        def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number:
111            return x
112
113        result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
114        self.assertEqual(result, "(Scalar x) -> Scalar")
115
116    def test_no_library_prefix(self):
117        def foo_op(x: Tensor) -> Tensor:
118            return x.clone()
119
120        result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
121        self.assertEqual(result, "(Tensor x) -> Tensor")
122
123        def foo_op_2(x: Tensor) -> torch.Tensor:
124            return x.clone()
125
126        result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
127        self.assertEqual(result, "(Tensor x) -> Tensor")
128
129        def foo_op_3(x: torch.Tensor) -> Tensor:
130            return x.clone()
131
132        result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
133        self.assertEqual(result, "(Tensor x) -> Tensor")
134
135        def foo_op_4(x: List[int]) -> types.Number:
136            return x[0]
137
138        result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
139        self.assertEqual(result, "(SymInt[] x) -> Scalar")
140
141        def foo_op_5(x: Optional[int]) -> int:
142            return 1
143
144        result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
145        self.assertEqual(result, "(SymInt? x) -> SymInt")
146
147        def foo_op_6(x: Sequence[int]) -> int:
148            return 1
149
150        result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
151        self.assertEqual(result, "(SymInt[] x) -> SymInt")
152
153        def foo_op_7(x: List[int]) -> int:
154            return 1
155
156        result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
157        self.assertEqual(result, "(SymInt[] x) -> SymInt")
158
159        def foo_op_8(x: Optional[Sequence[int]]) -> int:
160            return 1
161
162        result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
163        self.assertEqual(result, "(SymInt[]? x) -> SymInt")
164
165        def foo_op_9(x: Optional[List[int]]) -> int:
166            return 1
167
168        result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)
169        self.assertEqual(result, "(SymInt[]? x) -> SymInt")
170
171        def foo_op_10(x: Union[int, float, bool]) -> types.Number:
172            return x
173
174        result = torch.library.infer_schema(foo_op_10, mutates_args=mutates_args)
175        self.assertEqual(result, "(Scalar x) -> Scalar")
176
177        def foo_op_11(x: Union[int, bool, float]) -> types.Number:
178            return x
179
180        result = torch.library.infer_schema(foo_op_11, mutates_args=mutates_args)
181        self.assertEqual(result, "(Scalar x) -> Scalar")
182
183    def test_unsupported_annotation(self):
184        with self.assertRaisesRegex(
185            ValueError,
186            r"Unsupported type annotation D. It is not a type.",
187        ):
188
189            def foo_op(x: D) -> Tensor:  # noqa: F821
190                return torch.Tensor(x)
191
192            torch.library.infer_schema(foo_op, mutates_args=mutates_args)
193
194        with self.assertRaisesRegex(
195            ValueError,
196            r"Unsupported type annotation E. It is not a type.",
197        ):
198
199            def foo_op_2(x: Tensor) -> E:  # noqa: F821
200                return x
201
202            torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
203
204
205if __name__ == "__main__":
206    run_tests()
207