1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PATTERN_TO_PATTERN_H 18 #define MINDSPORE_PATTERN_TO_PATTERN_H 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "base/base.h" 26 #include "base/base_ref.h" 27 #include "include/backend/optimizer/optimizer.h" 28 #include "include/backend/visible.h" 29 30 namespace mindspore { 31 namespace opt { 32 bool BACKEND_EXPORT AlwaysReturnTrue(const BaseRef &); 33 34 class BACKEND_EXPORT PatternMap { 35 public: 36 PatternMap() = default; 37 bool Contains(const std::string &name) const; 38 bool CheckSeq(const std::string &name) const; 39 AnfNodePtr Get(const std::string &name) const; 40 const std::vector<AnfNodePtr> &GetSeq(const std::string &name) const; 41 bool Emplace(const std::string &name, const AnfNodePtr &node); 42 bool Emplace(const std::string &name, const std::vector<AnfNodePtr> &v); 43 void Clear(); 44 bool Check(const std::string &name, const AnfNodePtr &node) const; 45 void Erase(const mindspore::HashSet<std::string> &del_set); GetOptScope()46 const mindspore::HashSet<AnfNodePtr> &GetOptScope() const { return opt_scope_; } 47 48 private: 49 mindspore::HashSet<std::string> name_set_; 50 mindspore::HashMap<std::string, AnfNodePtr> node_map_; 51 mindspore::HashMap<std::string, std::vector<AnfNodePtr>> seq_map_; 52 mindspore::HashSet<AnfNodePtr> opt_scope_; 53 }; 54 55 using PatternMapPtr = std::shared_ptr<PatternMap>; 56 using BuildCNodeFunc = std::function<AnfNodePtr(const PatternMap &, const AnfNodePtr &)>; 57 using BuildValueFunc = std::function<AnfNodePtr(const PatternMap &)>; 58 59 class BACKEND_EXPORT DefaultCNodeFunc { 60 public: 61 DefaultCNodeFunc() = default; operator()62 AnfNodePtr operator()(const PatternMap &, const AnfNodePtr &default_cnode) const { return default_cnode; } 63 }; 64 65 class BACKEND_EXPORT InplaceCNodeFunc { 66 public: InplaceCNodeFunc(std::string s)67 explicit InplaceCNodeFunc(std::string s) : s_(std::move(s)) {} operator()68 AnfNodePtr operator()(const PatternMap &m, const AnfNodePtr & /* default_cnode */) const { return m.Get(s_); } 69 70 private: 71 std::string s_; 72 }; 73 74 class BACKEND_EXPORT DefaultValueFunc { 75 public: DefaultValueFunc(ValuePtr v)76 explicit DefaultValueFunc(ValuePtr v) : v_(std::move(v)) {} operator()77 AnfNodePtr operator()(const PatternMap &) const { return NewValueNode(v_); } 78 79 private: 80 ValuePtr v_; 81 }; 82 83 class BACKEND_EXPORT InplaceValueFunc { 84 public: InplaceValueFunc(std::string s)85 explicit InplaceValueFunc(std::string s) : s_(std::move(s)) {} operator()86 AnfNodePtr operator()(const PatternMap &m) const { return m.Get(s_); } 87 88 private: 89 std::string s_; 90 }; 91 92 class BACKEND_EXPORT PatternToPatternPass; 93 class BACKEND_EXPORT UnpackNode { 94 public: 95 UnpackNode &operator=(const std::string &name); 96 UnpackNode &operator=(const UnpackNode &u) = default; 97 UnpackNode(const UnpackNode &u) = default; 98 99 private: UnpackNode(AnfNodePtr node)100 explicit UnpackNode(AnfNodePtr node) : node_(std::move(node)) {} 101 AnfNodePtr node_ = nullptr; 102 std::string key_; 103 friend class DstPattern; 104 friend class PatternToPatternPass; 105 }; 106 107 class BACKEND_EXPORT PatternNode { 108 public: PatternNode(const PrimitivePtr & p)109 PatternNode(const PrimitivePtr &p) // NOLINT 110 : type_("prim"), p_(NewValueNode(std::make_shared<Primitive>(p->name()))) {} PatternNode(const char * name)111 PatternNode(const char *name) : type_("name"), name_(name) {} // NOLINT PatternNode(std::vector<UnpackNode> & v)112 PatternNode(std::vector<UnpackNode> &v) : type_("unpack"), v_(v) {} // NOLINT 113 PatternNode(const PatternNode &) = default; 114 PatternNode &operator=(const PatternNode &) = default; 115 116 private: 117 std::string type_; 118 std::string name_; 119 ValueNodePtr p_; 120 std::vector<UnpackNode> v_; 121 friend class SrcPattern; 122 friend class DstPattern; 123 }; 124 125 class BACKEND_EXPORT SrcPattern { 126 public: 127 SrcPattern &AddVar(const std::string &name, const PatternConditionFunc &f = AlwaysReturnTrue); 128 SrcPattern &AddSeqVar(const std::string &name, const PatternConditionFunc &f = AlwaysReturnTrue); 129 const BaseRef &GetRef(const std::string &name) const; 130 SrcPattern &AddCNode(const std::string &name, const std::initializer_list<PatternNode> &v); 131 BaseRef GetRoot() const; 132 133 private: SrcPattern(PatternMapPtr m)134 explicit SrcPattern(PatternMapPtr m) : m_(std::move(m)), has_root_(false) {} 135 bool CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv, const std::vector<PatternNode> &inputs, 136 size_t *now_pattern); 137 bool match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv); 138 bool build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv); 139 PatternMapPtr m_; 140 mindspore::HashMap<std::string, BaseRef> ref_map_; 141 mindspore::HashMap<std::string, std::vector<PatternNode>> inputs_map_; 142 bool has_root_; 143 std::string root_; 144 friend class PatternToPatternPass; 145 }; 146 147 class BACKEND_EXPORT DstPattern { 148 public: 149 DstPattern &AddCNode(const string &name, const std::initializer_list<PatternNode> &inputs, 150 const BuildCNodeFunc &buildfunc = DefaultCNodeFunc()); 151 DstPattern &AddValueNode(const string &name, const BuildValueFunc &buildfunc); 152 153 private: DstPattern(PatternMapPtr m)154 explicit DstPattern(PatternMapPtr m) : m_(std::move(m)) {} 155 AnfNodePtr Root(); 156 void clear(); 157 void set_info(PatternToPatternPass *now_pass, const FuncGraphPtr &func_graph); 158 friend class PatternToPatternPass; 159 PatternMapPtr m_; 160 mindspore::HashSet<std::string> dst_set_; 161 bool fail_ = false; 162 AnfNodePtr root_ = nullptr; 163 FuncGraphPtr fg_ = nullptr; 164 PatternToPatternPass *pass_ = nullptr; 165 }; 166 167 class BACKEND_EXPORT PatternToPatternPass : public PatternPass { 168 public: 169 explicit PatternToPatternPass(const std::string &name = "", bool is_fast_pass = false, bool multigraph = true) PatternPass(name,multigraph)170 : PatternPass(name, multigraph), 171 m_(std::make_shared<PatternMap>()), 172 src_pattern_(SrcPattern(m_)), 173 dst_pattern_(DstPattern(m_)), 174 is_fast_pass_(is_fast_pass) {} 175 ~PatternToPatternPass() override = default; 176 virtual void DefineSrcPattern(SrcPattern *src_pattern) = 0; 177 virtual void DefineDstPattern(DstPattern *dst_pattern) = 0; CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &)178 virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { return true; } 179 bool IsFastPass() override; 180 AnfNodePtr GetSrcPatternRoot(); 181 std::string GetPatternRootPrimitiveName() override; 182 AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; 183 void AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const FuncGraphPtr &sub_graph, 184 const FuncGraphIndexPtr &func_graph_index) override; 185 std::vector<UnpackNode> Unpacking(const std::string &s); 186 187 private: 188 PatternMapPtr m_; 189 SrcPattern src_pattern_; 190 DstPattern dst_pattern_; 191 AnfNodePtr src_pattern_root_ = nullptr; 192 bool is_fast_pass_; 193 }; 194 } // namespace opt 195 } // namespace mindspore 196 197 #endif // MINDSPORE_PATTERN_TO_PATTERN_H 198