1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16 17"""test quant_v2""" 18import numpy as np 19import pytest 20import mindspore.common.dtype as mstype 21 22from mindspore.ops.operations._infer_ops import QuantV2 23from mindspore import Tensor, jit, JitConfig 24from tests.st.utils import test_utils 25 26 27def generate_random_input(shape, dtype, tensor_type): 28 np.random.seed(0) 29 return Tensor(np.random.randn(*shape).astype(dtype), dtype=tensor_type) 30 31 32def generate_expect_output(data, scale, offset, round_mode): 33 if round_mode == "ROUND": 34 out = np.around(data * scale + offset) 35 elif round_mode == "FLOOR": 36 out = np.floor(data * scale + offset) 37 elif round_mode == "CEIL": 38 out = np.ceil(data * scale + offset) 39 elif round_mode == "TRUNC": 40 out = np.trunc(data * scale + offset) 41 else: 42 out = np.around(data * scale + offset) 43 return out.astype(np.int8) 44 45 46@test_utils.run_with_cell 47def quant_forward_func(data, scale, offset, sqrt_mode, round_mode, out_type): 48 net = QuantV2() 49 return net(data, scale, offset, sqrt_mode, round_mode, out_type) 50 51 52@pytest.mark.level0 53@pytest.mark.env_onecard 54@pytest.mark.platform_arm_ascend910b_training 55@pytest.mark.platform_x86_ascend910b_training 56@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE']) 57@pytest.mark.parametrize('rounding', ['ROUND', 'FLOOR', 'CEIL', 'TRUNC']) 58@pytest.mark.parametrize('support_type', [mstype.float32, mstype.float16, mstype.bfloat16]) 59def test_quant_static_shape(mode, rounding, support_type): 60 """ 61 Feature: Test quant_v2 with static shape in graph and pynative mode. 62 Description: call ops.quant_v2 with valid input and index. 63 Expectation: return the correct value. 64 """ 65 np.random.seed(1) 66 x = generate_random_input((2, 3, 4, 5), np.float32, support_type) 67 scale = generate_random_input((5,), np.float32, support_type) 68 offset = generate_random_input((5,), np.float32, support_type) 69 70 if mode == 'pynative': 71 ms_out = quant_forward_func(x, scale, offset, False, rounding, mstype.int8) 72 elif mode == 'KBK': 73 ms_out = (jit(quant_forward_func, jit_config=JitConfig(jit_level="O0")))\ 74 (x, scale, offset, False, rounding, mstype.int8) 75 else: 76 ms_out = (jit(quant_forward_func, jit_config=JitConfig(jit_level="O2")))\ 77 (x, scale, offset, False, rounding, mstype.int8) 78 79 if support_type == mstype.bfloat16: 80 expect = \ 81 generate_expect_output(x.float().asnumpy(), scale.float().asnumpy(), offset.float().asnumpy(), rounding) 82 else: 83 expect = generate_expect_output(x.asnumpy(), scale.asnumpy(), offset.asnumpy(), rounding) 84 np.testing.assert_allclose(ms_out.asnumpy(), expect) 85 86 87@pytest.mark.level0 88@pytest.mark.env_onecard 89@pytest.mark.platform_arm_ascend910b_training 90@pytest.mark.platform_x86_ascend910b_training 91@pytest.mark.parametrize('rounding', ['ROUND', 'FLOOR', 'CEIL', 'TRUNC']) 92def test_quant_dynamic_shape(rounding): 93 """ 94 Feature: Test quant_v2 with dynamic shape in graph mode. 95 Description: call ops.quant_v2 with valid input and index. 96 Expectation: return the correct value. 97 """ 98 np.random.seed(1) 99 x = generate_random_input((2, 3, 4, 5), np.float32, mstype.float32) 100 scale = generate_random_input((5,), np.float32, mstype.float32) 101 offset = generate_random_input((5,), np.float32, mstype.float32) 102 103 x_dyn = Tensor(shape=[None, None, None, None], dtype=mstype.float32) 104 scale_dyn = Tensor(shape=[None], dtype=mstype.float32) 105 offset_dyn = Tensor(shape=[None], dtype=mstype.float32) 106 107 test_cell = test_utils.to_cell_obj(quant_forward_func) 108 test_cell.set_inputs(x_dyn, scale_dyn, offset_dyn, False, rounding, mstype.int8) 109 ms_out = test_cell(x, scale, offset, False, rounding, mstype.int8) 110 111 expect = generate_expect_output(x.asnumpy(), scale.asnumpy(), offset.asnumpy(), rounding) 112 np.testing.assert_allclose(ms_out.asnumpy(), expect) 113