• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
18 #define MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <tuple>
23 #include <vector>
24 #include <algorithm>
25 #include <utility>
26 
27 #include "ir/visitor.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "ir/func_graph.h"
30 #include "utils/shape_utils.h"
31 
32 namespace mindspore {
33 ///
34 ///  Base class for all recognizable patterns.
35 ///  We implement an Expression Template approach using static polymorphism based on
36 ///  the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
37 ///  to the use of virtual functions without the costs..." as described in:
38 ///  https://en.wikipedia.org/wiki/Expression_templates and
39 ///  https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
40 ///  The TryCapture function tries to capture the pattern with the given node.
41 ///  The GetNode function builds a new node using the captured values.
42 ///
43 
44 template <typename T>
45 class PBase {
46  public:
CheckFunc(const PredicateFuncType & func,const AnfNodePtr & node)47   bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) const {
48     return func(get_object().GetNode(node));
49   }
50 
get_object()51   const T &get_object() const { return *static_cast<const T *>(this); }
52 
53   template <typename TN>
TryCapture(const TN & value)54   bool TryCapture(const TN &value) const {
55     const auto &self = get_object();
56     self.Reset();
57     return self.TryCapture_(value);
58   }
59 
60   using Internal = T;
61 };
62 
63 template <typename T = AnfNodePtr>
64 class PatternNode : public PBase<PatternNode<T> > {
65  public:
66   virtual ~PatternNode() = default;
GetNode(const AnfNodePtr &)67   T GetNode(const AnfNodePtr &) const {
68     if (!captured_) {
69       MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
70     }
71     return captured_node_;
72   }
73 
TryCapture_(const T & node)74   bool TryCapture_(const T &node) const {
75     if (!captured_) {
76       captured_node_ = node;
77       captured_ = true;
78       return true;
79     }
80     return captured_node_ == node;
81   }
82 
Reset()83   void Reset() const { captured_ = false; }
84   using Internal = const PatternNode<T> &;
85 
86  protected:
87   mutable T captured_node_;
88   mutable bool captured_{false};
89 };
90 
91 template <typename T, typename T2>
92 class PBinOperation : public PBase<PBinOperation<T, T2> > {
93  public:
94   PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false)
prim_(prim)95       : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {}
96   ~PBinOperation() = default;
97 
GetNode(const AnfNodePtr & node)98   AnfNodePtr GetNode(const AnfNodePtr &node) const {
99     AnfNodePtr lhs = x_.GetNode(node);
100     AnfNodePtr rhs = y_.GetNode(node);
101     AnfNodePtrList inputs{NewValueNode(prim_), lhs, rhs};
102     return NewCNode(std::move(inputs), node->func_graph());
103   }
104 
TryCapture_(const AnfNodePtr & node)105   bool TryCapture_(const AnfNodePtr &node) const {
106     if (!IsPrimitiveCNode(node, prim_)) {
107       return false;
108     }
109     auto cnode = node->cast<CNodePtr>();
110     const auto &inputs = cnode->inputs();
111     // Binary Prim assumes only two inputs.
112     constexpr size_t bin_op_input_size = 3;
113     if (inputs.size() != bin_op_input_size) {
114       return false;
115     }
116     constexpr size_t bin_op_lhs_input = 1;
117     constexpr size_t bin_op_rhs_input = 2;
118     const auto &input1 = inputs[bin_op_lhs_input];
119     const auto &input2 = inputs[bin_op_rhs_input];
120     // Try capture inputs.
121     if (x_.TryCapture(input1) && y_.TryCapture(input2)) {
122       captured_binop_node_ = node;
123       return true;
124     }
125     if (!is_commutative_) {
126       return false;
127     }
128     // If the operation is commutative, then check with inversed operands.
129     Reset();
130     if (x_.TryCapture(input2) && y_.TryCapture(input1)) {
131       captured_binop_node_ = node;
132       return true;
133     }
134     return false;
135   }
136 
137   /// Returns the original node captured by this Binary Operation Pattern.
138   /// Throws exception if a node was not captured before.
GetOriginalNode()139   AnfNodePtr GetOriginalNode() const {
140     if (captured_binop_node_ == nullptr) {
141       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
142     }
143     return captured_binop_node_;
144   }
145 
Reset()146   void Reset() const {
147     x_.Reset();
148     y_.Reset();
149     captured_binop_node_ = nullptr;
150   }
151 
152   using Internal = const PBinOperation<T, T2> &;
153 
154  private:
155   const PrimitivePtr prim_;
156   typename T::Internal x_;
157   typename T2::Internal y_;
158   bool is_commutative_{false};
159   mutable AnfNodePtr captured_binop_node_{nullptr};
160 };
161 
162 ///
163 /// Helper functions to apply a pattern function on all elements of a tuple.
164 ///
165 namespace tuple_utils {
166 template <size_t Index = 0, typename TTuple, typename Func>
ForEach(const TTuple & tuple,Func && func)167 inline constexpr void ForEach(const TTuple &tuple, Func &&func) {
168   if constexpr (Index < std::tuple_size<TTuple>::value) {
169     func(std::get<Index>(tuple));
170     ForEach<Index + 1, TTuple, Func>(tuple, std::forward<Func>(func));
171   }
172 }
173 
174 template <typename TTuple>
ResetAll(const TTuple & tuple)175 inline void ResetAll(const TTuple &tuple) {
176   ForEach(tuple, [](auto &arg) { arg.Reset(); });
177 }
178 
179 template <bool First = true, size_t Index = 0, typename TTuple, typename Iter>
CaptureAll(const TTuple & tuple,Iter iter)180 inline bool CaptureAll(const TTuple &tuple, Iter iter) {
181   if constexpr (Index < std::tuple_size<TTuple>::value) {
182     const auto &input = *iter;
183     const auto &pattern = std::get<Index>(tuple);
184     if constexpr (First) {
185       // Check if the first node is a Primitive.
186       if (input->template isa<Primitive>()) {
187         if (input != pattern.GetNode(input)) {
188           return false;
189         }
190       } else if (!pattern.TryCapture_(input)) {
191         return false;
192       }
193     } else {
194       if (!pattern.TryCapture_(input)) {
195         return false;
196       }
197     }
198     return CaptureAll<false, Index + 1, TTuple, Iter>(tuple, ++iter);
199   } else {
200     return true;
201   }
202 }
203 }  // namespace tuple_utils
204 
205 template <typename... TArgs>
206 class PCNode : public PBase<PCNode<TArgs...> > {
207  public:
PCNode(const TArgs &...args)208   explicit PCNode(const TArgs &... args) : args_(args...) {}
209   virtual ~PCNode() = default;
210 
GetNode(const AnfNodePtr & node)211   AnfNodePtr GetNode(const AnfNodePtr &node) const {
212     constexpr auto n_args = sizeof...(TArgs);
213     const auto n_extra_nodes = extra_nodes_.size();
214     std::vector<AnfNodePtr> inputs;
215     inputs.reserve(n_args + n_extra_nodes);
216     // Get nodes from input patterns.
217     tuple_utils::ForEach(args_, [&inputs, &node](auto &arg) { (void)inputs.emplace_back(arg.GetNode(node)); });
218     if (n_extra_nodes > 0) {
219       // In case this PCNode has captured extra nodes.
220       inputs.insert(inputs.begin(), extra_nodes_.begin(), extra_nodes_.end());
221     }
222     return NewCNode(std::move(inputs), node->func_graph());
223   }
224 
TryCapture_(const AnfNodePtr & node)225   bool TryCapture_(const AnfNodePtr &node) const {
226     auto cnode = dyn_cast<CNode>(node);
227     if (cnode == nullptr) {
228       return false;
229     }
230     const auto &inputs = cnode->inputs();
231     const auto inputs_size = inputs.size();
232     constexpr auto pattern_arg_len = sizeof...(TArgs);
233     // There aren't enough inputs in Node to fill up the Pattern.
234     if (inputs_size < pattern_arg_len) {
235       return false;
236     }
237     // Pattern must exactly match the number of Node inputs.
238     if (!has_min_extra_nodes_) {
239       // Inputs in Node should perfectly match number of tokens in Pattern.
240       if (inputs_size != pattern_arg_len) {
241         return false;
242       }
243       return tuple_utils::CaptureAll(args_, inputs.begin());
244     }
245     // Pattern may accept extra (non specified) nodes at the end of the CNode,
246     // There must be at least `min_extra_nodes` additional nodes in the inputs.
247     if (inputs_size < pattern_arg_len + min_extra_nodes_) {
248       return false;
249     }
250     auto captured = tuple_utils::CaptureAll(args_, inputs.begin());
251     // If it could capture the initial set of nodes specified in the Pattern
252     // and there are enough extra inputs to add.
253     if (captured && inputs_size > pattern_arg_len) {
254       (void)extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + SizeToLong(pattern_arg_len), inputs.end());
255     }
256     return captured;
257   }
258 
259   /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one
260   /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
261   /// more nodes after the last one specified when building the PCNode.
262   const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
263     has_min_extra_nodes_ = true;
264     min_extra_nodes_ = min_extra_nodes;
265     return *this;
266   }
267 
268   using Internal = const PCNode<TArgs...> &;
269 
Reset()270   void Reset() const {
271     tuple_utils::ResetAll(args_);
272     extra_nodes_.clear();
273   }
274 
275  private:
276   std::tuple<typename TArgs::Internal...> args_;
277   mutable AnfNodePtrList extra_nodes_;
278   mutable bool has_min_extra_nodes_{false};
279   mutable size_t min_extra_nodes_{0};
280 };
281 
282 template <typename... TArgs>
283 class PPrimitive : public PBase<PPrimitive<TArgs...> > {
284  public:
PPrimitive(const PrimitivePtr & prim,const TArgs &...args)285   explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
286   virtual ~PPrimitive() = default;
287 
GetNode(const AnfNodePtr & node)288   AnfNodePtr GetNode(const AnfNodePtr &node) const {
289     std::vector<AnfNodePtr> inputs;
290     constexpr auto n_args = sizeof...(TArgs);
291     const auto n_extra_nodes = extra_nodes_.size();
292     inputs.reserve(1 + n_args + n_extra_nodes);
293     // The first input is the value of primitive.
294     (void)inputs.emplace_back(NewValueNode(prim_));
295     // Get nodes from input patterns.
296     tuple_utils::ForEach(args_, [&inputs, &node](auto &arg) { (void)inputs.emplace_back(arg.GetNode(node)); });
297     if (n_extra_nodes > 0) {
298       // In case this PPrimitive has captured extra nodes.
299       (void)inputs.insert(inputs.begin(), extra_nodes_.begin(), extra_nodes_.end());
300     }
301     return NewCNode(std::move(inputs), node->func_graph());
302   }
303 
TryCapture_(const AnfNodePtr & node)304   bool TryCapture_(const AnfNodePtr &node) const {
305     if (!IsPrimitiveCNode(node, prim_)) {
306       return false;
307     }
308     auto cnode = node->cast<CNodePtr>();
309     const auto &inputs = cnode->inputs();
310     const auto inputs_size = inputs.size();
311     // Number of arguments in Primitive Pattern (not including the Primitive node).
312     constexpr auto pattern_arg_len = sizeof...(TArgs);
313     // There aren't enough inputs in Node to fill up the Pattern.
314     if (inputs_size < (pattern_arg_len + 1)) {
315       return false;
316     }
317     // Pattern must exactly match the number of Node inputs.
318     if (!has_min_extra_nodes_) {
319       if (inputs_size != (pattern_arg_len + 1)) {
320         return false;
321       }
322       // Inputs in Node perfectly match number of tokens in Pattern.
323       auto captured = tuple_utils::CaptureAll<false>(args_, inputs.begin() + 1);
324       if (captured) {
325         captured_prim_node_ = node;
326       }
327       return captured;
328     }
329     // Pattern may accept extra (non specified) nodes at the end of the Primitive,
330     // There must be at least `min_extra_nodes` additional nodes in the inputs.
331     if (inputs_size < (pattern_arg_len + 1 + min_extra_nodes_)) {
332       return false;
333     }
334     auto captured = tuple_utils::CaptureAll<false>(args_, inputs.begin() + 1);
335     // If it could capture the initial set of nodes specified in the Pattern
336     // and there are enough extra inputs to add.
337     if (captured) {
338       captured_prim_node_ = node;
339       if (inputs_size > pattern_arg_len + 1) {
340         (void)extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + SizeToLong(pattern_arg_len), inputs.end());
341       }
342     }
343     return captured;
344   }
345 
346   /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one
347   /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
348   /// more nodes after the last one specified when building the PPrimitive.
349   const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
350     has_min_extra_nodes_ = true;
351     min_extra_nodes_ = min_extra_nodes;
352     return *this;
353   }
354 
GetCapturedExtraNodes()355   const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; }
356 
357   /// Returns the FuncGraph of the original node captured by this Primitive Pattern.
358   /// Throws exception if a node was not captured before.
GetFuncGraph()359   FuncGraphPtr GetFuncGraph() const {
360     if (captured_prim_node_ == nullptr) {
361       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph.";
362     }
363     return captured_prim_node_->func_graph();
364   }
365 
366   /// Returns the original node captured by this Primitive Pattern.
367   /// Throws exception if a node was not captured before.
GetOriginalNode()368   AnfNodePtr GetOriginalNode() const {
369     if (captured_prim_node_ == nullptr) {
370       MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
371     }
372     return captured_prim_node_;
373   }
374 
Reset()375   void Reset() const {
376     tuple_utils::ResetAll(args_);
377     extra_nodes_.clear();
378     captured_prim_node_ = nullptr;
379   }
380 
381   using Internal = const PPrimitive<TArgs...> &;
382 
383  private:
384   const PrimitivePtr prim_;
385   std::tuple<typename TArgs::Internal...> args_;
386   mutable AnfNodePtrList extra_nodes_;
387   mutable bool has_min_extra_nodes_{false};
388   mutable size_t min_extra_nodes_{0};
389   mutable AnfNodePtr captured_prim_node_{nullptr};
390 };
391 
392 ///
393 /// PConstant class can capture a value node of a specified value (check_value_)
394 /// or a non-specified one (any_value = true).
395 /// It can be configured to capture a scalar constant as well (is_scalar_ = true)
396 ///
397 template <typename T = AnfNodePtr>
398 class PConstant : public PBase<PConstant<T> > {
399  public:
400   explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int64_t check_value = 0,
401                      const bool is_scalar = false)
as_node_(as_node)402       : as_node_(as_node),
403         captured_node_(as_node),
404         any_value_(any_value),
405         check_value_(check_value),
406         is_scalar_(is_scalar) {}
407 
408   virtual ~PConstant() = default;
409   // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode
WithShapeAs(const AnfNodePtr & node)410   const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const {
411     if (node == nullptr) {
412       MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node.";
413     }
414     as_node_ = node;
415     changed_shape_ = true;
416     return *this;
417   }
418 
419   // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode
WithShapeAs(const PatternNode<T> & pnode)420   const PConstant<T> &WithShapeAs(const PatternNode<T> &pnode) const {
421     if (captured_node_ == nullptr) {
422       MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node.";
423     }
424     as_node_ = pnode.GetNode(captured_node_);
425     changed_shape_ = true;
426     return *this;
427   }
428 
429   /// Sets captured_node_ as the node captured by the Pattern received as argument
430   /// to produce a new node with its contents when calling GetNode.
WithValueOf(const PatternNode<T> & pnode)431   const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const {
432     if (!any_value_) {
433       MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node.";
434     }
435     if (captured_node_ == nullptr) {
436       MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node.";
437     }
438     captured_node_ = pnode.GetNode(captured_node_);
439     changed_shape_ = true;
440     return *this;
441   }
442 
443   /// Create a new Value Node filled up with check_value.
444   /// This function must be used immediately before GetNode to avoid replacing the expected result.
445   /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf.
NewValue()446   const PConstant<T> &NewValue() const {
447     if (!is_scalar_) {
448       MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants.";
449     }
450     auto value_node_ = MakeValue(check_value_);
451     captured_node_ = NewValueNode(value_node_);
452     is_new_value_node_ = true;
453     return *this;
454   }
455 
GetNode(const AnfNodePtr &)456   AnfNodePtr GetNode(const AnfNodePtr &) const {
457     // If a NewValueNode was requested (using NewValue function) then return that created node.
458     if (is_new_value_node_) {
459       return captured_node_;
460     }
461     /// Return a NewTensorFilledWithData if the node was initialized to have a specific value
462     /// even if it wasn't captured. Usually for zero constants (x - x => zero).
463     /// If the shape was changed, use the new shape.
464     if (changed_shape_ || !captured_) {
465       if (!any_value_) {
466         return NewTensorFilledWithData(as_node_, check_value_);
467       }
468       return NewTensorFilledWithData(as_node_, captured_node_);
469     }
470     return captured_node_;
471   }
472 
TryCapture_(const AnfNodePtr & node)473   bool TryCapture_(const AnfNodePtr &node) const {
474     if (!node->isa<ValueNode>()) {
475       return false;
476     }
477     // If any_value_ is set don't check for the node's value. Just capture it.
478     if (any_value_) {
479       captured_node_ = node;
480       captured_ = true;
481       return true;
482     }
483     // Check value.
484     auto value = node->cast<ValueNodePtr>()->value();
485     if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) {
486       captured_node_ = node;
487       captured_ = true;
488       return true;
489     }
490     auto value_node_ = MakeValue(check_value_);
491     if (*GetValueNode(node) == *value_node_) {
492       captured_node_ = node;
493       captured_ = true;
494       return true;
495     }
496     return false;
497   }
498 
Reset()499   void Reset() const {
500     captured_ = false;
501     changed_shape_ = false;
502     is_new_value_node_ = false;
503   }
504 
505   // Support function used for checking if all values of a Tensor are equal to `check_value_`
506   // Supported data types: double, float/float32, int/int32
IsTensorConstant(const ValuePtr & value)507   bool IsTensorConstant(const ValuePtr &value) const {
508     if (!value->isa<tensor::Tensor>()) {
509       return false;
510     }
511     auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
512     TypeId tensor_type = tensor_ptr->Dtype()->type_id();
513     if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
514       float *data2 = static_cast<float *>(tensor_ptr->data_c());
515       auto threshold = FLT_MIN;
516       for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
517         if (fabs(data2[i] - check_value_) > threshold) {
518           return false;
519         }
520       }
521       return true;
522     } else if (tensor_type == TypeId::kNumberTypeFloat64) {
523       double *data2 = static_cast<double *>(tensor_ptr->data_c());
524       auto threshold = DBL_MIN;
525       for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
526         if (fabs(data2[i] - check_value_) > threshold) {
527           return false;
528         }
529       }
530       return true;
531     } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
532       int *data2 = static_cast<int *>(tensor_ptr->data_c());
533       for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
534         if (data2[i] != check_value_) {
535           return false;
536         }
537       }
538       return true;
539     }
540     // Input Data Type is not supported
541     return false;
542   }
543 
IsTensorScalarConstant(const ValuePtr & value)544   bool IsTensorScalarConstant(const ValuePtr &value) const {
545     if (!value->isa<tensor::Tensor>()) {
546       return false;
547     }
548     auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
549     if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
550       return false;
551     }
552     return IsTensorConstant(value);
553   }
554 
GetPointerToTensorData(const AnfNodePtr & node)555   void *GetPointerToTensorData(const AnfNodePtr &node) const {
556     if (!node->isa<ValueNode>()) {
557       return nullptr;
558     }
559     auto value = node->cast<ValueNodePtr>()->value();
560     if (!value->isa<tensor::Tensor>()) {
561       return nullptr;
562     }
563 
564     tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
565     return tensor_ptr->data_c();
566   }
567 
568   // Make a new tensor (when possible) with the same shape as of `node`
569   // If x is nullptr then fill new tensor will "0"
570   // If x is a tensor with empty shape then fill new tensor with the single value of x
571   // If x is a tensor with same shape as `node` then return x as result
572   AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const {
573     if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
574       return nullptr;
575     }
576 
577     auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
578     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
579     ShapeVector tensor_shape = tensor_abstract->shape()->shape();
580     if (x == nullptr) {
581       auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
582       char *data = static_cast<char *>(new_tensor_ptr->data_c());
583       if (memset_s(data, new_tensor_ptr->Size(), 0, new_tensor_ptr->Size()) != EOK) {
584         return nullptr;
585       }
586       auto new_vnode = NewValueNode(new_tensor_ptr);
587       new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
588       return new_vnode;
589     }
590     // x is not nullptr
591     if (x->isa<CNode>() || x->isa<Parameter>()) {
592       if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
593         return nullptr;
594       }
595       auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
596       ShapeVector x_shape = x_abstract->shape()->shape();
597       if (x_shape != tensor_shape) {
598         return nullptr;
599       }
600       return x;
601     }
602 
603     if (!x->isa<ValueNode>()) {
604       return nullptr;
605     }
606     auto x_value = x->cast<ValueNodePtr>()->value();
607     if (!x_value->isa<tensor::Tensor>()) {
608       return nullptr;
609     }
610     auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
611     auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
612     if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
613       return nullptr;
614     }
615     int ret = 0;
616     char *source_data = static_cast<char *>(GetPointerToTensorData(x));
617     MS_EXCEPTION_IF_NULL(source_data);
618     if (x_tensor_ptr->DataSize() == 1) {
619       auto tensor_type_byte = GetTypeByte(tensor_type_ptr);
620       char *data = static_cast<char *>(new_tensor_ptr->data_c());
621       for (int64_t i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
622         ret = memcpy_s(data + LongToSize(i) * tensor_type_byte, tensor_type_byte, source_data, tensor_type_byte);
623         if (ret != EOK) {
624           MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << tensor_type_byte << ", dest size "
625                        << tensor_type_byte;
626         }
627       }
628     } else {
629       char *data = static_cast<char *>(new_tensor_ptr->data_c());
630       ret = memcpy_s(data, new_tensor_ptr->Size(), source_data, new_tensor_ptr->Size());
631       if (ret != EOK) {
632         MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << new_tensor_ptr->Size()
633                      << ", dest size " << new_tensor_ptr->DataSize();
634         return nullptr;
635       }
636     }
637     auto new_vnode = NewValueNode(new_tensor_ptr);
638     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
639     return new_vnode;
640   }
641 
NewTensorFilledWithData(const AnfNodePtr & node,const int & value)642   AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const {
643     if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
644       return nullptr;
645     }
646 
647     auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
648     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
649     ShapeVector tensor_shape = tensor_abstract->shape()->shape();
650 
651     auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
652     size_t mem_size = GetTypeByte(tensor_type_ptr) * LongToSize(new_tensor_ptr->ElementsNum());
653     char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
654 
655     if (memset_s(data, mem_size, value, mem_size) != EOK) {
656       return nullptr;
657     }
658     auto new_vnode = NewValueNode(new_tensor_ptr);
659     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
660     return new_vnode;
661   }
662 
663   enum BinOperator {
664     ADD,
665     MULTIPLY,
666   };
667 
668   // Support function to add/multiply two constant tensors: partially support broadcasting shapes
669   template <typename TM>
CalcByOperator(void * in_data_1,int in_data_1_size,void * in_data_2,int in_data_2_size,void ** out_data,int out_data_size,BinOperator bin_operator)670   void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
671                       int out_data_size, BinOperator bin_operator) const {
672     if (out_data_size <= 0) {
673       MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros";
674     }
675     TM *data_1 = static_cast<TM *>(in_data_1);
676     TM *data_2 = static_cast<TM *>(in_data_2);
677     TM *data_out = new TM[out_data_size];
678 
679     if (in_data_1_size == 1) {
680       for (int i = 0; i < out_data_size; i++) {
681         data_out[i] = data_1[0];
682       }
683     } else {
684       for (int i = 0; i < out_data_size; i++) {
685         data_out[i] = data_1[i];
686       }
687     }
688     if (in_data_2_size == 1) {
689       for (int i = 0; i < out_data_size; i++) {
690         if (bin_operator == ADD) {
691           data_out[i] += data_2[0];
692         } else {
693           data_out[i] *= data_2[0];
694         }
695       }
696     } else {
697       if (in_data_2_size < out_data_size) {
698         MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size
699                      << ".in_data2 will be broadcast.";
700       }
701       auto min_size = std::min<int>(in_data_2_size, out_data_size);
702       for (int i = 0; i < min_size; i++) {
703         if (bin_operator == ADD) {
704           data_out[i] += data_2[i];
705         } else {
706           data_out[i] *= data_2[i];
707         }
708       }
709       // In case of in_data2_size < out_data_size
710       for (int i = min_size; i < out_data_size; i++) {
711         if (bin_operator != ADD) {
712           // if operator is MUL, data_out[i] *= 0, => data_out[i] = 0.
713           data_out[i] = 0;
714         }
715         // if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange.
716       }
717     }
718     *out_data = static_cast<void *>(data_out);
719   }
720 
MulByPatternConst(const PConstant<T> & vpnode_2,const AnfNodePtr & node_3)721   AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
722     AnfNodePtr vnode_1 = this->GetNode(captured_node_);
723     AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
724     return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY);
725   }
726 
GetTensorFromValueNode(const AnfNodePtr & vnode)727   tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &vnode) const {
728     if (!vnode->isa<ValueNode>() || vnode->abstract() == nullptr) {
729       return nullptr;
730     }
731     auto value = GetValueNode(vnode);
732     if (!value->isa<tensor::Tensor>()) {
733       return nullptr;
734     }
735     return dyn_cast<tensor::Tensor>(value);
736   }
737 
CalcConstantTensors(const AnfNodePtr & vnode_1,const AnfNodePtr & vnode_2,const AnfNodePtr & node_3,BinOperator bin_operator)738   AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3,
739                                  BinOperator bin_operator) const {
740     auto tensor_ptr_1 = GetTensorFromValueNode(vnode_1);
741     auto tensor_ptr_2 = GetTensorFromValueNode(vnode_2);
742     if (tensor_ptr_1 == nullptr || tensor_ptr_2 == nullptr || node_3->abstract() == nullptr) {
743       return nullptr;
744     }
745     tensor::TensorPtr new_tensor_ptr = GetNewTensor(vnode_1, vnode_2, node_3);
746     if (new_tensor_ptr == nullptr) {
747       return nullptr;
748     }
749     ShapeVector tensor_out_shape = new_tensor_ptr->shape();
750     int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
751     size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * LongToSize(new_tensor_ptr->ElementsNum());
752     char *data = static_cast<char *>(new_tensor_ptr->data_c());
753 
754     int ret = 0;
755     void *data_out = nullptr;
756     if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
757         (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
758       CalcByOperator<float>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
759                             SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
760       ret = memcpy_s(data, mem_size, data_out, mem_size);
761       delete[] static_cast<float *>(data_out);
762     } else {
763       if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
764         CalcByOperator<double>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
765                                SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
766         ret = memcpy_s(data, mem_size, data_out, mem_size);
767         delete[] static_cast<double *>(data_out);
768       } else {
769         if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
770             (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
771           CalcByOperator<int>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
772                               SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
773           ret = memcpy_s(data, mem_size, data_out, mem_size);
774           delete[] static_cast<int *>(data_out);
775         } else {
776           // Unsupported data types
777           return nullptr;
778         }
779       }
780     }
781     if (ret != EOK) {
782       MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
783                         << new_tensor_ptr->DataSize();
784     }
785     auto new_vnode = NewValueNode(new_tensor_ptr);
786     new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
787     return new_vnode;
788   }
789 
790   using Internal = const PConstant<T> &;
791 
792  protected:
793   mutable AnfNodePtr as_node_;
794   mutable AnfNodePtr captured_node_;
795   bool any_value_{true};
796   int64_t check_value_{0};
797   bool is_scalar_{false};
798   mutable bool is_new_value_node_{false};
799   mutable bool captured_{false};
800   mutable bool changed_shape_{false};
801 
802  private:
GetNewTensor(const AnfNodePtr & vnode_1,const AnfNodePtr & vnode_2,const AnfNodePtr & node_3)803   tensor::TensorPtr GetNewTensor(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
804     auto value_1 = GetValueNode(vnode_1);
805     auto value_2 = GetValueNode(vnode_2);
806     auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
807     auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
808     auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
809     auto tensor_2_abstract = vnode_2->abstract()->cast<abstract::AbstractTensorPtr>();
810 
811     TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
812     TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
813     if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) &&
814         (tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) {
815       // If two constant nodes have the same shape, then create a new one with this shape
816       auto tensor_out_shape = tensor_1_abstract->shape()->shape();
817 
818       return std::make_shared<tensor::Tensor>(tensor_1_type_ptr->type_id(), tensor_out_shape);
819     } else {
820       // If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node
821       auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
822 
823       TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
824       if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
825           (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
826         return nullptr;
827       }
828       auto tensor_out_shape = tensor_3_abstract->shape()->shape();
829       // if dynamic, return nullptr
830       if (std::any_of(tensor_out_shape.begin(), tensor_out_shape.end(), [](int64_t i) { return i < 0; })) {
831         return nullptr;
832       }
833       size_t data_out_size =
834         LongToSize(std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int64_t>()));
835       if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
836         return nullptr;
837       }
838       if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
839         return nullptr;
840       }
841       return std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
842     }
843   }
844 };
845 
846 // Macro for binary operation functions
847 #define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative)                   \
848   template <typename T, typename T2>                                                \
849   inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) {     \
850     return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \
851   }
852 
853 // Arithmetic operations
854 BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true);
855 BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
856 BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
857 BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
858 
859 // Macros for match and replace
860 #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
861   do {                                                    \
862     if ((CaptureNode).TryCapture(OrigNode)) {             \
863       auto rep = (ReplaceWith).GetNode(OrigNode);         \
864       if (rep != nullptr) {                               \
865         return rep;                                       \
866       }                                                   \
867     }                                                     \
868   } while (0)
869 
870 #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
871   do {                                                                  \
872     if ((CaptureNode).TryCapture(OrigNode) && (Condition)) {            \
873       auto rep = (ReplaceWith).GetNode(OrigNode);                       \
874       if (rep != nullptr) {                                             \
875         return rep;                                                     \
876       }                                                                 \
877     }                                                                   \
878   } while (0)
879 
880 #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
881   do {                                                      \
882     if ((CaptureNode).TryCapture(OrigNode)) {               \
883       auto rep = (Lambda)();                                \
884       if (rep != nullptr) {                                 \
885         return rep;                                         \
886       }                                                     \
887     }                                                       \
888   } while (0)
889 
890 #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
891   do {                                                                    \
892     if ((CaptureNode).TryCapture(OrigNode) && (Condition)) {              \
893       auto rep = (Lambda)();                                              \
894       if (rep != nullptr) {                                               \
895         return rep;                                                       \
896       }                                                                   \
897     }                                                                     \
898   } while (0)
899 }  // namespace mindspore
900 
901 #endif  // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
902