• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
18 
19 #include <memory>
20 #include <algorithm>
21 #include <functional>
22 #include <sstream>
23 #include <vector>
24 #include <unordered_map>
25 #include <set>
26 #include <iostream>
27 #include <utility>
28 #include <string>
29 #include <stdexcept>
30 
31 #include "mindspore/core/ir/dtype/type_id.h"
32 #include "mindspore/core/ir/value.h"
33 #include "mindspore/core/ir/tensor.h"
34 #include "mindspore/core/utils/shape_utils.h"
35 #include "utils/utils.h"
36 #include "backend/kernel_compiler/common_utils.h"
37 
38 namespace mindspore {
39 namespace opt {
40 namespace graphkernel {
41 enum class NType {
42   Base,
43   Primitive,
44   Parameter,
45   Value,
46   Output,
47 };
48 
49 using DFormat = std::string;
50 using DShape = ShapeVector;
51 using DAttrs = std::unordered_map<std::string, ValuePtr>;
52 
53 struct NodeBase {
54   DShape shape;
55   TypeId type;
56   DFormat format;
57 };
58 
59 class Node;
60 using NodePtr = std::shared_ptr<Node>;
61 using NodePtrList = std::vector<NodePtr>;
62 class Node : public NodeBase, public std::enable_shared_from_this<Node> {
63  public:
Node(const NodeBase & baseinfo,const std::string & name)64   Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {}
~Node()65   virtual ~Node() {
66     // remove this node from the previous nodes' user.
67     SetInputs({});
68   }
69 
SetBaseInfo(NodeBase baseinfo)70   void SetBaseInfo(NodeBase baseinfo) {
71     this->shape = std::move(baseinfo.shape);
72     this->type = std::move(baseinfo.type);
73     this->format = std::move(baseinfo.format);
74   }
NodeType()75   virtual NType NodeType() { return NType::Base; }
76   friend std::ostream &operator<<(std::ostream &output, const Node &n) {
77     std::ostringstream os;
78     n.Dump(os);
79     output << os.str();
80     return output;
81   }
82   virtual void Dump(std::ostringstream &os) const = 0;
83   virtual void DumpTensor(std::ostringstream &os) const;
84 
85   void AddInput(const NodePtr &new_input);
86   void SetInput(size_t i, const NodePtr &new_input);
87   void SetInputs(const NodePtrList &inputs);
88   void ReplaceWith(const NodePtr &other_node);
SetAttrs(const DAttrs & attrs)89   void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; }
SetAttr(const std::string & key,const ValuePtr & value)90   void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
91 
92   template <typename T>
As()93   std::shared_ptr<T> As() {
94     return std::static_pointer_cast<T>(shared_from_this());
95   }
96 
name()97   const std::string &name() const { return name_; }
attrs()98   const DAttrs &attrs() const { return attrs_; }
input(size_t i)99   const NodePtr &input(size_t i) const { return inputs_[i]; }
inputs()100   const NodePtrList &inputs() const { return inputs_; }
users()101   const std::unordered_map<Node *, std::set<size_t>> &users() const { return users_; }
102 
103  protected:
104   std::string name_;
105   DAttrs attrs_;
106   NodePtrList inputs_;
107   std::unordered_map<Node *, std::set<size_t>> users_;
108 
109  private:
110   // the nodes' users are only maintained by AddInput/SetInput.
AddUser(Node * user,size_t index)111   void AddUser(Node *user, size_t index) { users_[user].insert(index); }
RemoveUser(Node * user,size_t index)112   void RemoveUser(Node *user, size_t index) {
113     if (auto iter = users_.find(user); iter != users_.end()) {
114       iter->second.erase(index);
115       if (iter->second.empty()) {
116         users_.erase(iter);
117       }
118     }
119   }
120 };
121 
122 class ConstTensorNode : public Node {
123  public:
124   explicit ConstTensorNode(const tensor::TensorPtr &data, const std::string &name = "")
125       : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}, name), data_(data) {}
126   ~ConstTensorNode() = default;
127 
NodeType()128   NType NodeType() override { return NType::Value; }
Dump(std::ostringstream & os)129   void Dump(std::ostringstream &os) const override { os << ToString(); }
DumpTensor(std::ostringstream & os)130   void DumpTensor(std::ostringstream &os) const override { os << ToString(); }
ToString()131   std::string ToString() const { return data_->data().ToString(this->type, this->shape, false); }
data()132   const tensor::TensorPtr data() const { return data_; }
133 
134  protected:
135   tensor::TensorPtr data_;
136 };
137 
138 class ParamNode : public Node {
139  public:
ParamNode(const std::string & name,const NodeBase & baseinfo)140   ParamNode(const std::string &name, const NodeBase &baseinfo) : Node(baseinfo, name) {}
141   ~ParamNode() = default;
142 
Dump(std::ostringstream & os)143   void Dump(std::ostringstream &os) const override { DumpTensor(os); }
NodeType()144   NType NodeType() override { return NType::Parameter; }
145 };
146 
147 class OutputNode : public Node {
148  public:
OutputNode()149   OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {}
150   ~OutputNode() = default;
151 
Dump(std::ostringstream & os)152   void Dump(std::ostringstream &os) const override { ; }
NodeType()153   NType NodeType() override { return NType::Output; }
154 };
155 }  // namespace graphkernel
156 }  // namespace opt
157 }  // namespace mindspore
158 #endif
159