• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16@File  : test_parse_method.py
17@Author:
18@Date  : 2019-06-27
19@Desc  : test parse the object's method
20"""
21import logging
22from dataclasses import dataclass
23
24import numpy as np
25import pytest
26
27import mindspore.nn as nn
28from mindspore import context
29from mindspore._extends.parse.standard_method import ms_len
30from mindspore.common.api import ms_function
31from mindspore.common.tensor import Tensor
32from mindspore.ops.composite import core
33from mindspore.ops.primitive import constexpr
34from mindspore.ops import functional as F
35from ..ut_filter import non_graph_engine
36
37
38def setup_module(module):
39    context.set_context(mode=context.PYNATIVE_MODE)
40
41
42log = logging.getLogger("test")
43log.setLevel(level=logging.ERROR)
44
45
46@ms_function
47def default_parameter_f(x, y=3):
48    """ default_parameter_f """
49    z = x + y
50    return z
51
52
53# Test case: test parse fn that use default parameter
54def test_parse_defalut_parameter_case1():
55    """ Test default parameter function call """
56    log.debug("begin test_parse_defalut_parameter_case1")
57    ret = default_parameter_f(2)
58    log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
59
60
61def get_val_fn(x):
62    """ get_val_fn """
63    ret = x + 3
64    return ret
65
66
67# Test case: test bool not
68@ms_function
69def bool_exp(x, y):
70    """ bool_exp """
71    return not x > y
72
73
74def test_bool_exp():
75    """ test_bool_exp """
76    bool_exp(1, 2)
77
78
79# Test case: use the variable parameter for @mindspore
80@ms_function
81def var_parameter_f(x, *args):
82    """ var_parameter_f """
83    z = x + args[0] + args[1] + args[2]
84    return z
85
86
87def test_var_parameter_case1():
88    """ test_var_parameter_case1 """
89    log.debug("start test_var_parameter_case1")
90    var_parameter_f(1, 2, 3, 4, 5)
91    log.debug("end test_var_parameter_case1")
92
93
94class Net(nn.Cell):
95    """ Net definition """
96
97    def __init__(self, value1):
98        super(Net, self).__init__()
99        self.relu = nn.ReLU()
100        self.softmax = nn.Softmax(0)
101        self.axis = 0
102        self.TC = ClassTest("test_class", 1.2)
103        self.value = value1
104
105    @ms_function
106    def construct(self, x):
107        x = self.get_test_value(x)
108        return x
109
110    def get_test_value(self, x):
111        ret = x + self.value
112        return ret
113
114
115class ClassTest:
116    """ ClassTest definition """
117
118    def __init__(self, name, value1):
119        self.name = name
120        self.value = value1
121
122    def get_name(self):
123        return self.name
124
125    def get_value(self, inc):
126        ret = self.value + inc
127        return ret
128
129    def __call__(self, *args, **kwargs):
130        pass
131
132
133# Test: call method on parse graph code
134@non_graph_engine
135def test_call_method_on_construct():
136    """ test_call_method_on_construct """
137    log.debug("begin test_call_method_on_construct")
138
139    x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
140    y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
141    z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
142
143    net = Net(y)
144    output = net.construct(x)
145    result = output.asnumpy()
146    print(result)
147    assert np.all(result == z)
148
149    log.debug("finished test_call_method_on_construct")
150
151
152# Test: call method on parse graph code
153class Net1(nn.Cell):
154    """ Net1 definition """
155
156    def __init__(self, v1, v2):
157        super(Net1, self).__init__()
158        self.relu = nn.ReLU()
159        self.softmax = nn.Softmax(0)
160        self.axis = 0
161        self.TC = ClassTest("test_class", v1)
162        self.value = v2
163
164    @ms_function
165    def construct(self, x):
166        x = x + self.TC.get_value(self.value)
167        return x
168
169
170@non_graph_engine
171def test_call_other_object_method():
172    """ test_call_other_object_method """
173    log.debug("begin test_call_other_object_method")
174
175    x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
176    y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
177    y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
178    z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
179
180    net = Net1(y, y1)
181    with pytest.raises(TypeError):
182        output = net.construct(x)
183        result = output.asnumpy()
184        print(result)
185        assert np.all(result == z)
186
187    log.debug("finished test_call_other_object_method")
188
189
190# Test: call global object method(not self) on parse graph code
191value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
192TC = ClassTest("test_class", value)
193
194
195class Net2(nn.Cell):
196    """ Net2 definition """
197
198    def __init__(self, value1):
199        super(Net2, self).__init__()
200        self.value = value1
201
202    @ms_function
203    def construct(self, x):
204        x = x + TC.get_value(self.value)
205        return x
206
207    @ms_function
208    def construct1(self, x):
209        x = x + TC.value
210        x = x + self.value
211        return x
212
213
214@non_graph_engine
215def test_call_no_self_other_object_method():
216    """ test_call_no_self_other_object_method """
217    log.debug("begin test_call_other_object_method")
218    x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
219    y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
220    z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
221
222    net = Net2(y)
223    with pytest.raises(TypeError):
224        output = net.construct(x)
225        result = output.asnumpy()
226        print(result)
227        assert np.all(result == z)
228
229    log.debug("finished test_call_other_object_method")
230
231
232def test_call_no_self_other_object_attr_value():
233    """ test_call_no_self_other_object_attr_value """
234    # do not support tensor as init input.
235    return
236
237
238# Test case: use the * to unlock the varargs for @mindspore
239def vararg1(x, y):
240    """ vararg1 """
241    z = x + y
242    return z
243
244
245def varargs_main(fn):
246    """ varargs_main """
247
248    @ms_function
249    def t1(*args):
250        return fn(*args)
251
252    return t1
253
254
255def test_var_parameter_case3():
256    """ test_var_parameter_case3 """
257    log.debug("start test_var_parameter_case3")
258    ret = varargs_main(vararg1)(1, 2)
259    log.debug("ret = %r", ret)
260    log.debug("end test_var_parameter_case3")
261
262
263# Test case: test the flag set
264@core(tg=True)
265def set_flag(x):
266    """ set_flag """
267    return x + 1
268
269
270@ms_function
271def set_test_flag_main(x, y):
272    """ set_test_flag_main """
273    z = set_flag(x)
274    z = z + y
275    return z
276
277
278def test_set_flag():
279    """ Test default parameter function call """
280    log.debug("begin test_set_flag")
281    ret = set_test_flag_main(2, 3)
282    log.debug("finished test_set_flag, ret = %r", ret)
283
284
285@dataclass
286class Access:
287    a: int
288    b: int
289
290    def max(self):
291        if self.a > self.b:
292            return self.a
293        return self.b
294
295
296@ms_function
297def invoke_dataclass(x, y):
298    """ invoke_dataclass """
299    acs = Access(x, y)
300    return acs.max()
301
302
303def test_access():
304    """ test_access """
305    invoke_dataclass(1, 2)
306
307@dataclass
308class Access2:
309    a: int
310    b: int
311
312    def max(self):
313        if self.a > self.b:
314            return self.c
315        return self.b
316
317
318@ms_function
319def invoke_dataclass2(x, y):
320    """ invoke_dataclass """
321    acs = Access2(x, y)
322    return acs.max()
323
324
325def test_access_attr_error():
326    """ test_access """
327    with pytest.raises(AttributeError):
328        invoke_dataclass2(2, 1)
329
330
331def myfunc(x):
332    """ myfunc """
333    return x * x
334
335
336@ms_function
337def ms_infer_for():
338    """ ms_infer_for """
339    a = 0.0
340    for x in [1.1, 2.3, 3.3]:
341        a = a + x
342    return a
343
344
345def test_infer_for():
346    """ test_infer_for """
347    ms_infer_for()
348
349
350@ms_function
351def ms_infer_for_func(y):
352    """ ms_infer_for_func """
353    for x in [1.0, 2.0, 3.0]:
354        y = myfunc(x) + y
355    return y
356
357
358def test_ms_infer_for_func():
359    """ test_ms_infer_for_func """
360    ms_infer_for_func(1.0)
361
362
363@ms_function
364def add(x, y):
365    """ add """
366    return x + y
367
368
369def test_add():
370    """ test_add """
371    res = add(1, 2.0)
372    return res
373
374
375@ms_function
376def add_list():
377    """ add_list """
378    a = [1, 2, 3]
379    b = a[1] + a[2]
380    return b
381
382
383def test_list():
384    """ test_list """
385    return add_list()
386
387
388@ms_function
389def compare_list_len():
390    """ compare_list_len """
391    a = [1, 2, 3]
392    return ms_len(a)
393
394
395def test_list_len():
396    """ test_list_len """
397    compare_list_len()
398
399
400@ms_function
401def add_tuple():
402    """ add_tuple """
403    a = (1, 2, 3)
404    b = a[1] + a[2]
405    return b
406
407
408def test_tuple():
409    """ test_tuple """
410    return add_tuple()
411
412
413def invoke_func(x):
414    """ invoke_func """
415    return x * x
416
417
418@ms_function
419def tuple_of_node(x, y):
420    """ tuple_of_node """
421    a = invoke_func(x)
422    b = invoke_func(y)
423    c = (a, b)
424    d = c[1] * x
425    return d
426
427
428def test_tuple_node():
429    """ test_tuple_node """
430    res = tuple_of_node(1, 2)
431    return res
432
433
434@ms_function
435def range_spec(x, y):
436    """ range_spec """
437    for _ in range(1, 10, 3):
438        x = x + 1
439    return x + y
440
441
442def test_range():
443    """ test_range """
444    res = range_spec(10, 10)
445    return res
446
447def test_expr():
448    """ test const expr """
449    a = (1, 2)
450    @constexpr
451    def tuple_len(x):
452        assert len(x) == 2
453    tuple_len(a)
454
455
456def test_tuple_to_array():
457    """ test range tuple to array """
458    range_x = range(10)
459    res = F.tuple_to_array(range_x)
460    print(res)
461