• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test enumerate"""
16import numpy as np
17import pytest
18
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22
23context.set_context(mode=context.GRAPH_MODE)
24
25
26def test_enumerate_list_const():
27    class Net(nn.Cell):
28        def __init__(self):
29            super(Net, self).__init__()
30            self.value = [11, 22, 33, 44]
31
32        def construct(self):
33            index_sum = 0
34            value_sum = 0
35            for i, j in enumerate(self.value):
36                index_sum += i
37                value_sum += j
38            return index_sum, value_sum
39
40    net = Net()
41    assert net() == (6, 110)
42
43
44def test_enumerate_tuple_const():
45    class Net(nn.Cell):
46        def __init__(self):
47            super(Net, self).__init__()
48            self.value = (11, 22, 33, 44)
49
50        def construct(self):
51            index_sum = 0
52            value_sum = 0
53            for i, j in enumerate(self.value):
54                index_sum += i
55                value_sum += j
56            return index_sum, value_sum
57
58    net = Net()
59    assert net() == (6, 110)
60
61
62def test_enumerate_tensor_const():
63    class Net(nn.Cell):
64        def __init__(self):
65            super(Net, self).__init__()
66            self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
67
68        def construct(self):
69            return enumerate(self.value)
70
71    net = Net()
72    net()
73
74
75def test_enumerate_list_parameter():
76    class Net(nn.Cell):
77        def __init__(self):
78            super(Net, self).__init__()
79
80        def construct(self, x, y):
81            index_sum = 0
82            value = [x, y]
83            ret = ()
84            for i, j in enumerate(value):
85                index_sum += i
86                ret += (j,)
87            return index_sum, ret
88
89    x = Tensor(np.arange(4))
90    net = Net()
91    net(x, x)
92
93
94def test_enumerate_tuple_parameter():
95    class Net(nn.Cell):
96        def __init__(self):
97            super(Net, self).__init__()
98
99        def construct(self, x, y):
100            index_sum = 0
101            value = (x, y)
102            ret = ()
103            for i, j in enumerate(value):
104                index_sum += i
105                ret += (j,)
106            return index_sum, ret
107
108    x = Tensor(np.arange(4))
109    net = Net()
110    net(x, x)
111
112
113def test_enumerate_tensor_parameter():
114    class Net(nn.Cell):
115        def __init__(self):
116            super(Net, self).__init__()
117
118        def construct(self, x):
119            index_sum = 0
120            ret = ()
121            for i, j in enumerate(x):
122                index_sum += i
123                ret += (j,)
124            return index_sum, ret
125
126    x = Tensor(np.arange(2 * 3).reshape(2, 3))
127    net = Net()
128    net(x)
129
130
131def test_enumerate_tuple_const_1():
132    class Net(nn.Cell):
133        def __init__(self):
134            super(Net, self).__init__()
135            self.value = (11, 22, 33, 44)
136
137        def construct(self):
138            index_sum = 0
139            value_sum = 0
140            for i in enumerate(self.value):
141                index_sum += i[0]
142                value_sum += i[1]
143            return index_sum, value_sum
144
145    net = Net()
146    assert net() == (6, 110)
147
148
149def test_enumerate_tensor_const_1():
150    class Net(nn.Cell):
151        def __init__(self):
152            super(Net, self).__init__()
153            self.value = Tensor(np.arange(2*3).reshape(2, 3))
154
155        def construct(self):
156            index_sum = 0
157            ret = ()
158            for i in enumerate(self.value):
159                index_sum += i[0]
160                ret += (i[1],)
161            return index_sum, ret
162
163    net = Net()
164    net()
165
166
167def test_enumerate_tuple_parameter_1():
168    class Net(nn.Cell):
169        def __init__(self):
170            super(Net, self).__init__()
171
172        def construct(self, x, y):
173            index_sum = 0
174            value = (x, y)
175            ret = ()
176            for i in enumerate(value):
177                index_sum += i[0]
178                ret += (i[1],)
179            return index_sum, ret
180
181    x = Tensor(np.arange(4))
182    net = Net()
183    net(x, x)
184
185
186def test_enumerate_tensor_parameter_1():
187    class Net(nn.Cell):
188        def __init__(self):
189            super(Net, self).__init__()
190
191        def construct(self, x):
192            index_sum = 0
193            ret = ()
194            for i in enumerate(x):
195                index_sum += i[0]
196                ret += (i[1],)
197            return index_sum, ret
198
199    x = Tensor(np.arange(2 * 3).reshape(2, 3))
200    net = Net()
201    net(x)
202
203
204def test_enumerate_tuple_const_2():
205    class Net(nn.Cell):
206        def __init__(self):
207            super(Net, self).__init__()
208            self.value = (11, 22, 33, 44)
209
210        def construct(self):
211            index_sum = 0
212            value_sum = 0
213            for i in enumerate(self.value, 1):
214                index_sum += i[0]
215                value_sum += i[1]
216            return index_sum, value_sum
217
218    net = Net()
219    assert net() == (10, 110)
220
221
222def test_enumerate_tensor_const_2():
223    class Net(nn.Cell):
224        def __init__(self):
225            super(Net, self).__init__()
226            self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
227
228        def construct(self):
229            index_sum = 0
230            ret = ()
231            for i in enumerate(self.value, 1):
232                index_sum += i[0]
233                ret += (i[1],)
234            return index_sum, ret
235
236    net = Net()
237    net()
238
239
240def test_enumerate_tuple_parameter_2():
241    class Net(nn.Cell):
242        def __init__(self):
243            super(Net, self).__init__()
244
245        def construct(self, x, y):
246            index_sum = 0
247            value = (x, y)
248            ret = ()
249            for i in enumerate(value, 1):
250                index_sum += i[0]
251                ret += (i[1],)
252            return index_sum, ret
253
254    x = Tensor(np.arange(4))
255    net = Net()
256    net(x, x)
257
258
259def test_enumerate_tensor_parameter_2():
260    class Net(nn.Cell):
261        def __init__(self):
262            super(Net, self).__init__()
263
264        def construct(self, x):
265            index_sum = 0
266            ret = ()
267            for i, j in enumerate(x, 1):
268                index_sum += i
269                ret += (j,)
270            return index_sum, ret
271
272    x = Tensor(np.arange(2 * 3).reshape(2, 3))
273    net = Net()
274    net(x)
275
276
277def test_enumerate_start_type_error():
278    class Net(nn.Cell):
279        def __init__(self):
280            super(Net, self).__init__()
281
282        def construct(self, x):
283            return enumerate((x, x), start=1.2)
284
285    x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
286    net = Net()
287    with pytest.raises(TypeError) as ex:
288        net(x)
289    assert "For 'enumerate', the 'start'" in str(ex.value)
290