1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore.context as context 18import mindspore.ops.composite as C 19from mindspore import Tensor, Parameter 20from mindspore import nn 21from mindspore.nn import Cell 22from mindspore.ops import operations as P 23 24context.set_context(mode=context.GRAPH_MODE) 25 26grad_all = C.GradOperation(get_all=True) 27grad_all_with_sens = C.GradOperation(sens_param=True) 28 29 30def test_parser_three_default_mixed_args_subnet(): 31 class SubNetDefaultMixedArgs(Cell): 32 def __init__(self): 33 super().__init__() 34 35 def construct(self, y, x=3, x1=None, x2=(1, 2)): 36 if x == 3: 37 if x1 == None: 38 return y 39 return -y 40 41 class NetOut(Cell): 42 def __init__(self): 43 super(NetOut, self).__init__() 44 self.net_inside = SubNetDefaultMixedArgs() 45 46 def construct(self, x, y=3): 47 z = self.net_inside(x) 48 49 return z 50 51 tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32)) 52 tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32)) 53 net = NetOut() 54 assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy()) 55 56 57# pylint: disable=keyword-arg-before-vararg 58def test_net_vararg_kwonlyarg_kwarg(): 59 class FirstNet(Cell): 60 def __init__(self): 61 super(FirstNet, self).__init__() 62 self.net = SecondNet() 63 64 def construct(self, x=1, z=2 + 2 + 4, y=3): 65 c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40) 66 return c 67 68 class SecondNet(Cell): 69 def __init__(self): 70 super(SecondNet, self).__init__() 71 72 def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs): 73 a = x - y 74 b = p * q 75 c = a / b 76 d = var[0] * var[1] * var[2] * var[3] 77 e = key1 - key2 - kwargs["key3"] + kwargs["key4"] 78 return a + b + c + d + e 79 80 net = FirstNet() 81 net() 82 83 84# pylint: disable=keyword-arg-before-vararg 85def test_net_vararg_normal_input(): 86 class FirstNet(Cell): 87 def __init__(self): 88 super(FirstNet, self).__init__() 89 self.net = SecondNet() 90 91 def construct(self, x=1, z=2 + 2 + 4, y=3): 92 c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40) 93 return c 94 95 class SecondNet(Cell): 96 def __init__(self): 97 super(SecondNet, self).__init__() 98 99 def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs): 100 a = x - y 101 b = p * q 102 c = a / b 103 d = var[0] * var[1] * var[2] * var[3] 104 e = key1 - key2 - kwargs["key3"] + kwargs["key4"] 105 return a + b + c + d + e 106 107 x = Tensor(np.ones((2, 3, 4), np.int32)) 108 net = FirstNet() 109 net(x, x, x) 110 111 112def test_prim_vararg_kwonlyarg(): 113 class FirstNet(Cell): 114 def __init__(self): 115 super(FirstNet, self).__init__() 116 self.max = P.Maximum() 117 self.min = P.Minimum() 118 self.net = SecondNet() 119 self.x = Tensor(np.ones((2, 3, 4), np.float32)) 120 self.y = Tensor(np.ones((2, 3, 4), np.float32)) 121 122 def construct(self): 123 a = self.max(self.x, self.y) 124 b = self.min(self.x, self.y) 125 t = {"x": a, "y": b} 126 c = self.net(t["x"], t["y"], a, b, z=a, r=b) 127 return c 128 129 class SecondNet(Cell): 130 def __init__(self): 131 super(SecondNet, self).__init__() 132 self.addN = P.AddN() 133 self.max = P.Maximum() 134 self.add = P.Add() 135 136 def construct(self, x, y, *args, z=0, r=1): 137 c = self.max(args[0], args[1]) 138 d = self.addN(args) 139 e = self.max(*args) 140 ret = x + y + c + d + e + z + r 141 return ret 142 143 net = FirstNet() 144 net() 145 146 147def test_no_vararg(): 148 class FirstNet(Cell): 149 def __init__(self): 150 super(FirstNet, self).__init__() 151 self.max = P.Maximum() 152 self.min = P.Minimum() 153 self.net = SecondNet() 154 self.x = Tensor(np.ones((2, 3, 4), np.float32)) 155 self.y = Tensor(np.ones((2, 3, 4), np.float32)) 156 157 def construct(self): 158 t = {"x": self.x, "y": self.y} 159 a = self.max(self.x, self.y) 160 b = self.min(self.x, self.y) 161 c = self.net(a, b, z=a, r=b) 162 return c 163 164 class SecondNet(Cell): 165 def __init__(self): 166 super(SecondNet, self).__init__() 167 168 def construct(self, x, y, *, z=0, r=1): 169 ret = x + y + z + r 170 return ret 171 172 net = FirstNet() 173 net() 174 175 176def test_net_variable_and_weights(): 177 class FirstNet(Cell): 178 def __init__(self): 179 super(FirstNet, self).__init__() 180 self.max = P.Maximum() 181 self.min = P.Minimum() 182 self.net = SecondNet() 183 self.x = Tensor(np.ones((3, 4), np.float32)) 184 self.y = Tensor(np.ones((3, 4), np.float32)) 185 self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True) 186 187 def construct(self, *args): 188 t = (self.x, self.y) 189 a = self.max(self.x, self.weight) 190 b = self.min(self.weight, args[0]) 191 c = self.net(a, b, *t) 192 return c 193 194 class SecondNet(Cell): 195 def __init__(self): 196 super(SecondNet, self).__init__() 197 self.addN = P.AddN() 198 self.max = P.Maximum() 199 self.add = P.Add() 200 self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True) 201 202 def construct(self, a, b, *args): 203 c = self.max(args[0], a) 204 d = self.addN(args) 205 ret = a + b + c + d + self.weight 206 return ret 207 208 net = FirstNet() 209 x = Tensor(np.ones((4,), np.float32)) 210 y = Tensor(np.ones((4,), np.float32)) 211 z = Tensor(np.ones((4,), np.float32)) 212 net(x, y, z) 213 214 215def test_net_vargs_expand(): 216 class InputBackward(Cell): 217 """ InputBackward definition """ 218 219 def __init__(self, network, c1=None, c2=None): 220 super(InputBackward, self).__init__() 221 self.network = network 222 self.network.set_train() 223 self.grad = grad_all_with_sens 224 self.c1 = c1 225 self.c2 = c2 226 227 def construct(self, *inputs): 228 return self.grad(self.network)(*inputs) 229 230 class AddNet(Cell): 231 def __init__(self): 232 super(AddNet, self).__init__() 233 234 def construct(self, x, y): 235 return x + y 236 237 net = InputBackward(AddNet()) 238 x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 239 y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 240 sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) 241 242 net.set_train() 243 net(x, y, sens) 244 245 246def test_mixed_precision_const_parameter(): 247 class NetLoss(Cell): 248 def __init__(self): 249 super(NetLoss, self).__init__() 250 self.shape = P.Shape() 251 self.up_sample1 = P.ResizeBilinear((14, 14)) 252 self.up_sample2 = P.ResizeBilinear((28, 28)) 253 self.up_sample3 = P.ResizeBilinear((36, 36)) 254 255 def construct(self, x, y, z, *args): 256 ret = 0 257 if args[0] == self.shape(z)[2]: 258 if args[0] == 14: 259 ret = self.up_sample1(y) + x 260 elif args[0] == 28: 261 ret = self.up_sample2(y) - x 262 else: 263 ret = x / y 264 else: 265 ret = x * y 266 ret = ret * z 267 return ret 268 269 class NetMain(Cell): 270 def __init__(self, loss_fn): 271 super(NetMain, self).__init__() 272 self.loss_fn = loss_fn 273 self.shape = P.Shape() 274 275 def construct(self, x, y, z): 276 size_x = self.shape(x)[2] 277 size_y = self.shape(y)[2] 278 ret = self.loss_fn(x, y, z, size_x, size_y) 279 return ret 280 281 loss_fn = NetLoss() 282 net = NetMain(loss_fn) 283 net.add_flags_recursive(fp32=True) 284 x = Tensor(np.ones((1, 3, 28, 28), np.float32)) 285 y = Tensor(np.ones((1, 3, 14, 14), np.float32)) 286 z = Tensor(np.ones((1, 3, 28, 28), np.float32)) 287 _ = net(x, y, z) 288 289 290def test_pass_args_by_key_ward_way(): 291 class KeyWardNet(Cell): 292 def __init__(self): 293 super(KeyWardNet, self).__init__() 294 295 def construct(self, x, y, z): 296 return x + y - z 297 298 class GradNet(Cell): 299 def __init__(self, net): 300 super(GradNet, self).__init__() 301 self.grad = C.GradOperation(get_all=True, sens_param=True) 302 self.net = net 303 self.sens = Tensor(np.ones((3, 3, 4), np.float32)) 304 305 def construct(self, x, y, z, sens): 306 return self.grad(self.net)(x, y, z, sens) 307 308 x = Tensor(np.ones((1, 3, 4), np.float32)) 309 y = Tensor(np.ones((1, 3, 4), np.float32)) 310 z = Tensor(np.ones((3, 3, 4), np.float32)) 311 net = KeyWardNet() 312 net(x, z=z, y=y) 313 grad_net = GradNet(net) 314 sens = Tensor(np.ones((3, 3, 4), np.float32)) 315 grad_net(x, y=y, z=z, sens=sens) 316 317 318def test_none_input(): 319 """ 320 Feature: Net's inputs 321 Description: Support none input for the outermost net 322 Expectation: no error 323 """ 324 325 class Net(Cell): 326 def __init__(self): 327 super(Net, self).__init__() 328 self.op = nn.ResizeBilinear() 329 330 def construct(self, a, b, c, d): 331 return self.op(a, b, c, d) 332 333 x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32).reshape((1, 1, 2, 2,))) 334 net = Net() 335 net(x, (4, 4), None, True) 336