1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ 18 #define MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ 19 20 #include <functional> 21 #include <memory> 22 #include <tuple> 23 #include <vector> 24 #include <algorithm> 25 26 #include "base/core_ops.h" 27 #include "ir/visitor.h" 28 #include "utils/shape_utils.h" 29 30 namespace mindspore { 31 /// 32 /// Base class for all recognizable patterns. 33 /// We implement an Expression Template approach using static polymorphism based on 34 /// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect 35 /// to the use of virtual functions without the costs..." as described in: 36 /// https://en.wikipedia.org/wiki/Expression_templates and 37 /// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern 38 /// The TryCapture function tries to capture the pattern with the given node. 39 /// The GetNode function builds a new node using the captured values. 40 /// 41 42 template <typename T> 43 class PBase { 44 public: CheckFunc(const PredicateFuncType & func,const AnfNodePtr & node)45 bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) { return func(get_object().GetNode(node)); } 46 get_object()47 const T &get_object() const { return *static_cast<const T *>(this); } 48 49 template <typename TN> TryCapture(const TN & value)50 bool TryCapture(const TN &value) const { 51 get_object().Reset(); 52 return get_object().TryCapture_(value); 53 } 54 55 using Internal = T; 56 }; 57 58 template <typename T> 59 class PIsEqual { 60 public: operator()61 bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } 62 }; 63 64 template <typename T = AnfNodePtr> 65 class PatternNode : public PBase<PatternNode<T> > { 66 public: GetNode(const AnfNodePtr &)67 T GetNode(const AnfNodePtr &) const { 68 if (!captured_) { 69 MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; 70 } 71 return captured_node_; 72 } 73 TryCapture_(const T & node)74 bool TryCapture_(const T &node) const { 75 if (!captured_) { 76 captured_node_ = node; 77 captured_ = true; 78 return true; 79 } 80 return PIsEqual<T>()(captured_node_, node); 81 } 82 Reset()83 void Reset() const { captured_ = false; } 84 using Internal = const PatternNode<T> &; 85 86 protected: 87 mutable T captured_node_; 88 mutable bool captured_{false}; 89 }; 90 91 template <typename T, typename T2> 92 class PBinOperation : public PBase<PBinOperation<T, T2> > { 93 public: 94 PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false) prim_(prim)95 : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {} 96 ~PBinOperation() = default; 97 GetNode(const AnfNodePtr & node)98 AnfNodePtr GetNode(const AnfNodePtr &node) const { 99 AnfNodePtr lhs = x_.GetNode(node); 100 AnfNodePtr rhs = y_.GetNode(node); 101 AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; 102 return NewCNode(list, node->func_graph()); 103 } 104 TryCapture_(const AnfNodePtr & node)105 bool TryCapture_(const AnfNodePtr &node) const { 106 if (IsPrimitiveCNode(node, prim_)) { 107 auto cnode = node->cast<CNodePtr>(); 108 auto inputs = cnode->inputs(); 109 if (inputs.size() == 3) { 110 // Binary Prim assumes only two inputs 111 if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) { 112 // If the operation is commutative, then check with inversed operands 113 if (is_commutative_) { 114 Reset(); 115 if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { 116 return false; 117 } 118 captured_binop_node_ = node; 119 return true; 120 } 121 return false; 122 } 123 captured_binop_node_ = node; 124 return true; 125 } 126 } 127 return false; 128 } 129 130 /// Returns the original node captured by this Binary Operation Pattern. 131 /// Throws exception if a node was not captured before. GetOriginalNode()132 AnfNodePtr GetOriginalNode() const { 133 if (captured_binop_node_ == nullptr) { 134 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; 135 } 136 137 return captured_binop_node_; 138 } 139 Reset()140 void Reset() const { 141 x_.Reset(); 142 y_.Reset(); 143 captured_binop_node_ = nullptr; 144 } 145 146 using Internal = const PBinOperation<T, T2> &; 147 148 private: 149 const PrimitivePtr prim_; 150 typename T::Internal x_; 151 typename T2::Internal y_; 152 bool is_commutative_{false}; 153 mutable AnfNodePtr captured_binop_node_{nullptr}; 154 }; 155 156 template <typename T> 157 class PUnaryOperation : public PBase<PUnaryOperation<T> > { 158 public: PUnaryOperation(const PrimitivePtr & prim,const T & x)159 PUnaryOperation(const PrimitivePtr &prim, const T &x) : prim_(prim), x_(x) {} 160 ~PUnaryOperation() = default; 161 GetNode(const AnfNodePtr & node)162 AnfNodePtr GetNode(const AnfNodePtr &node) const { 163 AnfNodePtrList list = {NewValueNode(prim_), x_.GetNode(node)}; 164 return NewCNode(list, node->func_graph()); 165 } 166 TryCapture_(const AnfNodePtr & node)167 bool TryCapture_(const AnfNodePtr &node) const { 168 if (IsPrimitiveCNode(node, prim_)) { 169 auto cnode = node->cast<CNodePtr>(); 170 auto inputs = cnode->inputs(); 171 if (inputs.size() == 2 && x_.TryCapture(inputs[1])) { 172 captured_unaryop_node_ = node; 173 return true; 174 } 175 } 176 return false; 177 } 178 GetOriginalNode()179 AnfNodePtr GetOriginalNode() const { 180 if (captured_unaryop_node_ == nullptr) { 181 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; 182 } 183 return captured_unaryop_node_; 184 } 185 Reset()186 void Reset() const { 187 x_.Reset(); 188 captured_unaryop_node_ = nullptr; 189 } 190 191 using Internal = const PUnaryOperation<T> &; 192 193 private: 194 const PrimitivePtr prim_; 195 typename T::Internal x_; 196 mutable AnfNodePtr captured_unaryop_node_{nullptr}; 197 }; 198 199 /// 200 /// Helper functions to apply a pattern function on all elements of a tuple 201 /// 202 namespace tuple_utils { 203 template <bool stop, size_t Index, typename Func> 204 struct apply_func_tuple_item { 205 template <typename TTuple> applyapply_func_tuple_item206 static void apply(Func *func, const TTuple &tuple) { 207 (*func)(Index, std::get<Index>(tuple)); 208 apply_func_tuple_item<(Index + 1) == std::tuple_size<TTuple>::value, (Index + 1), Func>::apply(func, tuple); 209 } 210 }; 211 212 template <size_t Index, typename Func> 213 struct apply_func_tuple_item<true, Index, Func> { 214 template <typename TTuple> 215 static void apply(Func *func, const TTuple &tuple) {} 216 }; 217 218 template <typename Func, typename TTuple> 219 inline void apply_func_tuple(Func *func, const TTuple &tuple) { 220 apply_func_tuple_item<std::tuple_size<TTuple>::value == 0, 0, Func>::apply(func, tuple); 221 } 222 223 struct PTupleResetCapture { 224 template <typename T> 225 void operator()(size_t i, const T &pattern) const { 226 pattern.Reset(); 227 } 228 }; 229 230 struct PTupleCapture { 231 explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} 232 233 template <typename TPattern> 234 void operator()(size_t i, const TPattern &pattern) { 235 // Check if the first node is a Primitive 236 if (i == 0 && tuple_[i]->isa<Primitive>()) { 237 auto prim = tuple_[i]->cast<PrimitivePtr>(); 238 if (tuple_[i] != pattern.GetNode(tuple_[i])) { 239 captured_ = false; 240 } 241 } else { 242 captured_ = captured_ && pattern.TryCapture_(tuple_[i]); 243 } 244 } 245 246 const AnfNodePtrList tuple_; 247 bool captured_{true}; 248 }; 249 250 struct PTupleGetNode { 251 explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} 252 253 template <typename TPattern> 254 void operator()(size_t, const TPattern &pattern) { 255 args_.push_back(pattern.GetNode(node_)); 256 } 257 258 const AnfNodePtr &node_; 259 std::vector<AnfNodePtr> args_; 260 }; 261 } // namespace tuple_utils 262 263 template <typename... TArgs> 264 class PCNode : public PBase<PCNode<TArgs...> > { 265 public: 266 explicit PCNode(const TArgs &... args) : args_(args...) {} 267 ~PCNode() = default; 268 269 AnfNodePtr GetNode(const AnfNodePtr &node) const { 270 tuple_utils::PTupleGetNode get_node(node); 271 tuple_utils::apply_func_tuple(&get_node, args_); 272 auto prim_cnode = get_node.args_; 273 // In case this PCNode has captured extra nodes 274 if (extra_nodes_.size() > 0) { 275 prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); 276 } 277 return NewCNode(prim_cnode, node->func_graph()); 278 } 279 280 bool TryCapture_(const AnfNodePtr &node) const { 281 if (node->isa<CNode>()) { 282 auto cnode = node->cast<CNodePtr>(); 283 auto inputs = cnode->inputs(); 284 285 auto pattern_arg_len = sizeof...(TArgs); 286 // There aren't enough inputs in Node to fill up the Pattern 287 if (inputs.size() < pattern_arg_len) { 288 return false; 289 } 290 291 // Pattern must exactly match the number of Node inputs. 292 if (!has_min_extra_nodes_) { 293 // Inputs in Node perfectly match number of tokens in Pattern. 294 if (inputs.size() == pattern_arg_len) { 295 AnfNodePtrList tokens(inputs.begin(), inputs.end()); 296 tuple_utils::PTupleCapture capture_func(tokens); 297 tuple_utils::apply_func_tuple(&capture_func, args_); 298 return capture_func.captured_; 299 } 300 return false; 301 } 302 303 // Pattern may accept extra (non specified) nodes at the end of the CNode 304 // There must be at least `min_extra_nodes` additional nodes in the inputs. 305 if (inputs.size() >= pattern_arg_len + min_extra_nodes_) { 306 AnfNodePtrList tokens(inputs.begin(), inputs.begin() + pattern_arg_len); 307 tuple_utils::PTupleCapture capture_func(tokens); 308 tuple_utils::apply_func_tuple(&capture_func, args_); 309 // If it could capture the initial set of nodes specified in the Pattern 310 // and there are enough extra inputs to add 311 if (capture_func.captured_ && inputs.size() > pattern_arg_len) { 312 extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + pattern_arg_len, inputs.end()); 313 return true; 314 } 315 return capture_func.captured_; 316 } 317 return false; 318 } 319 return false; 320 } 321 322 /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one 323 /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or 324 /// more nodes after the last one specified when building the PCNode. 325 const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const { 326 has_min_extra_nodes_ = true; 327 min_extra_nodes_ = min_extra_nodes; 328 return *this; 329 } 330 331 using Internal = const PCNode<TArgs...> &; 332 333 void Reset() const { 334 tuple_utils::PTupleResetCapture reset; 335 tuple_utils::apply_func_tuple(&reset, args_); 336 extra_nodes_.clear(); 337 } 338 339 private: 340 std::tuple<typename TArgs::Internal...> args_; 341 mutable AnfNodePtrList extra_nodes_; 342 mutable bool has_min_extra_nodes_{false}; 343 mutable size_t min_extra_nodes_{0}; 344 }; 345 346 template <typename... TArgs> 347 class PPrimitive : public PBase<PPrimitive<TArgs...> > { 348 public: 349 explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} 350 ~PPrimitive() = default; 351 352 AnfNodePtr GetNode(const AnfNodePtr &node) const { 353 tuple_utils::PTupleGetNode get_node(node); 354 tuple_utils::apply_func_tuple(&get_node, args_); 355 auto prim_cnode = get_node.args_; 356 prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); 357 358 // In case this PPrimitive has captured extra nodes 359 if (extra_nodes_.size() > 0) { 360 prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); 361 } 362 return NewCNode(prim_cnode, node->func_graph()); 363 } 364 365 bool TryCapture_(const AnfNodePtr &node) const { 366 if (IsPrimitiveCNode(node, prim_)) { 367 auto cnode = node->cast<CNodePtr>(); 368 auto inputs = cnode->inputs(); 369 // Number of arguments in Primitive Pattern (not including the Primitive node) 370 auto pattern_arg_len = sizeof...(TArgs); 371 // There aren't enough inputs in Node to fill up the Pattern 372 if ((inputs.size() - 1) < pattern_arg_len) { 373 return false; 374 } 375 376 // Pattern must exactly match the number of Node inputs. 377 if (!has_min_extra_nodes_) { 378 // Inputs in Node perfectly match number of tokens in Pattern. 379 if ((inputs.size() - 1) == pattern_arg_len) { 380 AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); 381 tuple_utils::PTupleCapture capture_func(tokens); 382 tuple_utils::apply_func_tuple(&capture_func, args_); 383 if (capture_func.captured_) { 384 captured_prim_node_ = node; 385 } 386 return capture_func.captured_; 387 } 388 return false; 389 } 390 391 // Pattern may accept extra (non specified) nodes at the end of the Primitive 392 // There must be at least `min_extra_nodes` additional nodes in the inputs. 393 if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { 394 AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); 395 tuple_utils::PTupleCapture capture_func(tokens); 396 tuple_utils::apply_func_tuple(&capture_func, args_); 397 // If it could capture the initial set of nodes specified in the Pattern 398 // and there are enough extra inputs to add 399 if (capture_func.captured_) { 400 captured_prim_node_ = node; 401 if (inputs.size() > pattern_arg_len + 1) { 402 extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); 403 } 404 } 405 return capture_func.captured_; 406 } 407 return false; 408 } 409 return false; 410 } 411 412 /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one 413 /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or 414 /// more nodes after the last one specified when building the PPrimitive. 415 const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const { 416 has_min_extra_nodes_ = true; 417 min_extra_nodes_ = min_extra_nodes; 418 return *this; 419 } 420 421 const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; } 422 423 /// Returns the FuncGraph of the original node captured by this Primitive Pattern. 424 /// Throws exception if a node was not captured before. 425 FuncGraphPtr GetFuncGraph() const { 426 if (captured_prim_node_ == nullptr) { 427 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph."; 428 } 429 430 return captured_prim_node_->func_graph(); 431 } 432 433 /// Returns the original node captured by this Primitive Pattern. 434 /// Throws exception if a node was not captured before. 435 AnfNodePtr GetOriginalNode() const { 436 if (captured_prim_node_ == nullptr) { 437 MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; 438 } 439 440 return captured_prim_node_; 441 } 442 443 void Reset() const { 444 tuple_utils::PTupleResetCapture reset; 445 tuple_utils::apply_func_tuple(&reset, args_); 446 extra_nodes_.clear(); 447 captured_prim_node_ = nullptr; 448 } 449 450 using Internal = const PPrimitive<TArgs...> &; 451 452 private: 453 const PrimitivePtr prim_; 454 std::tuple<typename TArgs::Internal...> args_; 455 mutable AnfNodePtrList extra_nodes_; 456 mutable bool has_min_extra_nodes_{false}; 457 mutable size_t min_extra_nodes_{0}; 458 mutable AnfNodePtr captured_prim_node_{nullptr}; 459 }; 460 461 /// 462 /// PConstant class can capture a value node of a specified value (check_value_) 463 /// or a non-specified one (any_value = true). 464 /// It can be configured to capture a scalar constant as well (is_scalar_ = true) 465 /// 466 template <typename T = AnfNodePtr> 467 class PConstant : public PBase<PConstant<T> > { 468 public: 469 explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int64_t check_value = 0, 470 const bool is_scalar = false) 471 : as_node_(as_node), 472 captured_node_(as_node), 473 any_value_(any_value), 474 check_value_(check_value), 475 is_scalar_(is_scalar) {} 476 477 ~PConstant() = default; 478 // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode 479 const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const { 480 if (node == nullptr) { 481 MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node."; 482 } 483 as_node_ = node; 484 changed_shape_ = true; 485 return *this; 486 } 487 488 // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode 489 const PConstant<T> &WithShapeAs(const PatternNode<T> &pnode) const { 490 if (captured_node_ == nullptr) { 491 MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node."; 492 } 493 as_node_ = pnode.GetNode(captured_node_); 494 changed_shape_ = true; 495 return *this; 496 } 497 498 /// Sets captured_node_ as the node captured by the Pattern received as argument 499 /// to produce a new node with its contents when calling GetNode. 500 const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const { 501 if (!any_value_) { 502 MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node."; 503 } 504 if (captured_node_ == nullptr) { 505 MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node."; 506 } 507 captured_node_ = pnode.GetNode(captured_node_); 508 changed_shape_ = true; 509 return *this; 510 } 511 512 /// Create a new Value Node filled up with check_value. 513 /// This function must be used immediately before GetNode to avoid replacing the expected result. 514 /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf. 515 const PConstant<T> &NewValue() const { 516 if (!is_scalar_) { 517 MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants."; 518 } 519 auto value_node_ = MakeValue(check_value_); 520 captured_node_ = NewValueNode(value_node_); 521 is_new_value_node_ = true; 522 return *this; 523 } 524 525 AnfNodePtr GetNode(const AnfNodePtr &node) const { 526 // If a NewValueNode was requested (using NewValue function) then return that created node. 527 if (is_new_value_node_) { 528 return captured_node_; 529 } 530 /// Return a NewTensorFilledWithData if the node was initialized to have a specific value 531 /// even if it wasn't captured. Usually for zero constants (x - x => zero). 532 /// If the shape was changed, use the new shape. 533 if (changed_shape_ || !captured_) { 534 if (!any_value_) { 535 return NewTensorFilledWithData(as_node_, check_value_); 536 } 537 return NewTensorFilledWithData(as_node_, captured_node_); 538 } 539 return captured_node_; 540 } 541 542 bool TryCapture_(const AnfNodePtr &node) const { 543 if (node->isa<ValueNode>()) { 544 // If any_value_ is set don't check for the node's value. Just capture it. 545 if (any_value_) { 546 captured_node_ = node; 547 captured_ = true; 548 return true; 549 } 550 551 auto value = node->cast<ValueNodePtr>()->value(); 552 if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) { 553 captured_node_ = node; 554 captured_ = true; 555 return true; 556 } 557 558 auto value_node_ = MakeValue(check_value_); 559 if (*GetValueNode(node) == *value_node_) { 560 captured_node_ = node; 561 captured_ = true; 562 return true; 563 } 564 } 565 return false; 566 } 567 568 void Reset() const { 569 captured_ = false; 570 changed_shape_ = false; 571 is_new_value_node_ = false; 572 } 573 574 // Support function used for checking if all values of a Tensor are equal to `check_value_` 575 // Supported data types: double, float/float32, int/int32 576 bool IsTensorConstant(const ValuePtr &value) const { 577 if (!value->isa<tensor::Tensor>()) { 578 return false; 579 } 580 auto tensor_ptr = dyn_cast<tensor::Tensor>(value); 581 TypeId tensor_type = tensor_ptr->Dtype()->type_id(); 582 if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { 583 float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); 584 auto threshold = FLT_MIN; 585 for (int i = 0; i < tensor_ptr->DataSize(); i++) { 586 if (fabs(data2[i] - check_value_) > threshold) { 587 return false; 588 } 589 } 590 return true; 591 } else if (tensor_type == TypeId::kNumberTypeFloat64) { 592 double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); 593 auto threshold = DBL_MIN; 594 for (int i = 0; i < tensor_ptr->DataSize(); i++) { 595 if (fabs(data2[i] - check_value_) > threshold) { 596 return false; 597 } 598 } 599 return true; 600 } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { 601 int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c()); 602 for (int i = 0; i < tensor_ptr->DataSize(); i++) { 603 if (data2[i] != check_value_) { 604 return false; 605 } 606 } 607 return true; 608 } 609 // Input Data Type is not supported 610 return false; 611 } 612 613 bool IsTensorScalarConstant(const ValuePtr &value) const { 614 if (!value->isa<tensor::Tensor>()) { 615 return false; 616 } 617 auto tensor_ptr = dyn_cast<tensor::Tensor>(value); 618 if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { 619 return false; 620 } 621 return IsTensorConstant(value); 622 } 623 624 void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const { 625 if (!node->isa<ValueNode>()) { 626 return nullptr; 627 } 628 auto value = node->cast<ValueNodePtr>()->value(); 629 if (!value->isa<tensor::Tensor>()) { 630 return nullptr; 631 } 632 633 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); 634 return tensor_ptr->data_c(); 635 } 636 637 // Make a new tensor (when possible) with the same shape as of `node` 638 // If x is nullptr then fill new tensor will "0" 639 // If x is a tensor with empty shape then fill new tensor with the single value of x 640 // If x is a tensor with same shape as `node` then return x as result 641 AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const { 642 if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { 643 return nullptr; 644 } 645 646 auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); 647 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); 648 ShapeVector tensor_shape = tensor_abstract->shape()->shape(); 649 650 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); 651 size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); 652 char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); 653 654 if (x == nullptr) { 655 if (memset_s(data, mem_size, 0, mem_size) != 0) { 656 return nullptr; 657 } 658 auto new_vnode = NewValueNode(new_tensor_ptr); 659 new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); 660 return new_vnode; 661 } 662 // x is not nullptr 663 if (x->isa<CNode>() || x->isa<Parameter>()) { 664 if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) { 665 return nullptr; 666 } 667 auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>(); 668 ShapeVector x_shape = x_abstract->shape()->shape(); 669 if (x_shape != tensor_shape) { 670 return nullptr; 671 } 672 return x; 673 } 674 675 if (!x->isa<ValueNode>()) { 676 return nullptr; 677 } 678 auto x_value = x->cast<ValueNodePtr>()->value(); 679 if (!x_value->isa<tensor::Tensor>()) { 680 return nullptr; 681 } 682 683 auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value); 684 if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { 685 return nullptr; 686 } 687 int ret = 0; 688 char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x)); 689 MS_EXCEPTION_IF_NULL(source_data); 690 if (x_tensor_ptr->DataSize() == 1) { 691 auto tensor_type_byte = GetTypeByte(tensor_type_ptr); 692 for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { 693 ret = memcpy_s(data + i * tensor_type_byte, tensor_type_byte, source_data, tensor_type_byte); 694 if (ret != 0) { 695 MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << tensor_type_byte << ", dest size " 696 << tensor_type_byte; 697 } 698 } 699 } else { 700 ret = memcpy_s(data, mem_size, source_data, mem_size); 701 if (ret != 0) { 702 MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << mem_size << ", dest size " 703 << new_tensor_ptr->DataSize(); 704 return nullptr; 705 } 706 } 707 auto new_vnode = NewValueNode(new_tensor_ptr); 708 new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); 709 return new_vnode; 710 } 711 712 AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const { 713 if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { 714 return nullptr; 715 } 716 717 auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); 718 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); 719 ShapeVector tensor_shape = tensor_abstract->shape()->shape(); 720 721 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); 722 size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); 723 char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); 724 725 if (memset_s(data, mem_size, value, mem_size) != 0) { 726 return nullptr; 727 } 728 auto new_vnode = NewValueNode(new_tensor_ptr); 729 new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); 730 return new_vnode; 731 } 732 733 template <typename TD> 734 TD CalcuConstant(const TD &data, const PrimitivePtr &calcu_type) { 735 TD tmp_data = data; 736 if (calcu_type == prim::kPrimReciprocal) { 737 if (data == 0) { 738 MS_EXCEPTION(ValueError); 739 } else { 740 tmp_data = 1 / data; 741 } 742 } 743 if (calcu_type == prim::kPrimNeg) { 744 tmp_data = -data; 745 } 746 return tmp_data; 747 } 748 749 template <typename TD> 750 bool TensorCopyData(const tensor::TensorPtr &src_tensor_ptr, const tensor::TensorPtr &dst_tensor_ptr, 751 const PrimitivePtr &calcu_type, size_t mem_size) { 752 auto *data = reinterpret_cast<TD *>(src_tensor_ptr->data_c()); 753 auto *data2 = reinterpret_cast<TD *>(dst_tensor_ptr->data_c()); 754 if (memcpy_s(data2, mem_size, data, mem_size) != 0) { 755 return false; 756 } 757 for (int i = 0; i < src_tensor_ptr->DataSize(); i++) { 758 if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { 759 return false; 760 } 761 data2[i] = CalcuConstant(data2[i], calcu_type); 762 } 763 return true; 764 } 765 766 // calculate const with different operations 767 AnfNodePtr CalcuConstantTensor(const AnfNodePtr &node, const ValuePtr &value, const PrimitivePtr &calcu_type) { 768 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); 769 TypeId tensor_type = tensor_ptr->Dtype()->type_id(); 770 auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); 771 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); 772 ShapeVector tensor_shape = tensor_abstract->shape()->shape(); 773 auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); 774 size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); 775 if (new_tensor_ptr->DataSize() < tensor_ptr->DataSize()) { 776 MS_EXCEPTION(ValueError) << "DataSize of new_tensor_ptr is smaller than DataSize of tensor_ptr"; 777 } 778 if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || 779 (tensor_type == TypeId::kNumberTypeFloat64)) { 780 if (!TensorCopyData<float>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) { 781 return nullptr; 782 } 783 } 784 if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { 785 if (!TensorCopyData<int>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) { 786 return nullptr; 787 } 788 } 789 if (tensor_type == TypeId::kNumberTypeFloat64) { 790 if (!TensorCopyData<double>(tensor_ptr, new_tensor_ptr, calcu_type, mem_size)) { 791 return nullptr; 792 } 793 } 794 auto new_vnode = NewValueNode(new_tensor_ptr); 795 new_vnode->set_abstract(tensor_ptr->ToAbstract()); 796 return new_vnode; 797 } 798 799 // calculate const with different operations 800 AnfNodePtr ValueNodeWithOprations(const PrimitivePtr &calcu_type) { 801 AnfNodePtr node = this->GetNode(captured_node_); 802 if (!node->isa<ValueNode>()) { 803 MS_EXCEPTION(ValueError) << "CalcuValue is trying to use a not ValueNode."; 804 } 805 auto value = node->cast<ValueNodePtr>()->value(); 806 if (value->isa<tensor::Tensor>()) { 807 return CalcuConstantTensor(node, value, calcu_type); 808 } 809 return nullptr; 810 } 811 812 enum BinOperator { 813 ADD = 0, 814 MULTIPLY, 815 }; 816 817 // Support function to add/multiply two constant tensors: partially support broadcasting shapes 818 template <typename TM> 819 void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, 820 int out_data_size, BinOperator bin_operator) const { 821 if (out_data_size <= 0) { 822 MS_EXCEPTION(ValueError) << "out_data_size should be greater than zeros"; 823 } 824 TM *data_1 = reinterpret_cast<TM *>(in_data_1); 825 TM *data_2 = reinterpret_cast<TM *>(in_data_2); 826 TM *data_out = new TM[out_data_size]; 827 828 if (in_data_1_size == 1) { 829 for (int i = 0; i < out_data_size; i++) { 830 data_out[i] = data_1[0]; 831 } 832 } else { 833 for (int i = 0; i < out_data_size; i++) { 834 data_out[i] = data_1[i]; 835 } 836 } 837 if (in_data_2_size == 1) { 838 for (int i = 0; i < out_data_size; i++) { 839 if (bin_operator == ADD) { 840 data_out[i] += data_2[0]; 841 } else { 842 data_out[i] *= data_2[0]; 843 } 844 } 845 } else { 846 if (in_data_2_size < out_data_size) { 847 MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size 848 << ".in_data2 will be broadcast."; 849 } 850 auto min_size = std::min<int>(in_data_2_size, out_data_size); 851 for (int i = 0; i < min_size; i++) { 852 if (bin_operator == ADD) { 853 data_out[i] += data_2[i]; 854 } else { 855 data_out[i] *= data_2[i]; 856 } 857 } 858 // In case of in_data2_size < out_data_size 859 for (int i = min_size; i < out_data_size; i++) { 860 if (bin_operator != ADD) { 861 // if operator is MUL, data_out[i] *= 0, => data_out[i] = 0. 862 data_out[i] = 0; 863 } 864 // if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange. 865 } 866 } 867 *out_data = reinterpret_cast<void *>(data_out); 868 return; 869 } 870 871 AnfNodePtr AddByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const { 872 AnfNodePtr vnode_1 = this->GetNode(captured_node_); 873 AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); 874 return CalcConstantTensors(vnode_1, vnode_2, node_3, ADD); 875 } 876 877 AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const { 878 AnfNodePtr vnode_1 = this->GetNode(captured_node_); 879 AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); 880 return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY); 881 } 882 883 tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &vnode) const { 884 if (!vnode->isa<ValueNode>() || vnode->abstract() == nullptr) { 885 return nullptr; 886 } 887 auto value = GetValueNode(vnode); 888 if (!value->isa<tensor::Tensor>()) { 889 return nullptr; 890 } 891 return dyn_cast<tensor::Tensor>(value); 892 } 893 894 AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3, 895 BinOperator bin_operator) const { 896 auto tensor_ptr_1 = GetTensorFromValueNode(vnode_1); 897 auto tensor_ptr_2 = GetTensorFromValueNode(vnode_2); 898 if (tensor_ptr_1 == nullptr || tensor_ptr_2 == nullptr || node_3->abstract() == nullptr) { 899 return nullptr; 900 } 901 tensor::TensorPtr new_tensor_ptr = GetNewTensor(vnode_1, vnode_2, node_3); 902 if (new_tensor_ptr == nullptr) { 903 return nullptr; 904 } 905 ShapeVector tensor_out_shape = new_tensor_ptr->shape(); 906 int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>()); 907 size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * IntToSize(new_tensor_ptr->ElementsNum()); 908 char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); 909 910 int ret = 0; 911 void *data_out = nullptr; 912 if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) || 913 (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) { 914 CalcByOperator<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), 915 tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); 916 ret = memcpy_s(data, mem_size, data_out, mem_size); 917 delete[] reinterpret_cast<float *>(data_out); 918 } else { 919 if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) { 920 CalcByOperator<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), 921 tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); 922 ret = memcpy_s(data, mem_size, data_out, mem_size); 923 delete[] reinterpret_cast<double *>(data_out); 924 } else { 925 if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) || 926 (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) { 927 CalcByOperator<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), 928 tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); 929 ret = memcpy_s(data, mem_size, data_out, mem_size); 930 delete[] reinterpret_cast<int *>(data_out); 931 } else { 932 // Unsupported data types 933 return nullptr; 934 } 935 } 936 } 937 if (ret != 0) { 938 MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" 939 << new_tensor_ptr->DataSize(); 940 } 941 auto new_vnode = NewValueNode(new_tensor_ptr); 942 new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); 943 return new_vnode; 944 } 945 946 using Internal = const PConstant<T> &; 947 948 protected: 949 mutable AnfNodePtr as_node_; 950 mutable AnfNodePtr captured_node_; 951 bool any_value_{true}; 952 int64_t check_value_{0}; 953 bool is_scalar_{false}; 954 mutable bool is_new_value_node_{false}; 955 mutable bool captured_{false}; 956 mutable bool changed_shape_{false}; 957 958 private: 959 tensor::TensorPtr GetNewTensor(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { 960 auto value_1 = GetValueNode(vnode_1); 961 auto value_2 = GetValueNode(vnode_2); 962 auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1); 963 auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2); 964 auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); 965 auto tensor_2_abstract = vnode_2->abstract()->cast<abstract::AbstractTensorPtr>(); 966 967 TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); 968 TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); 969 if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) && 970 (tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) { 971 // If two constant nodes have the same shape, then create a new one with this shape 972 auto tensor_out_shape = tensor_1_abstract->shape()->shape(); 973 974 return std::make_shared<tensor::Tensor>(tensor_1_type_ptr->type_id(), tensor_out_shape); 975 } else { 976 // If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node 977 auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>(); 978 979 TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); 980 if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || 981 (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { 982 return nullptr; 983 } 984 auto tensor_out_shape = tensor_3_abstract->shape()->shape(); 985 int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>()); 986 if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { 987 return nullptr; 988 } 989 if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { 990 return nullptr; 991 } 992 return std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); 993 } 994 } 995 }; 996 997 // Macro for binary operation functions 998 #define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \ 999 template <typename T, typename T2> \ 1000 inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \ 1001 return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \ 1002 } 1003 1004 // Arithmetic operations 1005 BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true); 1006 BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); 1007 BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false); 1008 BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); 1009 1010 // Macros for match and replace 1011 #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ 1012 if ((CaptureNode).TryCapture(OrigNode)) { \ 1013 auto rep = (ReplaceWith).GetNode(OrigNode); \ 1014 if (rep != nullptr) { \ 1015 return rep; \ 1016 } \ 1017 } 1018 1019 #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ 1020 if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ 1021 auto rep = (ReplaceWith).GetNode(OrigNode); \ 1022 if (rep != nullptr) { \ 1023 return rep; \ 1024 } \ 1025 } 1026 1027 #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ 1028 if ((CaptureNode).TryCapture(OrigNode)) { \ 1029 if ((Condition)) { \ 1030 auto rep = (ReplaceWith).GetNode(OrigNode); \ 1031 if (rep != nullptr) { \ 1032 return (ReplaceWith).GetNode(OrigNode); \ 1033 } \ 1034 } else { \ 1035 auto rep = (ElseNode).GetNode(OrigNode); \ 1036 if (rep != nullptr) { \ 1037 return (ElseNode).GetNode(OrigNode); \ 1038 } \ 1039 } \ 1040 } 1041 1042 #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ 1043 if ((CaptureNode).TryCapture(OrigNode)) { \ 1044 auto rep = (Lambda)(); \ 1045 if (rep != nullptr) { \ 1046 return rep; \ 1047 } \ 1048 } 1049 1050 #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ 1051 if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ 1052 auto rep = (Lambda)(); \ 1053 if (rep != nullptr) { \ 1054 return rep; \ 1055 } \ 1056 } 1057 1058 #define MATCH_REPLACE_LAMBDA_FLAG(OrigNode, CaptureNode, Lambda, Flag) \ 1059 if ((CaptureNode).TryCapture(OrigNode)) { \ 1060 auto rep = (Lambda)(Flag); \ 1061 if (rep != nullptr) { \ 1062 return rep; \ 1063 } \ 1064 } 1065 } // namespace mindspore 1066 1067 #endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ 1068