• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Define the namespace of MindSpore op definition."""
16import os
17import sys
18import inspect
19import types
20from mindspore._extends.parse.namespace import ModuleNamespace
21from mindspore.nn import CellList, SequentialCell
22
23
24_ms_common_ns = ModuleNamespace('mindspore.common')
25_ms_nn_ns = ModuleNamespace('mindspore.nn')
26_ms_ops_ns = ModuleNamespace('mindspore.ops.operations')
27_ms_functional_ns = ModuleNamespace('mindspore.ops.functional')
28
29# Elements in _subtree_black_list will not be converted to symbol tree.
30# Only str and types are stored in _subtree_black_list.
31_subtree_black_list = [CellList, SequentialCell]
32# Whether to convert mindspore built-in cells to symbol tree.
33_ms_cells_to_subtree = False
34# Paths of modules which will not be considered as third party module
35_ignore_third_party_paths = []
36
37def is_subtree(cls_inst):
38    """Determine whether 'cls_inst' is a subtree."""
39    cls_name = type(cls_inst).__name__
40    if isinstance(cls_inst, tuple(_subtree_black_list)):
41        return False
42    if cls_name in _ms_common_ns and isinstance(cls_inst, _ms_common_ns[cls_name]):
43        return False
44    if cls_name in _ms_nn_ns and isinstance(cls_inst, _ms_nn_ns[cls_name]):
45        return bool(_ms_cells_to_subtree)
46    if cls_name in _ms_ops_ns and isinstance(cls_inst, _ms_ops_ns[cls_name]):
47        return False
48    return True
49
50
51def is_ms_function(func_obj):
52    """Determine whether 'func_obj' is a mindspore function."""
53    if isinstance(func_obj, types.BuiltinFunctionType):
54        return False
55    try:
56        # module, class, method, function, traceback, frame, or code object was expected
57        func_file = inspect.getabsfile(func_obj)
58    except TypeError:
59        return False
60    func_file = os.path.normcase(func_file)
61    ms_module = sys.modules.get('mindspore')
62    if ms_module is None:
63        return False
64    ms_path = ms_module.__file__
65    ms_path = os.path.normcase(ms_path)
66    ms_path = ms_path.rsplit(os.path.sep, 1)[0]
67    return func_file.startswith(ms_path)
68
69
70def is_functional(func_name):
71    """Determine whether 'cls_name' is a functional."""
72    return func_name in _ms_functional_ns
73
74
75def get_functional(func_name):
76    """Get the function corresponding to the func_name."""
77    if func_name in _ms_functional_ns:
78        return _ms_functional_ns[func_name]
79    return None
80
81
82def is_third_party(func_obj):
83    """Check whether func_obj is from third party module"""
84    module = inspect.getmodule(func_obj)
85    # A module without __file__ attribute (normally to be a c++ lib) is considered to be third party module.
86    if not hasattr(module, '__file__'):
87        return True
88    module_path = os.path.abspath(module.__file__)
89    for path in _ignore_third_party_paths:
90        if module_path.startswith(path):
91            return False
92    # Python builtin modules are treated as third-party libraries.
93    python_builtin_dir = os.path.abspath(os.path.dirname(os.__file__))
94    if module_path.startswith(python_builtin_dir):
95        return True
96    # Check if module is under user workspace directory.
97    user_workspace_dir = get_top_level_module_path(os.getcwd())
98    if module_path.startswith(user_workspace_dir):
99        return False
100    # Third-party modules are under site-packages.
101    split_path = module_path.split(os.path.sep)
102    if "site-packages" in split_path:
103        return True
104    return False
105
106
107def get_top_level_module_path(module_path):
108    """Get the path of the top level package of the current working directory."""
109    module_abspath = os.path.abspath(module_path)
110    upper_path = os.path.abspath(os.path.dirname(module_abspath))
111    if module_abspath == upper_path:
112        return module_abspath
113    # Check whether __init__.py exists in the upper directory.
114    init_path = os.path.join(upper_path, '__init__.py')
115    # If the path does not exist or is accessed without permission, os.path.isfile returns false.
116    if os.path.isfile(init_path):
117        module_abspath = get_top_level_module_path(upper_path)
118    return module_abspath
119