1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
20 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
21
22 #include <string>
23 #include <sstream>
24 #include <memory>
25 #include <vector>
26 #include <unordered_set>
27 #include <unordered_map>
28 #include <initializer_list>
29 #include <iostream>
30 #include <algorithm>
31 #include <map>
32 #include <stdexcept>
33 #include <list>
34 #include <utility>
35
36 #include "backend/optimizer/common/visit.h"
37 #include "base/base.h"
38 #include "utils/log_adapter.h"
39 #include "base/base_ref.h"
40
41 namespace mindspore {
42 class CondVar;
43 class SeqVar;
44 using CondVarPtr = std::shared_ptr<CondVar>;
45 using SVarPtr = std::shared_ptr<SeqVar>;
46 const int kInvalidVarIndex = -2;
47
48 using ConditionFunc = std::function<bool(const BaseRef &)>;
49
50 // Base wildcard variable which could match any anf node.
51 class Var : public Base {
52 friend class VarHasher;
53
54 public:
tag_(std::move (tag))55 explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
tag_(std::move (tag))56 explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
57 EnsureTag();
58 }
Var(const Var & other)59 Var(const Var &other) : Base(other), tag_(other.tag_) {}
60 virtual Var &operator=(const Var &other) {
61 if (&other == this) {
62 return *this;
63 }
64 this->tag_ = other.tag_;
65 return *this;
66 }
67 ~Var() override = default;
68 MS_DECLARE_PARENT(Var, Base);
69
matches(const BaseRef &)70 virtual bool matches(const BaseRef &) { return true; }
71
72 virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
73 bool operator!=(const Var &other) const { return !(&other == this); }
74
tag()75 std::string tag() const { return tag_; }
primitive()76 PrimitivePtr primitive() const { return primitive_; }
ToString()77 std::string ToString() const override {
78 std::ostringstream buffer;
79 buffer << "Var(" << tag_ << ")";
80 return buffer.str();
81 }
hash()82 std::size_t hash() const override { return std::hash<std::string>()(tag_); }
83
84 protected:
85 void EnsureTag();
86
87 std::string tag_;
88 PrimitivePtr primitive_;
89 };
90
91 // VarNode means variable node, a subclass of AnfNode
92 class VarNode : public AnfNode {
93 public:
VarNode(const VarPtr & value,const FuncGraphPtr & func_graph)94 VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
95 ~VarNode() override = default;
96 MS_DECLARE_PARENT(VarNode, AnfNode);
97
98 const VarPtr var_;
99 };
100 using VarNodePtr = std::shared_ptr<VarNode>;
101
102 class VarHasher {
103 public:
operator()104 std::size_t operator()(const Var &var) const { return var.hash(); }
105 };
106
107 // Condition Var, match an anf node when condition function return true.
108 class CondVar : public Var {
109 public:
CondVar(const ConditionFunc & cond)110 explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
111 ~CondVar() override = default;
112 MS_DECLARE_PARENT(CondVar, Var);
matches(const BaseRef & value)113 bool matches(const BaseRef &value) override {
114 MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
115 if (utils::isa<Var>(value)) {
116 return false;
117 }
118 return cond_fn_(value);
119 }
120
121 private:
122 ConditionFunc cond_fn_;
123 };
124
125 using Seq = VectorRef;
126 using SeqPtr = std::shared_ptr<Seq>;
127
128 // Sequence Var which could match multiple consecutive input nodes of a CNode.
129 class SeqVar : public Var {
130 public:
SeqVar()131 SeqVar() { subvar_ = std::make_shared<Var>(); }
132 ~SeqVar() override = default;
133 MS_DECLARE_PARENT(SeqVar, Var);
SeqVar(const VarPtr subvar)134 explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
matches(const BaseRef & value)135 bool matches(const BaseRef &value) override {
136 // match Seq.
137 if (utils::isa<Seq>(value)) {
138 const Seq &seq = utils::cast<Seq>(value);
139 return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
140 auto eq = subvar_->matches(v);
141 return eq;
142 });
143 }
144 return false;
145 }
146 bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
147 std::string ToString() const override;
148
149 private:
150 VarPtr subvar_;
151 };
152
153 bool operator==(const VarPtr &lhs, const VarPtr &rhs);
154
155 inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
156
157 std::ostream &operator<<(std::ostream &os, const VarPtr &var);
158
159 using Equiv = std::map<VarPtr, BaseRef>;
160 using EquivPtr = std::shared_ptr<Equiv>;
161 using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
162 using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
163
DefaultTypeEq(const BaseRef & x,const BaseRef & y)164 inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
165
166 class PatternEngine {
167 public:
PatternEngine(const std::shared_ptr<Visitor> & visitor)168 explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
169 ~PatternEngine() = default;
170
171 EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
172 EquivPtr equiv) const;
173 // Replace pattern with equivalent
174
175 private:
176 EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
177 const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
178 bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
179 VectorRef *const values_expr) const;
180 bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
181 VectorRef *const values_expr) const;
182 static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
183 std::shared_ptr<Visitor> visitor_;
184 };
185 } // namespace mindspore
186 namespace std {
187 using mindspore::ERROR;
188 using mindspore::LogStream;
189 using mindspore::NoExceptionType;
190 template <>
191 struct hash<mindspore::VarPtr> {
192 std::size_t operator()(const mindspore::VarPtr var) const {
193 if (var == nullptr) {
194 MS_LOG(ERROR) << "Invalid var ptr";
195 return 0;
196 }
197 return std::hash<std::string>{}(var->tag());
198 }
199 };
200 } // namespace std
201 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
202