1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
18 #define MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
19
20 #include <functional>
21 #include <memory>
22 #include <tuple>
23 #include <vector>
24 #include <algorithm>
25 #include <utility>
26
27 #include "ir/visitor.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "ir/func_graph.h"
30 #include "utils/shape_utils.h"
31
32 namespace mindspore {
33 ///
34 /// Base class for all recognizable patterns.
35 /// We implement an Expression Template approach using static polymorphism based on
36 /// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
37 /// to the use of virtual functions without the costs..." as described in:
38 /// https://en.wikipedia.org/wiki/Expression_templates and
39 /// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
40 /// The TryCapture function tries to capture the pattern with the given node.
41 /// The GetNode function builds a new node using the captured values.
42 ///
43
44 template <typename T>
45 class PBase {
46 public:
CheckFunc(const PredicateFuncType & func,const AnfNodePtr & node)47 bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) const {
48 return func(get_object().GetNode(node));
49 }
50
get_object()51 const T &get_object() const { return *static_cast<const T *>(this); }
52
53 template <typename TN>
TryCapture(const TN & value)54 bool TryCapture(const TN &value) const {
55 const auto &self = get_object();
56 self.Reset();
57 return self.TryCapture_(value);
58 }
59
60 using Internal = T;
61 };
62
63 template <typename T = AnfNodePtr>
64 class PatternNode : public PBase<PatternNode<T> > {
65 public:
66 virtual ~PatternNode() = default;
GetNode(const AnfNodePtr &)67 T GetNode(const AnfNodePtr &) const {
68 if (!captured_) {
69 MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
70 }
71 return captured_node_;
72 }
73
TryCapture_(const T & node)74 bool TryCapture_(const T &node) const {
75 if (!captured_) {
76 captured_node_ = node;
77 captured_ = true;
78 return true;
79 }
80 return captured_node_ == node;
81 }
82
Reset()83 void Reset() const { captured_ = false; }
84 using Internal = const PatternNode<T> &;
85
86 protected:
87 mutable T captured_node_;
88 mutable bool captured_{false};
89 };
90
91 template <typename T, typename T2>
92 class PBinOperation : public PBase<PBinOperation<T, T2> > {
93 public:
94 PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false)
prim_(prim)95 : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {}
96 ~PBinOperation() = default;
97
GetNode(const AnfNodePtr & node)98 AnfNodePtr GetNode(const AnfNodePtr &node) const {
99 AnfNodePtr lhs = x_.GetNode(node);
100 AnfNodePtr rhs = y_.GetNode(node);
101 AnfNodePtrList inputs{NewValueNode(prim_), lhs, rhs};
102 return NewCNode(std::move(inputs), node->func_graph());
103 }
104
TryCapture_(const AnfNodePtr & node)105 bool TryCapture_(const AnfNodePtr &node) const {
106 if (!IsPrimitiveCNode(node, prim_)) {
107 return false;
108 }
109 auto cnode = node->cast<CNodePtr>();
110 const auto &inputs = cnode->inputs();
111 // Binary Prim assumes only two inputs.
112 constexpr size_t bin_op_input_size = 3;
113 if (inputs.size() != bin_op_input_size) {
114 return false;
115 }
116 constexpr size_t bin_op_lhs_input = 1;
117 constexpr size_t bin_op_rhs_input = 2;
118 const auto &input1 = inputs[bin_op_lhs_input];
119 const auto &input2 = inputs[bin_op_rhs_input];
120 // Try capture inputs.
121 if (x_.TryCapture(input1) && y_.TryCapture(input2)) {
122 captured_binop_node_ = node;
123 return true;
124 }
125 if (!is_commutative_) {
126 return false;
127 }
128 // If the operation is commutative, then check with inversed operands.
129 Reset();
130 if (x_.TryCapture(input2) && y_.TryCapture(input1)) {
131 captured_binop_node_ = node;
132 return true;
133 }
134 return false;
135 }
136
137 /// Returns the original node captured by this Binary Operation Pattern.
138 /// Throws exception if a node was not captured before.
GetOriginalNode()139 AnfNodePtr GetOriginalNode() const {
140 if (captured_binop_node_ == nullptr) {
141 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
142 }
143 return captured_binop_node_;
144 }
145
Reset()146 void Reset() const {
147 x_.Reset();
148 y_.Reset();
149 captured_binop_node_ = nullptr;
150 }
151
152 using Internal = const PBinOperation<T, T2> &;
153
154 private:
155 const PrimitivePtr prim_;
156 typename T::Internal x_;
157 typename T2::Internal y_;
158 bool is_commutative_{false};
159 mutable AnfNodePtr captured_binop_node_{nullptr};
160 };
161
162 ///
163 /// Helper functions to apply a pattern function on all elements of a tuple.
164 ///
165 namespace tuple_utils {
166 template <size_t Index = 0, typename TTuple, typename Func>
ForEach(const TTuple & tuple,Func && func)167 inline constexpr void ForEach(const TTuple &tuple, Func &&func) {
168 if constexpr (Index < std::tuple_size<TTuple>::value) {
169 func(std::get<Index>(tuple));
170 ForEach<Index + 1, TTuple, Func>(tuple, std::forward<Func>(func));
171 }
172 }
173
174 template <typename TTuple>
ResetAll(const TTuple & tuple)175 inline void ResetAll(const TTuple &tuple) {
176 ForEach(tuple, [](auto &arg) { arg.Reset(); });
177 }
178
179 template <bool First = true, size_t Index = 0, typename TTuple, typename Iter>
CaptureAll(const TTuple & tuple,Iter iter)180 inline bool CaptureAll(const TTuple &tuple, Iter iter) {
181 if constexpr (Index < std::tuple_size<TTuple>::value) {
182 const auto &input = *iter;
183 const auto &pattern = std::get<Index>(tuple);
184 if constexpr (First) {
185 // Check if the first node is a Primitive.
186 if (input->template isa<Primitive>()) {
187 if (input != pattern.GetNode(input)) {
188 return false;
189 }
190 } else if (!pattern.TryCapture_(input)) {
191 return false;
192 }
193 } else {
194 if (!pattern.TryCapture_(input)) {
195 return false;
196 }
197 }
198 return CaptureAll<false, Index + 1, TTuple, Iter>(tuple, ++iter);
199 } else {
200 return true;
201 }
202 }
203 } // namespace tuple_utils
204
205 template <typename... TArgs>
206 class PCNode : public PBase<PCNode<TArgs...> > {
207 public:
PCNode(const TArgs &...args)208 explicit PCNode(const TArgs &... args) : args_(args...) {}
209 virtual ~PCNode() = default;
210
GetNode(const AnfNodePtr & node)211 AnfNodePtr GetNode(const AnfNodePtr &node) const {
212 constexpr auto n_args = sizeof...(TArgs);
213 const auto n_extra_nodes = extra_nodes_.size();
214 std::vector<AnfNodePtr> inputs;
215 inputs.reserve(n_args + n_extra_nodes);
216 // Get nodes from input patterns.
217 tuple_utils::ForEach(args_, [&inputs, &node](auto &arg) { (void)inputs.emplace_back(arg.GetNode(node)); });
218 if (n_extra_nodes > 0) {
219 // In case this PCNode has captured extra nodes.
220 inputs.insert(inputs.begin(), extra_nodes_.begin(), extra_nodes_.end());
221 }
222 return NewCNode(std::move(inputs), node->func_graph());
223 }
224
TryCapture_(const AnfNodePtr & node)225 bool TryCapture_(const AnfNodePtr &node) const {
226 auto cnode = dyn_cast<CNode>(node);
227 if (cnode == nullptr) {
228 return false;
229 }
230 const auto &inputs = cnode->inputs();
231 const auto inputs_size = inputs.size();
232 constexpr auto pattern_arg_len = sizeof...(TArgs);
233 // There aren't enough inputs in Node to fill up the Pattern.
234 if (inputs_size < pattern_arg_len) {
235 return false;
236 }
237 // Pattern must exactly match the number of Node inputs.
238 if (!has_min_extra_nodes_) {
239 // Inputs in Node should perfectly match number of tokens in Pattern.
240 if (inputs_size != pattern_arg_len) {
241 return false;
242 }
243 return tuple_utils::CaptureAll(args_, inputs.begin());
244 }
245 // Pattern may accept extra (non specified) nodes at the end of the CNode,
246 // There must be at least `min_extra_nodes` additional nodes in the inputs.
247 if (inputs_size < pattern_arg_len + min_extra_nodes_) {
248 return false;
249 }
250 auto captured = tuple_utils::CaptureAll(args_, inputs.begin());
251 // If it could capture the initial set of nodes specified in the Pattern
252 // and there are enough extra inputs to add.
253 if (captured && inputs_size > pattern_arg_len) {
254 (void)extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + SizeToLong(pattern_arg_len), inputs.end());
255 }
256 return captured;
257 }
258
259 /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one
260 /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
261 /// more nodes after the last one specified when building the PCNode.
262 const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
263 has_min_extra_nodes_ = true;
264 min_extra_nodes_ = min_extra_nodes;
265 return *this;
266 }
267
268 using Internal = const PCNode<TArgs...> &;
269
Reset()270 void Reset() const {
271 tuple_utils::ResetAll(args_);
272 extra_nodes_.clear();
273 }
274
275 private:
276 std::tuple<typename TArgs::Internal...> args_;
277 mutable AnfNodePtrList extra_nodes_;
278 mutable bool has_min_extra_nodes_{false};
279 mutable size_t min_extra_nodes_{0};
280 };
281
282 template <typename... TArgs>
283 class PPrimitive : public PBase<PPrimitive<TArgs...> > {
284 public:
PPrimitive(const PrimitivePtr & prim,const TArgs &...args)285 explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
286 virtual ~PPrimitive() = default;
287
GetNode(const AnfNodePtr & node)288 AnfNodePtr GetNode(const AnfNodePtr &node) const {
289 std::vector<AnfNodePtr> inputs;
290 constexpr auto n_args = sizeof...(TArgs);
291 const auto n_extra_nodes = extra_nodes_.size();
292 inputs.reserve(1 + n_args + n_extra_nodes);
293 // The first input is the value of primitive.
294 (void)inputs.emplace_back(NewValueNode(prim_));
295 // Get nodes from input patterns.
296 tuple_utils::ForEach(args_, [&inputs, &node](auto &arg) { (void)inputs.emplace_back(arg.GetNode(node)); });
297 if (n_extra_nodes > 0) {
298 // In case this PPrimitive has captured extra nodes.
299 (void)inputs.insert(inputs.begin(), extra_nodes_.begin(), extra_nodes_.end());
300 }
301 return NewCNode(std::move(inputs), node->func_graph());
302 }
303
TryCapture_(const AnfNodePtr & node)304 bool TryCapture_(const AnfNodePtr &node) const {
305 if (!IsPrimitiveCNode(node, prim_)) {
306 return false;
307 }
308 auto cnode = node->cast<CNodePtr>();
309 const auto &inputs = cnode->inputs();
310 const auto inputs_size = inputs.size();
311 // Number of arguments in Primitive Pattern (not including the Primitive node).
312 constexpr auto pattern_arg_len = sizeof...(TArgs);
313 // There aren't enough inputs in Node to fill up the Pattern.
314 if (inputs_size < (pattern_arg_len + 1)) {
315 return false;
316 }
317 // Pattern must exactly match the number of Node inputs.
318 if (!has_min_extra_nodes_) {
319 if (inputs_size != (pattern_arg_len + 1)) {
320 return false;
321 }
322 // Inputs in Node perfectly match number of tokens in Pattern.
323 auto captured = tuple_utils::CaptureAll<false>(args_, inputs.begin() + 1);
324 if (captured) {
325 captured_prim_node_ = node;
326 }
327 return captured;
328 }
329 // Pattern may accept extra (non specified) nodes at the end of the Primitive,
330 // There must be at least `min_extra_nodes` additional nodes in the inputs.
331 if (inputs_size < (pattern_arg_len + 1 + min_extra_nodes_)) {
332 return false;
333 }
334 auto captured = tuple_utils::CaptureAll<false>(args_, inputs.begin() + 1);
335 // If it could capture the initial set of nodes specified in the Pattern
336 // and there are enough extra inputs to add.
337 if (captured) {
338 captured_prim_node_ = node;
339 if (inputs_size > pattern_arg_len + 1) {
340 (void)extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + SizeToLong(pattern_arg_len), inputs.end());
341 }
342 }
343 return captured;
344 }
345
346 /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one
347 /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
348 /// more nodes after the last one specified when building the PPrimitive.
349 const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
350 has_min_extra_nodes_ = true;
351 min_extra_nodes_ = min_extra_nodes;
352 return *this;
353 }
354
GetCapturedExtraNodes()355 const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; }
356
357 /// Returns the FuncGraph of the original node captured by this Primitive Pattern.
358 /// Throws exception if a node was not captured before.
GetFuncGraph()359 FuncGraphPtr GetFuncGraph() const {
360 if (captured_prim_node_ == nullptr) {
361 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph.";
362 }
363 return captured_prim_node_->func_graph();
364 }
365
366 /// Returns the original node captured by this Primitive Pattern.
367 /// Throws exception if a node was not captured before.
GetOriginalNode()368 AnfNodePtr GetOriginalNode() const {
369 if (captured_prim_node_ == nullptr) {
370 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
371 }
372 return captured_prim_node_;
373 }
374
Reset()375 void Reset() const {
376 tuple_utils::ResetAll(args_);
377 extra_nodes_.clear();
378 captured_prim_node_ = nullptr;
379 }
380
381 using Internal = const PPrimitive<TArgs...> &;
382
383 private:
384 const PrimitivePtr prim_;
385 std::tuple<typename TArgs::Internal...> args_;
386 mutable AnfNodePtrList extra_nodes_;
387 mutable bool has_min_extra_nodes_{false};
388 mutable size_t min_extra_nodes_{0};
389 mutable AnfNodePtr captured_prim_node_{nullptr};
390 };
391
392 ///
393 /// PConstant class can capture a value node of a specified value (check_value_)
394 /// or a non-specified one (any_value = true).
395 /// It can be configured to capture a scalar constant as well (is_scalar_ = true)
396 ///
397 template <typename T = AnfNodePtr>
398 class PConstant : public PBase<PConstant<T> > {
399 public:
400 explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int64_t check_value = 0,
401 const bool is_scalar = false)
as_node_(as_node)402 : as_node_(as_node),
403 captured_node_(as_node),
404 any_value_(any_value),
405 check_value_(check_value),
406 is_scalar_(is_scalar) {}
407
408 virtual ~PConstant() = default;
409 // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode
WithShapeAs(const AnfNodePtr & node)410 const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const {
411 if (node == nullptr) {
412 MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node.";
413 }
414 as_node_ = node;
415 changed_shape_ = true;
416 return *this;
417 }
418
419 // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode
WithShapeAs(const PatternNode<T> & pnode)420 const PConstant<T> &WithShapeAs(const PatternNode<T> &pnode) const {
421 if (captured_node_ == nullptr) {
422 MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node.";
423 }
424 as_node_ = pnode.GetNode(captured_node_);
425 changed_shape_ = true;
426 return *this;
427 }
428
429 /// Sets captured_node_ as the node captured by the Pattern received as argument
430 /// to produce a new node with its contents when calling GetNode.
WithValueOf(const PatternNode<T> & pnode)431 const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const {
432 if (!any_value_) {
433 MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node.";
434 }
435 if (captured_node_ == nullptr) {
436 MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node.";
437 }
438 captured_node_ = pnode.GetNode(captured_node_);
439 changed_shape_ = true;
440 return *this;
441 }
442
443 /// Create a new Value Node filled up with check_value.
444 /// This function must be used immediately before GetNode to avoid replacing the expected result.
445 /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf.
NewValue()446 const PConstant<T> &NewValue() const {
447 if (!is_scalar_) {
448 MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants.";
449 }
450 auto value_node_ = MakeValue(check_value_);
451 captured_node_ = NewValueNode(value_node_);
452 is_new_value_node_ = true;
453 return *this;
454 }
455
GetNode(const AnfNodePtr &)456 AnfNodePtr GetNode(const AnfNodePtr &) const {
457 // If a NewValueNode was requested (using NewValue function) then return that created node.
458 if (is_new_value_node_) {
459 return captured_node_;
460 }
461 /// Return a NewTensorFilledWithData if the node was initialized to have a specific value
462 /// even if it wasn't captured. Usually for zero constants (x - x => zero).
463 /// If the shape was changed, use the new shape.
464 if (changed_shape_ || !captured_) {
465 if (!any_value_) {
466 return NewTensorFilledWithData(as_node_, check_value_);
467 }
468 return NewTensorFilledWithData(as_node_, captured_node_);
469 }
470 return captured_node_;
471 }
472
TryCapture_(const AnfNodePtr & node)473 bool TryCapture_(const AnfNodePtr &node) const {
474 if (!node->isa<ValueNode>()) {
475 return false;
476 }
477 // If any_value_ is set don't check for the node's value. Just capture it.
478 if (any_value_) {
479 captured_node_ = node;
480 captured_ = true;
481 return true;
482 }
483 // Check value.
484 auto value = node->cast<ValueNodePtr>()->value();
485 if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) {
486 captured_node_ = node;
487 captured_ = true;
488 return true;
489 }
490 auto value_node_ = MakeValue(check_value_);
491 if (*GetValueNode(node) == *value_node_) {
492 captured_node_ = node;
493 captured_ = true;
494 return true;
495 }
496 return false;
497 }
498
Reset()499 void Reset() const {
500 captured_ = false;
501 changed_shape_ = false;
502 is_new_value_node_ = false;
503 }
504
505 // Support function used for checking if all values of a Tensor are equal to `check_value_`
506 // Supported data types: double, float/float32, int/int32
IsTensorConstant(const ValuePtr & value)507 bool IsTensorConstant(const ValuePtr &value) const {
508 if (!value->isa<tensor::Tensor>()) {
509 return false;
510 }
511 auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
512 TypeId tensor_type = tensor_ptr->Dtype()->type_id();
513 if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
514 float *data2 = static_cast<float *>(tensor_ptr->data_c());
515 auto threshold = FLT_MIN;
516 for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
517 if (fabs(data2[i] - check_value_) > threshold) {
518 return false;
519 }
520 }
521 return true;
522 } else if (tensor_type == TypeId::kNumberTypeFloat64) {
523 double *data2 = static_cast<double *>(tensor_ptr->data_c());
524 auto threshold = DBL_MIN;
525 for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
526 if (fabs(data2[i] - check_value_) > threshold) {
527 return false;
528 }
529 }
530 return true;
531 } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
532 int *data2 = static_cast<int *>(tensor_ptr->data_c());
533 for (size_t i = 0; i < tensor_ptr->DataSize(); i++) {
534 if (data2[i] != check_value_) {
535 return false;
536 }
537 }
538 return true;
539 }
540 // Input Data Type is not supported
541 return false;
542 }
543
IsTensorScalarConstant(const ValuePtr & value)544 bool IsTensorScalarConstant(const ValuePtr &value) const {
545 if (!value->isa<tensor::Tensor>()) {
546 return false;
547 }
548 auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
549 if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
550 return false;
551 }
552 return IsTensorConstant(value);
553 }
554
GetPointerToTensorData(const AnfNodePtr & node)555 void *GetPointerToTensorData(const AnfNodePtr &node) const {
556 if (!node->isa<ValueNode>()) {
557 return nullptr;
558 }
559 auto value = node->cast<ValueNodePtr>()->value();
560 if (!value->isa<tensor::Tensor>()) {
561 return nullptr;
562 }
563
564 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
565 return tensor_ptr->data_c();
566 }
567
568 // Make a new tensor (when possible) with the same shape as of `node`
569 // If x is nullptr then fill new tensor will "0"
570 // If x is a tensor with empty shape then fill new tensor with the single value of x
571 // If x is a tensor with same shape as `node` then return x as result
572 AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const {
573 if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
574 return nullptr;
575 }
576
577 auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
578 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
579 ShapeVector tensor_shape = tensor_abstract->shape()->shape();
580 if (x == nullptr) {
581 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
582 char *data = static_cast<char *>(new_tensor_ptr->data_c());
583 if (memset_s(data, new_tensor_ptr->Size(), 0, new_tensor_ptr->Size()) != EOK) {
584 return nullptr;
585 }
586 auto new_vnode = NewValueNode(new_tensor_ptr);
587 new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
588 return new_vnode;
589 }
590 // x is not nullptr
591 if (x->isa<CNode>() || x->isa<Parameter>()) {
592 if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
593 return nullptr;
594 }
595 auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
596 ShapeVector x_shape = x_abstract->shape()->shape();
597 if (x_shape != tensor_shape) {
598 return nullptr;
599 }
600 return x;
601 }
602
603 if (!x->isa<ValueNode>()) {
604 return nullptr;
605 }
606 auto x_value = x->cast<ValueNodePtr>()->value();
607 if (!x_value->isa<tensor::Tensor>()) {
608 return nullptr;
609 }
610 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
611 auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
612 if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
613 return nullptr;
614 }
615 int ret = 0;
616 char *source_data = static_cast<char *>(GetPointerToTensorData(x));
617 MS_EXCEPTION_IF_NULL(source_data);
618 if (x_tensor_ptr->DataSize() == 1) {
619 auto tensor_type_byte = GetTypeByte(tensor_type_ptr);
620 char *data = static_cast<char *>(new_tensor_ptr->data_c());
621 for (int64_t i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
622 ret = memcpy_s(data + LongToSize(i) * tensor_type_byte, tensor_type_byte, source_data, tensor_type_byte);
623 if (ret != EOK) {
624 MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << tensor_type_byte << ", dest size "
625 << tensor_type_byte;
626 }
627 }
628 } else {
629 char *data = static_cast<char *>(new_tensor_ptr->data_c());
630 ret = memcpy_s(data, new_tensor_ptr->Size(), source_data, new_tensor_ptr->Size());
631 if (ret != EOK) {
632 MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << new_tensor_ptr->Size()
633 << ", dest size " << new_tensor_ptr->DataSize();
634 return nullptr;
635 }
636 }
637 auto new_vnode = NewValueNode(new_tensor_ptr);
638 new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
639 return new_vnode;
640 }
641
NewTensorFilledWithData(const AnfNodePtr & node,const int & value)642 AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const {
643 if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
644 return nullptr;
645 }
646
647 auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
648 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
649 ShapeVector tensor_shape = tensor_abstract->shape()->shape();
650
651 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
652 size_t mem_size = GetTypeByte(tensor_type_ptr) * LongToSize(new_tensor_ptr->ElementsNum());
653 char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
654
655 if (memset_s(data, mem_size, value, mem_size) != EOK) {
656 return nullptr;
657 }
658 auto new_vnode = NewValueNode(new_tensor_ptr);
659 new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
660 return new_vnode;
661 }
662
663 enum BinOperator {
664 ADD,
665 MULTIPLY,
666 };
667
668 // Support function to add/multiply two constant tensors: partially support broadcasting shapes
669 template <typename TM>
CalcByOperator(void * in_data_1,int in_data_1_size,void * in_data_2,int in_data_2_size,void ** out_data,int out_data_size,BinOperator bin_operator)670 void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
671 int out_data_size, BinOperator bin_operator) const {
672 if (out_data_size <= 0) {
673 MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros";
674 }
675 TM *data_1 = static_cast<TM *>(in_data_1);
676 TM *data_2 = static_cast<TM *>(in_data_2);
677 TM *data_out = new TM[out_data_size];
678
679 if (in_data_1_size == 1) {
680 for (int i = 0; i < out_data_size; i++) {
681 data_out[i] = data_1[0];
682 }
683 } else {
684 for (int i = 0; i < out_data_size; i++) {
685 data_out[i] = data_1[i];
686 }
687 }
688 if (in_data_2_size == 1) {
689 for (int i = 0; i < out_data_size; i++) {
690 if (bin_operator == ADD) {
691 data_out[i] += data_2[0];
692 } else {
693 data_out[i] *= data_2[0];
694 }
695 }
696 } else {
697 if (in_data_2_size < out_data_size) {
698 MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size
699 << ".in_data2 will be broadcast.";
700 }
701 auto min_size = std::min<int>(in_data_2_size, out_data_size);
702 for (int i = 0; i < min_size; i++) {
703 if (bin_operator == ADD) {
704 data_out[i] += data_2[i];
705 } else {
706 data_out[i] *= data_2[i];
707 }
708 }
709 // In case of in_data2_size < out_data_size
710 for (int i = min_size; i < out_data_size; i++) {
711 if (bin_operator != ADD) {
712 // if operator is MUL, data_out[i] *= 0, => data_out[i] = 0.
713 data_out[i] = 0;
714 }
715 // if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange.
716 }
717 }
718 *out_data = static_cast<void *>(data_out);
719 }
720
MulByPatternConst(const PConstant<T> & vpnode_2,const AnfNodePtr & node_3)721 AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
722 AnfNodePtr vnode_1 = this->GetNode(captured_node_);
723 AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
724 return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY);
725 }
726
GetTensorFromValueNode(const AnfNodePtr & vnode)727 tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &vnode) const {
728 if (!vnode->isa<ValueNode>() || vnode->abstract() == nullptr) {
729 return nullptr;
730 }
731 auto value = GetValueNode(vnode);
732 if (!value->isa<tensor::Tensor>()) {
733 return nullptr;
734 }
735 return dyn_cast<tensor::Tensor>(value);
736 }
737
CalcConstantTensors(const AnfNodePtr & vnode_1,const AnfNodePtr & vnode_2,const AnfNodePtr & node_3,BinOperator bin_operator)738 AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3,
739 BinOperator bin_operator) const {
740 auto tensor_ptr_1 = GetTensorFromValueNode(vnode_1);
741 auto tensor_ptr_2 = GetTensorFromValueNode(vnode_2);
742 if (tensor_ptr_1 == nullptr || tensor_ptr_2 == nullptr || node_3->abstract() == nullptr) {
743 return nullptr;
744 }
745 tensor::TensorPtr new_tensor_ptr = GetNewTensor(vnode_1, vnode_2, node_3);
746 if (new_tensor_ptr == nullptr) {
747 return nullptr;
748 }
749 ShapeVector tensor_out_shape = new_tensor_ptr->shape();
750 int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
751 size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * LongToSize(new_tensor_ptr->ElementsNum());
752 char *data = static_cast<char *>(new_tensor_ptr->data_c());
753
754 int ret = 0;
755 void *data_out = nullptr;
756 if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
757 (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
758 CalcByOperator<float>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
759 SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
760 ret = memcpy_s(data, mem_size, data_out, mem_size);
761 delete[] static_cast<float *>(data_out);
762 } else {
763 if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
764 CalcByOperator<double>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
765 SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
766 ret = memcpy_s(data, mem_size, data_out, mem_size);
767 delete[] static_cast<double *>(data_out);
768 } else {
769 if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
770 (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
771 CalcByOperator<int>(tensor_ptr_1->data_c(), SizeToInt(tensor_ptr_1->DataSize()), tensor_ptr_2->data_c(),
772 SizeToInt(tensor_ptr_2->DataSize()), &data_out, data_out_size, bin_operator);
773 ret = memcpy_s(data, mem_size, data_out, mem_size);
774 delete[] static_cast<int *>(data_out);
775 } else {
776 // Unsupported data types
777 return nullptr;
778 }
779 }
780 }
781 if (ret != EOK) {
782 MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
783 << new_tensor_ptr->DataSize();
784 }
785 auto new_vnode = NewValueNode(new_tensor_ptr);
786 new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
787 return new_vnode;
788 }
789
790 using Internal = const PConstant<T> &;
791
792 protected:
793 mutable AnfNodePtr as_node_;
794 mutable AnfNodePtr captured_node_;
795 bool any_value_{true};
796 int64_t check_value_{0};
797 bool is_scalar_{false};
798 mutable bool is_new_value_node_{false};
799 mutable bool captured_{false};
800 mutable bool changed_shape_{false};
801
802 private:
GetNewTensor(const AnfNodePtr & vnode_1,const AnfNodePtr & vnode_2,const AnfNodePtr & node_3)803 tensor::TensorPtr GetNewTensor(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
804 auto value_1 = GetValueNode(vnode_1);
805 auto value_2 = GetValueNode(vnode_2);
806 auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
807 auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
808 auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
809 auto tensor_2_abstract = vnode_2->abstract()->cast<abstract::AbstractTensorPtr>();
810
811 TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
812 TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
813 if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) &&
814 (tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) {
815 // If two constant nodes have the same shape, then create a new one with this shape
816 auto tensor_out_shape = tensor_1_abstract->shape()->shape();
817
818 return std::make_shared<tensor::Tensor>(tensor_1_type_ptr->type_id(), tensor_out_shape);
819 } else {
820 // If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node
821 auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
822
823 TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
824 if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
825 (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
826 return nullptr;
827 }
828 auto tensor_out_shape = tensor_3_abstract->shape()->shape();
829 // if dynamic, return nullptr
830 if (std::any_of(tensor_out_shape.begin(), tensor_out_shape.end(), [](int64_t i) { return i < 0; })) {
831 return nullptr;
832 }
833 size_t data_out_size =
834 LongToSize(std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int64_t>()));
835 if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
836 return nullptr;
837 }
838 if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
839 return nullptr;
840 }
841 return std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
842 }
843 }
844 };
845
846 // Macro for binary operation functions
847 #define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \
848 template <typename T, typename T2> \
849 inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
850 return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \
851 }
852
853 // Arithmetic operations
854 BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true);
855 BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
856 BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
857 BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
858
859 // Macros for match and replace
860 #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
861 do { \
862 if ((CaptureNode).TryCapture(OrigNode)) { \
863 auto rep = (ReplaceWith).GetNode(OrigNode); \
864 if (rep != nullptr) { \
865 return rep; \
866 } \
867 } \
868 } while (0)
869
870 #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
871 do { \
872 if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
873 auto rep = (ReplaceWith).GetNode(OrigNode); \
874 if (rep != nullptr) { \
875 return rep; \
876 } \
877 } \
878 } while (0)
879
880 #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
881 do { \
882 if ((CaptureNode).TryCapture(OrigNode)) { \
883 auto rep = (Lambda)(); \
884 if (rep != nullptr) { \
885 return rep; \
886 } \
887 } \
888 } while (0)
889
890 #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
891 do { \
892 if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
893 auto rep = (Lambda)(); \
894 if (rep != nullptr) { \
895 return rep; \
896 } \
897 } \
898 } while (0)
899 } // namespace mindspore
900
901 #endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
902