1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""less Batch Normalization""" 16from __future__ import absolute_import 17 18import numpy as np 19from mindspore.nn.cell import Cell 20from mindspore.nn.layer import Dense 21from mindspore.ops import operations as P 22from mindspore.common import Tensor, Parameter 23from mindspore.common import dtype as mstype 24from mindspore.common.initializer import initializer 25 26 27__all__ = ["CommonHeadLastFN", "LessBN"] 28 29 30class CommonHeadLastFN(Cell): 31 r""" 32 The last full Normalization layer. 33 34 This layer implements the operation as: 35 36 .. math:: 37 \text{inputs} = \text{norm}(\text{inputs}) 38 \text{kernel} = \text{norm}(\text{kernel}) 39 \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), 40 41 Args: 42 in_channels (int): The number of channels in the input space. 43 out_channels (int): The number of channels in the output space. 44 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 45 is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. 46 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 47 same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. 48 has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True``. 49 50 Supported Platforms: 51 ``Ascend`` ``GPU`` ``CPU`` 52 53 Examples: 54 >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) 55 >>> net = CommonHeadLastFN(3, 4) 56 >>> output = net(input) 57 """ 58 def __init__(self, 59 in_channels, 60 out_channels, 61 weight_init='normal', 62 bias_init='zeros', 63 has_bias=True): 64 super(CommonHeadLastFN, self).__init__() 65 weight_shape = [out_channels, in_channels] 66 self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') 67 self.x_norm = P.L2Normalize(axis=1) 68 self.w_norm = P.L2Normalize(axis=1) 69 self.fc = P.MatMul(transpose_a=False, transpose_b=True) 70 self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') 71 self.has_bias = has_bias 72 if self.has_bias: 73 bias_shape = [out_channels] 74 self.bias_add = P.BiasAdd() 75 self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') 76 77 def construct(self, x): 78 x = self.x_norm(x) 79 w = self.w_norm(self.weight) 80 x = self.fc(x, w) 81 if self.has_bias: 82 x = self.bias_add(x, self.bias) 83 x = self.multiplier * x 84 return x 85 86 87class LessBN(Cell): 88 """ 89 Reduce the number of BN automatically to improve the network performance 90 and ensure the network accuracy. 91 92 Args: 93 network (Cell): Network to be modified. 94 fn_flag (bool): Replace FC with FN. Default: ``False`` . 95 96 Examples: 97 >>> import numpy as np 98 >>> from mindspore import Tensor, Parameter, nn 99 >>> from mindspore import ops 100 >>> from mindspore.nn import WithLossCell 101 >>> from mindspore import dtype as mstype 102 >>> from mindspore import boost 103 >>> 104 >>> class Net(nn.Cell): 105 ... def __init__(self, in_features, out_features): 106 ... super(Net, self).__init__() 107 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 108 ... name='weight') 109 ... self.matmul = ops.MatMul() 110 ... 111 ... def construct(self, x): 112 ... output = self.matmul(x, self.weight) 113 ... return output 114 >>> size, in_features, out_features = 16, 16, 10 115 >>> net = Net(in_features, out_features) 116 >>> loss = nn.MSELoss() 117 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 118 >>> net_with_loss = WithLossCell(net, loss) 119 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 120 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 121 >>> train_network = boost.LessBN(net_with_loss) 122 >>> output = train_network(inputs, label) 123 """ 124 125 def __init__(self, network, fn_flag=False): 126 super(LessBN, self).__init__() 127 self.network = network 128 if "less_bn" not in self.network.get_flags(): 129 self.network.set_boost("less_bn") 130 self.network.update_cell_prefix() 131 if fn_flag: 132 self._convert_to_less_bn_net(self.network) 133 # Please use this interface with caution as it may cause uncontrollable problem. 134 # The purpose here is to prevent the "less_bn" flag from disappearing. 135 self.network.add_flags(defer_inline=True) 136 137 @staticmethod 138 def _convert_dense(subcell): 139 """ 140 convert dense cell to FN cell 141 """ 142 prefix = subcell.param_prefix 143 new_subcell = CommonHeadLastFN(subcell.in_channels, 144 subcell.out_channels, 145 subcell.weight, 146 subcell.bias, 147 False) 148 new_subcell.update_parameters_name(prefix + '.') 149 150 return new_subcell 151 152 def construct(self, *inputs): 153 return self.network(*inputs) 154 155 def _convert_to_less_bn_net(self, net): 156 """ 157 convert network to less_bn network 158 """ 159 cells = net.name_cells() 160 dense_name = [] 161 dense_list = [] 162 for name in cells: 163 subcell = cells[name] 164 if subcell == net: 165 continue 166 if isinstance(subcell, (Dense)): 167 dense_name.append(name) 168 dense_list.append(subcell) 169 else: 170 self._convert_to_less_bn_net(subcell) 171 172 if dense_list: 173 new_subcell = LessBN._convert_dense(dense_list[-1]) 174 net.insert_child_to_cell(dense_name[-1], new_subcell) 175