# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """less Batch Normalization""" import numpy as np from mindspore.nn.cell import Cell from mindspore.nn.layer import Dense from mindspore.ops import operations as P from mindspore.common import Tensor, Parameter from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer __all__ = ["CommonHeadLastFN", "LessBN"] class CommonHeadLastFN(Cell): r""" The last full Normalization layer. This layer implements the operation as: .. math:: \text{inputs} = \text{norm}(\text{inputs}) \text{kernel} = \text{norm}(\text{kernel}) \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), Args: in_channels (int): The number of channels in the input space. out_channels (int): The number of channels in the output space. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) >>> net = CommonHeadLastFN(3, 4) >>> output = net(input) """ def __init__(self, in_channels, out_channels, weight_init='normal', bias_init='zeros', has_bias=True): super(CommonHeadLastFN, self).__init__() weight_shape = [out_channels, in_channels] self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') self.x_norm = P.L2Normalize(axis=1) self.w_norm = P.L2Normalize(axis=1) self.fc = P.MatMul(transpose_a=False, transpose_b=True) self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') self.has_bias = has_bias if self.has_bias: bias_shape = [out_channels] self.bias_add = P.BiasAdd() self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') def construct(self, x): x = self.x_norm(x) w = self.w_norm(self.weight) x = self.fc(x, w) if self.has_bias: x = self.bias_add(x, self.bias) x = self.multiplier * x return x class LessBN(Cell): """ Reduce the number of BN automatically to improve the network performance and ensure the network accuracy. Args: network (Cell): Network to be modified. fn_flag (bool): Replace FC with FN. default: False. Examples: >>> network = boost.LessBN(network) """ def __init__(self, network, fn_flag=False): super(LessBN, self).__init__() self.network = network self.network.set_boost("less_bn") self.network.update_cell_prefix() if fn_flag: self._convert_to_less_bn_net(self.network) self.network.add_flags(defer_inline=True) def _convert_dense(self, subcell): """ convert dense cell to FN cell """ prefix = subcell.param_prefix new_subcell = CommonHeadLastFN(subcell.in_channels, subcell.out_channels, subcell.weight, subcell.bias, False) new_subcell.update_parameters_name(prefix + '.') return new_subcell def _convert_to_less_bn_net(self, net): """ convert network to less_bn network """ cells = net.name_cells() dense_name = [] dense_list = [] for name in cells: subcell = cells[name] if subcell == net: continue elif isinstance(subcell, (Dense)): dense_name.append(name) dense_list.append(subcell) else: self._convert_to_less_bn_net(subcell) if dense_list: new_subcell = self._convert_dense(dense_list[-1]) net.insert_child_to_cell(dense_name[-1], new_subcell) def construct(self, *inputs): return self.network(*inputs)