1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test ops """ 16import numpy as np 17 18import mindspore.nn as nn 19import mindspore.ops.composite as C 20import mindspore.ops.functional as F 21import mindspore.ops.operations as P 22from mindspore import Tensor 23from mindspore.common.api import _cell_graph_executor 24 25 26grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 27 28 29class InputBackward(nn.Cell): 30 """ InputBackward definition """ 31 32 def __init__(self, network, c1=None, c2=None): 33 super(InputBackward, self).__init__() 34 self.network = network 35 self.network.set_train() 36 self.grad = grad_all_with_sens 37 self.c1 = c1 38 self.c2 = c2 39 40 def construct(self, *inputs): 41 pass 42 43 def construct1(self, x1, sens): 44 return self.grad(self.network)(x1, sens) 45 46 def construct2(self, x1, x2, sens): 47 return self.grad(self.network)(x1, x2, sens) 48 49 def construct3(self, x1, x2, x3, sens): 50 return self.grad(self.network)(x1, x2, x3, sens) 51 52 def construct4(self, x1, x2, x3, x4, sens): 53 return self.grad(self.network)(x1, x2, x3, x4, sens) 54 55 def construct5(self, x1, x2, x3, x4, x5, sens): 56 return self.grad(self.network)(x1, x2, x3, x4, x5, sens) 57 58 def construct6(self, x1, x2, x3, x4, x5, x6, sens): 59 return self.grad(self.network)(x1, x2, x3, x4, x5, x6, sens) 60 61 def construct7(self, x1, x2, x3, x4, x5, x6, x7, sens): 62 return self.grad(self.network)(x1, x2, x3, x4, x5, x6, x7, sens) 63 64 65class InputOpNet(nn.Cell): 66 """ InputOpNet definition """ 67 68 def __init__(self, op, get_first=False, 69 c1=None, c2=None, c3=None, c4=None): 70 super(InputOpNet, self).__init__() 71 self.op = op 72 self.get_first = get_first 73 self.c1 = c1 74 self.c2 = c2 75 self.c3 = c3 76 self.c4 = c4 77 78 def construct(self, *inputs): 79 pass 80 81 def construct0_c0_fack(self, data): 82 x = self.op() + data 83 if self.get_first: 84 x = x[0] 85 return x 86 87 def construct0_c1_fack(self, data): 88 x = self.op(self.c1) + data 89 if self.get_first: 90 x = x[0] 91 return x 92 93 def construct0_c2_fack(self, data): 94 x = self.op(self.c1, self.c2) + data 95 if self.get_first: 96 x = x[0] 97 return x 98 99 def construct0_c0(self): 100 x = self.op() 101 if self.get_first: 102 x = x[0] 103 return x 104 105 def construct0_c1(self): 106 x = self.op(self.c1) 107 if self.get_first: 108 x = x[0] 109 return x 110 111 def construct0_c2(self): 112 x = self.op(self.c1, self.c2) 113 if self.get_first: 114 x = x[0] 115 return x 116 117 def construct1_c0(self, x1): 118 x = self.op(x1) 119 if self.get_first: 120 x = x[0] 121 return x 122 123 def construct1_c1(self, x1): 124 x = self.op(x1, self.c1) 125 if self.get_first: 126 x = x[0] 127 return x 128 129 def construct1_c2(self, x1): 130 x = self.op(x1, self.c1, self.c2) 131 if self.get_first: 132 x = x[0] 133 return x 134 135 def construct1_c3(self, x1): 136 x = self.op(x1, self.c1, self.c2, self.c3) 137 if self.get_first: 138 x = x[0] 139 return x 140 141 def construct1_c4(self, x1): 142 x = self.op(x1, self.c1, self.c2, self.c3, self.c4) 143 if self.get_first: 144 x = x[0] 145 return x 146 147 def constructc1_1(self, x1): 148 x = self.op(self.c1, x1) 149 if self.get_first: 150 x = x[0] 151 return x 152 153 def construct2_c0(self, x1, x2): 154 x = self.op(x1, x2) 155 if self.get_first: 156 x = x[0] 157 return x 158 159 def construct2_c1(self, x1, x2): 160 x = self.op(x1, x2, self.c1) 161 if self.get_first: 162 x = x[0] 163 return x 164 165 def construct2_c3(self, x1, x2): 166 x = self.op(x1, x2, self.c1, self.c2, self.c3) 167 if self.get_first: 168 x = x[0] 169 return x 170 171 def construct3_c0(self, x1, x2, x3): 172 x = self.op(x1, x2, x3) 173 if self.get_first: 174 x = x[0] 175 return x 176 177 def construct3_c1(self, x1, x2, x3): 178 x = self.op(x1, x2, x3, self.c1) 179 if self.get_first: 180 x = x[0] 181 return x 182 183 def construct4_c0(self, x1, x2, x3, x4): 184 x = self.op(x1, x2, x3, x4) 185 if self.get_first: 186 x = x[0] 187 return x 188 189 def construct4_c1(self, x1, x2, x3, x4): 190 x = self.op(x1, x2, x3, x4, self.c1) 191 if self.get_first: 192 x = x[0] 193 return x 194 195 def construct5_c0(self, x1, x2, x3, x4, x5): 196 x = self.op(x1, x2, x3, x4, x5) 197 if self.get_first: 198 x = x[0] 199 return x 200 201 def construct6_c0(self, x1, x2, x3, x4, x5, x6): 202 x = self.op(x1, x2, x3, x4, x5, x6) 203 if self.get_first: 204 x = x[0] 205 return x 206 207 def construct5_c1(self, x1, x2, x3, x4, x5): 208 x = self.op(x1, x2, x3, x4, x5, self.c1) 209 if self.get_first: 210 x = x[0] 211 return x 212 213 214class NetOutputAsLoss(nn.Cell): 215 """ NetOutputAsLoss definition """ 216 217 def __init__(self, network, output_index): 218 super(NetOutputAsLoss, self).__init__() 219 self.network = network 220 self.output_index = output_index 221 222 def construct(self, *inputs): 223 pass 224 225 def construct1(self, x1): 226 predict = self.network(x1)[self.output_index] 227 return predict 228 229 def construct2(self, x1, x2): 230 predict = self.network(x1, x2)[self.output_index] 231 return predict 232 233 def construct3(self, x1, x2, x3): 234 predict = self.network(x1, x2, x3)[self.output_index] 235 return predict 236 237 def construct4(self, x1, x2, x3, x4): 238 predict = self.network(x1, x2, x3, x4)[self.output_index] 239 return predict 240 241 def construct5(self, x1, x2, x3, x4, x5): 242 predict = self.network(x1, x2, x3, x4, x5)[self.output_index] 243 return predict 244 245 246def get_loss_fun(construct_net, num_input, output_index): 247 net = NetOutputAsLoss(construct_net, output_index) 248 f = getattr(net, 'construct%d' % num_input) 249 setattr(net, "construct", f) 250 return net 251 252 253def build_construct_graph(net, *inputs, execute=True): 254 net.set_train() 255 _cell_graph_executor.compile(net, *inputs) 256 if execute: 257 _cell_graph_executor(net, inputs) 258 259 260def build_backward_graph(net, output_shapes, inputs, execute=True): 261 inputs = append_sens_to_inputs(output_shapes, inputs) 262 net = gen_backward_net(net, len(inputs) - 1) 263 net.set_train() 264 _cell_graph_executor.compile(net, inputs) 265 if execute: 266 _cell_graph_executor(net, inputs) 267 268 269def convert(shp, dtype=np.float32, scale=6): 270 if isinstance(shp, list): 271 if not shp: 272 return Tensor((np.random.rand() * scale).astype(dtype)) 273 return Tensor((np.random.rand(*shp) * scale).astype(dtype)) 274 return shp 275 276 277def gen_inputs(input_shapes, config): 278 add_fack_input = config.get('add_fack_input', False) 279 if not input_shapes and add_fack_input: 280 return [Tensor(np.array([1.0]).astype(config.get('fack_input_type', np.float32)))] 281 return [convert(shp) for shp in input_shapes] 282 283 284def gen_backward_inputs(input_shapes, output_shapes, config): 285 add_fack_input = config.get('add_fack_input', False) 286 if not input_shapes and add_fack_input: 287 inputs = [Tensor(np.array([1.0]))] 288 else: 289 inputs = [convert(shp) for shp in input_shapes] 290 sens_shape = output_shapes[0] 291 sens = convert(sens_shape) 292 return inputs + [sens] 293 294 295def append_sens_to_inputs(output_shapes, inputs): 296 inputs = inputs 297 sens = Tensor(np.random.normal(0, 1, output_shapes).astype(np.float32)) 298 return inputs + [sens] 299 300 301def gen_net(shapes, config, get_first=False): 302 """ 303 gen_net function 304 """ 305 add_fack_input = config.get('add_fack_input', False) 306 op = config['op'] 307 if 'const' not in config: 308 const_input = [] 309 else: 310 const_input = config['const'] 311 const_first = False 312 if 'const_first' in config: 313 const_first = config['const_first'] 314 315 net = InputOpNet(op, get_first, *const_input) 316 if const_first: 317 fn_name = 'constructc%d_%d' % (len(const_input), len(shapes)) 318 else: 319 fn_name = 'construct%d_c%d' % (len(shapes), len(const_input)) 320 if add_fack_input: 321 fn_name += '_fack' 322 f = getattr(net, fn_name) 323 setattr(net, "construct", f) 324 return net 325 326 327def gen_backward_net(construct_net, input_num): 328 net = InputBackward(construct_net) 329 f = getattr(net, 'construct%d' % input_num) 330 setattr(net, "construct", f) 331 return net 332 333 334def batch_tuple_tensor(data, batch_size): 335 ret = [Tensor(np.tile(d.asnumpy(), (batch_size, 1))) for d in data] 336 return tuple(ret) 337 338 339class OutPutWrap(nn.Cell): 340 """ 341 OutPutWrap definition 342 """ 343 344 def __init__(self, network, num_output, output_is_tuple): 345 super(OutPutWrap, self).__init__() 346 self.network = network 347 self.num_output = num_output 348 self.one = Tensor(np.array([1])) 349 self.dtype = P.DType() 350 self.cast = P.Cast() 351 self.output_is_tuple = output_is_tuple 352 353 def construct(self, *inputs): 354 pass 355 356 def construct1(self, x1): 357 ret = F.make_tuple() 358 predict = self.network(x1) 359 if self.num_output == 1 and self.output_is_tuple == 0: 360 return predict * self.cast(self.one, self.dtype(predict)) 361 for i in range(self.num_output): 362 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 363 return ret 364 365 def construct2(self, x1, x2): 366 ret = F.make_tuple() 367 predict = self.network(x1, x2) 368 if self.num_output == 1 and self.output_is_tuple == 0: 369 return predict * self.cast(self.one, self.dtype(predict)) 370 for i in range(self.num_output): 371 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 372 return ret 373 374 def construct3(self, x1, x2, x3): 375 ret = F.make_tuple() 376 predict = self.network(x1, x2, x3) 377 if self.num_output == 1 and self.output_is_tuple == 0: 378 return predict * self.cast(self.one, self.dtype(predict)) 379 for i in range(self.num_output): 380 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 381 return ret 382 383 def construct4(self, x1, x2, x3, x4): 384 ret = F.make_tuple() 385 predict = self.network(x1, x2, x3, x4) 386 if self.num_output == 1 and self.output_is_tuple == 0: 387 return predict * self.cast(self.one, self.dtype(predict)) 388 for i in range(self.num_output): 389 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 390 return ret 391 392 def construct5(self, x1, x2, x3, x4, x5): 393 ret = F.make_tuple() 394 predict = self.network(x1, x2, x3, x4, x5) 395 if self.num_output == 1 and self.output_is_tuple == 0: 396 return predict * self.cast(self.one, self.dtype(predict)) 397 for i in range(self.num_output): 398 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 399 return ret 400 401 def construct6(self, x1, x2, x3, x4, x5, x6): 402 ret = F.make_tuple() 403 predict = self.network(x1, x2, x3, x4, x5, x6) 404 if self.num_output == 1 and self.output_is_tuple == 0: 405 return predict * self.cast(self.one, self.dtype(predict)) 406 for i in range(self.num_output): 407 ret = ret + F.make_tuple(predict[i] * self.cast(self.one, self.dtype(predict[i]))) 408 return ret 409 410 411def get_output_wrap(network, num_input, num_output, output_is_tuple=0): 412 net = OutPutWrap(network, num_output, output_is_tuple) 413 f = getattr(net, 'construct%d' % num_input) 414 setattr(net, "construct", f) 415 return net 416