• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18import mindspore.context as context
19from mindspore import Tensor
20from mindspore.nn import Cell
21import mindspore.ops as ops
22import mindspore.ops.operations as P
23
24def test_case_1():
25    class Net1(Cell):
26        def __init__(self):
27            super(Net1, self).__init__()
28            self.sub = ops.Sub()
29            self.mul = ops.Mul()
30            self.sum = ops.ReduceSum(keep_dims=False)
31            self.add = ops.Add()
32            self.pow = ops.Pow()
33        def construct(self, x, y, z):
34            t1 = self.sub(x, y)
35            t2 = self.mul(t1, x)
36            t3 = self.add(y, t2)
37            t4 = self.add(t3, t3)
38            t5 = z + 1.0
39            t6 = self.sum(t4)
40            t7 = self.add(t5, t6)
41            return t7
42    def get_output(x, y, z, net, enable_graph_kernel=False):
43        context.set_context(enable_graph_kernel=enable_graph_kernel)
44        net_obj = net()
45        output = net_obj(x, y, z)
46        return output
47
48    N = 8
49    x = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
50    y = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
51    z = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
52    expect = get_output(x, y, z, Net1, False)
53    output = get_output(x, y, z, Net1, True)
54    expect_np = expect.asnumpy().copy()
55    output_np = output.asnumpy().copy()
56    assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
57
58
59def test_case_2():
60    class Net2(Cell):
61        def __init__(self):
62            super(Net2, self).__init__()
63            self.sqrt = P.Sqrt()
64            self.sum = P.ReduceSum(keep_dims=True)
65            self.add = P.Add()
66            self.neg = P.Neg()
67        def construct(self, x, y):
68            sqrt_res = self.sqrt(x)
69            add_res = self.add(y, sqrt_res)
70            neg_res = self.neg(add_res)
71            return neg_res
72
73    def get_output(x, y, net, enable_graph_kernel=False):
74        context.set_context(enable_graph_kernel=enable_graph_kernel)
75        net_obj = net()
76        output = net_obj(x, y)
77        return output
78
79    N = 16
80    x = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
81    y = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
82    expect = get_output(x, y, Net2, False)
83    output = get_output(x, y, Net2, True)
84    expect_np = expect[0].asnumpy().copy()
85    output_np = output[0].asnumpy().copy()
86    assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
87
88@pytest.mark.level0
89@pytest.mark.platform_x86_gpu_training
90@pytest.mark.env_onecard
91def test_gpu_case_1():
92    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
93    context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
94    test_case_1()
95
96@pytest.mark.level0
97@pytest.mark.platform_x86_gpu_training
98@pytest.mark.env_onecard
99def test_gpu_case_2():
100    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
101    context.set_context(graph_kernel_flags="--enable_low_precision=true")
102    test_case_2()
103
104@pytest.mark.level0
105@pytest.mark.platform_arm_ascend_training
106@pytest.mark.platform_x86_ascend_training
107@pytest.mark.env_onecard
108def test_ascend_case_1():
109    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
110    context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
111    test_case_1()
112
113@pytest.mark.level0
114@pytest.mark.platform_arm_ascend_training
115@pytest.mark.platform_x86_ascend_training
116@pytest.mark.env_onecard
117def test_ascend_case_2():
118    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
119    context.set_context(graph_kernel_flags="--enable_low_precision=true")
120    test_case_2()
121