• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test graph JIT Fallback runtime feature """
16import os
17import math
18from functools import reduce
19import pytest
20import numpy as np
21import mindspore as ms
22from mindspore import nn
23from mindspore import Tensor, tensor
24from mindspore import ops
25from mindspore import mutable, jit
26
27ms.set_context(mode=ms.GRAPH_MODE)
28
29
30@pytest.mark.level2
31@pytest.mark.platform_x86_gpu_training
32@pytest.mark.platform_arm_ascend_training
33@pytest.mark.platform_x86_ascend_training
34@pytest.mark.env_onecard
35def test_getattr_tensor_with_wrong_attr():
36    """
37    Feature: Syntax getattr.
38    Description: Graph syntax getattr support tensor input.
39    Expectation: AttributeError.
40    """
41
42    @ms.jit
43    def foo(x):
44        abs_func = getattr(x, "abs2")
45        return abs_func()
46
47    with pytest.raises(AttributeError) as err:
48        foo(ms.Tensor([-1, -2, -3]))  # Not throw error any more, should move to ST.
49    assert "object has no attribute" in str(err.value)
50
51
52@pytest.mark.level1
53@pytest.mark.platform_x86_gpu_training
54@pytest.mark.platform_arm_ascend_training
55@pytest.mark.platform_x86_ascend_training
56@pytest.mark.env_onecard
57def test_getattr_list_with_wrong_attr():
58    """
59    Feature: Syntax getattr.
60    Description: Graph syntax getattr support list input.
61    Expectation: AttributeError.
62    """
63
64    @ms.jit
65    def foo(x):
66        abs_func = getattr(x, "abs2")
67        return abs_func()
68
69    with pytest.raises(AttributeError) as err:
70        foo([1, 2, 3, 4])  # Not throw error any more, should move to ST.
71    assert "object has no attribute" in str(err.value)
72
73
74@pytest.mark.level2
75@pytest.mark.platform_x86_gpu_training
76@pytest.mark.platform_arm_ascend_training
77@pytest.mark.platform_x86_ascend_training
78@pytest.mark.env_onecard
79def test_getattr_tuple_with_wrong_attr():
80    """
81    Feature: Syntax getattr.
82    Description: Graph syntax getattr support tensor input.
83    Expectation: AttributeError.
84    """
85
86    @ms.jit
87    def foo(x):
88        abs_func = getattr(x, "shape")
89        return abs_func()
90
91    with pytest.raises(AttributeError) as err:
92        foo((1, 2, 3, 4))  # Not throw error any more, should move to ST.
93    assert "object has no attribute" in str(err.value)
94
95
96@pytest.mark.level2
97@pytest.mark.platform_x86_gpu_training
98@pytest.mark.platform_arm_ascend_training
99@pytest.mark.platform_x86_ascend_training
100@pytest.mark.env_onecard
101def test_getattr_dict_with_wrong_attr():
102    """
103    Feature: Syntax getattr.
104    Description: Graph syntax getattr support tensor input.
105    Expectation: AttributeError.
106    """
107
108    @ms.jit
109    def foo(x):
110        abs_func = getattr(x, "abs2")
111        return abs_func()
112
113    with pytest.raises(AttributeError) as err:
114        foo({"1": 1, "2": 2})  # Not throw error any more, should move to ST.
115    assert "object has no attribute" in str(err.value)
116
117
118@pytest.mark.level1
119@pytest.mark.platform_x86_gpu_training
120@pytest.mark.platform_arm_ascend_training
121@pytest.mark.platform_x86_ascend_training
122@pytest.mark.env_onecard
123def test_pyexecute_with_scalar_input():
124    """
125    Feature: Fallback runtime.
126    Description: The pyexecute node has scalar input.
127    Expectation: No error.
128    """
129    def _check_is_inf_nan(x):
130        if math.isinf(x) or math.isnan(x) or np.isinf(x) or np.isnan(x):
131            return True
132        return False
133
134    class InnerNet(nn.Cell):
135        def construct(self, x):
136            return _check_is_inf_nan(x.shape[0])
137
138    net = InnerNet()
139    data = Tensor(np.random.randint(6, size=(2, 4, 3, 4, 5)), dtype=ms.float32)
140    dyn = Tensor(shape=[None, None, None, None, None], dtype=ms.float32)
141    net.set_inputs(dyn)
142    ret = net(data)
143    assert not ret
144
145
146@pytest.mark.level2
147@pytest.mark.platform_x86_gpu_training
148@pytest.mark.platform_arm_ascend_training
149@pytest.mark.platform_x86_ascend_training
150@pytest.mark.env_onecard
151def test_pyexecute_with_scalar_input_2():
152    """
153    Feature: Fallback runtime.
154    Description: The pyexecute node has scalar input.
155    Expectation: No error.
156    """
157    def _check_is_inf_nan(x):
158        if math.isinf(x) or math.isnan(x) or np.isinf(x) or np.isnan(x):
159            return True
160        return False
161
162    class InnerNet(nn.Cell):
163        def construct(self, x):
164            return _check_is_inf_nan(x)
165
166    net = InnerNet()
167    ret = net(math.inf)
168    assert ret
169
170
171@pytest.mark.level2
172@pytest.mark.platform_x86_gpu_training
173@pytest.mark.platform_arm_ascend_training
174@pytest.mark.platform_x86_ascend_training
175@pytest.mark.env_onecard
176def test_pyexecute_with_scalar_input_3():
177    """
178    Feature: Fallback runtime.
179    Description: The pyexecute node has scalar input.
180    Expectation: No error.
181    """
182
183    class InnerNet(nn.Cell):
184        def construct(self, x):
185            shp = x.shape
186            return all(i < 3 for i in shp)
187
188    net = InnerNet()
189    data = Tensor(np.random.randint(6, size=(2, 4, 3, 4, 5)), dtype=ms.float32)
190    dyn = Tensor(shape=[None, None, None, None, None], dtype=ms.float32)
191    net.set_inputs(dyn)
192    ret = net(data)
193    assert not ret
194
195
196@pytest.mark.level2
197@pytest.mark.platform_x86_gpu_training
198@pytest.mark.platform_arm_ascend_training
199@pytest.mark.platform_x86_ascend_training
200@pytest.mark.env_onecard
201def test_pyexecute_with_scalar_input_4():
202    """
203    Feature: Fallback runtime.
204    Description: The pyexecute node has scalar input.
205    Expectation: No error.
206    """
207
208    class InnerNet(nn.Cell):
209        def construct(self, x):
210            shp = x.shape
211            return any(i < 3 for i in shp)
212
213    net = InnerNet()
214    data = Tensor(np.random.randint(6, size=(2, 4, 3, 4, 5)), dtype=ms.float32)
215    dyn = Tensor(shape=[None, None, None, None, None], dtype=ms.float32)
216    net.set_inputs(dyn)
217    ret = net(data)
218    assert ret
219
220
221@pytest.mark.level2
222@pytest.mark.platform_x86_gpu_training
223@pytest.mark.platform_arm_ascend_training
224@pytest.mark.platform_x86_ascend_training
225@pytest.mark.env_onecard
226def test_pyexecute_as_multitype_fg_input():
227    """
228    Feature: Fallback runtime.
229    Description: Pyexecute node can not be used as multitype function graph.
230    Expectation: No error.
231    """
232    class sub_class:
233        def __getitem__(self, item):
234            pass
235        def __setitem__(self, key, target):
236            pass
237
238
239    class InnerNet(nn.Cell):
240        def __init__(self, tuple_input):
241            super(InnerNet, self).__init__()
242            self.data = tuple_input
243
244        def construct(self, start):
245            return self.data[start:]
246
247    sub_class_obj = sub_class()
248    sub_class_obj[0] = [1, 2, 3, 4, 5]
249    net = InnerNet(sub_class_obj)
250    assert net(1) is None
251
252
253def user_mul(x, y):
254    return x * y
255
256
257@ms.jit
258def reduce_user_mul(x):
259    out = reduce(user_mul, x)
260    return out
261
262
263@pytest.mark.level2
264@pytest.mark.platform_x86_gpu_training
265@pytest.mark.platform_arm_ascend_training
266@pytest.mark.platform_x86_ascend_training
267@pytest.mark.env_onecard
268def test_pyexecute_with_func_graph_input():
269    """
270    Feature: Fallback runtime.
271    Description: The pyexecute node has FuncGraph input.
272    Expectation: No error.
273    """
274    x1 = (1, 2, 3)
275    x2 = mutable((1, 2, 3), False)
276    ret1 = reduce_user_mul(x1)
277    ret2 = reduce_user_mul(x2)
278    assert ret1 == 6
279    assert ret2 == 6
280
281
282@pytest.mark.skip('backend not support different type in value tuple')
283@pytest.mark.level1
284@pytest.mark.platform_x86_gpu_training
285@pytest.mark.platform_arm_ascend_training
286@pytest.mark.platform_x86_ascend_training
287@pytest.mark.env_onecard
288def test_fallback_anytype():
289    """
290    Feature: Fallback runtime.
291    Description: test ops input is PyExecute out
292    Expectation: No error.
293    """
294
295    @jit
296    def func(x):
297        x = x.asnumpy()
298        x = ms.Tensor(x)
299        x = ops.ReLU()(x)
300        return x
301
302    def func_numpy(x):
303        return np.maximum(x, 0)
304
305    x_np = np.array([1, -1])
306    ms_out = func(ms.Tensor(np.array([1, -1])))
307    np_out = func_numpy(x_np)
308    assert np.allclose(np_out, ms_out.asnumpy())
309
310
311class CreateDynTensor(nn.Cell):
312    def construct(self, x):
313        # @jit.typing: () -> tensor_type[int32]
314        shape_tensor1 = Tensor(ms.mutable(ops.shape(x)), ms.int32)
315        output1 = ops.FillV2()(shape_tensor1, Tensor(1, ms.int32))
316
317        shape_tensor2 = Tensor(ms.mutable(ops.shape(x)), ms.int32)  # @jit.typing: () -> tensor_type[int32]
318        output2 = ops.FillV2()(shape_tensor2, Tensor(1, ms.int32))
319        return output1 + output2
320
321
322@pytest.mark.level0
323@pytest.mark.platform_x86_gpu_training
324@pytest.mark.platform_arm_ascend_training
325@pytest.mark.platform_x86_ascend_training
326@pytest.mark.env_onecard
327def test_dynamic_shape_tensor():
328    """
329    Feature: Fallback runtime.
330    Description: Set PyExecute output type by the annotation from comment.
331    Expectation: No error.
332    """
333    net = CreateDynTensor()
334    x = Tensor(dtype=ms.int32, input_data=[2, 2])
335    out = net(x)
336    return out
337
338
339class CreateNotDynTensor(nn.Cell):
340    def construct(self, x):
341        # ops.shape(x) is a constant, should not convert to PyExecute.
342        shape_tensor1 = Tensor(ops.shape(x), ms.int32)
343        output1 = ops.FillV2()(shape_tensor1, Tensor(1, ms.int32))
344
345        shape_tensor2 = Tensor(ops.shape(x), ms.int32)
346        output2 = ops.FillV2()(shape_tensor2, Tensor(1, ms.int32))
347        return output1 + output2
348
349
350@pytest.mark.skip('ops.shape(x) is constant, not mutable.')
351@pytest.mark.level1
352@pytest.mark.platform_x86_gpu_training
353@pytest.mark.platform_arm_ascend_training
354@pytest.mark.platform_x86_ascend_training
355@pytest.mark.env_onecard
356def test_not_dynamic_shape_tensor():
357    """
358    Feature: Fallback runtime.
359    Description: Not convert to PyExecute.
360    Expectation: No error.
361    """
362    net = CreateNotDynTensor()
363    x = Tensor(dtype=ms.int32, input_data=[2, 2])
364    out = net(x)
365    return out
366
367
368class CreateDynTensorWithInputDtype(nn.Cell):
369    def construct(self, x, dtype):
370        # @jit.typing: () -> tensor_type[{dtype}]
371        shape_tensor1 = Tensor(ms.mutable(ops.shape(x)), dtype)
372        output1 = ops.FillV2()(shape_tensor1, Tensor(1, dtype))
373
374        shape_tensor2 = Tensor(ms.mutable(ops.shape(x)), dtype)  # @jit.typing: () -> tensor_type[{dtype}]
375        output2 = ops.FillV2()(shape_tensor2, Tensor(1, ms.int32))
376        return output1 + output2
377
378
379@pytest.mark.level0
380@pytest.mark.platform_x86_gpu_training
381@pytest.mark.platform_arm_ascend_training
382@pytest.mark.platform_x86_ascend_training
383@pytest.mark.env_onecard
384def test_dynamic_shape_dtype_tensor():
385    """
386    Feature: Fallback runtime.
387    Description: Set PyExecute output type by the annotation from comment.
388    Expectation: No error.
389    """
390    net = CreateDynTensorWithInputDtype()
391    x = Tensor(dtype=ms.int32, input_data=[2, 2])
392    out = net(x, ms.int32)
393    return out
394
395
396class MakeTensorAsConstant(ms.nn.Cell):
397    def construct(self, x):
398        shape_tensor1 = ms.tensor(ops.shape(x), ms.int32)
399        output1 = ops.FillV2()(shape_tensor1, ms.Tensor(1, ms.int32))
400
401        shape_tensor2 = ms.tensor(ops.shape(x), ms.int32)
402        output2 = ops.FillV2()(shape_tensor2, ms.Tensor(1, ms.int32))
403        return output1 + output2
404
405
406@pytest.mark.level2
407@pytest.mark.platform_x86_gpu_training
408@pytest.mark.platform_arm_ascend_training
409@pytest.mark.platform_x86_ascend_training
410@pytest.mark.env_onecard
411def test_make_tensor_as_constant():
412    """
413    Feature: Fallback runtime.
414    Description: Test tensor API, create constant Tensor on compile time.
415    Expectation: No error.
416    """
417    net = MakeTensorAsConstant()
418    x = ms.Tensor(dtype=ms.int32, input_data=[2, 2])
419    out = net(x)
420    return out
421
422
423class MakeTensorWithShapeDtype(nn.Cell):
424    def construct(self, x):
425        dtype = ms.int32
426        shape_tensor1 = ms.tensor(ms.mutable(ops.shape(x)), dtype)  # shape is mutable, so dtype is used in RT.
427        output1 = ops.FillV2()(shape_tensor1, Tensor(1, dtype))
428
429        shape_tensor2 = ms.tensor(ms.mutable(ops.shape(x)), dtype)
430        output2 = ops.FillV2()(shape_tensor2, Tensor(1, ms.int32))
431        return output1 + output2
432
433
434@pytest.mark.level0
435@pytest.mark.platform_x86_gpu_training
436@pytest.mark.platform_arm_ascend_training
437@pytest.mark.platform_x86_ascend_training
438@pytest.mark.env_onecard
439def test_make_tensor_with_dynamic_shape_dtype():
440    """
441    Feature: Fallback runtime.
442    Description: Test tensor API, in which the PyExecute output type is set by the annotation from comment.
443    Expectation: No error.
444    """
445    net = MakeTensorWithShapeDtype()
446    x = Tensor(dtype=ms.int32, input_data=[2, 2])
447    out = net(x)
448    return out
449
450
451@pytest.mark.level1
452@pytest.mark.platform_x86_ascend_training
453@pytest.mark.platform_arm_ascend_training
454@pytest.mark.platform_x86_gpu_training
455@pytest.mark.env_onecard
456def test_gelu():
457    """
458    Feature: Fallback runtime.
459    Description: Set PyInterpret output type by the annotation from comment.
460    Expectation: No error.
461    """
462    @ms.jit
463    def gelu_forward_1(x):
464        # @jit.typing: () -> tensor_type[float32]
465        return 0.5 * x * (1 + ms.ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ms.ops.pow(x, 3))))
466
467    @ms.jit
468    def gelu_forward_2(x):
469        math_var = math.sqrt(2 / math.pi)
470        pow_var = ms.ops.pow(x, 3)
471        var1 = 0.044715 * pow_var
472        var2 = x + var1
473        var3 = math_var * var2  # @jit.typing: () -> tensor_type[float32]
474        tanh_var = ms.ops.tanh(var3)
475        return 0.5 * x * (1 + tanh_var)
476
477    @ms.jit
478    def gelu_forward_3(x):
479        math_var = math.sqrt(2 / math.pi)
480        pow_var = ms.ops.pow(x, 3)
481        var1 = 0.044715 * pow_var
482        var2 = x + var1
483        var3 = math_var * var2  # No @jit.typing
484        tanh_var = ms.ops.tanh(var3)
485        return 0.5 * x * (1 + tanh_var)
486
487    x = ms.Tensor(9, dtype=ms.float32)
488    res1 = gelu_forward_1(x)
489    res2 = gelu_forward_2(x)
490    res3 = gelu_forward_3(x)
491    assert np.all(res1.asnumpy() == res2.asnumpy())
492    assert np.all(res1.asnumpy() == res3.asnumpy())
493
494
495@pytest.mark.level1
496@pytest.mark.platform_x86_gpu_training
497@pytest.mark.platform_arm_ascend_training
498@pytest.mark.platform_x86_ascend_training
499@pytest.mark.env_onecard
500def test_np_save():
501    """
502    Feature: Fallback runtime.
503    Description: Use numpy.save().
504    Expectation: No error.
505    """
506    @ms.jit
507    def func(x):
508        if isinstance(x, ms.Tensor):
509            np.save("x_data.npy", x.asnumpy())
510
511    x = ms.Tensor([1, 2, 3])
512    func(x)
513    x_load = np.load("x_data.npy")
514    assert np.all(x_load == x.asnumpy())
515    os.remove("x_data.npy")
516
517
518@pytest.mark.level1
519@pytest.mark.platform_x86_gpu_training
520@pytest.mark.platform_arm_ascend_training
521@pytest.mark.platform_x86_ascend_training
522@pytest.mark.env_onecard
523def test_np_save_with_args():
524    """
525    Feature: Fallback runtime.
526    Description: Test numpy.save() and isolated side effect for top args.
527    Expectation: No error.
528    """
529    def _save_tensor(data):
530        np.save("data_from_args.npy", data.asnumpy())
531
532    class NpSaveWithArgsNet(nn.Cell):
533        def construct(self, *args):
534            x = args[0]
535            _save_tensor(x)
536            return x
537
538    x = ms.Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
539    net = NpSaveWithArgsNet()
540    output = net(x)
541    print(f'output: {output}')
542    x_load = np.load("data_from_args.npy")
543    assert np.all(x_load == x.asnumpy())
544    os.remove("data_from_args.npy")
545
546
547@pytest.mark.level2
548@pytest.mark.platform_x86_gpu_training
549@pytest.mark.platform_arm_ascend_training
550@pytest.mark.platform_x86_ascend_training
551@pytest.mark.env_onecard
552def test_np_save_with_call_kw1():
553    """
554    Feature: Fallback runtime.
555    Description: Test numpy.save() and isolated side effect for calling kw function.
556    Expectation: No error.
557    """
558    def _save_tensor(data):
559        np.save("data_from_kw.npy", data.asnumpy())
560
561    class NpSaveWithCallKw(nn.Cell):
562        def construct(self, x):
563            _save_tensor(data=x)
564            return x
565
566    x = ms.Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
567    net = NpSaveWithCallKw()
568    output = net(x)
569    print(f'output: {output}')
570    x_load = np.load("data_from_kw.npy")
571    assert np.all(x_load == x.asnumpy())
572    os.remove("data_from_kw.npy")
573
574
575@pytest.mark.level2
576@pytest.mark.platform_x86_gpu_training
577@pytest.mark.platform_arm_ascend_training
578@pytest.mark.platform_x86_ascend_training
579@pytest.mark.env_onecard
580def test_np_save_with_call_kw2():
581    """
582    Feature: Fallback runtime.
583    Description: Test numpy.save() and isolated side effect for calling kw function with if.
584    Expectation: No error.
585    """
586    def _save_tensor_with_if(data):
587        if True:  # pylint: disable=using-constant-test
588            np.save("data_from_kw_with_if.npy", data.asnumpy())
589
590    class NpSaveWithCallKw(nn.Cell):
591        def construct(self, x):
592            _save_tensor_with_if(data=x)
593            return x
594
595    x = ms.Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
596    net = NpSaveWithCallKw()
597    output = net(x)
598    print(f'output: {output}')
599    x_load = np.load("data_from_kw_with_if.npy")
600    assert np.all(x_load == x.asnumpy())
601    os.remove("data_from_kw_with_if.npy")
602
603
604@pytest.mark.level0
605@pytest.mark.platform_x86_gpu_training
606@pytest.mark.platform_arm_ascend_training
607@pytest.mark.platform_x86_ascend_training
608@pytest.mark.env_onecard
609def test_pyexecute_raise_error_with_dynamic_length_sequence():
610    """
611    Feature: Fallback runtime.
612    Description: Pyexecute node can not be used as multitype function graph.
613    Expectation: No error.
614    """
615    def _check_dim_shape_valid(data, tensor_index):
616        if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
617            raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
618                             f"of the indexed data {data.shape}")
619
620    class InnerNet(nn.Cell):
621        def construct(self, x):
622            idx1 = Tensor([[True, False], [False, True], [True, True]])
623            idx2 = Tensor([True, True, True, False])
624            indices = idx1.nonzero()
625            x1 = ops.gather_nd(x, indices)
626            _check_dim_shape_valid(x1, idx2)
627            return x1
628
629    net = InnerNet()
630    input_x = Tensor(np.arange(6).reshape(3, 2).astype(np.float32))
631    ret = net(input_x)
632    assert np.allclose(ret.asnumpy(), np.array([0.0, 3.0, 4.0, 5.0]))
633
634
635@pytest.mark.level0
636@pytest.mark.platform_x86_gpu_training
637@pytest.mark.platform_arm_ascend_training
638@pytest.mark.platform_x86_ascend_training
639@pytest.mark.env_onecard
640def test_pyexecute_raise_error_with_dynamic_length_sequence_2():
641    """
642    Feature: Fallback runtime.
643    Description: Pyexecute node can not be used as multitype function graph.
644    Expectation: No error.
645    """
646    def _check_dim_shape_valid(data, tensor_index):
647        if data.shape[:tensor_index.ndim] == tensor_index.shape[:]:
648            raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
649                             f"of the indexed data {data.shape}")
650
651    class InnerNet(nn.Cell):
652        def construct(self, x):
653            idx1 = Tensor([[True, False], [False, True], [True, True]])
654            idx2 = Tensor([True, True, True, False])
655            indices = idx1.nonzero()
656            x1 = ops.gather_nd(x, indices)
657            _check_dim_shape_valid(x1, idx2)
658            return x1
659
660    with pytest.raises(IndexError) as err:
661        net = InnerNet()
662        input_x = Tensor(np.arange(6).reshape(3, 2).astype(np.float32))
663        net(input_x)
664    assert "does not match the shape" in str(err.value)
665
666
667@pytest.mark.level2
668@pytest.mark.platform_x86_gpu_training
669@pytest.mark.platform_arm_ascend_training
670@pytest.mark.platform_x86_ascend_training
671@pytest.mark.env_onecard
672def test_parse_subscript():
673    """
674    Feature: JIT Fallback
675    Description: Test Interpret node in subscript in graph mode.
676    Expectation: No exception.
677    """
678
679    class Network(nn.Cell):
680        def construct(self):
681            x = [Tensor([11]), Tensor([22]), Tensor([33])]
682            y = x[Tensor([0])] + x[Tensor([1])] + x[Tensor([2])]
683            return y
684
685    net = Network()
686    out = net()
687    assert out.asnumpy() == 66
688
689
690@pytest.mark.level1
691@pytest.mark.platform_x86_cpu
692@pytest.mark.env_onecard
693def test_tensor_func():
694    """
695    Feature: JIT Fallback
696    Description: Test tensor function in graph mode.
697    Expectation: No exception.
698    """
699    ms.set_context(mode=ms.GRAPH_MODE)
700
701    @jit
702    def func(x):
703        x = tensor(x.asnumpy(), x.dtype)
704        return ops.Abs()(x)
705
706    class Net(nn.Cell):
707        def __init__(self):
708            super(Net, self).__init__()
709            self.abs = ops.Abs()
710
711        def construct(self, x, y):
712            y1 = tensor(x.asnumpy() + y.asnumpy(), dtype=ms.float32)
713            return self.abs(y1)
714
715    x = Tensor([-1, 1, 2, -2], dtype=ms.float32)
716    x_np = np.array([1, 1, 2, 2], dtype=np.float32)
717    x = func(x)
718    ms_x_np = x.asnumpy()
719    assert ms_x_np.dtype == x_np.dtype
720    assert np.allclose(ms_x_np, x_np)
721
722    net = Net()
723    x = ms.Tensor(-1, dtype=ms.int32)
724    y = ms.Tensor(-1, dtype=ms.float32)
725    result = net(x, y)
726    result_ms_np = result.asnumpy()
727    exp_np = np.array(2, dtype=np.float32)
728    assert np.allclose(result_ms_np, exp_np)
729    assert result_ms_np.dtype == exp_np.dtype
730
731
732@pytest.mark.level0
733@pytest.mark.platform_x86_gpu_training
734@pytest.mark.env_onecard
735def test_fallback_self_variable_as_func_args():
736    """
737    Feature: JIT Fallback
738    Description: Use self as variable name
739    Expectation: No exception
740    """
741
742    class Network(nn.Cell):
743        def __init__(self):
744            super(Network, self).__init__()
745            self.value = 5
746
747        def construct(self, x, y):
748            return Network.func(self, x, y)
749
750        def func(self, x, y):
751            return x + y
752
753    net = Network()
754    out = net(Tensor([1], dtype=ms.float32), Tensor([2], dtype=ms.float32))
755    assert np.allclose(out.asnumpy(), np.array([3], dtype=np.float32))
756
757
758@pytest.mark.skip(reason="ParseCall convert whole script to pyexecute")
759@pytest.mark.level0
760@pytest.mark.platform_x86_gpu_training
761@pytest.mark.env_onecard
762def test_fallback_tensor_with_variable_input():
763    """
764    Feature: JIT Fallback
765    Description: Generate Tensor with graph.
766    Expectation: No exception
767    """
768
769    @jit
770    def foo(x):
771        return Tensor([0], dtype=x.dtype)
772
773    os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '0'
774    ret = foo(Tensor([1, 2, 3]))
775    assert ret == Tensor([0])
776    os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2'
777
778
779@pytest.mark.level0
780@pytest.mark.platform_x86_gpu_training
781@pytest.mark.env_onecard
782def test_fallback_map_with_variable_input():
783    """
784    Feature: JIT Fallback
785    Description: Generate Tensor with graph.
786    Expectation: No exception
787    """
788
789    @jit
790    def foo(x, y):
791        m = map(lambda a, b: a + b, x.asnumpy(), y.asnumpy())
792        return tuple(m)
793
794    ret = foo(Tensor([1, 2, 3]), Tensor([4, 5, 6]))
795    assert ret == (5, 7, 9)
796