1# 2# Copyright (c) 2024 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import inspect 7 8import torch 9from executorch.backends.apple.mps.test.test_mps_utils import TestMPS 10 11 12class TestMPSIndexingOps(TestMPS): 13 def test_mps_indexing_get_1(self): 14 class IndexGet(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x): 19 return x[[0, 1, 2], [0, 1, 0]] 20 21 module = IndexGet() 22 model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) 23 24 self.lower_and_test_with_partitioner( 25 module, model_inputs, func_name=inspect.stack()[0].function[5:] 26 ) 27 28 def test_mps_indexing_get_2(self): 29 class IndexGet(torch.nn.Module): 30 def __init__(self): 31 super().__init__() 32 33 def forward(self, x): 34 return x[:, [0, 1, 0]] 35 36 module = IndexGet() 37 model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) 38 39 self.lower_and_test_with_partitioner( 40 module, model_inputs, func_name=inspect.stack()[0].function[5:] 41 ) 42 43 def test_mps_indexing_get_3(self): 44 class IndexGet(torch.nn.Module): 45 def __init__(self): 46 super().__init__() 47 48 def forward(self, x): 49 return x[:, [0, 1, 0], [0, 1, 0]] 50 51 module = IndexGet() 52 model_inputs = (torch.tensor([[[1, 2], [3, 4], [5, 6]]]),) 53 54 self.lower_and_test_with_partitioner( 55 module, model_inputs, func_name=inspect.stack()[0].function[5:] 56 ) 57 58 def test_mps_indexing_get_4(self): 59 class IndexGet(torch.nn.Module): 60 def __init__(self): 61 super().__init__() 62 63 def forward(self, x): 64 return x[:, [0, 1, 0], [0, 1, 0]] 65 66 module = IndexGet() 67 model_inputs = ( 68 torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]), 69 ) 70 71 self.lower_and_test_with_partitioner( 72 module, model_inputs, func_name=inspect.stack()[0].function[5:] 73 ) 74 75 def test_mps_indexing_get_5(self): 76 class IndexGet(torch.nn.Module): 77 def __init__(self): 78 super().__init__() 79 80 def forward(self, x): 81 return x[:, [0, 4, 2]] 82 83 module = IndexGet() 84 model_inputs = (torch.randn(5, 7, 3),) 85 86 self.lower_and_test_with_partitioner( 87 module, model_inputs, func_name=inspect.stack()[0].function[5:] 88 ) 89 90 def test_mps_indexing_get_6(self): 91 class IndexGet(torch.nn.Module): 92 def __init__(self): 93 super().__init__() 94 95 def forward(self, x): 96 return x[:, [[0, 1], [4, 3]]] 97 98 module = IndexGet() 99 model_inputs = (torch.randn(5, 7, 3),) 100 101 self.lower_and_test_with_partitioner( 102 module, model_inputs, func_name=inspect.stack()[0].function[5:] 103 ) 104 105 def test_mps_indexing_get_7(self): 106 class IndexGet(torch.nn.Module): 107 def __init__(self): 108 super().__init__() 109 110 def forward(self, x): 111 return x[[0, 4, 2]] 112 113 module = IndexGet() 114 model_inputs = (torch.randn(5, 7, 3),) 115 116 self.lower_and_test_with_partitioner( 117 module, model_inputs, func_name=inspect.stack()[0].function[5:] 118 ) 119 120 def test_mps_indexing_get_8(self): 121 class IndexGet(torch.nn.Module): 122 def __init__(self): 123 super().__init__() 124 125 def forward(self, x): 126 return x[[0, 2, 1], :, 0] 127 128 module = IndexGet() 129 model_inputs = (torch.ones(3, 2, 4),) 130 131 self.lower_and_test_with_partitioner( 132 module, model_inputs, func_name=inspect.stack()[0].function[5:] 133 ) 134 135 def test_mps_indices2d(self): 136 class IndexGet(torch.nn.Module): 137 def __init__(self): 138 super().__init__() 139 140 def forward(self, x, rows, columns): 141 return x[rows, columns] 142 143 module = IndexGet() 144 x = torch.arange(0, 12).resize(4, 3) 145 rows = torch.tensor([[0, 0], [3, 3]]) 146 columns = torch.tensor([[0, 2], [0, 2]]) 147 model_inputs = ( 148 x, 149 rows, 150 columns, 151 ) 152 153 self.lower_and_test_with_partitioner( 154 module, model_inputs, func_name=inspect.stack()[0].function[5:] 155 ) 156 157 def test_mps_slicing_using_advanced_index_for_column_0(self): 158 class IndexGet(torch.nn.Module): 159 def __init__(self): 160 super().__init__() 161 162 def forward(self, x): 163 return x[1:4] 164 165 module = IndexGet() 166 model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 167 168 self.lower_and_test_with_partitioner( 169 module, model_inputs, func_name=inspect.stack()[0].function[5:] 170 ) 171 172 def test_mps_slicing_using_advanced_index_for_column_1(self): 173 class IndexGet(torch.nn.Module): 174 def __init__(self): 175 super().__init__() 176 177 def forward(self, x): 178 # using advanced index for column 179 return x[1:4, [1, 2]] 180 181 module = IndexGet() 182 model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 183 184 self.lower_and_test_with_partitioner( 185 module, model_inputs, func_name=inspect.stack()[0].function[5:] 186 ) 187 188 # def test_boolean_array_indexing(self): 189 # class IndexGet(torch.nn.Module): 190 # def __init__(self): 191 # super().__init__() 192 193 # def forward(self, x): 194 # return x[x > 5] 195 196 # module = IndexGet() 197 # model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 198 199 # self.lower_and_test_with_partitioner( 200 # module, model_inputs, func_name=inspect.stack()[0].function[5:] 201 # ) 202 203 def test_mps_indexing_put_1(self): 204 class IndexPut(torch.nn.Module): 205 def __init__(self): 206 super().__init__() 207 208 def forward(self, x, y, z): 209 x[:, :, y] = z 210 return x 211 212 module = IndexPut() 213 input = torch.ones(1, 8, 128, 8) 214 indices = torch.tensor([1]) 215 values = torch.randn(8, 1, 8) 216 model_inputs = ( 217 input, 218 indices, 219 values, 220 ) 221 222 self.lower_and_test_with_partitioner( 223 module, model_inputs, func_name=inspect.stack()[0].function[5:] 224 ) 225 226 def test_mps_indexing_slice_scatter_1(self): 227 class IndexSliceScatter(torch.nn.Module): 228 def __init__(self): 229 super().__init__() 230 231 def forward(self, x, y): 232 return x.slice_scatter(y, start=6) 233 234 module = IndexSliceScatter() 235 input = torch.zeros(8, 8) 236 src = torch.ones(2, 8) 237 model_inputs = ( 238 input, 239 src, 240 ) 241 242 self.lower_and_test_with_partitioner( 243 module, model_inputs, func_name=inspect.stack()[0].function[5:] 244 ) 245 246 def test_mps_indexing_slice_scatter_2(self): 247 class IndexSliceScatter(torch.nn.Module): 248 def __init__(self): 249 super().__init__() 250 251 def forward(self, x, y): 252 return x.slice_scatter(y, dim=1, start=2, end=6, step=2) 253 254 module = IndexSliceScatter() 255 input = torch.zeros(8, 8) 256 src = torch.ones(8, 2) 257 model_inputs = ( 258 input, 259 src, 260 ) 261 262 self.lower_and_test_with_partitioner( 263 module, model_inputs, func_name=inspect.stack()[0].function[5:] 264 ) 265