1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Utils for Cell related computation.""" 17 18# pylint: disable=missing-docstring 19 20import numpy as np 21 22from mindspore import ParameterTuple 23from mindspore import nn, context 24from mindspore.common.api import _cell_graph_executor, jit 25from mindspore.common.tensor import Tensor 26from mindspore.ops import functional as F 27from mindspore.ops import operations as P 28from mindspore.ops.composite import GradOperation 29from . import keyword 30 31 32def get_uniform_with_shape(shape): 33 np.random.seed(1) 34 return np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) 35 36 37def set_block_param_with_rand(net, rand_func=None): 38 if not isinstance(net, nn.Cell) or rand_func is None: 39 return 40 net.init_parameters_data() 41 for param in net.trainable_params(): 42 param.set_data(Tensor(rand_func(param.data.asnumpy().shape))) 43 44 45def compile_block(net, *inputs, rand_func=None, training=True): 46 set_block_training(net, training) 47 set_block_param_with_rand(net, rand_func) 48 return _cell_graph_executor.compile(net, *inputs) 49 50 51def run_block(net, *inputs, rand_func=None, training=True): 52 set_block_training(net, training) 53 set_block_param_with_rand(net, rand_func) 54 if context.get_context("mode") == context.PYNATIVE_MODE: 55 def func_pynative(*inputs): 56 @jit 57 def _func_pynative(*inputs): 58 return net(*inputs) 59 60 return _func_pynative(*inputs) 61 62 return func_pynative(*inputs) 63 return net(*inputs) 64 65 66class IthOutputCell(nn.Cell): 67 def __init__(self, network, output_index): 68 if isinstance(network, nn.Cell): 69 super(IthOutputCell, self).__init__(auto_prefix=False) 70 else: 71 super(IthOutputCell, self).__init__() 72 self.network = network 73 self.output_index = output_index 74 75 def construct(self, *inputs): 76 predict = self.network(*inputs)[self.output_index] 77 return predict 78 79 80def get_output_cell(network, num_input, output_index, training=True): 81 _ = num_input 82 net = IthOutputCell(network, output_index) 83 set_block_training(net, training) 84 return net 85 86 87class OutputReduceSumCell(nn.Cell): 88 def __init__(self, network, output_num): 89 super(OutputReduceSumCell, self).__init__() 90 self.output_num = output_num 91 self.network = network 92 self.reduce_sum = P.ReduceSum() 93 94 def construct(self, *inputs): 95 if self.output_num == 1: 96 return self.reduce_sum(self.network(*inputs), None) 97 ret = F.make_tuple() 98 for index in range(self.output_num): 99 predict = self.network(*inputs)[index] 100 predict_reduce = self.reduce_sum(predict, None) 101 ret = ret + F.make_tuple(predict_reduce) 102 return ret 103 104 105def get_output_reduce_cell(network, output_num, training=True): 106 net = OutputReduceSumCell(network, output_num) 107 set_block_training(net, training) 108 return net 109 110 111class InputOpNet(nn.Cell): 112 def __init__(self, op, c1=None, c2=None, c3=None, c4=None): 113 super(InputOpNet, self).__init__() 114 self.op = op 115 self.c1 = c1 116 self.c2 = c2 117 self.c3 = c3 118 self.c4 = c4 119 120 def construct(self, *inputs): 121 raise NotImplementedError 122 123 def construct0_c0_fake(self, data): 124 x = self.op() + data 125 return x 126 127 def construct0_c1_fake(self, data): 128 x = self.op(self.c1) + data 129 return x 130 131 def construct0_c2_fake(self, data): 132 x = self.op(self.c1, self.c2) + data 133 return x 134 135 def construct0_c3_fake(self, data): 136 x = self.op(self.c1, self.c2, self.c3) + data 137 return x 138 139 def construct0_c0(self): 140 x = self.op() 141 return x 142 143 def construct0_c1(self): 144 x = self.op(self.c1) 145 return x 146 147 def construct0_c2(self): 148 x = self.op(self.c1, self.c2) 149 return x 150 151 def construct1_c0(self, x1): 152 x = self.op(x1) 153 return x 154 155 def construct1_c1(self, x1): 156 x = self.op(x1, self.c1) 157 return x 158 159 def construct1_c2(self, x1): 160 x = self.op(x1, self.c1, self.c2) 161 return x 162 163 def construct1_c3(self, x1): 164 x = self.op(x1, self.c1, self.c2, self.c3) 165 return x 166 167 def construct1_c4(self, x1): 168 x = self.op(x1, self.c1, self.c2, self.c3, self.c4) 169 return x 170 171 def constructc1_1(self, x1): 172 x = self.op(self.c1, x1) 173 return x 174 175 def construct2_c0(self, x1, x2): 176 x = self.op(x1, x2) 177 return x 178 179 def construct2_c1(self, x1, x2): 180 x = self.op(x1, x2, self.c1) 181 return x 182 183 def construct2_c3(self, x1, x2): 184 x = self.op(x1, x2, self.c1, self.c2, self.c3) 185 return x 186 187 def construct3_c0(self, x1, x2, x3): 188 x = self.op(x1, x2, x3) 189 return x 190 191 def construct3_c1(self, x1, x2, x3): 192 x = self.op(x1, x2, x3, self.c1) 193 return x 194 195 def construct4_c0(self, x1, x2, x3, x4): 196 x = self.op(x1, x2, x3, x4) 197 return x 198 199 def construct4_c1(self, x1, x2, x3, x4): 200 x = self.op(x1, x2, x3, x4, self.c1) 201 return x 202 203 def construct4_c2(self, x1, x2, x3, x4): 204 x = self.op(x1, x2, x3, x4, self.c1, self.c2) 205 return x 206 207 def construct4_c4(self, x1, x2, x3, x4): 208 x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) 209 return x 210 211 def construct5_c0(self, x1, x2, x3, x4, x5): 212 x = self.op(x1, x2, x3, x4, x5) 213 return x 214 215 def construct6_c0(self, x1, x2, x3, x4, x5, x6): 216 x = self.op(x1, x2, x3, x4, x5, x6) 217 return x 218 219 def construct5_c1(self, x1, x2, x3, x4, x5): 220 x = self.op(x1, x2, x3, x4, x5, self.c1) 221 return x 222 223 def construct5_c4(self, x1, x2, x3, x4, x5): 224 x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4) 225 return x 226 227 def construct7_c0(self, x1, x2, x3, x4, x5, x6, x7): 228 x = self.op(x1, x2, x3, x4, x5, x6, x7) 229 return x 230 231 def construct8_c0(self, x1, x2, x3, x4, x5, x6, x7, x8): 232 x = self.op(x1, x2, x3, x4, x5, x6, x7, x8) 233 return x 234 235 def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9): 236 x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9) 237 return x 238 239 240def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): 241 if isinstance(op, nn.Cell): 242 return op 243 net = InputOpNet(op, *desc_const) 244 if const_first: 245 fn_name = 'constructc%d_%d' % (len(desc_const), input_num) 246 else: 247 fn_name = 'construct%d_c%d' % (input_num, len(desc_const)) 248 if add_fake_input: 249 fn_name += '_fake' 250 f = getattr(net, fn_name) 251 setattr(net, "construct", f) 252 set_block_training(net, training) 253 return net 254 255 256class OperationBackward(nn.Cell): 257 def __init__(self, network, grad_op, sens): 258 if isinstance(network, nn.Cell): 259 super(OperationBackward, self).__init__(auto_prefix=False) 260 else: 261 super(OperationBackward, self).__init__() 262 self.network = network 263 self.grad = grad_op 264 self.sens = sens 265 266 def construct(self, *inputs): 267 return self.grad(self.network)(*inputs, self.sens) 268 269 270class OperationBackwardWithNoSens(nn.Cell): 271 def __init__(self, network, grad_op): 272 if isinstance(network, nn.Cell): 273 super(OperationBackwardWithNoSens, self).__init__(auto_prefix=False) 274 else: 275 super(OperationBackwardWithNoSens, self).__init__() 276 self.network = network 277 self.grad = grad_op 278 279 def construct(self, *inputs): 280 return self.grad(self.network)(*inputs) 281 282 283class NNBackward(nn.Cell): 284 def __init__(self, network, grad_op, sens): 285 if isinstance(network, nn.Cell): 286 super(NNBackward, self).__init__(auto_prefix=False) 287 else: 288 super(NNBackward, self).__init__() 289 self.network = network 290 self.grad = grad_op 291 self.params = ParameterTuple(network.trainable_params()) 292 self.sens = sens 293 294 def construct(self, *inputs): 295 return self.grad(self.network, self.params)(*inputs, self.sens) 296 297 298class NNBackwardWithNoSens(nn.Cell): 299 def __init__(self, network, grad_op): 300 if isinstance(network, nn.Cell): 301 super(NNBackwardWithNoSens, self).__init__(auto_prefix=False) 302 else: 303 super(NNBackwardWithNoSens, self).__init__() 304 self.network = network 305 self.grad = grad_op 306 self.params = ParameterTuple(network.trainable_params()) 307 308 def construct(self, *inputs): 309 return self.grad(self.network, self.params)(*inputs) 310 311 312def gen_grad_net(net, grad_op, input_num, sens=None, training=True, desc_const=(), 313 const_first=False, add_fake_input=False): 314 if not isinstance(net, nn.Cell): 315 net = gen_net(net, input_num, desc_const=desc_const, const_first=const_first, add_fake_input=add_fake_input) 316 if grad_op.get_by_list: 317 if grad_op.sens_param: 318 net = NNBackward(net, grad_op, sens) 319 else: 320 net = NNBackwardWithNoSens(net, grad_op) 321 else: 322 if grad_op.sens_param: 323 net = OperationBackward(net, grad_op, sens) 324 else: 325 net = OperationBackwardWithNoSens(net, grad_op) 326 set_block_training(net, training) 327 return net 328 329 330def set_block_training(net, training=True): 331 if isinstance(net, nn.Cell): 332 net.set_train(training) 333 334 335def set_block_phase(net, phase='train'): 336 if isinstance(net, nn.Cell): 337 net.phase = phase 338 339 340def create_funcs(verification_set, block_generator, block_runner, grad_op=None, default_rand_func=None): 341 def create_func(block, num_outputs, rand_func, desc_const, const_first, add_fake_input, split_outputs): 342 def function(*inputs): 343 # gradient 344 if grad_op: 345 if num_outputs == 0: 346 grad_op_ = GradOperation(get_all=grad_op.get_all, 347 get_by_list=grad_op.get_by_list, sens_param=False) 348 b = block_generator(block, grad_op_, len(inputs), desc_const=desc_const, 349 const_first=const_first, add_fake_input=add_fake_input) 350 return block_runner(b, *inputs, rand_func=rand_func) 351 if num_outputs == 1: 352 b = block_generator(block, grad_op, len(inputs) - 1, inputs[-1], desc_const=desc_const, 353 const_first=const_first, add_fake_input=add_fake_input) 354 return block_runner(b, *(inputs[:-1]), rand_func=rand_func) 355 if split_outputs: 356 block_inputs = inputs[0:len(inputs) - num_outputs] 357 sens_inputs = inputs[len(inputs) - num_outputs:] 358 ret = [] 359 for i in range(num_outputs): 360 bi_inputs = list(block_inputs) 361 bi = get_output_cell(block, len(block_inputs), i) 362 bi = block_generator(bi, grad_op, len(bi_inputs), sens_inputs[i], desc_const=desc_const, 363 const_first=const_first, add_fake_input=add_fake_input) 364 grads_i = block_runner(bi, *bi_inputs, rand_func=rand_func) 365 if isinstance(grads_i, tuple): 366 ret.extend(grads_i) 367 else: 368 ret.append(grads_i) 369 return ret 370 block_inputs = inputs[0:len(inputs) - num_outputs] 371 sens_inputs = tuple(inputs[len(inputs) - num_outputs:]) 372 b = block_generator(block, grad_op, len(block_inputs), sens_inputs, desc_const=desc_const, 373 const_first=const_first, add_fake_input=add_fake_input) 374 return block_runner(b, *block_inputs, rand_func=rand_func) 375 # forward 376 inputs_num = len(inputs) 377 if add_fake_input and inputs_num == 1: 378 # input is faked 379 inputs_num = 0 380 b = block_generator(block, inputs_num, desc_const=desc_const, const_first=const_first, 381 add_fake_input=add_fake_input) 382 return block_runner(b, *inputs, rand_func=rand_func) 383 384 return function 385 386 bc_configs = verification_set[keyword.function] 387 for config in bc_configs: 388 block = config[keyword.block] 389 rand_func = config.get(keyword.init_param_with, default_rand_func) 390 num_outputs = config.get(keyword.num_outputs, 0) 391 desc_const = config.get(keyword.desc_const, []) 392 const_first = config.get(keyword.const_first, False) 393 add_fake_input = config.get(keyword.add_fake_input, False) 394 split_outputs = config.get(keyword.split_outputs, True) 395 config[keyword.block] = create_func(block, num_outputs, rand_func, desc_const, 396 const_first, add_fake_input, split_outputs) 397 return bc_configs 398