1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Train Resnet50_quant on Cifar10""" 16 17import pytest 18import numpy as np 19from easydict import EasyDict as ed 20 21from mindspore import context 22from mindspore import Tensor 23from mindspore.nn.optim.momentum import Momentum 24from mindspore.train.model import Model 25from mindspore.compression.quant import QuantizationAwareTraining 26from mindspore import set_seed 27 28from resnet_quant_manual import resnet50_quant 29from dataset import create_dataset 30from lr_generator import get_lr 31from utils import Monitor, CrossEntropy 32 33 34config_quant = ed({ 35 "class_num": 10, 36 "batch_size": 128, 37 "step_threshold": 20, 38 "loss_scale": 1024, 39 "momentum": 0.9, 40 "weight_decay": 1e-4, 41 "epoch_size": 1, 42 "pretrained_epoch_size": 90, 43 "buffer_size": 1000, 44 "image_height": 224, 45 "image_width": 224, 46 "data_load_mode": "original", 47 "save_checkpoint": True, 48 "save_checkpoint_epochs": 1, 49 "keep_checkpoint_max": 50, 50 "save_checkpoint_path": "./", 51 "warmup_epochs": 0, 52 "lr_decay_mode": "cosine", 53 "use_label_smooth": True, 54 "label_smooth_factor": 0.1, 55 "lr_init": 0, 56 "lr_max": 0.005, 57}) 58 59dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/" 60 61 62@pytest.mark.level1 63@pytest.mark.platform_arm_ascend_training 64@pytest.mark.platform_x86_ascend_training 65@pytest.mark.env_onecard 66def test_resnet50_quant(): 67 set_seed(1) 68 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 69 config = config_quant 70 print("training configure: {}".format(config)) 71 epoch_size = config.epoch_size 72 73 # define network 74 net = resnet50_quant(class_num=config.class_num) 75 net.set_train(True) 76 77 # define loss 78 if not config.use_label_smooth: 79 config.label_smooth_factor = 0.0 80 loss = CrossEntropy( 81 smooth_factor=config.label_smooth_factor, num_classes=config.class_num) 82 #loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) 83 84 # define dataset 85 dataset = create_dataset(dataset_path=dataset_path, 86 config=config, 87 repeat_num=1, 88 batch_size=config.batch_size) 89 step_size = dataset.get_dataset_size() 90 91 # convert fusion network to quantization aware network 92 quantizer = QuantizationAwareTraining(bn_fold=True, 93 per_channel=[True, False], 94 symmetric=[True, False]) 95 net = quantizer.quantize(net) 96 97 # get learning rate 98 lr = Tensor(get_lr(lr_init=config.lr_init, 99 lr_end=0.0, 100 lr_max=config.lr_max, 101 warmup_epochs=config.warmup_epochs, 102 total_epochs=config.epoch_size, 103 steps_per_epoch=step_size, 104 lr_decay_mode='cosine')) 105 106 # define optimization 107 opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, 108 config.weight_decay, config.loss_scale) 109 110 # define model 111 #model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) 112 model = Model(net, loss_fn=loss, optimizer=opt) 113 114 print("============== Starting Training ==============") 115 monitor = Monitor(lr_init=lr.asnumpy(), 116 step_threshold=config.step_threshold) 117 118 callbacks = [monitor] 119 model.train(epoch_size, dataset, callbacks=callbacks, 120 dataset_sink_mode=False) 121 print("============== End Training ==============") 122 123 expect_avg_step_loss = 2.60 124 avg_step_loss = np.mean(np.array(monitor.losses)) 125 126 print("average step loss:{}".format(avg_step_loss)) 127 assert avg_step_loss < expect_avg_step_loss 128 129 130if __name__ == '__main__': 131 test_resnet50_quant() 132