• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import numpy as np
17import mindspore.nn as nn
18import mindspore.ops.operations as P
19import mindspore.ops.functional as F
20from mindspore import context, Tensor
21from mindspore.common import dtype as mstype
22
23context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
24
25
26class NpuFloatNet(nn.Cell):
27    """ NpuFloat definition, base on the related code in test_math_ops.py."""
28
29    def __init__(self):
30        super(NpuFloatNet, self).__init__()
31        self.mul = P.Mul()
32        self.alloc_status = P.NPUAllocFloatStatus()
33        self.get_status = P.NPUGetFloatStatus()
34        self.clear_status = P.NPUClearFloatStatus()
35        self.fill = P.Fill()
36        self.shape_op = P.Shape()
37        self.select = P.Select()
38        self.less = P.Less()
39        self.cast = P.Cast()
40        self.dtype = P.DType()
41        self.reduce_sum = P.ReduceSum(keep_dims=True)
42        self.sub = P.Sub()
43        self.neg = P.Neg()
44
45    def construct(self, x):
46        init = self.alloc_status()
47        clear_status = self.clear_status(init)
48        x = F.depend(x, clear_status)  # let x depend on clear_status
49        res = self.sub(x, self.neg(x))
50        init = F.depend(init, res)  # let get_status depend on res
51        get_status = self.get_status(init)
52        # let reduce_sum depend on get_statusk
53        init = F.depend(init, get_status)
54        flag_sum = self.reduce_sum(init, (0,))
55        base = self.cast(self.fill(self.dtype(
56            res), self.shape_op(res), 0.0), self.dtype(flag_sum))
57        cond = self.less(base, flag_sum)
58        out = self.select(cond, self.cast(base, self.dtype(res)), res)
59        return out
60
61
62@pytest.mark.level1
63@pytest.mark.platform_arm_ascend_training
64@pytest.mark.platform_x86_ascend_training
65@pytest.mark.env_onecard
66def test_float_not_overflow():
67    input_data = Tensor(np.full((8, 5, 3, 1), 655, dtype=np.float16), dtype=mstype.float16)
68    net = NpuFloatNet()
69    out = net(input_data)
70    # not overflow, we should got expected output.
71    expect = Tensor(np.full((8, 5, 3, 1), 655 * 2,
72                            dtype=np.float16), dtype=mstype.float16)
73    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
74
75
76@pytest.mark.level0
77@pytest.mark.platform_arm_ascend_training
78@pytest.mark.platform_x86_ascend_training
79@pytest.mark.env_onecard
80def test_float_overflow():
81    input_data = Tensor(np.full((8, 5, 3, 1), 65504, dtype=np.float16), dtype=mstype.float16)
82    net = NpuFloatNet()
83    out = net(input_data)
84    # all zero if overflowed.
85    assert np.all(out.asnumpy() == 0)
86