• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "backend/optimizer/common/pattern_engine.h"
20 #include "frontend/optimizer/opt.h"
21 #include "ir/anf.h"
22 #include "utils/convert_utils_base.h"
23 #include "utils/overload.h"
24 #include "backend/optimizer/common/helper.h"
25 
26 namespace mindspore {
GetNextTag()27 static int GetNextTag() {
28   static int kID = 0;
29   return kID++;
30 }
31 
EnsureTag()32 void Var::EnsureTag() {
33   if (tag_.length() == 0) {
34     std::ostringstream buffer;
35     buffer << "_" << GetNextTag();
36     tag_ = buffer.str();
37   }
38 }
39 
operator ==(const VarPtr & lhs,const VarPtr & rhs)40 bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
41   if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
42     CondVarPtr v1 = dyn_cast<CondVar>(lhs);
43     CondVarPtr v2 = dyn_cast<CondVar>(rhs);
44     return *v1 == *v2;
45   }
46 
47   if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
48     SVarPtr v1 = dyn_cast<SeqVar>(lhs);
49     SVarPtr v2 = dyn_cast<SeqVar>(rhs);
50     return *v1 == *v2;
51   }
52   return (*lhs == *rhs);
53 }
54 
ToString() const55 std::string SeqVar::ToString() const {
56   std::ostringstream buffer;
57   buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
58   return buffer.str();
59 }
60 
operator <<(std::ostream & os,const VarPtr & var)61 std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
62   if (var == nullptr) {
63     os << "";
64   } else {
65     os << var->ToString();
66   }
67   return os;
68 }
69 
70 template <>
71 std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
72   os << "[Equiv]"
73      << "\n";
74   for (auto &equiv_item : equiv) {
75     auto k = equiv_item.first;
76     os << k << ":";
77     BaseRef x = equiv_item.second;
78     if (utils::isa<AnfNodePtr>(x)) {
79       auto node = utils::cast<AnfNodePtr>(x);
80       os << "TypeString[" << node->type_name() << "]";
81       if (IsValueNode<FuncGraph>(node)) {
82         os << "IsValueNodeGraph ";
83       }
84       os << "type " << node->type_name();
85       if (node->isa<ValueNode>()) {
86         os << " value " << GetValueNode(node);
87       }
88       os << " addr: " << node;
89     } else if (utils::isa<Named>(x)) {
90       os << "Named " << x.ToString().c_str();
91     } else if (utils::isa<VarPtr>(x)) {
92       os << "TypeString[Var]";
93       os << (utils::cast<VarPtr>(x));
94     } else if (utils::isa<FuncGraphPtr>(x)) {
95       os << "TypeString[Graph]";
96     }
97     os << "\n";
98   }
99   return os;
100 }
101 
GetVar(const BaseRef & x)102 static BaseRef GetVar(const BaseRef &x) {
103   MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
104   if (utils::isa<AnfNodePtr>(x)) {
105     auto node = utils::cast<AnfNodePtr>(x);
106     MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
107     if (node->isa<VarNode>()) {
108       MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
109       return node->cast<VarNodePtr>()->var_;
110     }
111     if (node->isa<ValueNode>()) {
112       MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
113     } else {
114       MS_LOG(DEBUG) << "type " + node->type_name();
115     }
116   } else if (utils::isa<Named>(x)) {
117     MS_LOG(DEBUG) << "Named " + x.ToString();
118   } else if (utils::isa<VectorRef>(x)) {
119     MS_LOG(DEBUG) << "VectorRef";
120   } else if (utils::isa<VarPtr>(x)) {
121     MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
122   }
123   MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
124   return x;
125 }
126 
MatchOnVar(const BaseRef & pattern,const BaseRef & expr,EquivPtr equiv)127 EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
128   MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
129   MS_EXCEPTION_IF_NULL(equiv);
130   if (utils::isa<VarPtr>(pattern)) {
131     VarPtr var = utils::cast<VarPtr>(pattern);
132     if (var->matches(expr)) {
133       (*equiv)[var] = expr;
134       MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
135       return equiv;
136     }
137   }
138 
139   return nullptr;
140 }
141 
ToVector(const VectorRef & pattern_ref,const VectorRef & expr_ref,VectorRef * const values_pattern,VectorRef * const values_expr) const142 bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
143                              VectorRef *const values_expr) const {
144   MS_EXCEPTION_IF_NULL(values_expr);
145   if (utils::isa<SeqPtr>(pattern_ref)) {
146     *values_pattern = pattern_ref;
147     *values_expr = expr_ref;
148     return true;
149   }
150   return false;
151 }
152 
ToVector(const BaseRef & pattern_ref,const BaseRef & expr_ref,VectorRef * const values_pattern,VectorRef * const values_expr) const153 bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
154                              VectorRef *const values_expr) const {
155   MS_EXCEPTION_IF_NULL(values_expr);
156   MS_LOG(DEBUG) << "visit pattern_ref";
157   bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
158   if (!success) {
159     return false;
160   }
161   MS_LOG(DEBUG) << "visit expr_ref";
162   return visitor_->Visit(expr_ref, values_expr, nullptr);
163 }
164 
GetSVarStartIndex(const VectorRef & values)165 static int GetSVarStartIndex(const VectorRef &values) {
166   int index = -1;
167   int count = 0;
168   for (auto &value : values) {
169     if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
170       if (index != -1) {
171         MS_LOG(DEBUG) << "Multiple SVars in sequence";
172         return kInvalidVarIndex;
173       }
174       index = count;
175     }
176     count++;
177   }
178   return index;
179 }
180 
UpdateEquivMap(const VectorRef & values_pattern,const BaseRef & expr_ref,const PrimitiveVarMap & primitive_vars,const EquivPtr & equiv)181 void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
182                     const EquivPtr &equiv) {
183   if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
184       !utils::isa<AnfNodePtr>(expr_ref)) {
185     return;
186   }
187   auto real_node = utils::cast<AnfNodePtr>(expr_ref);
188   MS_EXCEPTION_IF_NULL(real_node);
189   if (!real_node->isa<CNode>()) {
190     return;
191   }
192   auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
193   MS_EXCEPTION_IF_NULL(prim_node);
194   if (!IsValueNode<Primitive>(prim_node)) {
195     return;
196   }
197   ValuePtr value = GetValueNode(prim_node);
198   MS_EXCEPTION_IF_NULL(value);
199   auto prim = value->cast<PrimitivePtr>();
200   MS_EXCEPTION_IF_NULL(prim);
201   auto iter = primitive_vars.find(prim);
202   if (iter == primitive_vars.end()) {
203     return;
204   }
205   (*equiv)[iter->second] = real_node;
206 }
207 
AlignSVar(const VectorRef & values_pattern,const VectorRef & values_expr,const PrimitiveVarMap & primitive_vars,EquivPtr equiv) const208 EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
209                                   const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
210   int svar_index = GetSVarStartIndex(values_pattern);
211   if (svar_index == kInvalidVarIndex) {
212     return nullptr;
213   }
214 
215   size_t values_pattern_len = values_pattern.size();
216   size_t values_expr_len = values_expr.size();
217 
218   if (svar_index == -1) {
219     if (values_pattern_len != values_expr_len) {
220       MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len "
221                     << values_expr_len;
222       return nullptr;
223     }
224   }
225   if (values_expr_len < values_pattern_len - 1) {
226     MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
227     return nullptr;
228   }
229   size_t diff = values_expr_len - values_pattern_len + 1;
230   for (size_t i = 0; i < values_pattern_len; i++) {
231     size_t expr_i = i;
232     if (svar_index != -1 && i == IntToSize(svar_index)) {
233       auto seq =
234         std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
235       equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
236     } else {
237       if (svar_index != -1 && i > IntToSize(svar_index)) {
238         expr_i = i + diff - 1;
239       }
240       equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
241     }
242     if (equiv == nullptr) {
243       return nullptr;
244     }
245   }
246   return equiv;
247 }
248 
Match(const BaseRef & pattern,const BaseRef & expr,const PrimitiveVarMap & primitive_vars,EquivPtr equiv) const249 EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
250                               EquivPtr equiv) const {
251   MS_LOG(DEBUG) << "-----[in Match]";
252   MS_LOG(DEBUG) << "GetVar w";
253   BaseRef pattern_ref = GetVar(pattern);
254   MS_LOG(DEBUG) << "GetVar v";
255   BaseRef expr_ref = expr;
256 
257   if (equiv == nullptr) {
258     MS_LOG(EXCEPTION) << "Equiv pointer is null";
259   }
260 
261   MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
262   // 1. if pattern_ref is var and already in equiv, replace it.
263   if (utils::isa<VarPtr>(pattern_ref)) {
264     VarPtr var = utils::cast<VarPtr>(pattern_ref);
265     auto iter = equiv->find(var);
266     if (iter != equiv->end()) {
267       pattern_ref = iter->second;
268     }
269   }
270 
271   // 2. check equal
272   if (opt::AnfEqual(pattern_ref, expr_ref)) {
273     return equiv;
274   }
275 
276   // 3. match var
277   EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
278   if (ret_equiv) {
279     return ret_equiv;
280   }
281 
282   // 4. here the type can be std:vector, std:list,
283   // or cnode.
284   if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
285     MS_LOG(DEBUG) << "Type mismatch";
286     return nullptr;
287   }
288 
289   // 5. transfer the Containers by visitor to std::vector
290   VectorRef values_pattern;
291   VectorRef values_expr;
292   if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
293     return nullptr;
294   }
295 
296   // 6. if any svar in both side, find the SeqVar index,
297   // try to pack the Var s in std::vector to a Seq and match elements one by one.
298   // check svar
299   equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
300   UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
301   return equiv;
302 }
303 
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)304 bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
305   // To matchCNode and Kernel's type
306   if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
307     return true;
308   }
309   return a.type() == b.type();
310 }
311 }  // namespace mindspore
312