• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16
17"""test quant_v2"""
18import numpy as np
19import pytest
20import mindspore.common.dtype as mstype
21
22from mindspore.ops.operations._infer_ops import QuantV2
23from mindspore import Tensor, jit, JitConfig
24from tests.st.utils import test_utils
25
26
27def generate_random_input(shape, dtype, tensor_type):
28    np.random.seed(0)
29    return Tensor(np.random.randn(*shape).astype(dtype), dtype=tensor_type)
30
31
32def generate_expect_output(data, scale, offset, round_mode):
33    if round_mode == "ROUND":
34        out = np.around(data * scale + offset)
35    elif round_mode == "FLOOR":
36        out = np.floor(data * scale + offset)
37    elif round_mode == "CEIL":
38        out = np.ceil(data * scale + offset)
39    elif round_mode == "TRUNC":
40        out = np.trunc(data * scale + offset)
41    else:
42        out = np.around(data * scale + offset)
43    return out.astype(np.int8)
44
45
46@test_utils.run_with_cell
47def quant_forward_func(data, scale, offset, sqrt_mode, round_mode, out_type):
48    net = QuantV2()
49    return net(data, scale, offset, sqrt_mode, round_mode, out_type)
50
51
52@pytest.mark.level0
53@pytest.mark.env_onecard
54@pytest.mark.platform_arm_ascend910b_training
55@pytest.mark.platform_x86_ascend910b_training
56@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE'])
57@pytest.mark.parametrize('rounding', ['ROUND', 'FLOOR', 'CEIL', 'TRUNC'])
58@pytest.mark.parametrize('support_type', [mstype.float32, mstype.float16, mstype.bfloat16])
59def test_quant_static_shape(mode, rounding, support_type):
60    """
61    Feature: Test quant_v2 with static shape in graph and pynative mode.
62    Description: call ops.quant_v2 with valid input and index.
63    Expectation: return the correct value.
64    """
65    np.random.seed(1)
66    x = generate_random_input((2, 3, 4, 5), np.float32, support_type)
67    scale = generate_random_input((5,), np.float32, support_type)
68    offset = generate_random_input((5,), np.float32, support_type)
69
70    if mode == 'pynative':
71        ms_out = quant_forward_func(x, scale, offset, False, rounding, mstype.int8)
72    elif mode == 'KBK':
73        ms_out = (jit(quant_forward_func, jit_config=JitConfig(jit_level="O0")))\
74            (x, scale, offset, False, rounding, mstype.int8)
75    else:
76        ms_out = (jit(quant_forward_func, jit_config=JitConfig(jit_level="O2")))\
77            (x, scale, offset, False, rounding, mstype.int8)
78
79    if support_type == mstype.bfloat16:
80        expect = \
81            generate_expect_output(x.float().asnumpy(), scale.float().asnumpy(), offset.float().asnumpy(), rounding)
82    else:
83        expect = generate_expect_output(x.asnumpy(), scale.asnumpy(), offset.asnumpy(), rounding)
84    np.testing.assert_allclose(ms_out.asnumpy(), expect)
85
86
87@pytest.mark.level0
88@pytest.mark.env_onecard
89@pytest.mark.platform_arm_ascend910b_training
90@pytest.mark.platform_x86_ascend910b_training
91@pytest.mark.parametrize('rounding', ['ROUND', 'FLOOR', 'CEIL', 'TRUNC'])
92def test_quant_dynamic_shape(rounding):
93    """
94    Feature: Test quant_v2 with dynamic shape in graph mode.
95    Description: call ops.quant_v2 with valid input and index.
96    Expectation: return the correct value.
97    """
98    np.random.seed(1)
99    x = generate_random_input((2, 3, 4, 5), np.float32, mstype.float32)
100    scale = generate_random_input((5,), np.float32, mstype.float32)
101    offset = generate_random_input((5,), np.float32, mstype.float32)
102
103    x_dyn = Tensor(shape=[None, None, None, None], dtype=mstype.float32)
104    scale_dyn = Tensor(shape=[None], dtype=mstype.float32)
105    offset_dyn = Tensor(shape=[None], dtype=mstype.float32)
106
107    test_cell = test_utils.to_cell_obj(quant_forward_func)
108    test_cell.set_inputs(x_dyn, scale_dyn, offset_dyn, False, rounding, mstype.int8)
109    ms_out = test_cell(x, scale, offset, False, rounding, mstype.int8)
110
111    expect = generate_expect_output(x.asnumpy(), scale.asnumpy(), offset.asnumpy(), rounding)
112    np.testing.assert_allclose(ms_out.asnumpy(), expect)
113