1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_ascend_control_sink """ 16import pytest 17import numpy as np 18import mindspore.context as context 19import mindspore.nn as nn 20from mindspore.ops import operations as op 21from mindspore.common import dtype as mstype 22from mindspore.common.tensor import Tensor 23from mindspore.common.parameter import Parameter 24from mindspore.common.initializer import initializer 25 26 27class ControlSimpleIf(nn.Cell): 28 def __init__(self): 29 super().__init__() 30 self.addn = op.AddN() 31 32 def construct(self, x, y, z, input1, input2): 33 addn1 = self.addn([input1, input1, input1]) 34 addn2 = self.addn([input2, input2, input2]) 35 addn11 = self.addn([addn1, addn1, addn1]) 36 addn22 = self.addn([addn2, addn2, addn2]) 37 cond1 = x > y 38 cond2 = y > z 39 # dodge pylint 40 if cond1 and cond2: 41 out = self.addn([addn11, addn11]) 42 else: 43 out = self.addn([addn22, addn22]) 44 out_me = self.addn([out, input1]) 45 return out_me 46 47 48class ControlSimpleIfWithAssign(nn.Cell): 49 def __init__(self, input_shape): 50 super().__init__() 51 self.addn = op.AddN() 52 self.assign = op.Assign() 53 self.input_data = Parameter(initializer(1, input_shape, mstype.float32), name="var") 54 55 def construct(self, x, y, input_data): 56 if x > y: 57 out = self.addn([input_data, input_data, input_data]) 58 else: 59 out = self.assign(self.input_data, input_data) 60 return out 61 62 63class ControlIfinIf(nn.Cell): 64 """pass""" 65 66 def construct(self, x, y): 67 if x > y: 68 x = x + 1 69 if y < 0: 70 y = y + 1 71 else: 72 y = y + 2 73 else: 74 x = x + 2 75 x = x + y 76 return x 77 78 79class ControlIfbyIfbyIf(nn.Cell): 80 def __init__(self): 81 super().__init__() 82 self.addn = op.AddN() 83 84 def construct(self, x, y, cond1, cond2, input_data): 85 tri_in = self.addn([input_data, input_data, input_data]) 86 if x > y: 87 addn_1 = self.addn([tri_in, tri_in]) 88 else: 89 addn_1 = self.addn([tri_in, tri_in, tri_in]) 90 if cond1: 91 addn_2 = self.addn([addn_1, addn_1]) 92 else: 93 addn_2 = self.addn([addn_1, addn_1, addn_1]) 94 if cond2: 95 out = self.addn([addn_2, addn_2, addn_2]) 96 else: 97 out = self.addn([addn_2, addn_2]) 98 return out 99 100 101class ControlSimpleWhile(nn.Cell): 102 def __init__(self): 103 super().__init__() 104 self.addn = op.AddN() 105 106 def construct(self, x, y, input_data): 107 out = input_data 108 while x: 109 out = self.addn([input_data, input_data, input_data]) 110 x = y 111 return out 112 113 114class ControlMixedWhileIf(nn.Cell): 115 def __init__(self): 116 super().__init__() 117 self.assign = op.Assign() 118 self.var = Parameter(initializer(1, (1), mstype.float32), name="var") 119 120 def construct(self, x, y, z, c2, c4): 121 out = c4 122 self.assign(self.var, c4) 123 while x < c2: 124 y = c4 125 self.assign(self.var, c4) 126 while y < c2 and x < c2: 127 if 2 * y < c2: 128 y = y + 2 129 else: 130 y = y + 1 131 out = out + y 132 z = c4 133 self.assign(self.var, c4) 134 while z < c2: 135 z = z + 1 136 out = out + z 137 x = x + 1 138 out = out + x 139 while x < 2 * c2: 140 y = c4 141 self.assign(self.var, c4) 142 x = x + 1 143 while y < c2: 144 z = c4 145 self.assign(self.var, c4) 146 while z < c2: 147 z = z + 1 148 if x < c2: 149 y = y - 1 150 else: 151 y = y + 1 152 out = out + z 153 out = out + y 154 out = out + x 155 return out 156 157 158class AndOperation(nn.Cell): 159 def __init__(self): 160 super().__init__() 161 self.reduce_sum = op.ReduceSum() 162 163 def construct(self, x, y): 164 x_sum = self.reduce_sum(x) 165 y_sum = self.reduce_sum(y) 166 out = x_sum and y_sum 167 return out 168 169 170class OrOperation(nn.Cell): 171 def __init__(self): 172 super().__init__() 173 self.reduce_sum = op.ReduceSum() 174 175 def construct(self, x, y): 176 x_sum = self.reduce_sum(x) 177 y_sum = self.reduce_sum(y) 178 out = x_sum or y_sum 179 return out 180 181 182class NotOperation(nn.Cell): 183 def __init__(self): 184 super().__init__() 185 self.reduce_sum = op.ReduceSum() 186 187 def construct(self, x): 188 x_sum = self.reduce_sum(x) 189 return not x_sum 190 191 192@pytest.mark.level1 193@pytest.mark.platform_arm_ascend_training 194@pytest.mark.platform_x86_ascend_training 195@pytest.mark.env_onecard 196def test_simple_if(): 197 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 198 x = np.array(3).astype(np.float32) 199 y = np.array(2).astype(np.float32) 200 z = np.array(3).astype(np.float32) 201 input_shape = (127, 7, 53, 31) 202 input1 = np.random.randn(*input_shape).astype(np.float32) 203 input2 = np.random.randn(*input_shape).astype(np.float32) 204 net = ControlSimpleIf() 205 output = net(Tensor(x), Tensor(y), Tensor(z), Tensor(input1), Tensor(input2)) 206 expect = input2 * 3 * 3 * 2 + input1 207 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 208 209 210@pytest.mark.level1 211@pytest.mark.platform_arm_ascend_training 212@pytest.mark.platform_x86_ascend_training 213@pytest.mark.env_onecard 214def test_simple_if_with_assign(): 215 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 216 x = np.array(0).astype(np.float32) 217 y = np.array(1).astype(np.float32) 218 input_shape = (127, 7, 53, 31) 219 input_data = np.random.randn(*input_shape).astype(np.float32) 220 net = ControlSimpleIfWithAssign(input_shape) 221 output = net(Tensor(x), Tensor(y), Tensor(input_data)) 222 expect = input_data 223 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 224 225 226@pytest.mark.level1 227@pytest.mark.platform_arm_ascend_training 228@pytest.mark.platform_x86_ascend_training 229@pytest.mark.env_onecard 230def test_if_in_if(): 231 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 232 x = np.array(2.345678).astype(np.float32) 233 y = np.array(1.234567).astype(np.float32) 234 net = ControlIfinIf() 235 output = net(Tensor(x), Tensor(y)) 236 expect = x + y + 3 237 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 238 239 240@pytest.mark.level1 241@pytest.mark.platform_arm_ascend_training 242@pytest.mark.platform_x86_ascend_training 243@pytest.mark.env_onecard 244def test_if_by_if_by_if(): 245 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 246 x = np.array(2.345678).astype(np.float32) 247 y = np.array(1.234567).astype(np.float32) 248 cond1 = np.array(True).astype(np.bool) 249 cond2 = np.array(False).astype(np.bool) 250 input_shape = (127, 7, 53, 31) 251 input_data = np.random.randn(*input_shape).astype(np.float32) 252 net = ControlIfbyIfbyIf() 253 output = net(Tensor(x), Tensor(y), Tensor(cond1), Tensor(cond2), Tensor(input_data)) 254 expect = input_data * 3 * 2 * 2 * 2 255 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 256 257 258@pytest.mark.level0 259@pytest.mark.platform_arm_ascend_training 260@pytest.mark.platform_x86_ascend_training 261@pytest.mark.env_onecard 262def test_simple_while(): 263 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 264 x = np.array(True).astype(np.bool) 265 y = np.array(False).astype(np.bool) 266 input_shape = (127, 7, 53, 31) 267 input_data = np.random.randn(*input_shape).astype(np.float32) 268 net = ControlSimpleWhile() 269 output = net(Tensor(x), Tensor(y), Tensor(input_data)) 270 expect = input_data * 3 271 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 272 273 274@pytest.mark.level1 275@pytest.mark.platform_arm_ascend_training 276@pytest.mark.platform_x86_ascend_training 277@pytest.mark.env_onecard 278def test_mixed_while_if(): 279 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 280 x = np.array(2).astype(np.int32) 281 y = np.array(14).astype(np.int32) 282 z = np.array(1).astype(np.int32) 283 c2 = Tensor([14], mstype.int32) 284 c4 = Tensor([0], mstype.int32) 285 net = ControlMixedWhileIf() 286 output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) 287 expect = np.array(3318).astype(np.int32) 288 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 289 290 291@pytest.mark.level0 292@pytest.mark.platform_arm_ascend_training 293@pytest.mark.platform_x86_ascend_training 294@pytest.mark.env_onecard 295def test_and_or_operation(): 296 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 297 x = np.array([0, 1]).astype(np.float32) 298 y = np.array([0, 0]).astype(np.float32) 299 net = AndOperation() 300 output = net(Tensor(x), Tensor(y)) 301 expect = np.sum(x) and np.sum(y) 302 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 303 304 net = OrOperation() 305 output = net(Tensor(x), Tensor(y)) 306 expect = np.sum(x) or np.sum(y) 307 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 308 309 net = NotOperation() 310 output = net(Tensor(x)) 311 expect = not np.sum(x) 312 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) 313