• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""less Batch Normalization"""
16from __future__ import absolute_import
17
18import numpy as np
19from mindspore.nn.cell import Cell
20from mindspore.nn.layer import Dense
21from mindspore.ops import operations as P
22from mindspore.common import Tensor, Parameter
23from mindspore.common import dtype as mstype
24from mindspore.common.initializer import initializer
25
26
27__all__ = ["CommonHeadLastFN", "LessBN"]
28
29
30class CommonHeadLastFN(Cell):
31    r"""
32    The last full Normalization layer.
33
34    This layer implements the operation as:
35
36    .. math::
37        \text{inputs} = \text{norm}(\text{inputs})
38        \text{kernel} = \text{norm}(\text{kernel})
39        \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
40
41    Args:
42        in_channels (int): The number of channels in the input space.
43        out_channels (int): The number of channels in the output space.
44        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
45            is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
46        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
47            same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
48        has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True``.
49
50    Supported Platforms:
51        ``Ascend`` ``GPU`` ``CPU``
52
53    Examples:
54        >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
55        >>> net = CommonHeadLastFN(3, 4)
56        >>> output = net(input)
57    """
58    def __init__(self,
59                 in_channels,
60                 out_channels,
61                 weight_init='normal',
62                 bias_init='zeros',
63                 has_bias=True):
64        super(CommonHeadLastFN, self).__init__()
65        weight_shape = [out_channels, in_channels]
66        self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
67        self.x_norm = P.L2Normalize(axis=1)
68        self.w_norm = P.L2Normalize(axis=1)
69        self.fc = P.MatMul(transpose_a=False, transpose_b=True)
70        self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
71        self.has_bias = has_bias
72        if self.has_bias:
73            bias_shape = [out_channels]
74            self.bias_add = P.BiasAdd()
75            self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
76
77    def construct(self, x):
78        x = self.x_norm(x)
79        w = self.w_norm(self.weight)
80        x = self.fc(x, w)
81        if self.has_bias:
82            x = self.bias_add(x, self.bias)
83        x = self.multiplier * x
84        return x
85
86
87class LessBN(Cell):
88    """
89    Reduce the number of BN automatically to improve the network performance
90    and ensure the network accuracy.
91
92    Args:
93        network (Cell): Network to be modified.
94        fn_flag (bool): Replace FC with FN. Default: ``False`` .
95
96    Examples:
97        >>> import numpy as np
98        >>> from mindspore import Tensor, Parameter, nn
99        >>> from mindspore import ops
100        >>> from mindspore.nn import WithLossCell
101        >>> from mindspore import dtype as mstype
102        >>> from mindspore import boost
103        >>>
104        >>> class Net(nn.Cell):
105        ...    def __init__(self, in_features, out_features):
106        ...        super(Net, self).__init__()
107        ...        self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
108        ...                                name='weight')
109        ...        self.matmul = ops.MatMul()
110        ...
111        ...    def construct(self, x):
112        ...        output = self.matmul(x, self.weight)
113        ...        return output
114        >>> size, in_features, out_features = 16, 16, 10
115        >>> net = Net(in_features, out_features)
116        >>> loss = nn.MSELoss()
117        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
118        >>> net_with_loss = WithLossCell(net, loss)
119        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
120        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
121        >>> train_network = boost.LessBN(net_with_loss)
122        >>> output = train_network(inputs, label)
123    """
124
125    def __init__(self, network, fn_flag=False):
126        super(LessBN, self).__init__()
127        self.network = network
128        if "less_bn" not in self.network.get_flags():
129            self.network.set_boost("less_bn")
130            self.network.update_cell_prefix()
131            if fn_flag:
132                self._convert_to_less_bn_net(self.network)
133            # Please use this interface with caution as it may cause uncontrollable problem.
134            # The purpose here is to prevent the "less_bn" flag from disappearing.
135            self.network.add_flags(defer_inline=True)
136
137    @staticmethod
138    def _convert_dense(subcell):
139        """
140        convert dense cell to FN cell
141        """
142        prefix = subcell.param_prefix
143        new_subcell = CommonHeadLastFN(subcell.in_channels,
144                                       subcell.out_channels,
145                                       subcell.weight,
146                                       subcell.bias,
147                                       False)
148        new_subcell.update_parameters_name(prefix + '.')
149
150        return new_subcell
151
152    def construct(self, *inputs):
153        return self.network(*inputs)
154
155    def _convert_to_less_bn_net(self, net):
156        """
157        convert network to less_bn network
158        """
159        cells = net.name_cells()
160        dense_name = []
161        dense_list = []
162        for name in cells:
163            subcell = cells[name]
164            if subcell == net:
165                continue
166            if isinstance(subcell, (Dense)):
167                dense_name.append(name)
168                dense_list.append(subcell)
169            else:
170                self._convert_to_less_bn_net(subcell)
171
172        if dense_list:
173            new_subcell = LessBN._convert_dense(dense_list[-1])
174            net.insert_child_to_cell(dense_name[-1], new_subcell)
175