1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "ir/anf.h" 25 #include "ir/manager.h" 26 #include "include/common/utils/python_adapter.h" 27 #include "pipeline/jit/ps/parse/parse_base.h" 28 #include "abstract/abstract_value.h" 29 #include "utils/log_adapter.h" 30 31 // forward declaration of ResourceBase 32 namespace mindspore { 33 namespace pipeline { 34 class ResourceBase; 35 using ResourceBasePtr = std::shared_ptr<ResourceBase>; 36 } // namespace pipeline 37 } // namespace mindspore 38 39 namespace mindspore { 40 namespace parse { 41 // NameSpace class for resolving python code. 42 class NameSpace final : public Named { 43 public: 44 NameSpace(const std::string &module, const py::object &namespace_obj, const py::object &module_obj = py::none()) 45 : Named(module + ": \'" + std::string(py::str(namespace_obj)) + "\'"), 46 module_(module), 47 namespace_obj_(namespace_obj), 48 module_obj_(module_obj) {} ~NameSpace()49 ~NameSpace() override { 50 py::gil_scoped_acquire gil_acquire; 51 namespace_obj_ = py::none(); 52 module_obj_ = py::none(); 53 } 54 MS_DECLARE_PARENT(NameSpace, Named); 55 namespace_obj()56 const py::object &namespace_obj() const { return namespace_obj_; } module_obj()57 const py::object &module_obj() const { return module_obj_; } set_module_obj(py::object module_obj)58 void set_module_obj(py::object module_obj) { module_obj_ = module_obj; } module()59 const std::string &module() const { return module_; } ToAbstract()60 abstract::AbstractBasePtr ToAbstract() override { 61 return std::make_shared<abstract::AbstractScalar>(shared_from_base<NameSpace>(), std::make_shared<External>()); 62 } 63 64 private: 65 // namespace of the module 66 std::string module_; 67 // namespace object 68 py::object namespace_obj_; 69 // module object 70 py::object module_obj_; 71 }; 72 using NameSpacePtr = std::shared_ptr<NameSpace>; 73 74 // Symbol in NameSpace or Class which shall be resolved. 75 class Symbol final : public Named { 76 public: Symbol(const std::string & symbol)77 explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} Symbol(const std::string & symbol,const std::string & name)78 Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} 79 80 ~Symbol() override = default; 81 MS_DECLARE_PARENT(Symbol, Named); 82 symbol()83 const std::string &symbol() const { return symbol_; } ToAbstract()84 abstract::AbstractBasePtr ToAbstract() override { 85 return std::make_shared<abstract::AbstractScalar>(shared_from_base<Symbol>(), std::make_shared<External>()); 86 } 87 88 private: 89 std::string symbol_; 90 }; 91 using SymbolPtr = std::shared_ptr<Symbol>; 92 93 class Script final : public Named { 94 public: Script(const std::string & script)95 explicit Script(const std::string &script) : Named(script), script_(script) {} Script(const std::string & script,const std::string & name)96 Script(const std::string &script, const std::string &name) : Named(name), script_(script) {} 97 98 ~Script() override = default; 99 MS_DECLARE_PARENT(Script, Named); 100 script()101 std::string script() const { return script_; } ToAbstract()102 abstract::AbstractBasePtr ToAbstract() override { 103 return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>()); 104 } ToString()105 std::string ToString() const override { return "\'" + name() + "\'"; } 106 107 private: 108 std::string script_; 109 }; 110 using ScriptPtr = std::shared_ptr<Script>; 111 112 // PyObjectWrapper class wrappers resolved python object for further processing. 113 class PyObjectWrapper : public Named { 114 public: 115 explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") Named(name)116 : Named(name), obj_(std::make_unique<py::object>(obj)) {} ~PyObjectWrapper()117 ~PyObjectWrapper() { 118 py::gil_scoped_acquire acquire_gil; 119 obj_ = nullptr; 120 } 121 122 MS_DECLARE_PARENT(PyObjectWrapper, Named); obj()123 py::object obj() const { return *obj_; } 124 hash()125 std::size_t hash() const override { return tid(); } 126 127 virtual bool operator==(const PyObjectWrapper &other) const { 128 if (obj().get_type() != other.obj().get_type()) { 129 return false; 130 } 131 try { 132 return obj().equal(other.obj()); catch(const std::exception & e)133 } catch (const std::exception &e) { 134 // Return false if the comparison is ambiguous. Such as numpy.array. 135 MS_LOG(INFO) << e.what() << "\n" 136 << "This: {" << py::str(obj()) << ", " << py::str(obj().get_type()) << "}, Other: {" 137 << py::str(other.obj()) << ", " << py::str(other.obj().get_type()) << "}"; 138 return false; 139 } 140 } 141 bool operator==(const Named &other) const override { 142 if (other.isa<PyObjectWrapper>()) { 143 auto &other_py_obj = static_cast<const PyObjectWrapper &>(other); 144 return *this == other_py_obj; 145 } 146 return false; 147 } 148 149 private: 150 // The object that needs to be resolved 151 std::unique_ptr<py::object> obj_; 152 }; 153 using PyObjectWrapperPtr = std::shared_ptr<PyObjectWrapper>; 154 155 // InterpretedObject class wrappers interpreted python object. 156 class InterpretedObject final : public PyObjectWrapper { 157 public: 158 explicit InterpretedObject(const py::object &obj); 159 ~InterpretedObject() override = default; 160 MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper); ToAbstract()161 abstract::AbstractBasePtr ToAbstract() override { 162 return std::make_shared<abstract::AbstractScalar>(shared_from_base<InterpretedObject>(), 163 std::make_shared<External>()); 164 } set_has_converted(bool has_converted)165 void set_has_converted(bool has_converted) { has_converted_ = has_converted; } has_converted()166 bool has_converted() const { return has_converted_; } 167 168 private: 169 bool has_converted_ = false; 170 }; 171 using InterpretedObjectPtr = std::shared_ptr<InterpretedObject>; 172 173 class MsClassObject final : public PyObjectWrapper { 174 public: MsClassObject(const py::object & obj,const std::string & name)175 explicit MsClassObject(const py::object &obj, const std::string &name) 176 : PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {} 177 ~MsClassObject() override = default; 178 MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper); 179 abstract::AbstractBasePtr ToAbstract() override; 180 }; 181 using MsClassObjectPtr = std::shared_ptr<MsClassObject>; 182 183 // ClassType class wrappers class name in python 184 class ClassType final : public PyObjectWrapper { 185 public: 186 explicit ClassType(const py::object &obj, const std::string &name = "Python class type") PyObjectWrapper(obj,name)187 : PyObjectWrapper(obj, name) {} 188 ~ClassType() override = default; 189 MS_DECLARE_PARENT(ClassType, PyObjectWrapper); 190 abstract::AbstractBasePtr ToAbstract() override; 191 }; 192 using ClassTypePtr = std::shared_ptr<ClassType>; 193 194 // Resolve symbol in namespace. 195 AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, 196 const AnfNodePtr &node); 197 AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node, 198 const AnfNodePtr &attr_node, const AnfNodePtr &node); 199 AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node, 200 const AnfNodePtr &attr_node, const AnfNodePtr &node); 201 AnfNodePtr ResolveClassObjectWithAttr(const py::object &cls_obj, const AnfNodePtr &attr, const AnfNodePtr &node); 202 203 AnfNodePtr ResolveInterpretedObjectOfSetAttr(const AnfNodePtr &target_node, const AnfNodePtr &attr_node, 204 const AnfNodePtr &value_node); 205 206 AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj); 207 // Check if node is cnode with getitem. 208 bool IsGetItemCNode(const AnfNodePtr &node); 209 210 // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). 211 bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); 212 213 // Resolve all graphs in manager which is defined outside of pipeline::Resource. 214 // Mainly used for test cases or resolve graphs which will not be managed by manager. 215 bool ResolveAll(const FuncGraphManagerPtr &manager); 216 217 py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node); 218 bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, AnfNodePtr *const node, 219 bool is_element_obj = false); 220 ValuePtr GetParameterValue(const py::object ¶m_obj); 221 } // namespace parse 222 } // namespace mindspore 223 224 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ 225