1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore.context as context 19from mindspore import Tensor 20from mindspore.common import dtype as mstype 21from mindspore.nn import Cell 22import mindspore.ops.operations as P 23 24#{cast} would be recompute and fused 25class Net1(Cell): 26 def __init__(self): 27 super(Net1, self).__init__() 28 self.cast = P.Cast() 29 self.sum = P.ReduceSum(keep_dims=False) 30 31 def construct(self, x): 32 cast_res = self.cast(x, mstype.float32) 33 sum1_res = self.sum(cast_res, (0,)) 34 sum2_res = self.sum(cast_res, (1,)) 35 return sum1_res, sum2_res 36 37#{sqrt} would be recompute on Ascend 38class Net2(Cell): 39 def __init__(self): 40 super(Net2, self).__init__() 41 self.sqrt = P.Sqrt() 42 self.sum = P.ReduceSum(keep_dims=True) 43 self.add = P.Add() 44 self.neg = P.Neg() 45 46 def construct(self, x0, x1): 47 sqrt_res = self.sqrt(x0) 48 neg_res = self.neg(sqrt_res) 49 add_res = self.add(x1, sqrt_res) 50 sum_res = self.sum(add_res, (0,)) 51 return neg_res, sum_res 52 53#{sqrt} would be recompute 54class Net3(Cell): 55 def __init__(self): 56 super(Net3, self).__init__() 57 self.sqrt = P.Sqrt() 58 self.add = P.Add() 59 self.neg = P.Neg() 60 61 def construct(self, x0, x1): 62 sqrt_res = self.sqrt(x0) 63 neg_res = self.neg(sqrt_res) 64 add_res = self.add(x1, sqrt_res) 65 return neg_res, add_res 66 67#{sqrt neg} would be recompute 68class Net4(Cell): 69 def __init__(self): 70 super(Net4, self).__init__() 71 self.sqrt = P.Sqrt() 72 self.neg = P.Neg() 73 self.sum = P.ReduceSum(keep_dims=False) 74 75 def construct(self, x): 76 sqrt_res = self.sqrt(x) 77 neg_res = self.neg(sqrt_res) 78 sum1_res = self.sum(neg_res, (0,)) 79 sum2_res = self.sum(neg_res, (1,)) 80 return sum1_res, sum2_res 81 82#{sqrt} would be recompute 83class Net5(Cell): 84 def __init__(self): 85 super(Net5, self).__init__() 86 self.sqrt = P.Sqrt() 87 self.add = P.Add() 88 89 def construct(self, x0, x1, x2): 90 sqrt_res = self.sqrt(x0) 91 add1_res = self.add(sqrt_res, x1) 92 add2_res = self.add(sqrt_res, x2) 93 return add1_res, add2_res 94 95def test_basic1(net): 96 def get_output(i0, net, enable_graph_kernel=False): 97 context.set_context(enable_graph_kernel=enable_graph_kernel) 98 net_obj = net() 99 output = net_obj(i0) 100 return output 101 102 i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) 103 expect = get_output(i0, net, False) 104 output = get_output(i0, net, True) 105 expect0_np = expect[0].asnumpy().copy() 106 output0_np = output[0].asnumpy().copy() 107 expect1_np = expect[1].asnumpy().copy() 108 output1_np = output[1].asnumpy().copy() 109 assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) 110 assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) 111 112 113def test_basic2(net): 114 def get_output(i0, i1, net, enable_graph_kernel=False): 115 context.set_context(enable_graph_kernel=enable_graph_kernel) 116 net_obj = net() 117 output = net_obj(i0, i1) 118 return output 119 120 i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32)) 121 i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32)) 122 expect = get_output(i0, i1, net, False) 123 output = get_output(i0, i1, net, True) 124 expect0_np = expect[0].asnumpy().copy() 125 output0_np = output[0].asnumpy().copy() 126 expect1_np = expect[1].asnumpy().copy() 127 output1_np = output[1].asnumpy().copy() 128 assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) 129 assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) 130 131def test_basic3(net): 132 def get_output(i0, i1, i2, net, enable_graph_kernel=False): 133 context.set_context(enable_graph_kernel=enable_graph_kernel) 134 net_obj = net() 135 output = net_obj(i0, i1, i2) 136 return output 137 138 i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16)) 139 i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) 140 i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16)) 141 expect = get_output(i0, i1, i2, net, False) 142 output = get_output(i0, i1, i2, net, True) 143 expect0_np = expect[0].asnumpy().copy() 144 output0_np = output[0].asnumpy().copy() 145 expect1_np = expect[1].asnumpy().copy() 146 output1_np = output[1].asnumpy().copy() 147 assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) 148 assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) 149 150@pytest.mark.level0 151@pytest.mark.platform_x86_gpu_training 152@pytest.mark.env_onecard 153def test_gpu_1(): 154 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 155 test_basic1(Net1) 156 157@pytest.mark.level0 158@pytest.mark.platform_x86_gpu_training 159@pytest.mark.env_onecard 160def test_gpu_2(): 161 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 162 test_basic2(Net2) 163 164@pytest.mark.level0 165@pytest.mark.platform_x86_gpu_training 166@pytest.mark.env_onecard 167def test_gpu_3(): 168 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 169 test_basic2(Net3) 170 171@pytest.mark.level0 172@pytest.mark.platform_x86_gpu_training 173@pytest.mark.env_onecard 174def test_gpu_4(): 175 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 176 test_basic1(Net4) 177 178@pytest.mark.level0 179@pytest.mark.platform_x86_gpu_training 180@pytest.mark.env_onecard 181def test_gpu_5(): 182 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 183 test_basic3(Net5) 184 185@pytest.mark.level0 186@pytest.mark.platform_arm_ascend_training 187@pytest.mark.platform_x86_ascend_training 188@pytest.mark.env_onecard 189def test_ascend_1(): 190 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 191 test_basic1(Net1) 192 193@pytest.mark.level0 194@pytest.mark.platform_arm_ascend_training 195@pytest.mark.platform_x86_ascend_training 196@pytest.mark.env_onecard 197def test_ascend_2(): 198 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 199 test_basic2(Net2) 200 201@pytest.mark.level0 202@pytest.mark.platform_arm_ascend_training 203@pytest.mark.platform_x86_ascend_training 204@pytest.mark.env_onecard 205def test_ascend_3(): 206 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 207 test_basic2(Net3) 208 209@pytest.mark.level0 210@pytest.mark.platform_arm_ascend_training 211@pytest.mark.platform_x86_ascend_training 212@pytest.mark.env_onecard 213def test_ascend_4(): 214 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 215 test_basic1(Net4) 216 217@pytest.mark.level0 218@pytest.mark.platform_arm_ascend_training 219@pytest.mark.platform_x86_ascend_training 220@pytest.mark.env_onecard 221def test_ascend_5(): 222 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 223 test_basic3(Net5) 224