• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
18 
19 #include <memory>
20 #include <vector>
21 #include <set>
22 #include <string>
23 #include "ir/dtype/type_id.h"
24 #include "ir/anf.h"
25 #include "ir/value.h"
26 #include "ir/tensor.h"
27 #include "utils/hash_map.h"
28 #include "utils/shape_utils.h"
29 #include "include/common/utils/utils.h"
30 #include "include/backend/visible.h"
31 #include "mindspore/core/symbolic_shape/symbol.h"
32 
33 namespace mindspore::graphkernel::inner {
34 enum class NType {
35   Base,
36   Primitive,
37   Parameter,
38   Tensor,
39   Scalar,
40   Tuple,
41   Output,
42 };
43 
44 using DFormat = std::string;
45 using DShape = ShapeVector;
46 using DAttrs = mindspore::HashMap<std::string, ValuePtr>;
47 
48 struct BACKEND_EXPORT NodeBase {
49   DShape shape;
50   TypeId type;
51   DFormat format;
52   ListSymbolPtr symbolic_shape{nullptr};
53 };
54 using NodeBaseList = std::vector<NodeBase>;
55 
56 class BACKEND_EXPORT Node;
57 using NodePtr = std::shared_ptr<Node>;
58 using NodePtrList = std::vector<NodePtr>;
59 class BACKEND_EXPORT Node : public NodeBase, public std::enable_shared_from_this<Node> {
60  public:
Node(const NodeBase & baseinfo)61   explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {}
~Node()62   virtual ~Node() { ClearInputs(); }  // remove this node from the previous nodes' user.
63 
NodeType()64   virtual NType NodeType() { return NType::Base; }
65   virtual std::string ToString() const;
66   virtual abstract::AbstractBasePtr ToAbstract() const;
67 
68   virtual void SetBaseInfo(const NodeBaseList &baseinfo);
69   void AddInput(const NodePtr &new_input);
70   void SetInput(size_t i, const NodePtr &new_input);
71   void SetInputs(const NodePtrList &inputs);
72   void ClearInputs() noexcept;
73   void ReplaceWith(const NodePtr &other_node);
SetAttrs(const DAttrs & attrs)74   void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; }
SetAttr(const std::string & key,const ValuePtr & value)75   void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
SetDebugName(const std::string & debug_name)76   void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; }
77 
78   template <typename T>
As()79   std::shared_ptr<T> As() {
80     return std::static_pointer_cast<T>(shared_from_this());
81   }
82 
debug_name()83   const std::string &debug_name() const { return debug_name_; }
attrs()84   const DAttrs &attrs() const { return attrs_; }
input(size_t i)85   const NodePtr &input(size_t i) const { return inputs_[i]; }
inputs()86   const NodePtrList &inputs() const { return inputs_; }
users()87   const mindspore::HashMap<Node *, std::set<size_t>> &users() const { return users_; }
88   size_t tensor_size(bool in_bytes = false) const;
outputs()89   const NodeBaseList &outputs() const { return outputs_; }
90 
91  protected:
92   // only used in Dump function
93   mutable std::string debug_name_;
94   DAttrs attrs_;
95   NodePtrList inputs_;
96   // {user_node: {input edge index set}}
97   mindspore::HashMap<Node *, std::set<size_t>> users_;
98   // save output tensor info when the node is a multi-output operator.
99   // it should keep empty when the node is single-output.
100   NodeBaseList outputs_;
101 
102  private:
103   // the nodes' users are only maintained by AddInput/SetInput.
AddUser(Node * const user,size_t index)104   void AddUser(Node *const user, size_t index) { (void)users_[user].insert(index); }
105   void RemoveUser(Node *const user, size_t index);
106 };
107 
108 class BACKEND_EXPORT ConstTensorNode : public Node {
109  public:
ConstTensorNode(const tensor::TensorPtr & data)110   explicit ConstTensorNode(const tensor::TensorPtr &data)
111       : Node({data->DataSize() == 1 ? DShape({1}) : data->shape(), data->data_type(), kOpFormat_DEFAULT}),
112         data_(data) {}
113   ~ConstTensorNode() = default;
114 
NodeType()115   NType NodeType() override { return NType::Tensor; }
ToString()116   std::string ToString() const override { return data_->data().ToString(data_->data_type(), data_->shape(), false); }
data()117   const tensor::TensorPtr data() const { return data_; }
ToAbstract()118   abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); }
119 
120  protected:
121   tensor::TensorPtr data_;
122 };
123 
124 class ConstScalarNode : public Node {
125  public:
126   explicit ConstScalarNode(const ValuePtr &data);
127   ~ConstScalarNode() = default;
128 
NodeType()129   NType NodeType() override { return NType::Scalar; }
data()130   const ValuePtr data() const { return data_; }
ToAbstract()131   abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); }
132 
133  protected:
134   ValuePtr data_;
135 };
136 
137 class ConstTupleNode : public Node {
138  public:
139   explicit ConstTupleNode(const ValuePtr &data, const size_t len);
140   ~ConstTupleNode() = default;
141 
NodeType()142   NType NodeType() override { return NType::Tuple; }
data()143   const ValuePtr data() const { return data_; }
ToAbstract()144   abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); }
145 
146  protected:
147   ValuePtr data_;
148 };
149 
150 class ParamNode : public Node {
151  public:
ParamNode(const NodeBase & baseinfo)152   explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {}
153   ~ParamNode() = default;
154 
NodeType()155   NType NodeType() override { return NType::Parameter; }
156 };
157 
158 // the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph.
159 class OutputNode : public Node {
160  public:
OutputNode()161   OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) { debug_name_ = "Output"; }
162   ~OutputNode() = default;
163 
NodeType()164   NType NodeType() override { return NType::Output; }
165 };
166 }  // namespace mindspore::graphkernel::inner
167 #endif
168