1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_IR_MUTATOR_H_ 17 #define MINDSPORE_PI_JIT_IR_MUTATOR_H_ 18 19 #include "pipeline/jit/pi/graph_compiler/pi_ir/ctrl_flow.h" 20 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h" 21 #include "pipeline/jit/pi/graph_compiler/pi_ir/functor.h" 22 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h" 23 24 namespace mindspore { 25 namespace pijit { 26 namespace ir { 27 class IRMutator { 28 public: 29 /* 30 * \brief recursively Mutate an IR node 31 */ Mutate(const NodePtr & node)32 virtual NodePtr Mutate(const NodePtr &node) { 33 if (node == nullptr) { 34 return nullptr; 35 } 36 static const FMutate &f = vtable(); 37 return f(node, this); 38 } 39 40 /// \brief destructor ~IRMutator()41 virtual ~IRMutator() {} 42 43 /*! \brief functor type of visitor */ 44 using FMutate = NodeFunctor<NodePtr(const NodePtr &, IRMutator *)>; 45 /*! \return internal vtable */ 46 static FMutate &vtable(); 47 48 // overloadable Mutate function. 49 virtual NodePtr Mutate_(const RefNodePtr &node); 50 virtual NodePtr Mutate_(const ParameterPtr &node); 51 virtual NodePtr Mutate_(const FunctionNodePtr &node); 52 virtual NodePtr Mutate_(const ValuePtr &node); 53 virtual NodePtr Mutate_(const IfNodePtr &node); 54 virtual NodePtr Mutate_(const WhileNodePtr &node); 55 virtual NodePtr Mutate_(const UnaryOperationPtr &node); 56 virtual NodePtr Mutate_(const BinaryOperationPtr &node); 57 virtual NodePtr Mutate_(const NaryOperationPtr &node); 58 virtual NodePtr Mutate_(const NegativeNodePtr &node); 59 virtual NodePtr Mutate_(const NotNodePtr &node); 60 virtual NodePtr Mutate_(const InvertNodePtr &node); 61 virtual NodePtr Mutate_(const ReturnNodePtr &node); 62 virtual NodePtr Mutate_(const LoadValueNodePtr &node); 63 virtual NodePtr Mutate_(const CastNodePtr &node); 64 virtual NodePtr Mutate_(const FormatNodePtr &node); 65 virtual NodePtr Mutate_(const AddNodePtr &node); 66 virtual NodePtr Mutate_(const SubNodePtr &node); 67 virtual NodePtr Mutate_(const MulNodePtr &node); 68 virtual NodePtr Mutate_(const DivNodePtr &node); 69 virtual NodePtr Mutate_(const BitwiseNodePtr &node); 70 virtual NodePtr Mutate_(const IsNodePtr &node); 71 virtual NodePtr Mutate_(const ContainsNodePtr &node); 72 virtual NodePtr Mutate_(const StoreNodePtr &node); 73 virtual NodePtr Mutate_(const CompareNodePtr &node); 74 virtual NodePtr Mutate_(const LoadFieldNodePtr &node); 75 virtual NodePtr Mutate_(const BuildNodePtr &node); 76 virtual NodePtr Mutate_(const CallNodePtr &node); 77 virtual NodePtr Mutate_(const NaryWithFlagNodePtr &node); 78 virtual NodePtr Mutate_(const UpdateNodePtr &node); 79 virtual NodePtr Mutate_(const SubscrNodePtr &node); 80 virtual NodePtr Mutate_(const AttrNodePtr &node); 81 virtual NodePtr Mutate_(const PairNodePtr &node); 82 }; 83 84 #define MUTATE_NODE_LIST(LIST) \ 85 do { \ 86 for (ir::NodePtr & node : LIST) { \ 87 node = Mutate(node); \ 88 } \ 89 } while (0); 90 91 #define DISPATCH_TO_MUTATE(OP) \ 92 set_dispatch<OP>([](const NodePtr &node, IRMutator *m) { return m->Mutate_(std::static_pointer_cast<OP>(node)); }) 93 } // namespace ir 94 } // namespace pijit 95 } // namespace mindspore 96 97 #endif // MINDSPORE_PI_JIT_IR_MUTATOR_H_ 98