1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ 18 19 #include <memory> 20 #include <algorithm> 21 #include <functional> 22 #include <sstream> 23 #include <vector> 24 #include <unordered_map> 25 #include <set> 26 #include <iostream> 27 #include <utility> 28 #include <string> 29 #include <stdexcept> 30 31 #include "mindspore/core/ir/dtype/type_id.h" 32 #include "mindspore/core/ir/value.h" 33 #include "mindspore/core/ir/tensor.h" 34 #include "mindspore/core/utils/shape_utils.h" 35 #include "utils/utils.h" 36 #include "backend/kernel_compiler/common_utils.h" 37 38 namespace mindspore { 39 namespace opt { 40 namespace graphkernel { 41 enum class NType { 42 Base, 43 Primitive, 44 Parameter, 45 Value, 46 Output, 47 }; 48 49 using DFormat = std::string; 50 using DShape = ShapeVector; 51 using DAttrs = std::unordered_map<std::string, ValuePtr>; 52 53 struct NodeBase { 54 DShape shape; 55 TypeId type; 56 DFormat format; 57 }; 58 59 class Node; 60 using NodePtr = std::shared_ptr<Node>; 61 using NodePtrList = std::vector<NodePtr>; 62 class Node : public NodeBase, public std::enable_shared_from_this<Node> { 63 public: Node(const NodeBase & baseinfo,const std::string & name)64 Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {} ~Node()65 virtual ~Node() { 66 // remove this node from the previous nodes' user. 67 SetInputs({}); 68 } 69 SetBaseInfo(NodeBase baseinfo)70 void SetBaseInfo(NodeBase baseinfo) { 71 this->shape = std::move(baseinfo.shape); 72 this->type = std::move(baseinfo.type); 73 this->format = std::move(baseinfo.format); 74 } NodeType()75 virtual NType NodeType() { return NType::Base; } 76 friend std::ostream &operator<<(std::ostream &output, const Node &n) { 77 std::ostringstream os; 78 n.Dump(os); 79 output << os.str(); 80 return output; 81 } 82 virtual void Dump(std::ostringstream &os) const = 0; 83 virtual void DumpTensor(std::ostringstream &os) const; 84 85 void AddInput(const NodePtr &new_input); 86 void SetInput(size_t i, const NodePtr &new_input); 87 void SetInputs(const NodePtrList &inputs); 88 void ReplaceWith(const NodePtr &other_node); SetAttrs(const DAttrs & attrs)89 void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } SetAttr(const std::string & key,const ValuePtr & value)90 void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } 91 92 template <typename T> As()93 std::shared_ptr<T> As() { 94 return std::static_pointer_cast<T>(shared_from_this()); 95 } 96 name()97 const std::string &name() const { return name_; } attrs()98 const DAttrs &attrs() const { return attrs_; } input(size_t i)99 const NodePtr &input(size_t i) const { return inputs_[i]; } inputs()100 const NodePtrList &inputs() const { return inputs_; } users()101 const std::unordered_map<Node *, std::set<size_t>> &users() const { return users_; } 102 103 protected: 104 std::string name_; 105 DAttrs attrs_; 106 NodePtrList inputs_; 107 std::unordered_map<Node *, std::set<size_t>> users_; 108 109 private: 110 // the nodes' users are only maintained by AddInput/SetInput. AddUser(Node * user,size_t index)111 void AddUser(Node *user, size_t index) { users_[user].insert(index); } RemoveUser(Node * user,size_t index)112 void RemoveUser(Node *user, size_t index) { 113 if (auto iter = users_.find(user); iter != users_.end()) { 114 iter->second.erase(index); 115 if (iter->second.empty()) { 116 users_.erase(iter); 117 } 118 } 119 } 120 }; 121 122 class ConstTensorNode : public Node { 123 public: 124 explicit ConstTensorNode(const tensor::TensorPtr &data, const std::string &name = "") 125 : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}, name), data_(data) {} 126 ~ConstTensorNode() = default; 127 NodeType()128 NType NodeType() override { return NType::Value; } Dump(std::ostringstream & os)129 void Dump(std::ostringstream &os) const override { os << ToString(); } DumpTensor(std::ostringstream & os)130 void DumpTensor(std::ostringstream &os) const override { os << ToString(); } ToString()131 std::string ToString() const { return data_->data().ToString(this->type, this->shape, false); } data()132 const tensor::TensorPtr data() const { return data_; } 133 134 protected: 135 tensor::TensorPtr data_; 136 }; 137 138 class ParamNode : public Node { 139 public: ParamNode(const std::string & name,const NodeBase & baseinfo)140 ParamNode(const std::string &name, const NodeBase &baseinfo) : Node(baseinfo, name) {} 141 ~ParamNode() = default; 142 Dump(std::ostringstream & os)143 void Dump(std::ostringstream &os) const override { DumpTensor(os); } NodeType()144 NType NodeType() override { return NType::Parameter; } 145 }; 146 147 class OutputNode : public Node { 148 public: OutputNode()149 OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {} 150 ~OutputNode() = default; 151 Dump(std::ostringstream & os)152 void Dump(std::ostringstream &os) const override { ; } NodeType()153 NType NodeType() override { return NType::Output; } 154 }; 155 } // namespace graphkernel 156 } // namespace opt 157 } // namespace mindspore 158 #endif 159