• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2024 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import inspect
7
8import torch
9from executorch.backends.apple.mps.test.test_mps_utils import TestMPS
10
11
12class TestMPSIndexingOps(TestMPS):
13    def test_mps_indexing_get_1(self):
14        class IndexGet(torch.nn.Module):
15            def __init__(self):
16                super().__init__()
17
18            def forward(self, x):
19                return x[[0, 1, 2], [0, 1, 0]]
20
21        module = IndexGet()
22        model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),)
23
24        self.lower_and_test_with_partitioner(
25            module, model_inputs, func_name=inspect.stack()[0].function[5:]
26        )
27
28    def test_mps_indexing_get_2(self):
29        class IndexGet(torch.nn.Module):
30            def __init__(self):
31                super().__init__()
32
33            def forward(self, x):
34                return x[:, [0, 1, 0]]
35
36        module = IndexGet()
37        model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),)
38
39        self.lower_and_test_with_partitioner(
40            module, model_inputs, func_name=inspect.stack()[0].function[5:]
41        )
42
43    def test_mps_indexing_get_3(self):
44        class IndexGet(torch.nn.Module):
45            def __init__(self):
46                super().__init__()
47
48            def forward(self, x):
49                return x[:, [0, 1, 0], [0, 1, 0]]
50
51        module = IndexGet()
52        model_inputs = (torch.tensor([[[1, 2], [3, 4], [5, 6]]]),)
53
54        self.lower_and_test_with_partitioner(
55            module, model_inputs, func_name=inspect.stack()[0].function[5:]
56        )
57
58    def test_mps_indexing_get_4(self):
59        class IndexGet(torch.nn.Module):
60            def __init__(self):
61                super().__init__()
62
63            def forward(self, x):
64                return x[:, [0, 1, 0], [0, 1, 0]]
65
66        module = IndexGet()
67        model_inputs = (
68            torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),
69        )
70
71        self.lower_and_test_with_partitioner(
72            module, model_inputs, func_name=inspect.stack()[0].function[5:]
73        )
74
75    def test_mps_indexing_get_5(self):
76        class IndexGet(torch.nn.Module):
77            def __init__(self):
78                super().__init__()
79
80            def forward(self, x):
81                return x[:, [0, 4, 2]]
82
83        module = IndexGet()
84        model_inputs = (torch.randn(5, 7, 3),)
85
86        self.lower_and_test_with_partitioner(
87            module, model_inputs, func_name=inspect.stack()[0].function[5:]
88        )
89
90    def test_mps_indexing_get_6(self):
91        class IndexGet(torch.nn.Module):
92            def __init__(self):
93                super().__init__()
94
95            def forward(self, x):
96                return x[:, [[0, 1], [4, 3]]]
97
98        module = IndexGet()
99        model_inputs = (torch.randn(5, 7, 3),)
100
101        self.lower_and_test_with_partitioner(
102            module, model_inputs, func_name=inspect.stack()[0].function[5:]
103        )
104
105    def test_mps_indexing_get_7(self):
106        class IndexGet(torch.nn.Module):
107            def __init__(self):
108                super().__init__()
109
110            def forward(self, x):
111                return x[[0, 4, 2]]
112
113        module = IndexGet()
114        model_inputs = (torch.randn(5, 7, 3),)
115
116        self.lower_and_test_with_partitioner(
117            module, model_inputs, func_name=inspect.stack()[0].function[5:]
118        )
119
120    def test_mps_indexing_get_8(self):
121        class IndexGet(torch.nn.Module):
122            def __init__(self):
123                super().__init__()
124
125            def forward(self, x):
126                return x[[0, 2, 1], :, 0]
127
128        module = IndexGet()
129        model_inputs = (torch.ones(3, 2, 4),)
130
131        self.lower_and_test_with_partitioner(
132            module, model_inputs, func_name=inspect.stack()[0].function[5:]
133        )
134
135    def test_mps_indices2d(self):
136        class IndexGet(torch.nn.Module):
137            def __init__(self):
138                super().__init__()
139
140            def forward(self, x, rows, columns):
141                return x[rows, columns]
142
143        module = IndexGet()
144        x = torch.arange(0, 12).resize(4, 3)
145        rows = torch.tensor([[0, 0], [3, 3]])
146        columns = torch.tensor([[0, 2], [0, 2]])
147        model_inputs = (
148            x,
149            rows,
150            columns,
151        )
152
153        self.lower_and_test_with_partitioner(
154            module, model_inputs, func_name=inspect.stack()[0].function[5:]
155        )
156
157    def test_mps_slicing_using_advanced_index_for_column_0(self):
158        class IndexGet(torch.nn.Module):
159            def __init__(self):
160                super().__init__()
161
162            def forward(self, x):
163                return x[1:4]
164
165        module = IndexGet()
166        model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
167
168        self.lower_and_test_with_partitioner(
169            module, model_inputs, func_name=inspect.stack()[0].function[5:]
170        )
171
172    def test_mps_slicing_using_advanced_index_for_column_1(self):
173        class IndexGet(torch.nn.Module):
174            def __init__(self):
175                super().__init__()
176
177            def forward(self, x):
178                # using advanced index for column
179                return x[1:4, [1, 2]]
180
181        module = IndexGet()
182        model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
183
184        self.lower_and_test_with_partitioner(
185            module, model_inputs, func_name=inspect.stack()[0].function[5:]
186        )
187
188    # def test_boolean_array_indexing(self):
189    #     class IndexGet(torch.nn.Module):
190    #         def __init__(self):
191    #             super().__init__()
192
193    #         def forward(self, x):
194    #             return x[x > 5]
195
196    #     module = IndexGet()
197    #     model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
198
199    #     self.lower_and_test_with_partitioner(
200    #         module, model_inputs, func_name=inspect.stack()[0].function[5:]
201    #     )
202
203    def test_mps_indexing_put_1(self):
204        class IndexPut(torch.nn.Module):
205            def __init__(self):
206                super().__init__()
207
208            def forward(self, x, y, z):
209                x[:, :, y] = z
210                return x
211
212        module = IndexPut()
213        input = torch.ones(1, 8, 128, 8)
214        indices = torch.tensor([1])
215        values = torch.randn(8, 1, 8)
216        model_inputs = (
217            input,
218            indices,
219            values,
220        )
221
222        self.lower_and_test_with_partitioner(
223            module, model_inputs, func_name=inspect.stack()[0].function[5:]
224        )
225
226    def test_mps_indexing_slice_scatter_1(self):
227        class IndexSliceScatter(torch.nn.Module):
228            def __init__(self):
229                super().__init__()
230
231            def forward(self, x, y):
232                return x.slice_scatter(y, start=6)
233
234        module = IndexSliceScatter()
235        input = torch.zeros(8, 8)
236        src = torch.ones(2, 8)
237        model_inputs = (
238            input,
239            src,
240        )
241
242        self.lower_and_test_with_partitioner(
243            module, model_inputs, func_name=inspect.stack()[0].function[5:]
244        )
245
246    def test_mps_indexing_slice_scatter_2(self):
247        class IndexSliceScatter(torch.nn.Module):
248            def __init__(self):
249                super().__init__()
250
251            def forward(self, x, y):
252                return x.slice_scatter(y, dim=1, start=2, end=6, step=2)
253
254        module = IndexSliceScatter()
255        input = torch.zeros(8, 8)
256        src = torch.ones(8, 2)
257        model_inputs = (
258            input,
259            src,
260        )
261
262        self.lower_and_test_with_partitioner(
263            module, model_inputs, func_name=inspect.stack()[0].function[5:]
264        )
265