• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16train and infer lenet quantization network
17"""
18
19import os
20import pytest
21from mindspore import context
22from mindspore import Tensor
23from mindspore.common import dtype as mstype
24import mindspore.nn as nn
25from mindspore.nn.metrics import Accuracy
26from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
27from mindspore import load_checkpoint, load_param_into_net, export
28from mindspore.train import Model
29from mindspore.compression.quant import QuantizationAwareTraining
30from mindspore.compression.quant.quantizer import OptimizeOption
31from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
32from dataset import create_dataset
33from config import quant_cfg
34from lenet_fusion import LeNet5 as LeNet5Fusion
35import numpy as np
36
37data_path = "/home/workspace/mindspore_dataset/mnist"
38lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
39
40def train_lenet_quant(optim_option="QAT"):
41    cfg = quant_cfg
42    ckpt_path = lenet_ckpt_path
43    ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
44    step_size = ds_train.get_dataset_size()
45
46    # define fusion network
47    network = LeNet5Fusion(cfg.num_classes)
48
49    # load quantization aware network checkpoint
50    param_dict = load_checkpoint(ckpt_path)
51    load_nonquant_param_into_quant_net(network, param_dict)
52
53    # convert fusion network to quantization aware network
54    if optim_option == "LEARNED_SCALE":
55        quant_optim_otions = OptimizeOption.LEARNED_SCALE
56        quantizer = QuantizationAwareTraining(bn_fold=False,
57                                              per_channel=[True, False],
58                                              symmetric=[True, True],
59                                              narrow_range=[True, True],
60                                              freeze_bn=0,
61                                              quant_delay=0,
62                                              one_conv_fold=True,
63                                              optimize_option=quant_optim_otions)
64    else:
65        quantizer = QuantizationAwareTraining(quant_delay=900,
66                                              bn_fold=False,
67                                              per_channel=[True, False],
68                                              symmetric=[True, False])
69    network = quantizer.quantize(network)
70
71    # define network loss
72    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
73    # define network optimization
74    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
75
76    # call back and monitor
77    config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
78                                   keep_checkpoint_max=cfg.keep_checkpoint_max)
79    ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant"+optim_option, config=config_ckpt)
80
81    # define model
82    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
83
84    print("============== Starting Training ==============")
85    model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
86                dataset_sink_mode=True)
87    print("============== End Training ==============")
88
89
90def eval_quant(optim_option="QAT"):
91    cfg = quant_cfg
92    ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
93    ckpt_path = './ckpt_lenet_quant'+optim_option+'-10_937.ckpt'
94    # define fusion network
95    network = LeNet5Fusion(cfg.num_classes)
96    # convert fusion network to quantization aware network
97    if optim_option == "LEARNED_SCALE":
98        quant_optim_otions = OptimizeOption.LEARNED_SCALE
99        quantizer = QuantizationAwareTraining(bn_fold=False,
100                                              per_channel=[True, False],
101                                              symmetric=[True, True],
102                                              narrow_range=[True, True],
103                                              freeze_bn=0,
104                                              quant_delay=0,
105                                              one_conv_fold=True,
106                                              optimize_option=quant_optim_otions)
107    else:
108        quantizer = QuantizationAwareTraining(quant_delay=0,
109                                              bn_fold=False,
110                                              freeze_bn=10000,
111                                              per_channel=[True, False],
112                                              symmetric=[True, False])
113    network = quantizer.quantize(network)
114
115    # define loss
116    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
117    # define network optimization
118    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
119
120    # call back and monitor
121    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
122
123    # load quantization aware network checkpoint
124    param_dict = load_checkpoint(ckpt_path)
125    not_load_param = load_param_into_net(network, param_dict)
126    if not_load_param:
127        raise ValueError("Load param into net fail!")
128
129    print("============== Starting Testing ==============")
130    acc = model.eval(ds_eval, dataset_sink_mode=True)
131    print("============== {} ==============".format(acc))
132    assert acc['Accuracy'] > 0.98
133
134
135def export_lenet(optim_option="QAT", file_format="MINDIR"):
136    cfg = quant_cfg
137    # define fusion network
138    network = LeNet5Fusion(cfg.num_classes)
139    # convert fusion network to quantization aware network
140    if optim_option == "LEARNED_SCALE":
141        quant_optim_otions = OptimizeOption.LEARNED_SCALE
142        quantizer = QuantizationAwareTraining(bn_fold=False,
143                                              per_channel=[True, False],
144                                              symmetric=[True, True],
145                                              narrow_range=[True, True],
146                                              freeze_bn=0,
147                                              quant_delay=0,
148                                              one_conv_fold=True,
149                                              optimize_option=quant_optim_otions)
150    else:
151        quantizer = QuantizationAwareTraining(quant_delay=0,
152                                              bn_fold=False,
153                                              freeze_bn=10000,
154                                              per_channel=[True, False],
155                                              symmetric=[True, False])
156    network = quantizer.quantize(network)
157
158    # export network
159    inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
160    export(network, inputs, file_name="lenet_quant", file_format=file_format, quant_mode='AUTO')
161
162
163@pytest.mark.level0
164@pytest.mark.platform_x86_gpu_training
165@pytest.mark.env_onecard
166def test_lenet_quant():
167    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
168    train_lenet_quant()
169    eval_quant()
170    export_lenet()
171    train_lenet_quant(optim_option="LEARNED_SCALE")
172    eval_quant(optim_option="LEARNED_SCALE")
173    export_lenet(optim_option="LEARNED_SCALE")
174
175
176@pytest.mark.level0
177@pytest.mark.platform_arm_ascend_training
178@pytest.mark.platform_x86_ascend_training
179@pytest.mark.env_onecard
180def test_lenet_quant_ascend():
181    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
182    train_lenet_quant(optim_option="LEARNED_SCALE")
183    eval_quant(optim_option="LEARNED_SCALE")
184    export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")
185