1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/node_builder.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/node_def_util.h"
20 #include "tensorflow/core/framework/versions.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22
23 namespace tensorflow {
24
NodeOut(Node * n,int32 i)25 NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit)
26 : node(n),
27 error(false),
28 name(node != nullptr ? node->name() : (error = true, "")),
29 index(i),
30 dt(SafeGetOutput(node, i, &error)) {}
31
NodeOut(OutputTensor t)32 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
33
NodeOut(StringPiece n,int32 i,DataType t)34 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
35 : node(nullptr), error(false), name(n), index(i), dt(t) {}
36
NodeOut()37 NodeBuilder::NodeOut::NodeOut()
38 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
39
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
41 const OpRegistryInterface* op_registry,
42 const NodeDebugInfo* debug)
43 : def_builder_(name, op_name, op_registry, debug) {}
44
NodeBuilder(StringPiece name,const OpDef * op_def)45 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
46 : def_builder_(name, op_def) {}
47
NodeBuilder(const NodeDefBuilder & def_builder)48 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
49 : def_builder_(def_builder) {}
50
Input(Node * src_node,int src_index)51 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
52 inputs_.emplace_back(src_node, src_index);
53 DataType dt;
54 if (GetOutputType(src_node, src_index, &dt)) {
55 def_builder_.Input(src_node->name(), src_index, dt);
56 }
57 return *this;
58 }
59
Input(NodeOut src)60 NodeBuilder& NodeBuilder::Input(NodeOut src) {
61 if (src.error) {
62 AddIndexError(src.node, src.index);
63 } else {
64 inputs_.emplace_back(src.node, src.index);
65 def_builder_.Input(src.name, src.index, src.dt);
66 }
67 return *this;
68 }
69
Input(gtl::ArraySlice<NodeOut> src_list)70 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
71 std::vector<NodeDefBuilder::NodeOut> srcs;
72 srcs.reserve(src_list.size());
73 for (const auto& node_out : src_list) {
74 if (node_out.error) {
75 AddIndexError(node_out.node, node_out.index);
76 } else {
77 srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
78 inputs_.emplace_back(node_out.node, node_out.index);
79 }
80 }
81 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
82 return *this;
83 }
84
ControlInput(Node * src_node)85 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
86 control_inputs_.emplace_back(src_node);
87 def_builder_.ControlInput(src_node->name());
88 return *this;
89 }
90
ControlInputs(gtl::ArraySlice<Node * > src_nodes)91 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
92 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
93 src_nodes.end());
94 for (const Node* src_node : src_nodes) {
95 def_builder_.ControlInput(src_node->name());
96 }
97 return *this;
98 }
99
Device(StringPiece device_spec)100 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
101 def_builder_.Device(device_spec);
102 return *this;
103 }
104
AssignedDevice(StringPiece device)105 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
106 assigned_device_ = string(device);
107 return *this;
108 }
109
XlaCluster(StringPiece xla_cluster)110 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
111 def_builder_.Attr("_XlaCluster", xla_cluster);
112 return *this;
113 }
114
Finalize(Graph * graph,Node ** created_node,bool consume)115 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
116 // In case of error, set *created_node to nullptr.
117 if (created_node != nullptr) *created_node = nullptr;
118 if (!errors_.empty()) {
119 return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
120 }
121
122 NodeDef node_def;
123 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
124 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
125 TF_RETURN_IF_ERROR(
126 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
127 Status status;
128 Node* node = graph->AddNode(std::move(node_def), &status);
129 if (!status.ok()) return status;
130
131 node->set_assigned_device_name(assigned_device_);
132
133 for (size_t i = 0; i < inputs_.size(); ++i) {
134 if (inputs_[i].node != nullptr) { // Skip back edges.
135 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
136 }
137 }
138 for (Node* control_input : control_inputs_) {
139 graph->AddControlEdge(control_input, node);
140 }
141 if (created_node != nullptr) *created_node = node;
142 return Status::OK();
143 }
144
AddIndexError(const Node * node,int i)145 void NodeBuilder::AddIndexError(const Node* node, int i) {
146 if (node == nullptr) {
147 errors_.emplace_back(
148 strings::StrCat("Attempt to add nullptr Node to node with type ",
149 def_builder_.op_def().name()));
150 } else {
151 errors_.emplace_back(strings::StrCat(
152 "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
153 node->num_outputs(), ") to node with type ",
154 def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
155 }
156 }
157
GetOutputType(const Node * node,int i,DataType * dt)158 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
159 bool error;
160 *dt = SafeGetOutput(node, i, &error);
161 if (error) AddIndexError(node, i);
162 return !error;
163 }
164
165 } // namespace tensorflow
166