• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17from tests.st.control.cases_register import case_register
18import mindspore.context as context
19import mindspore as ms
20from mindspore.ops import composite as C
21from mindspore.ops import functional as F
22from mindspore.ops import operations as P
23from mindspore import ops, Tensor, nn
24
25context.set_context(mode=context.GRAPH_MODE)
26
27
28@case_register.level0
29@case_register.target_gpu
30@case_register.target_ascend
31def test_dde_make_tuple_joined_with_tuple_output_primitive():
32    """
33    Feature: Eliminate unused element for tuple.
34    Description: Two branch return make tuple and tuple output node like top_k
35    Expectation: Correct result and no exception.
36    """
37
38    @ms.jit
39    def topk_fun(x, k):
40        if k == 0:
41            output = (ms.ops.ones((0,), dtype=ms.float32), ms.ops.ones((0,), dtype=ms.int32))
42        else:
43            output = ms.ops.topk(x, k, None, True, True)
44        return output
45
46    x = ms.tensor([1., 2., 3.])
47    k = ms.tensor([0])
48    out = topk_fun(x, k)
49    expect_out0 = ms.ops.ones((0,), dtype=ms.float32)
50    expect_out1 = ms.ops.ones((0,), dtype=ms.int32)
51    assert np.allclose(out[0].asnumpy(), expect_out0.asnumpy())
52    assert np.allclose(out[1].asnumpy(), expect_out1.asnumpy())
53
54
55@case_register.level0
56@case_register.target_gpu
57@case_register.target_ascend
58def test_dde_parameter_converted_to_value_tuple():
59    """
60    Feature: Eliminate unused element for tuple.
61    Description: The value_tuple is converted from the parameter which is not set sequence_nodes.
62    Expectation: Correct result and no exception.
63    """
64
65    def _old_norm(norm_type, x):
66        out = F.pow((F.reduce_sum(F.pow(x, norm_type))), 1. / norm_type).astype(x.dtype)
67        return out
68
69    class ClipByNormFuncNet(nn.Cell):
70        def __init__(self, max_norm, norm_type=2.0, error_if_nonfinite=False):
71            super().__init__()
72            self.max_norm = max_norm
73            self.norm_type = norm_type
74            self.error_if_nonfinite = error_if_nonfinite
75            self.partial_op = P.Partial()
76            self.hyper_map = C.HyperMap()
77
78        def construct(self, *x):
79            is_tensor = False
80            if isinstance(x, Tensor):
81                x = [x]
82                is_tensor = True
83            total_norm = _old_norm(self.norm_type,
84                                   F.stack(self.hyper_map(self.partial_op(_old_norm, self.norm_type), x)))
85            clip_coef = self.max_norm / (total_norm + 1e-6)
86            if clip_coef < 1:
87                ret = self.hyper_map(self.partial_op(F.mul, clip_coef), x)
88            else:
89                ret = x
90            if is_tensor:
91                return ret[0]
92            return ret
93
94    class GradNetWrtX(nn.Cell):
95        def __init__(self, net):
96            super(GradNetWrtX, self).__init__()
97            self.net = net
98            self.grad_op = ops.GradOperation(sens_param=True)
99
100        def construct(self, *x):
101            gradient_function = self.grad_op(self.net)
102            return gradient_function(*x)
103
104    ms.set_context(mode=ms.GRAPH_MODE)
105    net = ClipByNormFuncNet(max_norm=1, norm_type=2, error_if_nonfinite=True)
106    net.set_train()
107    x = [ops.ones((2, 2)), ops.ones((2,))]
108    ms_output = net(*x)
109    output = GradNetWrtX(net)(*x, ms_output)
110    expect_out = np.array([[0.4082482, 0.4082482], [0.4082482, 0.4082482]]).astype(np.float32)
111    assert np.allclose(output.asnumpy(), expect_out)
112