1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_parse_numpy """ 16import os 17import shutil 18import subprocess 19import numpy as np 20import mindspore as ms 21from mindspore import nn 22from mindspore import context 23from mindspore import ms_function 24from mindspore import Tensor 25from tests.security_utils import security_off_wrap 26 27 28def find_files(file, para): 29 output = subprocess.check_output( 30 ["grep '%s' %s | wc -l" % (para, file)], 31 shell=True) 32 out = str(output, 'utf-8').strip() 33 return out 34 35 36def remove_path(path): 37 if os.path.exists(path): 38 shutil.rmtree(path) 39 40 41@security_off_wrap 42def test_ms_function(): 43 @ms_function 44 def add(x): 45 return x + 1 46 47 context.set_context(mode=context.GRAPH_MODE) 48 context.set_context(save_graphs=True, save_graphs_path="ir_dump_path") 49 input1 = np.random.randn(5, 5) 50 add(Tensor(input1, ms.float32)) 51 result = find_files("./ir_dump_path/*validate*.ir", "test_debug_info.py(45)/ return x + 1/") 52 assert result == '2' 53 remove_path("./ir_dump_path/") 54 55 56@security_off_wrap 57def test_cell_ms_function(): 58 class Net(nn.Cell): 59 60 @ms_function 61 def construct(self, x): 62 return x 63 64 context.set_context(mode=context.GRAPH_MODE) 65 context.set_context(save_graphs=True, save_graphs_path="ir_dump_path") 66 input1 = np.random.randn(5, 5) 67 net = Net() 68 net(Tensor(input1, ms.float32)) 69 result = find_files("./ir_dump_path/*validate*.ir", "test_debug_info.py(62)/ return x/") 70 assert result == '1' 71 remove_path("./ir_dump_path/") 72