1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_tensor_setitem """ 16import numpy as onp 17import pytest 18 19from mindspore import Tensor, context 20from mindspore.nn import Cell 21from mindspore import dtype as mstype 22 23 24def setup_module(): 25 context.set_context(mode=context.PYNATIVE_MODE) 26 27 28def setup_testcase(input_np, case_fn): 29 input_ms = Tensor(input_np) 30 31 class TensorSetItem(Cell): 32 def construct(self, x): 33 return case_fn(x) 34 35 class NumpySetItem(): 36 def __call__(self, x): 37 return case_fn(x) 38 39 out_ms = TensorSetItem()(input_ms) 40 out_np = NumpySetItem()(input_np) 41 assert onp.all(out_ms.asnumpy() == out_np) 42 43 44class TensorSetItemByList(Cell): 45 def construct(self, x): 46 x[[0, 1], [1, 2], [1, 3]] = [3, 4] 47 x[([0, 1], [0, 2], [1, 1])] = [10, 5] 48 x[[0, 1], ..., [0, 1]] = 4 49 return x 50 51 52class NumpySetItemByList(): 53 def __call__(self, x): 54 x[[0, 1], [1, 2], [1, 3]] = [3, 4] 55 x[([0, 1], [0, 2], [1, 1])] = [10, 5] 56 x[[0, 1], ..., [0, 1]] = 4 57 return x 58 59 60@pytest.mark.level1 61@pytest.mark.platform_arm_ascend_training 62@pytest.mark.platform_x86_ascend_training 63@pytest.mark.platform_x86_gpu_training 64@pytest.mark.env_onecard 65def test_setitem_by_list(): 66 x = onp.ones((2, 3, 4), dtype=onp.float32) 67 68 def cases(x): 69 x[[0, 1], [1, 2], [1, 3]] = [3, 4] 70 x[([0, 1], [0, 2], [1, 1])] = [10, 5] 71 x[[0, 1], ..., [0, 1]] = 4 72 return x 73 setup_testcase(x, cases) 74 75 76@pytest.mark.level1 77@pytest.mark.platform_arm_ascend_training 78@pytest.mark.platform_x86_ascend_training 79@pytest.mark.platform_x86_gpu_training 80@pytest.mark.env_onecard 81def test_setitem_with_sequence(): 82 x = onp.ones((2, 3, 4), dtype=onp.float32) 83 84 def cases(x): 85 x[...] = [3] 86 x[..., 1] = ([1, 2, 3], [4, 5, 6]) 87 x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) 88 x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) 89 return x 90 setup_testcase(x, cases) 91 92 93@pytest.mark.level1 94@pytest.mark.platform_arm_ascend_training 95@pytest.mark.platform_x86_ascend_training 96@pytest.mark.platform_x86_gpu_training 97@pytest.mark.env_onecard 98def test_setitem_dtype(): 99 x = onp.ones((2, 3, 4), dtype=onp.float32) 100 101 def cases(x): 102 x[...] = 3 103 x[..., 1] = 3.0 104 x[0] = True 105 x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11]) 106 return x 107 setup_testcase(x, cases) 108 109 110@pytest.mark.level1 111@pytest.mark.platform_arm_ascend_training 112@pytest.mark.platform_x86_ascend_training 113@pytest.mark.platform_x86_gpu_training 114@pytest.mark.env_onecard 115def test_setitem_by_tuple_with_int(): 116 x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) 117 118 def cases(x): 119 x[..., 2, False, 1] = -1 120 x[0, True, 0, None, True] = -2 121 x[0, ..., None] = -3 122 x[..., 0, None, 1, True, True, None] = -4 123 return x 124 setup_testcase(x, cases) 125 126 127@pytest.mark.level1 128@pytest.mark.platform_arm_ascend_training 129@pytest.mark.platform_x86_ascend_training 130@pytest.mark.platform_x86_gpu_training 131@pytest.mark.env_onecard 132def test_setitem_by_tuple_with_list(): 133 x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) 134 135 def cases(x): 136 x[..., 2, False, 1] = [-1] 137 x[0, True, 0, None, True] = [-2, -2, -2, -2] 138 x[0, ..., None] = [[-3], [-3], [-3], [-3]] 139 x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]] 140 x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]] 141 return x 142 setup_testcase(x, cases) 143 144 145@pytest.mark.level1 146@pytest.mark.platform_arm_ascend_training 147@pytest.mark.platform_x86_ascend_training 148@pytest.mark.platform_x86_gpu_training 149@pytest.mark.env_onecard 150def test_setitem_by_nested_unit_list(): 151 x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) 152 153 def cases(x): 154 x[[[[0]]], True] = -1 155 x[[1], ..., [[[[2]]]]] = -2 156 x[0, [[[2]]], [1]] = -3 157 return x 158 setup_testcase(x, cases) 159 160 161@pytest.mark.level0 162@pytest.mark.platform_arm_ascend_training 163@pytest.mark.platform_x86_ascend_training 164@pytest.mark.platform_x86_gpu_training 165@pytest.mark.env_onecard 166def test_setitem_with_broadcast(): 167 x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32) 168 v1 = onp.full((1, 4, 5), -1).tolist() 169 v2 = onp.full((4, 1, 6), -2).tolist() 170 171 def cases(x): 172 x[..., 4] = v1 173 x[0, 2] = v2 174 x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]] 175 x[0, ..., 1, 3, 5] = -4 176 return x 177 setup_testcase(x, cases) 178 179 180@pytest.mark.level1 181@pytest.mark.platform_arm_ascend_training 182@pytest.mark.platform_x86_ascend_training 183@pytest.mark.platform_x86_gpu_training 184@pytest.mark.env_onecard 185def test_setitem_mul_by_scalar(): 186 x = onp.ones((4, 5), dtype=onp.float32) 187 188 def cases(x): 189 x[1, :] = x[1, :]*2 190 x[:, 2] = x[:, 3]*3.0 191 return x 192 setup_testcase(x, cases) 193 194 195@pytest.mark.level1 196@pytest.mark.platform_arm_ascend_training 197@pytest.mark.platform_x86_ascend_training 198@pytest.mark.platform_x86_gpu_training 199@pytest.mark.env_onecard 200def test_setitem_by_slice(): 201 x = onp.ones((3, 4, 5), dtype=onp.float32) 202 203 def cases(x): 204 x[1:2] = 2 205 x[-3:1] = 3 206 x[-10:3:2] = 4 207 x[5:0:3] = 5 208 x[5:5:5] = 6 209 x[-1:2] = 7 210 x[1:0:-1] = 8 211 return x 212 setup_testcase(x, cases) 213 214 215@pytest.mark.level1 216@pytest.mark.platform_arm_ascend_training 217@pytest.mark.platform_x86_ascend_training 218@pytest.mark.platform_x86_gpu_training 219@pytest.mark.env_onecard 220def test_setitem_by_tuple_of_slices(): 221 x = onp.ones((3, 4, 5), dtype=onp.float32) 222 223 def cases(x): 224 x[1:2, 2] = 2 225 x[0, -4:1] = 3 226 x[1, -10:3:2] = 4 227 x[5:0:3, 3] = 5 228 x[1:1, 2:2] = 6 229 return x 230 setup_testcase(x, cases) 231 232 233class TensorItemSetWithNumber(Cell): 234 def construct(self, tensor, number_value): 235 ret = tensor.itemset(number_value) 236 return ret 237 238 239@pytest.mark.level1 240@pytest.mark.platform_arm_ascend_training 241@pytest.mark.platform_x86_ascend_training 242@pytest.mark.platform_x86_gpu_training 243@pytest.mark.env_onecard 244def test_itemset_with_number(): 245 net = TensorItemSetWithNumber() 246 input_1d_np = onp.ndarray([1]).astype(onp.float32) 247 input_1d_ms = Tensor(input_1d_np, mstype.float32) 248 249 input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) 250 input_3d_ms = Tensor(input_3d_np, mstype.float32) 251 252 value_np_1, value_np_2 = 1, 2.0 253 254 output_1d_ms_1 = net(input_1d_ms, value_np_1) 255 output_1d_ms_2 = net(input_1d_ms, value_np_2) 256 257 input_1d_np.itemset(value_np_1) 258 assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) 259 input_1d_np.itemset(value_np_2) 260 assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) 261 262 with pytest.raises(IndexError): 263 net(input_3d_ms, value_np_1) 264 with pytest.raises(IndexError): 265 net(input_3d_ms, value_np_2) 266 267 268class TensorItemSetByItemWithNumber(Cell): 269 def construct(self, tensor, index, number_value): 270 ret = tensor.itemset(index, number_value) 271 return ret 272 273 274@pytest.mark.level0 275@pytest.mark.platform_arm_ascend_training 276@pytest.mark.platform_x86_ascend_training 277@pytest.mark.platform_x86_gpu_training 278@pytest.mark.env_onecard 279def test_setitem_dim_expand(): 280 x = onp.ones((2, 3, 4), dtype=onp.float32) 281 def cases(x): 282 x[None, True, [1, 0], (False, True, True), [2]] = 2 283 x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]] 284 x[0:1] = [[2, 3, 4, 5]] 285 x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]] 286 return x 287 setup_testcase(x, cases) 288 289 290@pytest.mark.level1 291@pytest.mark.platform_arm_ascend_training 292@pytest.mark.platform_x86_ascend_training 293@pytest.mark.platform_x86_gpu_training 294@pytest.mark.env_onecard 295def test_itemset_by_number_with_number(): 296 net = TensorItemSetByItemWithNumber() 297 input_1d_np = onp.ndarray([1]).astype(onp.float32) 298 input_1d_ms = Tensor(input_1d_np, mstype.float32) 299 300 input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) 301 input_3d_ms = Tensor(input_3d_np, mstype.float32) 302 303 index_np_1, index_np_2, index_np_3, index_np_4 = 0, 30, 60, 2.0 304 value_np_1, value_np_2 = 1, 2.0 305 306 output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1) 307 output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2) 308 output_3d_ms_1 = net(input_3d_ms, index_np_1, value_np_1) 309 output_3d_ms_2 = net(output_3d_ms_1, index_np_1, value_np_2) 310 output_3d_ms_3 = net(output_3d_ms_2, index_np_2, value_np_1) 311 output_3d_ms_4 = net(output_3d_ms_3, index_np_2, value_np_2) 312 313 input_1d_np.itemset(index_np_1, value_np_1) 314 assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) 315 input_1d_np.itemset(index_np_1, value_np_2) 316 assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) 317 input_3d_np.itemset(index_np_1, value_np_1) 318 assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np) 319 input_3d_np.itemset(index_np_1, value_np_2) 320 assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np) 321 input_3d_np.itemset(index_np_2, value_np_1) 322 assert onp.all(output_3d_ms_3.asnumpy() == input_3d_np) 323 input_3d_np.itemset(index_np_2, value_np_2) 324 assert onp.all(output_3d_ms_4.asnumpy() == input_3d_np) 325 326 with pytest.raises(IndexError): 327 net(input_1d_ms, index_np_2, value_np_1) 328 with pytest.raises(IndexError): 329 net(input_1d_ms, index_np_2, value_np_2) 330 with pytest.raises(TypeError): 331 net(input_1d_ms, index_np_4, value_np_1) 332 with pytest.raises(TypeError): 333 net(input_1d_ms, index_np_4, value_np_2) 334 with pytest.raises(IndexError): 335 net(input_3d_ms, index_np_3, value_np_1) 336 with pytest.raises(IndexError): 337 net(input_3d_ms, index_np_3, value_np_2) 338 with pytest.raises(TypeError): 339 net(input_3d_ms, index_np_4, value_np_1) 340 with pytest.raises(TypeError): 341 net(input_3d_ms, index_np_4, value_np_2) 342 343 344@pytest.mark.level1 345@pytest.mark.platform_arm_ascend_training 346@pytest.mark.platform_x86_ascend_training 347@pytest.mark.platform_x86_gpu_training 348@pytest.mark.env_onecard 349def test_itemset_by_tuple_with_number(): 350 net = TensorItemSetByItemWithNumber() 351 input_1d_np = onp.ndarray([1]).astype(onp.float32) 352 input_1d_ms = Tensor(input_1d_np, mstype.float32) 353 354 input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) 355 input_3d_ms = Tensor(input_3d_np, mstype.float32) 356 357 index_np_1, index_np_2, index_np_3, index_np_4, index_np_5 = (0,), (1, 2), (1, 1, 0), (3, 4, 5), (1, 2, 3, 4) 358 value_np_1, value_np_2 = 1, 2.0 359 360 output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1) 361 input_1d_np.itemset(index_np_1, value_np_1) 362 assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) 363 364 output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2) 365 input_1d_np.itemset(index_np_1, value_np_2) 366 assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) 367 368 output_3d_ms_1 = net(input_3d_ms, index_np_3, value_np_1) 369 input_3d_np.itemset(index_np_3, value_np_1) 370 assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np) 371 372 output_3d_ms_2 = net(input_3d_ms, index_np_3, value_np_2) 373 input_3d_np.itemset(index_np_3, value_np_2) 374 assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np) 375 376 with pytest.raises(IndexError): 377 net(input_1d_ms, index_np_2, value_np_1) 378 with pytest.raises(IndexError): 379 net(input_1d_ms, index_np_2, value_np_2) 380 with pytest.raises(IndexError): 381 net(input_3d_ms, index_np_1, value_np_1) 382 with pytest.raises(IndexError): 383 net(input_3d_ms, index_np_1, value_np_2) 384 with pytest.raises(IndexError): 385 net(input_3d_ms, index_np_2, value_np_1) 386 with pytest.raises(IndexError): 387 net(input_3d_ms, index_np_2, value_np_2) 388 with pytest.raises(IndexError): 389 net(input_3d_ms, index_np_4, value_np_1) 390 with pytest.raises(IndexError): 391 net(input_3d_ms, index_np_4, value_np_2) 392 with pytest.raises(IndexError): 393 net(input_3d_ms, index_np_5, value_np_1) 394 with pytest.raises(IndexError): 395 net(input_3d_ms, index_np_5, value_np_2) 396