• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PATTERN_TO_PATTERN_H
18 #define MINDSPORE_PATTERN_TO_PATTERN_H
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "base/base.h"
26 #include "base/base_ref.h"
27 #include "include/backend/optimizer/optimizer.h"
28 #include "include/backend/visible.h"
29 
30 namespace mindspore {
31 namespace opt {
32 bool BACKEND_EXPORT AlwaysReturnTrue(const BaseRef &);
33 
34 class BACKEND_EXPORT PatternMap {
35  public:
36   PatternMap() = default;
37   bool Contains(const std::string &name) const;
38   bool CheckSeq(const std::string &name) const;
39   AnfNodePtr Get(const std::string &name) const;
40   const std::vector<AnfNodePtr> &GetSeq(const std::string &name) const;
41   bool Emplace(const std::string &name, const AnfNodePtr &node);
42   bool Emplace(const std::string &name, const std::vector<AnfNodePtr> &v);
43   void Clear();
44   bool Check(const std::string &name, const AnfNodePtr &node) const;
45   void Erase(const mindspore::HashSet<std::string> &del_set);
GetOptScope()46   const mindspore::HashSet<AnfNodePtr> &GetOptScope() const { return opt_scope_; }
47 
48  private:
49   mindspore::HashSet<std::string> name_set_;
50   mindspore::HashMap<std::string, AnfNodePtr> node_map_;
51   mindspore::HashMap<std::string, std::vector<AnfNodePtr>> seq_map_;
52   mindspore::HashSet<AnfNodePtr> opt_scope_;
53 };
54 
55 using PatternMapPtr = std::shared_ptr<PatternMap>;
56 using BuildCNodeFunc = std::function<AnfNodePtr(const PatternMap &, const AnfNodePtr &)>;
57 using BuildValueFunc = std::function<AnfNodePtr(const PatternMap &)>;
58 
59 class BACKEND_EXPORT DefaultCNodeFunc {
60  public:
61   DefaultCNodeFunc() = default;
operator()62   AnfNodePtr operator()(const PatternMap &, const AnfNodePtr &default_cnode) const { return default_cnode; }
63 };
64 
65 class BACKEND_EXPORT InplaceCNodeFunc {
66  public:
InplaceCNodeFunc(std::string s)67   explicit InplaceCNodeFunc(std::string s) : s_(std::move(s)) {}
operator()68   AnfNodePtr operator()(const PatternMap &m, const AnfNodePtr & /* default_cnode */) const { return m.Get(s_); }
69 
70  private:
71   std::string s_;
72 };
73 
74 class BACKEND_EXPORT DefaultValueFunc {
75  public:
DefaultValueFunc(ValuePtr v)76   explicit DefaultValueFunc(ValuePtr v) : v_(std::move(v)) {}
operator()77   AnfNodePtr operator()(const PatternMap &) const { return NewValueNode(v_); }
78 
79  private:
80   ValuePtr v_;
81 };
82 
83 class BACKEND_EXPORT InplaceValueFunc {
84  public:
InplaceValueFunc(std::string s)85   explicit InplaceValueFunc(std::string s) : s_(std::move(s)) {}
operator()86   AnfNodePtr operator()(const PatternMap &m) const { return m.Get(s_); }
87 
88  private:
89   std::string s_;
90 };
91 
92 class BACKEND_EXPORT PatternToPatternPass;
93 class BACKEND_EXPORT UnpackNode {
94  public:
95   UnpackNode &operator=(const std::string &name);
96   UnpackNode &operator=(const UnpackNode &u) = default;
97   UnpackNode(const UnpackNode &u) = default;
98 
99  private:
UnpackNode(AnfNodePtr node)100   explicit UnpackNode(AnfNodePtr node) : node_(std::move(node)) {}
101   AnfNodePtr node_ = nullptr;
102   std::string key_;
103   friend class DstPattern;
104   friend class PatternToPatternPass;
105 };
106 
107 class BACKEND_EXPORT PatternNode {
108  public:
PatternNode(const PrimitivePtr & p)109   PatternNode(const PrimitivePtr &p)  // NOLINT
110       : type_("prim"), p_(NewValueNode(std::make_shared<Primitive>(p->name()))) {}
PatternNode(const char * name)111   PatternNode(const char *name) : type_("name"), name_(name) {}        // NOLINT
PatternNode(std::vector<UnpackNode> & v)112   PatternNode(std::vector<UnpackNode> &v) : type_("unpack"), v_(v) {}  // NOLINT
113   PatternNode(const PatternNode &) = default;
114   PatternNode &operator=(const PatternNode &) = default;
115 
116  private:
117   std::string type_;
118   std::string name_;
119   ValueNodePtr p_;
120   std::vector<UnpackNode> v_;
121   friend class SrcPattern;
122   friend class DstPattern;
123 };
124 
125 class BACKEND_EXPORT SrcPattern {
126  public:
127   SrcPattern &AddVar(const std::string &name, const PatternConditionFunc &f = AlwaysReturnTrue);
128   SrcPattern &AddSeqVar(const std::string &name, const PatternConditionFunc &f = AlwaysReturnTrue);
129   const BaseRef &GetRef(const std::string &name) const;
130   SrcPattern &AddCNode(const std::string &name, const std::initializer_list<PatternNode> &v);
131   BaseRef GetRoot() const;
132 
133  private:
SrcPattern(PatternMapPtr m)134   explicit SrcPattern(PatternMapPtr m) : m_(std::move(m)), has_root_(false) {}
135   bool CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv, const std::vector<PatternNode> &inputs,
136                         size_t *now_pattern);
137   bool match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv);
138   bool build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv);
139   PatternMapPtr m_;
140   mindspore::HashMap<std::string, BaseRef> ref_map_;
141   mindspore::HashMap<std::string, std::vector<PatternNode>> inputs_map_;
142   bool has_root_;
143   std::string root_;
144   friend class PatternToPatternPass;
145 };
146 
147 class BACKEND_EXPORT DstPattern {
148  public:
149   DstPattern &AddCNode(const string &name, const std::initializer_list<PatternNode> &inputs,
150                        const BuildCNodeFunc &buildfunc = DefaultCNodeFunc());
151   DstPattern &AddValueNode(const string &name, const BuildValueFunc &buildfunc);
152 
153  private:
DstPattern(PatternMapPtr m)154   explicit DstPattern(PatternMapPtr m) : m_(std::move(m)) {}
155   AnfNodePtr Root();
156   void clear();
157   void set_info(PatternToPatternPass *now_pass, const FuncGraphPtr &func_graph);
158   friend class PatternToPatternPass;
159   PatternMapPtr m_;
160   mindspore::HashSet<std::string> dst_set_;
161   bool fail_ = false;
162   AnfNodePtr root_ = nullptr;
163   FuncGraphPtr fg_ = nullptr;
164   PatternToPatternPass *pass_ = nullptr;
165 };
166 
167 class BACKEND_EXPORT PatternToPatternPass : public PatternPass {
168  public:
169   explicit PatternToPatternPass(const std::string &name = "", bool is_fast_pass = false, bool multigraph = true)
PatternPass(name,multigraph)170       : PatternPass(name, multigraph),
171         m_(std::make_shared<PatternMap>()),
172         src_pattern_(SrcPattern(m_)),
173         dst_pattern_(DstPattern(m_)),
174         is_fast_pass_(is_fast_pass) {}
175   ~PatternToPatternPass() override = default;
176   virtual void DefineSrcPattern(SrcPattern *src_pattern) = 0;
177   virtual void DefineDstPattern(DstPattern *dst_pattern) = 0;
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &)178   virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { return true; }
179   bool IsFastPass() override;
180   AnfNodePtr GetSrcPatternRoot();
181   std::string GetPatternRootPrimitiveName() override;
182   AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
183   void AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const FuncGraphPtr &sub_graph,
184                     const FuncGraphIndexPtr &func_graph_index) override;
185   std::vector<UnpackNode> Unpacking(const std::string &s);
186 
187  private:
188   PatternMapPtr m_;
189   SrcPattern src_pattern_;
190   DstPattern dst_pattern_;
191   AnfNodePtr src_pattern_root_ = nullptr;
192   bool is_fast_pass_;
193 };
194 }  // namespace opt
195 }  // namespace mindspore
196 
197 #endif  // MINDSPORE_PATTERN_TO_PATTERN_H
198