1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "backend/optimizer/common/pattern_engine.h"
20 #include "frontend/optimizer/opt.h"
21 #include "ir/anf.h"
22 #include "utils/convert_utils_base.h"
23 #include "utils/overload.h"
24 #include "backend/optimizer/common/helper.h"
25
26 namespace mindspore {
GetNextTag()27 static int GetNextTag() {
28 static int kID = 0;
29 return kID++;
30 }
31
EnsureTag()32 void Var::EnsureTag() {
33 if (tag_.length() == 0) {
34 std::ostringstream buffer;
35 buffer << "_" << GetNextTag();
36 tag_ = buffer.str();
37 }
38 }
39
operator ==(const VarPtr & lhs,const VarPtr & rhs)40 bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
41 if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
42 CondVarPtr v1 = dyn_cast<CondVar>(lhs);
43 CondVarPtr v2 = dyn_cast<CondVar>(rhs);
44 return *v1 == *v2;
45 }
46
47 if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
48 SVarPtr v1 = dyn_cast<SeqVar>(lhs);
49 SVarPtr v2 = dyn_cast<SeqVar>(rhs);
50 return *v1 == *v2;
51 }
52 return (*lhs == *rhs);
53 }
54
ToString() const55 std::string SeqVar::ToString() const {
56 std::ostringstream buffer;
57 buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
58 return buffer.str();
59 }
60
operator <<(std::ostream & os,const VarPtr & var)61 std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
62 if (var == nullptr) {
63 os << "";
64 } else {
65 os << var->ToString();
66 }
67 return os;
68 }
69
70 template <>
71 std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
72 os << "[Equiv]"
73 << "\n";
74 for (auto &equiv_item : equiv) {
75 auto k = equiv_item.first;
76 os << k << ":";
77 BaseRef x = equiv_item.second;
78 if (utils::isa<AnfNodePtr>(x)) {
79 auto node = utils::cast<AnfNodePtr>(x);
80 os << "TypeString[" << node->type_name() << "]";
81 if (IsValueNode<FuncGraph>(node)) {
82 os << "IsValueNodeGraph ";
83 }
84 os << "type " << node->type_name();
85 if (node->isa<ValueNode>()) {
86 os << " value " << GetValueNode(node);
87 }
88 os << " addr: " << node;
89 } else if (utils::isa<Named>(x)) {
90 os << "Named " << x.ToString().c_str();
91 } else if (utils::isa<VarPtr>(x)) {
92 os << "TypeString[Var]";
93 os << (utils::cast<VarPtr>(x));
94 } else if (utils::isa<FuncGraphPtr>(x)) {
95 os << "TypeString[Graph]";
96 }
97 os << "\n";
98 }
99 return os;
100 }
101
GetVar(const BaseRef & x)102 static BaseRef GetVar(const BaseRef &x) {
103 MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
104 if (utils::isa<AnfNodePtr>(x)) {
105 auto node = utils::cast<AnfNodePtr>(x);
106 MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
107 if (node->isa<VarNode>()) {
108 MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
109 return node->cast<VarNodePtr>()->var_;
110 }
111 if (node->isa<ValueNode>()) {
112 MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
113 } else {
114 MS_LOG(DEBUG) << "type " + node->type_name();
115 }
116 } else if (utils::isa<Named>(x)) {
117 MS_LOG(DEBUG) << "Named " + x.ToString();
118 } else if (utils::isa<VectorRef>(x)) {
119 MS_LOG(DEBUG) << "VectorRef";
120 } else if (utils::isa<VarPtr>(x)) {
121 MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
122 }
123 MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
124 return x;
125 }
126
MatchOnVar(const BaseRef & pattern,const BaseRef & expr,EquivPtr equiv)127 EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
128 MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
129 MS_EXCEPTION_IF_NULL(equiv);
130 if (utils::isa<VarPtr>(pattern)) {
131 VarPtr var = utils::cast<VarPtr>(pattern);
132 if (var->matches(expr)) {
133 (*equiv)[var] = expr;
134 MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
135 return equiv;
136 }
137 }
138
139 return nullptr;
140 }
141
ToVector(const VectorRef & pattern_ref,const VectorRef & expr_ref,VectorRef * const values_pattern,VectorRef * const values_expr) const142 bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
143 VectorRef *const values_expr) const {
144 MS_EXCEPTION_IF_NULL(values_expr);
145 if (utils::isa<SeqPtr>(pattern_ref)) {
146 *values_pattern = pattern_ref;
147 *values_expr = expr_ref;
148 return true;
149 }
150 return false;
151 }
152
ToVector(const BaseRef & pattern_ref,const BaseRef & expr_ref,VectorRef * const values_pattern,VectorRef * const values_expr) const153 bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
154 VectorRef *const values_expr) const {
155 MS_EXCEPTION_IF_NULL(values_expr);
156 MS_LOG(DEBUG) << "visit pattern_ref";
157 bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
158 if (!success) {
159 return false;
160 }
161 MS_LOG(DEBUG) << "visit expr_ref";
162 return visitor_->Visit(expr_ref, values_expr, nullptr);
163 }
164
GetSVarStartIndex(const VectorRef & values)165 static int GetSVarStartIndex(const VectorRef &values) {
166 int index = -1;
167 int count = 0;
168 for (auto &value : values) {
169 if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
170 if (index != -1) {
171 MS_LOG(DEBUG) << "Multiple SVars in sequence";
172 return kInvalidVarIndex;
173 }
174 index = count;
175 }
176 count++;
177 }
178 return index;
179 }
180
UpdateEquivMap(const VectorRef & values_pattern,const BaseRef & expr_ref,const PrimitiveVarMap & primitive_vars,const EquivPtr & equiv)181 void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
182 const EquivPtr &equiv) {
183 if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
184 !utils::isa<AnfNodePtr>(expr_ref)) {
185 return;
186 }
187 auto real_node = utils::cast<AnfNodePtr>(expr_ref);
188 MS_EXCEPTION_IF_NULL(real_node);
189 if (!real_node->isa<CNode>()) {
190 return;
191 }
192 auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
193 MS_EXCEPTION_IF_NULL(prim_node);
194 if (!IsValueNode<Primitive>(prim_node)) {
195 return;
196 }
197 ValuePtr value = GetValueNode(prim_node);
198 MS_EXCEPTION_IF_NULL(value);
199 auto prim = value->cast<PrimitivePtr>();
200 MS_EXCEPTION_IF_NULL(prim);
201 auto iter = primitive_vars.find(prim);
202 if (iter == primitive_vars.end()) {
203 return;
204 }
205 (*equiv)[iter->second] = real_node;
206 }
207
AlignSVar(const VectorRef & values_pattern,const VectorRef & values_expr,const PrimitiveVarMap & primitive_vars,EquivPtr equiv) const208 EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
209 const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
210 int svar_index = GetSVarStartIndex(values_pattern);
211 if (svar_index == kInvalidVarIndex) {
212 return nullptr;
213 }
214
215 size_t values_pattern_len = values_pattern.size();
216 size_t values_expr_len = values_expr.size();
217
218 if (svar_index == -1) {
219 if (values_pattern_len != values_expr_len) {
220 MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len "
221 << values_expr_len;
222 return nullptr;
223 }
224 }
225 if (values_expr_len < values_pattern_len - 1) {
226 MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
227 return nullptr;
228 }
229 size_t diff = values_expr_len - values_pattern_len + 1;
230 for (size_t i = 0; i < values_pattern_len; i++) {
231 size_t expr_i = i;
232 if (svar_index != -1 && i == IntToSize(svar_index)) {
233 auto seq =
234 std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
235 equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
236 } else {
237 if (svar_index != -1 && i > IntToSize(svar_index)) {
238 expr_i = i + diff - 1;
239 }
240 equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
241 }
242 if (equiv == nullptr) {
243 return nullptr;
244 }
245 }
246 return equiv;
247 }
248
Match(const BaseRef & pattern,const BaseRef & expr,const PrimitiveVarMap & primitive_vars,EquivPtr equiv) const249 EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
250 EquivPtr equiv) const {
251 MS_LOG(DEBUG) << "-----[in Match]";
252 MS_LOG(DEBUG) << "GetVar w";
253 BaseRef pattern_ref = GetVar(pattern);
254 MS_LOG(DEBUG) << "GetVar v";
255 BaseRef expr_ref = expr;
256
257 if (equiv == nullptr) {
258 MS_LOG(EXCEPTION) << "Equiv pointer is null";
259 }
260
261 MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
262 // 1. if pattern_ref is var and already in equiv, replace it.
263 if (utils::isa<VarPtr>(pattern_ref)) {
264 VarPtr var = utils::cast<VarPtr>(pattern_ref);
265 auto iter = equiv->find(var);
266 if (iter != equiv->end()) {
267 pattern_ref = iter->second;
268 }
269 }
270
271 // 2. check equal
272 if (opt::AnfEqual(pattern_ref, expr_ref)) {
273 return equiv;
274 }
275
276 // 3. match var
277 EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
278 if (ret_equiv) {
279 return ret_equiv;
280 }
281
282 // 4. here the type can be std:vector, std:list,
283 // or cnode.
284 if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
285 MS_LOG(DEBUG) << "Type mismatch";
286 return nullptr;
287 }
288
289 // 5. transfer the Containers by visitor to std::vector
290 VectorRef values_pattern;
291 VectorRef values_expr;
292 if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
293 return nullptr;
294 }
295
296 // 6. if any svar in both side, find the SeqVar index,
297 // try to pack the Var s in std::vector to a Seq and match elements one by one.
298 // check svar
299 equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
300 UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
301 return equiv;
302 }
303
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)304 bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
305 // To matchCNode and Kernel's type
306 if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
307 return true;
308 }
309 return a.type() == b.type();
310 }
311 } // namespace mindspore
312