• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Train Resnet50_quant on Cifar10"""
16
17import pytest
18import numpy as np
19from easydict import EasyDict as ed
20
21from mindspore import context
22from mindspore import Tensor
23from mindspore.nn.optim.momentum import Momentum
24from mindspore.train.model import Model
25from mindspore.compression.quant import QuantizationAwareTraining
26from mindspore import set_seed
27
28from resnet_quant_manual import resnet50_quant
29from dataset import create_dataset
30from lr_generator import get_lr
31from utils import Monitor, CrossEntropy
32
33
34config_quant = ed({
35    "class_num": 10,
36    "batch_size": 128,
37    "step_threshold": 20,
38    "loss_scale": 1024,
39    "momentum": 0.9,
40    "weight_decay": 1e-4,
41    "epoch_size": 1,
42    "pretrained_epoch_size": 90,
43    "buffer_size": 1000,
44    "image_height": 224,
45    "image_width": 224,
46    "data_load_mode": "original",
47    "save_checkpoint": True,
48    "save_checkpoint_epochs": 1,
49    "keep_checkpoint_max": 50,
50    "save_checkpoint_path": "./",
51    "warmup_epochs": 0,
52    "lr_decay_mode": "cosine",
53    "use_label_smooth": True,
54    "label_smooth_factor": 0.1,
55    "lr_init": 0,
56    "lr_max": 0.005,
57})
58
59dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
60
61
62@pytest.mark.level1
63@pytest.mark.platform_arm_ascend_training
64@pytest.mark.platform_x86_ascend_training
65@pytest.mark.env_onecard
66def test_resnet50_quant():
67    set_seed(1)
68    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
69    config = config_quant
70    print("training configure: {}".format(config))
71    epoch_size = config.epoch_size
72
73    # define network
74    net = resnet50_quant(class_num=config.class_num)
75    net.set_train(True)
76
77    # define loss
78    if not config.use_label_smooth:
79        config.label_smooth_factor = 0.0
80    loss = CrossEntropy(
81        smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
82    #loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
83
84    # define dataset
85    dataset = create_dataset(dataset_path=dataset_path,
86                             config=config,
87                             repeat_num=1,
88                             batch_size=config.batch_size)
89    step_size = dataset.get_dataset_size()
90
91    # convert fusion network to quantization aware network
92    quantizer = QuantizationAwareTraining(bn_fold=True,
93                                          per_channel=[True, False],
94                                          symmetric=[True, False])
95    net = quantizer.quantize(net)
96
97    # get learning rate
98    lr = Tensor(get_lr(lr_init=config.lr_init,
99                       lr_end=0.0,
100                       lr_max=config.lr_max,
101                       warmup_epochs=config.warmup_epochs,
102                       total_epochs=config.epoch_size,
103                       steps_per_epoch=step_size,
104                       lr_decay_mode='cosine'))
105
106    # define optimization
107    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
108                   config.weight_decay, config.loss_scale)
109
110    # define model
111    #model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
112    model = Model(net, loss_fn=loss, optimizer=opt)
113
114    print("============== Starting Training ==============")
115    monitor = Monitor(lr_init=lr.asnumpy(),
116                      step_threshold=config.step_threshold)
117
118    callbacks = [monitor]
119    model.train(epoch_size, dataset, callbacks=callbacks,
120                dataset_sink_mode=False)
121    print("============== End Training ==============")
122
123    expect_avg_step_loss = 2.60
124    avg_step_loss = np.mean(np.array(monitor.losses))
125
126    print("average step loss:{}".format(avg_step_loss))
127    assert avg_step_loss < expect_avg_step_loss
128
129
130if __name__ == '__main__':
131    test_resnet50_quant()
132