• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15''' test FOR_ITER for pijit '''
16import pytest
17import dis
18import sys
19from mindspore import jit, Tensor
20from mindspore._c_expression import get_code_extra
21
22
23def for_range(x):
24    res = 0
25    for i in range(x):
26        res = res + i
27    return res
28
29
30def for_enumerate(x):
31    x = [x, x, x]
32    res = 0
33    for i, v in enumerate(x):
34        res = res + i
35        res = res + v
36    return x
37
38
39def for_zip(x):
40    x = [x, x, x]
41    v = None
42    for v in zip(x, x, x, x):
43        pass
44    return v
45
46
47def for_mix(x):
48    x = [x, x, x]
49    res = 0
50    for i, v in enumerate(list(zip(x, x, x, x))):
51        res = res + i
52        res = res + v[0]
53    return res
54
55
56def for_mix_with_sideeffect(x):
57    x = [x, x, x]
58    z = zip(list(enumerate(x)))
59    for i in z:
60        if i[0] == 1:
61            break
62    return list(z)
63
64
65@pytest.mark.level0
66@pytest.mark.platform_x86_cpu
67@pytest.mark.env_onecard
68@pytest.mark.parametrize('func', [for_range, for_enumerate, for_mix, for_zip])
69@pytest.mark.parametrize('param', [1, Tensor([1])])
70def test_for_iter_unrolling(func, param):
71    """
72    Feature: Test loop unrolling
73    Description: Test loop unrolling
74    Expectation: No exception.
75    """
76    config = {"loop_unrolling": True}
77    excepted = func(param)
78    result = jit(fn=func, mode="PIJit", jit_config=config)(param)
79    jcr = get_code_extra(func)
80
81    assert jcr["stat"] == "GRAPH_CALLABLE"
82    assert jcr["code"]["call_count_"] > 0
83    assert excepted == result
84
85
86@pytest.mark.level0
87@pytest.mark.platform_x86_cpu
88@pytest.mark.env_onecard
89@pytest.mark.parametrize('func', [for_mix_with_sideeffect])
90@pytest.mark.parametrize('param', [1, Tensor([1])])
91def test_not_implement_for_iter(func, param):
92    """
93    Feature: Test loop unrolling
94    Description: Test loop unrolling
95    Expectation: No exception.
96    """
97    config = {"loop_unrolling": True}
98    excepted = func(param)
99    result = jit(fn=func, mode="PIJit", jit_config=config)(param)
100    jcr = get_code_extra(func)
101
102    assert jcr["stat"] == "GRAPH_CALLABLE"
103    assert jcr["code"]["call_count_"] > 0
104    assert excepted == result
105