1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 20 #define MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 21 22 #include <unordered_map> 23 #include <string> 24 #include <map> 25 #include <memory> 26 #include <vector> 27 #include <algorithm> 28 29 #include "ir/dtype.h" 30 #include "ir/anf.h" 31 #include "ir/func_graph.h" 32 #include "ir/signature.h" 33 #include "abstract/abstract_value.h" 34 35 namespace mindspore { 36 // namespace to support intermediate representation definition 37 // Graph generator. 38 // Can be called with a pipeline's resources and a list of argument types to 39 // generate a graph corresponding to these types. 40 class MetaFuncGraph : public FuncGraphBase { 41 public: MetaFuncGraph(const std::string & name)42 explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } 43 44 ~MetaFuncGraph() override = default; 45 46 MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); 47 // Return normalized versions of the arguments. 48 // By default, this returns args unchanged. NormalizeArgs(const abstract::AbstractBasePtrList & args_spec_list)49 virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { 50 return args_spec_list; 51 } 52 abstract::AbstractBasePtr ToAbstract() override; signatures()53 const std::vector<Signature> &signatures() const { return signatures_; } set_signatures(const std::vector<Signature> & signatures)54 void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; } 55 // Generate a Graph for the given abstract arguments. 56 virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list); 57 58 // Generate a Graph for this type signature. GenerateFromTypes(const TypePtrList &)59 virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { 60 MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; 61 } 62 name()63 std::string name() { return name_; } ToString()64 std::string ToString() const override { return name_; } hash()65 std::size_t hash() const override { return tid(); } 66 67 virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } 68 bool operator==(const Value &other) const override { 69 if (other.isa<MetaFuncGraph>()) { 70 return &other == this; 71 } else { 72 return false; 73 } 74 } 75 76 protected: 77 template <typename Derived> shared_from_base()78 std::shared_ptr<Derived> shared_from_base() { 79 return std::static_pointer_cast<Derived>(shared_from_this()); 80 } 81 FuncGraphPtr GenerateStubFunc(const TypePtrList &types); 82 std::string name_; 83 std::vector<Signature> signatures_; 84 std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_; 85 }; 86 87 using MetaFuncGraphPtr = std::shared_ptr<MetaFuncGraph>; 88 } // namespace mindspore 89 90 #endif // MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ 91