• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
19 
20 #include <string>
21 #include <sstream>
22 #include <memory>
23 #include <vector>
24 #include <initializer_list>
25 #include <iostream>
26 #include <algorithm>
27 #include <map>
28 #include <stdexcept>
29 #include <list>
30 #include <utility>
31 
32 #include "utils/hash_map.h"
33 #include "utils/hash_set.h"
34 #include "base/base.h"
35 #include "utils/log_adapter.h"
36 #include "base/base_ref.h"
37 #include "include/backend/visible.h"
38 #include "include/backend/optimizer/visitor.h"
39 
40 namespace mindspore {
41 class CondVar;
42 class SeqVar;
43 using CondVarPtr = std::shared_ptr<CondVar>;
44 using SVarPtr = std::shared_ptr<SeqVar>;
45 const int kInvalidVarIndex = -2;
46 
47 using PatternConditionFunc = std::function<bool(const BaseRef &)>;
48 
49 // Base wildcard variable which could match any anf node.
50 class BACKEND_EXPORT Var : public Base {
51  public:
tag_(std::move (tag))52   explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
tag_(std::move (tag))53   explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
54     EnsureTag();
55   }
Var(const Var & other)56   Var(const Var &other) : Base(other), tag_(other.tag_), primitive_(other.primitive_) {}
57   Var &operator=(const Var &other) {
58     if (&other == this) {
59       return *this;
60     }
61     this->tag_ = other.tag_;
62     this->primitive_ = other.primitive_;
63     return *this;
64   }
65   ~Var() override = default;
66   MS_DECLARE_PARENT(Var, Base);
67 
matches(const BaseRef &)68   virtual bool matches(const BaseRef &) { return true; }
69 
70   virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
71   bool operator!=(const Var &other) const { return !(&other == this); }
72 
tag()73   std::string tag() const { return tag_; }
primitive()74   PrimitivePtr primitive() const { return primitive_; }
ToString()75   std::string ToString() const override {
76     std::ostringstream buffer;
77     buffer << "Var(" << tag_ << ")";
78     return buffer.str();
79   }
hash()80   std::size_t hash() const override { return std::hash<std::string>()(tag_); }
81 
82  protected:
83   void EnsureTag();
84 
85   std::string tag_;
86   PrimitivePtr primitive_;
87 };
88 
89 // VarNode means variable node, a subclass of AnfNode
90 class VarNode : public AnfNode {
91  public:
VarNode(const VarPtr & value,const FuncGraphPtr & func_graph)92   VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
93   ~VarNode() override = default;
94   MS_DECLARE_PARENT(VarNode, AnfNode);
95 
96   const VarPtr var_;
97 };
98 using VarNodePtr = std::shared_ptr<VarNode>;
99 
100 // Condition Var, match an anf node when condition function return true.
101 class CondVar : public Var {
102  public:
CondVar(const PatternConditionFunc & cond)103   explicit CondVar(const PatternConditionFunc &cond) : cond_fn_(cond) {}
CondVar(const PatternConditionFunc & cond,std::string tag)104   explicit CondVar(const PatternConditionFunc &cond, std::string tag) : Var(tag), cond_fn_(cond) {}
105   ~CondVar() override = default;
106   MS_DECLARE_PARENT(CondVar, Var);
matches(const BaseRef & value)107   bool matches(const BaseRef &value) override {
108     MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
109     if (utils::isa<Var>(value)) {
110       return false;
111     }
112     return cond_fn_(value);
113   }
114 
115  private:
116   PatternConditionFunc cond_fn_;
117 };
118 
119 using Seq = VectorRef;
120 using SeqPtr = std::shared_ptr<Seq>;
121 
122 // Sequence Var which could match multiple consecutive input nodes of a CNode.
123 class BACKEND_EXPORT SeqVar : public Var {
124  public:
SeqVar()125   SeqVar() { subvar_ = std::make_shared<Var>(); }
126   ~SeqVar() override = default;
127   MS_DECLARE_PARENT(SeqVar, Var);
SeqVar(const VarPtr subvar)128   explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
SeqVar(const PatternConditionFunc & cond)129   explicit SeqVar(const PatternConditionFunc &cond) { subvar_ = std::make_shared<CondVar>(cond); }
matches(const BaseRef & value)130   bool matches(const BaseRef &value) override {
131     // match Seq.
132     if (utils::isa<Seq>(value)) {
133       const Seq &seq = utils::cast<Seq>(value);
134       return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
135         auto eq = subvar_->matches(v);
136         return eq;
137       });
138     }
139     return false;
140   }
141   bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
142   std::string ToString() const override;
143 
144  private:
145   VarPtr subvar_;
146 };
147 
148 bool operator==(const VarPtr &lhs, const VarPtr &rhs);
149 
150 inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
151 
152 std::ostream &operator<<(std::ostream &os, const VarPtr &var);
153 
154 using Equiv = std::map<VarPtr, BaseRef>;
155 using EquivPtr = std::shared_ptr<Equiv>;
156 using PrimitiveVarMap = mindspore::HashMap<PrimitivePtr, VarPtr>;
157 using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
158 
DefaultTypeEq(const BaseRef & x,const BaseRef & y)159 inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
160 
161 class PatternEngine {
162  public:
PatternEngine(const std::shared_ptr<Visitor> & visitor)163   explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
164   ~PatternEngine() = default;
165 
166   EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
167                  EquivPtr equiv) const;
168   // Replace pattern with equivalent
169 
170  private:
171   EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
172                      const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
173   bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
174                 VectorRef *const values_expr) const;
175   bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
176                 VectorRef *const values_expr) const;
177   static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
178   std::shared_ptr<Visitor> visitor_;
179 };
180 }  // namespace mindspore
181 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
182