• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_ascend_control_sink """
16import pytest
17import numpy as np
18import mindspore.context as context
19import mindspore.nn as nn
20from mindspore.ops import operations as op
21from mindspore.common import dtype as mstype
22from mindspore.common.tensor import Tensor
23from mindspore.common.parameter import Parameter
24from mindspore.common.initializer import initializer
25
26
27class ControlSimpleIf(nn.Cell):
28    def __init__(self):
29        super().__init__()
30        self.addn = op.AddN()
31
32    def construct(self, x, y, z, input1, input2):
33        addn1 = self.addn([input1, input1, input1])
34        addn2 = self.addn([input2, input2, input2])
35        addn11 = self.addn([addn1, addn1, addn1])
36        addn22 = self.addn([addn2, addn2, addn2])
37        cond1 = x > y
38        cond2 = y > z
39        # dodge pylint
40        if cond1 and cond2:
41            out = self.addn([addn11, addn11])
42        else:
43            out = self.addn([addn22, addn22])
44        out_me = self.addn([out, input1])
45        return out_me
46
47
48class ControlSimpleIfWithAssign(nn.Cell):
49    def __init__(self, input_shape):
50        super().__init__()
51        self.addn = op.AddN()
52        self.assign = op.Assign()
53        self.input_data = Parameter(initializer(1, input_shape, mstype.float32), name="var")
54
55    def construct(self, x, y, input_data):
56        if x > y:
57            out = self.addn([input_data, input_data, input_data])
58        else:
59            out = self.assign(self.input_data, input_data)
60        return out
61
62
63class ControlIfinIf(nn.Cell):
64    """pass"""
65
66    def construct(self, x, y):
67        if x > y:
68            x = x + 1
69            if y < 0:
70                y = y + 1
71            else:
72                y = y + 2
73        else:
74            x = x + 2
75        x = x + y
76        return x
77
78
79class ControlIfbyIfbyIf(nn.Cell):
80    def __init__(self):
81        super().__init__()
82        self.addn = op.AddN()
83
84    def construct(self, x, y, cond1, cond2, input_data):
85        tri_in = self.addn([input_data, input_data, input_data])
86        if x > y:
87            addn_1 = self.addn([tri_in, tri_in])
88        else:
89            addn_1 = self.addn([tri_in, tri_in, tri_in])
90        if cond1:
91            addn_2 = self.addn([addn_1, addn_1])
92        else:
93            addn_2 = self.addn([addn_1, addn_1, addn_1])
94        if cond2:
95            out = self.addn([addn_2, addn_2, addn_2])
96        else:
97            out = self.addn([addn_2, addn_2])
98        return out
99
100
101class ControlSimpleWhile(nn.Cell):
102    def __init__(self):
103        super().__init__()
104        self.addn = op.AddN()
105
106    def construct(self, x, y, input_data):
107        out = input_data
108        while x:
109            out = self.addn([input_data, input_data, input_data])
110            x = y
111        return out
112
113
114class ControlMixedWhileIf(nn.Cell):
115    def __init__(self):
116        super().__init__()
117        self.assign = op.Assign()
118        self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
119
120    def construct(self, x, y, z, c2, c4):
121        out = c4
122        self.assign(self.var, c4)
123        while x < c2:
124            y = c4
125            self.assign(self.var, c4)
126            while y < c2 and x < c2:
127                if 2 * y < c2:
128                    y = y + 2
129                else:
130                    y = y + 1
131            out = out + y
132            z = c4
133            self.assign(self.var, c4)
134            while z < c2:
135                z = z + 1
136            out = out + z
137            x = x + 1
138        out = out + x
139        while x < 2 * c2:
140            y = c4
141            self.assign(self.var, c4)
142            x = x + 1
143            while y < c2:
144                z = c4
145                self.assign(self.var, c4)
146                while z < c2:
147                    z = z + 1
148                if x < c2:
149                    y = y - 1
150                else:
151                    y = y + 1
152                out = out + z
153            out = out + y
154        out = out + x
155        return out
156
157
158class AndOperation(nn.Cell):
159    def __init__(self):
160        super().__init__()
161        self.reduce_sum = op.ReduceSum()
162
163    def construct(self, x, y):
164        x_sum = self.reduce_sum(x)
165        y_sum = self.reduce_sum(y)
166        out = x_sum and y_sum
167        return out
168
169
170class OrOperation(nn.Cell):
171    def __init__(self):
172        super().__init__()
173        self.reduce_sum = op.ReduceSum()
174
175    def construct(self, x, y):
176        x_sum = self.reduce_sum(x)
177        y_sum = self.reduce_sum(y)
178        out = x_sum or y_sum
179        return out
180
181
182class NotOperation(nn.Cell):
183    def __init__(self):
184        super().__init__()
185        self.reduce_sum = op.ReduceSum()
186
187    def construct(self, x):
188        x_sum = self.reduce_sum(x)
189        return not x_sum
190
191
192@pytest.mark.level1
193@pytest.mark.platform_arm_ascend_training
194@pytest.mark.platform_x86_ascend_training
195@pytest.mark.env_onecard
196def test_simple_if():
197    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
198    x = np.array(3).astype(np.float32)
199    y = np.array(2).astype(np.float32)
200    z = np.array(3).astype(np.float32)
201    input_shape = (127, 7, 53, 31)
202    input1 = np.random.randn(*input_shape).astype(np.float32)
203    input2 = np.random.randn(*input_shape).astype(np.float32)
204    net = ControlSimpleIf()
205    output = net(Tensor(x), Tensor(y), Tensor(z), Tensor(input1), Tensor(input2))
206    expect = input2 * 3 * 3 * 2 + input1
207    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
208
209
210@pytest.mark.level1
211@pytest.mark.platform_arm_ascend_training
212@pytest.mark.platform_x86_ascend_training
213@pytest.mark.env_onecard
214def test_simple_if_with_assign():
215    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
216    x = np.array(0).astype(np.float32)
217    y = np.array(1).astype(np.float32)
218    input_shape = (127, 7, 53, 31)
219    input_data = np.random.randn(*input_shape).astype(np.float32)
220    net = ControlSimpleIfWithAssign(input_shape)
221    output = net(Tensor(x), Tensor(y), Tensor(input_data))
222    expect = input_data
223    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
224
225
226@pytest.mark.level1
227@pytest.mark.platform_arm_ascend_training
228@pytest.mark.platform_x86_ascend_training
229@pytest.mark.env_onecard
230def test_if_in_if():
231    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
232    x = np.array(2.345678).astype(np.float32)
233    y = np.array(1.234567).astype(np.float32)
234    net = ControlIfinIf()
235    output = net(Tensor(x), Tensor(y))
236    expect = x + y + 3
237    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
238
239
240@pytest.mark.level1
241@pytest.mark.platform_arm_ascend_training
242@pytest.mark.platform_x86_ascend_training
243@pytest.mark.env_onecard
244def test_if_by_if_by_if():
245    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
246    x = np.array(2.345678).astype(np.float32)
247    y = np.array(1.234567).astype(np.float32)
248    cond1 = np.array(True).astype(np.bool)
249    cond2 = np.array(False).astype(np.bool)
250    input_shape = (127, 7, 53, 31)
251    input_data = np.random.randn(*input_shape).astype(np.float32)
252    net = ControlIfbyIfbyIf()
253    output = net(Tensor(x), Tensor(y), Tensor(cond1), Tensor(cond2), Tensor(input_data))
254    expect = input_data * 3 * 2 * 2 * 2
255    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
256
257
258@pytest.mark.level0
259@pytest.mark.platform_arm_ascend_training
260@pytest.mark.platform_x86_ascend_training
261@pytest.mark.env_onecard
262def test_simple_while():
263    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
264    x = np.array(True).astype(np.bool)
265    y = np.array(False).astype(np.bool)
266    input_shape = (127, 7, 53, 31)
267    input_data = np.random.randn(*input_shape).astype(np.float32)
268    net = ControlSimpleWhile()
269    output = net(Tensor(x), Tensor(y), Tensor(input_data))
270    expect = input_data * 3
271    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
272
273
274@pytest.mark.level1
275@pytest.mark.platform_arm_ascend_training
276@pytest.mark.platform_x86_ascend_training
277@pytest.mark.env_onecard
278def test_mixed_while_if():
279    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
280    x = np.array(2).astype(np.int32)
281    y = np.array(14).astype(np.int32)
282    z = np.array(1).astype(np.int32)
283    c2 = Tensor([14], mstype.int32)
284    c4 = Tensor([0], mstype.int32)
285    net = ControlMixedWhileIf()
286    output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
287    expect = np.array(3318).astype(np.int32)
288    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
289
290
291@pytest.mark.level0
292@pytest.mark.platform_arm_ascend_training
293@pytest.mark.platform_x86_ascend_training
294@pytest.mark.env_onecard
295def test_and_or_operation():
296    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
297    x = np.array([0, 1]).astype(np.float32)
298    y = np.array([0, 0]).astype(np.float32)
299    net = AndOperation()
300    output = net(Tensor(x), Tensor(y))
301    expect = np.sum(x) and np.sum(y)
302    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
303
304    net = OrOperation()
305    output = net(Tensor(x), Tensor(y))
306    expect = np.sum(x) or np.sum(y)
307    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
308
309    net = NotOperation()
310    output = net(Tensor(x))
311    expect = not np.sum(x)
312    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
313