1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15''' test FOR_ITER for pijit ''' 16import pytest 17import dis 18import sys 19from mindspore import jit, Tensor 20from mindspore._c_expression import get_code_extra 21 22 23def for_range(x): 24 res = 0 25 for i in range(x): 26 res = res + i 27 return res 28 29 30def for_enumerate(x): 31 x = [x, x, x] 32 res = 0 33 for i, v in enumerate(x): 34 res = res + i 35 res = res + v 36 return x 37 38 39def for_zip(x): 40 x = [x, x, x] 41 v = None 42 for v in zip(x, x, x, x): 43 pass 44 return v 45 46 47def for_mix(x): 48 x = [x, x, x] 49 res = 0 50 for i, v in enumerate(list(zip(x, x, x, x))): 51 res = res + i 52 res = res + v[0] 53 return res 54 55 56def for_mix_with_sideeffect(x): 57 x = [x, x, x] 58 z = zip(list(enumerate(x))) 59 for i in z: 60 if i[0] == 1: 61 break 62 return list(z) 63 64 65@pytest.mark.level0 66@pytest.mark.platform_x86_cpu 67@pytest.mark.env_onecard 68@pytest.mark.parametrize('func', [for_range, for_enumerate, for_mix, for_zip]) 69@pytest.mark.parametrize('param', [1, Tensor([1])]) 70def test_for_iter_unrolling(func, param): 71 """ 72 Feature: Test loop unrolling 73 Description: Test loop unrolling 74 Expectation: No exception. 75 """ 76 config = {"loop_unrolling": True} 77 excepted = func(param) 78 result = jit(fn=func, mode="PIJit", jit_config=config)(param) 79 jcr = get_code_extra(func) 80 81 assert jcr["stat"] == "GRAPH_CALLABLE" 82 assert jcr["code"]["call_count_"] > 0 83 assert excepted == result 84 85 86@pytest.mark.level0 87@pytest.mark.platform_x86_cpu 88@pytest.mark.env_onecard 89@pytest.mark.parametrize('func', [for_mix_with_sideeffect]) 90@pytest.mark.parametrize('param', [1, Tensor([1])]) 91def test_not_implement_for_iter(func, param): 92 """ 93 Feature: Test loop unrolling 94 Description: Test loop unrolling 95 Expectation: No exception. 96 """ 97 config = {"loop_unrolling": True} 98 excepted = func(param) 99 result = jit(fn=func, mode="PIJit", jit_config=config)(param) 100 jcr = get_code_extra(func) 101 102 assert jcr["stat"] == "GRAPH_CALLABLE" 103 assert jcr["code"]["call_count_"] > 0 104 assert excepted == result 105