• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Util class or function."""
16from mindspore.train.serialization import load_checkpoint
17import mindspore.nn as nn
18import mindspore.common.dtype as mstype
19
20from .yolo import YoloLossBlock
21
22
23class AverageMeter:
24    """Computes and stores the average and current value"""
25
26    def __init__(self, name, fmt=':f', tb_writer=None):
27        self.name = name
28        self.fmt = fmt
29        self.reset()
30        self.tb_writer = tb_writer
31        self.cur_step = 1
32        self.val = 0
33        self.avg = 0
34        self.sum = 0
35        self.count = 0
36
37    def reset(self):
38        self.val = 0
39        self.avg = 0
40        self.sum = 0
41        self.count = 0
42
43    def update(self, val, n=1):
44        self.val = val
45        self.sum += val * n
46        self.count += n
47        self.avg = self.sum / self.count
48        if self.tb_writer is not None:
49            self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
50        self.cur_step += 1
51
52    def __str__(self):
53        fmtstr = '{name}:{avg' + self.fmt + '}'
54        return fmtstr.format(**self.__dict__)
55
56
57def load_backbone(net, ckpt_path, args):
58    """Load darknet53 backbone checkpoint."""
59    param_dict = load_checkpoint(ckpt_path)
60    yolo_backbone_prefix = 'feature_map.backbone'
61    darknet_backbone_prefix = 'network.backbone'
62    find_param = []
63    not_found_param = []
64    net.init_parameters_data()
65    for name, cell in net.cells_and_names():
66        if name.startswith(yolo_backbone_prefix):
67            name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
68            if isinstance(cell, (nn.Conv2d, nn.Dense)):
69                darknet_weight = '{}.weight'.format(name)
70                darknet_bias = '{}.bias'.format(name)
71                if darknet_weight in param_dict:
72                    cell.weight.set_data(param_dict[darknet_weight].data)
73                    find_param.append(darknet_weight)
74                else:
75                    not_found_param.append(darknet_weight)
76                if darknet_bias in param_dict:
77                    cell.bias.set_data(param_dict[darknet_bias].data)
78                    find_param.append(darknet_bias)
79                else:
80                    not_found_param.append(darknet_bias)
81            elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
82                darknet_moving_mean = '{}.moving_mean'.format(name)
83                darknet_moving_variance = '{}.moving_variance'.format(name)
84                darknet_gamma = '{}.gamma'.format(name)
85                darknet_beta = '{}.beta'.format(name)
86                if darknet_moving_mean in param_dict:
87                    cell.moving_mean.set_data(param_dict[darknet_moving_mean].data)
88                    find_param.append(darknet_moving_mean)
89                else:
90                    not_found_param.append(darknet_moving_mean)
91                if darknet_moving_variance in param_dict:
92                    cell.moving_variance.set_data(param_dict[darknet_moving_variance].data)
93                    find_param.append(darknet_moving_variance)
94                else:
95                    not_found_param.append(darknet_moving_variance)
96                if darknet_gamma in param_dict:
97                    cell.gamma.set_data(param_dict[darknet_gamma].data)
98                    find_param.append(darknet_gamma)
99                else:
100                    not_found_param.append(darknet_gamma)
101                if darknet_beta in param_dict:
102                    cell.beta.set_data(param_dict[darknet_beta].data)
103                    find_param.append(darknet_beta)
104                else:
105                    not_found_param.append(darknet_beta)
106
107    args.logger.info('================found_param {}========='.format(len(find_param)))
108    args.logger.info(find_param)
109    args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
110    args.logger.info(not_found_param)
111    args.logger.info('=====load {} successfully ====='.format(ckpt_path))
112
113    return net
114
115
116def default_wd_filter(x):
117    """default weight decay filter."""
118    parameter_name = x.name
119    if parameter_name.endswith('.bias'):
120        # all bias not using weight decay
121        return False
122    if parameter_name.endswith('.gamma'):
123        # bn weight bias not using weight decay, be carefully for now x not include BN
124        return False
125    if parameter_name.endswith('.beta'):
126        # bn weight bias not using weight decay, be carefully for now x not include BN
127        return False
128
129    return True
130
131
132def get_param_groups(network):
133    """Param groups for optimizer."""
134    decay_params = []
135    no_decay_params = []
136    for x in network.trainable_params():
137        parameter_name = x.name
138        if parameter_name.endswith('.bias'):
139            # all bias not using weight decay
140            no_decay_params.append(x)
141        elif parameter_name.endswith('.gamma'):
142            # bn weight bias not using weight decay, be carefully for now x not include BN
143            no_decay_params.append(x)
144        elif parameter_name.endswith('.beta'):
145            # bn weight bias not using weight decay, be carefully for now x not include BN
146            no_decay_params.append(x)
147        else:
148            decay_params.append(x)
149
150    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
151
152
153class ShapeRecord:
154    """Log image shape."""
155    def __init__(self):
156        self.shape_record = {
157            320: 0,
158            352: 0,
159            384: 0,
160            416: 0,
161            448: 0,
162            480: 0,
163            512: 0,
164            544: 0,
165            576: 0,
166            608: 0,
167            'total': 0
168        }
169
170    def set(self, shape):
171        if len(shape) > 1:
172            shape = shape[0]
173        shape = int(shape)
174        self.shape_record[shape] += 1
175        self.shape_record['total'] += 1
176
177    def show(self, logger):
178        for key in self.shape_record:
179            rate = self.shape_record[key] / float(self.shape_record['total'])
180            logger.info('shape {}: {:.2f}%'.format(key, rate*100))
181
182
183def keep_loss_fp32(network):
184    """Keep loss of network with float32"""
185    for _, cell in network.cells_and_names():
186        if isinstance(cell, (YoloLossBlock,)):
187            cell.to_float(mstype.float32)
188