1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17import mindspore as ms 18from tests.st.utils import test_utils 19from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP 20 21 22def generate_random_input(shape, dtype): 23 return np.random.randn(*shape).astype(dtype), np.random.randn(*shape).astype(dtype) 24 25 26def generate_expect_forward_output(x, y, rounding_mode): 27 if rounding_mode == 'floor': 28 return np.floor_divide(x, y) 29 if rounding_mode == 'trunc': 30 return np.trunc(np.divide(x, y)) 31 return np.divide(x, y) 32 33 34class NetNone(ms.nn.Cell): 35 def __init__(self): 36 super().__init__() 37 self.div = ms.ops.div 38 39 def construct(self, x, y): 40 return self.div(x, y) 41 42 43class NetFloor(ms.nn.Cell): 44 def __init__(self): 45 super().__init__() 46 self.div = ms.ops.div 47 48 def construct(self, x, y): 49 return self.div(x, y, rounding_mode="floor") 50 51 52class NetTrunc(ms.nn.Cell): 53 def __init__(self): 54 super().__init__() 55 self.div = ms.ops.div 56 57 def construct(self, x, y): 58 return self.div(x, y, rounding_mode="trunc") 59 60 61@pytest.mark.level1 62@pytest.mark.env_onecard 63@pytest.mark.platform_x86_cpu 64@pytest.mark.platform_arm_cpu 65@pytest.mark.platform_x86_gpu_training 66@pytest.mark.platform_arm_ascend_training 67@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 68def test_div_vmap(mode): 69 """ 70 Feature: pyboost function. 71 Description: test function div vmap feature. 72 Expectation: expect correct result. 73 """ 74 ms.context.set_context(mode=mode) 75 x = np.array([7, 8, 9], dtype=np.float32) 76 y = np.array([14, 6, 12], dtype=np.float32) 77 output = ms.ops.vmap(ms.ops.div, in_axes=-1, out_axes=0)(ms.Tensor(x), ms.Tensor(y)) 78 expect = generate_expect_forward_output(x, y, None) 79 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-3) 80 81 82@pytest.mark.level0 83@pytest.mark.env_onecard 84@pytest.mark.platform_x86_cpu 85@pytest.mark.platform_x86_gpu_training 86@pytest.mark.platform_arm_ascend_training 87@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 88@pytest.mark.parametrize('rounding_mode', [None, 'floor', 'trunc']) 89def test_ops_div_std(mode, rounding_mode): 90 """ 91 Feature: pyboost function. 92 Description: test function div forward/backward. 93 Expectation: expect correct result. 94 """ 95 # forward test 96 ms.context.set_context(mode=mode) 97 x, y = generate_random_input((4, 5, 6), np.float32) 98 if rounding_mode == 'floor': 99 net = NetFloor() 100 elif rounding_mode == 'trunc': 101 net = NetTrunc() 102 else: 103 net = NetNone() 104 output = net(ms.Tensor(x, dtype=ms.float32), ms.Tensor(y, dtype=ms.float32)) 105 expect = generate_expect_forward_output(x, y, rounding_mode) 106 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-3) 107 # backward test 108 x, y = np.array([1.0, 5.0, 7.5]), np.array([4.0, 2.0, 3.0]) 109 net = NetNone() 110 output = ms.ops.grad(net, (0,))(ms.Tensor(x, dtype=ms.float32), ms.Tensor(y, dtype=ms.float32)) 111 expect = [0.25, 0.5, 0.33333333] 112 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-3) 113 114 115@pytest.mark.level1 116@pytest.mark.env_onecard 117@pytest.mark.platform_x86_cpu 118@pytest.mark.platform_x86_gpu_training 119@pytest.mark.platform_arm_ascend_training 120@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 121def test_ops_div_forward_case01(mode): 122 """ 123 Feature: pyboost function. 124 Description: test function div. 125 Expectation: expect correct result. 126 """ 127 ms.context.set_context(mode=mode) 128 x = np.random.randn(64, 32, 3578).astype(np.float32) 129 y = np.random.randn(64, 32, 1).astype(np.float32) 130 rounding_mode = None 131 net = NetNone() 132 output = net(ms.Tensor(x, dtype=ms.float32), ms.Tensor(y, dtype=ms.float32)) 133 expect = generate_expect_forward_output(x, y, rounding_mode) 134 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-3) 135 136 137@pytest.mark.level1 138@pytest.mark.env_onecard 139@pytest.mark.platform_x86_cpu 140@pytest.mark.platform_x86_gpu_training 141@pytest.mark.platform_arm_ascend_training 142@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 143def test_ops_div_forward_case02(mode): 144 """ 145 Feature: pyboost function. 146 Description: test function div. 147 Expectation: expect correct result. 148 """ 149 ms.context.set_context(mode=mode) 150 x = np.random.randn(64, 32, 1).astype(np.float32) 151 y = 7168 152 rounding_mode = None 153 net = NetNone() 154 output = net(ms.Tensor(x, dtype=ms.float32), y) 155 expect = generate_expect_forward_output(x, y, rounding_mode) 156 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-3) 157 158 159@test_utils.run_with_cell 160def div_forward_dyn(x, y): 161 return ms.ops.div(x, y) 162 163 164@pytest.mark.level1 165@pytest.mark.env_onecard 166@pytest.mark.platform_x86_cpu 167@pytest.mark.platform_x86_gpu_training 168@pytest.mark.platform_arm_ascend_training 169def test_div_dynamic_shape(): 170 """ 171 Feature: Test dynamic shape. 172 Description: test function div dynamic feature. 173 Expectation: expect correct result. 174 """ 175 ms_x0, ms_y0 = ms.Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), ms.float32), ms.Tensor(np.array([[1, 2, 3, 4]]), 176 ms.float32) 177 ms_x1, ms_y1 = ms.Tensor(np.array([[1, 2, 3], [5, 6, 7]]), ms.float32), ms.Tensor(np.array([[1, 2, 3]]), ms.float32) 178 TEST_OP(div_forward_dyn, [[ms_x0, ms_y0], [ms_x1, ms_y1]], '', disable_input_check=True, disable_yaml_check=True) 179