1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_IR_FUNCTOR_H_ 17 #define MINDSPORE_PI_JIT_IR_FUNCTOR_H_ 18 19 #include <map> 20 #include <utility> 21 #include "pipeline/jit/pi/graph_compiler/pi_ir/node.h" 22 #include "utils/log_adapter.h" 23 24 namespace mindspore { 25 namespace pijit { 26 namespace ir { 27 template <typename FType> 28 class NodeFunctor; 29 30 template <typename R, typename... Args> 31 class NodeFunctor<R(const NodePtr &node, Args...)> { 32 private: 33 /*! \brief internal function pointer type */ 34 typedef R (*FPointer)(const NodePtr &node, Args...); 35 /*! \brief refer to itself. */ 36 using TSelf = NodeFunctor<R(const NodePtr &node, Args...)>; 37 /*! \brief internal function table */ 38 std::map<uint32_t, FPointer> func_; 39 40 public: 41 /*! 42 * \brief Whether the functor can dispatch the corresponding Node 43 * \param n The node to be dispatched 44 * \return Whether dispatching function is registered for n's type. 45 */ can_dispatch(const NodePtr & node)46 bool can_dispatch(const NodePtr &node) const { return func_.find(node->GetClassId()) != func_.end(); } 47 /*! 48 * \brief invoke the functor, dispatch on type of n 49 * \param n The Node argument 50 * \return The result. 51 */ operator()52 R operator()(const NodePtr &node, Args... args) const { 53 MS_EXCEPTION_IF_CHECK_FAIL(can_dispatch(node), "NodeFunctor not defined for " + node->GetNodeName() + "."); 54 return (*func_.at(node->GetClassId()))(node, std::forward<Args>(args)...); 55 } 56 /*! 57 * \brief set the dispacher for type TNode 58 * \param f The function to be set. 59 * \tparam TNode the type of Node to be dispatched. 60 * \return reference to self. 61 */ 62 template <typename OP> set_dispatch(FPointer f)63 TSelf &set_dispatch(FPointer f) { // NOLINT(*) 64 func_[OP::kClassId] = f; 65 return *this; 66 } 67 /*! 68 * \brief unset the dispacher for type TNode 69 * 70 * \tparam TNode the type of Node to be dispatched. 71 * \return reference to self. 72 */ 73 template <typename OP> clear_dispatch()74 TSelf &clear_dispatch() { // NOLINT(*) 75 func_.erase(OP::kClassId); 76 return *this; 77 } 78 }; 79 80 /*! \brief helper macro to suppress unused warning */ 81 #if defined(__GNUC__) 82 #define IR_ATTRIBUTE_UNUSED __attribute__((unused)) 83 #else 84 #define IR_ATTRIBUTE_UNUSED 85 #endif 86 87 #define STATIC_IR_FUNCTOR(ClsName, FField) \ 88 static IR_ATTRIBUTE_UNUSED auto &__make_functor##_##ClsName##__COUNTER__ = ClsName::FField() 89 } // namespace ir 90 } // namespace pijit 91 } // namespace mindspore 92 93 #endif // MINDSPORE_PI_JIT_IR_FUNCTOR_H_ 94