1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Utils for Cell related computation.""" 17 18# pylint: disable=missing-docstring 19 20import numpy as np 21 22from mindspore import ParameterTuple 23from mindspore import nn, context 24from mindspore.common.api import _cell_graph_executor, ms_function 25from mindspore.common.tensor import Tensor 26from mindspore.ops import functional as F 27from mindspore.ops import operations as P 28from mindspore.ops.composite import GradOperation 29from . import keyword 30 31 32def get_uniform_with_shape(shape): 33 np.random.seed(1) 34 return np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) 35 36 37def set_block_param_with_rand(net, rand_func=None): 38 if not isinstance(net, nn.Cell) or rand_func is None: 39 return 40 net.init_parameters_data() 41 for param in net.trainable_params(): 42 param.set_data(Tensor(rand_func(param.data.asnumpy().shape))) 43 44 45def compile_block(net, *inputs, rand_func=None, training=True): 46 set_block_training(net, training) 47 set_block_param_with_rand(net, rand_func) 48 return _cell_graph_executor.compile(net, *inputs) 49 50 51def run_block(net, *inputs, rand_func=None, training=True): 52 set_block_training(net, training) 53 set_block_param_with_rand(net, rand_func) 54 if context.get_context("mode") == context.PYNATIVE_MODE: 55 def func_pynative(*inputs): 56 @ms_function 57 def _func_pynative(*inputs): 58 return net(*inputs) 59 60 return _func_pynative(*inputs) 61 62 return func_pynative(*inputs) 63 return net(*inputs) 64 65 66class IthOutputCell(nn.Cell): 67 def __init__(self, network, output_index): 68 if isinstance(network, nn.Cell): 69 super(IthOutputCell, self).__init__(auto_prefix=False) 70 else: 71 super(IthOutputCell, self).__init__() 72 self.network = network 73 self.output_index = output_index 74 75 def construct(self, *inputs): 76 predict = self.network(*inputs)[self.output_index] 77 return predict 78 79 80def get_output_cell(network, num_input, output_index, training=True): 81 _ = num_input 82 net = IthOutputCell(network, output_index) 83 set_block_training(net, training) 84 return net 85 86 87class OutputReduceSumCell(nn.Cell): 88 def __init__(self, network, output_num): 89 super(OutputReduceSumCell, self).__init__() 90 self.output_num = output_num 91 self.network = network 92 self.reduce_sum = P.ReduceSum() 93 94 def construct(self, *inputs): 95 if self.output_num == 1: 96 return self.reduce_sum(self.network(*inputs), None) 97 ret = F.make_tuple() 98 for index in range(self.output_num): 99 predict = self.network(*inputs)[index] 100 predict_reduce = self.reduce_sum(predict, None) 101 ret = ret + F.make_tuple(predict_reduce) 102 return ret 103 104 105def get_output_reduce_cell(network, output_num, training=True): 106 net = OutputReduceSumCell(network, output_num) 107 set_block_training(net, training) 108 return net 109 110 111class InputOpNet(nn.Cell): 112 def __init__(self, op, c1=None, c2=None, c3=None, c4=None): 113 super(InputOpNet, self).__init__() 114 self.op = op 115 self.c1 = c1 116 self.c2 = c2 117 self.c3 = c3 118 self.c4 = c4 119 120 def construct(self, *inputs): 121 raise NotImplementedError 122 123 def construct0_c0_fake(self, data): 124 x = self.op() + data 125 return x 126 127 def construct0_c1_fake(self, data): 128 x = self.op(self.c1) + data 129 return x 130 131 def construct0_c2_fake(self, data): 132 x = self.op(self.c1, self.c2) + data 133 return x 134 135 def construct0_c3_fake(self, data): 136 x = self.op(self.c1, self.c2, self.c3) + data 137 return x 138 139 def construct0_c0(self): 140 x = self.op() 141 return x 142 143 def construct0_c1(self): 144 x = self.op(self.c1) 145 return x 146 147 def construct0_c2(self): 148 x = self.op(self.c1, self.c2) 149 return x 150 151 def construct1_c0(self, x1): 152 x = self.op(x1) 153 return x 154 155 def construct1_c1(self, x1): 156 x = self.op(x1, self.c1) 157 return x 158 159 def construct1_c2(self, x1): 160 x = self.op(x1, self.c1, self.c2) 161 return x 162 163 def construct1_c3(self, x1): 164 x = self.op(x1, self.c1, self.c2, self.c3) 165 return x 166 167 def construct1_c4(self, x1): 168 x = self.op(x1, self.c1, self.c2, self.c3, self.c4) 169 return x 170 171 def constructc1_1(self, x1): 172 x = self.op(self.c1, x1) 173 return x 174 175 def construct2_c0(self, x1, x2): 176 x = self.op(x1, x2) 177 return x 178 179 def construct2_c1(self, x1, x2): 180 x = self.op(x1, x2, self.c1) 181 return x 182 183 def construct2_c3(self, x1, x2): 184 x = self.op(x1, x2, self.c1, self.c2, self.c3) 185 return x 186 187 def construct3_c0(self, x1, x2, x3): 188 x = self.op(x1, x2, x3) 189 return x 190 191 def construct3_c1(self, x1, x2, x3): 192 x = self.op(x1, x2, x3, self.c1) 193 return x 194 195 def construct4_c0(self, x1, x2, x3, x4): 196 x = self.op(x1, x2, x3, x4) 197 return x 198 199 def construct4_c1(self, x1, x2, x3, x4): 200 x = self.op(x1, x2, x3, x4, self.c1) 201 return x 202 203 def construct4_c2(self, x1, x2, x3, x4): 204 x = self.op(x1, x2, x3, x4, self.c1, self.c2) 205 return x 206 207 def construct4_c4(self, x1, x2, x3, x4): 208 x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) 209 return x 210 211 def construct5_c0(self, x1, x2, x3, x4, x5): 212 x = self.op(x1, x2, x3, x4, x5) 213 return x 214 215 def construct6_c0(self, x1, x2, x3, x4, x5, x6): 216 x = self.op(x1, x2, x3, x4, x5, x6) 217 return x 218 219 def construct5_c1(self, x1, x2, x3, x4, x5): 220 x = self.op(x1, x2, x3, x4, x5, self.c1) 221 return x 222 223 def construct5_c4(self, x1, x2, x3, x4, x5): 224 x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4) 225 return x 226 227 228def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): 229 if isinstance(op, nn.Cell): 230 return op 231 net = InputOpNet(op, *desc_const) 232 if const_first: 233 fn_name = 'constructc%d_%d' % (len(desc_const), input_num) 234 else: 235 fn_name = 'construct%d_c%d' % (input_num, len(desc_const)) 236 if add_fake_input: 237 fn_name += '_fake' 238 f = getattr(net, fn_name) 239 setattr(net, "construct", f) 240 set_block_training(net, training) 241 return net 242 243 244class OperationBackward(nn.Cell): 245 def __init__(self, network, grad_op, sens): 246 if isinstance(network, nn.Cell): 247 super(OperationBackward, self).__init__(auto_prefix=False) 248 else: 249 super(OperationBackward, self).__init__() 250 self.network = network 251 self.grad = grad_op 252 self.sens = sens 253 254 def construct(self, *inputs): 255 return self.grad(self.network)(*inputs, self.sens) 256 257 258class OperationBackwardWithNoSens(nn.Cell): 259 def __init__(self, network, grad_op): 260 if isinstance(network, nn.Cell): 261 super(OperationBackwardWithNoSens, self).__init__(auto_prefix=False) 262 else: 263 super(OperationBackwardWithNoSens, self).__init__() 264 self.network = network 265 self.grad = grad_op 266 267 def construct(self, *inputs): 268 return self.grad(self.network)(*inputs) 269 270 271class NNBackward(nn.Cell): 272 def __init__(self, network, grad_op, sens): 273 if isinstance(network, nn.Cell): 274 super(NNBackward, self).__init__(auto_prefix=False) 275 else: 276 super(NNBackward, self).__init__() 277 self.network = network 278 self.grad = grad_op 279 self.params = ParameterTuple(network.trainable_params()) 280 self.sens = sens 281 282 def construct(self, *inputs): 283 return self.grad(self.network, self.params)(*inputs, self.sens) 284 285 286class NNBackwardWithNoSens(nn.Cell): 287 def __init__(self, network, grad_op): 288 if isinstance(network, nn.Cell): 289 super(NNBackwardWithNoSens, self).__init__(auto_prefix=False) 290 else: 291 super(NNBackwardWithNoSens, self).__init__() 292 self.network = network 293 self.grad = grad_op 294 self.params = ParameterTuple(network.trainable_params()) 295 296 def construct(self, *inputs): 297 return self.grad(self.network, self.params)(*inputs) 298 299 300def gen_grad_net(net, grad_op, input_num, sens=None, training=True, desc_const=(), 301 const_first=False, add_fake_input=False): 302 if not isinstance(net, nn.Cell): 303 net = gen_net(net, input_num, desc_const=desc_const, const_first=const_first, add_fake_input=add_fake_input) 304 if grad_op.get_by_list: 305 if grad_op.sens_param: 306 net = NNBackward(net, grad_op, sens) 307 else: 308 net = NNBackwardWithNoSens(net, grad_op) 309 else: 310 if grad_op.sens_param: 311 net = OperationBackward(net, grad_op, sens) 312 else: 313 net = OperationBackwardWithNoSens(net, grad_op) 314 set_block_training(net, training) 315 return net 316 317 318def set_block_training(net, training=True): 319 if isinstance(net, nn.Cell): 320 net.set_train(training) 321 322 323def set_block_phase(net, phase='train'): 324 if isinstance(net, nn.Cell): 325 net.phase = phase 326 327 328def create_funcs(verification_set, block_generator, block_runner, grad_op=None, default_rand_func=None): 329 def create_func(block, num_outputs, rand_func, desc_const, const_first, add_fake_input, split_outputs): 330 def function(*inputs): 331 # gradient 332 if grad_op: 333 if num_outputs == 0: 334 grad_op_ = GradOperation(get_all=grad_op.get_all, 335 get_by_list=grad_op.get_by_list, sens_param=False) 336 b = block_generator(block, grad_op_, len(inputs), desc_const=desc_const, 337 const_first=const_first, add_fake_input=add_fake_input) 338 return block_runner(b, *inputs, rand_func=rand_func) 339 if num_outputs == 1: 340 b = block_generator(block, grad_op, len(inputs) - 1, inputs[-1], desc_const=desc_const, 341 const_first=const_first, add_fake_input=add_fake_input) 342 return block_runner(b, *(inputs[:-1]), rand_func=rand_func) 343 if split_outputs: 344 block_inputs = inputs[0:len(inputs) - num_outputs] 345 sens_inputs = inputs[len(inputs) - num_outputs:] 346 ret = [] 347 for i in range(num_outputs): 348 bi_inputs = list(block_inputs) 349 bi = get_output_cell(block, len(block_inputs), i) 350 bi = block_generator(bi, grad_op, len(bi_inputs), sens_inputs[i], desc_const=desc_const, 351 const_first=const_first, add_fake_input=add_fake_input) 352 grads_i = block_runner(bi, *bi_inputs, rand_func=rand_func) 353 if isinstance(grads_i, tuple): 354 ret.extend(grads_i) 355 else: 356 ret.append(grads_i) 357 return ret 358 block_inputs = inputs[0:len(inputs) - num_outputs] 359 sens_inputs = tuple(inputs[len(inputs) - num_outputs:]) 360 b = block_generator(block, grad_op, len(block_inputs), sens_inputs, desc_const=desc_const, 361 const_first=const_first, add_fake_input=add_fake_input) 362 return block_runner(b, *block_inputs, rand_func=rand_func) 363 # forward 364 inputs_num = len(inputs) 365 if add_fake_input and inputs_num == 1: 366 # input is faked 367 inputs_num = 0 368 b = block_generator(block, inputs_num, desc_const=desc_const, const_first=const_first, 369 add_fake_input=add_fake_input) 370 return block_runner(b, *inputs, rand_func=rand_func) 371 372 return function 373 374 bc_configs = verification_set[keyword.function] 375 for config in bc_configs: 376 block = config[keyword.block] 377 rand_func = config.get(keyword.init_param_with, default_rand_func) 378 num_outputs = config.get(keyword.num_outputs, 0) 379 desc_const = config.get(keyword.desc_const, []) 380 const_first = config.get(keyword.const_first, False) 381 add_fake_input = config.get(keyword.add_fake_input, False) 382 split_outputs = config.get(keyword.split_outputs, True) 383 config[keyword.block] = create_func(block, num_outputs, rand_func, desc_const, 384 const_first, add_fake_input, split_outputs) 385 return bc_configs 386