1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Util class or function.""" 16from mindspore.train.serialization import load_checkpoint 17import mindspore.nn as nn 18import mindspore.common.dtype as mstype 19 20from .yolo import YoloLossBlock 21 22 23class AverageMeter: 24 """Computes and stores the average and current value""" 25 26 def __init__(self, name, fmt=':f', tb_writer=None): 27 self.name = name 28 self.fmt = fmt 29 self.reset() 30 self.tb_writer = tb_writer 31 self.cur_step = 1 32 self.val = 0 33 self.avg = 0 34 self.sum = 0 35 self.count = 0 36 37 def reset(self): 38 self.val = 0 39 self.avg = 0 40 self.sum = 0 41 self.count = 0 42 43 def update(self, val, n=1): 44 self.val = val 45 self.sum += val * n 46 self.count += n 47 self.avg = self.sum / self.count 48 if self.tb_writer is not None: 49 self.tb_writer.add_scalar(self.name, self.val, self.cur_step) 50 self.cur_step += 1 51 52 def __str__(self): 53 fmtstr = '{name}:{avg' + self.fmt + '}' 54 return fmtstr.format(**self.__dict__) 55 56 57def load_backbone(net, ckpt_path, args): 58 """Load darknet53 backbone checkpoint.""" 59 param_dict = load_checkpoint(ckpt_path) 60 yolo_backbone_prefix = 'feature_map.backbone' 61 darknet_backbone_prefix = 'network.backbone' 62 find_param = [] 63 not_found_param = [] 64 net.init_parameters_data() 65 for name, cell in net.cells_and_names(): 66 if name.startswith(yolo_backbone_prefix): 67 name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) 68 if isinstance(cell, (nn.Conv2d, nn.Dense)): 69 darknet_weight = '{}.weight'.format(name) 70 darknet_bias = '{}.bias'.format(name) 71 if darknet_weight in param_dict: 72 cell.weight.set_data(param_dict[darknet_weight].data) 73 find_param.append(darknet_weight) 74 else: 75 not_found_param.append(darknet_weight) 76 if darknet_bias in param_dict: 77 cell.bias.set_data(param_dict[darknet_bias].data) 78 find_param.append(darknet_bias) 79 else: 80 not_found_param.append(darknet_bias) 81 elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): 82 darknet_moving_mean = '{}.moving_mean'.format(name) 83 darknet_moving_variance = '{}.moving_variance'.format(name) 84 darknet_gamma = '{}.gamma'.format(name) 85 darknet_beta = '{}.beta'.format(name) 86 if darknet_moving_mean in param_dict: 87 cell.moving_mean.set_data(param_dict[darknet_moving_mean].data) 88 find_param.append(darknet_moving_mean) 89 else: 90 not_found_param.append(darknet_moving_mean) 91 if darknet_moving_variance in param_dict: 92 cell.moving_variance.set_data(param_dict[darknet_moving_variance].data) 93 find_param.append(darknet_moving_variance) 94 else: 95 not_found_param.append(darknet_moving_variance) 96 if darknet_gamma in param_dict: 97 cell.gamma.set_data(param_dict[darknet_gamma].data) 98 find_param.append(darknet_gamma) 99 else: 100 not_found_param.append(darknet_gamma) 101 if darknet_beta in param_dict: 102 cell.beta.set_data(param_dict[darknet_beta].data) 103 find_param.append(darknet_beta) 104 else: 105 not_found_param.append(darknet_beta) 106 107 args.logger.info('================found_param {}========='.format(len(find_param))) 108 args.logger.info(find_param) 109 args.logger.info('================not_found_param {}========='.format(len(not_found_param))) 110 args.logger.info(not_found_param) 111 args.logger.info('=====load {} successfully ====='.format(ckpt_path)) 112 113 return net 114 115 116def default_wd_filter(x): 117 """default weight decay filter.""" 118 parameter_name = x.name 119 if parameter_name.endswith('.bias'): 120 # all bias not using weight decay 121 return False 122 if parameter_name.endswith('.gamma'): 123 # bn weight bias not using weight decay, be carefully for now x not include BN 124 return False 125 if parameter_name.endswith('.beta'): 126 # bn weight bias not using weight decay, be carefully for now x not include BN 127 return False 128 129 return True 130 131 132def get_param_groups(network): 133 """Param groups for optimizer.""" 134 decay_params = [] 135 no_decay_params = [] 136 for x in network.trainable_params(): 137 parameter_name = x.name 138 if parameter_name.endswith('.bias'): 139 # all bias not using weight decay 140 no_decay_params.append(x) 141 elif parameter_name.endswith('.gamma'): 142 # bn weight bias not using weight decay, be carefully for now x not include BN 143 no_decay_params.append(x) 144 elif parameter_name.endswith('.beta'): 145 # bn weight bias not using weight decay, be carefully for now x not include BN 146 no_decay_params.append(x) 147 else: 148 decay_params.append(x) 149 150 return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] 151 152 153class ShapeRecord: 154 """Log image shape.""" 155 def __init__(self): 156 self.shape_record = { 157 320: 0, 158 352: 0, 159 384: 0, 160 416: 0, 161 448: 0, 162 480: 0, 163 512: 0, 164 544: 0, 165 576: 0, 166 608: 0, 167 'total': 0 168 } 169 170 def set(self, shape): 171 if len(shape) > 1: 172 shape = shape[0] 173 shape = int(shape) 174 self.shape_record[shape] += 1 175 self.shape_record['total'] += 1 176 177 def show(self, logger): 178 for key in self.shape_record: 179 rate = self.shape_record[key] / float(self.shape_record['total']) 180 logger.info('shape {}: {:.2f}%'.format(key, rate*100)) 181 182 183def keep_loss_fp32(network): 184 """Keep loss of network with float32""" 185 for _, cell in network.cells_and_names(): 186 if isinstance(cell, (YoloLossBlock,)): 187 cell.to_float(mstype.float32) 188