1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_CUSTOM_NODES_H_ 17 #define MINDSPORE_PI_JIT_CUSTOM_NODES_H_ 18 19 #include <memory> 20 #include <string> 21 #include "pipeline/jit/pi/pydef.h" 22 #include "pipeline/jit/pi/graph_compiler/pi_ir/operation.h" 23 24 namespace mindspore { 25 namespace pijit { 26 namespace ir { 27 /// \brief RefNode is the class which represent that this node is defined elsewhere and is only used here. 28 class RefNode : public Node { 29 public: 30 /** 31 * \brief The constructor of reference node. 32 * 33 * \return The instance of reference node. 34 */ RefNode(const NodePtr & node)35 explicit RefNode(const NodePtr &node) : real_node_(node) {} 36 37 // \brief Destructor. 38 ~RefNode() override = default; 39 JIT_DECLARE_PARENT(RefNode, Node); 40 41 /** 42 * \brief Get the node this reference node represents. 43 * 44 * \return The node this reference node represents. 45 */ GetRealNode()46 const NodePtr &GetRealNode() const { return real_node_; } 47 48 /** 49 * \brief Set the real node of the ref node. 50 * 51 * \param[in] node the real object. 52 */ SetRealNode(const NodePtr & node)53 void SetRealNode(const NodePtr &node) { real_node_ = node; } 54 55 /** 56 * \brief Get the description of this node. 57 * \return The description. 58 */ ToString()59 std::string ToString() const override { 60 return "%" + std::to_string(GetNodeId()) + " = [" + GetType()->GetName() + "](" + GetNodeName() + ", " + 61 std::to_string(real_node_->GetNodeId()) + ")\n"; 62 } 63 64 private: 65 /// \brief The node this reference node represents 66 NodePtr real_node_; 67 }; 68 69 using RefNodePtr = std::shared_ptr<RefNode>; 70 71 /// \brief PlaceHolder is the class which represent a symbol, and don't care about object specific information. 72 class PlaceHolder : public Node { 73 public: 74 /** 75 * \brief The constructor of PlaceHolder node. 76 * 77 * \return The instance of PlaceHolder node. 78 */ PlaceHolder(const std::string & tag)79 explicit PlaceHolder(const std::string &tag) : tag_(tag) {} 80 81 // \brief Destructor. 82 ~PlaceHolder() override = default; 83 JIT_DECLARE_PARENT(PlaceHolder, Node); 84 85 /** 86 * \brief Get the tag of PlaceHolder node. 87 * 88 * \return The tag of PlaceHolder node. 89 */ GetTag()90 const std::string &GetTag() const { return tag_; } 91 92 /** 93 * \brief Set the id of this node. 94 * 95 * \note This method should not be actively called by the program writer, it should only be called by the method 96 * Sort() 97 */ SetNodeId(size_t * id)98 void SetNodeId(size_t *id) override {} 99 100 /** 101 * \brief Get the description of this node. 102 * \return The description. 103 */ ToString()104 std::string ToString() const override { 105 return "%" + std::to_string(GetNodeId()) + " = [" + GetType()->GetName() + "](" + GetNodeName() + ", " + tag_ + 106 ")\n"; 107 } 108 109 private: 110 /// \brief The mark of PlaceHolder used to explain the special meaning 111 const std::string tag_; 112 }; 113 114 using PlaceHolderPtr = std::shared_ptr<PlaceHolder>; 115 116 /// \brief SubscrNode is the class which represent a subscript access of object. 117 class SubscrNode : public Node { 118 public: 119 /** 120 * \brief The constructor of subscript node. 121 * 122 * \param[in] base the object being accessed. 123 * \param[in] subscr the subscript. 124 * 125 * \return The instance of subscript node. 126 */ SubscrNode(const NodePtr & base,const NodePtr & subscr)127 SubscrNode(const NodePtr &base, const NodePtr &subscr) : base_(base), subscr_(subscr) {} 128 129 // \brief Destructor. 130 ~SubscrNode() override = default; 131 JIT_DECLARE_PARENT(SubscrNode, Node); 132 133 /** 134 * \brief Get the object being accessed. 135 * 136 * \return The object being accessed. 137 */ GetObject()138 const NodePtr &GetObject() const { return base_; } 139 140 /** 141 * \brief Set the the object being accessed. 142 * 143 * \param[in] obj the object. 144 */ SetObject(const NodePtr & obj)145 void SetObject(const NodePtr &obj) { base_ = obj; } 146 147 /** 148 * \brief Get the subscr want to accessed. 149 * 150 * \return The subscr want to accessed. 151 */ GetSubscr()152 const NodePtr &GetSubscr() const { return subscr_; } 153 154 /** 155 * \brief Set the subscr want to accessed. 156 * 157 * \param[in] subscr the element. 158 */ SetSubscr(const NodePtr & subscr)159 void SetSubscr(const NodePtr &subscr) { subscr_ = subscr; } 160 161 /** 162 * \brief Set the id of this node. 163 * 164 * \note This method should not be actively called by the program writer, it should only be called by the method 165 * Sort() 166 */ SetNodeId(size_t * id)167 void SetNodeId(size_t *id) override { 168 base_->SetNodeId(id); 169 subscr_->SetNodeId(id); 170 } 171 172 /** 173 * \brief Set the offset of this node. 174 * 175 * \note This method should not be actively called by the program writer, it should only be called by the method 176 * Sort() 177 */ SetOffset(size_t * offset)178 void SetOffset(size_t *offset) override { 179 base_->SetOffset(offset); 180 subscr_->SetOffset(offset); 181 } 182 183 /** 184 * \brief Get the description of this node. 185 * \return The description. 186 */ ToString()187 std::string ToString() const override { 188 return base_->ToString() + "\n" + subscr_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = %" + 189 std::to_string(base_->GetNodeId()) + "[%" + std::to_string(subscr_->GetNodeId()) + "]\n"; 190 } 191 192 private: 193 NodePtr base_; 194 NodePtr subscr_; 195 }; 196 197 using SubscrNodePtr = std::shared_ptr<SubscrNode>; 198 199 /// \brief SubscrNode is the class which represent a attr or method of the object. 200 class AttrNode : public Node { 201 public: 202 /** 203 * \brief The constructor of attribute node. 204 * 205 * \param[in] base the object being accessed. 206 * \param[in] attr the attribute name. 207 * 208 * \return The instance of attribute node. 209 */ AttrNode(const NodePtr & base,const NodePtr & attr)210 AttrNode(const NodePtr &base, const NodePtr &attr) : base_(base), attr_(attr) {} 211 212 // \brief Destructor. 213 ~AttrNode() override = default; 214 JIT_DECLARE_PARENT(AttrNode, Node); 215 216 /** 217 * \brief Get the object being accessed. 218 * 219 * \return The object being accessed. 220 */ GetObject()221 const NodePtr &GetObject() const { return base_; } 222 223 /** 224 * \brief Set the object being accessed. 225 * 226 * \param[in] obj the object. 227 */ SetObject(const NodePtr & obj)228 void SetObject(const NodePtr &obj) { base_ = obj; } 229 230 /** 231 * \brief Get the attribute name of the object. 232 * 233 * \return The attribute name of the object. 234 */ GetAttr()235 const NodePtr &GetAttr() const { return attr_; } 236 237 /** 238 * \brief Set the attribute name of the object. 239 * 240 * \param[in] attr the attribute name. 241 */ SetAttr(const NodePtr & attr)242 void SetAttr(const NodePtr &attr) { attr_ = attr; } 243 244 /** 245 * \brief Set the id of this node. 246 * 247 * \note This method should not be actively called by the program writer, it should only be called by the method 248 * Sort() 249 */ SetNodeId(size_t * id)250 void SetNodeId(size_t *id) override { 251 base_->SetNodeId(id); 252 attr_->SetNodeId(id); 253 Node::SetNodeId(id); 254 } 255 256 /** 257 * \brief Set the offset of this node. 258 * 259 * \note This method should not be actively called by the program writer, it should only be called by the method 260 * Sort() 261 */ SetOffset(size_t * offset)262 void SetOffset(size_t *offset) override { base_->SetOffset(offset); } 263 264 /** 265 * \brief Get the description of this node. 266 * \return The description. 267 */ ToString()268 std::string ToString() const override { 269 return base_->ToString() + "\n" + attr_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = %" + 270 std::to_string(base_->GetNodeId()) + ".%" + std::to_string(attr_->GetNodeId()) + "\n"; 271 } 272 273 private: 274 NodePtr base_; 275 NodePtr attr_; 276 }; 277 278 using AttrNodePtr = std::shared_ptr<AttrNode>; 279 280 /// \brief PairNode is the class which represent the object subscript access. 281 class PairNode : public Node { 282 public: 283 /** 284 * \brief The constructor of pair node. 285 * 286 * \param[in] first the first element of the pair. 287 * \param[in] second the second element of the pair. 288 * 289 * \return The instance of pair node. 290 */ PairNode(const NodePtr & first,const NodePtr & second)291 PairNode(const NodePtr &first, const NodePtr &second) : first_(first), second_(second) {} 292 293 // \brief Destructor. 294 ~PairNode() override = default; 295 JIT_DECLARE_PARENT(PairNode, Node); 296 297 /** 298 * \brief Get the first element of the pair. 299 * 300 * \return The first element of the pair. 301 */ GetFirst()302 const NodePtr &GetFirst() const { return first_; } 303 304 /** 305 * \brief Set the first element of the pair. 306 * 307 * \param[in] arg the element. 308 */ SetFirst(const NodePtr & arg)309 void SetFirst(const NodePtr &arg) { first_ = arg; } 310 311 /** 312 * \brief Get the second element of the pair. 313 * 314 * \return The second element of the pair. 315 */ GetSecond()316 const NodePtr &GetSecond() const { return second_; } 317 318 /** 319 * \brief Set the second element of the pair. 320 * 321 * \param[in] arg the element. 322 */ SetSecond(const NodePtr & arg)323 void SetSecond(const NodePtr &arg) { second_ = arg; } 324 325 /** 326 * \brief Get the description of this node. 327 * \return The description. 328 */ ToString()329 std::string ToString() const override { 330 return first_->ToString() + "\n" + second_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = (" + 331 std::to_string(first_->GetNodeId()) + ", " + std::to_string(second_->GetNodeId()) + ")\n"; 332 } 333 334 private: 335 NodePtr first_; 336 NodePtr second_; 337 }; 338 339 using PairNodePtr = std::shared_ptr<PairNode>; 340 341 /// \brief InstrArg is the base class which represent the arg of instruction. 342 class InstrArg { 343 public: 344 /** 345 * \brief The constructor of InstrArg. 346 * 347 * \param[in] arg the value of arg. 348 * 349 * \return The instance of InstrArg. 350 */ InstrArg(int arg)351 explicit InstrArg(int arg) : instr_arg_(arg) {} 352 // \brief Destructor. 353 virtual ~InstrArg() = default; 354 355 /** 356 * \brief Get the value of the instruction arg. 357 * 358 * \return The value of the instruction arg. 359 */ GetInstrArg()360 int GetInstrArg() const { return instr_arg_; } 361 362 /** 363 * \brief Set the value of the instruction arg. 364 * 365 * \param[in] arg the value of the instruction arg. 366 */ SetInstrArg(int arg)367 void SetInstrArg(int arg) { instr_arg_ = arg; } 368 369 private: 370 /// \brief The value of the instruction arg. 371 int instr_arg_; 372 }; 373 374 /// \brief NegativeNode is the class which represent operation that take negative value. 375 class NegativeNode : public UnaryOperation { 376 public: 377 /** 378 * \brief The constructor of negative node. 379 * 380 * \param[in] opnd the value of negative node. 381 * 382 * \return The instance of negative node. 383 */ NegativeNode(const NodePtr & opnd)384 explicit NegativeNode(const NodePtr &opnd) : UnaryOperation(UNARY_NEGATIVE, opnd) {} 385 386 // \brief Destructor. 387 ~NegativeNode() override = default; 388 JIT_DECLARE_PARENT(NegativeNode, UnaryOperation); 389 }; 390 391 using NegativeNodePtr = std::shared_ptr<NegativeNode>; 392 393 /// \brief NotNode is the class which represent the operation that take logical negation. 394 class NotNode : public UnaryOperation { 395 public: 396 /** 397 * \brief The constructor of logical not node. 398 * 399 * \param[in] opnd the value of logical not node. 400 * 401 * \return The instance of logical not node. 402 */ NotNode(const NodePtr & opnd)403 explicit NotNode(const NodePtr &opnd) : UnaryOperation(UNARY_NOT, opnd) {} 404 405 // \brief Destructor. 406 ~NotNode() override = default; 407 JIT_DECLARE_PARENT(NotNode, UnaryOperation); 408 }; 409 410 using NotNodePtr = std::shared_ptr<NotNode>; 411 412 /// \brief InvertNode is the class which represent the operation that take bitwise inversion. 413 class InvertNode : public UnaryOperation { 414 public: 415 /** 416 * \brief The constructor of invert node. 417 * 418 * \param[in] opnd the value of invert node. 419 * 420 * \return The instance of invert node. 421 */ InvertNode(const NodePtr & opnd)422 explicit InvertNode(const NodePtr &opnd) : UnaryOperation(UNARY_INVERT, opnd) {} 423 424 // \brief Destructor. 425 ~InvertNode() override = default; 426 JIT_DECLARE_PARENT(InvertNode, UnaryOperation); 427 }; 428 429 using InvertNodePtr = std::shared_ptr<InvertNode>; 430 431 /// \brief ReturnNode is the class which represent the return of function. 432 class ReturnNode : public UnaryOperation { 433 public: 434 /** 435 * \brief The constructor of return node. 436 * 437 * \param[in] res the value of return node. 438 * 439 * \return The instance of return node. 440 */ ReturnNode(const NodePtr & res)441 explicit ReturnNode(const NodePtr &res) : UnaryOperation(RETURN_VALUE, res) {} 442 443 // \brief Destructor. 444 ~ReturnNode() override = default; 445 JIT_DECLARE_PARENT(ReturnNode, UnaryOperation); 446 447 /** 448 * \brief Get the value of return node. 449 * 450 * \return the return value. 451 */ GetReturn()452 const NodePtr &GetReturn() const { return GetArg(); } 453 }; 454 455 using ReturnNodePtr = std::shared_ptr<ReturnNode>; 456 457 /// \brief CastNode is the class which represent convert one type to another. 458 class CastNode : public UnaryOperation { 459 public: 460 /** 461 * \brief The constructor of cast node. 462 * 463 * \param[in] opnd the value of cast node. 464 * 465 * \return The instance of cast node. 466 */ CastNode(const NodePtr & opnd)467 explicit CastNode(const NodePtr &opnd) : UnaryOperation(LIST_TO_TUPLE, opnd) {} 468 469 // \brief Destructor. 470 ~CastNode() override = default; 471 JIT_DECLARE_PARENT(CastNode, UnaryOperation); 472 }; 473 474 using CastNodePtr = std::shared_ptr<CastNode>; 475 476 /// \brief DeleteNode is the class which represent delete a object. 477 class DeleteNode : public UnaryOperation { 478 public: 479 /** 480 * \brief The constructor of delete node. 481 * 482 * \param[in] opnd the object will be deleted. 483 * 484 * \return The instance of cast node. 485 */ DeleteNode(OpCode op,const NodePtr & opnd)486 explicit DeleteNode(OpCode op, const NodePtr &opnd) : UnaryOperation(op, opnd) {} 487 488 // \brief Destructor. 489 ~DeleteNode() override = default; 490 JIT_DECLARE_PARENT(DeleteNode, UnaryOperation); 491 }; 492 493 using DeleteNodePtr = std::shared_ptr<DeleteNode>; 494 495 /// \brief GetNode is the class which represent get a property of an object with `Get_*`. 496 class GetNode : public UnaryOperation { 497 public: 498 /** 499 * \brief The constructor of get node. 500 * 501 * \param[in] opnd the object. 502 * 503 * \return The instance of get node. 504 */ GetNode(OpCode op,const NodePtr & opnd)505 explicit GetNode(OpCode op, const NodePtr &opnd) : UnaryOperation(op, opnd) {} 506 507 // \brief Destructor. 508 ~GetNode() override = default; 509 JIT_DECLARE_PARENT(GetNode, UnaryOperation); 510 }; 511 512 using GetNodePtr = std::shared_ptr<GetNode>; 513 514 /// \brief LoadValueNode is the class which represent load a value to stack. 515 class LoadValueNode : public UnaryOperation { 516 public: 517 /** 518 * \brief The constructor of load node. 519 * 520 * \param[in] value the value will be load. 521 * 522 * \return The instance of load node. 523 */ LoadValueNode(OpCode op,const NodePtr & value)524 LoadValueNode(OpCode op, const NodePtr &value) : UnaryOperation(op, value) {} 525 526 // \brief Destructor. 527 ~LoadValueNode() override = default; 528 JIT_DECLARE_PARENT(LoadValueNode, NaryOperation); 529 }; 530 531 using LoadValueNodePtr = std::shared_ptr<LoadValueNode>; 532 533 /// \brief LoadFieldNode is the class which represent load a filed of class to stack. 534 class LoadFieldNode : public BinaryOperation { 535 public: 536 /** 537 * \brief The constructor of load node. 538 * 539 * \param[in] cls_ins the instance of class. 540 * \param[in] field the field will be load. 541 * 542 * \return The instance of load node. 543 */ LoadFieldNode(OpCode op,const NodePtr & cls_ins,const NodePtr & field)544 LoadFieldNode(OpCode op, const NodePtr &cls_ins, const NodePtr &field) : BinaryOperation(op, cls_ins, field) {} 545 546 // \brief Destructor. 547 ~LoadFieldNode() override = default; 548 JIT_DECLARE_PARENT(LoadFieldNode, BinaryOperation); 549 }; 550 551 using LoadFieldNodePtr = std::shared_ptr<LoadFieldNode>; 552 553 /// \brief AddNode is the class which represent the addition of two operands. 554 class AddNode : public BinaryOperation { 555 public: 556 /** 557 * \brief The constructor of add node. 558 * 559 * \param[in] left the first operand of add. 560 * \param[in] right the second operand of add. 561 * \param[in] is_inplace whether the sum store to the first operand. 562 * 563 * \return The instance of add node. 564 */ AddNode(OpCode op,const NodePtr & left,const NodePtr & right)565 AddNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {} 566 567 // \brief Destructor. 568 ~AddNode() override = default; 569 JIT_DECLARE_PARENT(AddNode, BinaryOperation); 570 571 /** 572 * \brief Judge whether the opcode of this node is INPLACE_ADD. 573 * 574 * \return The result of the judgment. 575 */ IsInplace()576 bool IsInplace() const { return INPLACE_ADD == GetOpCode(); } 577 }; 578 579 using AddNodePtr = std::shared_ptr<AddNode>; 580 581 /// \brief SubNode is the class which represent the subtraction of two operands. 582 class SubNode : public BinaryOperation { 583 public: 584 /** 585 * \brief The constructor of sub node. 586 * 587 * \param[in] left the first operand of sub. 588 * \param[in] right the second operand of sub. 589 * \param[in] is_inplace whether the difference store to the first operand. 590 * 591 * \return The instance of sub node. 592 */ SubNode(OpCode op,const NodePtr & left,const NodePtr & right)593 SubNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {} 594 595 // \brief Destructor. 596 ~SubNode() override = default; 597 JIT_DECLARE_PARENT(SubNode, BinaryOperation); 598 599 /** 600 * \brief Judge whether the opcode of this node is INPLACE_ADD. 601 * 602 * \return The result of the judgment. 603 */ IsInplace()604 bool IsInplace() const { return INPLACE_SUBTRACT == GetOpCode(); } 605 }; 606 607 using SubNodePtr = std::shared_ptr<SubNode>; 608 609 /// \brief MulNode is the class which represent the multiplication of two operands. 610 class MulNode : public BinaryOperation { 611 public: 612 /** 613 * \brief The constructor of mul node. 614 * 615 * \param[in] left the first operand of mul. 616 * \param[in] right the second operand of mul. 617 * \param[in] is_inplace whether the product store to the first operand. 618 * 619 * \return The instance of mul node. 620 */ MulNode(OpCode op,const NodePtr & left,const NodePtr & right)621 MulNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {} 622 623 // \brief Destructor. 624 ~MulNode() override = default; 625 JIT_DECLARE_PARENT(MulNode, BinaryOperation); 626 627 /** 628 * \brief Judge whether the opcode of this node is INPLACE_MULTIPLY. 629 * 630 * \return The result of the judgment. 631 */ IsInplace()632 bool IsInplace() const { return (INPLACE_MULTIPLY == GetOpCode()) || (INPLACE_MATRIX_MULTIPLY == GetOpCode()); } 633 }; 634 635 using MulNodePtr = std::shared_ptr<MulNode>; 636 637 /// \brief DivNode is the class which represent the division of two operands. 638 class DivNode : public BinaryOperation { 639 public: 640 /** 641 * \brief The constructor of div node. 642 * 643 * \param[in] left the first operand of div. 644 * \param[in] right the second operand of div. 645 * \param[in] is_inplace whether the quotient of division store to the first operand. 646 * 647 * \return The instance of div node. 648 */ DivNode(OpCode op,const NodePtr & left,const NodePtr & right)649 DivNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {} 650 651 // \brief Destructor. 652 ~DivNode() override = default; 653 JIT_DECLARE_PARENT(DivNode, BinaryOperation); 654 655 /** 656 * \brief Judge whether the opcode of this node is INPLACE_TRUE_DIVIDE. 657 * 658 * \return The result of the judgment. 659 */ IsInplace()660 bool IsInplace() const { return INPLACE_TRUE_DIVIDE == GetOpCode(); } 661 }; 662 663 using DivNodePtr = std::shared_ptr<DivNode>; 664 665 /// \brief BitwiseNode is the class which represent the addition of two operands. 666 class BitwiseNode : public BinaryOperation { 667 public: 668 /** 669 * \brief The constructor of add node. 670 * 671 * \param[in] left the first operand of add. 672 * \param[in] right the second operand of add. 673 * \param[in] is_inplace whether the sum store to the first operand. 674 * 675 * \return The instance of add node. 676 */ BitwiseNode(OpCode op,const NodePtr & left,const NodePtr & right)677 BitwiseNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {} 678 679 // \brief Destructor. 680 ~BitwiseNode() override = default; 681 JIT_DECLARE_PARENT(BitwiseNode, BinaryOperation); 682 683 /** 684 * \brief Judge whether the opcode of this node is INPLACE_ADD. 685 * 686 * \return The result of the judgment. 687 */ IsInplace()688 bool IsInplace() const { 689 return INPLACE_LSHIFT == GetOpCode() || INPLACE_RSHIFT == GetOpCode() || INPLACE_AND == GetOpCode() || 690 INPLACE_XOR == GetOpCode() || INPLACE_OR == GetOpCode(); 691 } 692 }; 693 694 using BitwiseNodePtr = std::shared_ptr<BitwiseNode>; 695 696 /// \brief IsNode is the class which represent whether two operands are same or not. 697 class IsNode : public BinaryOperation, public InstrArg { 698 public: 699 /** 700 * \brief The constructor of is node. 701 * 702 * \param[in] left the first operand of is node. 703 * \param[in] right the second operand of is node. 704 * \param[in] is_invert the flag whether invert the result. 705 * 706 * \return The instance of is node. 707 */ IsNode(const NodePtr & left,const NodePtr & right,int arg)708 IsNode(const NodePtr &left, const NodePtr &right, int arg) : BinaryOperation(IS_OP, left, right), InstrArg(arg) {} 709 710 // \brief Destructor. 711 ~IsNode() override = default; 712 JIT_DECLARE_PARENT(IsNode, BinaryOperation); 713 714 /** 715 * \brief Judge whether invert the result. 716 * 717 * \return The result of the judgment. 718 */ IsInvert()719 bool IsInvert() const { return GetInstrArg() != 0; } 720 }; 721 722 using IsNodePtr = std::shared_ptr<IsNode>; 723 724 /// \brief ContainsNode is the class which represent whether one contains another or not. 725 class ContainsNode : public BinaryOperation, public InstrArg { 726 public: 727 /** 728 * \brief The constructor of is node. 729 * 730 * \param[in] left the first operand of is node. 731 * \param[in] right the second operand of is node. 732 * \param[in] is_invert the flag whether invert the result. 733 * 734 * \return The instance of contains node. 735 */ ContainsNode(const NodePtr & left,const NodePtr & right,int arg)736 ContainsNode(const NodePtr &left, const NodePtr &right, int arg) 737 : BinaryOperation(CONTAINS_OP, left, right), InstrArg(arg) {} 738 739 // \brief Destructor. 740 ~ContainsNode() override = default; 741 JIT_DECLARE_PARENT(ContainsNode, BinaryOperation); 742 743 /** 744 * \brief Judge whether invert the result. 745 * 746 * \return The result of the judgment. 747 */ IsInvert()748 bool IsInvert() const { return GetInstrArg() != 0; } 749 }; 750 751 using ContainsNodePtr = std::shared_ptr<ContainsNode>; 752 753 /// \brief StoreNode is the class which represent whether two operands are same. 754 class StoreNode : public BinaryOperation { 755 public: 756 /** 757 * \brief The constructor of store node. 758 * 759 * \param[in] left the first operand of store node. 760 * \param[in] right the second operand of store node. 761 * 762 * \return The instance of store node. 763 */ StoreNode(OpCode op,const NodePtr & source,const NodePtr & target)764 StoreNode(OpCode op, const NodePtr &source, const NodePtr &target) : BinaryOperation(op, source, target) {} 765 766 // \brief Destructor. 767 ~StoreNode() override = default; 768 JIT_DECLARE_PARENT(StoreNode, BinaryOperation); 769 }; 770 771 using StoreNodePtr = std::shared_ptr<StoreNode>; 772 773 /// \brief JumpNode is the class which represent jump stmt. 774 class JumpNode : public BinaryOperation { 775 public: 776 /** 777 * \brief The constructor of jump node. 778 * 779 * \param[in] condition the condition for judging whether to jump. 780 * \param[in] target the jump target. 781 * 782 * \return The instance of jump node. 783 */ JumpNode(OpCode op,const NodePtr & condition,const NodePtr & target)784 JumpNode(OpCode op, const NodePtr &condition, const NodePtr &target) : BinaryOperation(op, condition, target) {} 785 786 // \brief Destructor. 787 ~JumpNode() override = default; 788 JIT_DECLARE_PARENT(JumpNode, BinaryOperation); 789 790 /** 791 * \brief Get the condition for judging whether to jump. 792 * 793 * \return The condition for judging whether to jump. 794 */ GetCondition()795 NodePtr GetCondition() const { return GetLeftArg(); } 796 797 /** 798 * \brief Get the target for jump. 799 * 800 * \return The target for jump. 801 */ GetTarget()802 NodePtr GetTarget() const { return GetRightArg(); } 803 804 /** 805 * \brief Set the target of jump. 806 * 807 * \param[in] target the jump target. 808 */ SetTarget(const NodePtr & target)809 void SetTarget(const NodePtr &target) { SetRightArg(target); } 810 811 /** 812 * \brief Set the id of this node. 813 * 814 * \note This method should not be actively called by the program writer, it should only be called by the method 815 * Sort() 816 */ SetNodeId(size_t * id)817 void SetNodeId(size_t *id) override { 818 auto left = GetLeftArg(); 819 if (left != nullptr) { 820 left->SetNodeId(id); 821 } 822 Node::SetNodeId(id); 823 } 824 825 /** 826 * \brief Set the offset of this node. 827 * 828 * \note This method should not be actively called by the program writer, it should only be called by the method 829 * Sort() 830 */ SetOffset(size_t * offset)831 void SetOffset(size_t *offset) override { 832 auto left = GetLeftArg(); 833 if (left != nullptr) { 834 left->SetOffset(offset); 835 } 836 Node::SetOffset(offset); 837 } 838 839 /** 840 * \brief Get the description of this jump node. 841 * \return The description. 842 */ ToString()843 std::string ToString() const override { 844 std::string str; 845 auto left = GetLeftArg(); 846 if (left != nullptr) { 847 str += left->ToString() + "\n"; 848 } 849 str += "%" + std::to_string(GetNodeId()) + " = " + GetNodeName() + "[" + GetType()->GetName() + "](" + 850 GetOpName(GetOpCode()); 851 if (left != nullptr) { 852 str += ", %" + std::to_string(left->GetNodeId()); 853 } else { 854 str += ", nullptr"; 855 } 856 auto right = GetRightArg(); 857 if (right != nullptr) { 858 str += ", %" + std::to_string(right->GetNodeId()); 859 } else { 860 str += ", nullptr"; 861 } 862 return str + ")\n"; 863 } 864 }; 865 866 using JumpNodePtr = std::shared_ptr<JumpNode>; 867 868 class CompareNode : public BinaryOperation, public InstrArg { 869 public: 870 /** 871 * \brief The constructor of compare node. 872 * 873 * \param[in] category the category of compare. 874 * \param[in] left the first operand of compare node. 875 * \param[in] right the second operand of compare node. 876 * 877 * \return The instance of compare node. 878 */ CompareNode(int arg,const NodePtr & left,const NodePtr & right)879 CompareNode(int arg, const NodePtr &left, const NodePtr &right) 880 : BinaryOperation(COMPARE_OP, left, right), InstrArg(arg) {} 881 882 // \brief Destructor. 883 ~CompareNode() override = default; 884 JIT_DECLARE_PARENT(CompareNode, BinaryOperation); 885 886 /** 887 * \brief Get the description of this node. 888 * \return The description. 889 */ ToString()890 std::string ToString() const override { 891 auto left = GetLeftArg(); 892 auto right = GetRightArg(); 893 return left->ToString() + "\n" + right->ToString() + "\n%" + std::to_string(GetNodeId()) + " = " + GetNodeName() + 894 "[" + GetType()->GetName() + "](" + GetOpName(GetOpCode()) + ", " + std::to_string(GetInstrArg()) + ", %" + 895 std::to_string(left->GetNodeId()) + ", %" + std::to_string(right->GetNodeId()) + ")\n"; 896 } 897 }; 898 899 using CompareNodePtr = std::shared_ptr<CompareNode>; 900 901 /// \brief CallNode is the class which represent merge several dicts/lists into one. 902 class UpdateNode : public BinaryOperation, public InstrArg { 903 public: 904 /** 905 * \brief The constructor of build node. 906 * 907 * \param[in] opnds the operand of build node. 908 * 909 * \return The instance of build node. 910 */ UpdateNode(OpCode op,const NodePtr & left,const NodePtr & right,int arg)911 UpdateNode(OpCode op, const NodePtr &left, const NodePtr &right, int arg) 912 : BinaryOperation(op, left, right), InstrArg(arg) {} 913 914 // \brief Destructor. 915 ~UpdateNode() override = default; 916 JIT_DECLARE_PARENT(UpdateNode, BinaryOperation); 917 }; 918 919 using UpdateNodePtr = std::shared_ptr<UpdateNode>; 920 921 /// \brief FormatNode is the class which represent format an object as required. 922 class FormatNode : public NaryOperation, public InstrArg { 923 public: 924 /** 925 * \brief The constructor of format node. 926 * 927 * \param[in] opnd the value of format node. 928 * \param[in] fmt the format type. 929 * 930 * \return The instance of format node. 931 */ FormatNode(const NodePtrList & opnds,int fmt)932 FormatNode(const NodePtrList &opnds, int fmt) : NaryOperation(FORMAT_VALUE, opnds), InstrArg(fmt) {} 933 934 // \brief Destructor. 935 ~FormatNode() override = default; 936 JIT_DECLARE_PARENT(FormatNode, NaryOperation); 937 938 /** 939 * \brief Get the format type of format node. 940 * 941 * \return the format type. 942 */ GetFormatType()943 int GetFormatType() const { return GetInstrArg(); } 944 }; 945 946 using FormatNodePtr = std::shared_ptr<FormatNode>; 947 948 /// \brief BuildNode is the class which represent build a value. 949 class BuildNode : public NaryOperation { 950 public: 951 /** 952 * \brief The constructor of build node. 953 * 954 * \param[in] opnds the operand of build node. 955 * 956 * \return The instance of build node. 957 */ BuildNode(OpCode op,const NodePtrList & opnds)958 BuildNode(OpCode op, const NodePtrList &opnds) : NaryOperation(op, opnds) {} 959 960 // \brief Destructor. 961 ~BuildNode() override = default; 962 JIT_DECLARE_PARENT(BuildNode, NaryOperation); 963 }; 964 965 using BuildNodePtr = std::shared_ptr<BuildNode>; 966 967 /// \brief CallNode is the class which represent call a function. 968 class CallNode : public NaryOperation { 969 public: 970 /** 971 * \brief The constructor of build node. 972 * 973 * \param[in] opnds the operand of build node. 974 * 975 * \return The instance of build node. 976 */ CallNode(OpCode op,const NodePtrList & opnds)977 CallNode(OpCode op, const NodePtrList &opnds) : NaryOperation(op, opnds) {} 978 979 // \brief Destructor. 980 ~CallNode() override = default; 981 JIT_DECLARE_PARENT(CallNode, NaryOperation); 982 }; 983 984 using CallNodePtr = std::shared_ptr<CallNode>; 985 986 /// \brief NaryWithFlagNode is the class which represent make function. 987 class NaryWithFlagNode : public NaryOperation, public InstrArg { 988 public: 989 /** 990 * \brief The constructor of nary with flag node. 991 * 992 * \param[in] opnds the operand of nary with flag node. 993 * 994 * \return The instance of nary with flag node. 995 */ NaryWithFlagNode(OpCode op,const NodePtrList & opnds,int flag)996 NaryWithFlagNode(OpCode op, const NodePtrList &opnds, int flag) : NaryOperation(op, opnds), InstrArg(flag) {} 997 998 // \brief Destructor. 999 ~NaryWithFlagNode() override = default; 1000 JIT_DECLARE_PARENT(NaryWithFlagNode, NaryOperation); 1001 1002 /** 1003 * \brief Get the flag of make function node. 1004 * 1005 * \return the flag. 1006 */ GetFlag()1007 int GetFlag() const { return GetInstrArg(); } 1008 }; 1009 1010 using NaryWithFlagNodePtr = std::shared_ptr<NaryWithFlagNode>; 1011 } // namespace ir 1012 } // namespace pijit 1013 } // namespace mindspore 1014 1015 #endif // MINDSPORE_PI_JIT_CUSTOM_NODES_H_ 1016