• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_IR_MUTATOR_H_
17 #define MINDSPORE_PI_JIT_IR_MUTATOR_H_
18 
19 #include "pipeline/jit/pi/graph_compiler/pi_ir/ctrl_flow.h"
20 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h"
21 #include "pipeline/jit/pi/graph_compiler/pi_ir/functor.h"
22 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h"
23 
24 namespace mindspore {
25 namespace pijit {
26 namespace ir {
27 class IRMutator {
28  public:
29   /*
30    * \brief recursively Mutate an IR node
31    */
Mutate(const NodePtr & node)32   virtual NodePtr Mutate(const NodePtr &node) {
33     if (node == nullptr) {
34       return nullptr;
35     }
36     static const FMutate &f = vtable();
37     return f(node, this);
38   }
39 
40   /// \brief destructor
~IRMutator()41   virtual ~IRMutator() {}
42 
43   /*! \brief functor type of visitor */
44   using FMutate = NodeFunctor<NodePtr(const NodePtr &, IRMutator *)>;
45   /*! \return internal vtable */
46   static FMutate &vtable();
47 
48   // overloadable Mutate function.
49   virtual NodePtr Mutate_(const RefNodePtr &node);
50   virtual NodePtr Mutate_(const ParameterPtr &node);
51   virtual NodePtr Mutate_(const FunctionNodePtr &node);
52   virtual NodePtr Mutate_(const ValuePtr &node);
53   virtual NodePtr Mutate_(const IfNodePtr &node);
54   virtual NodePtr Mutate_(const WhileNodePtr &node);
55   virtual NodePtr Mutate_(const UnaryOperationPtr &node);
56   virtual NodePtr Mutate_(const BinaryOperationPtr &node);
57   virtual NodePtr Mutate_(const NaryOperationPtr &node);
58   virtual NodePtr Mutate_(const NegativeNodePtr &node);
59   virtual NodePtr Mutate_(const NotNodePtr &node);
60   virtual NodePtr Mutate_(const InvertNodePtr &node);
61   virtual NodePtr Mutate_(const ReturnNodePtr &node);
62   virtual NodePtr Mutate_(const LoadValueNodePtr &node);
63   virtual NodePtr Mutate_(const CastNodePtr &node);
64   virtual NodePtr Mutate_(const FormatNodePtr &node);
65   virtual NodePtr Mutate_(const AddNodePtr &node);
66   virtual NodePtr Mutate_(const SubNodePtr &node);
67   virtual NodePtr Mutate_(const MulNodePtr &node);
68   virtual NodePtr Mutate_(const DivNodePtr &node);
69   virtual NodePtr Mutate_(const BitwiseNodePtr &node);
70   virtual NodePtr Mutate_(const IsNodePtr &node);
71   virtual NodePtr Mutate_(const ContainsNodePtr &node);
72   virtual NodePtr Mutate_(const StoreNodePtr &node);
73   virtual NodePtr Mutate_(const CompareNodePtr &node);
74   virtual NodePtr Mutate_(const LoadFieldNodePtr &node);
75   virtual NodePtr Mutate_(const BuildNodePtr &node);
76   virtual NodePtr Mutate_(const CallNodePtr &node);
77   virtual NodePtr Mutate_(const NaryWithFlagNodePtr &node);
78   virtual NodePtr Mutate_(const UpdateNodePtr &node);
79   virtual NodePtr Mutate_(const SubscrNodePtr &node);
80   virtual NodePtr Mutate_(const AttrNodePtr &node);
81   virtual NodePtr Mutate_(const PairNodePtr &node);
82 };
83 
84 #define MUTATE_NODE_LIST(LIST)        \
85   do {                                \
86     for (ir::NodePtr & node : LIST) { \
87       node = Mutate(node);            \
88     }                                 \
89   } while (0);
90 
91 #define DISPATCH_TO_MUTATE(OP) \
92   set_dispatch<OP>([](const NodePtr &node, IRMutator *m) { return m->Mutate_(std::static_pointer_cast<OP>(node)); })
93 }  // namespace ir
94 }  // namespace pijit
95 }  // namespace mindspore
96 
97 #endif  // MINDSPORE_PI_JIT_IR_MUTATOR_H_
98