• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_parse_numpy """
16import os
17import shutil
18import subprocess
19import numpy as np
20import mindspore as ms
21from mindspore import nn
22from mindspore import context
23from mindspore import ms_function
24from mindspore import Tensor
25from tests.security_utils import security_off_wrap
26
27
28def find_files(file, para):
29    output = subprocess.check_output(
30        ["grep '%s' %s | wc -l" % (para, file)],
31        shell=True)
32    out = str(output, 'utf-8').strip()
33    return out
34
35
36def remove_path(path):
37    if os.path.exists(path):
38        shutil.rmtree(path)
39
40
41@security_off_wrap
42def test_ms_function():
43    @ms_function
44    def add(x):
45        return x + 1
46
47    context.set_context(mode=context.GRAPH_MODE)
48    context.set_context(save_graphs=True, save_graphs_path="ir_dump_path")
49    input1 = np.random.randn(5, 5)
50    add(Tensor(input1, ms.float32))
51    result = find_files("./ir_dump_path/*validate*.ir", "test_debug_info.py(45)/        return x + 1/")
52    assert result == '2'
53    remove_path("./ir_dump_path/")
54
55
56@security_off_wrap
57def test_cell_ms_function():
58    class Net(nn.Cell):
59
60        @ms_function
61        def construct(self, x):
62            return x
63
64    context.set_context(mode=context.GRAPH_MODE)
65    context.set_context(save_graphs=True, save_graphs_path="ir_dump_path")
66    input1 = np.random.randn(5, 5)
67    net = Net()
68    net(Tensor(input1, ms.float32))
69    result = find_files("./ir_dump_path/*validate*.ir", "test_debug_info.py(62)/            return x/")
70    assert result == '1'
71    remove_path("./ir_dump_path/")
72