• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_node_util.h"
27 #include "tensorflow/core/graph/while_context.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/public/version.h"
35 
36 namespace tensorflow {
37 
38 const int Graph::kControlSlot = -1;
39 
40 struct NodeProperties {
41  public:
NodePropertiestensorflow::NodeProperties42   NodeProperties(const OpDef* op_def, NodeDef node_def,
43                  const DataTypeSlice inputs, const DataTypeSlice outputs)
44       : op_def(op_def),
45         node_def(std::move(node_def)),
46         input_types(inputs.begin(), inputs.end()),
47         output_types(outputs.begin(), outputs.end()) {}
48 
49   const OpDef* op_def;  // not owned
50   NodeDef node_def;
51   const DataTypeVector input_types;
52   const DataTypeVector output_types;
53 };
54 
55 // Node
56 
57 #define REF_CLASS(key, value) \
58   {key, value}, { "Ref" key, value }
59 
60 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
61     *new std::unordered_map<string, Node::NodeClass>({
62         // Keep in same order as NodeClass values
63         REF_CLASS("Switch", NC_SWITCH),
64         REF_CLASS("_SwitchN", NC_SWITCH),
65         REF_CLASS("Merge", NC_MERGE),
66         REF_CLASS("Enter", NC_ENTER),
67         REF_CLASS("Exit", NC_EXIT),
68         REF_CLASS("NextIteration", NC_NEXT_ITERATION),
69         {"LoopCond", NC_LOOP_COND},
70         {"ControlTrigger", NC_CONTROL_TRIGGER},
71         {"_Send", NC_SEND},
72         {"_HostSend", NC_HOST_SEND},
73         {"_Recv", NC_RECV},
74         {"_HostRecv", NC_HOST_RECV},
75         {"Const", NC_CONSTANT},
76         {"HostConst", NC_CONSTANT},
77         {"Variable", NC_VARIABLE},
78         {"VariableV2", NC_VARIABLE},
79         REF_CLASS("Identity", NC_IDENTITY),
80         {"GetSessionHandle", NC_GET_SESSION_HANDLE},
81         {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
82         {"GetSessionTensor", NC_GET_SESSION_TENSOR},
83         {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
84         {"Size", NC_METADATA},
85         {"Shape", NC_METADATA},
86         {"Rank", NC_METADATA},
87         {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
88         {"CollectiveReduce", NC_COLLECTIVE},
89         {"CollectiveBcastSend", NC_COLLECTIVE},
90         {"CollectiveBcastRecv", NC_COLLECTIVE},
91         {"CollectiveGather", NC_COLLECTIVE},
92         {"FakeParam", NC_FAKE_PARAM},
93         {"PartitionedCall", NC_PARTITIONED_CALL},
94         {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
95         {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
96         {"If", NC_IF},
97         {"StatelessIf", NC_IF},
98         {"While", NC_WHILE},
99         {"StatelessWhile", NC_WHILE},
100         // Not using the constants defined in FunctionLibraryDefinition for the
101         // 4 ops below because android inference library does not link
102         // tf.function related files.
103         {"_Arg", NC_ARG},
104         {"_DeviceArg", NC_ARG},
105         {"_Retval", NC_RETVAL},
106         {"_DeviceRetval", NC_RETVAL},
107         {"_XlaMerge", NC_MERGE},
108     });
109 
110 #undef REF_CLASS
111 
GetNodeClassForOp(const string & ts)112 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
113   auto it = kNodeClassTable.find(ts);
114   if (it != kNodeClassTable.end()) {
115     return it->second;
116   } else {
117     return NC_OTHER;
118   }
119 }
120 
DebugString() const121 string Node::DebugString() const {
122   string ret = strings::StrCat("{name:'", name(), "' id:", id_);
123   if (IsSource()) {
124     strings::StrAppend(&ret, " source}");
125   } else if (IsSink()) {
126     strings::StrAppend(&ret, " sink}");
127   } else {
128     strings::StrAppend(&ret, " op device:");
129     strings::StrAppend(&ret, "{", assigned_device_name(), "}");
130     strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
131   }
132   return ret;
133 }
134 
Node()135 Node::Node()
136     : id_(-1),
137       cost_id_(-1),
138       class_(NC_UNINITIALIZED),
139       props_(nullptr),
140       assigned_device_name_index_(0),
141       while_ctx_(nullptr) {}
142 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,bool is_function_op)143 void Node::Initialize(int id, int cost_id,
144                       std::shared_ptr<NodeProperties> props,
145                       bool is_function_op) {
146   DCHECK_EQ(id_, -1);
147   DCHECK(in_edges_.empty());
148   DCHECK(out_edges_.empty());
149   id_ = id;
150   cost_id_ = cost_id;
151 
152   props_ = std::move(props);
153   // Initialize the class_ based on the type string
154   if (is_function_op) {
155     class_ = NC_FUNCTION_OP;
156   } else {
157     class_ = GetNodeClassForOp(props_->node_def.op());
158   }
159 }
160 
Clear()161 void Node::Clear() {
162   in_edges_.clear();
163   out_edges_.clear();
164   id_ = -1;
165   cost_id_ = -1;
166   class_ = NC_UNINITIALIZED;
167   props_.reset();
168   assigned_device_name_index_ = 0;
169 }
170 
UpdateProperties()171 void Node::UpdateProperties() {
172   DataTypeVector inputs;
173   DataTypeVector outputs;
174   Status status =
175       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
176   if (!status.ok()) {
177     LOG(ERROR) << "Failed at updating node: " << status;
178     return;
179   }
180   props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
181                                             inputs, outputs);
182 }
183 
name() const184 const string& Node::name() const { return props_->node_def.name(); }
type_string() const185 const string& Node::type_string() const { return props_->node_def.op(); }
def() const186 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const187 const OpDef& Node::op_def() const { return *props_->op_def; }
188 
num_inputs() const189 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const190 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const191 const DataTypeVector& Node::input_types() const { return props_->input_types; }
192 
num_outputs() const193 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const194 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const195 const DataTypeVector& Node::output_types() const {
196   return props_->output_types;
197 }
198 
attrs() const199 AttrSlice Node::attrs() const { return AttrSlice(def()); }
200 
requested_inputs() const201 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
202   return def().input();
203 }
204 
requested_device() const205 const string& Node::requested_device() const { return def().device(); }
206 
out_nodes() const207 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
208   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
209                          NeighborIter(out_edges_.end(), false));
210 }
211 
in_nodes() const212 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
213   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
214                          NeighborIter(in_edges_.end(), true));
215 }
216 
MaybeCopyOnWrite()217 void Node::MaybeCopyOnWrite() {
218   // NodeProperties may be shared between Nodes. Make a copy if so.
219   if (!props_.unique()) {
220     props_ = std::make_shared<NodeProperties>(*props_);
221   }
222 }
223 
AddAttrHelper(const string & name)224 AttrValue* Node::AddAttrHelper(const string& name) {
225   MaybeCopyOnWrite();
226   return &((*props_->node_def.mutable_attr())[name]);
227 }
228 
ClearAttr(const string & name)229 void Node::ClearAttr(const string& name) {
230   MaybeCopyOnWrite();
231   (*props_->node_def.mutable_attr()).erase(name);
232 }
233 
set_name(string name)234 void Node::set_name(string name) {
235   MaybeCopyOnWrite();
236   props_->node_def.set_name(std::move(name));
237 }
238 
set_requested_device(const string & device)239 void Node::set_requested_device(const string& device) {
240   MaybeCopyOnWrite();
241   props_->node_def.set_device(device);
242 }
243 
set_original_node_names(const std::vector<string> & names)244 void Node::set_original_node_names(const std::vector<string>& names) {
245   MaybeCopyOnWrite();
246   props_->node_def.mutable_experimental_debug_info()
247       ->clear_original_node_names();
248   if (!names.empty()) {
249     *props_->node_def.mutable_experimental_debug_info()
250          ->mutable_original_node_names() = {names.begin(), names.end()};
251   }
252 }
253 
input_edge(int idx,const Edge ** e) const254 Status Node::input_edge(int idx, const Edge** e) const {
255   if (idx < 0 || idx >= num_inputs()) {
256     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
257                                    name(), " only has ", num_inputs(),
258                                    " inputs.");
259   }
260 
261   // This does a linear search over the edges.  In the common case,
262   // the number of elements is small enough that this search isn't
263   // expensive.  Should it become a bottleneck, one can make an
264   // optimization where, if the number of edges is small, we use
265   // linear iteration, and if the number of edges is large, we perform
266   // an indexing step during construction that keeps an array of Edges
267   // indexed by pointer.  This would keep the size of each Node small
268   // in the common case but make this function faster when the number
269   // of edges is large.
270   for (const Edge* edge : in_edges()) {
271     if (edge->dst_input() == idx) {
272       *e = edge;
273       return Status::OK();
274     }
275   }
276 
277   return errors::NotFound("Could not find input edge ", idx, " for ", name());
278 }
279 
280 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const281 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
282   input_edges->clear();
283   input_edges->resize(num_inputs(), nullptr);
284 
285   for (const Edge* edge : in_edges()) {
286     if (edge->IsControlEdge()) continue;
287     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
288       return errors::Internal("Invalid edge input number ", edge->dst_input());
289     }
290     if ((*input_edges)[edge->dst_input()] != nullptr) {
291       return errors::Internal("Duplicate edge input number: ",
292                               edge->dst_input());
293     }
294     (*input_edges)[edge->dst_input()] = edge;
295   }
296 
297   for (int i = 0; i < num_inputs(); ++i) {
298     if ((*input_edges)[i] == nullptr) {
299       return errors::InvalidArgument("Missing edge input number: ", i);
300     }
301   }
302   return Status::OK();
303 }
304 
input_node(int idx,Node ** n) const305 Status Node::input_node(int idx, Node** n) const {
306   const Edge* e;
307   TF_RETURN_IF_ERROR(input_edge(idx, &e));
308   if (e == nullptr) {
309     *n = nullptr;
310   } else {
311     *n = e->src();
312   }
313   return Status::OK();
314 }
315 
input_node(int idx,const Node ** const_n) const316 Status Node::input_node(int idx, const Node** const_n) const {
317   Node* n;
318   TF_RETURN_IF_ERROR(input_node(idx, &n));
319   *const_n = n;
320   return Status::OK();
321 }
322 
input_tensor(int idx,OutputTensor * t) const323 Status Node::input_tensor(int idx, OutputTensor* t) const {
324   const Edge* e;
325   TF_RETURN_IF_ERROR(input_edge(idx, &e));
326   DCHECK(e != nullptr);
327   *t = OutputTensor(e->src(), e->src_output());
328   return Status::OK();
329 }
330 
331 // NodeDebugInfo
332 
NodeDebugInfo(const Node & n)333 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)334 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
335     : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
336                     ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)337 NodeDebugInfo::NodeDebugInfo(
338     StringPiece node_name, bool has_experimental_debug_info,
339     const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
340     : name(node_name) {
341   if (has_experimental_debug_info) {
342     const auto& names = experimental_debug_info.original_node_names();
343     original_node_names.assign(names.begin(), names.end());
344   }
345 }
346 
347 // InputTensor
348 
operator ==(const InputTensor & other) const349 bool InputTensor::operator==(const InputTensor& other) const {
350   return node == other.node && index == other.index;
351 }
352 
operator ()(InputTensor const & s) const353 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
354   return Hash64Combine(std::hash<const Node*>()(s.node),
355                        std::hash<int>()(s.index));
356 }
357 
358 // OutputTensor
359 
operator ==(const OutputTensor & other) const360 bool OutputTensor::operator==(const OutputTensor& other) const {
361   return node == other.node && index == other.index;
362 }
363 
operator ()(OutputTensor const & s) const364 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
365   return Hash64Combine(std::hash<const Node*>()(s.node),
366                        std::hash<int>()(s.index));
367 }
368 
369 // Graph
370 
Graph(const OpRegistryInterface * ops)371 Graph::Graph(const OpRegistryInterface* ops)
372     : ops_(ops, FunctionDefLibrary()),
373       versions_(new VersionDef),
374       arena_(8 << 10 /* 8kB */) {
375   versions_->set_producer(TF_GRAPH_DEF_VERSION);
376   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
377 
378   // Initialize the name interning table for assigned_device_name.
379   device_names_.push_back("");
380   DCHECK_EQ(0, InternDeviceName(""));
381 
382   // Source and sink have no endpoints, just control edges.
383   NodeDef def;
384   def.set_name("_SOURCE");
385   def.set_op("NoOp");
386   Status status;
387   Node* source = AddNode(def, &status);
388   TF_CHECK_OK(status);
389   CHECK_EQ(source->id(), kSourceId);
390 
391   def.set_name("_SINK");
392   Node* sink = AddNode(def, &status);
393   TF_CHECK_OK(status);
394   CHECK_EQ(sink->id(), kSinkId);
395 
396   AddControlEdge(source, sink);
397 }
398 
Graph(const FunctionLibraryDefinition & flib_def)399 Graph::Graph(const FunctionLibraryDefinition& flib_def)
400     : Graph(flib_def.default_registry()) {
401   // Need a new-enough consumer to support the functions we add to the graph.
402   if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
403     versions_->set_min_consumer(12);
404   }
405   Status s = ops_.AddLibrary(flib_def);
406   CHECK(s.ok()) << s.error_message();
407 }
408 
~Graph()409 Graph::~Graph() {
410   // Manually call the destructors for all the Nodes we constructed using
411   // placement new.
412   for (Node* node : nodes_) {
413     if (node != nullptr) {
414       node->~Node();
415     }
416   }
417   for (Node* node : free_nodes_) {
418     node->~Node();
419   }
420   // Edges have no destructor, and we arena-allocated them, so no need to
421   // destroy them.
422 }
423 
versions() const424 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)425 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
426 
AddNode(NodeDef node_def,Status * status)427 Node* Graph::AddNode(NodeDef node_def, Status* status) {
428   const OpRegistrationData* op_reg_data;
429   status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
430   if (!status->ok()) return nullptr;
431 
432   DataTypeVector inputs;
433   DataTypeVector outputs;
434   status->Update(
435       InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
436   if (!status->ok()) {
437     *status = AttachDef(*status, node_def);
438     return nullptr;
439   }
440 
441   Node* node = AllocateNode(
442       std::make_shared<NodeProperties>(&op_reg_data->op_def,
443                                        std::move(node_def), inputs, outputs),
444       nullptr, op_reg_data->is_function_op);
445   return node;
446 }
447 
CopyNode(const Node * node)448 Node* Graph::CopyNode(const Node* node) {
449   DCHECK(!node->IsSource());
450   DCHECK(!node->IsSink());
451   Node* copy =
452       AllocateNode(node->props_, node, node->class_ == Node::NC_FUNCTION_OP);
453   copy->set_assigned_device_name(node->assigned_device_name());
454 
455   // Since the OpDef of a function may be owned by the Graph that owns 'node',
456   // relookup the OpDef in the target graph. If it differs, then clone the
457   // node properties with the updated OpDef.
458   const OpDef* op_def;
459   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
460   if (op_def != node->props_->op_def) {
461     copy->MaybeCopyOnWrite();
462     copy->props_->op_def = op_def;
463   }
464 
465   return copy;
466 }
467 
RemoveNode(Node * node)468 void Graph::RemoveNode(Node* node) {
469   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
470   DCHECK(!node->IsSource());
471   DCHECK(!node->IsSink());
472 
473   // Remove any edges involving this node.
474   for (const Edge* e : node->in_edges_) {
475     CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
476     edges_[e->id_] = nullptr;
477     RecycleEdge(e);
478     --num_edges_;
479   }
480   node->in_edges_.clear();
481   for (const Edge* e : node->out_edges_) {
482     CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
483     edges_[e->id_] = nullptr;
484     RecycleEdge(e);
485     --num_edges_;
486   }
487   node->out_edges_.clear();
488   ReleaseNode(node);
489 }
490 
AddEdge(Node * source,int x,Node * dest,int y)491 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
492   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
493   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
494 
495   // source/sink must only be linked via control slots, and
496   // control slots must only be linked to control slots.
497   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
498       y == kControlSlot) {
499     DCHECK_EQ(x, kControlSlot) << source->DebugString();
500     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
501   }
502 
503   Edge* e = nullptr;
504   if (free_edges_.empty()) {
505     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
506   } else {
507     e = free_edges_.back();
508     free_edges_.pop_back();
509   }
510   e->id_ = edges_.size();
511   e->src_ = source;
512   e->dst_ = dest;
513   e->src_output_ = x;
514   e->dst_input_ = y;
515   CHECK(source->out_edges_.insert(e).second);
516   CHECK(dest->in_edges_.insert(e).second);
517   edges_.push_back(e);
518   ++num_edges_;
519   return e;
520 }
521 
RemoveEdge(const Edge * e)522 void Graph::RemoveEdge(const Edge* e) {
523   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
524   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
525   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
526   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
527   CHECK_EQ(e, edges_[e->id_]);
528   CHECK_GT(num_edges_, 0);
529 
530   edges_[e->id_] = nullptr;
531   RecycleEdge(e);
532   --num_edges_;
533 }
534 
RecycleEdge(const Edge * e)535 void Graph::RecycleEdge(const Edge* e) {
536   free_edges_.push_back(const_cast<Edge*>(e));
537 }
538 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)539 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
540                                   bool allow_duplicates) {
541   if (!allow_duplicates) {
542     for (const Edge* edge : dest->in_edges()) {
543       if (edge->IsControlEdge() && edge->src() == source) {
544         // The requested edge already exists.
545         return nullptr;
546       }
547     }
548   }
549   // Modify dest's NodeDef if necessary.
550   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
551     // Check if this input is already in dest's NodeDef.
552     const string new_input = strings::StrCat("^", source->name());
553     bool input_exists = false;
554     for (const string& input : dest->props_->node_def.input()) {
555       if (input == new_input) {
556         input_exists = true;
557         break;
558       }
559     }
560     if (!input_exists) {
561       dest->MaybeCopyOnWrite();
562       dest->props_->node_def.add_input(new_input);
563     }
564   }
565   return AddEdge(source, kControlSlot, dest, kControlSlot);
566 }
567 
RemoveControlEdge(const Edge * e)568 void Graph::RemoveControlEdge(const Edge* e) {
569   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
570     e->dst_->MaybeCopyOnWrite();
571     string e_src_name = strings::StrCat("^", e->src_->name());
572     auto* inputs = e->dst_->props_->node_def.mutable_input();
573     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
574       if (*it == e_src_name) {
575         inputs->erase(it);
576         break;
577       }
578     }
579   }
580   RemoveEdge(e);
581 }
582 
583 namespace {
FindEdge(const Node * dst,int index)584 const Edge* FindEdge(const Node* dst, int index) {
585   for (const Edge* e : dst->in_edges()) {
586     if (e->dst_input() == index) return e;
587   }
588   return nullptr;
589 }
590 }  // namespace
591 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)592 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
593                          int dst_index) {
594   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
595   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
596   const Edge* e = FindEdge(dst, dst_index);
597   if (e == nullptr) {
598     return errors::InvalidArgument("Couldn't find edge to ",
599                                    FormatNodeForError(*dst));
600   }
601   RemoveEdge(e);
602   AddEdge(new_src, new_src_index, dst, dst_index);
603   dst->MaybeCopyOnWrite();
604   (*dst->props_->node_def.mutable_input())[dst_index] =
605       strings::StrCat(new_src->name(), ":", new_src_index);
606   return Status::OK();
607 }
608 
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)609 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
610   if (!dst->IsWhileNode()) {
611     return errors::Internal(
612         "dst argument to AddWhileEdgeHack should be a While op, got: ",
613         dst->DebugString());
614   }
615   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
616   // Find the current number of data inputs. We'll add the new edge to the next
617   // missing data input.
618   int dst_index = 0;
619   for (const Edge* edge : dst->in_edges()) {
620     if (edge->IsControlEdge()) continue;
621     ++dst_index;
622   }
623   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
624   AddEdge(new_src, new_src_index, dst, dst_index);
625   dst->MaybeCopyOnWrite();
626   dst->props_->node_def.add_input(
627       strings::StrCat(new_src->name(), ":", new_src_index));
628   return Status::OK();
629 }
630 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)631 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
632   // Need a new-enough consumer to support the functions we add to the graph.
633   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
634     versions_->set_min_consumer(12);
635   }
636   return ops_.AddLibrary(fdef_lib);
637 }
638 
639 namespace {
640 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)641 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
642   if (src_slot == Graph::kControlSlot) {
643     dst->add_input(strings::StrCat("^", src_name));
644   } else if (src_slot == 0) {
645     dst->add_input(src_name.data(), src_name.size());
646   } else {
647     dst->add_input(strings::StrCat(src_name, ":", src_slot));
648   }
649 }
650 
651 }  // namespace
652 
ToGraphDef(GraphDef * graph_def) const653 void Graph::ToGraphDef(GraphDef* graph_def) const {
654   ToGraphDefSubRange(graph_def, 0);
655 }
656 
ToGraphDefDebug() const657 GraphDef Graph::ToGraphDefDebug() const {
658   GraphDef ret;
659   ToGraphDef(&ret);
660   return ret;
661 }
662 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const663 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
664   graph_def->Clear();
665   *graph_def->mutable_versions() = versions();
666   *graph_def->mutable_library() = ops_.ToProto();
667 
668   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
669 
670   std::vector<const Edge*>
671       inputs;  // Construct this outside the loop for speed.
672   for (auto id = from_node_id; id < num_node_ids(); ++id) {
673     const Node* node = FindNodeId(id);
674     if (node == nullptr || !node->IsOp()) continue;
675     NodeDef* node_def = graph_def->add_node();
676     *node_def = node->def();
677 
678     // Use the node's assigned device, if any, instead of the device requested
679     // in the NodeDef.
680     if (!node->assigned_device_name().empty()) {
681       node_def->set_device(node->assigned_device_name());
682     }
683 
684     // Get the inputs for this Node.  We make sure control inputs are
685     // after data inputs, as required by GraphDef.
686     inputs.clear();
687     inputs.resize(node->num_inputs(), nullptr);
688     for (const Edge* edge : node->in_edges()) {
689       if (edge->IsControlEdge()) {
690         inputs.push_back(edge);
691       } else {
692         DCHECK(edge->dst_input() < inputs.size())
693             << "Edge " << edge->DebugString()
694             << " is overflowing the expected number of inputs ("
695             << node->num_inputs() << ") for node " << node->DebugString();
696         CHECK(inputs[edge->dst_input()] == nullptr)
697             << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
698             << " conflicts with pre-existing input edge "
699             << inputs[edge->dst_input()]->src()->name() << "->"
700             << inputs[edge->dst_input()]->dst()->name();
701 
702         inputs[edge->dst_input()] = edge;
703       }
704     }
705     // Sort the control inputs for more predictable serialization.
706     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
707               [](const Edge* a, const Edge* b) -> bool {
708                 return a->src()->name() < b->src()->name();
709               });
710     node_def->clear_input();
711     node_def->mutable_input()->Reserve(inputs.size());
712 
713     for (size_t i = 0; i < inputs.size(); ++i) {
714       const Edge* edge = inputs[i];
715       if (edge == nullptr) {
716         if (i < node->requested_inputs().size()) {
717           node_def->add_input(node->requested_inputs()[i]);
718         } else {
719           node_def->add_input("");
720         }
721       } else {
722         const Node* src = edge->src();
723         if (!src->IsOp()) continue;
724         AddInput(node_def, src->name(), edge->src_output());
725       }
726     }
727   }
728 }
729 
NewName(StringPiece prefix)730 string Graph::NewName(StringPiece prefix) {
731   return strings::StrCat(prefix, "/_", name_counter_++);
732 }
733 
IsValidNode(const Node * node) const734 Status Graph::IsValidNode(const Node* node) const {
735   if (node == nullptr) {
736     return errors::InvalidArgument("Node is null");
737   }
738   const int id = node->id();
739   if (id < 0) {
740     return errors::InvalidArgument("node id ", id, " is less than zero");
741   }
742   if (static_cast<size_t>(id) >= nodes_.size()) {
743     return errors::InvalidArgument(
744         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
745   }
746   if (nodes_[id] != node) {
747     return errors::InvalidArgument("Node with id ", id,
748                                    " is different from the passed in node. "
749                                    "Does it belong to a different graph?");
750   }
751   return Status::OK();
752 }
753 
IsValidOutputTensor(const Node * node,int idx) const754 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
755   TF_RETURN_IF_ERROR(IsValidNode(node));
756   if (idx >= node->num_outputs() || idx < 0) {
757     return errors::OutOfRange("Node '", node->name(), "' (type: '",
758                               node->op_def().name(),
759                               "', num of outputs: ", node->num_outputs(),
760                               ") does not have ", "output ", idx);
761   }
762   return Status::OK();
763 }
764 
IsValidInputTensor(const Node * node,int idx) const765 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
766   TF_RETURN_IF_ERROR(IsValidNode(node));
767   if (idx >= node->num_inputs() || idx < 0) {
768     return errors::OutOfRange("Node '", node->name(), "' (type: '",
769                               node->op_def().name(),
770                               "', num of inputs: ", node->num_inputs(),
771                               ") does not have ", "input ", idx);
772   }
773   return Status::OK();
774 }
775 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,bool is_function_op)776 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
777                           const Node* cost_node, bool is_function_op) {
778   Node* node = nullptr;
779   if (free_nodes_.empty()) {
780     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
781   } else {
782     node = free_nodes_.back();
783     free_nodes_.pop_back();
784   }
785   node->graph_ = this;
786   const int id = nodes_.size();
787   int cost_id = cost_node ? cost_node->cost_id() : id;
788   node->Initialize(id, cost_id, std::move(props), is_function_op);
789   nodes_.push_back(node);
790   ++num_nodes_;
791   return node;
792 }
793 
ReleaseNode(Node * node)794 void Graph::ReleaseNode(Node* node) {
795   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
796   nodes_[node->id()] = nullptr;
797   free_nodes_.push_back(node);
798   --num_nodes_;
799   node->Clear();
800 }
801 
802 // Ensures that 'device_name' is present in the device name table, and returns
803 // the index of that device name. The index is stable, and can be used in
804 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)805 int Graph::InternDeviceName(const string& device_name) {
806   // Special case, very common.  Also, this allows us to use a single map
807   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
808   // relies on this check.
809   if (device_name.empty()) {
810     return 0;
811   }
812 
813   int& index_cell = device_names_map_[device_name];
814   if (index_cell > 0) {
815     return index_cell;
816   }
817 
818   const int index = device_names_map_.size();
819   index_cell = index;
820   device_names_.push_back(device_name);
821   return index;
822 }
823 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)824 Status Graph::AddWhileContext(StringPiece frame_name,
825                               std::vector<Node*> enter_nodes,
826                               std::vector<Node*> exit_nodes,
827                               OutputTensor cond_output,
828                               std::vector<OutputTensor> body_inputs,
829                               std::vector<OutputTensor> body_outputs,
830                               WhileContext** result) {
831   auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
832       string(frame_name),
833       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
834                    cond_output, std::move(body_inputs),
835                    std::move(body_outputs))));
836   if (!pair.second) {
837     *result = nullptr;
838     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
839                                    "' already exists");
840   }
841   *result = &pair.first->second;
842   return Status::OK();
843 }
844 
BuildNodeNameIndex() const845 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
846   std::unordered_map<string, Node*> result;
847   for (Node* n : nodes()) {
848     result[n->name()] = n;
849   }
850   return result;
851 }
852 
DebugString() const853 string Edge::DebugString() const {
854   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
855                          src_output_, dst_->name().c_str(), dst_input_);
856 }
857 
858 }  // namespace tensorflow
859