• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_properties.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/graph/graph_node_util.h"
29 #include "tensorflow/core/graph/while_context.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/lib/strings/stringprintf.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/public/version.h"
38 
39 namespace tensorflow {
40 
41 const int Graph::kControlSlot = -1;
42 
43 // Node
GetNodeClassForOp(const std::string & ts)44 Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) {
45   static const absl::flat_hash_map<std::string, Node::NodeClass>*
46       kNodeClassTable =
47 #define REF_CLASS(key, value) \
48   {key, value}, { "Ref" key, value }
49           new absl::flat_hash_map<std::string, Node::NodeClass>({
50               // Keep in same order as NodeClass values
51               REF_CLASS("Switch", NC_SWITCH),
52               REF_CLASS("_SwitchN", NC_SWITCH),
53               REF_CLASS("Merge", NC_MERGE),
54               REF_CLASS("Enter", NC_ENTER),
55               REF_CLASS("Exit", NC_EXIT),
56               REF_CLASS("NextIteration", NC_NEXT_ITERATION),
57               {"LoopCond", NC_LOOP_COND},
58               {"ControlTrigger", NC_CONTROL_TRIGGER},
59               {"_Send", NC_SEND},
60               {"_HostSend", NC_HOST_SEND},
61               {"_Recv", NC_RECV},
62               {"_HostRecv", NC_HOST_RECV},
63               {"Const", NC_CONSTANT},
64               {"HostConst", NC_CONSTANT},
65               {"Variable", NC_VARIABLE},
66               {"VariableV2", NC_VARIABLE},
67               REF_CLASS("Identity", NC_IDENTITY),
68               {"GetSessionHandle", NC_GET_SESSION_HANDLE},
69               {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
70               {"GetSessionTensor", NC_GET_SESSION_TENSOR},
71               {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
72               {"Size", NC_METADATA},
73               {"Shape", NC_METADATA},
74               {"Rank", NC_METADATA},
75               {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
76               {"CollectiveReduce", NC_COLLECTIVE},
77               {"CollectiveBcastSend", NC_COLLECTIVE},
78               {"CollectiveBcastRecv", NC_COLLECTIVE},
79               {"CollectiveGather", NC_COLLECTIVE},
80               {"FakeParam", NC_FAKE_PARAM},
81               {"PartitionedCall", NC_PARTITIONED_CALL},
82               {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
83               {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
84               {"If", NC_IF},
85               {"StatelessIf", NC_IF},
86               {"While", NC_WHILE},
87               {"StatelessWhile", NC_WHILE},
88               {"Case", NC_CASE},
89               {"StatelessCase", NC_CASE},
90               // Not using the constants defined in FunctionLibraryDefinition
91               // for the
92               // 4 ops below because android inference library does not link
93               // tf.function related files.
94               {"_Arg", NC_ARG},
95               {"_DeviceArg", NC_ARG},
96               {"_Retval", NC_RETVAL},
97               {"_DeviceRetval", NC_RETVAL},
98               {"_XlaMerge", NC_MERGE},
99           });
100 #undef REF_CLASS
101 
102   auto it = kNodeClassTable->find(ts);
103   if (it != kNodeClassTable->end()) {
104     return it->second;
105   } else {
106     return NC_OTHER;
107   }
108 }
109 
DebugString() const110 std::string Node::DebugString() const {
111   std::string ret = strings::StrCat("{name:'", name(), "' id:", id_);
112   if (IsSource()) {
113     strings::StrAppend(&ret, " source}");
114   } else if (IsSink()) {
115     strings::StrAppend(&ret, " sink}");
116   } else {
117     strings::StrAppend(&ret, " op device:", "{requested: '", requested_device(),
118                        "', assigned: '", assigned_device_name(), "'}", " def:{",
119                        SummarizeNode(*this), "}}");
120   }
121   return ret;
122 }
123 
Node()124 Node::Node()
125     : id_(-1),
126       cost_id_(-1),
127       class_(NC_UNINITIALIZED),
128       props_(nullptr),
129       assigned_device_name_index_(0),
130       while_ctx_(nullptr) {}
131 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)132 void Node::Initialize(int id, int cost_id,
133                       std::shared_ptr<NodeProperties> props,
134                       Node::NodeClass node_class) {
135   DCHECK_EQ(id_, -1);
136   DCHECK(in_edges_.empty());
137   DCHECK(out_edges_.empty());
138   id_ = id;
139   cost_id_ = cost_id;
140 
141   props_ = std::move(props);
142   class_ = node_class;
143 }
144 
Clear()145 void Node::Clear() {
146   in_edges_.clear();
147   out_edges_.clear();
148   id_ = -1;
149   cost_id_ = -1;
150   class_ = NC_UNINITIALIZED;
151   props_.reset();
152   assigned_device_name_index_ = 0;
153 }
154 
UpdateProperties()155 void Node::UpdateProperties() {
156   DataTypeVector inputs;
157   DataTypeVector outputs;
158   Status status =
159       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
160   if (!status.ok()) {
161     LOG(ERROR) << "Failed at updating node: " << status;
162     return;
163   }
164   if (props_->input_types != inputs || props_->output_types != outputs) {
165     if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
166       props_->input_types = inputs;
167       props_->input_types_slice = props_->input_types;
168       props_->output_types = outputs;
169       props_->output_types_slice = props_->output_types;
170     } else {
171       props_ = std::make_shared<NodeProperties>(
172           props_->op_def, std::move(props_->node_def), inputs, outputs);
173     }
174   }
175 }
176 
name() const177 const std::string& Node::name() const { return props_->node_def.name(); }
type_string() const178 const std::string& Node::type_string() const { return props_->node_def.op(); }
def() const179 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const180 const OpDef& Node::op_def() const { return *props_->op_def; }
181 
mutable_def()182 NodeDef* Node::mutable_def() { return &props_->node_def; }
183 
num_inputs() const184 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32_t i) const185 DataType Node::input_type(int32_t i) const { return props_->input_types[i]; }
input_types() const186 const DataTypeVector& Node::input_types() const { return props_->input_types; }
187 
num_outputs() const188 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32_t o) const189 DataType Node::output_type(int32_t o) const { return props_->output_types[o]; }
output_types() const190 const DataTypeVector& Node::output_types() const {
191   return props_->output_types;
192 }
193 
attrs() const194 AttrSlice Node::attrs() const { return AttrSlice(def()); }
195 
requested_inputs() const196 const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const {
197   return def().input();
198 }
199 
requested_device() const200 const std::string& Node::requested_device() const { return def().device(); }
201 
out_nodes() const202 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
203   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
204                          NeighborIter(out_edges_.end(), false));
205 }
206 
in_nodes() const207 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
208   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
209                          NeighborIter(in_edges_.end(), true));
210 }
211 
MaybeCopyOnWrite()212 void Node::MaybeCopyOnWrite() {
213   // NodeProperties may be shared between Nodes. Make a copy if so.
214   if (!props_.unique()) {
215     props_ = std::make_shared<NodeProperties>(*props_);
216   }
217 }
218 
AddAttrHelper(const std::string & name)219 AttrValue* Node::AddAttrHelper(const std::string& name) {
220   MaybeCopyOnWrite();
221   return &((*props_->node_def.mutable_attr())[name]);
222 }
223 
ClearAttr(const std::string & name)224 void Node::ClearAttr(const std::string& name) {
225   MaybeCopyOnWrite();
226   (*props_->node_def.mutable_attr()).erase(name);
227 }
228 
set_name(std::string name)229 void Node::set_name(std::string name) {
230   MaybeCopyOnWrite();
231   props_->node_def.set_name(std::move(name));
232 }
233 
set_requested_device(const std::string & device)234 void Node::set_requested_device(const std::string& device) {
235   MaybeCopyOnWrite();
236   props_->node_def.set_device(device);
237 }
238 
set_original_node_names(const std::vector<std::string> & names)239 void Node::set_original_node_names(const std::vector<std::string>& names) {
240   MaybeCopyOnWrite();
241   props_->node_def.mutable_experimental_debug_info()
242       ->clear_original_node_names();
243   if (!names.empty()) {
244     *props_->node_def.mutable_experimental_debug_info()
245          ->mutable_original_node_names() = {names.begin(), names.end()};
246   }
247 }
248 
input_edge(int idx,const Edge ** e) const249 Status Node::input_edge(int idx, const Edge** e) const {
250   if (idx < 0 || idx >= num_inputs()) {
251     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
252                                    name(), " only has ", num_inputs(),
253                                    " inputs.");
254   }
255 
256   // This does a linear search over the edges.  In the common case,
257   // the number of elements is small enough that this search isn't
258   // expensive.  Should it become a bottleneck, one can make an
259   // optimization where, if the number of edges is small, we use
260   // linear iteration, and if the number of edges is large, we perform
261   // an indexing step during construction that keeps an array of Edges
262   // indexed by pointer.  This would keep the size of each Node small
263   // in the common case but make this function faster when the number
264   // of edges is large.
265   for (const Edge* edge : in_edges()) {
266     if (edge->dst_input() == idx) {
267       *e = edge;
268       return Status::OK();
269     }
270   }
271 
272   return errors::NotFound("Could not find input edge ", idx, " for ", name());
273 }
274 
275 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const276 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
277   input_edges->clear();
278   input_edges->resize(num_inputs(), nullptr);
279 
280   for (const Edge* edge : in_edges()) {
281     if (edge->IsControlEdge()) continue;
282     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
283       return errors::Internal("Invalid edge input number ", edge->dst_input());
284     }
285     if ((*input_edges)[edge->dst_input()] != nullptr) {
286       return errors::Internal("Duplicate edge input number: ",
287                               edge->dst_input());
288     }
289     (*input_edges)[edge->dst_input()] = edge;
290   }
291 
292   for (int i = 0; i < num_inputs(); ++i) {
293     if ((*input_edges)[i] == nullptr) {
294       return errors::InvalidArgument("Missing edge input number: ", i);
295     }
296   }
297   return Status::OK();
298 }
299 
input_node(int idx,Node ** n) const300 Status Node::input_node(int idx, Node** n) const {
301   const Edge* e;
302   TF_RETURN_IF_ERROR(input_edge(idx, &e));
303   if (e == nullptr) {
304     *n = nullptr;
305   } else {
306     *n = e->src();
307   }
308   return Status::OK();
309 }
310 
input_node(int idx,const Node ** const_n) const311 Status Node::input_node(int idx, const Node** const_n) const {
312   Node* n;
313   TF_RETURN_IF_ERROR(input_node(idx, &n));
314   *const_n = n;
315   return Status::OK();
316 }
317 
input_tensor(int idx,OutputTensor * t) const318 Status Node::input_tensor(int idx, OutputTensor* t) const {
319   const Edge* e;
320   TF_RETURN_IF_ERROR(input_edge(idx, &e));
321   DCHECK(e != nullptr);
322   *t = OutputTensor(e->src(), e->src_output());
323   return Status::OK();
324 }
325 
326 // NodeDebugInfo
327 
NodeDebugInfo(const Node & n)328 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)329 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
330     : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
331                     ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)332 NodeDebugInfo::NodeDebugInfo(
333     StringPiece node_name, bool has_experimental_debug_info,
334     const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
335     : name(node_name) {
336   if (has_experimental_debug_info) {
337     const auto& names = experimental_debug_info.original_node_names();
338     original_node_names.assign(names.begin(), names.end());
339   }
340 }
341 
342 // InputTensor
343 
operator ==(const InputTensor & other) const344 bool InputTensor::operator==(const InputTensor& other) const {
345   return node == other.node && index == other.index;
346 }
347 
operator ()(InputTensor const & s) const348 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
349   return Hash64Combine(std::hash<const Node*>()(s.node),
350                        std::hash<int>()(s.index));
351 }
352 
353 // OutputTensor
354 
operator ==(const OutputTensor & other) const355 bool OutputTensor::operator==(const OutputTensor& other) const {
356   return node == other.node && index == other.index;
357 }
358 
operator ()(OutputTensor const & s) const359 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
360   return Hash64Combine(std::hash<const Node*>()(s.node),
361                        std::hash<int>()(s.index));
362 }
363 
364 // Graph
365 
Graph(const OpRegistryInterface * ops)366 Graph::Graph(const OpRegistryInterface* ops)
367     : ops_(ops, FunctionDefLibrary()),
368       versions_(new VersionDef),
369       arena_(8 << 10 /* 8kB */) {
370   versions_->set_producer(TF_GRAPH_DEF_VERSION);
371   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
372 
373   // Initialize the name interning table for assigned_device_name.
374   device_names_.push_back("");
375   DCHECK_EQ(0, InternDeviceName(""));
376 
377   // Source and sink have no endpoints, just control edges.
378   NodeDef def;
379   def.set_name("_SOURCE");
380   def.set_op("NoOp");
381   Status status;
382   Node* source = AddNode(def, &status);
383   TF_CHECK_OK(status);
384   CHECK_EQ(source->id(), kSourceId);
385 
386   def.set_name("_SINK");
387   Node* sink = AddNode(def, &status);
388   TF_CHECK_OK(status);
389   CHECK_EQ(sink->id(), kSinkId);
390 
391   AddControlEdge(source, sink);
392 }
393 
Graph(const FunctionLibraryDefinition & flib_def)394 Graph::Graph(const FunctionLibraryDefinition& flib_def)
395     : Graph(flib_def.default_registry()) {
396   // Need a new-enough consumer to support the functions we add to the graph.
397   if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
398     versions_->set_min_consumer(12);
399   }
400   Status s = ops_.AddLibrary(flib_def);
401   CHECK(s.ok()) << s.error_message();
402 }
403 
~Graph()404 Graph::~Graph() {
405   // Manually call the destructors for all the Nodes we constructed using
406   // placement new.
407   for (Node* node : nodes_) {
408     if (node != nullptr) {
409       node->~Node();
410     }
411   }
412   for (Node* node : free_nodes_) {
413     node->~Node();
414   }
415   // Edges have no destructor, and we arena-allocated them, so no need to
416   // destroy them.
417 }
418 
Clone()419 std::unique_ptr<Graph> Graph::Clone() {
420   std::unique_ptr<Graph> new_graph(new Graph(flib_def()));
421   new_graph->Copy(*this);
422   return new_graph;
423 }
424 
versions() const425 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)426 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
427 
Copy(const Graph & src)428 void Graph::Copy(const Graph& src) {
429   SetConstructionContext(src.GetConstructionContextInternal());
430   for (Node* n : nodes()) {
431     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
432   }
433 
434   // Copy GraphDef versions
435   set_versions(src.versions());
436 
437   // Copy the nodes.
438   // "Node in src" -> "Node in *dest"
439   gtl::FlatMap<const Node*, Node*> node_map;
440   node_map.reserve(src.num_nodes());
441   node_map[src.source_node()] = source_node();
442   node_map[src.sink_node()] = sink_node();
443   for (Node* n : src.op_nodes()) {
444     auto copy = CopyNode(n);
445     copy->in_edges_.reserve(n->in_edges().size());
446     copy->out_edges_.reserve(n->out_edges().size());
447     node_map[n] = copy;
448   }
449 
450   // Copy the edges
451   edges_.reserve(src.num_edges());
452   for (const Edge* e : src.edges()) {
453     Node* src_copy = node_map[e->src()];
454     Node* dst_copy = node_map[e->dst()];
455     AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
456   }
457 
458   types_ = src.types_;
459   node_name_to_out_type_ = src.node_name_to_out_type_;
460 }
461 
AddNode(NodeDef node_def,Status * status)462 Node* Graph::AddNode(NodeDef node_def, Status* status) {
463   const OpRegistrationData* op_reg_data;
464   status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
465   if (!status->ok()) return nullptr;
466 
467   DataTypeVector inputs;
468   DataTypeVector outputs;
469   status->Update(
470       InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
471   if (!status->ok()) {
472     *status = AttachDef(*status, node_def);
473     return nullptr;
474   }
475 
476   Node::NodeClass node_class = op_reg_data->is_function_op
477                                    ? Node::NC_FUNCTION_OP
478                                    : Node::GetNodeClassForOp(node_def.op());
479 
480   Node* node = AllocateNode(
481       std::make_shared<NodeProperties>(&op_reg_data->op_def,
482                                        std::move(node_def), inputs, outputs),
483       nullptr, node_class);
484   return node;
485 }
486 
CopyNode(const Node * node)487 Node* Graph::CopyNode(const Node* node) {
488   DCHECK(!node->IsSource());
489   DCHECK(!node->IsSink());
490   Node* copy = AllocateNode(node->props_, node, node->class_);
491   copy->set_assigned_device_name(node->assigned_device_name());
492 
493   // Since the OpDef of a function may be owned by the Graph that owns 'node',
494   // relookup the OpDef in the target graph. If it differs, then clone the
495   // node properties with the updated OpDef.
496   const OpDef* op_def;
497   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
498   if (op_def != node->props_->op_def) {
499     copy->MaybeCopyOnWrite();
500     copy->props_->op_def = op_def;
501   }
502   copy->SetStackTrace(node->GetStackTrace());
503 
504   return copy;
505 }
506 
RemoveNode(Node * node)507 void Graph::RemoveNode(Node* node) {
508   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
509   DCHECK(!node->IsSource());
510   DCHECK(!node->IsSink());
511 
512   // Remove any edges involving this node.
513   for (const Edge* e : node->in_edges_) {
514     CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
515     edges_[e->id_] = nullptr;
516     RecycleEdge(e);
517     --num_edges_;
518   }
519   node->in_edges_.clear();
520   for (const Edge* e : node->out_edges_) {
521     CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
522     edges_[e->id_] = nullptr;
523     RecycleEdge(e);
524     --num_edges_;
525   }
526   node->out_edges_.clear();
527   ReleaseNode(node);
528 }
529 
AddEdge(Node * source,int x,Node * dest,int y)530 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
531   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
532   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
533 
534   // source/sink must only be linked via control slots, and
535   // control slots must only be linked to control slots.
536   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
537       y == kControlSlot) {
538     DCHECK_EQ(x, kControlSlot) << source->DebugString();
539     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
540   }
541 
542   Edge* e = nullptr;
543   if (free_edges_.empty()) {
544     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
545   } else {
546     e = free_edges_.back();
547     free_edges_.pop_back();
548   }
549   e->id_ = edges_.size();
550   e->src_ = source;
551   e->dst_ = dest;
552   e->src_output_ = x;
553   e->dst_input_ = y;
554   CHECK(source->out_edges_.insert(e).second);
555   CHECK(dest->in_edges_.insert(e).second);
556   edges_.push_back(e);
557   ++num_edges_;
558   return e;
559 }
560 
RemoveEdge(const Edge * e)561 void Graph::RemoveEdge(const Edge* e) {
562   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
563   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
564   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
565   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
566   CHECK_EQ(e, edges_[e->id_]);
567   CHECK_GT(num_edges_, 0);
568 
569   edges_[e->id_] = nullptr;
570   RecycleEdge(e);
571   --num_edges_;
572 }
573 
RecycleEdge(const Edge * e)574 void Graph::RecycleEdge(const Edge* e) {
575   free_edges_.push_back(const_cast<Edge*>(e));
576 }
577 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)578 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
579                                   bool allow_duplicates) {
580   if (!allow_duplicates) {
581     for (const Edge* edge : dest->in_edges()) {
582       if (edge->IsControlEdge() && edge->src() == source) {
583         // The requested edge already exists.
584         return nullptr;
585       }
586     }
587   }
588   // Modify dest's NodeDef if necessary.
589   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
590     // Check if this input is already in dest's NodeDef.
591     const std::string new_input = strings::StrCat("^", source->name());
592     bool input_exists = false;
593     for (const std::string& input : dest->props_->node_def.input()) {
594       if (input == new_input) {
595         input_exists = true;
596         break;
597       }
598     }
599     if (!input_exists) {
600       dest->MaybeCopyOnWrite();
601       dest->props_->node_def.add_input(new_input);
602     }
603   }
604   return AddEdge(source, kControlSlot, dest, kControlSlot);
605 }
606 
RemoveControlEdge(const Edge * e)607 void Graph::RemoveControlEdge(const Edge* e) {
608   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
609     e->dst_->MaybeCopyOnWrite();
610     std::string e_src_name = strings::StrCat("^", e->src_->name());
611     auto* inputs = e->dst_->props_->node_def.mutable_input();
612     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
613       if (*it == e_src_name) {
614         inputs->erase(it);
615         break;
616       }
617     }
618   }
619   RemoveEdge(e);
620 }
621 
622 namespace {
FindEdge(const Node * dst,int index)623 const Edge* FindEdge(const Node* dst, int index) {
624   for (const Edge* e : dst->in_edges()) {
625     if (e->dst_input() == index) return e;
626   }
627   return nullptr;
628 }
629 }  // namespace
630 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)631 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
632                          int dst_index) {
633   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
634   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
635   const Edge* e = FindEdge(dst, dst_index);
636   if (e == nullptr) {
637     return errors::InvalidArgument("Couldn't find edge to ",
638                                    FormatNodeForError(*dst));
639   }
640   RemoveEdge(e);
641   AddEdge(new_src, new_src_index, dst, dst_index);
642   dst->MaybeCopyOnWrite();
643   (*dst->props_->node_def.mutable_input())[dst_index] =
644       strings::StrCat(new_src->name(), ":", new_src_index);
645   return Status::OK();
646 }
647 
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)648 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
649   if (!dst->IsWhileNode()) {
650     return errors::Internal(
651         "dst argument to AddWhileEdgeHack should be a While op, got: ",
652         dst->DebugString());
653   }
654   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
655   // Find the current number of data inputs. We'll add the new edge to the next
656   // missing data input.
657   int dst_index = 0;
658   for (const Edge* edge : dst->in_edges()) {
659     if (edge->IsControlEdge()) continue;
660     ++dst_index;
661   }
662   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
663   AddEdge(new_src, new_src_index, dst, dst_index);
664   dst->MaybeCopyOnWrite();
665   dst->props_->node_def.add_input(
666       strings::StrCat(new_src->name(), ":", new_src_index));
667   return Status::OK();
668 }
669 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)670 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
671   // Need a new-enough consumer to support the functions we add to the graph.
672   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
673     versions_->set_min_consumer(12);
674   }
675   return ops_.AddLibrary(fdef_lib);
676 }
677 
678 namespace {
679 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)680 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
681   if (src_slot == Graph::kControlSlot) {
682     dst->add_input(strings::StrCat("^", src_name));
683   } else if (src_slot == 0) {
684     dst->add_input(src_name.data(), src_name.size());
685   } else {
686     dst->add_input(strings::StrCat(src_name, ":", src_slot));
687   }
688 }
689 
690 }  // namespace
691 
ToGraphDef(GraphDef * graph_def) const692 void Graph::ToGraphDef(GraphDef* graph_def) const {
693   ToGraphDefSubRange(graph_def, 0);
694 }
695 
ToGraphDefDebug() const696 GraphDef Graph::ToGraphDefDebug() const {
697   GraphDef ret;
698   ToGraphDef(&ret);
699   return ret;
700 }
701 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const702 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
703   graph_def->Clear();
704   *graph_def->mutable_versions() = versions();
705   *graph_def->mutable_library() = ops_.ToProto();
706 
707   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
708 
709   std::vector<const Edge*>
710       inputs;  // Construct this outside the loop for speed.
711   for (auto id = from_node_id; id < num_node_ids(); ++id) {
712     const Node* node = FindNodeId(id);
713     if (node == nullptr || !node->IsOp()) continue;
714     NodeDef* node_def = graph_def->add_node();
715     *node_def = node->def();
716 
717     // Use the node's assigned device, if any, instead of the device requested
718     // in the NodeDef.
719     if (!node->assigned_device_name().empty()) {
720       node_def->set_device(node->assigned_device_name());
721     }
722 
723     // Get the inputs for this Node.  We make sure control inputs are
724     // after data inputs, as required by GraphDef.
725     inputs.clear();
726     inputs.resize(node->num_inputs(), nullptr);
727     for (const Edge* edge : node->in_edges()) {
728       if (edge->IsControlEdge()) {
729         inputs.push_back(edge);
730       } else {
731         DCHECK(edge->dst_input() < inputs.size())
732             << "Edge " << edge->DebugString()
733             << " is overflowing the expected number of inputs ("
734             << node->num_inputs() << ") for node " << node->DebugString();
735         CHECK(inputs[edge->dst_input()] == nullptr)
736             << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
737             << " conflicts with pre-existing input edge "
738             << inputs[edge->dst_input()]->src()->name() << "->"
739             << inputs[edge->dst_input()]->dst()->name();
740 
741         inputs[edge->dst_input()] = edge;
742       }
743     }
744     // Sort the control inputs for more predictable serialization.
745     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
746               [](const Edge* a, const Edge* b) -> bool {
747                 return a->src()->name() < b->src()->name();
748               });
749     node_def->clear_input();
750     node_def->mutable_input()->Reserve(inputs.size());
751 
752     for (size_t i = 0; i < inputs.size(); ++i) {
753       const Edge* edge = inputs[i];
754       if (edge == nullptr) {
755         if (i < node->requested_inputs().size()) {
756           node_def->add_input(node->requested_inputs()[i]);
757         } else {
758           node_def->add_input("");
759         }
760       } else {
761         const Node* src = edge->src();
762         if (!src->IsOp()) continue;
763         AddInput(node_def, src->name(), edge->src_output());
764       }
765     }
766   }
767 }
768 
NewName(StringPiece prefix)769 std::string Graph::NewName(StringPiece prefix) {
770   return strings::StrCat(prefix, "/_", name_counter_++);
771 }
772 
IsValidNode(const Node * node) const773 Status Graph::IsValidNode(const Node* node) const {
774   if (node == nullptr) {
775     return errors::InvalidArgument("Node is null");
776   }
777   const int id = node->id();
778   if (id < 0) {
779     return errors::InvalidArgument("node id ", id, " is less than zero");
780   }
781   if (static_cast<size_t>(id) >= nodes_.size()) {
782     return errors::InvalidArgument(
783         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
784   }
785   if (nodes_[id] != node) {
786     return errors::InvalidArgument("Node with id ", id,
787                                    " is different from the passed in node. "
788                                    "Does it belong to a different graph?");
789   }
790   return Status::OK();
791 }
792 
IsValidOutputTensor(const Node * node,int idx) const793 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
794   TF_RETURN_IF_ERROR(IsValidNode(node));
795   if (idx >= node->num_outputs() || idx < 0) {
796     return errors::OutOfRange("Node '", node->name(), "' (type: '",
797                               node->op_def().name(),
798                               "', num of outputs: ", node->num_outputs(),
799                               ") does not have ", "output ", idx);
800   }
801   return Status::OK();
802 }
803 
IsValidInputTensor(const Node * node,int idx) const804 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
805   TF_RETURN_IF_ERROR(IsValidNode(node));
806   if (idx >= node->num_inputs() || idx < 0) {
807     return errors::OutOfRange("Node '", node->name(), "' (type: '",
808                               node->op_def().name(),
809                               "', num of inputs: ", node->num_inputs(),
810                               ") does not have ", "input ", idx);
811   }
812   return Status::OK();
813 }
814 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)815 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
816                           const Node* cost_node, Node::NodeClass node_class) {
817   Node* node = nullptr;
818   if (free_nodes_.empty()) {
819     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
820   } else {
821     node = free_nodes_.back();
822     free_nodes_.pop_back();
823   }
824   node->graph_ = this;
825   const int id = nodes_.size();
826   int cost_id = cost_node ? cost_node->cost_id() : id;
827   node->Initialize(id, cost_id, std::move(props), node_class);
828   nodes_.push_back(node);
829   ++num_nodes_;
830   return node;
831 }
832 
ReleaseNode(Node * node)833 void Graph::ReleaseNode(Node* node) {
834   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
835   nodes_[node->id()] = nullptr;
836   free_nodes_.push_back(node);
837   --num_nodes_;
838   node->Clear();
839 }
840 
841 // Ensures that 'device_name' is present in the device name table, and returns
842 // the index of that device name. The index is stable, and can be used in
843 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const std::string & device_name)844 int Graph::InternDeviceName(const std::string& device_name) {
845   // Special case, very common.  Also, this allows us to use a single map
846   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
847   // relies on this check.
848   if (device_name.empty()) {
849     return 0;
850   }
851 
852   int& index_cell = device_names_map_[device_name];
853   if (index_cell > 0) {
854     return index_cell;
855   }
856 
857   const int index = device_names_map_.size();
858   index_cell = index;
859   device_names_.push_back(device_name);
860   return index;
861 }
862 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)863 Status Graph::AddWhileContext(StringPiece frame_name,
864                               std::vector<Node*> enter_nodes,
865                               std::vector<Node*> exit_nodes,
866                               OutputTensor cond_output,
867                               std::vector<OutputTensor> body_inputs,
868                               std::vector<OutputTensor> body_outputs,
869                               WhileContext** result) {
870   auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>(
871       std::string(frame_name),
872       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
873                    cond_output, std::move(body_inputs),
874                    std::move(body_outputs))));
875   if (!pair.second) {
876     *result = nullptr;
877     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
878                                    "' already exists");
879   }
880   *result = &pair.first->second;
881   return Status::OK();
882 }
883 
BuildNodeNameIndex() const884 std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const {
885   std::unordered_map<std::string, Node*> result;
886   for (Node* n : nodes()) {
887     result[n->name()] = n;
888   }
889   return result;
890 }
891 
SetNodeType(StringPiece name,const FullTypeDef & ft)892 void Graph::SetNodeType(StringPiece name, const FullTypeDef& ft) {
893   TypeRef t = {std::make_shared<FullTypeDef>(ft)};
894   auto ret = types_.emplace(t);
895   if (ret.second == false) {
896     t = *ret.first;
897   }
898 
899   node_name_to_out_type_.emplace(string(name), t);
900 }
901 
NodeType(StringPiece name,FullTypeDef ** result)902 void Graph::NodeType(StringPiece name, FullTypeDef** result) {
903   *result = nullptr;
904   auto it = node_name_to_out_type_.find(string(name));
905   if (it == node_name_to_out_type_.end()) {
906     *result = nullptr;
907     return;
908   }
909   *result = it->second.full_type.get();
910 }
911 
DebugString() const912 std::string Edge::DebugString() const {
913   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
914                          src_output_, dst_->name().c_str(), dst_input_);
915 }
916 
917 }  // namespace tensorflow
918