1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_CTRL_FLOW_H_ 17 #define MINDSPORE_PI_JIT_CTRL_FLOW_H_ 18 19 #include <map> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h" 24 25 namespace mindspore { 26 namespace pijit { 27 namespace ir { 28 29 /// \brief Parameter is is the class which represent a parameter of function or method 30 class Parameter : public Node { 31 public: 32 /** 33 * \brief The constructor of parameter node. 34 * 35 * \param[in] index the index of parameter. 36 * \param[in] name the name of parameter. 37 * 38 * \return The instance of function node. 39 */ Parameter(size_t index,const std::string & name)40 Parameter(size_t index, const std::string &name) 41 : index_(index), name_(name), value_(nullptr), default_value_(nullptr), category_(0) {} 42 43 // \brief Destructor. 44 ~Parameter() override = default; 45 JIT_DECLARE_PARENT(Parameter, Node); 46 47 static constexpr int POSITIONAL = 0; 48 static constexpr int VARIABLE = 1; 49 static constexpr int KEYWORD_ONLY = 2; 50 static constexpr int KEYWORD = 3; 51 52 /** 53 * \brief Set the id of this node. 54 * 55 * \note This method should not be actively called by the program writer, it should only be called by the method 56 * Sort() 57 */ SetNodeId(size_t * id)58 void SetNodeId(size_t *id) override { 59 if (value_ != nullptr) { 60 value_->SetNodeId(id); 61 } 62 if (default_value_ != nullptr) { 63 default_value_->SetNodeId(id); 64 } 65 Node::SetNodeId(id); 66 } 67 68 /** 69 * \brief Get the index of parameter. 70 * 71 * \return the index of parameter. 72 */ GetIndex()73 size_t GetIndex() const { return index_; } 74 75 /** 76 * \brief Get the name of parameter. 77 * 78 * \return the name of parameter. 79 */ GetName()80 const std::string &GetName() const { return name_; } 81 82 /** 83 * \brief Set the name of parameter. 84 * 85 * \param[in] name the name of parameter. 86 */ SetName(const std::string & name)87 void SetName(const std::string &name) { name_ = name; } 88 89 /** 90 * \brief Get the value of parameter. 91 * 92 * \return the value of parameter. 93 */ GetValue()94 const NodePtr &GetValue() const { return value_; } 95 96 /** 97 * \brief Set the value of parameter. 98 * 99 * \param[in] value the value of parameter. 100 */ SetValue(const NodePtr & value)101 void SetValue(const NodePtr &value) { value_ = value; } 102 103 /** 104 * \brief Get the default value of parameter. 105 * 106 * \return the default value of parameter. 107 */ GetDefaultValue()108 const NodePtr &GetDefaultValue() const { return default_value_; } 109 110 /** 111 * \brief Set the default value of parameter. 112 * 113 * \param[in] default_value the default value of parameter. 114 */ SetDefaultValue(const NodePtr & default_value)115 void SetDefaultValue(const NodePtr &default_value) { default_value_ = default_value; } 116 117 /** 118 * \brief Get the category of parameter. 119 * 120 * \return the category of parameter. 121 */ GetCategory()122 int GetCategory() const { return category_; } 123 124 /** 125 * \brief Set the category of parameter. 126 * 127 * \param[in] category the category of parameter. 128 */ SetCategory(int category)129 void SetCategory(int category) { category_ = category; } 130 131 /** 132 * \brief Get the description of this parameter. 133 * \return The description. 134 */ ToString()135 std::string ToString() const override { 136 std::string str = (value_ == nullptr ? "" : value_->ToString()) + "\n"; 137 str += (default_value_ == nullptr ? "" : default_value_->ToString()) + "\n"; 138 str += "%" + std::to_string(GetNodeId()) + " = Parameter[" + GetType()->GetName() + "](Name : " + name_; 139 str += " Value : " + (value_ == nullptr ? "Null" : "%" + std::to_string(value_->GetNodeId())); 140 str += 141 " Default Value : " + (default_value_ == nullptr ? "Null" : "%" + std::to_string(default_value_->GetNodeId())) + 142 ")"; 143 return str; 144 } 145 146 private: 147 /// \brief The index of parameter. 148 size_t index_; 149 /// \brief The name of parameter. 150 std::string name_; 151 /// \brief The value of parameter. 152 NodePtr value_; 153 /// \brief The default value of parameter. 154 NodePtr default_value_; 155 /// \brief The category of parameter, 0 : positional, 1 : varargs, 2 : kwonly, 3 : kw 156 int category_; 157 }; 158 159 using ParameterPtr = std::shared_ptr<Parameter>; 160 161 using ParameterPtrList = std::vector<ParameterPtr>; 162 163 /// \brief FunctionNode is is the class which represent a python function or method 164 class FunctionNode : public Node { 165 public: 166 /** 167 * \brief The constructor of function node. 168 * 169 * \param[in] name the name of function. 170 * 171 * \return The instance of function node. 172 */ FunctionNode(const std::string & name)173 explicit FunctionNode(const std::string &name) : FunctionNode(name, {}) {} 174 175 /** 176 * \brief The constructor of function node. 177 * 178 * \param[in] name the name of function. 179 * \param[in] nodes the body of function. 180 * \param[in] use_global the global will be used in code generator. 181 * 182 * \return The instance of function node. 183 */ 184 FunctionNode(const std::string &name, const NodePtrList &nodes, const NodePtr &use_global = nullptr) name_(name)185 : name_(name), nodes_(nodes), use_global_(use_global) {} 186 187 // \brief Destructor. 188 ~FunctionNode() override = default; 189 JIT_DECLARE_PARENT(FunctionNode, Node); 190 191 /** 192 * \brief Set the id of this node. 193 * 194 * \note This method should not be actively called by the program writer, it should only be called by the method 195 * Sort() 196 */ SetNodeId(size_t * id)197 void SetNodeId(size_t *id) override { 198 for (const auto ¶meter : parameters_) { 199 parameter->SetNodeId(id); 200 } 201 for (const auto &node : nodes_) { 202 node->SetNodeId(id); 203 } 204 Node::SetNodeId(id); 205 } 206 207 /** 208 * \brief Set the offset of this node. 209 * 210 * \note This method should not be actively called by the program writer, it should only be called by the method 211 * Sort() 212 */ SetOffset(size_t * offset)213 void SetOffset(size_t *offset) override { 214 /// Inputs must be valueNodes, no need to set offset 215 /// Only the operation need to be set offset 216 for (const auto &node : nodes_) { 217 node->SetOffset(offset); 218 } 219 } 220 221 /** 222 * \brief Get the name of function. 223 * 224 * \return the name of function. 225 */ GetName()226 const std::string &GetName() const { return name_; } 227 228 /** 229 * \brief Get the count of positional args. 230 * 231 * \return the count of positional args. 232 */ GetPosArgsCnt()233 int GetPosArgsCnt() const { return pos_args_cnt_; } 234 235 /** 236 * \brief Set the count of positional args. 237 * 238 * \param[in] cnt the count of positional args. 239 */ SetPosArgsCnt(int cnt)240 void SetPosArgsCnt(int cnt) { pos_args_cnt_ = cnt; } 241 242 /** 243 * \brief Get the count of keyword only args. 244 * 245 * \return the count of keyword only args. 246 */ GetKwOnlyArgsCnt()247 int GetKwOnlyArgsCnt() const { return kw_only_args_cnt_; } 248 249 /** 250 * \brief Set the count of keyword only args. 251 * 252 * \param[in] cnt the count of keyword only args. 253 */ SetKwOnlyArgsCnt(int cnt)254 void SetKwOnlyArgsCnt(int cnt) { kw_only_args_cnt_ = cnt; } 255 256 /** 257 * \brief Get the flags of function. 258 * 259 * \return The flags of function. 260 */ GetFlags()261 int GetFlags() const { return flags_; } 262 263 /** 264 * \brief Set the flags of function. 265 * 266 * \param[in] flags the flags of function. 267 */ SetFlags(uint32_t flags)268 void SetFlags(uint32_t flags) { flags_ = flags; } 269 270 /** 271 * \brief Judgment whether has var args. 272 * 273 * \return The result of the judgment. 274 */ HasVarArg()275 bool HasVarArg() const { return (flags_ & 0x0004) != 0x0; } 276 277 /** 278 * \brief Set whether has var args. 279 * 280 * \param[in] has_var_arg the result of whether has var args. 281 */ SetHasVarArg(bool has_var_arg)282 void SetHasVarArg(bool has_var_arg) { flags_ = has_var_arg ? flags_ | 0x0004 : flags_ & 0xFFFB; } 283 284 /** 285 * \brief Judgment whether has kw args. 286 * 287 * \return The result of the judgment. 288 */ HasKwArg()289 bool HasKwArg() const { return (flags_ & 0x0008) != 0x0; } 290 291 /** 292 * \brief Set whether has kw args. 293 * 294 * \param[in] has_kw_arg the result of whether has kw args. 295 */ SetHasKwArg(bool has_kw_arg)296 void SetHasKwArg(bool has_kw_arg) { flags_ = has_kw_arg ? flags_ | 0x0008 : flags_ & 0xFFF7; } 297 298 /** 299 * \brief Judgment whether has the attr whose name is key. 300 * 301 * \param[in] key the name of the attr. 302 * 303 * \return The result of the judgment. 304 */ HasAttr(const std::string & key)305 bool HasAttr(const std::string &key) const { return attrs_.find(key) != attrs_.end(); } 306 307 /** 308 * \brief Get the value of the attr whose name is key. 309 * 310 * \param[in] key the name of the attr. 311 * 312 * \return The value of the attr. 313 */ GetAttr(const std::string & key)314 bool GetAttr(const std::string &key) const { return HasAttr(key) && attrs_.at(key); } 315 316 /** 317 * \brief Set the attr whose name is key. 318 * 319 * \param[in] key the name of the attr. 320 * \param[in] value the value of the attr. 321 */ SetAttr(const std::string & key,bool value)322 void SetAttr(const std::string &key, bool value) { attrs_[key] = value; } 323 324 /** 325 * \brief Judgment whether need generate parameters. 326 * 327 * \return Whether need generate parameters. 328 */ NeedGenParameters()329 bool NeedGenParameters() const { return !without_params_gen_; } 330 331 /** 332 * \brief Mark no need generate parameters. 333 */ MarkNoNeedGenParameters()334 void MarkNoNeedGenParameters() { without_params_gen_ = true; } 335 336 /** 337 * \brief Get the parameters of function. 338 * 339 * \return the parameters of function. 340 */ GetParameters()341 const NodePtrList &GetParameters() const { return parameters_; } 342 343 /** 344 * \brief Get the parameters of function. 345 * 346 * \return the parameters of function. 347 */ GetParameters()348 NodePtrList &GetParameters() { return parameters_; } 349 350 /** 351 * \brief Get the specified positional parameter of function. 352 * 353 * \return the specified positional parameter of function. 354 */ GetParameter(size_t index)355 const NodePtr &GetParameter(size_t index) const { return parameters_[index]; } 356 357 /** 358 * \brief Add the new input to function node. 359 * 360 * \param[in] parameter the new parameter of function. 361 */ AddParameter(const ParameterPtr & parameter)362 void AddParameter(const ParameterPtr ¶meter) { parameters_.push_back(parameter); } 363 364 /** 365 * \brief Set the specified positional parameters of function. 366 * 367 * \param[in] input the new parameter of function. 368 */ SetParameter(size_t index,const ParameterPtr & parameter)369 void SetParameter(size_t index, const ParameterPtr ¶meter) { parameters_[index] = parameter; } 370 371 /** 372 * \brief Set the parameters of function. 373 * 374 * \param[in] parameters the new parameters of function. 375 */ SetParameters(const NodePtrList & parameters)376 void SetParameters(const NodePtrList ¶meters) { parameters_ = parameters; } 377 378 /** 379 * \brief Get the nodes of function. 380 * 381 * \return the nodes of function. 382 */ GetNodes()383 const NodePtrList &GetNodes() const { return nodes_; } 384 385 /** 386 * \brief Get the nodes of function. 387 * 388 * \return the nodes of function. 389 */ GetNodes()390 NodePtrList &GetNodes() { return nodes_; } 391 392 /** 393 * \brief Add the new node to function node. 394 * 395 * \param[in] node the new node of function. 396 * 397 * \note The node after the return will be ignored. 398 */ AddNode(const NodePtr & node)399 void AddNode(const NodePtr &node) { 400 if (nodes_.empty() || !nodes_.back()->isa<ReturnNode>()) { 401 nodes_.push_back(node); 402 } 403 } 404 405 /** 406 * \brief Get the global will be used in code generator. 407 * 408 * \return The global will be used in code generator. 409 */ GetUseGlobal()410 const NodePtr &GetUseGlobal() const { return use_global_; } 411 412 /** 413 * \brief Get the global will be used in code generator. 414 * 415 * \param[in] use_global the global will be used in code generator. 416 */ SetUseGlobal(const NodePtr & use_global)417 void SetUseGlobal(const NodePtr &use_global) { use_global_ = use_global; } 418 419 /** 420 * \brief Get the file name of the function. 421 * 422 * \return The file name of the function. 423 */ GetFileName()424 const std::string &GetFileName() const { return file_names_[0]; } 425 426 /** 427 * \brief Get the file names of the function, maybe include inline functions. 428 * 429 * \return The the file names of the function. 430 */ GetFileNames()431 const std::vector<std::string> &GetFileNames() const { return file_names_; } 432 433 /** 434 * \brief Add file name to file names of function. 435 * 436 * \param[in] name the file name of sub function. 437 */ AddFileName(const std::string & name)438 void AddFileName(const std::string &name) { file_names_.push_back(name); } 439 440 /** 441 * \brief Get the number of the first line. 442 * 443 * \return the number of the first line. 444 */ GetFirstLineNo()445 int GetFirstLineNo() const { return first_line_no_; } 446 447 /** 448 * \brief Set the number of the first line. 449 * 450 * \param[in] line_no the number of the first line. 451 */ SetFirstLineNo(int line_no)452 void SetFirstLineNo(int line_no) { first_line_no_ = line_no; } 453 454 /** 455 * \brief Get the stack size of function. 456 * 457 * \return The stack size of function. 458 */ GetStackSize()459 int GetStackSize() const { return stack_size_; } 460 461 /** 462 * \brief Set the stack size of function. 463 * 464 * \param[in] size the stack size of function. 465 */ SetStackSize(int size)466 void SetStackSize(int size) { stack_size_ = size; } 467 468 /** 469 * \brief Get the description of this function. 470 * \return The description. 471 */ ToString()472 std::string ToString() const override { 473 std::string str; 474 for (const auto ¶meter : parameters_) { 475 str += parameter->ToString() + "\n"; 476 } 477 str += "%" + std::to_string(GetNodeId()) + " = FunctionNode " + name_ + "("; 478 for (const auto ¶meter : parameters_) { 479 str += "%" + std::to_string(parameter->GetNodeId()) + ", "; 480 } 481 str += ") {\n"; 482 for (const auto &node : nodes_) { 483 str += node->ToString() + "\n"; 484 } 485 str += "}\n"; 486 return str; 487 } 488 489 private: 490 /// \brief The name of function. 491 const std::string name_; 492 /// \brief whether the node represents a method. 493 bool is_method_{false}; 494 /// \brief The count of positional args. 495 int pos_args_cnt_{0}; 496 /// \brief The count of keyword only args. 497 int kw_only_args_cnt_{0}; 498 /// \brief An integer encoding a number of flags for the function. 499 uint32_t flags_{0}; 500 /// \brief the attrs of function. 501 std::map<std::string, bool> attrs_; 502 /// \brief the flag whether generate the parameters of function. 503 bool without_params_gen_{false}; 504 /// \brief The parameters of function 505 NodePtrList parameters_; 506 /// \brief The body of function 507 NodePtrList nodes_; 508 /// \brief The global will be used in code generator 509 NodePtr use_global_; 510 /// \brief The name of the file where the function resides. 511 std::vector<std::string> file_names_; 512 /// \brief The number of the first line. 513 int first_line_no_{0}; 514 /// \brief The size of stack. 515 int stack_size_{0}; 516 }; 517 518 using FunctionNodePtr = std::shared_ptr<FunctionNode>; 519 520 /// \brief IfNode is is the class which represent a if statement 521 class IfNode : public Node { 522 public: 523 /** 524 * \brief The constructor of if node. 525 * 526 * \param[in] condition the condition of if node. 527 * 528 * \return The instance of if node. 529 */ IfNode(const NodePtr & condition)530 explicit IfNode(const NodePtr &condition) : condition_jump_(condition) {} 531 532 // \brief Destructor. 533 ~IfNode() override = default; 534 JIT_DECLARE_PARENT(IfNode, Node); 535 536 /** 537 * \brief Set the id of this node. 538 * 539 * \note This method should not be actively called by the program writer, it should only be called by the method 540 * Sort() 541 */ SetNodeId(size_t * id)542 void SetNodeId(size_t *id) override { 543 condition_jump_->SetNodeId(id); 544 for (const auto &node : then_) { 545 node->SetNodeId(id); 546 } 547 for (const auto &node : else_) { 548 node->SetNodeId(id); 549 } 550 Node::SetNodeId(id); 551 } 552 553 /** 554 * \brief Set the offset of this node. 555 * 556 * \note This method should not be actively called by the program writer, it should only be called by the method 557 * Sort() 558 */ SetOffset(size_t * offset)559 void SetOffset(size_t *offset) override { 560 /// Only the operation need to be set offset 561 condition_jump_->SetOffset(offset); 562 for (const auto &node : then_) { 563 node->SetOffset(offset); 564 } 565 for (const auto &node : else_) { 566 node->SetOffset(offset); 567 } 568 } 569 570 /** 571 * \brief Get the condition of if node. 572 * 573 * \return the condition of if node. 574 */ GetCondition()575 const NodePtr &GetCondition() const { return condition_jump_; } 576 577 /** 578 * \brief Set the condition of if node. 579 * 580 * \param[in] condition the condition of if node. 581 */ SetCondition(const NodePtr & condition)582 void SetCondition(const NodePtr &condition) { condition_jump_ = condition; } 583 584 /** 585 * \brief Get the then body of if node. 586 * 587 * \return the nodes of then body of if node. 588 */ GetThen()589 const NodePtrList &GetThen() const { return then_; } 590 591 /** 592 * \brief Get the then body of if node. 593 * 594 * \return the nodes of then body of if node. 595 */ GetThen()596 NodePtrList &GetThen() { return then_; } 597 598 /** 599 * \brief Add the new node to then body of if node. 600 * 601 * \param[in] node the new node. 602 * 603 * \note The node after the return will be ignored. 604 */ AddThen(const NodePtr & node)605 void AddThen(const NodePtr &node) { 606 if (then_.empty() || !then_.back()->isa<ReturnNode>()) { 607 then_.push_back(node); 608 } 609 } 610 611 /** 612 * \brief Get the else body of if node. 613 * 614 * \return the nodes of else body. 615 */ GetElse()616 const NodePtrList &GetElse() const { return else_; } 617 618 /** 619 * \brief Get the else body of if node. 620 * 621 * \return the nodes of else body. 622 */ GetElse()623 NodePtrList &GetElse() { return else_; } 624 625 /** 626 * \brief Add the new node to else body of if node. 627 * 628 * \param[in] node the new node. 629 * 630 * \note The node after the return will be ignored. 631 */ AddElse(const NodePtr & node)632 void AddElse(const NodePtr &node) { 633 if (else_.empty() || !else_.back()->isa<ReturnNode>()) { 634 else_.push_back(node); 635 } 636 } 637 638 /** 639 * \brief Get the description of this If node. 640 * \return The description. 641 */ ToString()642 std::string ToString() const override { 643 std::string str = condition_jump_->ToString(); 644 str += "%" + std::to_string(GetNodeId()) + " = If (%" + std::to_string(condition_jump_->GetNodeId()) + ") {\n"; 645 for (const auto &node : then_) { 646 str += node->ToString(); 647 } 648 str += "} else {\n"; 649 for (const auto &node : else_) { 650 str += node->ToString(); 651 } 652 str += "}\n"; 653 return str; 654 } 655 656 private: 657 /// \brief The condition of if, it must be a jump. 658 NodePtr condition_jump_; 659 /// \brief The body of if will be executed when no need to jump, maybe empty 660 NodePtrList then_; 661 /// \brief The body of if will be executed when need to jump, maybe empty 662 NodePtrList else_; 663 }; 664 665 using IfNodePtr = std::shared_ptr<IfNode>; 666 667 /// \brief IfNode is is the class which represent a if statement 668 class WhileNode : public Node { 669 public: 670 /** 671 * \brief The constructor of while node. 672 * 673 * \param[in] condition the condition of function. 674 * 675 * \return The instance of while node. 676 */ WhileNode(const NodePtr & condition)677 explicit WhileNode(const NodePtr &condition) : condition_jump_(condition) {} 678 679 // \brief Destructor. 680 ~WhileNode() override = default; 681 JIT_DECLARE_PARENT(WhileNode, Node); 682 683 /** 684 * \brief Set the id of this node. 685 * 686 * \note This method should not be actively called by the program writer, it should only be called by the method 687 * Sort() 688 */ SetNodeId(size_t * id)689 void SetNodeId(size_t *id) override { 690 condition_jump_->SetNodeId(id); 691 for (const auto &node : body_) { 692 node->SetNodeId(id); 693 } 694 Node::SetNodeId(id); 695 } 696 697 /** 698 * \brief Set the offset of this node. 699 * 700 * \note This method should not be actively called by the program writer, it should only be called by the method 701 * Sort() 702 */ SetOffset(size_t * offset)703 void SetOffset(size_t *offset) override { 704 /// Only the operation need to be set offset 705 condition_jump_->SetOffset(offset); 706 for (const auto &node : body_) { 707 node->SetOffset(offset); 708 } 709 } 710 711 /** 712 * \brief Get the condition of while node. 713 * 714 * \return the condition of while node. 715 */ GetCondition()716 const NodePtr &GetCondition() const { return condition_jump_; } 717 718 /** 719 * \brief Set the condition of while node. 720 * 721 * \param[in] condition the condition of while node. 722 */ SetCondition(const NodePtr & condition)723 void SetCondition(const NodePtr &condition) { condition_jump_ = condition; } 724 725 /** 726 * \brief Get the body of while node. 727 * 728 * \return the body nodes of while node. 729 */ GetBody()730 const NodePtrList &GetBody() const { return body_; } 731 732 /** 733 * \brief Get the body of while node. 734 * 735 * \return the body nodes of while node. 736 */ GetBody()737 NodePtrList &GetBody() { return body_; } 738 739 /** 740 * \brief Add the new node to body of while node. 741 * 742 * \param[in] node the new node. 743 */ AddBody(const NodePtr & node)744 void AddBody(const NodePtr &node) { body_.push_back(node); } 745 746 /** 747 * \brief Set the new nodes as body of while node. 748 * 749 * \param[in] nodes the new nodes of then body. 750 */ SetBody(const NodePtrList & nodes)751 void SetBody(const NodePtrList &nodes) { body_ = nodes; } 752 753 /** 754 * \brief Get the description of this While node. 755 * \return The description. 756 */ ToString()757 std::string ToString() const override { 758 std::string str = condition_jump_->ToString(); 759 str += "%" + std::to_string(GetNodeId()) + " = While (%" + std::to_string(condition_jump_->GetNodeId()) + ") {"; 760 for (const auto &node : body_) { 761 str += node->ToString(); 762 } 763 str += "}\n"; 764 return str; 765 } 766 767 private: 768 /// \brief The condition of while, it must be a jump. 769 NodePtr condition_jump_; 770 /// \brief The body executed in a loop. 771 NodePtrList body_; 772 }; 773 774 using WhileNodePtr = std::shared_ptr<WhileNode>; 775 } // namespace ir 776 } // namespace pijit 777 } // namespace mindspore 778 779 #endif // MINDSPORE_PI_JIT_CTRL_FLOW_H_ 780