1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ 18 19 #include <memory> 20 #include <vector> 21 #include <set> 22 #include <string> 23 #include "ir/dtype/type_id.h" 24 #include "ir/anf.h" 25 #include "ir/value.h" 26 #include "ir/tensor.h" 27 #include "utils/hash_map.h" 28 #include "utils/shape_utils.h" 29 #include "include/common/utils/utils.h" 30 #include "include/backend/visible.h" 31 #include "mindspore/core/symbolic_shape/symbol.h" 32 33 namespace mindspore::graphkernel::inner { 34 enum class NType { 35 Base, 36 Primitive, 37 Parameter, 38 Tensor, 39 Scalar, 40 Tuple, 41 Output, 42 }; 43 44 using DFormat = std::string; 45 using DShape = ShapeVector; 46 using DAttrs = mindspore::HashMap<std::string, ValuePtr>; 47 48 struct BACKEND_EXPORT NodeBase { 49 DShape shape; 50 TypeId type; 51 DFormat format; 52 ListSymbolPtr symbolic_shape{nullptr}; 53 }; 54 using NodeBaseList = std::vector<NodeBase>; 55 56 class BACKEND_EXPORT Node; 57 using NodePtr = std::shared_ptr<Node>; 58 using NodePtrList = std::vector<NodePtr>; 59 class BACKEND_EXPORT Node : public NodeBase, public std::enable_shared_from_this<Node> { 60 public: Node(const NodeBase & baseinfo)61 explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {} ~Node()62 virtual ~Node() { ClearInputs(); } // remove this node from the previous nodes' user. 63 NodeType()64 virtual NType NodeType() { return NType::Base; } 65 virtual std::string ToString() const; 66 virtual abstract::AbstractBasePtr ToAbstract() const; 67 68 virtual void SetBaseInfo(const NodeBaseList &baseinfo); 69 void AddInput(const NodePtr &new_input); 70 void SetInput(size_t i, const NodePtr &new_input); 71 void SetInputs(const NodePtrList &inputs); 72 void ClearInputs() noexcept; 73 void ReplaceWith(const NodePtr &other_node); SetAttrs(const DAttrs & attrs)74 void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } SetAttr(const std::string & key,const ValuePtr & value)75 void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } SetDebugName(const std::string & debug_name)76 void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; } 77 78 template <typename T> As()79 std::shared_ptr<T> As() { 80 return std::static_pointer_cast<T>(shared_from_this()); 81 } 82 debug_name()83 const std::string &debug_name() const { return debug_name_; } attrs()84 const DAttrs &attrs() const { return attrs_; } input(size_t i)85 const NodePtr &input(size_t i) const { return inputs_[i]; } inputs()86 const NodePtrList &inputs() const { return inputs_; } users()87 const mindspore::HashMap<Node *, std::set<size_t>> &users() const { return users_; } 88 size_t tensor_size(bool in_bytes = false) const; outputs()89 const NodeBaseList &outputs() const { return outputs_; } 90 91 protected: 92 // only used in Dump function 93 mutable std::string debug_name_; 94 DAttrs attrs_; 95 NodePtrList inputs_; 96 // {user_node: {input edge index set}} 97 mindspore::HashMap<Node *, std::set<size_t>> users_; 98 // save output tensor info when the node is a multi-output operator. 99 // it should keep empty when the node is single-output. 100 NodeBaseList outputs_; 101 102 private: 103 // the nodes' users are only maintained by AddInput/SetInput. AddUser(Node * const user,size_t index)104 void AddUser(Node *const user, size_t index) { (void)users_[user].insert(index); } 105 void RemoveUser(Node *const user, size_t index); 106 }; 107 108 class BACKEND_EXPORT ConstTensorNode : public Node { 109 public: ConstTensorNode(const tensor::TensorPtr & data)110 explicit ConstTensorNode(const tensor::TensorPtr &data) 111 : Node({data->DataSize() == 1 ? DShape({1}) : data->shape(), data->data_type(), kOpFormat_DEFAULT}), 112 data_(data) {} 113 ~ConstTensorNode() = default; 114 NodeType()115 NType NodeType() override { return NType::Tensor; } ToString()116 std::string ToString() const override { return data_->data().ToString(data_->data_type(), data_->shape(), false); } data()117 const tensor::TensorPtr data() const { return data_; } ToAbstract()118 abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } 119 120 protected: 121 tensor::TensorPtr data_; 122 }; 123 124 class ConstScalarNode : public Node { 125 public: 126 explicit ConstScalarNode(const ValuePtr &data); 127 ~ConstScalarNode() = default; 128 NodeType()129 NType NodeType() override { return NType::Scalar; } data()130 const ValuePtr data() const { return data_; } ToAbstract()131 abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } 132 133 protected: 134 ValuePtr data_; 135 }; 136 137 class ConstTupleNode : public Node { 138 public: 139 explicit ConstTupleNode(const ValuePtr &data, const size_t len); 140 ~ConstTupleNode() = default; 141 NodeType()142 NType NodeType() override { return NType::Tuple; } data()143 const ValuePtr data() const { return data_; } ToAbstract()144 abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } 145 146 protected: 147 ValuePtr data_; 148 }; 149 150 class ParamNode : public Node { 151 public: ParamNode(const NodeBase & baseinfo)152 explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {} 153 ~ParamNode() = default; 154 NodeType()155 NType NodeType() override { return NType::Parameter; } 156 }; 157 158 // the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph. 159 class OutputNode : public Node { 160 public: OutputNode()161 OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) { debug_name_ = "Output"; } 162 ~OutputNode() = default; 163 NodeType()164 NType NodeType() override { return NType::Output; } 165 }; 166 } // namespace mindspore::graphkernel::inner 167 #endif 168