1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16train and infer lenet quantization network 17""" 18 19import os 20import pytest 21from mindspore import context 22from mindspore import Tensor 23from mindspore.common import dtype as mstype 24import mindspore.nn as nn 25from mindspore.nn.metrics import Accuracy 26from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor 27from mindspore import load_checkpoint, load_param_into_net, export 28from mindspore.train import Model 29from mindspore.compression.quant import QuantizationAwareTraining 30from mindspore.compression.quant.quantizer import OptimizeOption 31from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net 32from dataset import create_dataset 33from config import quant_cfg 34from lenet_fusion import LeNet5 as LeNet5Fusion 35import numpy as np 36 37data_path = "/home/workspace/mindspore_dataset/mnist" 38lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt" 39 40def train_lenet_quant(optim_option="QAT"): 41 cfg = quant_cfg 42 ckpt_path = lenet_ckpt_path 43 ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1) 44 step_size = ds_train.get_dataset_size() 45 46 # define fusion network 47 network = LeNet5Fusion(cfg.num_classes) 48 49 # load quantization aware network checkpoint 50 param_dict = load_checkpoint(ckpt_path) 51 load_nonquant_param_into_quant_net(network, param_dict) 52 53 # convert fusion network to quantization aware network 54 if optim_option == "LEARNED_SCALE": 55 quant_optim_otions = OptimizeOption.LEARNED_SCALE 56 quantizer = QuantizationAwareTraining(bn_fold=False, 57 per_channel=[True, False], 58 symmetric=[True, True], 59 narrow_range=[True, True], 60 freeze_bn=0, 61 quant_delay=0, 62 one_conv_fold=True, 63 optimize_option=quant_optim_otions) 64 else: 65 quantizer = QuantizationAwareTraining(quant_delay=900, 66 bn_fold=False, 67 per_channel=[True, False], 68 symmetric=[True, False]) 69 network = quantizer.quantize(network) 70 71 # define network loss 72 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 73 # define network optimization 74 net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) 75 76 # call back and monitor 77 config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, 78 keep_checkpoint_max=cfg.keep_checkpoint_max) 79 ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant"+optim_option, config=config_ckpt) 80 81 # define model 82 model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) 83 84 print("============== Starting Training ==============") 85 model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], 86 dataset_sink_mode=True) 87 print("============== End Training ==============") 88 89 90def eval_quant(optim_option="QAT"): 91 cfg = quant_cfg 92 ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1) 93 ckpt_path = './ckpt_lenet_quant'+optim_option+'-10_937.ckpt' 94 # define fusion network 95 network = LeNet5Fusion(cfg.num_classes) 96 # convert fusion network to quantization aware network 97 if optim_option == "LEARNED_SCALE": 98 quant_optim_otions = OptimizeOption.LEARNED_SCALE 99 quantizer = QuantizationAwareTraining(bn_fold=False, 100 per_channel=[True, False], 101 symmetric=[True, True], 102 narrow_range=[True, True], 103 freeze_bn=0, 104 quant_delay=0, 105 one_conv_fold=True, 106 optimize_option=quant_optim_otions) 107 else: 108 quantizer = QuantizationAwareTraining(quant_delay=0, 109 bn_fold=False, 110 freeze_bn=10000, 111 per_channel=[True, False], 112 symmetric=[True, False]) 113 network = quantizer.quantize(network) 114 115 # define loss 116 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 117 # define network optimization 118 net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) 119 120 # call back and monitor 121 model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) 122 123 # load quantization aware network checkpoint 124 param_dict = load_checkpoint(ckpt_path) 125 not_load_param = load_param_into_net(network, param_dict) 126 if not_load_param: 127 raise ValueError("Load param into net fail!") 128 129 print("============== Starting Testing ==============") 130 acc = model.eval(ds_eval, dataset_sink_mode=True) 131 print("============== {} ==============".format(acc)) 132 assert acc['Accuracy'] > 0.98 133 134 135def export_lenet(optim_option="QAT", file_format="MINDIR"): 136 cfg = quant_cfg 137 # define fusion network 138 network = LeNet5Fusion(cfg.num_classes) 139 # convert fusion network to quantization aware network 140 if optim_option == "LEARNED_SCALE": 141 quant_optim_otions = OptimizeOption.LEARNED_SCALE 142 quantizer = QuantizationAwareTraining(bn_fold=False, 143 per_channel=[True, False], 144 symmetric=[True, True], 145 narrow_range=[True, True], 146 freeze_bn=0, 147 quant_delay=0, 148 one_conv_fold=True, 149 optimize_option=quant_optim_otions) 150 else: 151 quantizer = QuantizationAwareTraining(quant_delay=0, 152 bn_fold=False, 153 freeze_bn=10000, 154 per_channel=[True, False], 155 symmetric=[True, False]) 156 network = quantizer.quantize(network) 157 158 # export network 159 inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) 160 export(network, inputs, file_name="lenet_quant", file_format=file_format, quant_mode='AUTO') 161 162 163@pytest.mark.level0 164@pytest.mark.platform_x86_gpu_training 165@pytest.mark.env_onecard 166def test_lenet_quant(): 167 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 168 train_lenet_quant() 169 eval_quant() 170 export_lenet() 171 train_lenet_quant(optim_option="LEARNED_SCALE") 172 eval_quant(optim_option="LEARNED_SCALE") 173 export_lenet(optim_option="LEARNED_SCALE") 174 175 176@pytest.mark.level0 177@pytest.mark.platform_arm_ascend_training 178@pytest.mark.platform_x86_ascend_training 179@pytest.mark.env_onecard 180def test_lenet_quant_ascend(): 181 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 182 train_lenet_quant(optim_option="LEARNED_SCALE") 183 eval_quant(optim_option="LEARNED_SCALE") 184 export_lenet(optim_option="LEARNED_SCALE", file_format="AIR") 185