1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""less Batch Normalization""" 16import numpy as np 17from mindspore.nn.cell import Cell 18from mindspore.nn.layer import Dense 19from mindspore.ops import operations as P 20from mindspore.common import Tensor, Parameter 21from mindspore.common import dtype as mstype 22from mindspore.common.initializer import initializer 23 24 25__all__ = ["CommonHeadLastFN", "LessBN"] 26 27 28class CommonHeadLastFN(Cell): 29 r""" 30 The last full Normalization layer. 31 32 This layer implements the operation as: 33 34 .. math:: 35 \text{inputs} = \text{norm}(\text{inputs}) 36 \text{kernel} = \text{norm}(\text{kernel}) 37 \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), 38 39 Args: 40 in_channels (int): The number of channels in the input space. 41 out_channels (int): The number of channels in the output space. 42 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 43 is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. 44 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 45 same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. 46 has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. 47 48 Supported Platforms: 49 ``Ascend`` ``GPU`` ``CPU`` 50 51 Examples: 52 >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) 53 >>> net = CommonHeadLastFN(3, 4) 54 >>> output = net(input) 55 """ 56 def __init__(self, 57 in_channels, 58 out_channels, 59 weight_init='normal', 60 bias_init='zeros', 61 has_bias=True): 62 super(CommonHeadLastFN, self).__init__() 63 weight_shape = [out_channels, in_channels] 64 self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') 65 self.x_norm = P.L2Normalize(axis=1) 66 self.w_norm = P.L2Normalize(axis=1) 67 self.fc = P.MatMul(transpose_a=False, transpose_b=True) 68 self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') 69 self.has_bias = has_bias 70 if self.has_bias: 71 bias_shape = [out_channels] 72 self.bias_add = P.BiasAdd() 73 self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') 74 75 def construct(self, x): 76 x = self.x_norm(x) 77 w = self.w_norm(self.weight) 78 x = self.fc(x, w) 79 if self.has_bias: 80 x = self.bias_add(x, self.bias) 81 x = self.multiplier * x 82 return x 83 84 85class LessBN(Cell): 86 """ 87 Reduce the number of BN automatically to improve the network performance 88 and ensure the network accuracy. 89 90 Args: 91 network (Cell): Network to be modified. 92 fn_flag (bool): Replace FC with FN. default: False. 93 94 Examples: 95 >>> network = boost.LessBN(network) 96 """ 97 98 def __init__(self, network, fn_flag=False): 99 super(LessBN, self).__init__() 100 self.network = network 101 self.network.set_boost("less_bn") 102 self.network.update_cell_prefix() 103 if fn_flag: 104 self._convert_to_less_bn_net(self.network) 105 self.network.add_flags(defer_inline=True) 106 107 def _convert_dense(self, subcell): 108 """ 109 convert dense cell to FN cell 110 """ 111 prefix = subcell.param_prefix 112 new_subcell = CommonHeadLastFN(subcell.in_channels, 113 subcell.out_channels, 114 subcell.weight, 115 subcell.bias, 116 False) 117 new_subcell.update_parameters_name(prefix + '.') 118 119 return new_subcell 120 121 def _convert_to_less_bn_net(self, net): 122 """ 123 convert network to less_bn network 124 """ 125 cells = net.name_cells() 126 dense_name = [] 127 dense_list = [] 128 for name in cells: 129 subcell = cells[name] 130 if subcell == net: 131 continue 132 elif isinstance(subcell, (Dense)): 133 dense_name.append(name) 134 dense_list.append(subcell) 135 else: 136 self._convert_to_less_bn_net(subcell) 137 138 if dense_list: 139 new_subcell = self._convert_dense(dense_list[-1]) 140 net.insert_child_to_cell(dense_name[-1], new_subcell) 141 142 def construct(self, *inputs): 143 return self.network(*inputs) 144