1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
19
20 #include <string>
21 #include <sstream>
22 #include <memory>
23 #include <vector>
24 #include <initializer_list>
25 #include <iostream>
26 #include <algorithm>
27 #include <map>
28 #include <stdexcept>
29 #include <list>
30 #include <utility>
31
32 #include "utils/hash_map.h"
33 #include "utils/hash_set.h"
34 #include "base/base.h"
35 #include "utils/log_adapter.h"
36 #include "base/base_ref.h"
37 #include "include/backend/visible.h"
38 #include "include/backend/optimizer/visitor.h"
39
40 namespace mindspore {
41 class CondVar;
42 class SeqVar;
43 using CondVarPtr = std::shared_ptr<CondVar>;
44 using SVarPtr = std::shared_ptr<SeqVar>;
45 const int kInvalidVarIndex = -2;
46
47 using PatternConditionFunc = std::function<bool(const BaseRef &)>;
48
49 // Base wildcard variable which could match any anf node.
50 class BACKEND_EXPORT Var : public Base {
51 public:
tag_(std::move (tag))52 explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
tag_(std::move (tag))53 explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
54 EnsureTag();
55 }
Var(const Var & other)56 Var(const Var &other) : Base(other), tag_(other.tag_), primitive_(other.primitive_) {}
57 Var &operator=(const Var &other) {
58 if (&other == this) {
59 return *this;
60 }
61 this->tag_ = other.tag_;
62 this->primitive_ = other.primitive_;
63 return *this;
64 }
65 ~Var() override = default;
66 MS_DECLARE_PARENT(Var, Base);
67
matches(const BaseRef &)68 virtual bool matches(const BaseRef &) { return true; }
69
70 virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
71 bool operator!=(const Var &other) const { return !(&other == this); }
72
tag()73 std::string tag() const { return tag_; }
primitive()74 PrimitivePtr primitive() const { return primitive_; }
ToString()75 std::string ToString() const override {
76 std::ostringstream buffer;
77 buffer << "Var(" << tag_ << ")";
78 return buffer.str();
79 }
hash()80 std::size_t hash() const override { return std::hash<std::string>()(tag_); }
81
82 protected:
83 void EnsureTag();
84
85 std::string tag_;
86 PrimitivePtr primitive_;
87 };
88
89 // VarNode means variable node, a subclass of AnfNode
90 class VarNode : public AnfNode {
91 public:
VarNode(const VarPtr & value,const FuncGraphPtr & func_graph)92 VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
93 ~VarNode() override = default;
94 MS_DECLARE_PARENT(VarNode, AnfNode);
95
96 const VarPtr var_;
97 };
98 using VarNodePtr = std::shared_ptr<VarNode>;
99
100 // Condition Var, match an anf node when condition function return true.
101 class CondVar : public Var {
102 public:
CondVar(const PatternConditionFunc & cond)103 explicit CondVar(const PatternConditionFunc &cond) : cond_fn_(cond) {}
CondVar(const PatternConditionFunc & cond,std::string tag)104 explicit CondVar(const PatternConditionFunc &cond, std::string tag) : Var(tag), cond_fn_(cond) {}
105 ~CondVar() override = default;
106 MS_DECLARE_PARENT(CondVar, Var);
matches(const BaseRef & value)107 bool matches(const BaseRef &value) override {
108 MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
109 if (utils::isa<Var>(value)) {
110 return false;
111 }
112 return cond_fn_(value);
113 }
114
115 private:
116 PatternConditionFunc cond_fn_;
117 };
118
119 using Seq = VectorRef;
120 using SeqPtr = std::shared_ptr<Seq>;
121
122 // Sequence Var which could match multiple consecutive input nodes of a CNode.
123 class BACKEND_EXPORT SeqVar : public Var {
124 public:
SeqVar()125 SeqVar() { subvar_ = std::make_shared<Var>(); }
126 ~SeqVar() override = default;
127 MS_DECLARE_PARENT(SeqVar, Var);
SeqVar(const VarPtr subvar)128 explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
SeqVar(const PatternConditionFunc & cond)129 explicit SeqVar(const PatternConditionFunc &cond) { subvar_ = std::make_shared<CondVar>(cond); }
matches(const BaseRef & value)130 bool matches(const BaseRef &value) override {
131 // match Seq.
132 if (utils::isa<Seq>(value)) {
133 const Seq &seq = utils::cast<Seq>(value);
134 return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
135 auto eq = subvar_->matches(v);
136 return eq;
137 });
138 }
139 return false;
140 }
141 bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
142 std::string ToString() const override;
143
144 private:
145 VarPtr subvar_;
146 };
147
148 bool operator==(const VarPtr &lhs, const VarPtr &rhs);
149
150 inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
151
152 std::ostream &operator<<(std::ostream &os, const VarPtr &var);
153
154 using Equiv = std::map<VarPtr, BaseRef>;
155 using EquivPtr = std::shared_ptr<Equiv>;
156 using PrimitiveVarMap = mindspore::HashMap<PrimitivePtr, VarPtr>;
157 using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
158
DefaultTypeEq(const BaseRef & x,const BaseRef & y)159 inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
160
161 class PatternEngine {
162 public:
PatternEngine(const std::shared_ptr<Visitor> & visitor)163 explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
164 ~PatternEngine() = default;
165
166 EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
167 EquivPtr equiv) const;
168 // Replace pattern with equivalent
169
170 private:
171 EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
172 const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
173 bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
174 VectorRef *const values_expr) const;
175 bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
176 VectorRef *const values_expr) const;
177 static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
178 std::shared_ptr<Visitor> visitor_;
179 };
180 } // namespace mindspore
181 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
182