• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_tensor_slice """
16import numpy as np
17import pytest
18
19from mindspore import Tensor
20from mindspore import context
21from mindspore import dtype as mstype
22from mindspore.nn import Cell
23from ....mindspore_test_framework.mindspore_test import mindspore_test
24from ....mindspore_test_framework.pipeline.forward.compile_forward \
25    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
26    pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
27
28
29class NetWorkFancyIndex(Cell):
30    def __init__(self, index):
31        super(NetWorkFancyIndex, self).__init__()
32        self.index = index
33
34    def construct(self, tensor):
35        return tensor[self.index]
36
37
38class TensorItemByNone(Cell):
39    def construct(self, tensor):
40        ret = tensor.item()
41        return ret
42
43
44class TensorItemByItem(Cell):
45    def construct(self, tensor, index):
46        ret = tensor.item(index)
47        return ret
48
49
50def test_tensor_fancy_index_integer_list():
51    context.set_context(mode=context.GRAPH_MODE)
52    index = [0, 2, 1]
53    net = NetWorkFancyIndex(index)
54    input_np = np.arange(60).reshape(3, 4, 5)
55    input_me = Tensor(input_np, dtype=mstype.float32)
56    net(input_me)
57
58
59def test_tensor_fancy_index_boolean_list():
60    context.set_context(mode=context.GRAPH_MODE)
61    index = [True, True, False]
62    net = NetWorkFancyIndex(index)
63    input_np = np.arange(60).reshape(3, 4, 5)
64    input_me = Tensor(input_np, dtype=mstype.float32)
65    net(input_me)
66
67
68def test_tensor_fancy_index_integer_boolean_list_graph():
69    context.set_context(mode=context.GRAPH_MODE)
70    index = [1, 2, True, False]
71    net = NetWorkFancyIndex(index)
72    input_np = np.arange(60).reshape(3, 4, 5)
73    input_me = Tensor(input_np, dtype=mstype.float32)
74    net(input_me)
75
76
77def test_tensor_fancy_index_integer_list_mixed():
78    context.set_context(mode=context.GRAPH_MODE)
79    index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4)
80    net = NetWorkFancyIndex(index)
81    input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
82    input_me = Tensor(input_np, dtype=mstype.float32)
83    net(input_me)
84
85
86def test_tensor_fancy_index_integer_tuple_mixed():
87    context.set_context(mode=context.GRAPH_MODE)
88    index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4)
89    net = NetWorkFancyIndex(index)
90    input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
91    input_me = Tensor(input_np, dtype=mstype.float32)
92    net(input_me)
93
94
95def test_tensor_fancy_index_integer_list_tuple_mixed():
96    context.set_context(mode=context.GRAPH_MODE)
97    index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4)
98    net = NetWorkFancyIndex(index)
99    input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
100    input_me = Tensor(input_np, dtype=mstype.float32)
101    net(input_me)
102
103
104def test_tensor_fancy_index_integer_list_tuple_bool_mixed():
105    context.set_context(mode=context.GRAPH_MODE)
106    index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., True, 4)
107    net = NetWorkFancyIndex(index)
108    input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
109    input_me = Tensor(input_np, dtype=mstype.float32)
110    net(input_me)
111
112
113def test_tensor_fancy_index_integer_list_tuple_bool_mixed_error():
114    context.set_context(mode=context.GRAPH_MODE)
115    index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., False, 4)
116    net = NetWorkFancyIndex(index)
117    input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
118    input_me = Tensor(input_np, dtype=mstype.float32)
119    with pytest.raises(IndexError):
120        net(input_me)
121
122
123input_1d_np = np.ndarray([1]).astype(np.float32)
124input_1d_ms = Tensor(input_1d_np, mstype.float32)
125input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32)
126input_3d_ms = Tensor(input_3d_np, mstype.float32)
127index_np_1, index_np_2, index_np_3, index_np_4 = 0, 1.0, 30, 60
128tuple_index_np_1, tuple_index_np_2, tuple_index_np_3, tuple_index_np_4, tuple_index_np_5 = \
129    (0,), (1, 2), (1, 2, 3), (3, 4, 4), (1, 2, 3, 4)
130
131test_cases = [
132    ('TensorItemByNone', {'block': TensorItemByNone(), 'desc_inputs': [input_1d_ms],}),
133    ('1dTensorItemByInt', {'block': TensorItemByItem(), 'desc_inputs': [input_1d_ms, index_np_1],}),
134    ('3dTensorItemByInt', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, index_np_1],}),
135    ('3dTensorItemByInt2', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, index_np_3],}),
136    ('1dTensorItemByTuple', {'block': TensorItemByItem(), 'desc_inputs': [input_1d_ms, tuple_index_np_1],}),
137    ('3dTensorItemByTuple', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, tuple_index_np_3],}),
138]
139
140
141test_error_cases = [
142    ('TensorItemByNoneForMulDimsTensor', {
143        'block': (TensorItemByNone(), {'exception': ValueError}),
144        'desc_inputs': [input_3d_ms]
145    }),
146    ('TensorItemByFloatError', {
147        'block': (TensorItemByItem(), {'exception': TypeError}),
148        'desc_inputs': [input_1d_ms, index_np_2]
149    }),
150    ('TensorItemByFloatError2', {
151        'block': (TensorItemByItem(), {'exception': TypeError}),
152        'desc_inputs': [input_3d_ms, index_np_2]
153    }),
154    ('TensorItemByIntOverBoundary', {
155        'block': (TensorItemByItem(), {'exception': IndexError}),
156        'desc_inputs': [input_1d_ms, index_np_3]
157    }),
158    ('TensorItemByIntOverBoundary2', {
159        'block': (TensorItemByItem(), {'exception': IndexError}),
160        'desc_inputs': [input_3d_ms, index_np_4]
161    }),
162    ('1dTensorItemBy2dTuple', {
163        'block': (TensorItemByItem(), {'exception': ValueError}),
164        'desc_inputs': [input_1d_ms, tuple_index_np_2]
165    }),
166    ('3dTensorItemBy2dTuple', {
167        'block': (TensorItemByItem(), {'exception': ValueError}),
168        'desc_inputs': [input_3d_ms, tuple_index_np_2]
169    }),
170    ('3dTensorItemBy3dTupleOutOfBoundary', {
171        'block': (TensorItemByItem(), {'exception': IndexError}),
172        'desc_inputs': [input_3d_ms, tuple_index_np_4]
173    }),
174    ('3dTensorItemBy4dTuple', {
175        'block': (TensorItemByItem(), {'exception': ValueError}),
176        'desc_inputs': [input_3d_ms, tuple_index_np_5]
177    })
178]
179
180
181@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
182def test_exec():
183    context.set_context(mode=context.GRAPH_MODE)
184    return test_cases
185
186
187@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
188def test_check_exception():
189    return test_error_cases
190