• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
18 #define MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <tuple>
23 #include <vector>
24 #include <algorithm>
25 
26 #include "base/core_ops.h"
27 #include "ir/visitor.h"
28 #include "utils/shape_utils.h"
29 
30 namespace mindspore {
31 ///
32 ///  Base class for all recognizable patterns.
33 ///  We implement an Expression Template approach using static polymorphism based on
34 ///  the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
35 ///  to the use of virtual functions without the costs..." as described in:
36 ///  https://en.wikipedia.org/wiki/Expression_templates and
37 ///  https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
38 ///  The TryCapture function tries to capture the pattern with the given node.
39 ///  The GetNode function builds a new node using the captured values.
40 ///
41 
42 template <typename T>
43 class PBase {
44  public:
CheckFunc(const PredicateFuncType & func,const AnfNodePtr & node)45   bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) { return func(get_object().GetNode(node)); }
46 
get_object()47   const T &get_object() const { return *static_cast<const T *>(this); }
48 
49   template <typename TN>
TryCapture(const TN & value)50   bool TryCapture(const TN &value) const {
51     get_object().Reset();
52     return get_object().TryCapture_(value);
53   }
54 
55   using Internal = T;
56 };
57 
58 template <typename T>
59 class PIsEqual {
60  public:
operator()61   bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
62 };
63 
64 template <typename T = AnfNodePtr>
65 class PatternNode : public PBase<PatternNode<T> > {
66  public:
GetNode(const AnfNodePtr &)67   T GetNode(const AnfNodePtr &) const {
68     if (!captured_) {
69       MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
70     }
71     return captured_node_;
72   }
73 
TryCapture_(const T & node)74   bool TryCapture_(const T &node) const {
75     if (!captured_) {
76       captured_node_ = node;
77       captured_ = true;
78       return true;
79     }
80     return PIsEqual<T>()(captured_node_, node);
81   }
82 
Reset()83   void Reset() const { captured_ = false; }
84   using Internal = const PatternNode<T> &;
85 
86  protected:
87   mutable T captured_node_;
88   mutable bool captured_{false};
89 };
90 
91 template <typename T, typename T2>
92 class PBinOperation : public PBase<PBinOperation<T, T2> > {
93  public:
94   PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false)
prim_(prim)95       : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {}
96   ~PBinOperation() = default;
97 
GetNode(const AnfNodePtr & node)98   AnfNodePtr GetNode(const AnfNodePtr &node) const {
99     AnfNodePtr lhs = x_.GetNode(node);
100     AnfNodePtr rhs = y_.GetNode(node);
101     AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs};
102     return NewCNode(list, node->func_graph());
103   }
104 
TryCapture_(const AnfNodePtr & node)105   bool TryCapture_(const AnfNodePtr &node) const {
106     if (IsPrimitiveCNode(node, prim_)) {
107       auto cnode = node->cast<CNodePtr>();
108       auto inputs = cnode->inputs();
109       if (inputs.size() == 3) {
110         // Binary Prim assumes only two inputs
111         if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) {
112           // If the operation is commutative, then check with inversed operands
113           if (is_commutative_) {
114             Reset();
115             if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) {
116               return false;
117             }
118             captured_binop_node_ = node;
119             return true;
120           }
121           return false;
122         }
123         captured_binop_node_ = node;
124         return true;
125       }
126     }
127     return false;
128   }
129 
130   /// Returns the original node captured by this Binary Operation Pattern.
131   /// Throws exception if a node was not captured before.
GetOriginalNode()132   AnfNodePtr GetOriginalNode() const {
133     if (captured_binop_node_ == nullptr) {
134       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
135     }
136 
137     return captured_binop_node_;
138   }
139 
Reset()140   void Reset() const {
141     x_.Reset();
142     y_.Reset();
143     captured_binop_node_ = nullptr;
144   }
145 
146   using Internal = const PBinOperation<T, T2> &;
147 
148  private:
149   const PrimitivePtr prim_;
150   typename T::Internal x_;
151   typename T2::Internal y_;
152   bool is_commutative_{false};
153   mutable AnfNodePtr captured_binop_node_{nullptr};
154 };
155 
156 template <typename T>
157 class PUnaryOperation : public PBase<PUnaryOperation<T> > {
158  public:
PUnaryOperation(const PrimitivePtr & prim,const T & x)159   PUnaryOperation(const PrimitivePtr &prim, const T &x) : prim_(prim), x_(x) {}
160   ~PUnaryOperation() = default;
161 
GetNode(const AnfNodePtr & node)162   AnfNodePtr GetNode(const AnfNodePtr &node) const {
163     AnfNodePtrList list = {NewValueNode(prim_), x_.GetNode(node)};
164     return NewCNode(list, node->func_graph());
165   }
166 
TryCapture_(const AnfNodePtr & node)167   bool TryCapture_(const AnfNodePtr &node) const {
168     if (IsPrimitiveCNode(node, prim_)) {
169       auto cnode = node->cast<CNodePtr>();
170       auto inputs = cnode->inputs();
171       if (inputs.size() == 2 && x_.TryCapture(inputs[1])) {
172         captured_unaryop_node_ = node;
173         return true;
174       }
175     }
176     return false;
177   }
178 
GetOriginalNode()179   AnfNodePtr GetOriginalNode() const {
180     if (captured_unaryop_node_ == nullptr) {
181       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
182     }
183     return captured_unaryop_node_;
184   }
185 
Reset()186   void Reset() const {
187     x_.Reset();
188     captured_unaryop_node_ = nullptr;
189   }
190 
191   using Internal = const PUnaryOperation<T> &;
192 
193  private:
194   const PrimitivePtr prim_;
195   typename T::Internal x_;
196   mutable AnfNodePtr captured_unaryop_node_{nullptr};
197 };
198 
199 ///
200 /// Helper functions to apply a pattern function on all elements of a tuple
201 ///
202 namespace tuple_utils {
203 template <bool stop, size_t Index, typename Func>
204 struct apply_func_tuple_item {
205   template <typename TTuple>
applyapply_func_tuple_item206   static void apply(Func *func, const TTuple &tuple) {
207     (*func)(Index, std::get<Index>(tuple));
208     apply_func_tuple_item<(Index + 1) == std::tuple_size<TTuple>::value, (Index + 1), Func>::apply(func, tuple);
209   }
210 };
211 
212 template <size_t Index, typename Func>
213 struct apply_func_tuple_item<true, Index, Func> {
214   template <typename TTuple>
215   static void apply(Func *func, const TTuple &tuple) {}
216 };
217 
218 template <typename Func, typename TTuple>
219 inline void apply_func_tuple(Func *func, const TTuple &tuple) {
220   apply_func_tuple_item<std::tuple_size<TTuple>::value == 0, 0, Func>::apply(func, tuple);
221 }
222 
223 struct PTupleResetCapture {
224   template <typename T>
225   void operator()(size_t i, const T &pattern) const {
226     pattern.Reset();
227   }
228 };
229 
230 struct PTupleCapture {
231   explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {}
232 
233   template <typename TPattern>
234   void operator()(size_t i, const TPattern &pattern) {
235     // Check if the first node is a Primitive
236     if (i == 0 && tuple_[i]->isa<Primitive>()) {
237       auto prim = tuple_[i]->cast<PrimitivePtr>();
238       if (tuple_[i] != pattern.GetNode(tuple_[i])) {
239         captured_ = false;
240       }
241     } else {
242       captured_ = captured_ && pattern.TryCapture_(tuple_[i]);
243     }
244   }
245 
246   const AnfNodePtrList tuple_;
247   bool captured_{true};
248 };
249 
250 struct PTupleGetNode {
251   explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {}
252 
253   template <typename TPattern>
254   void operator()(size_t, const TPattern &pattern) {
255     args_.push_back(pattern.GetNode(node_));
256   }
257 
258   const AnfNodePtr &node_;
259   std::vector<AnfNodePtr> args_;
260 };
261 }  // namespace tuple_utils
262 
263 template <typename... TArgs>
264 class PCNode : public PBase<PCNode<TArgs...> > {
265  public:
266   explicit PCNode(const TArgs &... args) : args_(args...) {}
267   ~PCNode() = default;
268 
269   AnfNodePtr GetNode(const AnfNodePtr &node) const {
270     tuple_utils::PTupleGetNode get_node(node);
271     tuple_utils::apply_func_tuple(&get_node, args_);
272     auto prim_cnode = get_node.args_;
273     // In case this PCNode has captured extra nodes
274     if (extra_nodes_.size() > 0) {
275       prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
276     }
277     return NewCNode(prim_cnode, node->func_graph());
278   }
279 
280   bool TryCapture_(const AnfNodePtr &node) const {
281     if (node->isa<CNode>()) {
282       auto cnode = node->cast<CNodePtr>();
283       auto inputs = cnode->inputs();
284 
285       auto pattern_arg_len = sizeof...(TArgs);
286       // There aren't enough inputs in Node to fill up the Pattern
287       if (inputs.size() < pattern_arg_len) {
288         return false;
289       }
290 
291       // Pattern must exactly match the number of Node inputs.
292       if (!has_min_extra_nodes_) {
293         // Inputs in Node perfectly match number of tokens in Pattern.
294         if (inputs.size() == pattern_arg_len) {
295           AnfNodePtrList tokens(inputs.begin(), inputs.end());
296           tuple_utils::PTupleCapture capture_func(tokens);
297           tuple_utils::apply_func_tuple(&capture_func, args_);
298           return capture_func.captured_;
299         }
300         return false;
301       }
302 
303       // Pattern may accept extra (non specified) nodes at the end of the CNode
304       // There must be at least `min_extra_nodes` additional nodes in the inputs.
305       if (inputs.size() >= pattern_arg_len + min_extra_nodes_) {
306         AnfNodePtrList tokens(inputs.begin(), inputs.begin() + pattern_arg_len);
307         tuple_utils::PTupleCapture capture_func(tokens);
308         tuple_utils::apply_func_tuple(&capture_func, args_);
309         // If it could capture the initial set of nodes specified in the Pattern
310         // and there are enough extra inputs to add
311         if (capture_func.captured_ && inputs.size() > pattern_arg_len) {
312           extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + pattern_arg_len, inputs.end());
313           return true;
314         }
315         return capture_func.captured_;
316       }
317       return false;
318     }
319     return false;
320   }
321 
322   /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one
323   /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
324   /// more nodes after the last one specified when building the PCNode.
325   const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
326     has_min_extra_nodes_ = true;
327     min_extra_nodes_ = min_extra_nodes;
328     return *this;
329   }
330 
331   using Internal = const PCNode<TArgs...> &;
332 
333   void Reset() const {
334     tuple_utils::PTupleResetCapture reset;
335     tuple_utils::apply_func_tuple(&reset, args_);
336     extra_nodes_.clear();
337   }
338 
339  private:
340   std::tuple<typename TArgs::Internal...> args_;
341   mutable AnfNodePtrList extra_nodes_;
342   mutable bool has_min_extra_nodes_{false};
343   mutable size_t min_extra_nodes_{0};
344 };
345 
346 template <typename... TArgs>
347 class PPrimitive : public PBase<PPrimitive<TArgs...> > {
348  public:
349   explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
350   ~PPrimitive() = default;
351 
352   AnfNodePtr GetNode(const AnfNodePtr &node) const {
353     tuple_utils::PTupleGetNode get_node(node);
354     tuple_utils::apply_func_tuple(&get_node, args_);
355     auto prim_cnode = get_node.args_;
356     prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
357 
358     // In case this PPrimitive has captured extra nodes
359     if (extra_nodes_.size() > 0) {
360       prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
361     }
362     return NewCNode(prim_cnode, node->func_graph());
363   }
364 
365   bool TryCapture_(const AnfNodePtr &node) const {
366     if (IsPrimitiveCNode(node, prim_)) {
367       auto cnode = node->cast<CNodePtr>();
368       auto inputs = cnode->inputs();
369       // Number of arguments in Primitive Pattern (not including the Primitive node)
370       auto pattern_arg_len = sizeof...(TArgs);
371       // There aren't enough inputs in Node to fill up the Pattern
372       if ((inputs.size() - 1) < pattern_arg_len) {
373         return false;
374       }
375 
376       // Pattern must exactly match the number of Node inputs.
377       if (!has_min_extra_nodes_) {
378         // Inputs in Node perfectly match number of tokens in Pattern.
379         if ((inputs.size() - 1) == pattern_arg_len) {
380           AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
381           tuple_utils::PTupleCapture capture_func(tokens);
382           tuple_utils::apply_func_tuple(&capture_func, args_);
383           if (capture_func.captured_) {
384             captured_prim_node_ = node;
385           }
386           return capture_func.captured_;
387         }
388         return false;
389       }
390 
391       // Pattern may accept extra (non specified) nodes at the end of the Primitive
392       // There must be at least `min_extra_nodes` additional nodes in the inputs.
393       if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) {
394         AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len);
395         tuple_utils::PTupleCapture capture_func(tokens);
396         tuple_utils::apply_func_tuple(&capture_func, args_);
397         // If it could capture the initial set of nodes specified in the Pattern
398         // and there are enough extra inputs to add
399         if (capture_func.captured_) {
400           captured_prim_node_ = node;
401           if (inputs.size() > pattern_arg_len + 1) {
402             extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
403           }
404         }
405         return capture_func.captured_;
406       }
407       return false;
408     }
409     return false;
410   }
411 
412   /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one
413   /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
414   /// more nodes after the last one specified when building the PPrimitive.
415   const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
416     has_min_extra_nodes_ = true;
417     min_extra_nodes_ = min_extra_nodes;
418     return *this;
419   }
420 
421   const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; }
422 
423   /// Returns the FuncGraph of the original node captured by this Primitive Pattern.
424   /// Throws exception if a node was not captured before.
425   FuncGraphPtr GetFuncGraph() const {
426     if (captured_prim_node_ == nullptr) {
427       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph.";
428     }
429 
430     return captured_prim_node_->func_graph();
431   }
432 
433   /// Returns the original node captured by this Primitive Pattern.
434   /// Throws exception if a node was not captured before.
435   AnfNodePtr GetOriginalNode() const {
436     if (captured_prim_node_ == nullptr) {
437       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
438     }
439 
440     return captured_prim_node_;
441   }
442 
443   void Reset() const {
444     tuple_utils::PTupleResetCapture reset;
445     tuple_utils::apply_func_tuple(&reset, args_);
446     extra_nodes_.clear();
447     captured_prim_node_ = nullptr;
448   }
449 
450   using Internal = const PPrimitive<TArgs...> &;
451 
452  private:
453   const PrimitivePtr prim_;
454   std::tuple<typename TArgs::Internal...> args_;
455   mutable AnfNodePtrList extra_nodes_;
456   mutable bool has_min_extra_nodes_{false};
457   mutable size_t min_extra_nodes_{0};
458   mutable AnfNodePtr captured_prim_node_{nullptr};
459 };
460 
461 ///
462 /// PConstant class can capture a value node of a specified value (check_value_)
463 /// or a non-specified one (any_value = true).
464 /// It can be configured to capture a scalar constant as well (is_scalar_ = true)
465 ///
466 template <typename T = AnfNodePtr>
467 class PConstant : public PBase<PConstant<T> > {
468  public:
469   explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int64_t check_value = 0,
470                      const bool is_scalar = false)
471       : as_node_(as_node),
472         captured_node_(as_node),
473         any_value_(any_value),
474         check_value_(check_value),
475         is_scalar_(is_scalar) {}
476 
477   ~PConstant() = default;
478   // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode
479   const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const {
480     if (node == nullptr) {
481       MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node.";
482     }
483     as_node_ = node;
484     changed_shape_ = true;
485     return *this;
486   }
487 
488   // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode
489   const PConstant<T> &WithShapeAs(const PatternNode<T> &pnode) const {
490     if (captured_node_ == nullptr) {
491       MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node.";
492     }
493     as_node_ = pnode.GetNode(captured_node_);
494     changed_shape_ = true;
495     return *this;
496   }
497 
498   /// Sets captured_node_ as the node captured by the Pattern received as argument
499   /// to produce a new node with its contents when calling GetNode.
500   const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const {
501     if (!any_value_) {
502       MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node.";
503     }
504     if (captured_node_ == nullptr) {
505       MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node.";
506     }
507     captured_node_ = pnode.GetNode(captured_node_);
508     changed_shape_ = true;
509     return *this;
510   }
511 
512   /// Create a new Value Node filled up with check_value.
513   /// This function must be used immediately before GetNode to avoid replacing the expected result.
514   /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf.
515   const PConstant<T> &NewValue() const {
516     if (!is_scalar_) {
517       MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants.";
518     }
519     auto value_node_ = MakeValue(check_value_);
520     captured_node_ = NewValueNode(value_node_);
521     is_new_value_node_ = true;
522     return *this;
523   }
524 
525   AnfNodePtr GetNode(const AnfNodePtr &node) const {
526     // If a NewValueNode was requested (using NewValue function) then return that created node.
527     if (is_new_value_node_) {
528       return captured_node_;
529     }
530     /// Return a NewTensorFilledWithData if the node was initialized to have a specific value
531     /// even if it wasn't captured. Usually for zero constants (x - x => zero).
532     /// If the shape was changed, use the new shape.
533     if (changed_shape_ || !captured_) {
534       if (!any_value_) {
535         return NewTensorFilledWithData(as_node_, check_value_);
536       }
537       return NewTensorFilledWithData(as_node_, captured_node_);
538     }
539     return captured_node_;
540   }
541 
542   bool TryCapture_(const AnfNodePtr &node) const {
543     if (node->isa<ValueNode>()) {
544       // If any_value_ is set don't check for the node's value. Just capture it.
545       if (any_value_) {
546         captured_node_ = node;
547         captured_ = true;
548         return true;
549       }
550 
551       auto value = node->cast<ValueNodePtr>()->value();
552       if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) {
553         captured_node_ = node;
554         captured_ = true;
555         return true;
556       }
557 
558       auto value_node_ = MakeValue(check_value_);
559       if (*GetValueNode(node) == *value_node_) {
560         captured_node_ = node;
561         captured_ = true;
562         return true;
563       }
564     }
565     return false;
566   }
567 
568   void Reset() const {
569     captured_ = false;
570     changed_shape_ = false;
571     is_new_value_node_ = false;
572   }
573 
574   // Support function used for checking if all values of a Tensor are equal to `check_value_`
575   // Supported data types: double, float/float32, int/int32
576   bool IsTensorConstant(const ValuePtr &value) const {
577     if (!value->isa<tensor::Tensor>()) {
578       return false;
579     }
580     auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
581     TypeId tensor_type = tensor_ptr->Dtype()->type_id();
582     if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
583       float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
584       auto threshold = FLT_MIN;
585       for (int i = 0; i < tensor_ptr->DataSize(); i++) {
586         if (fabs(data2[i] - check_value_) > threshold) {
587           return false;
588         }
589       }
590       return true;
591     } else if (tensor_type == TypeId::kNumberTypeFloat64) {
592       double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
593       auto threshold = DBL_MIN;
594       for (int i = 0; i < tensor_ptr->DataSize(); i++) {
595         if (fabs(data2[i] - check_value_) > threshold) {
596           return false;
597         }
598       }
599       return true;
600     } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
601       int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
602       for (int i = 0; i < tensor_ptr->DataSize(); i++) {
603         if (data2[i] != check_value_) {
604           return false;
605         }
606       }
607       return true;
608     }
609     // Input Data Type is not supported
610     return false;
611   }
612 
613   bool IsTensorScalarConstant(const ValuePtr &value) const {
614     if (!value->isa<tensor::Tensor>()) {
615       return false;
616     }
617     auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
618     if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
619       return false;
620     }
621     return IsTensorConstant(value);
622   }
623 
624   void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const {
625     if (!node->isa<ValueNode>()) {
626       return nullptr;
627     }
628     auto value = node->cast<ValueNodePtr>()->value();
629     if (!value->isa<tensor::Tensor>()) {
630       return nullptr;
631     }
632 
633     tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
634     return tensor_ptr->data_c();
635   }
636 
637   // Make a new tensor (when possible) with the same shape as of `node`
638   // If x is nullptr then fill new tensor will "0"
639   // If x is a tensor with empty shape then fill new tensor with the single value of x
640   // If x is a tensor with same shape as `node` then return x as result
641   AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const {
642     if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
643       return nullptr;
644     }
645 
646     auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
647     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
648     ShapeVector tensor_shape = tensor_abstract->shape()->shape();
649 
650     auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
651     size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
652     char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
653 
654     if (x == nullptr) {
655       if (memset_s(data, mem_size, 0, mem_size) != 0) {
656         return nullptr;
657       }
658       auto new_vnode = NewValueNode(new_tensor_ptr);
659       new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
660       return new_vnode;
661     }
662     // x is not nullptr
663     if (x->isa<CNode>() || x->isa<Parameter>()) {
664       if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
665         return nullptr;
666       }
667       auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
668       ShapeVector x_shape = x_abstract->shape()->shape();
669       if (x_shape != tensor_shape) {
670         return nullptr;
671       }
672       return x;
673     }
674 
675     if (!x->isa<ValueNode>()) {
676       return nullptr;
677     }
678     auto x_value = x->cast<ValueNodePtr>()->value();
679     if (!x_value->isa<tensor::Tensor>()) {
680       return nullptr;
681     }
682 
683     auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
684     if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
685       return nullptr;
686     }
687     int ret = 0;
688     char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
689     MS_EXCEPTION_IF_NULL(source_data);
690     if (x_tensor_ptr->DataSize() == 1) {
691       auto tensor_type_byte = GetTypeByte(tensor_type_ptr);
692       for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
693         ret = memcpy_s(data + i * tensor_type_byte, tensor_type_byte, source_data, tensor_type_byte);
694         if (ret != 0) {
695           MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << tensor_type_byte << ", dest size "
696                        << tensor_type_byte;
697         }
698       }
699     } else {
700       ret = memcpy_s(data, mem_size, source_data, mem_size);
701       if (ret != 0) {
702         MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << mem_size << ", dest size "
703                      << new_tensor_ptr->DataSize();
704         return nullptr;
705       }
706     }
707     auto new_vnode = NewValueNode(new_tensor_ptr);
708     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
709     return new_vnode;
710   }
711 
712   AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const {
713     if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
714       return nullptr;
715     }
716 
717     auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
718     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
719     ShapeVector tensor_shape = tensor_abstract->shape()->shape();
720 
721     auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
722     size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
723     char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
724 
725     if (memset_s(data, mem_size, value, mem_size) != 0) {
726       return nullptr;
727     }
728     auto new_vnode = NewValueNode(new_tensor_ptr);
729     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
730     return new_vnode;
731   }
732 
733   template <typename TD>
734   TD CalcuConstant(const TD &data, const PrimitivePtr &calcu_type) {
735     TD tmp_data = data;
736     if (calcu_type == prim::kPrimReciprocal) {
737       if (data == 0) {
738         MS_EXCEPTION(ValueError);
739       } else {
740         tmp_data = 1 / data;
741       }
742     }
743     if (calcu_type == prim::kPrimNeg) {
744       tmp_data = -data;
745     }
746     return tmp_data;
747   }
748 
749   template <typename TD>
750   bool TensorCopyData(const tensor::TensorPtr &src_tensor_ptr, const tensor::TensorPtr &dst_tensor_ptr,
751                       const PrimitivePtr &calcu_type, size_t mem_size) {
752     auto *data = reinterpret_cast<TD *>(src_tensor_ptr->data_c());
753     auto *data2 = reinterpret_cast<TD *>(dst_tensor_ptr->data_c());
754     if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
755       return false;
756     }
757     for (int i = 0; i < src_tensor_ptr->DataSize(); i++) {
758       if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
759         return false;
760       }
761       data2[i] = CalcuConstant(data2[i], calcu_type);
762     }
763     return true;
764   }
765 
766   // calculate const with different operations
767   AnfNodePtr CalcuConstantTensor(const AnfNodePtr &node, const ValuePtr &value, const PrimitivePtr &calcu_type) {
768     tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
769     TypeId tensor_type = tensor_ptr->Dtype()->type_id();
770     auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
771     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
772     ShapeVector tensor_shape = tensor_abstract->shape()->shape();
773     auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
774     size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
775     if (new_tensor_ptr->DataSize() < tensor_ptr->DataSize()) {
776       MS_EXCEPTION(ValueError) << "DataSize of new_tensor_ptr is smaller than DataSize of tensor_ptr";
777     }
778     if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
779         (tensor_type == TypeId::kNumberTypeFloat64)) {
780       if (!TensorCopyData<float>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) {
781         return nullptr;
782       }
783     }
784     if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
785       if (!TensorCopyData<int>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) {
786         return nullptr;
787       }
788     }
789     if (tensor_type == TypeId::kNumberTypeFloat64) {
790       if (!TensorCopyData<double>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) {
791         return nullptr;
792       }
793     }
794     auto new_vnode = NewValueNode(new_tensor_ptr);
795     new_vnode->set_abstract(tensor_ptr->ToAbstract());
796     return new_vnode;
797   }
798 
799   // calculate const with different operations
800   AnfNodePtr ValueNodeWithOprations(const PrimitivePtr &calcu_type) {
801     AnfNodePtr node = this->GetNode(captured_node_);
802     if (!node->isa<ValueNode>()) {
803       MS_EXCEPTION(ValueError) << "CalcuValue is trying to use a not ValueNode.";
804     }
805     auto value = node->cast<ValueNodePtr>()->value();
806     if (value->isa<tensor::Tensor>()) {
807       return CalcuConstantTensor(node, value, calcu_type);
808     }
809     return nullptr;
810   }
811 
812   enum BinOperator {
813     ADD = 0,
814     MULTIPLY,
815   };
816 
817   // Support function to add/multiply two constant tensors: partially support broadcasting shapes
818   template <typename TM>
819   void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
820                       int out_data_size, BinOperator bin_operator) const {
821     if (out_data_size <= 0) {
822       MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros";
823     }
824     TM *data_1 = reinterpret_cast<TM *>(in_data_1);
825     TM *data_2 = reinterpret_cast<TM *>(in_data_2);
826     TM *data_out = new TM[out_data_size];
827 
828     if (in_data_1_size == 1) {
829       for (int i = 0; i < out_data_size; i++) {
830         data_out[i] = data_1[0];
831       }
832     } else {
833       for (int i = 0; i < out_data_size; i++) {
834         data_out[i] = data_1[i];
835       }
836     }
837     if (in_data_2_size == 1) {
838       for (int i = 0; i < out_data_size; i++) {
839         if (bin_operator == ADD) {
840           data_out[i] += data_2[0];
841         } else {
842           data_out[i] *= data_2[0];
843         }
844       }
845     } else {
846       if (in_data_2_size < out_data_size) {
847         MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size
848                      << ".in_data2 will be broadcast.";
849       }
850       auto min_size = std::min<int>(in_data_2_size, out_data_size);
851       for (int i = 0; i < min_size; i++) {
852         if (bin_operator == ADD) {
853           data_out[i] += data_2[i];
854         } else {
855           data_out[i] *= data_2[i];
856         }
857       }
858       // In case of in_data2_size < out_data_size
859       for (int i = min_size; i < out_data_size; i++) {
860         if (bin_operator != ADD) {
861           // if operator is MUL, data_out[i] *= 0, => data_out[i] = 0.
862           data_out[i] = 0;
863         }
864         // if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange.
865       }
866     }
867     *out_data = reinterpret_cast<void *>(data_out);
868     return;
869   }
870 
871   AnfNodePtr AddByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
872     AnfNodePtr vnode_1 = this->GetNode(captured_node_);
873     AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
874     return CalcConstantTensors(vnode_1, vnode_2, node_3, ADD);
875   }
876 
877   AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
878     AnfNodePtr vnode_1 = this->GetNode(captured_node_);
879     AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
880     return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY);
881   }
882 
883   tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &vnode) const {
884     if (!vnode->isa<ValueNode>() || vnode->abstract() == nullptr) {
885       return nullptr;
886     }
887     auto value = GetValueNode(vnode);
888     if (!value->isa<tensor::Tensor>()) {
889       return nullptr;
890     }
891     return dyn_cast<tensor::Tensor>(value);
892   }
893 
894   AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3,
895                                  BinOperator bin_operator) const {
896     auto tensor_ptr_1 = GetTensorFromValueNode(vnode_1);
897     auto tensor_ptr_2 = GetTensorFromValueNode(vnode_2);
898     if (tensor_ptr_1 == nullptr || tensor_ptr_2 == nullptr || node_3->abstract() == nullptr) {
899       return nullptr;
900     }
901     tensor::TensorPtr new_tensor_ptr = GetNewTensor(vnode_1, vnode_2, node_3);
902     if (new_tensor_ptr == nullptr) {
903       return nullptr;
904     }
905     ShapeVector tensor_out_shape = new_tensor_ptr->shape();
906     int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
907     size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * IntToSize(new_tensor_ptr->ElementsNum());
908     char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
909 
910     int ret = 0;
911     void *data_out = nullptr;
912     if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
913         (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
914       CalcByOperator<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
915                             tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
916       ret = memcpy_s(data, mem_size, data_out, mem_size);
917       delete[] reinterpret_cast<float *>(data_out);
918     } else {
919       if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
920         CalcByOperator<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
921                                tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
922         ret = memcpy_s(data, mem_size, data_out, mem_size);
923         delete[] reinterpret_cast<double *>(data_out);
924       } else {
925         if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
926             (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
927           CalcByOperator<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
928                               tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
929           ret = memcpy_s(data, mem_size, data_out, mem_size);
930           delete[] reinterpret_cast<int *>(data_out);
931         } else {
932           // Unsupported data types
933           return nullptr;
934         }
935       }
936     }
937     if (ret != 0) {
938       MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
939                         << new_tensor_ptr->DataSize();
940     }
941     auto new_vnode = NewValueNode(new_tensor_ptr);
942     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
943     return new_vnode;
944   }
945 
946   using Internal = const PConstant<T> &;
947 
948  protected:
949   mutable AnfNodePtr as_node_;
950   mutable AnfNodePtr captured_node_;
951   bool any_value_{true};
952   int64_t check_value_{0};
953   bool is_scalar_{false};
954   mutable bool is_new_value_node_{false};
955   mutable bool captured_{false};
956   mutable bool changed_shape_{false};
957 
958  private:
959   tensor::TensorPtr GetNewTensor(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
960     auto value_1 = GetValueNode(vnode_1);
961     auto value_2 = GetValueNode(vnode_2);
962     auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
963     auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
964     auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
965     auto tensor_2_abstract = vnode_2->abstract()->cast<abstract::AbstractTensorPtr>();
966 
967     TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
968     TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
969     if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) &&
970         (tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) {
971       // If two constant nodes have the same shape, then create a new one with this shape
972       auto tensor_out_shape = tensor_1_abstract->shape()->shape();
973 
974       return std::make_shared<tensor::Tensor>(tensor_1_type_ptr->type_id(), tensor_out_shape);
975     } else {
976       // If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node
977       auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
978 
979       TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
980       if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
981           (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
982         return nullptr;
983       }
984       auto tensor_out_shape = tensor_3_abstract->shape()->shape();
985       int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
986       if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
987         return nullptr;
988       }
989       if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
990         return nullptr;
991       }
992       return std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
993     }
994   }
995 };
996 
997 // Macro for binary operation functions
998 #define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative)                   \
999   template <typename T, typename T2>                                                \
1000   inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) {     \
1001     return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \
1002   }
1003 
1004 // Arithmetic operations
1005 BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true);
1006 BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
1007 BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
1008 BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
1009 
1010 // Macros for match and replace
1011 #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
1012   if ((CaptureNode).TryCapture(OrigNode)) {               \
1013     auto rep = (ReplaceWith).GetNode(OrigNode);           \
1014     if (rep != nullptr) {                                 \
1015       return rep;                                         \
1016     }                                                     \
1017   }
1018 
1019 #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
1020   if ((CaptureNode).TryCapture(OrigNode) && (Condition)) {              \
1021     auto rep = (ReplaceWith).GetNode(OrigNode);                         \
1022     if (rep != nullptr) {                                               \
1023       return rep;                                                       \
1024     }                                                                   \
1025   }
1026 
1027 #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
1028   if ((CaptureNode).TryCapture(OrigNode)) {                                            \
1029     if ((Condition)) {                                                                 \
1030       auto rep = (ReplaceWith).GetNode(OrigNode);                                      \
1031       if (rep != nullptr) {                                                            \
1032         return (ReplaceWith).GetNode(OrigNode);                                        \
1033       }                                                                                \
1034     } else {                                                                           \
1035       auto rep = (ElseNode).GetNode(OrigNode);                                         \
1036       if (rep != nullptr) {                                                            \
1037         return (ElseNode).GetNode(OrigNode);                                           \
1038       }                                                                                \
1039     }                                                                                  \
1040   }
1041 
1042 #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
1043   if ((CaptureNode).TryCapture(OrigNode)) {                 \
1044     auto rep = (Lambda)();                                  \
1045     if (rep != nullptr) {                                   \
1046       return rep;                                           \
1047     }                                                       \
1048   }
1049 
1050 #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
1051   if ((CaptureNode).TryCapture(OrigNode) && (Condition)) {                \
1052     auto rep = (Lambda)();                                                \
1053     if (rep != nullptr) {                                                 \
1054       return rep;                                                         \
1055     }                                                                     \
1056   }
1057 
1058 #define MATCH_REPLACE_LAMBDA_FLAG(OrigNode, CaptureNode, Lambda, Flag) \
1059   if ((CaptureNode).TryCapture(OrigNode)) {                            \
1060     auto rep = (Lambda)(Flag);                                         \
1061     if (rep != nullptr) {                                              \
1062       return rep;                                                      \
1063     }                                                                  \
1064   }
1065 }  // namespace mindspore
1066 
1067 #endif  // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
1068