• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/node_builder.h"
17 
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/full_type_util.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/statusor.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 
30 namespace tensorflow {
31 
NodeOut(Node * n,int32_t i)32 NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i)  // NOLINT(runtime/explicit)
33     : node(n),
34       error(false),
35       name(node != nullptr ? node->name() : (error = true, "")),
36       index(i),
37       dt(SafeGetOutput(node, i, &error)) {}
38 
NodeOut(OutputTensor t)39 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
40 
NodeOut(StringPiece n,int32_t i,DataType t)41 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t)
42     : node(nullptr), error(false), name(n), index(i), dt(t) {}
43 
NodeOut()44 NodeBuilder::NodeOut::NodeOut()
45     : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
46 
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)47 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
48                          const OpRegistryInterface* op_registry,
49                          const NodeDebugInfo* debug)
50     : def_builder_(name, op_name, op_registry, debug) {}
51 
NodeBuilder(StringPiece name,const OpDef * op_def)52 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
53     : def_builder_(name, op_def) {}
54 
NodeBuilder(const NodeDefBuilder & def_builder)55 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
56     : def_builder_(def_builder) {}
57 
Input(Node * src_node,int src_index)58 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
59   inputs_.emplace_back(src_node, src_index);
60   DataType dt;
61   if (GetOutputType(src_node, src_index, &dt)) {
62     def_builder_.Input(src_node->name(), src_index, dt);
63   }
64   return *this;
65 }
66 
Input(NodeOut src)67 NodeBuilder& NodeBuilder::Input(NodeOut src) {
68   if (src.error) {
69     AddIndexError(src.node, src.index);
70   } else {
71     inputs_.emplace_back(src.node, src.index);
72     def_builder_.Input(src.name, src.index, src.dt);
73   }
74   return *this;
75 }
76 
Input(gtl::ArraySlice<NodeOut> src_list)77 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
78   std::vector<NodeDefBuilder::NodeOut> srcs;
79   srcs.reserve(src_list.size());
80   for (const auto& node_out : src_list) {
81     if (node_out.error) {
82       AddIndexError(node_out.node, node_out.index);
83     } else {
84       srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
85       inputs_.emplace_back(node_out.node, node_out.index);
86     }
87   }
88   def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
89   return *this;
90 }
91 
ControlInput(Node * src_node)92 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
93   control_inputs_.emplace_back(src_node);
94   def_builder_.ControlInput(src_node->name());
95   return *this;
96 }
97 
ControlInputs(gtl::ArraySlice<Node * > src_nodes)98 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
99   control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
100                          src_nodes.end());
101   for (const Node* src_node : src_nodes) {
102     def_builder_.ControlInput(src_node->name());
103   }
104   return *this;
105 }
106 
Device(StringPiece device_spec)107 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
108   def_builder_.Device(device_spec);
109   return *this;
110 }
111 
AssignedDevice(StringPiece device)112 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
113   assigned_device_ = string(device);
114   return *this;
115 }
116 
XlaCluster(StringPiece xla_cluster)117 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
118   def_builder_.Attr("_XlaCluster", xla_cluster);
119   return *this;
120 }
121 
122 namespace {
123 
run_type_constructor(Graph * graph,const NodeDef & node_def)124 StatusOr<FullTypeDef> run_type_constructor(Graph* graph,
125                                            const NodeDef& node_def) {
126   // TODO(mdan): Decouple this from graph building, or run again after.
127   const auto* op_registry = graph->op_registry();
128   const tensorflow::OpRegistrationData* op_reg_data;
129   TF_RETURN_IF_ERROR(op_registry->LookUp(node_def.op(), &op_reg_data));
130   if (op_reg_data->type_ctor == nullptr) {
131     // Default to the default unset type.
132     return FullTypeDef();
133   }
134 
135   // TODO(mdan): Do we still need to save this info in the Graph object?
136   return full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def);
137 }
138 
139 }  // namespace
140 
Finalize(Graph * graph,Node ** created_node,bool consume)141 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
142   // In case of error, set *created_node to nullptr.
143   if (created_node != nullptr) *created_node = nullptr;
144   if (!errors_.empty()) {
145     return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
146   }
147 
148   NodeDef node_def;
149   TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
150   TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
151   TF_RETURN_IF_ERROR(
152       CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
153 
154   const auto ret = run_type_constructor(graph, node_def);
155   TF_RETURN_IF_ERROR(ret.status());
156 
157   Status status;
158   Node* node = graph->AddNode(std::move(node_def), &status);
159   TF_RETURN_IF_ERROR(status);
160 
161   FullTypeDef ft = ret.ValueOrDie();
162   if (ft.type_id() != TFT_UNSET) {
163     graph->SetNodeType(node->name(), ft);
164   }
165 
166   node->set_assigned_device_name(assigned_device_);
167 
168   for (size_t i = 0; i < inputs_.size(); ++i) {
169     if (inputs_[i].node != nullptr) {  // Skip back edges.
170       graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
171     }
172   }
173   for (Node* control_input : control_inputs_) {
174     graph->AddControlEdge(control_input, node);
175   }
176   if (created_node != nullptr) *created_node = node;
177   return Status::OK();
178 }
179 
AddIndexError(const Node * node,int i)180 void NodeBuilder::AddIndexError(const Node* node, int i) {
181   if (node == nullptr) {
182     errors_.emplace_back(
183         strings::StrCat("Attempt to add nullptr Node to node with type ",
184                         def_builder_.op_def().name()));
185   } else {
186     errors_.emplace_back(strings::StrCat(
187         "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
188         node->num_outputs(), ") to node with type ",
189         def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
190   }
191 }
192 
GetOutputType(const Node * node,int i,DataType * dt)193 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
194   bool error;
195   *dt = SafeGetOutput(node, i, &error);
196   if (error) AddIndexError(node, i);
197   return !error;
198 }
199 
200 }  // namespace tensorflow
201