1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 20 #define MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 21 22 #include <string> 23 #include <map> 24 #include <memory> 25 #include <vector> 26 #include <algorithm> 27 #include "ir/dtype.h" 28 #include "ir/anf.h" 29 #include "ir/func_graph.h" 30 #include "ir/signature.h" 31 #include "abstract/abstract_value.h" 32 33 namespace mindspore { 34 // namespace to support intermediate representation definition 35 // Graph generator. 36 // Can be called with a pipeline's resources and a list of argument types to 37 // generate a graph corresponding to these types. 38 class MS_CORE_API MetaFuncGraph : public FuncGraphBase { 39 public: MetaFuncGraph(const std::string & name)40 explicit MetaFuncGraph(const std::string &name) : name_(name) { 41 cache_.clear(); 42 debug_info_ = std::make_shared<DebugInfo>(); 43 } 44 ~MetaFuncGraph()45 ~MetaFuncGraph() { subclass_destruct_flag_ = true; } 46 47 MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); 48 // Return normalized versions of the arguments. 49 // By default, this returns args unchanged. NormalizeArgs(const abstract::AbstractBasePtrList & args_abs_list)50 virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_abs_list) const { 51 return args_abs_list; 52 } 53 abstract::AbstractBasePtr ToAbstract() override; signatures()54 const std::vector<Signature> &signatures() const { return signatures_; } set_signatures(const std::vector<Signature> & signatures)55 void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; } 56 // Generate a Graph for the given abstract arguments. 57 virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_abs_list); 58 59 // Generate a Graph for this type signature. GenerateFromTypes(const TypePtrList &)60 virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { 61 MS_LOG(INTERNAL_EXCEPTION) << "Undefined the method of generating graph from types. func_name:" << name(); 62 } 63 name()64 std::string name() { return name_; } ToString()65 std::string ToString() const override { 66 std::ostringstream buffer; 67 buffer << "MetaFuncGraph-"; 68 buffer << name_; 69 buffer << "." << debug_info_->get_id(); 70 return buffer.str(); 71 } hash()72 std::size_t hash() const override { return tid(); } 73 74 virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } 75 bool operator==(const Value &other) const override { 76 if (other.isa<MetaFuncGraph>()) { 77 return &other == this; 78 } else { 79 return false; 80 } 81 } 82 DoBreakLoop()83 void DoBreakLoop() override { cache_.clear(); } 84 set_node_expr_src(const std::string & node_expr_src)85 void set_node_expr_src(const std::string &node_expr_src) { node_expr_src_ = node_expr_src; } 86 set_scope_name(const std::string & scope_name)87 void set_scope_name(const std::string &scope_name) { scope_name_ = scope_name; } scope_name()88 std::string scope_name() { return scope_name_; } 89 90 protected: 91 template <typename Derived> shared_from_base()92 std::shared_ptr<Derived> shared_from_base() { 93 return std::static_pointer_cast<Derived>(shared_from_this()); 94 } 95 FuncGraphPtr GenerateStubFunc(const TypePtrList &types) const; 96 std::string name_; 97 std::vector<Signature> signatures_; 98 TypeListMap<FuncGraphPtr> cache_; 99 std::string node_expr_src_ = ""; 100 std::string scope_name_ = ""; 101 102 private: 103 DebugInfoPtr debug_info_{nullptr}; 104 }; 105 106 using MetaFuncGraphPtr = std::shared_ptr<MetaFuncGraph>; 107 } // namespace mindspore 108 109 #endif // MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 110