• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Parameter init."""
16import math
17from functools import reduce
18import numpy as np
19from mindspore.common import initializer as init
20from mindspore.common.initializer import Initializer as MeInitializer
21from mindspore.train.serialization import load_checkpoint, load_param_into_net
22import mindspore.nn as nn
23from .util import load_backbone
24
25def calculate_gain(nonlinearity, param=None):
26    r"""Return the recommended gain value for the given nonlinearity function.
27    The values are as follows:
28
29    ================= ====================================================
30    nonlinearity      gain
31    ================= ====================================================
32    Linear / Identity :math:`1`
33    Conv{1,2,3}D      :math:`1`
34    Sigmoid           :math:`1`
35    Tanh              :math:`\frac{5}{3}`
36    ReLU              :math:`\sqrt{2}`
37    Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
38    ================= ====================================================
39
40    Args:
41        nonlinearity: the non-linear function (`nn.functional` name)
42        param: optional parameter for the non-linear function
43
44    Examples:
45        >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
46    """
47    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
48    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
49        return 1
50    if nonlinearity == 'tanh':
51        return 5.0 / 3
52    if nonlinearity == 'relu':
53        return math.sqrt(2.0)
54    if nonlinearity == 'leaky_relu':
55        if param is None:
56            negative_slope = 0.01
57        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
58            # True/False are instances of int, hence check above
59            negative_slope = param
60        else:
61            raise ValueError("negative_slope {} not a valid number".format(param))
62        return math.sqrt(2.0 / (1 + negative_slope ** 2))
63
64    raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
65
66
67def _assignment(arr, num):
68    """Assign the value of 'num' and 'arr'."""
69    if arr.shape == ():
70        arr = arr.reshape((1))
71        arr[:] = num
72        arr = arr.reshape(())
73    else:
74        if isinstance(num, np.ndarray):
75            arr[:] = num[:]
76        else:
77            arr[:] = num
78    return arr
79
80
81def _calculate_correct_fan(array, mode):
82    mode = mode.lower()
83    valid_modes = ['fan_in', 'fan_out']
84    if mode not in valid_modes:
85        raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
86
87    fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
88    return fan_in if mode == 'fan_in' else fan_out
89
90
91def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
92    r"""Fills the input `Tensor` with values according to the method
93    described in `Delving deep into rectifiers: Surpassing human-level
94    performance on ImageNet classification` - He, K. et al. (2015), using a
95    uniform distribution. The resulting tensor will have values sampled from
96    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
97
98    .. math::
99        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
100
101    Also known as He initialization.
102
103    Args:
104        tensor: an n-dimensional `Tensor`
105        a: the negative slope of the rectifier used after this layer (only
106        used with ``'leaky_relu'``)
107        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
108            preserves the magnitude of the variance of the weights in the
109            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
110            backwards pass.
111        nonlinearity: the non-linear function (`nn.functional` name),
112            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
113
114    Examples:
115        >>> w = np.empty(3, 5)
116        >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
117    """
118    fan = _calculate_correct_fan(arr, mode)
119    gain = calculate_gain(nonlinearity, a)
120    std = gain / math.sqrt(fan)
121    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
122    return np.random.uniform(-bound, bound, arr.shape)
123
124
125def _calculate_fan_in_and_fan_out(arr):
126    """Calculate fan in and fan out."""
127    dimensions = len(arr.shape)
128    if dimensions < 2:
129        raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
130
131    num_input_fmaps = arr.shape[1]
132    num_output_fmaps = arr.shape[0]
133    receptive_field_size = 1
134    if dimensions > 2:
135        receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
136    fan_in = num_input_fmaps * receptive_field_size
137    fan_out = num_output_fmaps * receptive_field_size
138
139    return fan_in, fan_out
140
141
142class KaimingUniform(MeInitializer):
143    """Kaiming uniform initializer."""
144    def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
145        super(KaimingUniform, self).__init__()
146        self.a = a
147        self.mode = mode
148        self.nonlinearity = nonlinearity
149
150    def _initialize(self, arr):
151        tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
152        _assignment(arr, tmp)
153
154
155def default_recurisive_init(custom_cell):
156    """Initialize parameter."""
157    for _, cell in custom_cell.cells_and_names():
158        if isinstance(cell, nn.Conv2d):
159            cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
160                                                  cell.weight.shape,
161                                                  cell.weight.dtype))
162            if cell.bias is not None:
163                fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
164                bound = 1 / math.sqrt(fan_in)
165                cell.bias.set_data(init.initializer(init.Uniform(bound),
166                                                    cell.bias.shape,
167                                                    cell.bias.dtype))
168        elif isinstance(cell, nn.Dense):
169            cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
170                                                  cell.weight.shape,
171                                                  cell.weight.dtype))
172            if cell.bias is not None:
173                fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
174                bound = 1 / math.sqrt(fan_in)
175                cell.bias.set_data(init.initializer(init.Uniform(bound),
176                                                    cell.bias.shape,
177                                                    cell.bias.dtype))
178        elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
179            pass
180
181def load_yolov3_params(args, network):
182    """Load yolov3 darknet parameter from checkpoint."""
183    if args.pretrained_backbone:
184        network = load_backbone(network, args.pretrained_backbone, args)
185        args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
186    else:
187        args.logger.info('Not load pre-trained backbone, please be careful')
188
189    if args.resume_yolov3:
190        param_dict = load_checkpoint(args.resume_yolov3)
191        param_dict_new = {}
192        for key, values in param_dict.items():
193            if key.startswith('moments.'):
194                continue
195            elif key.startswith('yolo_network.'):
196                param_dict_new[key[13:]] = values
197                args.logger.info('in resume {}'.format(key))
198            else:
199                param_dict_new[key] = values
200                args.logger.info('in resume {}'.format(key))
201
202        args.logger.info('resume finished')
203        load_param_into_net(network, param_dict_new)
204        args.logger.info('load_model {} success'.format(args.resume_yolov3))
205