• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/tensor_id.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 // Utilities for manipulating node name and input strings.
45 
46 // Returns the trailing position number (or zero if no number is present) if
47 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
48 // Returns -2 if input_name is empty or NodeName(input_name) is not equal to
49 // node_name.
NodePositionIfSameNode(absl::string_view input_name,absl::string_view node_name)50 inline int NodePositionIfSameNode(absl::string_view input_name,
51                                   absl::string_view node_name) {
52   bool is_control = absl::StartsWith(input_name, "^");
53   if (is_control) input_name.remove_prefix(1);
54   if (input_name.empty() || node_name.empty() ||
55       input_name.size() < node_name.size()) {
56     return -2;
57   }
58   TensorId id = ParseTensorName(input_name);
59   if (id.first != node_name) return -2;
60   if (is_control) return -1;
61   return id.second;
62 }
63 
64 // Returns the node name and position in a single call.
ParseNodeNameAsStringPiece(absl::string_view name,int * position)65 inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
66                                               int* position) {
67   const bool is_control = absl::StartsWith(name, "^");
68   TensorId id = ParseTensorName(name);
69   if (position) {
70     *position = is_control ? -1 : id.second;
71   }
72   if (is_control && id.second >= 0) {
73     id.first.remove_prefix(1);
74   }
75   return id.first;
76 }
77 
78 // Returns the node name and position in a single call.
ParseNodeName(const string & name,int * position)79 inline string ParseNodeName(const string& name, int* position) {
80   return string(ParseNodeNameAsStringPiece(name, position));
81 }
82 
83 // Return the node name corresponding to 'name' if name is valid, or the empty
84 // string otherwise.
NodeNameAsStringPiece(const string & name)85 inline StringPiece NodeNameAsStringPiece(const string& name) {
86   return ParseNodeNameAsStringPiece(name, nullptr);
87 }
88 
89 // Return the node name corresponding to 'name' if name is valid, or the empty
90 // string otherwise.
NodeName(const string & name)91 inline string NodeName(const string& name) {
92   return string(NodeNameAsStringPiece(name));
93 }
94 
NodePosition(const string & name)95 inline int NodePosition(const string& name) {
96   int position;
97   ParseNodeNameAsStringPiece(name, &position);
98   return position;
99 }
100 
101 namespace internal {
102 // Base template class for NodeMap and ImmutableNodeMap.
103 template <typename GraphDefT, typename NodeDefT>
104 class NodeMapInternal {
105  public:
106   // Note: The NodeMap will store pointers to nodes in graph, which may become
107   // invalid if graph is changed.
NodeMapInternal(GraphDefT * graph)108   explicit NodeMapInternal(GraphDefT* graph) {
109     if (graph == nullptr) {
110       LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!";
111       return;
112     }
113     nodes_.reserve(graph->node_size());
114     outputs_.reserve(graph->node_size());
115     for (int i = 0; i < graph->node_size(); i++) {
116       NodeDefT* node = GetNodeDefFromGraph(graph, i);
117       const string& node_name = node->name();
118       auto rslt = nodes_.emplace(node_name, node);
119       // Check that the graph doesn't contain multiple nodes with the same name.
120       if (!rslt.second) {
121         // The first node found with a given name becomes the canonical.
122         LOG(WARNING) << "Duplicated node in the graph: " << node_name;
123       }
124       NodeDefT* canonical = rslt.second ? node : rslt.first->second;
125       for (const auto& input : node->input()) {
126         outputs_[NodeName(input)].insert(canonical);
127       }
128     }
129   }
130 
131   // Get unordered list of fanouts from node. Notice, that the order is
132   // non-deterministic.
GetOutputs(const string & node_name)133   const absl::flat_hash_set<NodeDefT*>& GetOutputs(
134       const string& node_name) const {
135     auto it = outputs_.find(node_name);
136     if (it == outputs_.end()) {
137       return empty_set_;
138     }
139     return it->second;
140   }
141 
142   // Get fanouts ordered by name.
GetOutputsOrderedByNodeName(const string & node_name)143   std::vector<NodeDefT*> GetOutputsOrderedByNodeName(
144       const string& node_name) const {
145     std::vector<NodeDefT*> result;
146     auto it = outputs_.find(node_name);
147     if (it != outputs_.end()) {
148       const absl::flat_hash_set<NodeDefT*>& outputs = it->second;
149       result.reserve(outputs.size());
150       result.assign(outputs.begin(), outputs.end());
151       std::sort(result.begin(), result.end(),
152                 [](const NodeDef* n1, const NodeDef* n2) {
153                   return n1->name() < n2->name();
154                 });
155     }
156     return result;
157   }
158 
159   // This method doesn't record the outputs of the added node; the outputs need
160   // to be explicitly added by the AddOutput method.
AddNode(const string & node_name,NodeDefT * node)161   void AddNode(const string& node_name, NodeDefT* node) {
162     DCHECK(node != nullptr);
163     auto ret = nodes_.emplace(node_name, node);
164     DCHECK(ret.second)
165         << "Pair (" << node_name << "," << node
166         << ") is not inserted because the same key already exists.";
167   }
168 
RemoveNode(const string & name)169   void RemoveNode(const string& name) {
170     nodes_.erase(NodeName(name));
171     outputs_.erase(NodeName(name));
172   }
173 
GetNode(const string & name)174   NodeDefT* GetNode(const string& name) const {
175     const string node_name = NodeName(name);
176     auto it = nodes_.find(node_name);
177     if (it == nodes_.end()) {
178       VLOG(1) << "Node could not be found: " << name;
179       return nullptr;
180     }
181     return it->second;
182   }
183 
NodeExists(const string & name)184   bool NodeExists(const string& name) const {
185     const string node_name = NodeName(name);
186     return nodes_.find(node_name) != nodes_.end();
187   }
188 
AddOutput(const string & node_name,const string & output_name)189   void AddOutput(const string& node_name, const string& output_name) {
190     auto output_node = nodes_[NodeName(output_name)];
191     DCHECK(output_node) << "Output node " << output_name
192                         << " is missing in NodeMap.";
193     outputs_[node_name].insert(output_node);
194   }
195 
RemoveOutput(const string & node_name,const string & output_name)196   void RemoveOutput(const string& node_name, const string& output_name) {
197     outputs_[node_name].erase(nodes_[NodeName(output_name)]);
198   }
199 
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)200   void UpdateInput(const string& node_name, const string& old_input_name,
201                    const string& new_input_name) {
202     RemoveOutput(NodeName(old_input_name), node_name);
203     AddOutput(NodeName(new_input_name), node_name);
204   }
205 
RemoveInputs(const string & node_name)206   void RemoveInputs(const string& node_name) {
207     auto node = nodes_[node_name];
208     for (const auto& input : node->input()) {
209       RemoveOutput(NodeName(input), node->name());
210     }
211   }
212 
RemoveOutputs(const string & node_name)213   void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
214 
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)215   void UpdateOutput(const string& node_name, const string& old_output_name,
216                     const string& new_output_name) {
217     absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
218     outputs.erase(nodes_[NodeName(old_output_name)]);
219     outputs.insert(nodes_[NodeName(new_output_name)]);
220   }
221 
222  private:
223   // Helper method to get the NodeDef pointer of i-th node in a graph.
224   NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64 i) const;
225 
226   const absl::flat_hash_set<NodeDefT*> empty_set_;
227   absl::node_hash_map<string, NodeDefT*> nodes_;
228   absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_;
229 };
230 }  // namespace internal
231 
232 // A utility class to lookup a node and its outputs by node name.
233 class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> {
234  public:
NodeMap(GraphDef * graph)235   explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {}
236 };
237 
238 // Same to NodeMap, but uses const GraphDef.
239 class ImmutableNodeMap
240     : public internal::NodeMapInternal<const GraphDef, const NodeDef> {
241  public:
ImmutableNodeMap(const GraphDef * graph)242   explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {}
243 };
244 
245 // A vector with a set. The set stores the same elements as the vector, and
246 // quickly answers whether a value is in the vector. Duplicated elements are not
247 // allowed for now.
248 template <class T, class Hash = std::hash<T>>
249 class SetVector {
250  public:
251   // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)252   bool PushBack(const T& value) {
253     if (!set_.insert(value).second) {
254       return false;
255     }
256     vector_.push_back(value);
257     return true;
258   }
259 
PopBack()260   T PopBack() {
261     T back = vector_.back();
262     set_.erase(back);
263     vector_.pop_back();
264     return back;
265   }
266 
Exists(const T & value)267   bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
268 
Empty()269   bool Empty() const { return vector_.empty(); }
270 
Reserve(int64 size)271   void Reserve(int64 size) { vector_.reserve(size); }
272 
273  private:
274   gtl::FlatSet<T, Hash> set_;
275   std::vector<T> vector_;
276 };
277 
278 // Returns formatted string from TensorId specific to grappler. Specifically,
279 // for the 0 port (first output), only the node name is returned.
280 string TensorIdToString(const TensorId& tensor_id);
281 
282 // Returns formatted string from SafeTensorId specific to grappler.
283 // Specifically, for the 0 port (first output), only the node name is returned.
284 string SafeTensorIdToString(const SafeTensorId& tensor_id);
285 
286 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
287 // the ^ character.
288 bool IsControlInput(const string& name);
289 
290 // True iff tensor index refers to a control input.
291 bool IsControlInput(const TensorId& tensor_id);
292 
293 // True iff 'name1' and 'name2' refer to the same input.
294 bool IsSameInput(const string& name1, const string& name2);
295 
296 
297 // Add a prefix to a node name with a custom delimiter.
298 string AddPrefixToNodeName(const string& name, const string& prefix,
299                            const string& delimiter);
300 
301 // Add a prefix to a node name.
302 string AddPrefixToNodeName(const string& name, const string& prefix);
303 
304 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
305 // timeout (in milliseconds) for 'fn' to complete, before returning false.
306 //
307 // If returning false, the 'fn' may still continue to execute in the
308 // thread-pool. It is the responsibility of the caller to reset the thread-pool
309 // as appropriate.
310 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
311                         thread::ThreadPool* thread_pool);
312 
313 // Returns the node name prefixed with conventional symbol '^'
314 // for control dependency, given a NodeDef.
315 string AsControlDependency(const NodeDef& node);
316 
317 // Returns the node name prefixed with conventional symbol '^'
318 // for control dependency, given a node name
319 string AsControlDependency(const string& node);
320 
321 // Returns true if the node is assigned to run on CPU device.
322 bool NodeIsOnCpu(const NodeDef* node);
323 
324 // Returns true if the node is assigned to run on GPU device.
325 bool NodeIsOnGpu(const NodeDef* node);
326 
327 // Returns the number of outputs of a node according to its OpDef. Note that
328 // some of the outputs may be unconnected.
329 int NumOutputs(const NodeDef& node, GraphDef* graph);
330 
331 // Returns true iff the node has at least one control input.
332 bool HasControlInputs(const NodeDef& node);
333 
334 // Returns true iff the node has at least one regular input.
335 bool HasRegularInputs(const NodeDef& node);
336 
337 // Returns true iff the node has at least one regular output.
338 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
339 
340 // Returns true iff the node has at least one control output.
341 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
342 
343 // Number of connected control inputs.
344 int NumControlInputs(const NodeDef& node);
345 
346 // Number of connected non-control inputs.
347 int NumNonControlInputs(const NodeDef& node);
348 
349 // Number of connected control outputs.
350 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
351 
352 // Number of connected non-control outputs.
353 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
354 
355 // Number of connected non-control data outputs (Ops that consume output tensor
356 // data, not just it's shape).
357 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
358 
359 // Removes redundant control inputs from node.
360 void DedupControlInputs(NodeDef* node);
361 
362 // Returns an error if an attribute with the given key does not exist in node.
363 Status CheckAttrExists(const NodeDef& node, const string& key);
364 
365 // Returns an error if attributes with the given keys do not exist in node.
366 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
367 
368 // Returns the data type in attribute `attr_name` of `node`. If that attribute
369 // doesn't exist, returns DT_INVALID.
370 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
371 
372 // Returns the last node in the simple chain starting at source and traversing
373 // through the input(0) edge from each node as long as the next node satisfies
374 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
375 // will be returned. Example: For the chain
376 //    source <- a <- b <- ... <- y <- z
377 // where
378 //    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
379 //    pred_fn(z) = false,
380 // the return value will be a pointer to y.
381 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
382                         bool follow_control_input,
383                         const std::function<bool(const NodeDef&)>& pred_fn);
384 
385 // Permute the nodes of graph in place according to the permutation.
386 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
387                          bool invert_permutation);
388 
389 // Returns Status::OK() if a kernel is registered for node.op() on the device
390 // type corresponding to node.device().
391 Status IsKernelRegisteredForNode(
392     absl::string_view node_name, bool has_experimental_debug_info,
393     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
394     absl::string_view node_op, absl::string_view node_device,
395     AttrSlice node_attrs);
396 Status IsKernelRegisteredForNode(const NodeDef& node);
397 
398 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
399 
400 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
401 
402 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
403 
404 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
405                          GraphDef* graph);
406 
407 // Erase all attributes without leading underscore. Returns the number of
408 // attributes erased.
409 int EraseRegularNodeAttributes(NodeDef* node);
410 
411 // Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
412 // "_output_".
413 int EraseNodeOutputAttributes(NodeDef* node);
414 
415 }  // end namespace grappler
416 }  // end namespace tensorflow
417 
418 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
419