• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_NODE_H_
17 #define MINDSPORE_PI_JIT_NODE_H_
18 
19 #include <memory>
20 #include <limits>
21 #include <list>
22 #include <string>
23 #include <vector>
24 #include "pipeline/jit/pi/graph_compiler/pi_ir/debug_info.h"
25 #include "pipeline/jit/pi/graph_compiler/pi_ir/type.h"
26 #include "utils/hashing.h"
27 
28 namespace mindspore {
29 namespace pijit {
30 namespace ir {
31 template <typename T>
32 struct is_shared_ptr : public std::false_type {};
33 template <typename T>
34 struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
35 
36 /// \brief Node is a base class of all classes, which represent value, instruction, and so on.
37 class Node : public std::enable_shared_from_this<Node> {
38  public:
39   /**
40    * \brief The constructor of Node.
41    *
42    * \return The instance of Node.
43    */
44   Node()
45       : type_(std::make_shared<Type>()),
46         node_id_(0),
47         offset_(std::numeric_limits<size_t>::max()),
48         debug_info_(std::make_shared<DebugInfo>("")) {}
49 
50   /// \brief Destructor.
51   virtual ~Node() = default;
52 
53   /// \brief The description id of this class.
54   static constexpr uint32_t kClassId = ConstStringHash("Node");
55 
56   /**
57    * \brief Get type of this node.
58    *
59    * \return The type of this node.
60    */
61   const TypePtr &GetType() const { return type_; }
62 
63   /**
64    * \brief Set type of this node.
65    */
66   void SetType(const TypePtr type) { type_ = type; }
67 
68   /**
69    * \brief Get the id of this node.
70    *
71    * \return The id of this node.
72    */
73   size_t GetNodeId() const { return node_id_; }
74 
75   /**
76    * \brief Set the id of this node.
77    *
78    * \note This method should not be actively called by the program writer, it should only be called by the method
79    * Sort()
80    */
81   virtual void SetNodeId(size_t *id) {
82     node_id_ = *id;
83     (*id)++;
84   }
85 
86   /**
87    * \brief Get the offset if this node is a instruction, else returns an invalid value.
88    *
89    * \return The offset of this node.
90    */
91   size_t GetOffset() const { return offset_; }
92 
93   /**
94    * \brief Set the offset of this node.
95    *
96    * \note This method should not be actively called by the program writer, it should only be called by the method
97    * Sort()
98    */
99   virtual void SetOffset(size_t *offset) {
100     if (IsOperation()) {
101       if (NeedExtInstr()) {
102         (*offset)++;
103       }
104       offset_ = *offset;
105       (*offset)++;
106     }
107   }
108 
109   /**
110    * \brief Sort all nodes, and give them a id and a offset if this node is a instruction.
111    *
112    * \note This method should only be called on the root function node
113    */
114   void Sort(size_t index = 0, size_t offset = 0) {
115     SetNodeId(&index);
116     SetOffset(&offset);
117   }
118 
119   /**
120    * \brief Get the debug information of this node.
121    *
122    * \return The debug information of this node.
123    */
124   const DebugInfoPtr &GetDebugInfo() const { return debug_info_; }
125 
126   /**
127    * \brief Set the debug information of this node.
128    *
129    * \param[in] debug_info The debug information of this node.
130    */
131   void SetDebugInfo(const DebugInfoPtr &debug_info) { debug_info_ = debug_info; }
132 
133   /**
134    * \brief Judge whether this class is derived from class with the given class id.
135    *
136    * \param[in] id Define a class id.
137    *
138    * \return The result of the judgment.
139    */
140   static bool IsDerivedFrom(uint32_t id) { return id == Node::kClassId; }
141 
142   /// \brief Judge whether this object is an instance of class with the given type id.
143   ///
144   /// \param[in] id Define a type id.
145   ///
146   /// \return The result of the judgment.
147   virtual bool IsFromClass(uint32_t id) const { return Node::IsDerivedFrom(id); }
148 
149   /// \brief Get the id of this class.
150   ///
151   /// \return The id of this class.
152   virtual uint32_t GetClassId() const { return Node::kClassId; }
153 
154   /**
155    * \brief Judge whether the class id of this node is same as the given class id.
156    *
157    * \param[in] id Define a class id.
158    *
159    * \return The result of the judgment.
160    */
161   virtual bool IsSameClass(uint32_t id) const { return id == Node::kClassId; }
162 
163   /// \brief Get the name of this class.
164   ///
165   /// \return The node name.
166   virtual std::string GetNodeName() const { return "Node"; }
167 
168   /**
169    * \brief Judge whether this node is an instance of a given class which is derived from Node.
170    *
171    * \return The result of the judgment.
172    */
173   template <typename T,
174             typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Node, T>::value, T>::type * = nullptr>
175   inline bool isa() const {
176     if constexpr (std::is_final<T>::value) {
177       return this->IsSameClass(T::kClassId);
178     } else {
179       return this->IsFromClass(T::kClassId);
180     }
181   }
182 
183   /// \brief Cast a shared_ptr of this object to a given class.
184   ///
185   /// \return If success, a shared_ptr of the given class will be returned. Otherwise a nullptr will be returned.
186   template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type>
187   inline T cast() {
188     if (isa<U>()) {
189       return std::static_pointer_cast<U>(shared_from_this());
190     }
191     return nullptr;
192   }
193 
194   /**
195    * \brief Judge whether this node is an operation(instruction).
196    *
197    * \return The result of the judgment.
198    */
199   virtual bool IsOperation() const { return false; }
200 
201   /**
202    * \brief Judge whether need to insert a EXTENDED_ARG instruction before this operation.
203    *
204    * \return The result of the judgment.
205    */
206   virtual bool NeedExtInstr() const { return false; }
207 
208   /**
209    * \brief Mark whether this operation need to insert a EXTENDED_ARG instruction.
210    *
211    * \param[in] need the result.
212    */
213   virtual void SetNeedExtInstr(bool need) {}
214 
215   /**
216    * \brief Get the description of this node.
217    * \return The description.
218    */
219   virtual std::string ToString() const = 0;
220 
221  private:
222   /// \brief The type of this node.
223   TypePtr type_;
224   /// \brief The id of this node, used to describe node when dump.
225   size_t node_id_;
226   /// \brief The offset of this node, only makes sense when the node is an operation.
227   size_t offset_;
228   /// \brief The debug information of this node.
229   DebugInfoPtr debug_info_;
230 };
231 
232 using NodePtr = std::shared_ptr<Node>;
233 using NodePtrList = std::vector<NodePtr>;
234 
235 #define JIT_DECLARE_PARENT(current_t, parent_t)                                                                 \
236   static constexpr uint32_t kClassId = ConstStringHash(#parent_t "_" #current_t);                               \
237   static bool IsDerivedFrom(uint32_t id) { return (id == current_t::kClassId) || parent_t::IsDerivedFrom(id); } \
238   bool IsFromClass(uint32_t id) const override { return current_t::IsDerivedFrom(id); }                         \
239   bool IsSameClass(uint32_t id) const override { return id == current_t::kClassId; }                            \
240   uint32_t GetClassId() const override { return current_t::kClassId; }                                          \
241   std::string GetNodeName() const override { return #current_t; }
242 }  // namespace ir
243 }  // namespace pijit
244 }  // namespace mindspore
245 
246 #endif  // MINDSPORE_PI_JIT_NODE_H_
247