• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""less Batch Normalization"""
16import numpy as np
17from mindspore.nn.cell import Cell
18from mindspore.nn.layer import Dense
19from mindspore.ops import operations as P
20from mindspore.common import Tensor, Parameter
21from mindspore.common import dtype as mstype
22from mindspore.common.initializer import initializer
23
24
25__all__ = ["CommonHeadLastFN", "LessBN"]
26
27
28class CommonHeadLastFN(Cell):
29    r"""
30    The last full Normalization layer.
31
32    This layer implements the operation as:
33
34    .. math::
35        \text{inputs} = \text{norm}(\text{inputs})
36        \text{kernel} = \text{norm}(\text{kernel})
37        \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
38
39    Args:
40        in_channels (int): The number of channels in the input space.
41        out_channels (int): The number of channels in the output space.
42        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
43            is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
44        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
45            same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
46        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
47
48    Supported Platforms:
49        ``Ascend`` ``GPU`` ``CPU``
50
51    Examples:
52        >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
53        >>> net = CommonHeadLastFN(3, 4)
54        >>> output = net(input)
55    """
56    def __init__(self,
57                 in_channels,
58                 out_channels,
59                 weight_init='normal',
60                 bias_init='zeros',
61                 has_bias=True):
62        super(CommonHeadLastFN, self).__init__()
63        weight_shape = [out_channels, in_channels]
64        self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
65        self.x_norm = P.L2Normalize(axis=1)
66        self.w_norm = P.L2Normalize(axis=1)
67        self.fc = P.MatMul(transpose_a=False, transpose_b=True)
68        self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
69        self.has_bias = has_bias
70        if self.has_bias:
71            bias_shape = [out_channels]
72            self.bias_add = P.BiasAdd()
73            self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
74
75    def construct(self, x):
76        x = self.x_norm(x)
77        w = self.w_norm(self.weight)
78        x = self.fc(x, w)
79        if self.has_bias:
80            x = self.bias_add(x, self.bias)
81        x = self.multiplier * x
82        return x
83
84
85class LessBN(Cell):
86    """
87    Reduce the number of BN automatically to improve the network performance
88    and ensure the network accuracy.
89
90    Args:
91        network (Cell): Network to be modified.
92        fn_flag (bool): Replace FC with FN. default: False.
93
94    Examples:
95        >>> network = boost.LessBN(network)
96    """
97
98    def __init__(self, network, fn_flag=False):
99        super(LessBN, self).__init__()
100        self.network = network
101        self.network.set_boost("less_bn")
102        self.network.update_cell_prefix()
103        if fn_flag:
104            self._convert_to_less_bn_net(self.network)
105        self.network.add_flags(defer_inline=True)
106
107    def _convert_dense(self, subcell):
108        """
109        convert dense cell to FN cell
110        """
111        prefix = subcell.param_prefix
112        new_subcell = CommonHeadLastFN(subcell.in_channels,
113                                       subcell.out_channels,
114                                       subcell.weight,
115                                       subcell.bias,
116                                       False)
117        new_subcell.update_parameters_name(prefix + '.')
118
119        return new_subcell
120
121    def _convert_to_less_bn_net(self, net):
122        """
123        convert network to less_bn network
124        """
125        cells = net.name_cells()
126        dense_name = []
127        dense_list = []
128        for name in cells:
129            subcell = cells[name]
130            if subcell == net:
131                continue
132            elif isinstance(subcell, (Dense)):
133                dense_name.append(name)
134                dense_list.append(subcell)
135            else:
136                self._convert_to_less_bn_net(subcell)
137
138        if dense_list:
139            new_subcell = self._convert_dense(dense_list[-1])
140            net.insert_child_to_cell(dense_name[-1], new_subcell)
141
142    def construct(self, *inputs):
143        return self.network(*inputs)
144