1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16train Conv2dBnFoldQuant Cell 17""" 18 19import pytest 20import numpy as np 21from mindspore import nn 22from mindspore import context 23from mindspore import Tensor 24from mindspore.common import set_seed 25from mindspore.compression.quant import create_quant_config 26 27class Net(nn.Cell): 28 def __init__(self, qconfig): 29 super(Net, self).__init__() 30 self.conv = nn.Conv2dBnFoldQuant(2, 3, kernel_size=(2, 2), stride=(1, 1), 31 pad_mode='valid', quant_config=qconfig) 32 def construct(self, x): 33 return self.conv(x) 34 35def test_conv2d_bn_fold_quant(): 36 set_seed(1) 37 quant_config = create_quant_config() 38 network = Net(quant_config) 39 inputs = Tensor(np.ones([1, 2, 5, 5]).astype(np.float32)) 40 label = Tensor(np.ones([1, 3, 4, 4]).astype(np.int32)) 41 opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), learning_rate=0.1, momentum=0.9) 42 loss = nn.MSELoss() 43 net_with_loss = nn.WithLossCell(network, loss) 44 train_network = nn.TrainOneStepCell(net_with_loss, opt) 45 train_network.set_train() 46 out_loss = train_network(inputs, label) 47 expect_loss = np.array([0.940427]) 48 error = np.array([0.1]) 49 diff = out_loss.asnumpy() - expect_loss 50 assert np.all(abs(diff) < error) 51 52@pytest.mark.level0 53@pytest.mark.platform_arm_ascend_training 54@pytest.mark.platform_x86_ascend_training 55@pytest.mark.env_onecard 56def test_conv2d_bn_fold_quant_ascend(): 57 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 58 test_conv2d_bn_fold_quant() 59