• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the License);
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# httpwww.apache.orglicensesLICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an AS IS BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""DeepLabv3."""
16
17import numpy as np
18import mindspore.nn as nn
19from mindspore.ops import operations as P
20from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
21    DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
22
23
24class ASPPSampleBlock(nn.Cell):
25    """ASPP sample block."""
26    def __init__(self, feature_shape, scale_size, output_stride):
27        super(ASPPSampleBlock, self).__init__()
28        sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
29        sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
30        self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
31
32    def construct(self, x):
33        return self.sample(x)
34
35
36class ASPP(nn.Cell):
37    """
38    ASPP model for DeepLabv3.
39
40    Args:
41        channel (int): Input channel.
42        depth (int): Output channel.
43        feature_shape (list): The shape of feature,[h,w].
44        scale_sizes (list): Input scales for multi-scale feature extraction.
45        atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
46        output_stride (int): 'The ratio of input to output spatial resolution.'
47        fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
48
49    Returns:
50        Tensor, output tensor.
51
52    Examples:
53        >>> ASPP(channel=2048,256,[14,14],[1],[6],16)
54    """
55    def __init__(self, channel, depth, feature_shape, scale_sizes,
56                 atrous_rates, output_stride, fine_tune_batch_norm=False):
57        super(ASPP, self).__init__()
58        self.aspp0 = _conv_bn_relu(channel,
59                                   depth,
60                                   ksize=1,
61                                   stride=1,
62                                   use_batch_statistics=fine_tune_batch_norm)
63        self.atrous_rates = []
64        if atrous_rates is not None:
65            self.atrous_rates = atrous_rates
66            self.aspp_pointwise = _conv_bn_relu(channel,
67                                                depth,
68                                                ksize=1,
69                                                stride=1,
70                                                use_batch_statistics=fine_tune_batch_norm)
71            self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
72                                                                  channel_multiplier=1,
73                                                                  kernel_size=3,
74                                                                  stride=1,
75                                                                  dilation=1,
76                                                                  pad_mode="valid")
77            self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
78            self.aspp_depth_relu = nn.ReLU()
79            self.aspp_depths = []
80            self.aspp_depth_spacetobatchs = []
81            self.aspp_depth_batchtospaces = []
82
83            for scale_size in scale_sizes:
84                aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
85                if atrous_rates is None:
86                    break
87                for rate in atrous_rates:
88                    padding = 0
89                    for j in range(100):
90                        padded_size = rate * j
91                        if padded_size >= aspp_scale_depth_size + 2 * rate:
92                            padding = padded_size - aspp_scale_depth_size - 2 * rate
93                            break
94                    paddings = [[rate, rate + int(padding)],
95                                [rate, rate + int(padding)]]
96                    self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings)
97                    self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
98                    crops = [[0, int(padding)], [0, int(padding)]]
99                    self.aspp_depth_batchtospace = BatchToSpace(rate, crops)
100                    self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
101            self.aspp_depths = nn.CellList(self.aspp_depths)
102            self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
103            self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
104
105        self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1])))
106        self.global_poolings = []
107        for scale_size in scale_sizes:
108            pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
109            pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
110            self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
111        self.global_poolings = nn.CellList(self.global_poolings)
112        self.conv_bn = _conv_bn_relu(channel,
113                                     depth,
114                                     ksize=1,
115                                     stride=1,
116                                     use_batch_statistics=fine_tune_batch_norm)
117        self.samples = []
118        for scale_size in scale_sizes:
119            self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride))
120        self.samples = nn.CellList(self.samples)
121        self.feature_shape = feature_shape
122        self.concat = P.Concat(axis=1)
123
124    def construct(self, x, scale_index=0):
125        aspp0 = self.aspp0(x)
126        aspp1 = self.global_poolings[scale_index](x)
127        aspp1 = self.conv_bn(aspp1)
128        aspp1 = self.samples[scale_index](aspp1)
129        output = self.concat((aspp1, aspp0))
130
131        for i in range(len(self.atrous_rates)):
132            aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
133            aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
134            aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
135            aspp_i = self.aspp_depth_bn(aspp_i)
136            aspp_i = self.aspp_depth_relu(aspp_i)
137            aspp_i = self.aspp_pointwise(aspp_i)
138            output = self.concat((output, aspp_i))
139        return output
140
141
142class DecoderSampleBlock(nn.Cell):
143    """Decoder sample block."""
144    def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4):
145        super(DecoderSampleBlock, self).__init__()
146        sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
147        sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
148        self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
149
150    def construct(self, x):
151        return self.sample(x)
152
153
154class Decoder(nn.Cell):
155    """
156    Decode module for DeepLabv3.
157    Args:
158        low_level_channel (int): Low level input channel
159        channel (int): Input channel.
160        depth (int): Output channel.
161        feature_shape (list): 'Input image shape, [N,C,H,W].'
162        scale_sizes (list): 'Input scales for multi-scale feature extraction.'
163        decoder_output_stride (int): 'The ratio of input to output spatial resolution'
164        fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
165    Returns:
166        Tensor, output tensor.
167    Examples:
168        >>> Decoder(256, 100, [56,56])
169    """
170    def __init__(self,
171                 low_level_channel,
172                 channel,
173                 depth,
174                 feature_shape,
175                 scale_sizes,
176                 decoder_output_stride,
177                 fine_tune_batch_norm):
178        super(Decoder, self).__init__()
179        self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
180                                                pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
181        self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
182                                                 channel_multiplier=1,
183                                                 ksize=3,
184                                                 stride=1,
185                                                 pad_mode="same",
186                                                 dilation=1,
187                                                 use_batch_statistics=fine_tune_batch_norm)
188        self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
189                                                depth,
190                                                ksize=1,
191                                                stride=1,
192                                                use_batch_statistics=fine_tune_batch_norm)
193        self.decoder_depth1 = _deep_conv_bn_relu(depth,
194                                                 channel_multiplier=1,
195                                                 ksize=3,
196                                                 stride=1,
197                                                 pad_mode="same",
198                                                 dilation=1,
199                                                 use_batch_statistics=fine_tune_batch_norm)
200        self.decoder_pointwise1 = _conv_bn_relu(depth,
201                                                depth,
202                                                ksize=1,
203                                                stride=1,
204                                                use_batch_statistics=fine_tune_batch_norm)
205        self.depth = depth
206        self.concat = P.Concat(axis=1)
207        self.samples = []
208        for scale_size in scale_sizes:
209            self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride))
210        self.samples = nn.CellList(self.samples)
211
212    def construct(self, x, low_level_feature, scale_index):
213        low_level_feature = self.feature_projection(low_level_feature)
214        low_level_feature = self.samples[scale_index](low_level_feature)
215        x = self.samples[scale_index](x)
216        output = self.concat((x, low_level_feature))
217        output = self.decoder_depth0(output)
218        output = self.decoder_pointwise0(output)
219        output = self.decoder_depth1(output)
220        output = self.decoder_pointwise1(output)
221        return output
222
223
224class SingleDeepLabV3(nn.Cell):
225    """
226    DeepLabv3 Network.
227    Args:
228        num_classes (int): Class number.
229        feature_shape (list): Input image shape, [N,C,H,W].
230        backbone (Cell): Backbone Network.
231        channel (int): Resnet output channel.
232        depth (int): ASPP block depth.
233        scale_sizes (list): Input scales for multi-scale feature extraction.
234        atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
235        decoder_output_stride (int): 'The ratio of input to output spatial resolution'
236        output_stride (int): 'The ratio of input to output spatial resolution.'
237        fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
238    Returns:
239        Tensor, output tensor.
240    Examples:
241        >>> SingleDeepLabV3(num_classes=10,
242            >>>           feature_shape=[1,3,224,224],
243            >>>           backbone=resnet50_dl(),
244            >>>           channel=2048,
245            >>>           depth=256)
246            >>>           scale_sizes=[1.0])
247            >>>           atrous_rates=[6])
248            >>>           decoder_output_stride=4)
249            >>>           output_stride=16)
250        """
251
252    def __init__(self,
253                 num_classes,
254                 feature_shape,
255                 backbone,
256                 channel,
257                 depth,
258                 scale_sizes,
259                 atrous_rates,
260                 decoder_output_stride,
261                 output_stride,
262                 fine_tune_batch_norm=False):
263        super(SingleDeepLabV3, self).__init__()
264        self.num_classes = num_classes
265        self.channel = channel
266        self.depth = depth
267        self.scale_sizes = []
268        for scale_size in np.sort(scale_sizes):
269            self.scale_sizes.append(scale_size)
270        self.net = backbone
271        self.aspp = ASPP(channel=self.channel,
272                         depth=self.depth,
273                         feature_shape=[feature_shape[2],
274                                        feature_shape[3]],
275                         scale_sizes=self.scale_sizes,
276                         atrous_rates=atrous_rates,
277                         output_stride=output_stride,
278                         fine_tune_batch_norm=fine_tune_batch_norm)
279
280        atrous_rates_len = 0
281        if atrous_rates is not None:
282            atrous_rates_len = len(atrous_rates)
283        self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
284                                 ksize=1,
285                                 stride=1,
286                                 use_batch_statistics=fine_tune_batch_norm)
287        self.fc2 = nn.Conv2d(depth,
288                             num_classes,
289                             kernel_size=1,
290                             stride=1,
291                             has_bias=True)
292        self.upsample = P.ResizeBilinear((int(feature_shape[2]),
293                                          int(feature_shape[3])),
294                                         align_corners=True)
295        self.samples = []
296        for scale_size in self.scale_sizes:
297            self.samples.append(SampleBlock(feature_shape, scale_size))
298        self.samples = nn.CellList(self.samples)
299        self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
300                              float(feature_shape[3])]
301
302        self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
303        self.dropout = nn.Dropout(keep_prob=0.9)
304        self.shape = P.Shape()
305        self.decoder_output_stride = decoder_output_stride
306        if decoder_output_stride is not None:
307            self.decoder = Decoder(low_level_channel=depth,
308                                   channel=depth,
309                                   depth=depth,
310                                   feature_shape=[feature_shape[2],
311                                                  feature_shape[3]],
312                                   scale_sizes=self.scale_sizes,
313                                   decoder_output_stride=decoder_output_stride,
314                                   fine_tune_batch_norm=fine_tune_batch_norm)
315
316    def construct(self, x, scale_index=0):
317        x = (2.0 / 255.0) * x - 1.0
318        x = self.pad(x)
319        low_level_feature, feature_map = self.net(x)
320        for scale_size in self.scale_sizes:
321            if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
322                output = self.aspp(feature_map, scale_index)
323                output = self.fc1(output)
324                if self.decoder_output_stride is not None:
325                    output = self.decoder(output, low_level_feature, scale_index)
326                output = self.fc2(output)
327                output = self.samples[scale_index](output)
328                return output
329            scale_index += 1
330        return feature_map
331
332
333class SampleBlock(nn.Cell):
334    """Sample block."""
335    def __init__(self,
336                 feature_shape,
337                 scale_size=1.0):
338        super(SampleBlock, self).__init__()
339        sample_h = np.ceil(float(feature_shape[2]) * scale_size)
340        sample_w = np.ceil(float(feature_shape[3]) * scale_size)
341        self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
342
343    def construct(self, x):
344        return self.sample(x)
345
346
347class DeepLabV3(nn.Cell):
348    """DeepLabV3 model."""
349    def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
350                 decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
351        super(DeepLabV3, self).__init__()
352        self.infer_scale_sizes = []
353        if infer_scale_sizes is not None:
354            self.infer_scale_sizes = infer_scale_sizes
355
356        self.infer_scale_sizes = infer_scale_sizes
357        if image_pyramid is None:
358            image_pyramid = [1.0]
359
360        self.image_pyramid = image_pyramid
361        scale_sizes = []
362        for pyramid in image_pyramid:
363            scale_sizes.append(pyramid)
364        for scale in infer_scale_sizes:
365            scale_sizes.append(scale)
366        self.samples = []
367        for scale_size in scale_sizes:
368            self.samples.append(SampleBlock(feature_shape, scale_size))
369        self.samples = nn.CellList(self.samples)
370        self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
371                                         feature_shape=feature_shape,
372                                         backbone=resnet50_dl(fine_tune_batch_norm),
373                                         channel=channel,
374                                         depth=depth,
375                                         scale_sizes=scale_sizes,
376                                         atrous_rates=atrous_rates,
377                                         decoder_output_stride=decoder_output_stride,
378                                         output_stride=output_stride,
379                                         fine_tune_batch_norm=fine_tune_batch_norm)
380        self.softmax = P.Softmax(axis=1)
381        self.concat = P.Concat(axis=2)
382        self.expand_dims = P.ExpandDims()
383        self.reduce_mean = P.ReduceMean()
384        self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
385                                               int(feature_shape[3])),
386                                              align_corners=True)
387
388    def construct(self, x):
389        logits = ()
390        if self.training:
391            if len(self.image_pyramid) >= 1:
392                if self.image_pyramid[0] == 1:
393                    logits = self.deeplabv3(x)
394                else:
395                    x1 = self.samples[0](x)
396                    logits = self.deeplabv3(x1)
397                    logits = self.sample_common(logits)
398                logits = self.expand_dims(logits, 2)
399                for i in range(len(self.image_pyramid) - 1):
400                    x_i = self.samples[i + 1](x)
401                    logits_i = self.deeplabv3(x_i)
402                    logits_i = self.sample_common(logits_i)
403                    logits_i = self.expand_dims(logits_i, 2)
404                    logits = self.concat((logits, logits_i))
405            logits = self.reduce_mean(logits, 2)
406            return logits
407        if len(self.infer_scale_sizes) >= 1:
408            infer_index = len(self.image_pyramid)
409            x1 = self.samples[infer_index](x)
410            logits = self.deeplabv3(x1)
411            logits = self.sample_common(logits)
412            logits = self.softmax(logits)
413            logits = self.expand_dims(logits, 2)
414            for i in range(len(self.infer_scale_sizes) - 1):
415                x_i = self.samples[i + 1 + infer_index](x)
416                logits_i = self.deeplabv3(x_i)
417                logits_i = self.sample_common(logits_i)
418                logits_i = self.softmax(logits_i)
419                logits_i = self.expand_dims(logits_i, 2)
420                logits = self.concat((logits, logits_i))
421        logits = self.reduce_mean(logits, 2)
422        return logits
423
424
425def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
426                       infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
427                       output_stride=16, fine_tune_batch_norm=False):
428    """
429    ResNet50 based DeepLabv3 network.
430
431    Args:
432        num_classes (int): Class number.
433        feature_shape (list): Input image shape, [N,C,H,W].
434        image_pyramid (list): Input scales for multi-scale feature extraction.
435        atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
436        infer_scale_sizes (list): 'The scales to resize images for inference.
437        decoder_output_stride (int): 'The ratio of input to output spatial resolution'
438        output_stride (int): 'The ratio of input to output spatial resolution.'
439        fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
440
441    Returns:
442        Cell, cell instance of ResNet50 based DeepLabv3 neural network.
443
444    Examples:
445        >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
446    """
447    return DeepLabV3(num_classes=num_classes,
448                     feature_shape=feature_shape,
449                     backbone=resnet50_dl(fine_tune_batch_norm),
450                     channel=2048,
451                     depth=256,
452                     infer_scale_sizes=infer_scale_sizes,
453                     atrous_rates=atrous_rates,
454                     decoder_output_stride=decoder_output_stride,
455                     output_stride=output_stride,
456                     fine_tune_batch_norm=fine_tune_batch_norm,
457                     image_pyramid=image_pyramid)
458