• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test graph fallback """
16import mindspore.common.dtype as mstype
17import mindspore.nn as nn
18from mindspore import Tensor, jit, context
19from mindspore import ops, tensor
20
21import numpy as np
22import pytest
23
24context.set_context(mode=context.GRAPH_MODE)
25
26
27class ControlNet(nn.Cell):
28    def inner_function_1(self, a, b):
29        return a + b
30
31    def inner_function_2(self, a, b):
32        return a - b
33
34    def construct(self, x):
35        a = Tensor(np.array(4), mstype.int32)
36        b = Tensor(np.array(5), mstype.int32)
37        if a + b > x:
38            return self.inner_function_1(a, b)
39        return self.inner_function_2(a, b)
40
41
42@pytest.mark.level2
43@pytest.mark.platform_x86_gpu_training
44@pytest.mark.platform_arm_ascend_training
45@pytest.mark.platform_x86_ascend_training
46@pytest.mark.env_onecard
47def test_fallback_control_sink_tensor():
48    """
49    Feature: Fallback feature: support define Tensor in Class construct.
50    Description: Fallback feature: support define Tensor in Class construct.
51    Expectation: Fallback feature: support define Tensor in Class construct.
52    """
53    x = Tensor(np.array(1), mstype.int32)
54    net = ControlNet()
55    output = net(x)
56    output_expect = Tensor(9, mstype.int32)
57    assert output == output_expect
58
59
60@pytest.mark.level1
61@pytest.mark.platform_x86_gpu_training
62@pytest.mark.platform_arm_ascend_training
63@pytest.mark.platform_x86_ascend_training
64@pytest.mark.env_onecard
65def test_np_tensor_list():
66    """
67    Feature: Fallback feature
68    Description: support Basic method of Tensor list.
69    Expectation: No exception.
70    """
71
72    @jit
73    def np_tensor_list():
74        a = Tensor(np.array(4), mstype.int32)
75        b = Tensor(np.array(5), mstype.int32)
76        c = Tensor(np.array(6), mstype.int32)
77        tensor_list = [a, b]
78        for x in tensor_list:
79            print(x)
80        tensor_list.append(tensor_list[-1] + c)
81        return tensor_list
82
83    tensor_list = np_tensor_list()
84    print("tensor_list:", tensor_list)
85    assert len(tensor_list) == 3
86
87
88@jit
89def np_fallback_func_tensor_index(x):
90    array_x = tuple([2, 3, 4, 5])
91    np_x = np.array(array_x).astype(np.float32)
92    me_x = Tensor(np_x)
93    me_x = me_x + me_x
94    return me_x[x]
95
96
97@pytest.mark.level2
98@pytest.mark.platform_x86_gpu_training
99@pytest.mark.platform_arm_ascend_training
100@pytest.mark.platform_x86_ascend_training
101@pytest.mark.env_onecard
102def test_np_fallback_func_tensor_index():
103    """
104    Feature: Fallback feature: support Tensor index.
105    Description: Fallback feature: support Tensor index.
106    Expectation: Fallback feature: support Tensor index.
107    """
108    x = Tensor(1, mstype.int32)
109    output = np_fallback_func_tensor_index(x)
110    output_expect = Tensor(6, mstype.float32)
111    assert output == output_expect
112
113
114@pytest.mark.level2
115@pytest.mark.platform_x86_gpu_training
116@pytest.mark.platform_arm_ascend_training
117@pytest.mark.platform_x86_ascend_training
118@pytest.mark.env_onecard
119def test_fallback_tensor_compare_with_variable():
120    """
121    Feature: Fallback feature
122    Description: Test ms.Tensor() in graph mode.
123    Expectation: No exception.
124    """
125
126    @jit
127    def foo(x):
128        while x > Tensor([0]):
129            x = x - abs(Tensor([-1]))
130        return x
131
132    res = foo(Tensor([6]))
133    assert res == 0
134
135
136@pytest.mark.level1
137@pytest.mark.platform_x86_gpu_training
138@pytest.mark.platform_arm_ascend_training
139@pytest.mark.platform_x86_ascend_training
140@pytest.mark.env_onecard
141def test_np_tensor_add():
142    """
143    Feature: Fallback feature
144    Description: support Tensor add.
145    Expectation: No exception.
146    """
147
148    @jit
149    def np_tensor_add():
150        a = Tensor(np.array(4))
151        b = Tensor(np.array(5))
152        tensor_list = [a, b]
153        for x in tensor_list:
154            print(x)
155        x = 6
156        np_x = np.array(x)
157        c = Tensor(np_x)
158        d = tensor_list[-1] + c
159        tensor_list.append(d)
160        return tensor_list
161
162    tensor_list = np_tensor_add()
163    print("tensor_list:", tensor_list)
164    assert tensor_list[-1] == 11
165
166
167@pytest.mark.level0
168@pytest.mark.platform_x86_gpu_training
169@pytest.mark.env_onecard
170def test_user_define_bprop_using_fallback():
171    """
172    Feature: Fallback feature
173    Description: user define bprop support jit fallback.
174    Expectation: No exception.
175    """
176    class TestBpropCell(nn.Cell):
177        def __init__(self):
178            super().__init__()
179            self.const_value = 1
180
181        def construct(self, x):
182            x = x * self.const_value
183            x = x.asnumpy()
184            x = (x + x) * x
185            return tensor(x, mstype.float32)
186
187        def bprop(self, x, out, dout):
188            x = dout.asnumpy()
189            x = 2 * (x * x) * (np.log(x) + 1)
190            return (tensor(x, mstype.float32),)
191
192    class TestCell(nn.Cell):
193        def __init__(self):
194            super().__init__()
195            self.user_define_bprop = TestBpropCell()
196
197        def construct(self, x):
198            x = 2 * x
199            x = self.user_define_bprop(x)
200            x = x + 1
201            x = 2 * x
202            return x
203
204    test_cell = TestCell()
205    input_x = Tensor([1, 2, 3, 4], mstype.float32)
206    graph_output = ops.grad(test_cell)(input_x)
207
208    context.set_context(mode=context.PYNATIVE_MODE)
209    pynative_out = ops.grad(test_cell)(input_x)
210    context.set_context(mode=context.GRAPH_MODE)
211
212    assert np.allclose(graph_output.asnumpy(), pynative_out.asnumpy())
213