• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/node_builder.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/node_def_util.h"
20 #include "tensorflow/core/framework/versions.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 
23 namespace tensorflow {
24 
NodeOut(Node * n,int32 i)25 NodeBuilder::NodeOut::NodeOut(Node* n, int32 i)  // NOLINT(runtime/explicit)
26     : node(n),
27       error(false),
28       name(node != nullptr ? node->name() : (error = true, "")),
29       index(i),
30       dt(SafeGetOutput(node, i, &error)) {}
31 
NodeOut(OutputTensor t)32 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
33 
NodeOut(StringPiece n,int32 i,DataType t)34 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
35     : node(nullptr), error(false), name(n), index(i), dt(t) {}
36 
NodeOut()37 NodeBuilder::NodeOut::NodeOut()
38     : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
39 
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
41                          const OpRegistryInterface* op_registry,
42                          const NodeDebugInfo* debug)
43     : def_builder_(name, op_name, op_registry, debug) {}
44 
NodeBuilder(StringPiece name,const OpDef * op_def)45 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
46     : def_builder_(name, op_def) {}
47 
NodeBuilder(const NodeDefBuilder & def_builder)48 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
49     : def_builder_(def_builder) {}
50 
Input(Node * src_node,int src_index)51 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
52   inputs_.emplace_back(src_node, src_index);
53   DataType dt;
54   if (GetOutputType(src_node, src_index, &dt)) {
55     def_builder_.Input(src_node->name(), src_index, dt);
56   }
57   return *this;
58 }
59 
Input(NodeOut src)60 NodeBuilder& NodeBuilder::Input(NodeOut src) {
61   if (src.error) {
62     AddIndexError(src.node, src.index);
63   } else {
64     inputs_.emplace_back(src.node, src.index);
65     def_builder_.Input(src.name, src.index, src.dt);
66   }
67   return *this;
68 }
69 
Input(gtl::ArraySlice<NodeOut> src_list)70 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
71   std::vector<NodeDefBuilder::NodeOut> srcs;
72   srcs.reserve(src_list.size());
73   for (const auto& node_out : src_list) {
74     if (node_out.error) {
75       AddIndexError(node_out.node, node_out.index);
76     } else {
77       srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
78       inputs_.emplace_back(node_out.node, node_out.index);
79     }
80   }
81   def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
82   return *this;
83 }
84 
ControlInput(Node * src_node)85 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
86   control_inputs_.emplace_back(src_node);
87   def_builder_.ControlInput(src_node->name());
88   return *this;
89 }
90 
ControlInputs(gtl::ArraySlice<Node * > src_nodes)91 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
92   control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
93                          src_nodes.end());
94   for (const Node* src_node : src_nodes) {
95     def_builder_.ControlInput(src_node->name());
96   }
97   return *this;
98 }
99 
Device(StringPiece device_spec)100 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
101   def_builder_.Device(device_spec);
102   return *this;
103 }
104 
AssignedDevice(StringPiece device)105 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
106   assigned_device_ = string(device);
107   return *this;
108 }
109 
XlaCluster(StringPiece xla_cluster)110 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
111   def_builder_.Attr("_XlaCluster", xla_cluster);
112   return *this;
113 }
114 
Finalize(Graph * graph,Node ** created_node,bool consume)115 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
116   // In case of error, set *created_node to nullptr.
117   if (created_node != nullptr) *created_node = nullptr;
118   if (!errors_.empty()) {
119     return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
120   }
121 
122   NodeDef node_def;
123   TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
124   TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
125   TF_RETURN_IF_ERROR(
126       CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
127   Status status;
128   Node* node = graph->AddNode(std::move(node_def), &status);
129   if (!status.ok()) return status;
130 
131   node->set_assigned_device_name(assigned_device_);
132 
133   for (size_t i = 0; i < inputs_.size(); ++i) {
134     if (inputs_[i].node != nullptr) {  // Skip back edges.
135       graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
136     }
137   }
138   for (Node* control_input : control_inputs_) {
139     graph->AddControlEdge(control_input, node);
140   }
141   if (created_node != nullptr) *created_node = node;
142   return Status::OK();
143 }
144 
AddIndexError(const Node * node,int i)145 void NodeBuilder::AddIndexError(const Node* node, int i) {
146   if (node == nullptr) {
147     errors_.emplace_back(
148         strings::StrCat("Attempt to add nullptr Node to node with type ",
149                         def_builder_.op_def().name()));
150   } else {
151     errors_.emplace_back(strings::StrCat(
152         "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
153         node->num_outputs(), ") to node with type ",
154         def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
155   }
156 }
157 
GetOutputType(const Node * node,int i,DataType * dt)158 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
159   bool error;
160   *dt = SafeGetOutput(node, i, &error);
161   if (error) AddIndexError(node, i);
162   return !error;
163 }
164 
165 }  // namespace tensorflow
166