• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test graph fallback """
16import pytest
17import numpy as np
18
19import mindspore.nn as nn
20from mindspore import Tensor, jit, context
21from mindspore.ops import operations as P
22from mindspore.ops import functional as F
23from mindspore.nn.probability import distribution
24import mindspore.common.dtype as mstype
25import mindspore.common._monad as monad
26import mindspore.scipy.linalg as alg
27
28context.set_context(mode=context.GRAPH_MODE)
29
30# `add_func` is defined in current file.
31def add_func(x, y):
32    return x + y
33
34
35@jit
36def do_increment(i):
37    add_1 = F.partial(add_func, 1)
38    return add_1(i)
39
40
41def test_increment():
42    a = do_increment(9)
43    assert a == 10
44
45
46@jit
47def use_monad(x, y):
48    res = P.Mul()(x, y)
49    res = F.depend(res, monad.U)
50    return res
51
52
53def test_use_monad():
54    x = Tensor(1.0, mstype.float32)
55    y = Tensor(1.0, mstype.float32)
56    print(use_monad(x, y))
57
58
59@jit
60def use_tuple_of_tensor():
61    me_x = (Tensor(1), Tensor(1))
62    return me_x
63
64
65def test_tuple_of_tensor():
66    """
67    Feature: JIT Fallback
68    Description: Test tuple of tensor in graph mode.
69    Expectation: No exception.
70    """
71    print(use_tuple_of_tensor())
72
73
74@jit
75def use_list_of_tensor():
76    me_x = [Tensor(1), Tensor(1)]
77    return me_x
78
79
80def test_list_of_tensor():
81    """
82    Feature: JIT Fallback
83    Description: Test list of tensor in graph mode.
84    Expectation: No exception.
85    """
86    print(use_list_of_tensor())
87
88
89class Net(nn.Cell):
90    def __init__(self):
91        super(Net, self).__init__()
92        self.x = Tensor([2, 3, 4])
93
94    def construct(self):
95        x_len = len(self.x)
96        for i in range(x_len):
97            print(i)
98        return x_len
99
100
101def test_builtins_len():
102    net = Net()
103    net()
104
105
106@jit
107def np_fallback_func():
108    array_x = tuple([2, 3, 4, 5])
109    np_x = np.array(array_x).astype(np.float32)
110    me_x = Tensor(np_x)
111    me_x = me_x + me_x
112    return me_x
113
114
115def test_np_fallback_func():
116    print(np_fallback_func())
117
118
119# Test `return` interpret node.
120@jit
121def div_mod_func1():
122    x = 8
123    y = 3
124    a = divmod(x, y)
125    return Tensor(a)
126
127
128def test_div_mod_func1():
129    print(div_mod_func1())  # (2, 2)
130
131
132# Test interpret node with parameters as input.
133@jit
134def div_mod_func2(x, y):
135    a = divmod(x, y)
136    return Tensor(a)
137
138
139def test_div_mod_func2_scalar():
140    """
141    Feature: JIT Fallback
142    Description: Test divmod in graph.
143    Expectation: No exception.
144    """
145    print(div_mod_func2(8, 3))  # (2, 2)
146
147
148@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
149def test_div_mod_func2_tensor():
150    """
151    Feature: JIT Fallback
152    Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
153    Expectation: Not supported exception.
154    """
155    with pytest.raises(RuntimeError) as err:
156        print(div_mod_func2(Tensor(8), Tensor(3)))
157    assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
158
159
160@jit
161def select_func(cond, x, y):
162    if isinstance(cond, (tuple, list)):
163        output = y
164    elif isinstance(cond, Tensor):
165        output = F.select(cond, x, y)
166    else:
167        output = x
168    return output
169
170
171def test_select_func():
172    cond = Tensor([True, False])
173    x = Tensor([2, 3], mstype.float32)
174    y = Tensor([1, 2], mstype.float32)
175    print(select_func(cond, x, y))
176
177
178@jit
179def select_func2(cond, x, y):
180    if isinstance(cond, (tuple, list)):
181        output = y
182    if isinstance(cond, Tensor):
183        output = F.select(cond, x, y)
184    else:
185        output = x
186    return output
187
188
189def test_select_func2():
190    cond = Tensor([True, False])
191    x = Tensor([2, 3], mstype.float32)
192    y = Tensor([1, 2], mstype.float32)
193    print(select_func2(cond, x, y))
194
195
196@jit
197def slice_func(a, b):
198    a[1:3, ::] = b
199    return a
200
201
202def test_slice_func():
203    a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
204    b = Tensor([1], dtype=mstype.float32)
205    print(slice_func(a, b))
206
207
208def test_context():
209    """
210    Feature: JIT Fallback
211    Description: Test context in graph.
212    Expectation: No exception.
213    """
214    class ContextNet(nn.Cell):
215        def __init__(self):
216            super(ContextNet, self).__init__()
217            self.mode = context.get_context("mode")
218
219        def construct(self):
220            out = 1
221            if self.mode == context.GRAPH_MODE:
222                out = 2
223            return out
224
225    net = ContextNet()
226    out = net()
227    print(out)
228
229
230def test_scipy_module():
231    """
232    Feature: JIT Fallback
233    Description: Test scipy module in graph.
234    Expectation: No exception.
235    """
236    class Network(nn.Cell):
237        def construct(self, x):
238            return alg.eigh(x)
239
240    net = Network()
241    x = Tensor([[2, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]])
242    out = net(x)
243    print(out)
244
245
246def test_probability_cauchy():
247    """
248    Feature: JIT Fallback
249    Description: NumPy method is called in probability cauchy.
250    Expectation: No exception.
251    """
252    class CauchyProb(nn.Cell):
253        def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
254            super().__init__()
255            self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
256
257        def construct(self, value, loc=None, scale=None):
258            out1 = self.b.prob(value, loc, scale)
259            out2 = self.b.log_prob(value, loc, scale)
260            out3 = self.b.cdf(value, loc, scale)
261            out4 = self.b.log_cdf(value, loc, scale)
262            out5 = self.b.survival_function(value, loc, scale)
263            out6 = self.b.log_survival(value, loc, scale)
264            return out1, out2, out3, out4, out5, out6
265
266
267    loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
268    scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
269    loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
270    scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
271    value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
272
273    net = CauchyProb(loc, scale)
274    net(Tensor(value), Tensor(loc_a), Tensor(scale_a))
275