• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""
17######################## train YOLOv3 example ########################
18train YOLOv3 and get network model files(.ckpt) :
19python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
20
21If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
22Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
23"""
24
25import os
26import time
27import pytest
28import numpy as np
29import mindspore.nn as nn
30from mindspore import context, Tensor
31from mindspore.train import Model
32from mindspore.common.initializer import initializer
33from mindspore.train.callback import Callback
34
35from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
36from src.dataset import create_yolo_dataset
37from src.config import ConfigYOLOV3ResNet18
38
39np.random.seed(1)
40def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
41    """Set learning rate."""
42    lr_each_step = []
43    for i in range(global_step):
44        if steps:
45            lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
46        else:
47            lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
48    lr_each_step = np.array(lr_each_step).astype(np.float32)
49    lr_each_step = lr_each_step[start_step:]
50    return lr_each_step
51
52
53def init_net_param(network, init_value='ones'):
54    """Init:wq the parameters in network."""
55    params = network.trainable_params()
56    for p in params:
57        if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
58            p.set_data(initializer(init_value, p.data.shape, p.data.dtype))
59
60class ModelCallback(Callback):
61    def __init__(self):
62        super(ModelCallback, self).__init__()
63        self.loss_list = []
64
65    def step_end(self, run_context):
66        cb_params = run_context.original_args()
67        self.loss_list.append(cb_params.net_outputs.asnumpy())
68        print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
69
70class TimeMonitor(Callback):
71    """Time Monitor."""
72    def __init__(self, data_size):
73        super(TimeMonitor, self).__init__()
74        self.data_size = data_size
75        self.epoch_mseconds_list = []
76        self.per_step_mseconds_list = []
77    def epoch_begin(self, run_context):
78        self.epoch_time = time.time()
79
80    def epoch_end(self, run_context):
81        epoch_mseconds = (time.time() - self.epoch_time) * 1000
82        self.epoch_mseconds_list.append(epoch_mseconds)
83        self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
84
85DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3"
86
87@pytest.mark.level1
88@pytest.mark.platform_arm_ascend_training
89@pytest.mark.platform_x86_ascend_training
90@pytest.mark.env_single
91def test_yolov3():
92    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
93    rank = 0
94    device_num = 1
95    lr_init = 0.001
96    epoch_size = 5
97    batch_size = 32
98    loss_scale = 1024
99    mindrecord_dir = DATA_DIR
100
101    # It will generate mindrecord file in args_opt.mindrecord_dir,
102    # and the file name is yolo.mindrecord0, 1, ... file_num.
103    if not os.path.isdir(mindrecord_dir):
104        raise KeyError("mindrecord path is not exist.")
105
106    prefix = "yolo.mindrecord"
107    mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
108    print("yolov3 mindrecord is ", mindrecord_file)
109    if not os.path.exists(mindrecord_file):
110        print("mindrecord file is not exist.")
111        assert False
112    else:
113        loss_scale = float(loss_scale)
114
115        # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
116        dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
117                                      batch_size=batch_size, device_num=device_num, rank=rank)
118        dataset_size = dataset.get_dataset_size()
119        print("Create dataset done!")
120
121        net = yolov3_resnet18(ConfigYOLOV3ResNet18())
122        net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
123
124        total_epoch_size = 60
125        lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
126                           global_step=total_epoch_size * dataset_size,
127                           decay_step=1000, decay_rate=0.95, steps=True))
128        opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
129        net = TrainingWrapper(net, opt, loss_scale)
130
131        model_callback = ModelCallback()
132        time_monitor_callback = TimeMonitor(data_size=dataset_size)
133        callback = [model_callback, time_monitor_callback]
134
135        model = Model(net)
136        print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
137        model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True,
138                    sink_size=dataset.get_dataset_size())
139        # assertion occurs while the loss value, overflow state or loss_scale value is wrong
140        loss_value = np.array(model_callback.loss_list)
141
142        expect_loss_value = [6850, 4250, 2750]
143        print("loss value: {}".format(loss_value))
144        assert loss_value[0] < expect_loss_value[0]
145        assert loss_value[1] < expect_loss_value[1]
146        assert loss_value[2] < expect_loss_value[2]
147
148        epoch_mseconds0 = np.array(time_monitor_callback.epoch_mseconds_list)[2]
149        epoch_mseconds1 = np.array(time_monitor_callback.epoch_mseconds_list)[3]
150        epoch_mseconds2 = np.array(time_monitor_callback.epoch_mseconds_list)[4]
151        expect_epoch_mseconds = 1250
152        print("epoch mseconds: {}".format(epoch_mseconds0))
153        assert epoch_mseconds0 <= expect_epoch_mseconds or \
154               epoch_mseconds1 <= expect_epoch_mseconds or \
155               epoch_mseconds2 <= expect_epoch_mseconds
156
157        per_step_mseconds0 = np.array(time_monitor_callback.per_step_mseconds_list)[2]
158        per_step_mseconds1 = np.array(time_monitor_callback.per_step_mseconds_list)[3]
159        per_step_mseconds2 = np.array(time_monitor_callback.per_step_mseconds_list)[4]
160        expect_per_step_mseconds = 130
161        print("per step mseconds: {}".format(per_step_mseconds0))
162        assert per_step_mseconds0 <= expect_per_step_mseconds or \
163               per_step_mseconds1 <= expect_per_step_mseconds or \
164               per_step_mseconds2 <= expect_per_step_mseconds
165        print("yolov3 test case passed.")
166