1# Owner(s): ["module: pt2-dispatcher"] 2from __future__ import annotations 3 4import typing 5from typing import List, Optional, Sequence, Union # noqa: F401 6 7import torch 8from torch import Tensor, types 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12mutates_args = {} 13 14 15class TestInferSchemaWithAnnotation(TestCase): 16 def test_tensor(self): 17 def foo_op(x: torch.Tensor) -> torch.Tensor: 18 return x.clone() 19 20 result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) 21 self.assertEqual(result, "(Tensor x) -> Tensor") 22 23 def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 24 return x.clone() + y 25 26 result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 27 self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor") 28 29 def test_native_types(self): 30 def foo_op(x: int) -> int: 31 return x 32 33 result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) 34 self.assertEqual(result, "(SymInt x) -> SymInt") 35 36 def foo_op_2(x: bool) -> bool: 37 return x 38 39 result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 40 self.assertEqual(result, "(bool x) -> bool") 41 42 def foo_op_3(x: str) -> int: 43 return 1 44 45 result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) 46 self.assertEqual(result, "(str x) -> SymInt") 47 48 def foo_op_4(x: float) -> float: 49 return x 50 51 result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) 52 self.assertEqual(result, "(float x) -> float") 53 54 def test_torch_types(self): 55 def foo_op_1(x: torch.types.Number) -> torch.types.Number: 56 return x 57 58 result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args) 59 self.assertEqual(result, "(Scalar x) -> Scalar") 60 61 def foo_op_2(x: torch.dtype) -> int: 62 return 1 63 64 result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 65 self.assertEqual(result, "(ScalarType x) -> SymInt") 66 67 def foo_op_3(x: torch.device) -> int: 68 return 1 69 70 result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) 71 self.assertEqual(result, "(Device x) -> SymInt") 72 73 def test_type_variants(self): 74 def foo_op_1(x: typing.Optional[int]) -> int: 75 return 1 76 77 result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args) 78 self.assertEqual(result, "(SymInt? x) -> SymInt") 79 80 def foo_op_2(x: typing.Sequence[int]) -> int: 81 return 1 82 83 result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 84 self.assertEqual(result, "(SymInt[] x) -> SymInt") 85 86 def foo_op_3(x: typing.List[int]) -> int: 87 return 1 88 89 result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) 90 self.assertEqual(result, "(SymInt[] x) -> SymInt") 91 92 def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int: 93 return 1 94 95 result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) 96 self.assertEqual(result, "(SymInt[]? x) -> SymInt") 97 98 def foo_op_5(x: typing.Optional[typing.List[int]]) -> int: 99 return 1 100 101 result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args) 102 self.assertEqual(result, "(SymInt[]? x) -> SymInt") 103 104 def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number: 105 return x 106 107 result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args) 108 self.assertEqual(result, "(Scalar x) -> Scalar") 109 110 def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number: 111 return x 112 113 result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args) 114 self.assertEqual(result, "(Scalar x) -> Scalar") 115 116 def test_no_library_prefix(self): 117 def foo_op(x: Tensor) -> Tensor: 118 return x.clone() 119 120 result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) 121 self.assertEqual(result, "(Tensor x) -> Tensor") 122 123 def foo_op_2(x: Tensor) -> torch.Tensor: 124 return x.clone() 125 126 result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 127 self.assertEqual(result, "(Tensor x) -> Tensor") 128 129 def foo_op_3(x: torch.Tensor) -> Tensor: 130 return x.clone() 131 132 result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) 133 self.assertEqual(result, "(Tensor x) -> Tensor") 134 135 def foo_op_4(x: List[int]) -> types.Number: 136 return x[0] 137 138 result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) 139 self.assertEqual(result, "(SymInt[] x) -> Scalar") 140 141 def foo_op_5(x: Optional[int]) -> int: 142 return 1 143 144 result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args) 145 self.assertEqual(result, "(SymInt? x) -> SymInt") 146 147 def foo_op_6(x: Sequence[int]) -> int: 148 return 1 149 150 result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args) 151 self.assertEqual(result, "(SymInt[] x) -> SymInt") 152 153 def foo_op_7(x: List[int]) -> int: 154 return 1 155 156 result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args) 157 self.assertEqual(result, "(SymInt[] x) -> SymInt") 158 159 def foo_op_8(x: Optional[Sequence[int]]) -> int: 160 return 1 161 162 result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args) 163 self.assertEqual(result, "(SymInt[]? x) -> SymInt") 164 165 def foo_op_9(x: Optional[List[int]]) -> int: 166 return 1 167 168 result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args) 169 self.assertEqual(result, "(SymInt[]? x) -> SymInt") 170 171 def foo_op_10(x: Union[int, float, bool]) -> types.Number: 172 return x 173 174 result = torch.library.infer_schema(foo_op_10, mutates_args=mutates_args) 175 self.assertEqual(result, "(Scalar x) -> Scalar") 176 177 def foo_op_11(x: Union[int, bool, float]) -> types.Number: 178 return x 179 180 result = torch.library.infer_schema(foo_op_11, mutates_args=mutates_args) 181 self.assertEqual(result, "(Scalar x) -> Scalar") 182 183 def test_unsupported_annotation(self): 184 with self.assertRaisesRegex( 185 ValueError, 186 r"Unsupported type annotation D. It is not a type.", 187 ): 188 189 def foo_op(x: D) -> Tensor: # noqa: F821 190 return torch.Tensor(x) 191 192 torch.library.infer_schema(foo_op, mutates_args=mutates_args) 193 194 with self.assertRaisesRegex( 195 ValueError, 196 r"Unsupported type annotation E. It is not a type.", 197 ): 198 199 def foo_op_2(x: Tensor) -> E: # noqa: F821 200 return x 201 202 torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) 203 204 205if __name__ == "__main__": 206 run_tests() 207