1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore.context as context 19from mindspore import Tensor 20from mindspore.nn import Cell 21import mindspore.ops.operations._grad_ops as G 22 23 24class MinmumGradNet(Cell): 25 def __init__(self): 26 super(MinmumGradNet, self).__init__() 27 self.minimum_grad = G.MinimumGrad() 28 29 def construct(self, x, y, dy): 30 return self.minimum_grad(x, y, dy) 31 32 33def gen_data(): 34 np.random.seed(0) 35 input_x_np = np.random.normal(0, 1, [2, 3]).astype(np.float32) 36 input_y_np = np.random.normal(0, 1, [1]).astype(np.float32) 37 input_dout_np = np.minimum(input_x_np, input_y_np).astype(np.float32) 38 input_x = Tensor(input_x_np) 39 input_y = Tensor(input_y_np) 40 input_dout = Tensor(input_dout_np) 41 return input_x, input_y, input_dout 42 43 44def get_minimum_grad_output(input_x, input_y, input_dout, enable_graph_kernel=False): 45 context.set_context(enable_graph_kernel=enable_graph_kernel) 46 net = MinmumGradNet() 47 result = net(input_x, input_y, input_dout) 48 return result[0].asnumpy(), result[1].asnumpy() 49 50 51def test_minimum_grad(): 52 input_x, input_y, input_dout = gen_data() 53 result_off = get_minimum_grad_output(input_x, input_y, input_dout, False) 54 result_on = get_minimum_grad_output(input_x, input_y, input_dout, True) 55 assert np.allclose(result_on[0], result_off[0], rtol=1.e-4, atol=1.e-8, equal_nan=True) 56 assert np.allclose(result_on[1], result_off[1], rtol=1.e-4, atol=1.e-8, equal_nan=True) 57 58 59@pytest.mark.level0 60@pytest.mark.platform_x86_gpu_training 61@pytest.mark.env_onecard 62def test_minimum_grad_gpu(): 63 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 64 test_minimum_grad() 65 66 67@pytest.mark.level0 68@pytest.mark.platform_arm_ascend_training 69@pytest.mark.platform_x86_ascend_training 70@pytest.mark.env_onecard 71def test_minimum_grad_ascend(): 72 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 73 test_minimum_grad() 74