• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/tensor_id.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 // Utilities for manipulating node name and input strings.
45 
46 // Returns the trailing position number (or zero if no number is present) if
47 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
48 // Returns -2 if input_name is empty or NodeName(input_name) is not equal to
49 // node_name.
NodePositionIfSameNode(absl::string_view input_name,absl::string_view node_name)50 inline int NodePositionIfSameNode(absl::string_view input_name,
51                                   absl::string_view node_name) {
52   bool is_control = absl::StartsWith(input_name, "^");
53   if (is_control) input_name.remove_prefix(1);
54   if (input_name.empty() || node_name.empty() ||
55       input_name.size() < node_name.size()) {
56     return -2;
57   }
58   TensorId id = ParseTensorName(input_name);
59   if (id.first != node_name) return -2;
60   if (is_control) return -1;
61   return id.second;
62 }
63 
64 // Returns the node name and position in a single call.
ParseNodeNameAsStringPiece(absl::string_view name,int * position)65 inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
66                                               int* position) {
67   const bool is_control = absl::StartsWith(name, "^");
68   TensorId id = ParseTensorName(name);
69   if (position) {
70     *position = is_control ? -1 : id.second;
71   }
72   if (is_control && id.second >= 0) {
73     id.first.remove_prefix(1);
74   }
75   return id.first;
76 }
77 
78 // Returns the node name and position in a single call.
ParseNodeName(const string & name,int * position)79 inline string ParseNodeName(const string& name, int* position) {
80   return string(ParseNodeNameAsStringPiece(name, position));
81 }
82 
83 // Return the node name corresponding to 'name' if name is valid, or the empty
84 // string otherwise.
NodeNameAsStringPiece(const string & name)85 inline StringPiece NodeNameAsStringPiece(const string& name) {
86   return ParseNodeNameAsStringPiece(name, nullptr);
87 }
88 
89 // Return the node name corresponding to 'name' if name is valid, or the empty
90 // string otherwise.
NodeName(const string & name)91 inline string NodeName(const string& name) {
92   return string(NodeNameAsStringPiece(name));
93 }
94 
NodePosition(const string & name)95 inline int NodePosition(const string& name) {
96   int position;
97   ParseNodeNameAsStringPiece(name, &position);
98   return position;
99 }
100 
101 namespace internal {
102 // Base template class for NodeMap and ImmutableNodeMap.
103 template <typename GraphDefT, typename NodeDefT>
104 class NodeMapInternal {
105  public:
106   // Note: The NodeMap will store pointers to nodes in graph, which may become
107   // invalid if graph is changed.
NodeMapInternal(GraphDefT * graph)108   explicit NodeMapInternal(GraphDefT* graph) {
109     if (graph == nullptr) {
110       LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!";
111       return;
112     }
113     nodes_.reserve(graph->node_size());
114     outputs_.reserve(graph->node_size());
115     for (int i = 0; i < graph->node_size(); i++) {
116       NodeDefT* node = GetNodeDefFromGraph(graph, i);
117       const string& node_name = node->name();
118       auto rslt = nodes_.emplace(node_name, node);
119       // Check that the graph doesn't contain multiple nodes with the same name.
120       if (!rslt.second) {
121         // The first node found with a given name becomes the canonical.
122         LOG(WARNING) << "Duplicated node in the graph: " << node_name;
123       }
124       NodeDefT* canonical = rslt.second ? node : rslt.first->second;
125       for (const auto& input : node->input()) {
126         outputs_[NodeName(input)].insert(canonical);
127       }
128     }
129   }
130 
131   // Get unordered list of fanouts from node. Notice, that the order is
132   // non-deterministic.
GetOutputs(const string & node_name)133   const absl::flat_hash_set<NodeDefT*>& GetOutputs(
134       const string& node_name) const {
135     auto it = outputs_.find(node_name);
136     if (it == outputs_.end()) {
137       return empty_set_;
138     }
139     return it->second;
140   }
141 
142   // Get fanouts ordered by name.
GetOutputsOrderedByNodeName(const string & node_name)143   std::vector<NodeDefT*> GetOutputsOrderedByNodeName(
144       const string& node_name) const {
145     std::vector<NodeDefT*> result;
146     auto it = outputs_.find(node_name);
147     if (it != outputs_.end()) {
148       const absl::flat_hash_set<NodeDefT*>& outputs = it->second;
149       result.reserve(outputs.size());
150       result.assign(outputs.begin(), outputs.end());
151       std::sort(result.begin(), result.end(),
152                 [](const NodeDef* n1, const NodeDef* n2) {
153                   return n1->name() < n2->name();
154                 });
155     }
156     return result;
157   }
158 
159   // This method doesn't record the outputs of the added node; the outputs need
160   // to be explicitly added by the AddOutput method.
AddNode(const string & node_name,NodeDefT * node)161   void AddNode(const string& node_name, NodeDefT* node) {
162     DCHECK(node != nullptr);
163     auto ret = nodes_.emplace(node_name, node);
164     DCHECK(ret.second)
165         << "Pair (" << node_name << "," << node
166         << ") is not inserted because the same key already exists.";
167   }
168 
RemoveNode(const string & name)169   void RemoveNode(const string& name) {
170     nodes_.erase(NodeName(name));
171     outputs_.erase(NodeName(name));
172   }
173 
GetNode(const string & name)174   NodeDefT* GetNode(const string& name) const {
175     const string node_name = NodeName(name);
176     auto it = nodes_.find(node_name);
177     if (it == nodes_.end()) {
178       VLOG(1) << "Node could not be found: " << name;
179       return nullptr;
180     }
181     return it->second;
182   }
183 
NodeExists(const string & name)184   bool NodeExists(const string& name) const {
185     const string node_name = NodeName(name);
186     return nodes_.find(node_name) != nodes_.end();
187   }
188 
AddOutput(const string & node_name,const string & output_name)189   void AddOutput(const string& node_name, const string& output_name) {
190     auto output_node = nodes_[NodeName(output_name)];
191     DCHECK(output_node) << "Output node " << output_name
192                         << " is missing in NodeMap.";
193     outputs_[node_name].insert(output_node);
194   }
195 
RemoveOutput(const string & node_name,const string & output_name)196   void RemoveOutput(const string& node_name, const string& output_name) {
197     outputs_[node_name].erase(nodes_[NodeName(output_name)]);
198   }
199 
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)200   void UpdateInput(const string& node_name, const string& old_input_name,
201                    const string& new_input_name) {
202     RemoveOutput(NodeName(old_input_name), node_name);
203     AddOutput(NodeName(new_input_name), node_name);
204   }
205 
RemoveInputs(const string & node_name)206   void RemoveInputs(const string& node_name) {
207     auto node = nodes_[node_name];
208     for (const auto& input : node->input()) {
209       RemoveOutput(NodeName(input), node->name());
210     }
211   }
212 
RemoveOutputs(const string & node_name)213   void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
214 
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)215   void UpdateOutput(const string& node_name, const string& old_output_name,
216                     const string& new_output_name) {
217     absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
218     outputs.erase(nodes_[NodeName(old_output_name)]);
219     outputs.insert(nodes_[NodeName(new_output_name)]);
220   }
221 
222  private:
223   // Helper method to get the NodeDef pointer of i-th node in a graph.
224   inline NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64_t i) const;
225 
226   const absl::flat_hash_set<NodeDefT*> empty_set_;
227   absl::node_hash_map<string, NodeDefT*> nodes_;
228   absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_;
229 };
230 
231 // Specialized template class method GetNodeDefFromGraph.
232 template <>
GetNodeDefFromGraph(GraphDef * graph,int64_t i)233 inline NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph(
234     GraphDef* graph, int64_t i) const {
235   return graph->mutable_node(i);
236 }
237 
238 template <>
239 inline const NodeDef*
GetNodeDefFromGraph(const GraphDef * graph,int64_t i)240 NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph(
241     const GraphDef* graph, int64_t i) const {
242   return &graph->node(i);
243 }
244 }  // namespace internal
245 
246 // A utility class to lookup a node and its outputs by node name.
247 class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> {
248  public:
NodeMap(GraphDef * graph)249   explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {}
250 };
251 
252 // Same to NodeMap, but uses const GraphDef.
253 class ImmutableNodeMap
254     : public internal::NodeMapInternal<const GraphDef, const NodeDef> {
255  public:
ImmutableNodeMap(const GraphDef * graph)256   explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {}
257 };
258 
259 // A vector with a set. The set stores the same elements as the vector, and
260 // quickly answers whether a value is in the vector. Duplicated elements are not
261 // allowed for now.
262 template <class T, class Hash = std::hash<T>>
263 class SetVector {
264  public:
265   // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)266   bool PushBack(const T& value) {
267     if (!set_.insert(value).second) {
268       return false;
269     }
270     vector_.push_back(value);
271     return true;
272   }
273 
PopBack()274   T PopBack() {
275     T back = vector_.back();
276     set_.erase(back);
277     vector_.pop_back();
278     return back;
279   }
280 
Exists(const T & value)281   bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
282 
Empty()283   bool Empty() const { return vector_.empty(); }
284 
Reserve(int64_t size)285   void Reserve(int64_t size) { vector_.reserve(size); }
286 
287  private:
288   gtl::FlatSet<T, Hash> set_;
289   std::vector<T> vector_;
290 };
291 
292 // Returns formatted string from TensorId specific to grappler. Specifically,
293 // for the 0 port (first output), only the node name is returned.
294 string TensorIdToString(const TensorId& tensor_id);
295 
296 // Returns formatted string from SafeTensorId specific to grappler.
297 // Specifically, for the 0 port (first output), only the node name is returned.
298 string SafeTensorIdToString(const SafeTensorId& tensor_id);
299 
300 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
301 // the ^ character.
302 bool IsControlInput(absl::string_view name);
303 
304 // True iff tensor index refers to a control input.
305 bool IsControlInput(const TensorId& tensor_id);
306 
307 // True iff 'name1' and 'name2' refer to the same input.
308 bool IsSameInput(const string& name1, const string& name2);
309 
310 
311 // Add a prefix to a node name with a custom delimiter.
312 string AddPrefixToNodeName(const string& name, const string& prefix,
313                            const string& delimiter);
314 
315 // Add a prefix to a node name.
316 string AddPrefixToNodeName(const string& name, const string& prefix);
317 
318 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
319 // timeout (in milliseconds) for 'fn' to complete, before returning false.
320 //
321 // If returning false, the 'fn' may still continue to execute in the
322 // thread-pool. It is the responsibility of the caller to reset the thread-pool
323 // as appropriate.
324 bool ExecuteWithTimeout(std::function<void()> fn, int64_t timeout_in_ms,
325                         thread::ThreadPool* thread_pool);
326 
327 // Returns the node name prefixed with conventional symbol '^'
328 // for control dependency, given a NodeDef.
329 string AsControlDependency(const NodeDef& node);
330 
331 // Returns the node name prefixed with conventional symbol '^'
332 // for control dependency, given a node name
333 string AsControlDependency(const string& node);
334 
335 // Returns true if the node is assigned to run on CPU device.
336 bool NodeIsOnCpu(const NodeDef* node);
337 
338 // Returns true if the node is assigned to run on GPU device.
339 bool NodeIsOnGpu(const NodeDef* node);
340 
341 // Returns the number of outputs of a node according to its OpDef. Note that
342 // some of the outputs may be unconnected.
343 int NumOutputs(const NodeDef& node, GraphDef* graph);
344 
345 // Returns true iff the node has at least one control input.
346 bool HasControlInputs(const NodeDef& node);
347 
348 // Returns true iff the node has at least one regular input.
349 bool HasRegularInputs(const NodeDef& node);
350 
351 // Returns true iff the node has at least one regular output.
352 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
353 
354 // Returns true iff the node has at least one control output.
355 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
356 
357 // Number of connected control inputs.
358 int NumControlInputs(const NodeDef& node);
359 
360 // Number of connected non-control inputs.
361 int NumNonControlInputs(const NodeDef& node);
362 
363 // Number of connected control outputs.
364 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
365 
366 // Number of connected non-control outputs.
367 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
368 
369 // Number of connected non-control data outputs (Ops that consume output tensor
370 // data, not just it's shape).
371 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
372 
373 // Removes redundant control inputs from node.
374 void DedupControlInputs(NodeDef* node);
375 
376 // Returns an error if an attribute with the given key does not exist in node.
377 Status CheckAttrExists(const NodeDef& node, const string& key);
378 
379 // Returns an error if attributes with the given keys do not exist in node.
380 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
381 
382 // Returns the data type in attribute `attr_name` of `node`. If that attribute
383 // doesn't exist, returns DT_INVALID.
384 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
385 
386 // Returns the last node in the simple chain starting at source and traversing
387 // through the input(0) edge from each node as long as the next node satisfies
388 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
389 // will be returned. Example: For the chain
390 //    source <- a <- b <- ... <- y <- z
391 // where
392 //    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
393 //    pred_fn(z) = false,
394 // the return value will be a pointer to y.
395 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
396                         bool follow_control_input,
397                         const std::function<bool(const NodeDef&)>& pred_fn);
398 
399 // Permute the nodes of graph in place according to the permutation.
400 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
401                          bool invert_permutation);
402 
403 // Returns Status::OK() if a kernel is registered for node.op() on the device
404 // type corresponding to node.device().
405 Status IsKernelRegisteredForNode(
406     absl::string_view node_name, bool has_experimental_debug_info,
407     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
408     absl::string_view node_op, absl::string_view node_device,
409     AttrSlice node_attrs);
410 Status IsKernelRegisteredForNode(const NodeDef& node);
411 
412 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
413 
414 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
415 
416 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
417 
418 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
419                          GraphDef* graph);
420 
421 // Erase all attributes without leading underscore. Returns the number of
422 // attributes erased.
423 int EraseRegularNodeAttributes(NodeDef* node);
424 
425 // Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
426 // "_output_".
427 int EraseNodeOutputAttributes(NodeDef* node);
428 
429 }  // end namespace grappler
430 }  // end namespace tensorflow
431 
432 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
433