1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore.context as context 19from mindspore import Tensor 20from mindspore.nn import Cell 21import mindspore.ops as ops 22import mindspore.ops.operations as P 23 24def test_case_1(): 25 class Net1(Cell): 26 def __init__(self): 27 super(Net1, self).__init__() 28 self.sub = ops.Sub() 29 self.mul = ops.Mul() 30 self.sum = ops.ReduceSum(keep_dims=False) 31 self.add = ops.Add() 32 self.pow = ops.Pow() 33 def construct(self, x, y, z): 34 t1 = self.sub(x, y) 35 t2 = self.mul(t1, x) 36 t3 = self.add(y, t2) 37 t4 = self.add(t3, t3) 38 t5 = z + 1.0 39 t6 = self.sum(t4) 40 t7 = self.add(t5, t6) 41 return t7 42 def get_output(x, y, z, net, enable_graph_kernel=False): 43 context.set_context(enable_graph_kernel=enable_graph_kernel) 44 net_obj = net() 45 output = net_obj(x, y, z) 46 return output 47 48 N = 8 49 x = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32)) 50 y = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32)) 51 z = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32)) 52 expect = get_output(x, y, z, Net1, False) 53 output = get_output(x, y, z, Net1, True) 54 expect_np = expect.asnumpy().copy() 55 output_np = output.asnumpy().copy() 56 assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2) 57 58 59def test_case_2(): 60 class Net2(Cell): 61 def __init__(self): 62 super(Net2, self).__init__() 63 self.sqrt = P.Sqrt() 64 self.sum = P.ReduceSum(keep_dims=True) 65 self.add = P.Add() 66 self.neg = P.Neg() 67 def construct(self, x, y): 68 sqrt_res = self.sqrt(x) 69 add_res = self.add(y, sqrt_res) 70 neg_res = self.neg(add_res) 71 return neg_res 72 73 def get_output(x, y, net, enable_graph_kernel=False): 74 context.set_context(enable_graph_kernel=enable_graph_kernel) 75 net_obj = net() 76 output = net_obj(x, y) 77 return output 78 79 N = 16 80 x = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32)) 81 y = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32)) 82 expect = get_output(x, y, Net2, False) 83 output = get_output(x, y, Net2, True) 84 expect_np = expect[0].asnumpy().copy() 85 output_np = output[0].asnumpy().copy() 86 assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2) 87 88@pytest.mark.level0 89@pytest.mark.platform_x86_gpu_training 90@pytest.mark.env_onecard 91def test_gpu_case_1(): 92 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 93 context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean") 94 test_case_1() 95 96@pytest.mark.level0 97@pytest.mark.platform_x86_gpu_training 98@pytest.mark.env_onecard 99def test_gpu_case_2(): 100 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 101 context.set_context(graph_kernel_flags="--enable_low_precision=true") 102 test_case_2() 103 104@pytest.mark.level0 105@pytest.mark.platform_arm_ascend_training 106@pytest.mark.platform_x86_ascend_training 107@pytest.mark.env_onecard 108def test_ascend_case_1(): 109 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 110 context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean") 111 test_case_1() 112 113@pytest.mark.level0 114@pytest.mark.platform_arm_ascend_training 115@pytest.mark.platform_x86_ascend_training 116@pytest.mark.env_onecard 117def test_ascend_case_2(): 118 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 119 context.set_context(graph_kernel_flags="--enable_low_precision=true") 120 test_case_2() 121