• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_properties.h"
26 #include "tensorflow/core/framework/op_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/while_context.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/public/version.h"
39 
40 namespace tensorflow {
41 
42 const int Graph::kControlSlot = -1;
43 
44 // Node
GetNodeClassForOp(const std::string & ts)45 Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) {
46   static const absl::flat_hash_map<std::string, Node::NodeClass>*
47       kNodeClassTable =
48 #define REF_CLASS(key, value) \
49   {key, value}, { "Ref" key, value }
50           new absl::flat_hash_map<std::string, Node::NodeClass>({
51               // Keep in same order as NodeClass values
52               REF_CLASS("Switch", NC_SWITCH),
53               REF_CLASS("_SwitchN", NC_SWITCH),
54               REF_CLASS("Merge", NC_MERGE),
55               REF_CLASS("Enter", NC_ENTER),
56               REF_CLASS("Exit", NC_EXIT),
57               REF_CLASS("NextIteration", NC_NEXT_ITERATION),
58               {"LoopCond", NC_LOOP_COND},
59               {"ControlTrigger", NC_CONTROL_TRIGGER},
60               {"_Send", NC_SEND},
61               {"_HostSend", NC_HOST_SEND},
62               {"_Recv", NC_RECV},
63               {"_HostRecv", NC_HOST_RECV},
64               {"Const", NC_CONSTANT},
65               {"HostConst", NC_CONSTANT},
66               {"Variable", NC_VARIABLE},
67               {"VariableV2", NC_VARIABLE},
68               REF_CLASS("Identity", NC_IDENTITY),
69               {"GetSessionHandle", NC_GET_SESSION_HANDLE},
70               {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
71               {"GetSessionTensor", NC_GET_SESSION_TENSOR},
72               {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
73               {"Size", NC_METADATA},
74               {"Shape", NC_METADATA},
75               {"Rank", NC_METADATA},
76               {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
77               {"CollectiveReduce", NC_COLLECTIVE},
78               {"CollectiveBcastSend", NC_COLLECTIVE},
79               {"CollectiveBcastRecv", NC_COLLECTIVE},
80               {"CollectiveGather", NC_COLLECTIVE},
81               {"FakeParam", NC_FAKE_PARAM},
82               {"PartitionedCall", NC_PARTITIONED_CALL},
83               {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
84               {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
85               {"If", NC_IF},
86               {"StatelessIf", NC_IF},
87               {"While", NC_WHILE},
88               {"StatelessWhile", NC_WHILE},
89               {"Case", NC_CASE},
90               {"StatelessCase", NC_CASE},
91               // Not using the constants defined in FunctionLibraryDefinition
92               // for the
93               // 4 ops below because android inference library does not link
94               // tf.function related files.
95               {"_Arg", NC_ARG},
96               {"_DeviceArg", NC_ARG},
97               {"_Retval", NC_RETVAL},
98               {"_DeviceRetval", NC_RETVAL},
99               {"_XlaMerge", NC_MERGE},
100           });
101 #undef REF_CLASS
102 
103   auto it = kNodeClassTable->find(ts);
104   if (it != kNodeClassTable->end()) {
105     return it->second;
106   } else {
107     return NC_OTHER;
108   }
109 }
110 
DebugString() const111 std::string Node::DebugString() const {
112   std::string ret = strings::StrCat("{name:'", name(), "' id:", id_);
113   if (IsSource()) {
114     strings::StrAppend(&ret, " source}");
115   } else if (IsSink()) {
116     strings::StrAppend(&ret, " sink}");
117   } else {
118     strings::StrAppend(&ret, " op device:", "{requested: '", requested_device(),
119                        "', assigned: '", assigned_device_name(), "'}", " def:{",
120                        SummarizeNode(*this), "}}");
121   }
122   return ret;
123 }
124 
Node()125 Node::Node()
126     : id_(-1),
127       cost_id_(-1),
128       class_(NC_UNINITIALIZED),
129       props_(nullptr),
130       assigned_device_name_index_(0),
131       while_ctx_(nullptr) {}
132 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)133 void Node::Initialize(int id, int cost_id,
134                       std::shared_ptr<NodeProperties> props,
135                       Node::NodeClass node_class) {
136   DCHECK_EQ(id_, -1);
137   DCHECK(in_edges_.empty());
138   DCHECK(out_edges_.empty());
139   id_ = id;
140   cost_id_ = cost_id;
141 
142   props_ = std::move(props);
143   class_ = node_class;
144 }
145 
Clear()146 void Node::Clear() {
147   in_edges_.clear();
148   out_edges_.clear();
149   id_ = -1;
150   cost_id_ = -1;
151   class_ = NC_UNINITIALIZED;
152   props_.reset();
153   assigned_device_name_index_ = 0;
154 }
155 
UpdateProperties()156 void Node::UpdateProperties() {
157   DataTypeVector inputs;
158   DataTypeVector outputs;
159   Status status =
160       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
161   if (!status.ok()) {
162     LOG(ERROR) << "Failed at updating node: " << status;
163     return;
164   }
165   if (props_->input_types != inputs || props_->output_types != outputs) {
166     if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
167       props_->input_types = inputs;
168       props_->input_types_slice = props_->input_types;
169       props_->output_types = outputs;
170       props_->output_types_slice = props_->output_types;
171     } else {
172       props_ = std::make_shared<NodeProperties>(
173           props_->op_def, std::move(props_->node_def), inputs, outputs);
174     }
175   }
176 }
177 
ClearTypeInfo()178 void Node::ClearTypeInfo() {
179   if (props_->node_def.has_experimental_type()) {
180     MaybeCopyOnWrite();
181     props_->node_def.clear_experimental_type();
182   }
183 }
184 
name() const185 const std::string& Node::name() const { return props_->node_def.name(); }
type_string() const186 const std::string& Node::type_string() const { return props_->node_def.op(); }
def() const187 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const188 const OpDef& Node::op_def() const { return *props_->op_def; }
189 
mutable_def()190 NodeDef* Node::mutable_def() { return &props_->node_def; }
191 
num_inputs() const192 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32_t i) const193 DataType Node::input_type(int32_t i) const { return props_->input_types[i]; }
input_types() const194 const DataTypeVector& Node::input_types() const { return props_->input_types; }
195 
num_outputs() const196 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32_t o) const197 DataType Node::output_type(int32_t o) const { return props_->output_types[o]; }
output_types() const198 const DataTypeVector& Node::output_types() const {
199   return props_->output_types;
200 }
201 
attrs() const202 AttrSlice Node::attrs() const { return AttrSlice(def()); }
203 
requested_inputs() const204 const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const {
205   return def().input();
206 }
207 
requested_device() const208 const std::string& Node::requested_device() const { return def().device(); }
209 
out_nodes() const210 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
211   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
212                          NeighborIter(out_edges_.end(), false));
213 }
214 
in_nodes() const215 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
216   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
217                          NeighborIter(in_edges_.end(), true));
218 }
219 
MaybeCopyOnWrite()220 void Node::MaybeCopyOnWrite() {
221   // TODO(mdan): As nodes become more dynamic, this may not be worth the cost.
222   // NodeProperties may be shared between Nodes. Make a copy if so.
223   if (!props_.unique()) {
224     props_ = std::make_shared<NodeProperties>(*props_);
225   }
226 }
227 
AddAttrHelper(const std::string & name)228 AttrValue* Node::AddAttrHelper(const std::string& name) {
229   MaybeCopyOnWrite();
230   return &((*props_->node_def.mutable_attr())[name]);
231 }
232 
ClearAttr(const std::string & name)233 void Node::ClearAttr(const std::string& name) {
234   MaybeCopyOnWrite();
235   (*props_->node_def.mutable_attr()).erase(name);
236 }
237 
set_name(std::string name)238 void Node::set_name(std::string name) {
239   MaybeCopyOnWrite();
240   props_->node_def.set_name(std::move(name));
241 }
242 
set_requested_device(const std::string & device)243 void Node::set_requested_device(const std::string& device) {
244   MaybeCopyOnWrite();
245   props_->node_def.set_device(device);
246 }
247 
set_original_node_names(const std::vector<std::string> & names)248 void Node::set_original_node_names(const std::vector<std::string>& names) {
249   MaybeCopyOnWrite();
250   props_->node_def.mutable_experimental_debug_info()
251       ->clear_original_node_names();
252   if (!names.empty()) {
253     *props_->node_def.mutable_experimental_debug_info()
254          ->mutable_original_node_names() = {names.begin(), names.end()};
255   }
256 }
257 
set_original_func_names(const std::vector<std::string> & names)258 void Node::set_original_func_names(const std::vector<std::string>& names) {
259   MaybeCopyOnWrite();
260   props_->node_def.mutable_experimental_debug_info()
261       ->clear_original_func_names();
262   if (!names.empty()) {
263     *props_->node_def.mutable_experimental_debug_info()
264          ->mutable_original_func_names() = {names.begin(), names.end()};
265   }
266 }
267 
input_edge(int idx,const Edge ** e) const268 Status Node::input_edge(int idx, const Edge** e) const {
269   if (idx < 0 || idx >= num_inputs()) {
270     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
271                                    name(), " only has ", num_inputs(),
272                                    " inputs.");
273   }
274 
275   // This does a linear search over the edges.  In the common case,
276   // the number of elements is small enough that this search isn't
277   // expensive.  Should it become a bottleneck, one can make an
278   // optimization where, if the number of edges is small, we use
279   // linear iteration, and if the number of edges is large, we perform
280   // an indexing step during construction that keeps an array of Edges
281   // indexed by pointer.  This would keep the size of each Node small
282   // in the common case but make this function faster when the number
283   // of edges is large.
284   for (const Edge* edge : in_edges()) {
285     if (edge->dst_input() == idx) {
286       *e = edge;
287       return OkStatus();
288     }
289   }
290 
291   return errors::NotFound("Could not find input edge ", idx, " for ", name());
292 }
293 
294 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const295 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
296   input_edges->clear();
297   input_edges->resize(num_inputs(), nullptr);
298 
299   for (const Edge* edge : in_edges()) {
300     if (edge->IsControlEdge()) continue;
301     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
302       return errors::Internal("Invalid edge input number ", edge->dst_input());
303     }
304     if ((*input_edges)[edge->dst_input()] != nullptr) {
305       return errors::Internal("Duplicate edge input number: ",
306                               edge->dst_input());
307     }
308     (*input_edges)[edge->dst_input()] = edge;
309   }
310 
311   for (int i = 0; i < num_inputs(); ++i) {
312     if ((*input_edges)[i] == nullptr) {
313       return errors::InvalidArgument("Missing edge input number: ", i);
314     }
315   }
316   return OkStatus();
317 }
318 
input_node(int idx,Node ** n) const319 Status Node::input_node(int idx, Node** n) const {
320   const Edge* e;
321   TF_RETURN_IF_ERROR(input_edge(idx, &e));
322   if (e == nullptr) {
323     *n = nullptr;
324   } else {
325     *n = e->src();
326   }
327   return OkStatus();
328 }
329 
input_node(int idx,const Node ** const_n) const330 Status Node::input_node(int idx, const Node** const_n) const {
331   Node* n;
332   TF_RETURN_IF_ERROR(input_node(idx, &n));
333   *const_n = n;
334   return OkStatus();
335 }
336 
input_tensor(int idx,OutputTensor * t) const337 Status Node::input_tensor(int idx, OutputTensor* t) const {
338   const Edge* e;
339   TF_RETURN_IF_ERROR(input_edge(idx, &e));
340   DCHECK(e != nullptr);
341   *t = OutputTensor(e->src(), e->src_output());
342   return OkStatus();
343 }
344 
345 // NodeDebugInfo
346 
NodeDebugInfo(const Node & n)347 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)348 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
349     : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
350                     ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)351 NodeDebugInfo::NodeDebugInfo(
352     StringPiece node_name, bool has_experimental_debug_info,
353     const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
354     : name(node_name) {
355   if (has_experimental_debug_info) {
356     const auto& node_names = experimental_debug_info.original_node_names();
357     original_node_names.assign(node_names.begin(), node_names.end());
358     const auto& func_names = experimental_debug_info.original_func_names();
359     original_func_names.assign(func_names.begin(), func_names.end());
360   }
361 }
362 // InputTensor
363 
operator ==(const InputTensor & other) const364 bool InputTensor::operator==(const InputTensor& other) const {
365   return node == other.node && index == other.index;
366 }
367 
operator ()(InputTensor const & s) const368 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
369   return Hash64Combine(std::hash<const Node*>()(s.node),
370                        std::hash<int>()(s.index));
371 }
372 
373 // OutputTensor
374 
operator ==(const OutputTensor & other) const375 bool OutputTensor::operator==(const OutputTensor& other) const {
376   return node == other.node && index == other.index;
377 }
378 
operator ()(OutputTensor const & s) const379 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
380   return Hash64Combine(std::hash<const Node*>()(s.node),
381                        std::hash<int>()(s.index));
382 }
383 
384 // Graph
385 
Graph(const OpRegistryInterface * ops)386 Graph::Graph(const OpRegistryInterface* ops)
387     : ops_(ops, FunctionDefLibrary()),
388       versions_(new VersionDef),
389       arena_(8 << 10 /* 8kB */) {
390   versions_->set_producer(TF_GRAPH_DEF_VERSION);
391   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
392 
393   // Initialize the name interning table for assigned_device_name.
394   device_names_.push_back("");
395   DCHECK_EQ(0, InternDeviceName(""));
396 
397   // Source and sink have no endpoints, just control edges.
398   NodeDef def;
399   def.set_name("_SOURCE");
400   def.set_op("NoOp");
401   Status status;
402   Node* source = AddNode(def, &status);
403   TF_CHECK_OK(status);
404   CHECK_EQ(source->id(), kSourceId);
405 
406   def.set_name("_SINK");
407   Node* sink = AddNode(def, &status);
408   TF_CHECK_OK(status);
409   CHECK_EQ(sink->id(), kSinkId);
410 
411   AddControlEdge(source, sink);
412 }
413 
Graph(const FunctionLibraryDefinition & flib_def)414 Graph::Graph(const FunctionLibraryDefinition& flib_def)
415     : Graph(flib_def.default_registry()) {
416   // Need a new-enough consumer to support the functions we add to the graph.
417   if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
418     versions_->set_min_consumer(12);
419   }
420   Status s = ops_.AddLibrary(flib_def);
421   CHECK(s.ok()) << s.error_message();
422 }
423 
~Graph()424 Graph::~Graph() {
425   // Manually call the destructors for all the Nodes we constructed using
426   // placement new.
427   for (Node* node : nodes_) {
428     if (node != nullptr) {
429       node->~Node();
430     }
431   }
432   for (Node* node : free_nodes_) {
433     node->~Node();
434   }
435   // Edges have no destructor, and we arena-allocated them, so no need to
436   // destroy them.
437 }
438 
Clone()439 std::unique_ptr<Graph> Graph::Clone() {
440   std::unique_ptr<Graph> new_graph(new Graph(flib_def()));
441   new_graph->Copy(*this);
442   return new_graph;
443 }
444 
Clear()445 void Graph::Clear() {
446   // Do a direct iteration clearing nodes removing the RemoveNode helper method.
447   // This could avoid this helper and clear directly if it becomes performance
448   // sensitive.
449   for (Node* n : nodes()) {
450     if (!n->IsSource() && !n->IsSink()) RemoveNode(n);
451   }
452 }
453 
versions() const454 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)455 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
456 
Copy(const Graph & src)457 void Graph::Copy(const Graph& src) {
458   SetConstructionContext(src.GetConstructionContextInternal());
459   for (Node* n : nodes()) {
460     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
461   }
462 
463   // Copy GraphDef versions
464   set_versions(src.versions());
465 
466   // Copy the nodes.
467   // "Node in src" -> "Node in *dest"
468   gtl::FlatMap<const Node*, Node*> node_map;
469   node_map.reserve(src.num_nodes());
470   node_map[src.source_node()] = source_node();
471   node_map[src.sink_node()] = sink_node();
472   for (Node* n : src.op_nodes()) {
473     auto copy = CopyNode(n);
474     copy->in_edges_.reserve(n->in_edges().size());
475     copy->out_edges_.reserve(n->out_edges().size());
476     node_map[n] = copy;
477   }
478 
479   // Copy the edges
480   edges_.reserve(src.num_edges());
481   for (const Edge* e : src.edges()) {
482     Node* src_copy = node_map[e->src()];
483     Node* dst_copy = node_map[e->dst()];
484     AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
485   }
486 }
487 
AddNode(NodeDef node_def)488 StatusOr<Node*> Graph::AddNode(NodeDef node_def) {
489   Status s;
490   Node* out = AddNode(std::move(node_def), &s);
491   TF_RETURN_IF_ERROR(s);
492   return out;
493 }
494 
AddNode(NodeDef node_def,Status * status)495 Node* Graph::AddNode(NodeDef node_def, Status* status) {
496   const OpRegistrationData* op_reg_data;
497   status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
498   if (!status->ok()) return nullptr;
499 
500   DataTypeVector inputs;
501   DataTypeVector outputs;
502   status->Update(
503       InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
504   if (!status->ok()) {
505     *status = AttachDef(*status, node_def);
506     return nullptr;
507   }
508 
509   Node::NodeClass node_class = op_reg_data->is_function_op
510                                    ? Node::NC_FUNCTION_OP
511                                    : Node::GetNodeClassForOp(node_def.op());
512 
513   if (node_def.has_experimental_type()) {
514     VLOG(3) << "AddNode: node has type set, skipping type constructor "
515             << node_def.name();
516   } else {
517     if (op_reg_data->type_ctor != nullptr) {
518       VLOG(3) << "AddNode: found type constructor for " << node_def.name();
519       Status s =
520           full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def,
521                                     *(node_def.mutable_experimental_type()));
522       if (!s.ok()) {
523         *status = errors::InvalidArgument("type error: ", s.ToString());
524         VLOG(3) << "AddNode: type inference failed for " << node_def.name()
525                 << ": " << s;
526         return nullptr;
527       }
528     } else {
529       VLOG(3) << "AddNode: no type constructor for " << node_def.name();
530     }
531   }
532 
533   Node* node = AllocateNode(
534       std::make_shared<NodeProperties>(&op_reg_data->op_def,
535                                        std::move(node_def), inputs, outputs),
536       nullptr, node_class);
537   return node;
538 }
539 
CopyNode(const Node * node)540 Node* Graph::CopyNode(const Node* node) {
541   DCHECK(!node->IsSource());
542   DCHECK(!node->IsSink());
543   Node* copy = AllocateNode(node->props_, node, node->class_);
544   copy->set_assigned_device_name(node->assigned_device_name());
545 
546   // Since the OpDef of a function may be owned by the Graph that owns 'node',
547   // relookup the OpDef in the target graph. If it differs, then clone the
548   // node properties with the updated OpDef.
549   const OpDef* op_def;
550   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
551   if (op_def != node->props_->op_def) {
552     copy->MaybeCopyOnWrite();
553     copy->props_->op_def = op_def;
554   }
555   copy->SetStackTrace(node->GetStackTrace());
556 
557   return copy;
558 }
559 
RemoveNode(Node * node)560 void Graph::RemoveNode(Node* node) {
561   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
562   DCHECK(!node->IsSource());
563   DCHECK(!node->IsSink());
564 
565   // Remove any edges involving this node.
566   for (const Edge* e : node->in_edges_) {
567     CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
568     edges_[e->id_] = nullptr;
569     RecycleEdge(e);
570     --num_edges_;
571   }
572   node->in_edges_.clear();
573   for (const Edge* e : node->out_edges_) {
574     CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
575     edges_[e->id_] = nullptr;
576     RecycleEdge(e);
577     --num_edges_;
578   }
579   node->out_edges_.clear();
580   ReleaseNode(node);
581 }
582 
AddEdge(Node * source,int x,Node * dest,int y)583 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
584   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
585   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
586 
587   // source/sink must only be linked via control slots, and
588   // control slots must only be linked to control slots.
589   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
590       y == kControlSlot) {
591     DCHECK_EQ(x, kControlSlot) << source->DebugString();
592     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
593   }
594 
595   Edge* e = nullptr;
596   if (free_edges_.empty()) {
597     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
598   } else {
599     e = free_edges_.back();
600     free_edges_.pop_back();
601   }
602   e->id_ = edges_.size();
603   e->src_ = source;
604   e->dst_ = dest;
605   e->src_output_ = x;
606   e->dst_input_ = y;
607   CHECK(source->out_edges_.insert(e).second);
608   CHECK(dest->in_edges_.insert(e).second);
609   edges_.push_back(e);
610   ++num_edges_;
611 
612   return e;
613 }
614 
RemoveEdge(const Edge * e)615 void Graph::RemoveEdge(const Edge* e) {
616   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
617   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
618   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
619   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
620   CHECK_EQ(e, edges_[e->id_]);
621   CHECK_GT(num_edges_, 0);
622 
623   edges_[e->id_] = nullptr;
624   RecycleEdge(e);
625   --num_edges_;
626 }
627 
RecycleEdge(const Edge * e)628 void Graph::RecycleEdge(const Edge* e) {
629   free_edges_.push_back(const_cast<Edge*>(e));
630 }
631 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)632 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
633                                   bool allow_duplicates) {
634   if (!allow_duplicates) {
635     for (const Edge* edge : dest->in_edges()) {
636       if (edge->IsControlEdge() && edge->src() == source) {
637         // The requested edge already exists.
638         return nullptr;
639       }
640     }
641   }
642   // Modify dest's NodeDef if necessary.
643   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
644     // Check if this input is already in dest's NodeDef.
645     const std::string new_input = strings::StrCat("^", source->name());
646     bool input_exists = false;
647     for (const std::string& input : dest->props_->node_def.input()) {
648       if (input == new_input) {
649         input_exists = true;
650         break;
651       }
652     }
653     if (!input_exists) {
654       dest->MaybeCopyOnWrite();
655       dest->props_->node_def.add_input(new_input);
656     }
657   }
658   return AddEdge(source, kControlSlot, dest, kControlSlot);
659 }
660 
RemoveControlEdge(const Edge * e)661 void Graph::RemoveControlEdge(const Edge* e) {
662   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
663     e->dst_->MaybeCopyOnWrite();
664     std::string e_src_name = strings::StrCat("^", e->src_->name());
665     auto* inputs = e->dst_->props_->node_def.mutable_input();
666     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
667       if (*it == e_src_name) {
668         inputs->erase(it);
669         break;
670       }
671     }
672   }
673   RemoveEdge(e);
674 }
675 
676 namespace {
FindEdge(const Node * dst,int index)677 const Edge* FindEdge(const Node* dst, int index) {
678   for (const Edge* e : dst->in_edges()) {
679     if (e->dst_input() == index) return e;
680   }
681   return nullptr;
682 }
683 }  // namespace
684 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)685 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
686                          int dst_index) {
687   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
688   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
689   const Edge* e = FindEdge(dst, dst_index);
690   if (e == nullptr) {
691     return errors::InvalidArgument("Couldn't find edge to ",
692                                    FormatNodeForError(*dst));
693   }
694   RemoveEdge(e);
695   AddEdge(new_src, new_src_index, dst, dst_index);
696   dst->MaybeCopyOnWrite();
697   (*dst->props_->node_def.mutable_input())[dst_index] =
698       strings::StrCat(new_src->name(), ":", new_src_index);
699   return OkStatus();
700 }
701 
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)702 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
703   if (!dst->IsWhileNode()) {
704     return errors::Internal(
705         "dst argument to AddWhileEdgeHack should be a While op, got: ",
706         dst->DebugString());
707   }
708   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
709   // Find the current number of data inputs. We'll add the new edge to the next
710   // missing data input.
711   int dst_index = 0;
712   for (const Edge* edge : dst->in_edges()) {
713     if (edge->IsControlEdge()) continue;
714     ++dst_index;
715   }
716   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
717   AddEdge(new_src, new_src_index, dst, dst_index);
718   dst->MaybeCopyOnWrite();
719   dst->props_->node_def.add_input(
720       strings::StrCat(new_src->name(), ":", new_src_index));
721   return OkStatus();
722 }
723 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)724 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
725   // Need a new-enough consumer to support the functions we add to the graph.
726   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
727     versions_->set_min_consumer(12);
728   }
729   return ops_.AddLibrary(fdef_lib);
730 }
731 
732 namespace {
733 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)734 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
735   if (src_slot == Graph::kControlSlot) {
736     dst->add_input(strings::StrCat("^", src_name));
737   } else if (src_slot == 0) {
738     dst->add_input(src_name.data(), src_name.size());
739   } else {
740     dst->add_input(strings::StrCat(src_name, ":", src_slot));
741   }
742 }
743 
744 }  // namespace
745 
ToGraphDef(GraphDef * graph_def) const746 void Graph::ToGraphDef(GraphDef* graph_def) const {
747   ToGraphDefSubRange(graph_def, 0);
748 }
749 
ToGraphDefDebug() const750 GraphDef Graph::ToGraphDefDebug() const {
751   GraphDef ret;
752   ToGraphDef(&ret);
753   return ret;
754 }
755 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const756 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
757   graph_def->Clear();
758   *graph_def->mutable_versions() = versions();
759   *graph_def->mutable_library() = ops_.ToProto();
760 
761   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
762 
763   std::vector<const Edge*>
764       inputs;  // Construct this outside the loop for speed.
765   for (auto id = from_node_id; id < num_node_ids(); ++id) {
766     const Node* node = FindNodeId(id);
767     if (node == nullptr || !node->IsOp()) continue;
768     NodeDef* node_def = graph_def->add_node();
769     *node_def = node->def();
770 
771     // Use the node's assigned device, if any, instead of the device requested
772     // in the NodeDef.
773     if (!node->assigned_device_name().empty()) {
774       node_def->set_device(node->assigned_device_name());
775     }
776 
777     // Get the inputs for this Node.  We make sure control inputs are
778     // after data inputs, as required by GraphDef.
779     inputs.clear();
780     inputs.resize(node->num_inputs(), nullptr);
781     for (const Edge* edge : node->in_edges()) {
782       if (edge->IsControlEdge()) {
783         inputs.push_back(edge);
784       } else {
785         DCHECK(edge->dst_input() < inputs.size())
786             << "Edge " << edge->DebugString()
787             << " is overflowing the expected number of inputs ("
788             << node->num_inputs() << ") for node " << node->DebugString();
789         CHECK(inputs[edge->dst_input()] == nullptr)
790             << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
791             << " conflicts with pre-existing input edge "
792             << inputs[edge->dst_input()]->src()->name() << "->"
793             << inputs[edge->dst_input()]->dst()->name();
794 
795         inputs[edge->dst_input()] = edge;
796       }
797     }
798     // Sort the control inputs for more predictable serialization.
799     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
800               [](const Edge* a, const Edge* b) -> bool {
801                 return a->src()->name() < b->src()->name();
802               });
803     node_def->clear_input();
804     node_def->mutable_input()->Reserve(inputs.size());
805 
806     for (size_t i = 0; i < inputs.size(); ++i) {
807       const Edge* edge = inputs[i];
808       if (edge == nullptr) {
809         if (i < node->requested_inputs().size()) {
810           node_def->add_input(node->requested_inputs()[i]);
811         } else {
812           node_def->add_input("");
813         }
814       } else {
815         const Node* src = edge->src();
816         if (!src->IsOp()) continue;
817         AddInput(node_def, src->name(), edge->src_output());
818       }
819     }
820   }
821 }
822 
NewName(StringPiece prefix)823 std::string Graph::NewName(StringPiece prefix) {
824   return strings::StrCat(prefix, "/_", name_counter_++);
825 }
826 
IsValidNode(const Node * node) const827 Status Graph::IsValidNode(const Node* node) const {
828   if (node == nullptr) {
829     return errors::InvalidArgument("Node is null");
830   }
831   const int id = node->id();
832   if (id < 0) {
833     return errors::InvalidArgument("node id ", id, " is less than zero");
834   }
835   if (static_cast<size_t>(id) >= nodes_.size()) {
836     return errors::InvalidArgument(
837         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
838   }
839   if (nodes_[id] != node) {
840     return errors::InvalidArgument("Node with id ", id,
841                                    " is different from the passed in node. "
842                                    "Does it belong to a different graph?");
843   }
844   return OkStatus();
845 }
846 
IsValidOutputTensor(const Node * node,int idx) const847 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
848   TF_RETURN_IF_ERROR(IsValidNode(node));
849   if (idx >= node->num_outputs() || idx < 0) {
850     return errors::OutOfRange("Node '", node->name(), "' (type: '",
851                               node->op_def().name(),
852                               "', num of outputs: ", node->num_outputs(),
853                               ") does not have ", "output ", idx);
854   }
855   return OkStatus();
856 }
857 
IsValidInputTensor(const Node * node,int idx) const858 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
859   TF_RETURN_IF_ERROR(IsValidNode(node));
860   if (idx >= node->num_inputs() || idx < 0) {
861     return errors::OutOfRange("Node '", node->name(), "' (type: '",
862                               node->op_def().name(),
863                               "', num of inputs: ", node->num_inputs(),
864                               ") does not have ", "input ", idx);
865   }
866   return OkStatus();
867 }
868 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)869 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
870                           const Node* cost_node, Node::NodeClass node_class) {
871   Node* node = nullptr;
872   if (free_nodes_.empty()) {
873     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
874   } else {
875     node = free_nodes_.back();
876     free_nodes_.pop_back();
877   }
878   node->graph_ = this;
879   const int id = nodes_.size();
880   int cost_id = cost_node ? cost_node->cost_id() : id;
881   node->Initialize(id, cost_id, std::move(props), node_class);
882   nodes_.push_back(node);
883   ++num_nodes_;
884   return node;
885 }
886 
ReleaseNode(Node * node)887 void Graph::ReleaseNode(Node* node) {
888   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
889   nodes_[node->id()] = nullptr;
890   free_nodes_.push_back(node);
891   --num_nodes_;
892   node->Clear();
893 }
894 
895 // Ensures that 'device_name' is present in the device name table, and returns
896 // the index of that device name. The index is stable, and can be used in
897 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const std::string & device_name)898 int Graph::InternDeviceName(const std::string& device_name) {
899   // Special case, very common.  Also, this allows us to use a single map
900   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
901   // relies on this check.
902   if (device_name.empty()) {
903     return 0;
904   }
905 
906   int& index_cell = device_names_map_[device_name];
907   if (index_cell > 0) {
908     return index_cell;
909   }
910 
911   const int index = device_names_map_.size();
912   index_cell = index;
913   device_names_.push_back(device_name);
914   return index;
915 }
916 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)917 Status Graph::AddWhileContext(StringPiece frame_name,
918                               std::vector<Node*> enter_nodes,
919                               std::vector<Node*> exit_nodes,
920                               OutputTensor cond_output,
921                               std::vector<OutputTensor> body_inputs,
922                               std::vector<OutputTensor> body_outputs,
923                               WhileContext** result) {
924   auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>(
925       std::string(frame_name),
926       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
927                    cond_output, std::move(body_inputs),
928                    std::move(body_outputs))));
929   if (!pair.second) {
930     *result = nullptr;
931     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
932                                    "' already exists");
933   }
934   *result = &pair.first->second;
935   return OkStatus();
936 }
937 
BuildNodeNameIndex() const938 std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const {
939   std::unordered_map<std::string, Node*> result;
940   for (Node* n : nodes()) {
941     result[n->name()] = n;
942   }
943   return result;
944 }
945 
DebugString() const946 std::string Edge::DebugString() const {
947   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
948                          src_output_, dst_->name().c_str(), dst_input_);
949 }
950 
951 }  // namespace tensorflow
952