• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_CUSTOM_NODES_H_
17 #define MINDSPORE_PI_JIT_CUSTOM_NODES_H_
18 
19 #include <memory>
20 #include <string>
21 #include "pipeline/jit/pi/pydef.h"
22 #include "pipeline/jit/pi/graph_compiler/pi_ir/operation.h"
23 
24 namespace mindspore {
25 namespace pijit {
26 namespace ir {
27 /// \brief RefNode is the class which represent that this node is defined elsewhere and is only used here.
28 class RefNode : public Node {
29  public:
30   /**
31    * \brief The constructor of reference node.
32    *
33    * \return The instance of reference node.
34    */
RefNode(const NodePtr & node)35   explicit RefNode(const NodePtr &node) : real_node_(node) {}
36 
37   // \brief Destructor.
38   ~RefNode() override = default;
39   JIT_DECLARE_PARENT(RefNode, Node);
40 
41   /**
42    * \brief Get the node this reference node represents.
43    *
44    * \return The node this reference node represents.
45    */
GetRealNode()46   const NodePtr &GetRealNode() const { return real_node_; }
47 
48   /**
49    * \brief Set the real node of the ref node.
50    *
51    * \param[in] node the real object.
52    */
SetRealNode(const NodePtr & node)53   void SetRealNode(const NodePtr &node) { real_node_ = node; }
54 
55   /**
56    * \brief Get the description of this node.
57    * \return The description.
58    */
ToString()59   std::string ToString() const override {
60     return "%" + std::to_string(GetNodeId()) + " = [" + GetType()->GetName() + "](" + GetNodeName() + ", " +
61            std::to_string(real_node_->GetNodeId()) + ")\n";
62   }
63 
64  private:
65   /// \brief The node this reference node represents
66   NodePtr real_node_;
67 };
68 
69 using RefNodePtr = std::shared_ptr<RefNode>;
70 
71 /// \brief PlaceHolder is the class which represent a symbol, and don't care about object specific information.
72 class PlaceHolder : public Node {
73  public:
74   /**
75    * \brief The constructor of PlaceHolder node.
76    *
77    * \return The instance of PlaceHolder node.
78    */
PlaceHolder(const std::string & tag)79   explicit PlaceHolder(const std::string &tag) : tag_(tag) {}
80 
81   // \brief Destructor.
82   ~PlaceHolder() override = default;
83   JIT_DECLARE_PARENT(PlaceHolder, Node);
84 
85   /**
86    * \brief Get the tag of PlaceHolder node.
87    *
88    * \return The tag of PlaceHolder node.
89    */
GetTag()90   const std::string &GetTag() const { return tag_; }
91 
92   /**
93    * \brief Set the id of this node.
94    *
95    * \note This method should not be actively called by the program writer, it should only be called by the method
96    * Sort()
97    */
SetNodeId(size_t * id)98   void SetNodeId(size_t *id) override {}
99 
100   /**
101    * \brief Get the description of this node.
102    * \return The description.
103    */
ToString()104   std::string ToString() const override {
105     return "%" + std::to_string(GetNodeId()) + " = [" + GetType()->GetName() + "](" + GetNodeName() + ", " + tag_ +
106            ")\n";
107   }
108 
109  private:
110   /// \brief The mark of PlaceHolder used to explain the special meaning
111   const std::string tag_;
112 };
113 
114 using PlaceHolderPtr = std::shared_ptr<PlaceHolder>;
115 
116 /// \brief SubscrNode is the class which represent a subscript access of object.
117 class SubscrNode : public Node {
118  public:
119   /**
120    * \brief The constructor of subscript node.
121    *
122    * \param[in] base the object being accessed.
123    * \param[in] subscr the subscript.
124    *
125    * \return The instance of subscript node.
126    */
SubscrNode(const NodePtr & base,const NodePtr & subscr)127   SubscrNode(const NodePtr &base, const NodePtr &subscr) : base_(base), subscr_(subscr) {}
128 
129   // \brief Destructor.
130   ~SubscrNode() override = default;
131   JIT_DECLARE_PARENT(SubscrNode, Node);
132 
133   /**
134    * \brief Get the object being accessed.
135    *
136    * \return The object being accessed.
137    */
GetObject()138   const NodePtr &GetObject() const { return base_; }
139 
140   /**
141    * \brief Set the the object being accessed.
142    *
143    * \param[in] obj the object.
144    */
SetObject(const NodePtr & obj)145   void SetObject(const NodePtr &obj) { base_ = obj; }
146 
147   /**
148    * \brief Get the subscr want to accessed.
149    *
150    * \return The subscr want to accessed.
151    */
GetSubscr()152   const NodePtr &GetSubscr() const { return subscr_; }
153 
154   /**
155    * \brief Set the subscr want to accessed.
156    *
157    * \param[in] subscr the element.
158    */
SetSubscr(const NodePtr & subscr)159   void SetSubscr(const NodePtr &subscr) { subscr_ = subscr; }
160 
161   /**
162    * \brief Set the id of this node.
163    *
164    * \note This method should not be actively called by the program writer, it should only be called by the method
165    * Sort()
166    */
SetNodeId(size_t * id)167   void SetNodeId(size_t *id) override {
168     base_->SetNodeId(id);
169     subscr_->SetNodeId(id);
170   }
171 
172   /**
173    * \brief Set the offset of this node.
174    *
175    * \note This method should not be actively called by the program writer, it should only be called by the method
176    * Sort()
177    */
SetOffset(size_t * offset)178   void SetOffset(size_t *offset) override {
179     base_->SetOffset(offset);
180     subscr_->SetOffset(offset);
181   }
182 
183   /**
184    * \brief Get the description of this node.
185    * \return The description.
186    */
ToString()187   std::string ToString() const override {
188     return base_->ToString() + "\n" + subscr_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = %" +
189            std::to_string(base_->GetNodeId()) + "[%" + std::to_string(subscr_->GetNodeId()) + "]\n";
190   }
191 
192  private:
193   NodePtr base_;
194   NodePtr subscr_;
195 };
196 
197 using SubscrNodePtr = std::shared_ptr<SubscrNode>;
198 
199 /// \brief SubscrNode is the class which represent a attr or method of the object.
200 class AttrNode : public Node {
201  public:
202   /**
203    * \brief The constructor of attribute node.
204    *
205    * \param[in] base the object being accessed.
206    * \param[in] attr the attribute name.
207    *
208    * \return The instance of attribute node.
209    */
AttrNode(const NodePtr & base,const NodePtr & attr)210   AttrNode(const NodePtr &base, const NodePtr &attr) : base_(base), attr_(attr) {}
211 
212   // \brief Destructor.
213   ~AttrNode() override = default;
214   JIT_DECLARE_PARENT(AttrNode, Node);
215 
216   /**
217    * \brief Get the object being accessed.
218    *
219    * \return The object being accessed.
220    */
GetObject()221   const NodePtr &GetObject() const { return base_; }
222 
223   /**
224    * \brief Set the object being accessed.
225    *
226    * \param[in] obj the object.
227    */
SetObject(const NodePtr & obj)228   void SetObject(const NodePtr &obj) { base_ = obj; }
229 
230   /**
231    * \brief Get the attribute name of the object.
232    *
233    * \return The attribute name of the object.
234    */
GetAttr()235   const NodePtr &GetAttr() const { return attr_; }
236 
237   /**
238    * \brief Set the attribute name of the object.
239    *
240    * \param[in] attr the attribute name.
241    */
SetAttr(const NodePtr & attr)242   void SetAttr(const NodePtr &attr) { attr_ = attr; }
243 
244   /**
245    * \brief Set the id of this node.
246    *
247    * \note This method should not be actively called by the program writer, it should only be called by the method
248    * Sort()
249    */
SetNodeId(size_t * id)250   void SetNodeId(size_t *id) override {
251     base_->SetNodeId(id);
252     attr_->SetNodeId(id);
253     Node::SetNodeId(id);
254   }
255 
256   /**
257    * \brief Set the offset of this node.
258    *
259    * \note This method should not be actively called by the program writer, it should only be called by the method
260    * Sort()
261    */
SetOffset(size_t * offset)262   void SetOffset(size_t *offset) override { base_->SetOffset(offset); }
263 
264   /**
265    * \brief Get the description of this node.
266    * \return The description.
267    */
ToString()268   std::string ToString() const override {
269     return base_->ToString() + "\n" + attr_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = %" +
270            std::to_string(base_->GetNodeId()) + ".%" + std::to_string(attr_->GetNodeId()) + "\n";
271   }
272 
273  private:
274   NodePtr base_;
275   NodePtr attr_;
276 };
277 
278 using AttrNodePtr = std::shared_ptr<AttrNode>;
279 
280 /// \brief PairNode is the class which represent the object subscript access.
281 class PairNode : public Node {
282  public:
283   /**
284    * \brief The constructor of pair node.
285    *
286    * \param[in] first the first element of the pair.
287    * \param[in] second the second element of the pair.
288    *
289    * \return The instance of pair node.
290    */
PairNode(const NodePtr & first,const NodePtr & second)291   PairNode(const NodePtr &first, const NodePtr &second) : first_(first), second_(second) {}
292 
293   // \brief Destructor.
294   ~PairNode() override = default;
295   JIT_DECLARE_PARENT(PairNode, Node);
296 
297   /**
298    * \brief Get the first element of the pair.
299    *
300    * \return The first element of the pair.
301    */
GetFirst()302   const NodePtr &GetFirst() const { return first_; }
303 
304   /**
305    * \brief Set the first element of the pair.
306    *
307    * \param[in] arg the element.
308    */
SetFirst(const NodePtr & arg)309   void SetFirst(const NodePtr &arg) { first_ = arg; }
310 
311   /**
312    * \brief Get the second element of the pair.
313    *
314    * \return The second element of the pair.
315    */
GetSecond()316   const NodePtr &GetSecond() const { return second_; }
317 
318   /**
319    * \brief Set the second element of the pair.
320    *
321    * \param[in] arg the element.
322    */
SetSecond(const NodePtr & arg)323   void SetSecond(const NodePtr &arg) { second_ = arg; }
324 
325   /**
326    * \brief Get the description of this node.
327    * \return The description.
328    */
ToString()329   std::string ToString() const override {
330     return first_->ToString() + "\n" + second_->ToString() + "\n%" + std::to_string(GetNodeId()) + " = (" +
331            std::to_string(first_->GetNodeId()) + ", " + std::to_string(second_->GetNodeId()) + ")\n";
332   }
333 
334  private:
335   NodePtr first_;
336   NodePtr second_;
337 };
338 
339 using PairNodePtr = std::shared_ptr<PairNode>;
340 
341 /// \brief InstrArg is the base class which represent the arg of instruction.
342 class InstrArg {
343  public:
344   /**
345    * \brief The constructor of InstrArg.
346    *
347    * \param[in] arg the value of arg.
348    *
349    * \return The instance of InstrArg.
350    */
InstrArg(int arg)351   explicit InstrArg(int arg) : instr_arg_(arg) {}
352   // \brief Destructor.
353   virtual ~InstrArg() = default;
354 
355   /**
356    * \brief Get the value of the instruction arg.
357    *
358    * \return The value of the instruction arg.
359    */
GetInstrArg()360   int GetInstrArg() const { return instr_arg_; }
361 
362   /**
363    * \brief Set the value of the instruction arg.
364    *
365    * \param[in] arg the value of the instruction arg.
366    */
SetInstrArg(int arg)367   void SetInstrArg(int arg) { instr_arg_ = arg; }
368 
369  private:
370   /// \brief The value of the instruction arg.
371   int instr_arg_;
372 };
373 
374 /// \brief NegativeNode is the class which represent operation that take negative value.
375 class NegativeNode : public UnaryOperation {
376  public:
377   /**
378    * \brief The constructor of negative node.
379    *
380    * \param[in] opnd the value of negative node.
381    *
382    * \return The instance of negative node.
383    */
NegativeNode(const NodePtr & opnd)384   explicit NegativeNode(const NodePtr &opnd) : UnaryOperation(UNARY_NEGATIVE, opnd) {}
385 
386   // \brief Destructor.
387   ~NegativeNode() override = default;
388   JIT_DECLARE_PARENT(NegativeNode, UnaryOperation);
389 };
390 
391 using NegativeNodePtr = std::shared_ptr<NegativeNode>;
392 
393 /// \brief NotNode is the class which represent the operation that take logical negation.
394 class NotNode : public UnaryOperation {
395  public:
396   /**
397    * \brief The constructor of logical not node.
398    *
399    * \param[in] opnd the value of logical not node.
400    *
401    * \return The instance of logical not node.
402    */
NotNode(const NodePtr & opnd)403   explicit NotNode(const NodePtr &opnd) : UnaryOperation(UNARY_NOT, opnd) {}
404 
405   // \brief Destructor.
406   ~NotNode() override = default;
407   JIT_DECLARE_PARENT(NotNode, UnaryOperation);
408 };
409 
410 using NotNodePtr = std::shared_ptr<NotNode>;
411 
412 /// \brief InvertNode is the class which represent the operation that take bitwise inversion.
413 class InvertNode : public UnaryOperation {
414  public:
415   /**
416    * \brief The constructor of invert node.
417    *
418    * \param[in] opnd the value of invert node.
419    *
420    * \return The instance of invert node.
421    */
InvertNode(const NodePtr & opnd)422   explicit InvertNode(const NodePtr &opnd) : UnaryOperation(UNARY_INVERT, opnd) {}
423 
424   // \brief Destructor.
425   ~InvertNode() override = default;
426   JIT_DECLARE_PARENT(InvertNode, UnaryOperation);
427 };
428 
429 using InvertNodePtr = std::shared_ptr<InvertNode>;
430 
431 /// \brief ReturnNode is the class which represent the return of function.
432 class ReturnNode : public UnaryOperation {
433  public:
434   /**
435    * \brief The constructor of return node.
436    *
437    * \param[in] res the value of return node.
438    *
439    * \return The instance of return node.
440    */
ReturnNode(const NodePtr & res)441   explicit ReturnNode(const NodePtr &res) : UnaryOperation(RETURN_VALUE, res) {}
442 
443   // \brief Destructor.
444   ~ReturnNode() override = default;
445   JIT_DECLARE_PARENT(ReturnNode, UnaryOperation);
446 
447   /**
448    * \brief Get the value of return node.
449    *
450    * \return the return value.
451    */
GetReturn()452   const NodePtr &GetReturn() const { return GetArg(); }
453 };
454 
455 using ReturnNodePtr = std::shared_ptr<ReturnNode>;
456 
457 /// \brief CastNode is the class which represent convert one type to another.
458 class CastNode : public UnaryOperation {
459  public:
460   /**
461    * \brief The constructor of cast node.
462    *
463    * \param[in] opnd the value of cast node.
464    *
465    * \return The instance of cast node.
466    */
CastNode(const NodePtr & opnd)467   explicit CastNode(const NodePtr &opnd) : UnaryOperation(LIST_TO_TUPLE, opnd) {}
468 
469   // \brief Destructor.
470   ~CastNode() override = default;
471   JIT_DECLARE_PARENT(CastNode, UnaryOperation);
472 };
473 
474 using CastNodePtr = std::shared_ptr<CastNode>;
475 
476 /// \brief DeleteNode is the class which represent delete a object.
477 class DeleteNode : public UnaryOperation {
478  public:
479   /**
480    * \brief The constructor of delete node.
481    *
482    * \param[in] opnd the object will be deleted.
483    *
484    * \return The instance of cast node.
485    */
DeleteNode(OpCode op,const NodePtr & opnd)486   explicit DeleteNode(OpCode op, const NodePtr &opnd) : UnaryOperation(op, opnd) {}
487 
488   // \brief Destructor.
489   ~DeleteNode() override = default;
490   JIT_DECLARE_PARENT(DeleteNode, UnaryOperation);
491 };
492 
493 using DeleteNodePtr = std::shared_ptr<DeleteNode>;
494 
495 /// \brief GetNode is the class which represent get a property of an object with `Get_*`.
496 class GetNode : public UnaryOperation {
497  public:
498   /**
499    * \brief The constructor of get node.
500    *
501    * \param[in] opnd the object.
502    *
503    * \return The instance of get node.
504    */
GetNode(OpCode op,const NodePtr & opnd)505   explicit GetNode(OpCode op, const NodePtr &opnd) : UnaryOperation(op, opnd) {}
506 
507   // \brief Destructor.
508   ~GetNode() override = default;
509   JIT_DECLARE_PARENT(GetNode, UnaryOperation);
510 };
511 
512 using GetNodePtr = std::shared_ptr<GetNode>;
513 
514 /// \brief LoadValueNode is the class which represent load a value to stack.
515 class LoadValueNode : public UnaryOperation {
516  public:
517   /**
518    * \brief The constructor of load node.
519    *
520    * \param[in] value the value will be load.
521    *
522    * \return The instance of load node.
523    */
LoadValueNode(OpCode op,const NodePtr & value)524   LoadValueNode(OpCode op, const NodePtr &value) : UnaryOperation(op, value) {}
525 
526   // \brief Destructor.
527   ~LoadValueNode() override = default;
528   JIT_DECLARE_PARENT(LoadValueNode, NaryOperation);
529 };
530 
531 using LoadValueNodePtr = std::shared_ptr<LoadValueNode>;
532 
533 /// \brief LoadFieldNode is the class which represent load a filed of class to stack.
534 class LoadFieldNode : public BinaryOperation {
535  public:
536   /**
537    * \brief The constructor of load node.
538    *
539    * \param[in] cls_ins the instance of class.
540    * \param[in] field the field will be load.
541    *
542    * \return The instance of load node.
543    */
LoadFieldNode(OpCode op,const NodePtr & cls_ins,const NodePtr & field)544   LoadFieldNode(OpCode op, const NodePtr &cls_ins, const NodePtr &field) : BinaryOperation(op, cls_ins, field) {}
545 
546   // \brief Destructor.
547   ~LoadFieldNode() override = default;
548   JIT_DECLARE_PARENT(LoadFieldNode, BinaryOperation);
549 };
550 
551 using LoadFieldNodePtr = std::shared_ptr<LoadFieldNode>;
552 
553 /// \brief AddNode is the class which represent the addition of two operands.
554 class AddNode : public BinaryOperation {
555  public:
556   /**
557    * \brief The constructor of add node.
558    *
559    * \param[in] left the first operand of add.
560    * \param[in] right the second operand of add.
561    * \param[in] is_inplace whether the sum store to the first operand.
562    *
563    * \return The instance of add node.
564    */
AddNode(OpCode op,const NodePtr & left,const NodePtr & right)565   AddNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {}
566 
567   // \brief Destructor.
568   ~AddNode() override = default;
569   JIT_DECLARE_PARENT(AddNode, BinaryOperation);
570 
571   /**
572    * \brief Judge whether the opcode of this node is INPLACE_ADD.
573    *
574    * \return The result of the judgment.
575    */
IsInplace()576   bool IsInplace() const { return INPLACE_ADD == GetOpCode(); }
577 };
578 
579 using AddNodePtr = std::shared_ptr<AddNode>;
580 
581 /// \brief SubNode is the class which represent the subtraction of two operands.
582 class SubNode : public BinaryOperation {
583  public:
584   /**
585    * \brief The constructor of sub node.
586    *
587    * \param[in] left the first operand of sub.
588    * \param[in] right the second operand of sub.
589    * \param[in] is_inplace whether the difference store to the first operand.
590    *
591    * \return The instance of sub node.
592    */
SubNode(OpCode op,const NodePtr & left,const NodePtr & right)593   SubNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {}
594 
595   // \brief Destructor.
596   ~SubNode() override = default;
597   JIT_DECLARE_PARENT(SubNode, BinaryOperation);
598 
599   /**
600    * \brief Judge whether the opcode of this node is INPLACE_ADD.
601    *
602    * \return The result of the judgment.
603    */
IsInplace()604   bool IsInplace() const { return INPLACE_SUBTRACT == GetOpCode(); }
605 };
606 
607 using SubNodePtr = std::shared_ptr<SubNode>;
608 
609 /// \brief MulNode is the class which represent the multiplication of two operands.
610 class MulNode : public BinaryOperation {
611  public:
612   /**
613    * \brief The constructor of mul node.
614    *
615    * \param[in] left the first operand of mul.
616    * \param[in] right the second operand of mul.
617    * \param[in] is_inplace whether the product store to the first operand.
618    *
619    * \return The instance of mul node.
620    */
MulNode(OpCode op,const NodePtr & left,const NodePtr & right)621   MulNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {}
622 
623   // \brief Destructor.
624   ~MulNode() override = default;
625   JIT_DECLARE_PARENT(MulNode, BinaryOperation);
626 
627   /**
628    * \brief Judge whether the opcode of this node is INPLACE_MULTIPLY.
629    *
630    * \return The result of the judgment.
631    */
IsInplace()632   bool IsInplace() const { return (INPLACE_MULTIPLY == GetOpCode()) || (INPLACE_MATRIX_MULTIPLY == GetOpCode()); }
633 };
634 
635 using MulNodePtr = std::shared_ptr<MulNode>;
636 
637 /// \brief DivNode is the class which represent the division of two operands.
638 class DivNode : public BinaryOperation {
639  public:
640   /**
641    * \brief The constructor of div node.
642    *
643    * \param[in] left the first operand of div.
644    * \param[in] right the second operand of div.
645    * \param[in] is_inplace whether the quotient of division store to the first operand.
646    *
647    * \return The instance of div node.
648    */
DivNode(OpCode op,const NodePtr & left,const NodePtr & right)649   DivNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {}
650 
651   // \brief Destructor.
652   ~DivNode() override = default;
653   JIT_DECLARE_PARENT(DivNode, BinaryOperation);
654 
655   /**
656    * \brief Judge whether the opcode of this node is INPLACE_TRUE_DIVIDE.
657    *
658    * \return The result of the judgment.
659    */
IsInplace()660   bool IsInplace() const { return INPLACE_TRUE_DIVIDE == GetOpCode(); }
661 };
662 
663 using DivNodePtr = std::shared_ptr<DivNode>;
664 
665 /// \brief BitwiseNode is the class which represent the addition of two operands.
666 class BitwiseNode : public BinaryOperation {
667  public:
668   /**
669    * \brief The constructor of add node.
670    *
671    * \param[in] left the first operand of add.
672    * \param[in] right the second operand of add.
673    * \param[in] is_inplace whether the sum store to the first operand.
674    *
675    * \return The instance of add node.
676    */
BitwiseNode(OpCode op,const NodePtr & left,const NodePtr & right)677   BitwiseNode(OpCode op, const NodePtr &left, const NodePtr &right) : BinaryOperation(op, left, right) {}
678 
679   // \brief Destructor.
680   ~BitwiseNode() override = default;
681   JIT_DECLARE_PARENT(BitwiseNode, BinaryOperation);
682 
683   /**
684    * \brief Judge whether the opcode of this node is INPLACE_ADD.
685    *
686    * \return The result of the judgment.
687    */
IsInplace()688   bool IsInplace() const {
689     return INPLACE_LSHIFT == GetOpCode() || INPLACE_RSHIFT == GetOpCode() || INPLACE_AND == GetOpCode() ||
690            INPLACE_XOR == GetOpCode() || INPLACE_OR == GetOpCode();
691   }
692 };
693 
694 using BitwiseNodePtr = std::shared_ptr<BitwiseNode>;
695 
696 /// \brief IsNode is the class which represent whether two operands are same or not.
697 class IsNode : public BinaryOperation, public InstrArg {
698  public:
699   /**
700    * \brief The constructor of is node.
701    *
702    * \param[in] left the first operand of is node.
703    * \param[in] right the second operand of is node.
704    * \param[in] is_invert the flag whether invert the result.
705    *
706    * \return The instance of is node.
707    */
IsNode(const NodePtr & left,const NodePtr & right,int arg)708   IsNode(const NodePtr &left, const NodePtr &right, int arg) : BinaryOperation(IS_OP, left, right), InstrArg(arg) {}
709 
710   // \brief Destructor.
711   ~IsNode() override = default;
712   JIT_DECLARE_PARENT(IsNode, BinaryOperation);
713 
714   /**
715    * \brief Judge whether invert the result.
716    *
717    * \return The result of the judgment.
718    */
IsInvert()719   bool IsInvert() const { return GetInstrArg() != 0; }
720 };
721 
722 using IsNodePtr = std::shared_ptr<IsNode>;
723 
724 /// \brief ContainsNode is the class which represent whether one contains another or not.
725 class ContainsNode : public BinaryOperation, public InstrArg {
726  public:
727   /**
728    * \brief The constructor of is node.
729    *
730    * \param[in] left the first operand of is node.
731    * \param[in] right the second operand of is node.
732    * \param[in] is_invert the flag whether invert the result.
733    *
734    * \return The instance of contains node.
735    */
ContainsNode(const NodePtr & left,const NodePtr & right,int arg)736   ContainsNode(const NodePtr &left, const NodePtr &right, int arg)
737       : BinaryOperation(CONTAINS_OP, left, right), InstrArg(arg) {}
738 
739   // \brief Destructor.
740   ~ContainsNode() override = default;
741   JIT_DECLARE_PARENT(ContainsNode, BinaryOperation);
742 
743   /**
744    * \brief Judge whether invert the result.
745    *
746    * \return The result of the judgment.
747    */
IsInvert()748   bool IsInvert() const { return GetInstrArg() != 0; }
749 };
750 
751 using ContainsNodePtr = std::shared_ptr<ContainsNode>;
752 
753 /// \brief StoreNode is the class which represent whether two operands are same.
754 class StoreNode : public BinaryOperation {
755  public:
756   /**
757    * \brief The constructor of store node.
758    *
759    * \param[in] left the first operand of store node.
760    * \param[in] right the second operand of store node.
761    *
762    * \return The instance of store node.
763    */
StoreNode(OpCode op,const NodePtr & source,const NodePtr & target)764   StoreNode(OpCode op, const NodePtr &source, const NodePtr &target) : BinaryOperation(op, source, target) {}
765 
766   // \brief Destructor.
767   ~StoreNode() override = default;
768   JIT_DECLARE_PARENT(StoreNode, BinaryOperation);
769 };
770 
771 using StoreNodePtr = std::shared_ptr<StoreNode>;
772 
773 /// \brief JumpNode is the class which represent jump stmt.
774 class JumpNode : public BinaryOperation {
775  public:
776   /**
777    * \brief The constructor of jump node.
778    *
779    * \param[in] condition the condition for judging whether to jump.
780    * \param[in] target the jump target.
781    *
782    * \return The instance of jump node.
783    */
JumpNode(OpCode op,const NodePtr & condition,const NodePtr & target)784   JumpNode(OpCode op, const NodePtr &condition, const NodePtr &target) : BinaryOperation(op, condition, target) {}
785 
786   // \brief Destructor.
787   ~JumpNode() override = default;
788   JIT_DECLARE_PARENT(JumpNode, BinaryOperation);
789 
790   /**
791    * \brief Get the condition for judging whether to jump.
792    *
793    * \return The condition for judging whether to jump.
794    */
GetCondition()795   NodePtr GetCondition() const { return GetLeftArg(); }
796 
797   /**
798    * \brief Get the target for jump.
799    *
800    * \return The target for jump.
801    */
GetTarget()802   NodePtr GetTarget() const { return GetRightArg(); }
803 
804   /**
805    * \brief Set the target of jump.
806    *
807    * \param[in] target the jump target.
808    */
SetTarget(const NodePtr & target)809   void SetTarget(const NodePtr &target) { SetRightArg(target); }
810 
811   /**
812    * \brief Set the id of this node.
813    *
814    * \note This method should not be actively called by the program writer, it should only be called by the method
815    * Sort()
816    */
SetNodeId(size_t * id)817   void SetNodeId(size_t *id) override {
818     auto left = GetLeftArg();
819     if (left != nullptr) {
820       left->SetNodeId(id);
821     }
822     Node::SetNodeId(id);
823   }
824 
825   /**
826    * \brief Set the offset of this node.
827    *
828    * \note This method should not be actively called by the program writer, it should only be called by the method
829    * Sort()
830    */
SetOffset(size_t * offset)831   void SetOffset(size_t *offset) override {
832     auto left = GetLeftArg();
833     if (left != nullptr) {
834       left->SetOffset(offset);
835     }
836     Node::SetOffset(offset);
837   }
838 
839   /**
840    * \brief Get the description of this jump node.
841    * \return The description.
842    */
ToString()843   std::string ToString() const override {
844     std::string str;
845     auto left = GetLeftArg();
846     if (left != nullptr) {
847       str += left->ToString() + "\n";
848     }
849     str += "%" + std::to_string(GetNodeId()) + " = " + GetNodeName() + "[" + GetType()->GetName() + "](" +
850            GetOpName(GetOpCode());
851     if (left != nullptr) {
852       str += ", %" + std::to_string(left->GetNodeId());
853     } else {
854       str += ", nullptr";
855     }
856     auto right = GetRightArg();
857     if (right != nullptr) {
858       str += ", %" + std::to_string(right->GetNodeId());
859     } else {
860       str += ", nullptr";
861     }
862     return str + ")\n";
863   }
864 };
865 
866 using JumpNodePtr = std::shared_ptr<JumpNode>;
867 
868 class CompareNode : public BinaryOperation, public InstrArg {
869  public:
870   /**
871    * \brief The constructor of compare node.
872    *
873    * \param[in] category the category of compare.
874    * \param[in] left the first operand of compare node.
875    * \param[in] right the second operand of compare node.
876    *
877    * \return The instance of compare node.
878    */
CompareNode(int arg,const NodePtr & left,const NodePtr & right)879   CompareNode(int arg, const NodePtr &left, const NodePtr &right)
880       : BinaryOperation(COMPARE_OP, left, right), InstrArg(arg) {}
881 
882   // \brief Destructor.
883   ~CompareNode() override = default;
884   JIT_DECLARE_PARENT(CompareNode, BinaryOperation);
885 
886   /**
887    * \brief Get the description of this node.
888    * \return The description.
889    */
ToString()890   std::string ToString() const override {
891     auto left = GetLeftArg();
892     auto right = GetRightArg();
893     return left->ToString() + "\n" + right->ToString() + "\n%" + std::to_string(GetNodeId()) + " = " + GetNodeName() +
894            "[" + GetType()->GetName() + "](" + GetOpName(GetOpCode()) + ", " + std::to_string(GetInstrArg()) + ", %" +
895            std::to_string(left->GetNodeId()) + ", %" + std::to_string(right->GetNodeId()) + ")\n";
896   }
897 };
898 
899 using CompareNodePtr = std::shared_ptr<CompareNode>;
900 
901 /// \brief CallNode is the class which represent merge several dicts/lists into one.
902 class UpdateNode : public BinaryOperation, public InstrArg {
903  public:
904   /**
905    * \brief The constructor of build node.
906    *
907    * \param[in] opnds the operand of build node.
908    *
909    * \return The instance of build node.
910    */
UpdateNode(OpCode op,const NodePtr & left,const NodePtr & right,int arg)911   UpdateNode(OpCode op, const NodePtr &left, const NodePtr &right, int arg)
912       : BinaryOperation(op, left, right), InstrArg(arg) {}
913 
914   // \brief Destructor.
915   ~UpdateNode() override = default;
916   JIT_DECLARE_PARENT(UpdateNode, BinaryOperation);
917 };
918 
919 using UpdateNodePtr = std::shared_ptr<UpdateNode>;
920 
921 /// \brief FormatNode is the class which represent format an object as required.
922 class FormatNode : public NaryOperation, public InstrArg {
923  public:
924   /**
925    * \brief The constructor of format node.
926    *
927    * \param[in] opnd the value of format node.
928    * \param[in] fmt the format type.
929    *
930    * \return The instance of format node.
931    */
FormatNode(const NodePtrList & opnds,int fmt)932   FormatNode(const NodePtrList &opnds, int fmt) : NaryOperation(FORMAT_VALUE, opnds), InstrArg(fmt) {}
933 
934   // \brief Destructor.
935   ~FormatNode() override = default;
936   JIT_DECLARE_PARENT(FormatNode, NaryOperation);
937 
938   /**
939    * \brief Get the format type of format node.
940    *
941    * \return the format type.
942    */
GetFormatType()943   int GetFormatType() const { return GetInstrArg(); }
944 };
945 
946 using FormatNodePtr = std::shared_ptr<FormatNode>;
947 
948 /// \brief BuildNode is the class which represent build a value.
949 class BuildNode : public NaryOperation {
950  public:
951   /**
952    * \brief The constructor of build node.
953    *
954    * \param[in] opnds the operand of build node.
955    *
956    * \return The instance of build node.
957    */
BuildNode(OpCode op,const NodePtrList & opnds)958   BuildNode(OpCode op, const NodePtrList &opnds) : NaryOperation(op, opnds) {}
959 
960   // \brief Destructor.
961   ~BuildNode() override = default;
962   JIT_DECLARE_PARENT(BuildNode, NaryOperation);
963 };
964 
965 using BuildNodePtr = std::shared_ptr<BuildNode>;
966 
967 /// \brief CallNode is the class which represent call a function.
968 class CallNode : public NaryOperation {
969  public:
970   /**
971    * \brief The constructor of build node.
972    *
973    * \param[in] opnds the operand of build node.
974    *
975    * \return The instance of build node.
976    */
CallNode(OpCode op,const NodePtrList & opnds)977   CallNode(OpCode op, const NodePtrList &opnds) : NaryOperation(op, opnds) {}
978 
979   // \brief Destructor.
980   ~CallNode() override = default;
981   JIT_DECLARE_PARENT(CallNode, NaryOperation);
982 };
983 
984 using CallNodePtr = std::shared_ptr<CallNode>;
985 
986 /// \brief NaryWithFlagNode is the class which represent make function.
987 class NaryWithFlagNode : public NaryOperation, public InstrArg {
988  public:
989   /**
990    * \brief The constructor of nary with flag node.
991    *
992    * \param[in] opnds the operand of nary with flag node.
993    *
994    * \return The instance of nary with flag node.
995    */
NaryWithFlagNode(OpCode op,const NodePtrList & opnds,int flag)996   NaryWithFlagNode(OpCode op, const NodePtrList &opnds, int flag) : NaryOperation(op, opnds), InstrArg(flag) {}
997 
998   // \brief Destructor.
999   ~NaryWithFlagNode() override = default;
1000   JIT_DECLARE_PARENT(NaryWithFlagNode, NaryOperation);
1001 
1002   /**
1003    * \brief Get the flag of make function node.
1004    *
1005    * \return the flag.
1006    */
GetFlag()1007   int GetFlag() const { return GetInstrArg(); }
1008 };
1009 
1010 using NaryWithFlagNodePtr = std::shared_ptr<NaryWithFlagNode>;
1011 }  // namespace ir
1012 }  // namespace pijit
1013 }  // namespace mindspore
1014 
1015 #endif  // MINDSPORE_PI_JIT_CUSTOM_NODES_H_
1016