1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_NODE_H_ 17 #define MINDSPORE_PI_JIT_NODE_H_ 18 19 #include <memory> 20 #include <limits> 21 #include <list> 22 #include <string> 23 #include <vector> 24 #include "pipeline/jit/pi/graph_compiler/pi_ir/debug_info.h" 25 #include "pipeline/jit/pi/graph_compiler/pi_ir/type.h" 26 #include "utils/hashing.h" 27 28 namespace mindspore { 29 namespace pijit { 30 namespace ir { 31 template <typename T> 32 struct is_shared_ptr : public std::false_type {}; 33 template <typename T> 34 struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {}; 35 36 /// \brief Node is a base class of all classes, which represent value, instruction, and so on. 37 class Node : public std::enable_shared_from_this<Node> { 38 public: 39 /** 40 * \brief The constructor of Node. 41 * 42 * \return The instance of Node. 43 */ 44 Node() 45 : type_(std::make_shared<Type>()), 46 node_id_(0), 47 offset_(std::numeric_limits<size_t>::max()), 48 debug_info_(std::make_shared<DebugInfo>("")) {} 49 50 /// \brief Destructor. 51 virtual ~Node() = default; 52 53 /// \brief The description id of this class. 54 static constexpr uint32_t kClassId = ConstStringHash("Node"); 55 56 /** 57 * \brief Get type of this node. 58 * 59 * \return The type of this node. 60 */ 61 const TypePtr &GetType() const { return type_; } 62 63 /** 64 * \brief Set type of this node. 65 */ 66 void SetType(const TypePtr type) { type_ = type; } 67 68 /** 69 * \brief Get the id of this node. 70 * 71 * \return The id of this node. 72 */ 73 size_t GetNodeId() const { return node_id_; } 74 75 /** 76 * \brief Set the id of this node. 77 * 78 * \note This method should not be actively called by the program writer, it should only be called by the method 79 * Sort() 80 */ 81 virtual void SetNodeId(size_t *id) { 82 node_id_ = *id; 83 (*id)++; 84 } 85 86 /** 87 * \brief Get the offset if this node is a instruction, else returns an invalid value. 88 * 89 * \return The offset of this node. 90 */ 91 size_t GetOffset() const { return offset_; } 92 93 /** 94 * \brief Set the offset of this node. 95 * 96 * \note This method should not be actively called by the program writer, it should only be called by the method 97 * Sort() 98 */ 99 virtual void SetOffset(size_t *offset) { 100 if (IsOperation()) { 101 if (NeedExtInstr()) { 102 (*offset)++; 103 } 104 offset_ = *offset; 105 (*offset)++; 106 } 107 } 108 109 /** 110 * \brief Sort all nodes, and give them a id and a offset if this node is a instruction. 111 * 112 * \note This method should only be called on the root function node 113 */ 114 void Sort(size_t index = 0, size_t offset = 0) { 115 SetNodeId(&index); 116 SetOffset(&offset); 117 } 118 119 /** 120 * \brief Get the debug information of this node. 121 * 122 * \return The debug information of this node. 123 */ 124 const DebugInfoPtr &GetDebugInfo() const { return debug_info_; } 125 126 /** 127 * \brief Set the debug information of this node. 128 * 129 * \param[in] debug_info The debug information of this node. 130 */ 131 void SetDebugInfo(const DebugInfoPtr &debug_info) { debug_info_ = debug_info; } 132 133 /** 134 * \brief Judge whether this class is derived from class with the given class id. 135 * 136 * \param[in] id Define a class id. 137 * 138 * \return The result of the judgment. 139 */ 140 static bool IsDerivedFrom(uint32_t id) { return id == Node::kClassId; } 141 142 /// \brief Judge whether this object is an instance of class with the given type id. 143 /// 144 /// \param[in] id Define a type id. 145 /// 146 /// \return The result of the judgment. 147 virtual bool IsFromClass(uint32_t id) const { return Node::IsDerivedFrom(id); } 148 149 /// \brief Get the id of this class. 150 /// 151 /// \return The id of this class. 152 virtual uint32_t GetClassId() const { return Node::kClassId; } 153 154 /** 155 * \brief Judge whether the class id of this node is same as the given class id. 156 * 157 * \param[in] id Define a class id. 158 * 159 * \return The result of the judgment. 160 */ 161 virtual bool IsSameClass(uint32_t id) const { return id == Node::kClassId; } 162 163 /// \brief Get the name of this class. 164 /// 165 /// \return The node name. 166 virtual std::string GetNodeName() const { return "Node"; } 167 168 /** 169 * \brief Judge whether this node is an instance of a given class which is derived from Node. 170 * 171 * \return The result of the judgment. 172 */ 173 template <typename T, 174 typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Node, T>::value, T>::type * = nullptr> 175 inline bool isa() const { 176 if constexpr (std::is_final<T>::value) { 177 return this->IsSameClass(T::kClassId); 178 } else { 179 return this->IsFromClass(T::kClassId); 180 } 181 } 182 183 /// \brief Cast a shared_ptr of this object to a given class. 184 /// 185 /// \return If success, a shared_ptr of the given class will be returned. Otherwise a nullptr will be returned. 186 template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type> 187 inline T cast() { 188 if (isa<U>()) { 189 return std::static_pointer_cast<U>(shared_from_this()); 190 } 191 return nullptr; 192 } 193 194 /** 195 * \brief Judge whether this node is an operation(instruction). 196 * 197 * \return The result of the judgment. 198 */ 199 virtual bool IsOperation() const { return false; } 200 201 /** 202 * \brief Judge whether need to insert a EXTENDED_ARG instruction before this operation. 203 * 204 * \return The result of the judgment. 205 */ 206 virtual bool NeedExtInstr() const { return false; } 207 208 /** 209 * \brief Mark whether this operation need to insert a EXTENDED_ARG instruction. 210 * 211 * \param[in] need the result. 212 */ 213 virtual void SetNeedExtInstr(bool need) {} 214 215 /** 216 * \brief Get the description of this node. 217 * \return The description. 218 */ 219 virtual std::string ToString() const = 0; 220 221 private: 222 /// \brief The type of this node. 223 TypePtr type_; 224 /// \brief The id of this node, used to describe node when dump. 225 size_t node_id_; 226 /// \brief The offset of this node, only makes sense when the node is an operation. 227 size_t offset_; 228 /// \brief The debug information of this node. 229 DebugInfoPtr debug_info_; 230 }; 231 232 using NodePtr = std::shared_ptr<Node>; 233 using NodePtrList = std::vector<NodePtr>; 234 235 #define JIT_DECLARE_PARENT(current_t, parent_t) \ 236 static constexpr uint32_t kClassId = ConstStringHash(#parent_t "_" #current_t); \ 237 static bool IsDerivedFrom(uint32_t id) { return (id == current_t::kClassId) || parent_t::IsDerivedFrom(id); } \ 238 bool IsFromClass(uint32_t id) const override { return current_t::IsDerivedFrom(id); } \ 239 bool IsSameClass(uint32_t id) const override { return id == current_t::kClassId; } \ 240 uint32_t GetClassId() const override { return current_t::kClassId; } \ 241 std::string GetNodeName() const override { return #current_t; } 242 } // namespace ir 243 } // namespace pijit 244 } // namespace mindspore 245 246 #endif // MINDSPORE_PI_JIT_NODE_H_ 247