• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Train Mobilenetv2_quant on Cifar10"""
16
17
18import pytest
19import numpy as np
20from easydict import EasyDict as ed
21
22from mindspore import context
23from mindspore import Tensor
24from mindspore import nn
25from mindspore.train.model import Model
26from mindspore.compression.quant import QuantizationAwareTraining
27from mindspore.common import set_seed
28
29from dataset import create_dataset
30from lr_generator import get_lr
31from utils import Monitor, CrossEntropyWithLabelSmooth
32from mobilenetV2 import mobilenetV2
33
34config_ascend_quant = ed({
35    "num_classes": 10,
36    "image_height": 224,
37    "image_width": 224,
38    "batch_size": 200,
39    "step_threshold": 10,
40    "data_load_mode": "mindata",
41    "epoch_size": 1,
42    "start_epoch": 200,
43    "warmup_epochs": 1,
44    "lr": 0.3,
45    "momentum": 0.9,
46    "weight_decay": 4e-5,
47    "label_smooth": 0.1,
48    "loss_scale": 1024,
49    "save_checkpoint": True,
50    "save_checkpoint_epochs": 1,
51    "keep_checkpoint_max": 300,
52    "save_checkpoint_path": "./checkpoint",
53})
54
55dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
56
57
58@pytest.mark.level0
59@pytest.mark.platform_arm_ascend_training
60@pytest.mark.platform_x86_ascend_training
61@pytest.mark.env_single
62def test_mobilenetv2_quant():
63    set_seed(1)
64    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
65    config = config_ascend_quant
66    print("training configure: {}".format(config))
67
68    epoch_size = config.epoch_size
69
70    # define network
71    network = mobilenetV2(num_classes=config.num_classes)
72    # define loss
73    if config.label_smooth > 0:
74        loss = CrossEntropyWithLabelSmooth(
75            smooth_factor=config.label_smooth, num_classes=config.num_classes)
76    else:
77        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
78    # define dataset
79    dataset = create_dataset(dataset_path=dataset_path,
80                             config=config,
81                             repeat_num=1,
82                             batch_size=config.batch_size)
83    step_size = dataset.get_dataset_size()
84
85    # convert fusion network to quantization aware network
86    quantizer = QuantizationAwareTraining(bn_fold=True,
87                                          per_channel=[True, False],
88                                          symmetric=[True, False])
89    network = quantizer.quantize(network)
90
91    # get learning rate
92    lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
93                       lr_init=0,
94                       lr_end=0,
95                       lr_max=config.lr,
96                       warmup_epochs=config.warmup_epochs,
97                       total_epochs=epoch_size + config.start_epoch,
98                       steps_per_epoch=step_size))
99
100    # define optimization
101    opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
102                      config.weight_decay)
103    # define model
104    model = Model(network, loss_fn=loss, optimizer=opt)
105
106    print("============== Starting Training ==============")
107    monitor = Monitor(lr_init=lr.asnumpy(),
108                      step_threshold=config.step_threshold)
109    callback = [monitor]
110    model.train(epoch_size, dataset, callbacks=callback,
111                dataset_sink_mode=False)
112    print("============== End Training ==============")
113
114    export_time_used = 650
115    train_time = monitor.step_mseconds
116    print('train_time_used:{}'.format(train_time))
117    assert train_time < export_time_used
118    expect_avg_step_loss = 2.32
119    avg_step_loss = np.mean(np.array(monitor.losses))
120    print("average step loss:{}".format(avg_step_loss))
121    assert avg_step_loss < expect_avg_step_loss
122
123
124if __name__ == '__main__':
125    test_mobilenetv2_quant()
126