1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16""" 17######################## train YOLOv3 example ######################## 18train YOLOv3 and get network model files(.ckpt) : 19python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train 20 21If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. 22Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. 23""" 24 25import os 26import time 27import pytest 28import numpy as np 29import mindspore.nn as nn 30from mindspore import context, Tensor 31from mindspore.train import Model 32from mindspore.common.initializer import initializer 33from mindspore.train.callback import Callback 34 35from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper 36from src.dataset import create_yolo_dataset 37from src.config import ConfigYOLOV3ResNet18 38 39np.random.seed(1) 40def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): 41 """Set learning rate.""" 42 lr_each_step = [] 43 for i in range(global_step): 44 if steps: 45 lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step))) 46 else: 47 lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step))) 48 lr_each_step = np.array(lr_each_step).astype(np.float32) 49 lr_each_step = lr_each_step[start_step:] 50 return lr_each_step 51 52 53def init_net_param(network, init_value='ones'): 54 """Init:wq the parameters in network.""" 55 params = network.trainable_params() 56 for p in params: 57 if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: 58 p.set_data(initializer(init_value, p.data.shape, p.data.dtype)) 59 60class ModelCallback(Callback): 61 def __init__(self): 62 super(ModelCallback, self).__init__() 63 self.loss_list = [] 64 65 def step_end(self, run_context): 66 cb_params = run_context.original_args() 67 self.loss_list.append(cb_params.net_outputs.asnumpy()) 68 print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) 69 70class TimeMonitor(Callback): 71 """Time Monitor.""" 72 def __init__(self, data_size): 73 super(TimeMonitor, self).__init__() 74 self.data_size = data_size 75 self.epoch_mseconds_list = [] 76 self.per_step_mseconds_list = [] 77 def epoch_begin(self, run_context): 78 self.epoch_time = time.time() 79 80 def epoch_end(self, run_context): 81 epoch_mseconds = (time.time() - self.epoch_time) * 1000 82 self.epoch_mseconds_list.append(epoch_mseconds) 83 self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) 84 85DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3" 86 87@pytest.mark.level1 88@pytest.mark.platform_arm_ascend_training 89@pytest.mark.platform_x86_ascend_training 90@pytest.mark.env_single 91def test_yolov3(): 92 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 93 rank = 0 94 device_num = 1 95 lr_init = 0.001 96 epoch_size = 5 97 batch_size = 32 98 loss_scale = 1024 99 mindrecord_dir = DATA_DIR 100 101 # It will generate mindrecord file in args_opt.mindrecord_dir, 102 # and the file name is yolo.mindrecord0, 1, ... file_num. 103 if not os.path.isdir(mindrecord_dir): 104 raise KeyError("mindrecord path is not exist.") 105 106 prefix = "yolo.mindrecord" 107 mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") 108 print("yolov3 mindrecord is ", mindrecord_file) 109 if not os.path.exists(mindrecord_file): 110 print("mindrecord file is not exist.") 111 assert False 112 else: 113 loss_scale = float(loss_scale) 114 115 # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. 116 dataset = create_yolo_dataset(mindrecord_file, repeat_num=1, 117 batch_size=batch_size, device_num=device_num, rank=rank) 118 dataset_size = dataset.get_dataset_size() 119 print("Create dataset done!") 120 121 net = yolov3_resnet18(ConfigYOLOV3ResNet18()) 122 net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) 123 124 total_epoch_size = 60 125 lr = Tensor(get_lr(learning_rate=lr_init, start_step=0, 126 global_step=total_epoch_size * dataset_size, 127 decay_step=1000, decay_rate=0.95, steps=True)) 128 opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) 129 net = TrainingWrapper(net, opt, loss_scale) 130 131 model_callback = ModelCallback() 132 time_monitor_callback = TimeMonitor(data_size=dataset_size) 133 callback = [model_callback, time_monitor_callback] 134 135 model = Model(net) 136 print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") 137 model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True, 138 sink_size=dataset.get_dataset_size()) 139 # assertion occurs while the loss value, overflow state or loss_scale value is wrong 140 loss_value = np.array(model_callback.loss_list) 141 142 expect_loss_value = [6850, 4250, 2750] 143 print("loss value: {}".format(loss_value)) 144 assert loss_value[0] < expect_loss_value[0] 145 assert loss_value[1] < expect_loss_value[1] 146 assert loss_value[2] < expect_loss_value[2] 147 148 epoch_mseconds0 = np.array(time_monitor_callback.epoch_mseconds_list)[2] 149 epoch_mseconds1 = np.array(time_monitor_callback.epoch_mseconds_list)[3] 150 epoch_mseconds2 = np.array(time_monitor_callback.epoch_mseconds_list)[4] 151 expect_epoch_mseconds = 1250 152 print("epoch mseconds: {}".format(epoch_mseconds0)) 153 assert epoch_mseconds0 <= expect_epoch_mseconds or \ 154 epoch_mseconds1 <= expect_epoch_mseconds or \ 155 epoch_mseconds2 <= expect_epoch_mseconds 156 157 per_step_mseconds0 = np.array(time_monitor_callback.per_step_mseconds_list)[2] 158 per_step_mseconds1 = np.array(time_monitor_callback.per_step_mseconds_list)[3] 159 per_step_mseconds2 = np.array(time_monitor_callback.per_step_mseconds_list)[4] 160 expect_per_step_mseconds = 130 161 print("per step mseconds: {}".format(per_step_mseconds0)) 162 assert per_step_mseconds0 <= expect_per_step_mseconds or \ 163 per_step_mseconds1 <= expect_per_step_mseconds or \ 164 per_step_mseconds2 <= expect_per_step_mseconds 165 print("yolov3 test case passed.") 166