• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23
24context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
25
26
27axis0 = 0
28axis1 = 1
29axis2 = 2
30axis3 = 3
31axis4 = 4
32axis5 = -1
33axis6 = -2
34
35x0 = np.random.rand(3, 3, 4, 5, 3).astype(np.float32)
36x1 = np.random.rand(2, 3, 4, 5, 3).astype(np.float16)
37x2 = np.random.randint(-10000, 10000, size=(2, 3, 4, 5, 3)).astype(np.int32)
38x3 = np.random.randint(-5, 5, size=(2, 3, 4, 5, 3)).astype(np.int8)
39x4 = np.random.randint(0, 10, size=(2, 3, 4, 5, 3)).astype(np.uint8)
40x5 = np.random.rand(3).astype(np.float32)
41
42list1 = [x0, x1, x2, x3, x4]
43list2 = [axis0, axis1, axis2, axis3, axis4, axis5, axis6]
44
45class CumSum(nn.Cell):
46    def __init__(self, exclusive=False, reverse=False):
47        super(CumSum, self).__init__()
48        self.cumsum_op = P.CumSum(exclusive, reverse)
49
50        self.x0 = Tensor(x0)
51        self.axis0 = axis0
52        self.x1 = Tensor(x0)
53        self.axis1 = axis1
54        self.x2 = Tensor(x0)
55        self.axis2 = axis2
56        self.x3 = Tensor(x0)
57        self.axis3 = axis3
58        self.x4 = Tensor(x0)
59        self.axis4 = axis4
60        self.x5 = Tensor(x0)
61        self.axis5 = axis5
62        self.x6 = Tensor(x0)
63        self.axis6 = axis6
64
65        self.x7 = Tensor(x1)
66        self.axis7 = axis0
67        self.x8 = Tensor(x1)
68        self.axis8 = axis1
69        self.x9 = Tensor(x1)
70        self.axis9 = axis2
71        self.x10 = Tensor(x1)
72        self.axis10 = axis3
73        self.x11 = Tensor(x1)
74        self.axis11 = axis4
75        self.x12 = Tensor(x1)
76        self.axis12 = axis5
77        self.x13 = Tensor(x1)
78        self.axis13 = axis6
79
80        self.x14 = Tensor(x2)
81        self.axis14 = axis0
82        self.x15 = Tensor(x2)
83        self.axis15 = axis1
84        self.x16 = Tensor(x2)
85        self.axis16 = axis2
86        self.x17 = Tensor(x2)
87        self.axis17 = axis3
88        self.x18 = Tensor(x2)
89        self.axis18 = axis4
90        self.x19 = Tensor(x2)
91        self.axis19 = axis5
92        self.x20 = Tensor(x2)
93        self.axis20 = axis6
94
95        self.x21 = Tensor(x3)
96        self.axis21 = axis0
97        self.x22 = Tensor(x3)
98        self.axis22 = axis1
99        self.x23 = Tensor(x3)
100        self.axis23 = axis2
101        self.x24 = Tensor(x3)
102        self.axis24 = axis3
103        self.x25 = Tensor(x3)
104        self.axis25 = axis4
105        self.x26 = Tensor(x3)
106        self.axis26 = axis5
107        self.x27 = Tensor(x3)
108        self.axis27 = axis6
109
110        self.x28 = Tensor(x4)
111        self.axis28 = axis0
112        self.x29 = Tensor(x4)
113        self.axis29 = axis1
114        self.x30 = Tensor(x4)
115        self.axis30 = axis2
116        self.x31 = Tensor(x4)
117        self.axis31 = axis3
118        self.x32 = Tensor(x4)
119        self.axis32 = axis4
120        self.x33 = Tensor(x4)
121        self.axis33 = axis5
122        self.x34 = Tensor(x4)
123        self.axis34 = axis6
124
125        self.x35 = Tensor(x5)
126        self.axis35 = axis0
127
128    def construct(self):
129        return (self.cumsum_op(self.x0, self.axis0),
130                self.cumsum_op(self.x1, self.axis1),
131                self.cumsum_op(self.x2, self.axis2),
132                self.cumsum_op(self.x3, self.axis3),
133                self.cumsum_op(self.x4, self.axis4),
134                self.cumsum_op(self.x5, self.axis5),
135                self.cumsum_op(self.x6, self.axis6),
136                self.cumsum_op(self.x7, self.axis7),
137                self.cumsum_op(self.x8, self.axis8),
138                self.cumsum_op(self.x9, self.axis9),
139                self.cumsum_op(self.x10, self.axis10),
140                self.cumsum_op(self.x11, self.axis11),
141                self.cumsum_op(self.x12, self.axis12),
142                self.cumsum_op(self.x13, self.axis13),
143                self.cumsum_op(self.x14, self.axis14),
144                self.cumsum_op(self.x15, self.axis15),
145                self.cumsum_op(self.x16, self.axis16),
146                self.cumsum_op(self.x17, self.axis17),
147                self.cumsum_op(self.x18, self.axis18),
148                self.cumsum_op(self.x19, self.axis19),
149                self.cumsum_op(self.x20, self.axis20),
150                self.cumsum_op(self.x21, self.axis21),
151                self.cumsum_op(self.x22, self.axis22),
152                self.cumsum_op(self.x23, self.axis23),
153                self.cumsum_op(self.x24, self.axis24),
154                self.cumsum_op(self.x25, self.axis25),
155                self.cumsum_op(self.x26, self.axis26),
156                self.cumsum_op(self.x27, self.axis27),
157                self.cumsum_op(self.x28, self.axis28),
158                self.cumsum_op(self.x29, self.axis29),
159                self.cumsum_op(self.x30, self.axis30),
160                self.cumsum_op(self.x31, self.axis31),
161                self.cumsum_op(self.x32, self.axis32),
162                self.cumsum_op(self.x33, self.axis33),
163                self.cumsum_op(self.x34, self.axis34),
164                self.cumsum_op(self.x35, self.axis35))
165
166
167@pytest.mark.level0
168@pytest.mark.platform_x86_cpu
169@pytest.mark.env_onecard
170def test_cumsum():
171    cumsum = CumSum()
172    output = cumsum()
173
174    k = 0
175
176    for i in list1:
177        for j in list2:
178            expect = np.cumsum(i, axis=j)
179            diff = abs(output[k].asnumpy() - expect)
180            error = np.ones(shape=expect.shape) * 1.0e-5
181            assert np.all(diff < error)
182            assert output[k].shape == expect.shape
183            k += 1
184
185    expect = np.cumsum(x5, axis=axis0)
186    diff = abs(output[k].asnumpy() - expect)
187    error = np.ones(shape=expect.shape) * 1.0e-5
188    assert np.all(diff < error)
189    assert output[k].shape == expect.shape
190
191
192def test_cumsum2():
193    cumsum = CumSum(exclusive=False, reverse=True)
194    output = cumsum()
195
196    k = 0
197
198    for i in list1:
199        for j in list2:
200            result1 = np.flip(i, axis=j)
201            result2 = np.cumsum(result1, axis=j)
202            expect = np.flip(result2, axis=j)
203            diff = abs(output[k].asnumpy() - expect)
204            error = np.ones(shape=expect.shape) * 1.0e-5
205            assert np.all(diff < error)
206            assert output[k].shape == expect.shape
207            k += 1
208
209    result1 = np.flip(x5, axis=axis0)
210    result2 = np.cumsum(result1, axis=axis0)
211    expect = np.flip(result2, axis=axis0)
212    diff = abs(output[k].asnumpy() - expect)
213    error = np.ones(shape=expect.shape) * 1.0e-5
214    assert np.all(diff < error)
215    assert output[k].shape == expect.shape
216
217
218def test_cumsum3():
219    cumsum = CumSum(exclusive=True, reverse=False)
220    output = cumsum()
221
222    k = 0
223
224    for i in list1:
225        for j in list2:
226            result1 = np.insert(i, 0, [0], axis=j)
227            result2 = np.delete(result1, -1, axis=j)
228            expect = np.cumsum(result2, axis=j)
229            diff = abs(output[k].asnumpy() - expect)
230            error = np.ones(shape=expect.shape) * 1.0e-5
231            assert np.all(diff < error)
232            assert output[k].shape == expect.shape
233            k += 1
234
235    result1 = np.insert(x5, 0, [0], axis=axis0)
236    result2 = np.delete(result1, -1, axis=axis0)
237    expect = np.cumsum(result2, axis=axis0)
238    diff = abs(output[k].asnumpy() - expect)
239    error = np.ones(shape=expect.shape) * 1.0e-5
240    assert np.all(diff < error)
241    assert output[k].shape == expect.shape
242
243
244def test_cumsum4():
245    cumsum = CumSum(exclusive=True, reverse=True)
246    output = cumsum()
247
248    k = 0
249
250    for i in list1:
251        for j in list2:
252            result1 = np.flip(i, axis=j)
253            result2 = np.insert(result1, 0, [0], axis=j)
254            result3 = np.delete(result2, -1, axis=j)
255            result4 = np.cumsum(result3, axis=j)
256            expect = np.flip(result4, axis=j)
257            diff = abs(output[k].asnumpy() - expect)
258            error = np.ones(shape=expect.shape) * 1.0e-5
259            assert np.all(diff < error)
260            assert output[k].shape == expect.shape
261            k += 1
262
263    result1 = np.flip(x5, axis=axis0)
264    result2 = np.insert(result1, 0, [0], axis=axis0)
265    result3 = np.delete(result2, -1, axis=axis0)
266    result4 = np.cumsum(result3, axis=axis0)
267    expect = np.flip(result4, axis=axis0)
268    diff = abs(output[k].asnumpy() - expect)
269    error = np.ones(shape=expect.shape) * 1.0e-5
270    assert np.all(diff < error)
271    assert output[k].shape == expect.shape
272