• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_IR_FUNCTOR_H_
17 #define MINDSPORE_PI_JIT_IR_FUNCTOR_H_
18 
19 #include <map>
20 #include <utility>
21 #include "pipeline/jit/pi/graph_compiler/pi_ir/node.h"
22 #include "utils/log_adapter.h"
23 
24 namespace mindspore {
25 namespace pijit {
26 namespace ir {
27 template <typename FType>
28 class NodeFunctor;
29 
30 template <typename R, typename... Args>
31 class NodeFunctor<R(const NodePtr &node, Args...)> {
32  private:
33   /*! \brief internal function pointer type */
34   typedef R (*FPointer)(const NodePtr &node, Args...);
35   /*! \brief refer to itself. */
36   using TSelf = NodeFunctor<R(const NodePtr &node, Args...)>;
37   /*! \brief internal function table */
38   std::map<uint32_t, FPointer> func_;
39 
40  public:
41   /*!
42    * \brief Whether the functor can dispatch the corresponding Node
43    * \param n The node to be dispatched
44    * \return Whether dispatching function is registered for n's type.
45    */
can_dispatch(const NodePtr & node)46   bool can_dispatch(const NodePtr &node) const { return func_.find(node->GetClassId()) != func_.end(); }
47   /*!
48    * \brief invoke the functor, dispatch on type of n
49    * \param n The Node argument
50    * \return The result.
51    */
operator()52   R operator()(const NodePtr &node, Args... args) const {
53     MS_EXCEPTION_IF_CHECK_FAIL(can_dispatch(node), "NodeFunctor not defined for " + node->GetNodeName() + ".");
54     return (*func_.at(node->GetClassId()))(node, std::forward<Args>(args)...);
55   }
56   /*!
57    * \brief set the dispacher for type TNode
58    * \param f The function to be set.
59    * \tparam TNode the type of Node to be dispatched.
60    * \return reference to self.
61    */
62   template <typename OP>
set_dispatch(FPointer f)63   TSelf &set_dispatch(FPointer f) {  // NOLINT(*)
64     func_[OP::kClassId] = f;
65     return *this;
66   }
67   /*!
68    * \brief unset the dispacher for type TNode
69    *
70    * \tparam TNode the type of Node to be dispatched.
71    * \return reference to self.
72    */
73   template <typename OP>
clear_dispatch()74   TSelf &clear_dispatch() {  // NOLINT(*)
75     func_.erase(OP::kClassId);
76     return *this;
77   }
78 };
79 
80 /*! \brief helper macro to suppress unused warning */
81 #if defined(__GNUC__)
82 #define IR_ATTRIBUTE_UNUSED __attribute__((unused))
83 #else
84 #define IR_ATTRIBUTE_UNUSED
85 #endif
86 
87 #define STATIC_IR_FUNCTOR(ClsName, FField) \
88   static IR_ATTRIBUTE_UNUSED auto &__make_functor##_##ClsName##__COUNTER__ = ClsName::FField()
89 }  // namespace ir
90 }  // namespace pijit
91 }  // namespace mindspore
92 
93 #endif  // MINDSPORE_PI_JIT_IR_FUNCTOR_H_
94