• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MobileNetV2 Quant model define"""
16
17import numpy as np
18
19import mindspore.nn as nn
20from mindspore.ops import operations as P
21from mindspore import Tensor
22
23__all__ = ['mobilenetV2']
24
25
26def _make_divisible(v, divisor, min_value=None):
27    if min_value is None:
28        min_value = divisor
29    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30    # Make sure that round down does not go down by more than 10%.
31    if new_v < 0.9 * v:
32        new_v += divisor
33    return new_v
34
35
36class GlobalAvgPooling(nn.Cell):
37    """
38    Global avg pooling definition.
39
40    Args:
41
42    Returns:
43        Tensor, output tensor.
44
45    Examples:
46        >>> GlobalAvgPooling()
47    """
48
49    def __init__(self):
50        super(GlobalAvgPooling, self).__init__()
51        self.mean = P.ReduceMean(keep_dims=False)
52
53    def construct(self, x):
54        x = self.mean(x, (2, 3))
55        return x
56
57
58class ConvBNReLU(nn.Cell):
59    """
60    Convolution/Depthwise fused with Batchnorm and ReLU block definition.
61
62    Args:
63        in_planes (int): Input channel.
64        out_planes (int): Output channel.
65        kernel_size (int): Input kernel size.
66        stride (int): Stride size for the first convolutional layer. Default: 1.
67        groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
68
69    Returns:
70        Tensor, output tensor.
71
72    Examples:
73        >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
74    """
75
76    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
77        super(ConvBNReLU, self).__init__()
78        padding = (kernel_size - 1) // 2
79        self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size,
80                                   stride=stride,
81                                   pad_mode='pad',
82                                   padding=padding,
83                                   group=groups,
84                                   has_bn=True,
85                                   activation='relu')
86
87    def construct(self, x):
88        x = self.conv(x)
89        return x
90
91
92class InvertedResidual(nn.Cell):
93    """
94    Mobilenetv2 residual block definition.
95
96    Args:
97        inp (int): Input channel.
98        oup (int): Output channel.
99        stride (int): Stride size for the first convolutional layer. Default: 1.
100        expand_ratio (int): expand ration of input channel
101
102    Returns:
103        Tensor, output tensor.
104
105    Examples:
106        >>> ResidualBlock(3, 256, 1, 1)
107    """
108
109    def __init__(self, inp, oup, stride, expand_ratio):
110        super(InvertedResidual, self).__init__()
111        assert stride in [1, 2]
112
113        hidden_dim = int(round(inp * expand_ratio))
114        self.use_res_connect = stride == 1 and inp == oup
115
116        layers = []
117        if expand_ratio != 1:
118            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
119        layers.extend([
120            # dw
121            ConvBNReLU(hidden_dim, hidden_dim,
122                       stride=stride, groups=hidden_dim),
123            # pw-linear
124            nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1,
125                           pad_mode='pad', padding=0, group=1, has_bn=True)
126        ])
127        self.conv = nn.SequentialCell(layers)
128        self.add = P.Add()
129
130    def construct(self, x):
131        out = self.conv(x)
132        if self.use_res_connect:
133            out = self.add(out, x)
134        return out
135
136
137class mobilenetV2(nn.Cell):
138    """
139    mobilenetV2 fusion architecture.
140
141    Args:
142        class_num (Cell): number of classes.
143        width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
144        has_dropout (bool): Is dropout used. Default is false
145        inverted_residual_setting (list): Inverted residual settings. Default is None
146        round_nearest (list): Channel round to . Default is 8
147    Returns:
148        Tensor, output tensor.
149
150    Examples:
151        >>> mobilenetV2(num_classes=1000)
152    """
153
154    def __init__(self, num_classes=1000, width_mult=1.,
155                 has_dropout=False, inverted_residual_setting=None, round_nearest=8):
156        super(mobilenetV2, self).__init__()
157        block = InvertedResidual
158        input_channel = 32
159        last_channel = 1280
160        # setting of inverted residual blocks
161        self.cfgs = inverted_residual_setting
162        if inverted_residual_setting is None:
163            self.cfgs = [
164                # t, c, n, s
165                [1, 16, 1, 1],
166                [6, 24, 2, 2],
167                [6, 32, 3, 2],
168                [6, 64, 4, 2],
169                [6, 96, 3, 1],
170                [6, 160, 3, 2],
171                [6, 320, 1, 1],
172            ]
173
174        # building first layer
175        input_channel = _make_divisible(
176            input_channel * width_mult, round_nearest)
177        self.out_channels = _make_divisible(
178            last_channel * max(1.0, width_mult), round_nearest)
179
180        features = [ConvBNReLU(3, input_channel, stride=2)]
181        # building inverted residual blocks
182        for t, c, n, s in self.cfgs:
183            output_channel = _make_divisible(c * width_mult, round_nearest)
184            for i in range(n):
185                stride = s if i == 0 else 1
186                features.append(
187                    block(input_channel, output_channel, stride, expand_ratio=t))
188                input_channel = output_channel
189        # building last several layers
190        features.append(ConvBNReLU(
191            input_channel, self.out_channels, kernel_size=1))
192        # make it nn.CellList
193        self.features = nn.SequentialCell(features)
194        # mobilenet head
195        head = ([GlobalAvgPooling(),
196                 nn.DenseBnAct(self.out_channels, num_classes,
197                               has_bias=True, has_bn=False)
198                 ] if not has_dropout else
199                [GlobalAvgPooling(),
200                 nn.Dropout(0.2),
201                 nn.DenseBnAct(self.out_channels, num_classes,
202                               has_bias=True, has_bn=False)
203                 ])
204        self.head = nn.SequentialCell(head)
205
206        # init weights
207        self.init_parameters_data()
208        self._initialize_weights()
209
210    def construct(self, x):
211        x = self.features(x)
212        x = self.head(x)
213        return x
214
215    def _initialize_weights(self):
216        """
217        Initialize weights.
218
219        Args:
220
221        Returns:
222            None.
223
224        Examples:
225            >>> _initialize_weights()
226        """
227        self.init_parameters_data()
228        for _, m in self.cells_and_names():
229            np.random.seed(1)
230            if isinstance(m, nn.Conv2d):
231                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
232                w = Tensor(np.random.normal(0, np.sqrt(2. / n),
233                                            m.weight.data.shape).astype("float32"))
234                m.weight.set_data(w)
235                if m.bias is not None:
236                    m.bias.set_data(
237                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
238            elif isinstance(m, nn.Conv2dBnAct):
239                n = m.conv.kernel_size[0] * \
240                    m.conv.kernel_size[1] * m.conv.out_channels
241                w = Tensor(np.random.normal(0, np.sqrt(2. / n),
242                                            m.conv.weight.data.shape).astype("float32"))
243                m.conv.weight.set_data(w)
244                if m.conv.bias is not None:
245                    m.conv.bias.set_data(
246                        Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32")))
247            elif isinstance(m, nn.BatchNorm2d):
248                m.gamma.set_data(
249                    Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
250                m.beta.set_data(
251                    Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
252            elif isinstance(m, nn.Dense):
253                m.weight.set_data(Tensor(np.random.normal(
254                    0, 0.01, m.weight.data.shape).astype("float32")))
255                if m.bias is not None:
256                    m.bias.set_data(
257                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
258            elif isinstance(m, nn.DenseBnAct):
259                m.dense.weight.set_data(
260                    Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32")))
261                if m.dense.bias is not None:
262                    m.dense.bias.set_data(
263                        Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32")))
264