• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""YOLOv3 based on ResNet18."""
17
18import numpy as np
19import mindspore as ms
20import mindspore.nn as nn
21from mindspore import context, Tensor
22from mindspore.context import ParallelMode
23from mindspore.parallel._auto_parallel_context import auto_parallel_context
24from mindspore.communication.management import get_group_size
25from mindspore.common.initializer import TruncatedNormal
26from mindspore.ops import operations as P
27from mindspore.ops import functional as F
28from mindspore.ops import composite as C
29
30
31def weight_variable():
32    """Weight variable."""
33    return TruncatedNormal(0.02)
34
35
36class _conv2d(nn.Cell):
37    """Create Conv2D with padding."""
38    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
39        super(_conv2d, self).__init__()
40        self.conv = nn.Conv2d(in_channels, out_channels,
41                              kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same',
42                              weight_init=weight_variable())
43    def construct(self, x):
44        x = self.conv(x)
45        return x
46
47
48def _fused_bn(channels, momentum=0.99):
49    """Get a fused batchnorm."""
50    return nn.BatchNorm2d(channels, momentum=momentum)
51
52
53def _conv_bn_relu(in_channel,
54                  out_channel,
55                  ksize,
56                  stride=1,
57                  padding=0,
58                  dilation=1,
59                  alpha=0.1,
60                  momentum=0.99,
61                  pad_mode="same"):
62    """Get a conv2d batchnorm and relu layer."""
63    return nn.SequentialCell(
64        [nn.Conv2d(in_channel,
65                   out_channel,
66                   kernel_size=ksize,
67                   stride=stride,
68                   padding=padding,
69                   dilation=dilation,
70                   pad_mode=pad_mode),
71         nn.BatchNorm2d(out_channel, momentum=momentum),
72         nn.LeakyReLU(alpha)]
73    )
74
75
76class BasicBlock(nn.Cell):
77    """
78    ResNet basic block.
79
80    Args:
81        in_channels (int): Input channel.
82        out_channels (int): Output channel.
83        stride (int): Stride size for the initial convolutional layer. Default:1.
84        momentum (float): Momentum for batchnorm layer. Default:0.1.
85
86    Returns:
87        Tensor, output tensor.
88
89    Examples:
90        BasicBlock(3,256,stride=2,down_sample=True).
91    """
92    expansion = 1
93
94    def __init__(self,
95                 in_channels,
96                 out_channels,
97                 stride=1,
98                 momentum=0.99):
99        super(BasicBlock, self).__init__()
100
101        self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride)
102        self.bn1 = _fused_bn(out_channels, momentum=momentum)
103        self.conv2 = _conv2d(out_channels, out_channels, 3)
104        self.bn2 = _fused_bn(out_channels, momentum=momentum)
105        self.relu = P.ReLU()
106        self.down_sample_layer = None
107        self.downsample = (in_channels != out_channels)
108        if self.downsample:
109            self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
110        self.add = P.Add()
111
112    def construct(self, x):
113        identity = x
114
115        x = self.conv1(x)
116        x = self.bn1(x)
117        x = self.relu(x)
118
119        x = self.conv2(x)
120        x = self.bn2(x)
121
122        if self.downsample:
123            identity = self.down_sample_layer(identity)
124
125        out = self.add(x, identity)
126        out = self.relu(out)
127
128        return out
129
130
131class ResNet(nn.Cell):
132    """
133    ResNet network.
134
135    Args:
136        block (Cell): Block for network.
137        layer_nums (list): Numbers of different layers.
138        in_channels (int): Input channel.
139        out_channels (int): Output channel.
140        num_classes (int): Class number. Default:100.
141
142    Returns:
143        Tensor, output tensor.
144
145    Examples:
146        ResNet(ResidualBlock,
147               [3, 4, 6, 3],
148               [64, 256, 512, 1024],
149               [256, 512, 1024, 2048],
150               100).
151    """
152
153    def __init__(self,
154                 block,
155                 layer_nums,
156                 in_channels,
157                 out_channels,
158                 strides=None,
159                 num_classes=80):
160        super(ResNet, self).__init__()
161
162        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
163            raise ValueError("the length of "
164                             "layer_num, inchannel, outchannel list must be 4!")
165
166        self.conv1 = _conv2d(3, 64, 7, stride=2)
167        self.bn1 = _fused_bn(64)
168        self.relu = P.ReLU()
169        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
170
171        self.layer1 = self._make_layer(block,
172                                       layer_nums[0],
173                                       in_channel=in_channels[0],
174                                       out_channel=out_channels[0],
175                                       stride=strides[0])
176        self.layer2 = self._make_layer(block,
177                                       layer_nums[1],
178                                       in_channel=in_channels[1],
179                                       out_channel=out_channels[1],
180                                       stride=strides[1])
181        self.layer3 = self._make_layer(block,
182                                       layer_nums[2],
183                                       in_channel=in_channels[2],
184                                       out_channel=out_channels[2],
185                                       stride=strides[2])
186        self.layer4 = self._make_layer(block,
187                                       layer_nums[3],
188                                       in_channel=in_channels[3],
189                                       out_channel=out_channels[3],
190                                       stride=strides[3])
191
192        self.num_classes = num_classes
193        if num_classes:
194            self.reduce_mean = P.ReduceMean(keep_dims=True)
195            self.end_point = nn.Dense(out_channels[3], num_classes, has_bias=True,
196                                      weight_init=weight_variable(),
197                                      bias_init=weight_variable())
198            self.squeeze = P.Squeeze(axis=(2, 3))
199
200    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
201        """
202        Make Layer for ResNet.
203
204        Args:
205            block (Cell): Resnet block.
206            layer_num (int): Layer number.
207            in_channel (int): Input channel.
208            out_channel (int): Output channel.
209            stride (int): Stride size for the initial convolutional layer.
210
211        Returns:
212            SequentialCell, the output layer.
213
214        Examples:
215            _make_layer(BasicBlock, 3, 128, 256, 2).
216        """
217        layers = []
218
219        resblk = block(in_channel, out_channel, stride=stride)
220        layers.append(resblk)
221
222        for _ in range(1, layer_num - 1):
223            resblk = block(out_channel, out_channel, stride=1)
224            layers.append(resblk)
225
226        resblk = block(out_channel, out_channel, stride=1)
227        layers.append(resblk)
228
229        return nn.SequentialCell(layers)
230
231    def construct(self, x):
232        x = self.conv1(x)
233        x = self.bn1(x)
234        x = self.relu(x)
235        c1 = self.maxpool(x)
236
237        c2 = self.layer1(c1)
238        c3 = self.layer2(c2)
239        c4 = self.layer3(c3)
240        c5 = self.layer4(c4)
241
242        out = c5
243        if self.num_classes:
244            out = self.reduce_mean(c5, (2, 3))
245            out = self.squeeze(out)
246            out = self.end_point(out)
247
248        return c3, c4, out
249
250
251def resnet18(class_num=10):
252    """
253    Get ResNet18 neural network.
254
255    Args:
256        class_num (int): Class number.
257
258    Returns:
259        Cell, cell instance of ResNet18 neural network.
260
261    Examples:
262        resnet18(100).
263    """
264    return ResNet(BasicBlock,
265                  [2, 2, 2, 2],
266                  [64, 64, 128, 256],
267                  [64, 128, 256, 512],
268                  [1, 2, 2, 2],
269                  num_classes=class_num)
270
271
272class YoloBlock(nn.Cell):
273    """
274    YoloBlock for YOLOv3.
275
276    Args:
277        in_channels (int): Input channel.
278        out_chls (int): Middle channel.
279        out_channels (int): Output channel.
280
281    Returns:
282        Tuple, tuple of output tensor,(f1,f2,f3).
283
284    Examples:
285        YoloBlock(1024, 512, 255).
286
287    """
288    def __init__(self, in_channels, out_chls, out_channels):
289        super(YoloBlock, self).__init__()
290        out_chls_2 = out_chls * 2
291
292        self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
293        self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
294
295        self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
296        self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
297
298        self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
299        self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
300
301        self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
302
303    def construct(self, x):
304        c1 = self.conv0(x)
305        c2 = self.conv1(c1)
306
307        c3 = self.conv2(c2)
308        c4 = self.conv3(c3)
309
310        c5 = self.conv4(c4)
311        c6 = self.conv5(c5)
312
313        out = self.conv6(c6)
314        return c5, out
315
316
317class YOLOv3(nn.Cell):
318    """
319     YOLOv3 Network.
320
321     Note:
322         backbone = resnet18.
323
324     Args:
325         feature_shape (list): Input image shape, [N,C,H,W].
326         backbone_shape (list): resnet18 output channels shape.
327         backbone (Cell): Backbone Network.
328         out_channel (int): Output channel.
329
330     Returns:
331         Tensor, output tensor.
332
333     Examples:
334         YOLOv3(feature_shape=[1,3,416,416],
335                backbone_shape=[64, 128, 256, 512, 1024]
336                backbone=darknet53(),
337                out_channel=255).
338     """
339    def __init__(self, feature_shape, backbone_shape, backbone, out_channel):
340        super(YOLOv3, self).__init__()
341        self.out_channel = out_channel
342        self.net = backbone
343        self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
344
345        self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
346        self.upsample1 = P.ResizeNearestNeighbor((feature_shape[2]//16, feature_shape[3]//16))
347        self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
348                                    out_chls=backbone_shape[-3],
349                                    out_channels=out_channel)
350
351        self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1)
352        self.upsample2 = P.ResizeNearestNeighbor((feature_shape[2]//8, feature_shape[3]//8))
353        self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4],
354                                    out_chls=backbone_shape[-4],
355                                    out_channels=out_channel)
356        self.concat = P.Concat(axis=1)
357
358    def construct(self, x):
359        # input_shape of x is (batch_size, 3, h, w)
360        # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
361        # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
362        # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
363        feature_map1, feature_map2, feature_map3 = self.net(x)
364        con1, big_object_output = self.backblock0(feature_map3)
365
366        con1 = self.conv1(con1)
367        ups1 = self.upsample1(con1)
368        con1 = self.concat((ups1, feature_map2))
369        con2, medium_object_output = self.backblock1(con1)
370
371        con2 = self.conv2(con2)
372        ups2 = self.upsample2(con2)
373        con3 = self.concat((ups2, feature_map1))
374        _, small_object_output = self.backblock2(con3)
375
376        return big_object_output, medium_object_output, small_object_output
377
378
379class DetectionBlock(nn.Cell):
380    """
381     YOLOv3 detection Network. It will finally output the detection result.
382
383     Args:
384         scale (str): Character, scale.
385         config (Class): YOLOv3 config.
386
387     Returns:
388         Tuple, tuple of output tensor,(f1,f2,f3).
389
390     Examples:
391         DetectionBlock(scale='l',stride=32).
392     """
393
394    def __init__(self, scale, config):
395        super(DetectionBlock, self).__init__()
396
397        self.config = config
398        if scale == 's':
399            idx = (0, 1, 2)
400        elif scale == 'm':
401            idx = (3, 4, 5)
402        elif scale == 'l':
403            idx = (6, 7, 8)
404        else:
405            raise KeyError("Invalid scale value for DetectionBlock")
406        self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
407        self.num_anchors_per_scale = 3
408        self.num_attrib = 4 + 1 + self.config.num_classes
409        self.ignore_threshold = 0.5
410        self.lambda_coord = 1
411
412        self.sigmoid = nn.Sigmoid()
413        self.reshape = P.Reshape()
414        self.tile = P.Tile()
415        self.concat = P.Concat(axis=-1)
416        self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
417
418    def construct(self, x):
419        num_batch = P.Shape()(x)[0]
420        grid_size = P.Shape()(x)[2:4]
421
422        # Reshape and transpose the feature to [n, 3, grid_size[0], grid_size[1], num_attrib]
423        prediction = P.Reshape()(x, (num_batch,
424                                     self.num_anchors_per_scale,
425                                     self.num_attrib,
426                                     grid_size[0],
427                                     grid_size[1]))
428        prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
429
430        range_x = range(grid_size[1])
431        range_y = range(grid_size[0])
432        grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
433        grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
434        # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
435        grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
436        grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
437        # Shape is [grid_size[0], grid_size[1], 1, 2]
438        grid = self.concat((grid_x, grid_y))
439
440        box_xy = prediction[:, :, :, :, :2]
441        box_wh = prediction[:, :, :, :, 2:4]
442        box_confidence = prediction[:, :, :, :, 4:5]
443        box_probs = prediction[:, :, :, :, 5:]
444
445        box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
446        box_wh = P.Exp()(box_wh) * self.anchors / self.input_shape
447        box_confidence = self.sigmoid(box_confidence)
448        box_probs = self.sigmoid(box_probs)
449
450        if self.training:
451            return grid, prediction, box_xy, box_wh
452        return box_xy, box_wh, box_confidence, box_probs
453
454
455class Iou(nn.Cell):
456    """Calculate the iou of boxes."""
457    def __init__(self):
458        super(Iou, self).__init__()
459        self.min = P.Minimum()
460        self.max = P.Maximum()
461
462    def construct(self, box1, box2):
463        box1_xy = box1[:, :, :, :, :, :2]
464        box1_wh = box1[:, :, :, :, :, 2:4]
465        box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0)
466        box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0)
467
468        box2_xy = box2[:, :, :, :, :, :2]
469        box2_wh = box2[:, :, :, :, :, 2:4]
470        box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
471        box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
472
473        intersect_mins = self.max(box1_mins, box2_mins)
474        intersect_maxs = self.min(box1_maxs, box2_maxs)
475        intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
476
477        intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
478                         P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
479        box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
480        box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
481
482        iou = intersect_area / (box1_area + box2_area - intersect_area)
483        return iou
484
485
486class YoloLossBlock(nn.Cell):
487    """
488     YOLOv3 Loss block cell. It will finally output loss of the scale.
489
490     Args:
491         scale (str): Three scale here, 's', 'm' and 'l'.
492         config (Class): The default config of YOLOv3.
493
494     Returns:
495         Tensor, loss of the scale.
496
497     Examples:
498         YoloLossBlock('l', ConfigYOLOV3ResNet18()).
499     """
500
501    def __init__(self, scale, config):
502        super(YoloLossBlock, self).__init__()
503        self.config = config
504        if scale == 's':
505            idx = (0, 1, 2)
506        elif scale == 'm':
507            idx = (3, 4, 5)
508        elif scale == 'l':
509            idx = (6, 7, 8)
510        else:
511            raise KeyError("Invalid scale value for DetectionBlock")
512        self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
513        self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
514        self.concat = P.Concat(axis=-1)
515        self.iou = Iou()
516        self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
517        self.reduce_sum = P.ReduceSum()
518        self.reduce_max = P.ReduceMax(keep_dims=False)
519        self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
520
521    def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box):
522
523        object_mask = y_true[:, :, :, :, 4:5]
524        class_probs = y_true[:, :, :, :, 5:]
525
526        grid_shape = P.Shape()(prediction)[1:3]
527        grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
528
529        pred_boxes = self.concat((pred_xy, pred_wh))
530        true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
531        true_wh = y_true[:, :, :, :, 2:4]
532        true_wh = P.Select()(P.Equal()(true_wh, 0.0),
533                             P.Fill()(P.DType()(true_wh), P.Shape()(true_wh), 1.0),
534                             true_wh)
535        true_wh = P.Log()(true_wh / self.anchors * self.input_shape)
536        box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
537
538        gt_shape = P.Shape()(gt_box)
539        gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
540
541        iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) # [batch, grid[0], grid[1], num_anchor, num_gt]
542        best_iou = self.reduce_max(iou, -1) # [batch, grid[0], grid[1], num_anchor]
543        ignore_mask = best_iou < self.ignore_threshold
544        ignore_mask = P.Cast()(ignore_mask, ms.float32)
545        ignore_mask = P.ExpandDims()(ignore_mask, -1)
546        ignore_mask = F.stop_gradient(ignore_mask)
547
548        xy_loss = object_mask * box_loss_scale * self.cross_entropy(prediction[:, :, :, :, :2], true_xy)
549        wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - prediction[:, :, :, :, 2:4])
550        confidence_loss = self.cross_entropy(prediction[:, :, :, :, 4:5], object_mask)
551        confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
552        class_loss = object_mask * self.cross_entropy(prediction[:, :, :, :, 5:], class_probs)
553
554        # Get smooth loss
555        xy_loss = self.reduce_sum(xy_loss, ())
556        wh_loss = self.reduce_sum(wh_loss, ())
557        confidence_loss = self.reduce_sum(confidence_loss, ())
558        class_loss = self.reduce_sum(class_loss, ())
559
560        loss = xy_loss + wh_loss + confidence_loss + class_loss
561        return loss / P.Shape()(prediction)[0]
562
563
564class yolov3_resnet18(nn.Cell):
565    """
566    ResNet based YOLOv3 network.
567
568    Args:
569        config (Class): YOLOv3 config.
570
571    Returns:
572        Cell, cell instance of ResNet based YOLOv3 neural network.
573
574    Examples:
575        yolov3_resnet18(80, [1,3,416,416]).
576    """
577
578    def __init__(self, config):
579        super(yolov3_resnet18, self).__init__()
580        self.config = config
581
582        # YOLOv3 network
583        self.feature_map = YOLOv3(feature_shape=self.config.feature_shape,
584                                  backbone=ResNet(BasicBlock,
585                                                  self.config.backbone_layers,
586                                                  self.config.backbone_input_shape,
587                                                  self.config.backbone_shape,
588                                                  self.config.backbone_stride,
589                                                  num_classes=None),
590                                  backbone_shape=self.config.backbone_shape,
591                                  out_channel=self.config.out_channel)
592
593        # prediction on the default anchor boxes
594        self.detect_1 = DetectionBlock('l', self.config)
595        self.detect_2 = DetectionBlock('m', self.config)
596        self.detect_3 = DetectionBlock('s', self.config)
597
598    def construct(self, x):
599        big_object_output, medium_object_output, small_object_output = self.feature_map(x)
600        output_big = self.detect_1(big_object_output)
601        output_me = self.detect_2(medium_object_output)
602        output_small = self.detect_3(small_object_output)
603
604        return output_big, output_me, output_small
605
606
607class YoloWithLossCell(nn.Cell):
608    """"
609    Provide YOLOv3 training loss through network.
610
611    Args:
612        network (Cell): The training network.
613        config (Class): YOLOv3 config.
614
615    Returns:
616        Tensor, the loss of the network.
617    """
618    def __init__(self, network, config):
619        super(YoloWithLossCell, self).__init__()
620        self.yolo_network = network
621        self.config = config
622        self.loss_big = YoloLossBlock('l', self.config)
623        self.loss_me = YoloLossBlock('m', self.config)
624        self.loss_small = YoloLossBlock('s', self.config)
625
626    def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
627        yolo_out = self.yolo_network(x)
628        loss_l = self.loss_big(yolo_out[0][0], yolo_out[0][1], yolo_out[0][2], yolo_out[0][3], y_true_0, gt_0)
629        loss_m = self.loss_me(yolo_out[1][0], yolo_out[1][1], yolo_out[1][2], yolo_out[1][3], y_true_1, gt_1)
630        loss_s = self.loss_small(yolo_out[2][0], yolo_out[2][1], yolo_out[2][2], yolo_out[2][3], y_true_2, gt_2)
631        return loss_l + loss_m + loss_s
632
633
634class TrainingWrapper(nn.Cell):
635    """
636    Encapsulation class of YOLOv3 network training.
637
638    Append an optimizer to the training network after that the construct
639    function can be called to create the backward graph.
640
641    Args:
642        network (Cell): The training network. Note that loss function should have been added.
643        optimizer (Optimizer): Optimizer for updating the weights.
644        sens (Number): The adjust parameter. Default: 1.0.
645    """
646    def __init__(self, network, optimizer, sens=1.0):
647        super(TrainingWrapper, self).__init__(auto_prefix=False)
648        self.network = network
649        self.weights = ms.ParameterTuple(network.trainable_params())
650        self.optimizer = optimizer
651        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
652        self.sens = sens
653        self.reducer_flag = False
654        self.grad_reducer = None
655        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
656        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
657            self.reducer_flag = True
658        if self.reducer_flag:
659            mean = context.get_auto_parallel_context("gradients_mean")
660            if auto_parallel_context().get_device_num_is_set():
661                degree = context.get_auto_parallel_context("device_num")
662            else:
663                degree = get_group_size()
664            self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
665
666    def construct(self, *args):
667        weights = self.weights
668        loss = self.network(*args)
669        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
670        grads = self.grad(self.network, weights)(*args, sens)
671        if self.reducer_flag:
672            # apply grad reducer on grads
673            grads = self.grad_reducer(grads)
674        return F.depend(loss, self.optimizer(grads))
675
676
677class YoloBoxScores(nn.Cell):
678    """
679    Calculate the boxes of the original picture size and the score of each box.
680
681    Args:
682        config (Class): YOLOv3 config.
683
684    Returns:
685        Tensor, the boxes of the original picture size.
686        Tensor, the score of each box.
687    """
688    def __init__(self, config):
689        super(YoloBoxScores, self).__init__()
690        self.input_shape = Tensor(np.array(config.img_shape), ms.float32)
691        self.num_classes = config.num_classes
692
693    def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape):
694        batch_size = F.shape(box_xy)[0]
695        x = box_xy[:, :, :, :, 0:1]
696        y = box_xy[:, :, :, :, 1:2]
697        box_yx = P.Concat(-1)((y, x))
698        w = box_wh[:, :, :, :, 0:1]
699        h = box_wh[:, :, :, :, 1:2]
700        box_hw = P.Concat(-1)((h, w))
701
702        new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape))
703        offset = (self.input_shape - new_shape) / 2.0 / self.input_shape
704        scale = self.input_shape / new_shape
705        box_yx = (box_yx - offset) * scale
706        box_hw = box_hw * scale
707
708        box_min = box_yx - box_hw / 2.0
709        box_max = box_yx + box_hw / 2.0
710        boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1],
711                              box_min[:, :, :, :, 1:2],
712                              box_max[:, :, :, :, 0:1],
713                              box_max[:, :, :, :, 1:2]))
714        image_scale = P.Tile()(image_shape, (1, 2))
715        boxes = boxes * image_scale
716        boxes = F.reshape(boxes, (batch_size, -1, 4))
717        boxes_scores = box_confidence * box_probs
718        boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes))
719        return boxes, boxes_scores
720
721
722class YoloWithEval(nn.Cell):
723    """
724    Encapsulation class of YOLOv3 evaluation.
725
726    Args:
727        network (Cell): The training network. Note that loss function and optimizer must not be added.
728        config (Class): YOLOv3 config.
729
730    Returns:
731        Tensor, the boxes of the original picture size.
732        Tensor, the score of each box.
733        Tensor, the original picture size.
734    """
735    def __init__(self, network, config):
736        super(YoloWithEval, self).__init__()
737        self.yolo_network = network
738        self.box_score_0 = YoloBoxScores(config)
739        self.box_score_1 = YoloBoxScores(config)
740        self.box_score_2 = YoloBoxScores(config)
741
742    def construct(self, x, image_shape):
743        yolo_output = self.yolo_network(x)
744        boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape)
745        boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape)
746        boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape)
747        boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2))
748        boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2))
749        return boxes, boxes_scores, image_shape
750