1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Parameter init.""" 16import math 17from functools import reduce 18import numpy as np 19from mindspore.common import initializer as init 20from mindspore.common.initializer import Initializer as MeInitializer 21from mindspore.train.serialization import load_checkpoint, load_param_into_net 22import mindspore.nn as nn 23from .util import load_backbone 24 25def calculate_gain(nonlinearity, param=None): 26 r"""Return the recommended gain value for the given nonlinearity function. 27 The values are as follows: 28 29 ================= ==================================================== 30 nonlinearity gain 31 ================= ==================================================== 32 Linear / Identity :math:`1` 33 Conv{1,2,3}D :math:`1` 34 Sigmoid :math:`1` 35 Tanh :math:`\frac{5}{3}` 36 ReLU :math:`\sqrt{2}` 37 Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 38 ================= ==================================================== 39 40 Args: 41 nonlinearity: the non-linear function (`nn.functional` name) 42 param: optional parameter for the non-linear function 43 44 Examples: 45 >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 46 """ 47 linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 48 if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 49 return 1 50 if nonlinearity == 'tanh': 51 return 5.0 / 3 52 if nonlinearity == 'relu': 53 return math.sqrt(2.0) 54 if nonlinearity == 'leaky_relu': 55 if param is None: 56 negative_slope = 0.01 57 elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 58 # True/False are instances of int, hence check above 59 negative_slope = param 60 else: 61 raise ValueError("negative_slope {} not a valid number".format(param)) 62 return math.sqrt(2.0 / (1 + negative_slope ** 2)) 63 64 raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 65 66 67def _assignment(arr, num): 68 """Assign the value of 'num' and 'arr'.""" 69 if arr.shape == (): 70 arr = arr.reshape((1)) 71 arr[:] = num 72 arr = arr.reshape(()) 73 else: 74 if isinstance(num, np.ndarray): 75 arr[:] = num[:] 76 else: 77 arr[:] = num 78 return arr 79 80 81def _calculate_correct_fan(array, mode): 82 mode = mode.lower() 83 valid_modes = ['fan_in', 'fan_out'] 84 if mode not in valid_modes: 85 raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) 86 87 fan_in, fan_out = _calculate_fan_in_and_fan_out(array) 88 return fan_in if mode == 'fan_in' else fan_out 89 90 91def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): 92 r"""Fills the input `Tensor` with values according to the method 93 described in `Delving deep into rectifiers: Surpassing human-level 94 performance on ImageNet classification` - He, K. et al. (2015), using a 95 uniform distribution. The resulting tensor will have values sampled from 96 :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 97 98 .. math:: 99 \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} 100 101 Also known as He initialization. 102 103 Args: 104 tensor: an n-dimensional `Tensor` 105 a: the negative slope of the rectifier used after this layer (only 106 used with ``'leaky_relu'``) 107 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 108 preserves the magnitude of the variance of the weights in the 109 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 110 backwards pass. 111 nonlinearity: the non-linear function (`nn.functional` name), 112 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 113 114 Examples: 115 >>> w = np.empty(3, 5) 116 >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 117 """ 118 fan = _calculate_correct_fan(arr, mode) 119 gain = calculate_gain(nonlinearity, a) 120 std = gain / math.sqrt(fan) 121 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 122 return np.random.uniform(-bound, bound, arr.shape) 123 124 125def _calculate_fan_in_and_fan_out(arr): 126 """Calculate fan in and fan out.""" 127 dimensions = len(arr.shape) 128 if dimensions < 2: 129 raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") 130 131 num_input_fmaps = arr.shape[1] 132 num_output_fmaps = arr.shape[0] 133 receptive_field_size = 1 134 if dimensions > 2: 135 receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) 136 fan_in = num_input_fmaps * receptive_field_size 137 fan_out = num_output_fmaps * receptive_field_size 138 139 return fan_in, fan_out 140 141 142class KaimingUniform(MeInitializer): 143 """Kaiming uniform initializer.""" 144 def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): 145 super(KaimingUniform, self).__init__() 146 self.a = a 147 self.mode = mode 148 self.nonlinearity = nonlinearity 149 150 def _initialize(self, arr): 151 tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) 152 _assignment(arr, tmp) 153 154 155def default_recurisive_init(custom_cell): 156 """Initialize parameter.""" 157 for _, cell in custom_cell.cells_and_names(): 158 if isinstance(cell, nn.Conv2d): 159 cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), 160 cell.weight.shape, 161 cell.weight.dtype)) 162 if cell.bias is not None: 163 fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) 164 bound = 1 / math.sqrt(fan_in) 165 cell.bias.set_data(init.initializer(init.Uniform(bound), 166 cell.bias.shape, 167 cell.bias.dtype)) 168 elif isinstance(cell, nn.Dense): 169 cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), 170 cell.weight.shape, 171 cell.weight.dtype)) 172 if cell.bias is not None: 173 fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) 174 bound = 1 / math.sqrt(fan_in) 175 cell.bias.set_data(init.initializer(init.Uniform(bound), 176 cell.bias.shape, 177 cell.bias.dtype)) 178 elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): 179 pass 180 181def load_yolov3_params(args, network): 182 """Load yolov3 darknet parameter from checkpoint.""" 183 if args.pretrained_backbone: 184 network = load_backbone(network, args.pretrained_backbone, args) 185 args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) 186 else: 187 args.logger.info('Not load pre-trained backbone, please be careful') 188 189 if args.resume_yolov3: 190 param_dict = load_checkpoint(args.resume_yolov3) 191 param_dict_new = {} 192 for key, values in param_dict.items(): 193 if key.startswith('moments.'): 194 continue 195 elif key.startswith('yolo_network.'): 196 param_dict_new[key[13:]] = values 197 args.logger.info('in resume {}'.format(key)) 198 else: 199 param_dict_new[key] = values 200 args.logger.info('in resume {}'.format(key)) 201 202 args.logger.info('resume finished') 203 load_param_into_net(network, param_dict_new) 204 args.logger.info('load_model {} success'.format(args.resume_yolov3)) 205