• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <utility>
18 #include "backend/common/graph_kernel/model/node.h"
19 #include "abstract/utils.h"
20 
21 namespace mindspore::graphkernel::inner {
ConstScalarNode(const ValuePtr & data)22 ConstScalarNode::ConstScalarNode(const ValuePtr &data)
23     : Node({DShape({}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) {
24   auto type_ptr = data->ToAbstract()->BuildType();
25   MS_EXCEPTION_IF_NULL(type_ptr);
26   type = type_ptr->type_id();
27 }
28 
ConstTupleNode(const ValuePtr & data,const size_t len)29 ConstTupleNode::ConstTupleNode(const ValuePtr &data, const size_t len)
30     : Node({DShape({SizeToLong(len)}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) {
31   auto type_ptr = data->ToAbstract()->BuildType();
32   MS_EXCEPTION_IF_NULL(type_ptr);
33   type = type_ptr->type_id();
34 }
35 
SetBaseInfo(const NodeBaseList & baseinfo)36 void Node::SetBaseInfo(const NodeBaseList &baseinfo) {
37   this->shape = baseinfo[0].shape;
38   this->type = baseinfo[0].type;
39   this->format = baseinfo[0].format;
40   this->symbolic_shape = baseinfo[0].symbolic_shape;
41   if (baseinfo.size() > 1) {
42     outputs_ = baseinfo;
43   }
44 }
45 
ToString() const46 std::string Node::ToString() const {
47   std::ostringstream oss;
48   oss << debug_name() << "[";
49   for (size_t i = 0; i < shape.size(); i++) {
50     oss << shape[i];
51     if (i + 1 < shape.size()) {
52       oss << ",";
53     }
54   }
55   auto type_str = (type == TypeId::kNumberTypeBegin) ? "NOTYPE" : TypeIdToString(type);
56   oss << "]{" << type_str << "x" << format << "}";
57   return oss.str();
58 }
59 
ToAbstract() const60 abstract::AbstractBasePtr Node::ToAbstract() const {
61   if (outputs_.empty()) {
62     return std::make_shared<abstract::AbstractTensor>(TypeIdToType(this->type), this->shape);
63   }
64   AbstractBasePtrList abs_list(outputs_.size());
65   (void)std::transform(outputs_.cbegin(), outputs_.cend(), abs_list.begin(), [](const NodeBase &node) {
66     return std::make_shared<abstract::AbstractTensor>(TypeIdToType(node.type), node.shape);
67   });
68   return std::make_shared<abstract::AbstractTuple>(std::move(abs_list));
69 }
70 
AddInput(const NodePtr & new_input)71 void Node::AddInput(const NodePtr &new_input) {
72   MS_EXCEPTION_IF_NULL(new_input);
73   new_input->AddUser(this, inputs_.size());
74   (void)inputs_.emplace_back(new_input);
75 }
76 
SetInput(size_t i,const NodePtr & new_input)77 void Node::SetInput(size_t i, const NodePtr &new_input) {
78   MS_EXCEPTION_IF_NULL(new_input);
79   if (i >= inputs_.size()) {
80     MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range [0, " << inputs_.size() << ")";
81   }
82   auto &old_input = inputs_[i];
83   old_input->RemoveUser(this, i);
84   new_input->AddUser(this, i);
85   inputs_[i] = new_input;
86 }
87 
SetInputs(const NodePtrList & inputs)88 void Node::SetInputs(const NodePtrList &inputs) {
89   ClearInputs();
90   inputs_.reserve(inputs.size());
91   for (const auto &inp : inputs) {
92     AddInput(inp);
93   }
94 }
95 
ClearInputs()96 void Node::ClearInputs() noexcept {
97   if (!inputs_.empty()) {
98     // remove the original inputs
99     for (size_t i = 0; i < inputs_.size(); i++) {
100       inputs_[i]->RemoveUser(this, i);
101     }
102     inputs_.clear();
103   }
104 }
105 
ReplaceWith(const NodePtr & other_node)106 void Node::ReplaceWith(const NodePtr &other_node) {
107   if (this->users_.empty()) {
108     return;
109   }
110   // the users_ will be changed, so we copy the users before traversal
111   auto users = this->users_;
112   for (auto &user : users) {
113     for (const auto &idx : user.second) {
114       user.first->SetInput(idx, other_node);
115     }
116   }
117 }
118 
RemoveUser(Node * const user,size_t index)119 void Node::RemoveUser(Node *const user, size_t index) {
120   if (auto iter = users_.find(user); iter != users_.end()) {
121     (void)iter->second.erase(index);
122     if (iter->second.empty()) {
123       (void)users_.erase(iter);
124     }
125   }
126 }
127 
tensor_size(bool in_bytes) const128 size_t Node::tensor_size(bool in_bytes) const {
129   if (IsDynamic(this->shape)) {
130     return 0;
131   }
132   size_t size = LongToSize(abstract::ShapeSize(this->shape));
133   return in_bytes ? abstract::TypeIdSize(this->type) * size : size;
134 }
135 }  // namespace mindspore::graphkernel::inner
136