• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_tensor_setitem """
16import numpy as onp
17import pytest
18
19from mindspore import Tensor, context
20from mindspore.nn import Cell
21from mindspore import dtype as mstype
22
23
24def setup_module():
25    context.set_context(mode=context.PYNATIVE_MODE)
26
27
28def setup_testcase(input_np, case_fn):
29    input_ms = Tensor(input_np)
30
31    class TensorSetItem(Cell):
32        def construct(self, x):
33            return case_fn(x)
34
35    class NumpySetItem():
36        def __call__(self, x):
37            return case_fn(x)
38
39    out_ms = TensorSetItem()(input_ms)
40    out_np = NumpySetItem()(input_np)
41    assert onp.all(out_ms.asnumpy() == out_np)
42
43
44class TensorSetItemByList(Cell):
45    def construct(self, x):
46        x[[0, 1], [1, 2], [1, 3]] = [3, 4]
47        x[([0, 1], [0, 2], [1, 1])] = [10, 5]
48        x[[0, 1], ..., [0, 1]] = 4
49        return x
50
51
52class NumpySetItemByList():
53    def __call__(self, x):
54        x[[0, 1], [1, 2], [1, 3]] = [3, 4]
55        x[([0, 1], [0, 2], [1, 1])] = [10, 5]
56        x[[0, 1], ..., [0, 1]] = 4
57        return x
58
59
60@pytest.mark.level1
61@pytest.mark.platform_arm_ascend_training
62@pytest.mark.platform_x86_ascend_training
63@pytest.mark.platform_x86_gpu_training
64@pytest.mark.env_onecard
65def test_setitem_by_list():
66    x = onp.ones((2, 3, 4), dtype=onp.float32)
67
68    def cases(x):
69        x[[0, 1], [1, 2], [1, 3]] = [3, 4]
70        x[([0, 1], [0, 2], [1, 1])] = [10, 5]
71        x[[0, 1], ..., [0, 1]] = 4
72        return x
73    setup_testcase(x, cases)
74
75
76@pytest.mark.level1
77@pytest.mark.platform_arm_ascend_training
78@pytest.mark.platform_x86_ascend_training
79@pytest.mark.platform_x86_gpu_training
80@pytest.mark.env_onecard
81def test_setitem_with_sequence():
82    x = onp.ones((2, 3, 4), dtype=onp.float32)
83
84    def cases(x):
85        x[...] = [3]
86        x[..., 1] = ([1, 2, 3], [4, 5, 6])
87        x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
88        x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
89        return x
90    setup_testcase(x, cases)
91
92
93@pytest.mark.level1
94@pytest.mark.platform_arm_ascend_training
95@pytest.mark.platform_x86_ascend_training
96@pytest.mark.platform_x86_gpu_training
97@pytest.mark.env_onecard
98def test_setitem_dtype():
99    x = onp.ones((2, 3, 4), dtype=onp.float32)
100
101    def cases(x):
102        x[...] = 3
103        x[..., 1] = 3.0
104        x[0] = True
105        x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11])
106        return x
107    setup_testcase(x, cases)
108
109
110@pytest.mark.level1
111@pytest.mark.platform_arm_ascend_training
112@pytest.mark.platform_x86_ascend_training
113@pytest.mark.platform_x86_gpu_training
114@pytest.mark.env_onecard
115def test_setitem_by_tuple_with_int():
116    x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
117
118    def cases(x):
119        x[..., 2, False, 1] = -1
120        x[0, True, 0, None, True] = -2
121        x[0, ..., None] = -3
122        x[..., 0, None, 1, True, True, None] = -4
123        return x
124    setup_testcase(x, cases)
125
126
127@pytest.mark.level1
128@pytest.mark.platform_arm_ascend_training
129@pytest.mark.platform_x86_ascend_training
130@pytest.mark.platform_x86_gpu_training
131@pytest.mark.env_onecard
132def test_setitem_by_tuple_with_list():
133    x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
134
135    def cases(x):
136        x[..., 2, False, 1] = [-1]
137        x[0, True, 0, None, True] = [-2, -2, -2, -2]
138        x[0, ..., None] = [[-3], [-3], [-3], [-3]]
139        x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
140        x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]]
141        return x
142    setup_testcase(x, cases)
143
144
145@pytest.mark.level1
146@pytest.mark.platform_arm_ascend_training
147@pytest.mark.platform_x86_ascend_training
148@pytest.mark.platform_x86_gpu_training
149@pytest.mark.env_onecard
150def test_setitem_by_nested_unit_list():
151    x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
152
153    def cases(x):
154        x[[[[0]]], True] = -1
155        x[[1], ..., [[[[2]]]]] = -2
156        x[0, [[[2]]], [1]] = -3
157        return x
158    setup_testcase(x, cases)
159
160
161@pytest.mark.level0
162@pytest.mark.platform_arm_ascend_training
163@pytest.mark.platform_x86_ascend_training
164@pytest.mark.platform_x86_gpu_training
165@pytest.mark.env_onecard
166def test_setitem_with_broadcast():
167    x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32)
168    v1 = onp.full((1, 4, 5), -1).tolist()
169    v2 = onp.full((4, 1, 6), -2).tolist()
170
171    def cases(x):
172        x[..., 4] = v1
173        x[0, 2] = v2
174        x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]]
175        x[0, ..., 1, 3, 5] = -4
176        return x
177    setup_testcase(x, cases)
178
179
180@pytest.mark.level1
181@pytest.mark.platform_arm_ascend_training
182@pytest.mark.platform_x86_ascend_training
183@pytest.mark.platform_x86_gpu_training
184@pytest.mark.env_onecard
185def test_setitem_mul_by_scalar():
186    x = onp.ones((4, 5), dtype=onp.float32)
187
188    def cases(x):
189        x[1, :] = x[1, :]*2
190        x[:, 2] = x[:, 3]*3.0
191        return x
192    setup_testcase(x, cases)
193
194
195@pytest.mark.level1
196@pytest.mark.platform_arm_ascend_training
197@pytest.mark.platform_x86_ascend_training
198@pytest.mark.platform_x86_gpu_training
199@pytest.mark.env_onecard
200def test_setitem_by_slice():
201    x = onp.ones((3, 4, 5), dtype=onp.float32)
202
203    def cases(x):
204        x[1:2] = 2
205        x[-3:1] = 3
206        x[-10:3:2] = 4
207        x[5:0:3] = 5
208        x[5:5:5] = 6
209        x[-1:2] = 7
210        x[1:0:-1] = 8
211        return x
212    setup_testcase(x, cases)
213
214
215@pytest.mark.level1
216@pytest.mark.platform_arm_ascend_training
217@pytest.mark.platform_x86_ascend_training
218@pytest.mark.platform_x86_gpu_training
219@pytest.mark.env_onecard
220def test_setitem_by_tuple_of_slices():
221    x = onp.ones((3, 4, 5), dtype=onp.float32)
222
223    def cases(x):
224        x[1:2, 2] = 2
225        x[0, -4:1] = 3
226        x[1, -10:3:2] = 4
227        x[5:0:3, 3] = 5
228        x[1:1, 2:2] = 6
229        return x
230    setup_testcase(x, cases)
231
232
233class TensorItemSetWithNumber(Cell):
234    def construct(self, tensor, number_value):
235        ret = tensor.itemset(number_value)
236        return ret
237
238
239@pytest.mark.level1
240@pytest.mark.platform_arm_ascend_training
241@pytest.mark.platform_x86_ascend_training
242@pytest.mark.platform_x86_gpu_training
243@pytest.mark.env_onecard
244def test_itemset_with_number():
245    net = TensorItemSetWithNumber()
246    input_1d_np = onp.ndarray([1]).astype(onp.float32)
247    input_1d_ms = Tensor(input_1d_np, mstype.float32)
248
249    input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
250    input_3d_ms = Tensor(input_3d_np, mstype.float32)
251
252    value_np_1, value_np_2 = 1, 2.0
253
254    output_1d_ms_1 = net(input_1d_ms, value_np_1)
255    output_1d_ms_2 = net(input_1d_ms, value_np_2)
256
257    input_1d_np.itemset(value_np_1)
258    assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
259    input_1d_np.itemset(value_np_2)
260    assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
261
262    with pytest.raises(IndexError):
263        net(input_3d_ms, value_np_1)
264    with pytest.raises(IndexError):
265        net(input_3d_ms, value_np_2)
266
267
268class TensorItemSetByItemWithNumber(Cell):
269    def construct(self, tensor, index, number_value):
270        ret = tensor.itemset(index, number_value)
271        return ret
272
273
274@pytest.mark.level0
275@pytest.mark.platform_arm_ascend_training
276@pytest.mark.platform_x86_ascend_training
277@pytest.mark.platform_x86_gpu_training
278@pytest.mark.env_onecard
279def test_setitem_dim_expand():
280    x = onp.ones((2, 3, 4), dtype=onp.float32)
281    def cases(x):
282        x[None, True, [1, 0], (False, True, True), [2]] = 2
283        x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]]
284        x[0:1] = [[2, 3, 4, 5]]
285        x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]]
286        return x
287    setup_testcase(x, cases)
288
289
290@pytest.mark.level1
291@pytest.mark.platform_arm_ascend_training
292@pytest.mark.platform_x86_ascend_training
293@pytest.mark.platform_x86_gpu_training
294@pytest.mark.env_onecard
295def test_itemset_by_number_with_number():
296    net = TensorItemSetByItemWithNumber()
297    input_1d_np = onp.ndarray([1]).astype(onp.float32)
298    input_1d_ms = Tensor(input_1d_np, mstype.float32)
299
300    input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
301    input_3d_ms = Tensor(input_3d_np, mstype.float32)
302
303    index_np_1, index_np_2, index_np_3, index_np_4 = 0, 30, 60, 2.0
304    value_np_1, value_np_2 = 1, 2.0
305
306    output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1)
307    output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2)
308    output_3d_ms_1 = net(input_3d_ms, index_np_1, value_np_1)
309    output_3d_ms_2 = net(output_3d_ms_1, index_np_1, value_np_2)
310    output_3d_ms_3 = net(output_3d_ms_2, index_np_2, value_np_1)
311    output_3d_ms_4 = net(output_3d_ms_3, index_np_2, value_np_2)
312
313    input_1d_np.itemset(index_np_1, value_np_1)
314    assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
315    input_1d_np.itemset(index_np_1, value_np_2)
316    assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
317    input_3d_np.itemset(index_np_1, value_np_1)
318    assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np)
319    input_3d_np.itemset(index_np_1, value_np_2)
320    assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np)
321    input_3d_np.itemset(index_np_2, value_np_1)
322    assert onp.all(output_3d_ms_3.asnumpy() == input_3d_np)
323    input_3d_np.itemset(index_np_2, value_np_2)
324    assert onp.all(output_3d_ms_4.asnumpy() == input_3d_np)
325
326    with pytest.raises(IndexError):
327        net(input_1d_ms, index_np_2, value_np_1)
328    with pytest.raises(IndexError):
329        net(input_1d_ms, index_np_2, value_np_2)
330    with pytest.raises(TypeError):
331        net(input_1d_ms, index_np_4, value_np_1)
332    with pytest.raises(TypeError):
333        net(input_1d_ms, index_np_4, value_np_2)
334    with pytest.raises(IndexError):
335        net(input_3d_ms, index_np_3, value_np_1)
336    with pytest.raises(IndexError):
337        net(input_3d_ms, index_np_3, value_np_2)
338    with pytest.raises(TypeError):
339        net(input_3d_ms, index_np_4, value_np_1)
340    with pytest.raises(TypeError):
341        net(input_3d_ms, index_np_4, value_np_2)
342
343
344@pytest.mark.level1
345@pytest.mark.platform_arm_ascend_training
346@pytest.mark.platform_x86_ascend_training
347@pytest.mark.platform_x86_gpu_training
348@pytest.mark.env_onecard
349def test_itemset_by_tuple_with_number():
350    net = TensorItemSetByItemWithNumber()
351    input_1d_np = onp.ndarray([1]).astype(onp.float32)
352    input_1d_ms = Tensor(input_1d_np, mstype.float32)
353
354    input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
355    input_3d_ms = Tensor(input_3d_np, mstype.float32)
356
357    index_np_1, index_np_2, index_np_3, index_np_4, index_np_5 = (0,), (1, 2), (1, 1, 0), (3, 4, 5), (1, 2, 3, 4)
358    value_np_1, value_np_2 = 1, 2.0
359
360    output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1)
361    input_1d_np.itemset(index_np_1, value_np_1)
362    assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
363
364    output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2)
365    input_1d_np.itemset(index_np_1, value_np_2)
366    assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
367
368    output_3d_ms_1 = net(input_3d_ms, index_np_3, value_np_1)
369    input_3d_np.itemset(index_np_3, value_np_1)
370    assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np)
371
372    output_3d_ms_2 = net(input_3d_ms, index_np_3, value_np_2)
373    input_3d_np.itemset(index_np_3, value_np_2)
374    assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np)
375
376    with pytest.raises(IndexError):
377        net(input_1d_ms, index_np_2, value_np_1)
378    with pytest.raises(IndexError):
379        net(input_1d_ms, index_np_2, value_np_2)
380    with pytest.raises(IndexError):
381        net(input_3d_ms, index_np_1, value_np_1)
382    with pytest.raises(IndexError):
383        net(input_3d_ms, index_np_1, value_np_2)
384    with pytest.raises(IndexError):
385        net(input_3d_ms, index_np_2, value_np_1)
386    with pytest.raises(IndexError):
387        net(input_3d_ms, index_np_2, value_np_2)
388    with pytest.raises(IndexError):
389        net(input_3d_ms, index_np_4, value_np_1)
390    with pytest.raises(IndexError):
391        net(input_3d_ms, index_np_4, value_np_2)
392    with pytest.raises(IndexError):
393        net(input_3d_ms, index_np_5, value_np_1)
394    with pytest.raises(IndexError):
395        net(input_3d_ms, index_np_5, value_np_2)
396