1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <utility>
18 #include "backend/common/graph_kernel/model/node.h"
19 #include "abstract/utils.h"
20
21 namespace mindspore::graphkernel::inner {
ConstScalarNode(const ValuePtr & data)22 ConstScalarNode::ConstScalarNode(const ValuePtr &data)
23 : Node({DShape({}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) {
24 auto type_ptr = data->ToAbstract()->BuildType();
25 MS_EXCEPTION_IF_NULL(type_ptr);
26 type = type_ptr->type_id();
27 }
28
ConstTupleNode(const ValuePtr & data,const size_t len)29 ConstTupleNode::ConstTupleNode(const ValuePtr &data, const size_t len)
30 : Node({DShape({SizeToLong(len)}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) {
31 auto type_ptr = data->ToAbstract()->BuildType();
32 MS_EXCEPTION_IF_NULL(type_ptr);
33 type = type_ptr->type_id();
34 }
35
SetBaseInfo(const NodeBaseList & baseinfo)36 void Node::SetBaseInfo(const NodeBaseList &baseinfo) {
37 this->shape = baseinfo[0].shape;
38 this->type = baseinfo[0].type;
39 this->format = baseinfo[0].format;
40 this->symbolic_shape = baseinfo[0].symbolic_shape;
41 if (baseinfo.size() > 1) {
42 outputs_ = baseinfo;
43 }
44 }
45
ToString() const46 std::string Node::ToString() const {
47 std::ostringstream oss;
48 oss << debug_name() << "[";
49 for (size_t i = 0; i < shape.size(); i++) {
50 oss << shape[i];
51 if (i + 1 < shape.size()) {
52 oss << ",";
53 }
54 }
55 auto type_str = (type == TypeId::kNumberTypeBegin) ? "NOTYPE" : TypeIdToString(type);
56 oss << "]{" << type_str << "x" << format << "}";
57 return oss.str();
58 }
59
ToAbstract() const60 abstract::AbstractBasePtr Node::ToAbstract() const {
61 if (outputs_.empty()) {
62 return std::make_shared<abstract::AbstractTensor>(TypeIdToType(this->type), this->shape);
63 }
64 AbstractBasePtrList abs_list(outputs_.size());
65 (void)std::transform(outputs_.cbegin(), outputs_.cend(), abs_list.begin(), [](const NodeBase &node) {
66 return std::make_shared<abstract::AbstractTensor>(TypeIdToType(node.type), node.shape);
67 });
68 return std::make_shared<abstract::AbstractTuple>(std::move(abs_list));
69 }
70
AddInput(const NodePtr & new_input)71 void Node::AddInput(const NodePtr &new_input) {
72 MS_EXCEPTION_IF_NULL(new_input);
73 new_input->AddUser(this, inputs_.size());
74 (void)inputs_.emplace_back(new_input);
75 }
76
SetInput(size_t i,const NodePtr & new_input)77 void Node::SetInput(size_t i, const NodePtr &new_input) {
78 MS_EXCEPTION_IF_NULL(new_input);
79 if (i >= inputs_.size()) {
80 MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range [0, " << inputs_.size() << ")";
81 }
82 auto &old_input = inputs_[i];
83 old_input->RemoveUser(this, i);
84 new_input->AddUser(this, i);
85 inputs_[i] = new_input;
86 }
87
SetInputs(const NodePtrList & inputs)88 void Node::SetInputs(const NodePtrList &inputs) {
89 ClearInputs();
90 inputs_.reserve(inputs.size());
91 for (const auto &inp : inputs) {
92 AddInput(inp);
93 }
94 }
95
ClearInputs()96 void Node::ClearInputs() noexcept {
97 if (!inputs_.empty()) {
98 // remove the original inputs
99 for (size_t i = 0; i < inputs_.size(); i++) {
100 inputs_[i]->RemoveUser(this, i);
101 }
102 inputs_.clear();
103 }
104 }
105
ReplaceWith(const NodePtr & other_node)106 void Node::ReplaceWith(const NodePtr &other_node) {
107 if (this->users_.empty()) {
108 return;
109 }
110 // the users_ will be changed, so we copy the users before traversal
111 auto users = this->users_;
112 for (auto &user : users) {
113 for (const auto &idx : user.second) {
114 user.first->SetInput(idx, other_node);
115 }
116 }
117 }
118
RemoveUser(Node * const user,size_t index)119 void Node::RemoveUser(Node *const user, size_t index) {
120 if (auto iter = users_.find(user); iter != users_.end()) {
121 (void)iter->second.erase(index);
122 if (iter->second.empty()) {
123 (void)users_.erase(iter);
124 }
125 }
126 }
127
tensor_size(bool in_bytes) const128 size_t Node::tensor_size(bool in_bytes) const {
129 if (IsDynamic(this->shape)) {
130 return 0;
131 }
132 size_t size = LongToSize(abstract::ShapeSize(this->shape));
133 return in_bytes ? abstract::TypeIdSize(this->type) * size : size;
134 }
135 } // namespace mindspore::graphkernel::inner
136