• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
20 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
21 
22 #include <string>
23 #include <sstream>
24 #include <memory>
25 #include <vector>
26 #include <unordered_set>
27 #include <unordered_map>
28 #include <initializer_list>
29 #include <iostream>
30 #include <algorithm>
31 #include <map>
32 #include <stdexcept>
33 #include <list>
34 #include <utility>
35 
36 #include "backend/optimizer/common/visit.h"
37 #include "base/base.h"
38 #include "utils/log_adapter.h"
39 #include "base/base_ref.h"
40 
41 namespace mindspore {
42 class CondVar;
43 class SeqVar;
44 using CondVarPtr = std::shared_ptr<CondVar>;
45 using SVarPtr = std::shared_ptr<SeqVar>;
46 const int kInvalidVarIndex = -2;
47 
48 using ConditionFunc = std::function<bool(const BaseRef &)>;
49 
50 // Base wildcard variable which could match any anf node.
51 class Var : public Base {
52   friend class VarHasher;
53 
54  public:
tag_(std::move (tag))55   explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
tag_(std::move (tag))56   explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
57     EnsureTag();
58   }
Var(const Var & other)59   Var(const Var &other) : Base(other), tag_(other.tag_) {}
60   virtual Var &operator=(const Var &other) {
61     if (&other == this) {
62       return *this;
63     }
64     this->tag_ = other.tag_;
65     return *this;
66   }
67   ~Var() override = default;
68   MS_DECLARE_PARENT(Var, Base);
69 
matches(const BaseRef &)70   virtual bool matches(const BaseRef &) { return true; }
71 
72   virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
73   bool operator!=(const Var &other) const { return !(&other == this); }
74 
tag()75   std::string tag() const { return tag_; }
primitive()76   PrimitivePtr primitive() const { return primitive_; }
ToString()77   std::string ToString() const override {
78     std::ostringstream buffer;
79     buffer << "Var(" << tag_ << ")";
80     return buffer.str();
81   }
hash()82   std::size_t hash() const override { return std::hash<std::string>()(tag_); }
83 
84  protected:
85   void EnsureTag();
86 
87   std::string tag_;
88   PrimitivePtr primitive_;
89 };
90 
91 // VarNode means variable node, a subclass of AnfNode
92 class VarNode : public AnfNode {
93  public:
VarNode(const VarPtr & value,const FuncGraphPtr & func_graph)94   VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
95   ~VarNode() override = default;
96   MS_DECLARE_PARENT(VarNode, AnfNode);
97 
98   const VarPtr var_;
99 };
100 using VarNodePtr = std::shared_ptr<VarNode>;
101 
102 class VarHasher {
103  public:
operator()104   std::size_t operator()(const Var &var) const { return var.hash(); }
105 };
106 
107 // Condition Var, match an anf node when condition function return true.
108 class CondVar : public Var {
109  public:
CondVar(const ConditionFunc & cond)110   explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
111   ~CondVar() override = default;
112   MS_DECLARE_PARENT(CondVar, Var);
matches(const BaseRef & value)113   bool matches(const BaseRef &value) override {
114     MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
115     if (utils::isa<Var>(value)) {
116       return false;
117     }
118     return cond_fn_(value);
119   }
120 
121  private:
122   ConditionFunc cond_fn_;
123 };
124 
125 using Seq = VectorRef;
126 using SeqPtr = std::shared_ptr<Seq>;
127 
128 // Sequence Var which could match multiple consecutive input nodes of a CNode.
129 class SeqVar : public Var {
130  public:
SeqVar()131   SeqVar() { subvar_ = std::make_shared<Var>(); }
132   ~SeqVar() override = default;
133   MS_DECLARE_PARENT(SeqVar, Var);
SeqVar(const VarPtr subvar)134   explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
matches(const BaseRef & value)135   bool matches(const BaseRef &value) override {
136     // match Seq.
137     if (utils::isa<Seq>(value)) {
138       const Seq &seq = utils::cast<Seq>(value);
139       return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
140         auto eq = subvar_->matches(v);
141         return eq;
142       });
143     }
144     return false;
145   }
146   bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
147   std::string ToString() const override;
148 
149  private:
150   VarPtr subvar_;
151 };
152 
153 bool operator==(const VarPtr &lhs, const VarPtr &rhs);
154 
155 inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
156 
157 std::ostream &operator<<(std::ostream &os, const VarPtr &var);
158 
159 using Equiv = std::map<VarPtr, BaseRef>;
160 using EquivPtr = std::shared_ptr<Equiv>;
161 using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
162 using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
163 
DefaultTypeEq(const BaseRef & x,const BaseRef & y)164 inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
165 
166 class PatternEngine {
167  public:
PatternEngine(const std::shared_ptr<Visitor> & visitor)168   explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
169   ~PatternEngine() = default;
170 
171   EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
172                  EquivPtr equiv) const;
173   // Replace pattern with equivalent
174 
175  private:
176   EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
177                      const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
178   bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
179                 VectorRef *const values_expr) const;
180   bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
181                 VectorRef *const values_expr) const;
182   static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
183   std::shared_ptr<Visitor> visitor_;
184 };
185 }  // namespace mindspore
186 namespace std {
187 using mindspore::ERROR;
188 using mindspore::LogStream;
189 using mindspore::NoExceptionType;
190 template <>
191 struct hash<mindspore::VarPtr> {
192   std::size_t operator()(const mindspore::VarPtr var) const {
193     if (var == nullptr) {
194       MS_LOG(ERROR) << "Invalid var ptr";
195       return 0;
196     }
197     return std::hash<std::string>{}(var->tag());
198   }
199 };
200 }  // namespace std
201 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
202