1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Train Mobilenetv2_quant on Cifar10""" 16 17 18import pytest 19import numpy as np 20from easydict import EasyDict as ed 21 22from mindspore import context 23from mindspore import Tensor 24from mindspore import nn 25from mindspore.train.model import Model 26from mindspore.compression.quant import QuantizationAwareTraining 27from mindspore.common import set_seed 28 29from dataset import create_dataset 30from lr_generator import get_lr 31from utils import Monitor, CrossEntropyWithLabelSmooth 32from mobilenetV2 import mobilenetV2 33 34config_ascend_quant = ed({ 35 "num_classes": 10, 36 "image_height": 224, 37 "image_width": 224, 38 "batch_size": 200, 39 "step_threshold": 10, 40 "data_load_mode": "mindata", 41 "epoch_size": 1, 42 "start_epoch": 200, 43 "warmup_epochs": 1, 44 "lr": 0.3, 45 "momentum": 0.9, 46 "weight_decay": 4e-5, 47 "label_smooth": 0.1, 48 "loss_scale": 1024, 49 "save_checkpoint": True, 50 "save_checkpoint_epochs": 1, 51 "keep_checkpoint_max": 300, 52 "save_checkpoint_path": "./checkpoint", 53}) 54 55dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/" 56 57 58@pytest.mark.level0 59@pytest.mark.platform_arm_ascend_training 60@pytest.mark.platform_x86_ascend_training 61@pytest.mark.env_single 62def test_mobilenetv2_quant(): 63 set_seed(1) 64 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 65 config = config_ascend_quant 66 print("training configure: {}".format(config)) 67 68 epoch_size = config.epoch_size 69 70 # define network 71 network = mobilenetV2(num_classes=config.num_classes) 72 # define loss 73 if config.label_smooth > 0: 74 loss = CrossEntropyWithLabelSmooth( 75 smooth_factor=config.label_smooth, num_classes=config.num_classes) 76 else: 77 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 78 # define dataset 79 dataset = create_dataset(dataset_path=dataset_path, 80 config=config, 81 repeat_num=1, 82 batch_size=config.batch_size) 83 step_size = dataset.get_dataset_size() 84 85 # convert fusion network to quantization aware network 86 quantizer = QuantizationAwareTraining(bn_fold=True, 87 per_channel=[True, False], 88 symmetric=[True, False]) 89 network = quantizer.quantize(network) 90 91 # get learning rate 92 lr = Tensor(get_lr(global_step=config.start_epoch * step_size, 93 lr_init=0, 94 lr_end=0, 95 lr_max=config.lr, 96 warmup_epochs=config.warmup_epochs, 97 total_epochs=epoch_size + config.start_epoch, 98 steps_per_epoch=step_size)) 99 100 # define optimization 101 opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, 102 config.weight_decay) 103 # define model 104 model = Model(network, loss_fn=loss, optimizer=opt) 105 106 print("============== Starting Training ==============") 107 monitor = Monitor(lr_init=lr.asnumpy(), 108 step_threshold=config.step_threshold) 109 callback = [monitor] 110 model.train(epoch_size, dataset, callbacks=callback, 111 dataset_sink_mode=False) 112 print("============== End Training ==============") 113 114 export_time_used = 650 115 train_time = monitor.step_mseconds 116 print('train_time_used:{}'.format(train_time)) 117 assert train_time < export_time_used 118 expect_avg_step_loss = 2.32 119 avg_step_loss = np.mean(np.array(monitor.losses)) 120 print("average step loss:{}".format(avg_step_loss)) 121 assert avg_step_loss < expect_avg_step_loss 122 123 124if __name__ == '__main__': 125 test_mobilenetv2_quant() 126