1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/tensor_id.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 // Utilities for manipulating node name and input strings.
45
46 // Returns the trailing position number (or zero if no number is present) if
47 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
48 // Returns -2 if input_name is empty or NodeName(input_name) is not equal to
49 // node_name.
NodePositionIfSameNode(absl::string_view input_name,absl::string_view node_name)50 inline int NodePositionIfSameNode(absl::string_view input_name,
51 absl::string_view node_name) {
52 bool is_control = absl::StartsWith(input_name, "^");
53 if (is_control) input_name.remove_prefix(1);
54 if (input_name.empty() || node_name.empty() ||
55 input_name.size() < node_name.size()) {
56 return -2;
57 }
58 TensorId id = ParseTensorName(input_name);
59 if (id.first != node_name) return -2;
60 if (is_control) return -1;
61 return id.second;
62 }
63
64 // Returns the node name and position in a single call.
ParseNodeNameAsStringPiece(absl::string_view name,int * position)65 inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
66 int* position) {
67 const bool is_control = absl::StartsWith(name, "^");
68 TensorId id = ParseTensorName(name);
69 if (position) {
70 *position = is_control ? -1 : id.second;
71 }
72 if (is_control && id.second >= 0) {
73 id.first.remove_prefix(1);
74 }
75 return id.first;
76 }
77
78 // Returns the node name and position in a single call.
ParseNodeName(const string & name,int * position)79 inline string ParseNodeName(const string& name, int* position) {
80 return string(ParseNodeNameAsStringPiece(name, position));
81 }
82
83 // Return the node name corresponding to 'name' if name is valid, or the empty
84 // string otherwise.
NodeNameAsStringPiece(const string & name)85 inline StringPiece NodeNameAsStringPiece(const string& name) {
86 return ParseNodeNameAsStringPiece(name, nullptr);
87 }
88
89 // Return the node name corresponding to 'name' if name is valid, or the empty
90 // string otherwise.
NodeName(const string & name)91 inline string NodeName(const string& name) {
92 return string(NodeNameAsStringPiece(name));
93 }
94
NodePosition(const string & name)95 inline int NodePosition(const string& name) {
96 int position;
97 ParseNodeNameAsStringPiece(name, &position);
98 return position;
99 }
100
101 namespace internal {
102 // Base template class for NodeMap and ImmutableNodeMap.
103 template <typename GraphDefT, typename NodeDefT>
104 class NodeMapInternal {
105 public:
106 // Note: The NodeMap will store pointers to nodes in graph, which may become
107 // invalid if graph is changed.
NodeMapInternal(GraphDefT * graph)108 explicit NodeMapInternal(GraphDefT* graph) {
109 if (graph == nullptr) {
110 LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!";
111 return;
112 }
113 nodes_.reserve(graph->node_size());
114 outputs_.reserve(graph->node_size());
115 for (int i = 0; i < graph->node_size(); i++) {
116 NodeDefT* node = GetNodeDefFromGraph(graph, i);
117 const string& node_name = node->name();
118 auto rslt = nodes_.emplace(node_name, node);
119 // Check that the graph doesn't contain multiple nodes with the same name.
120 if (!rslt.second) {
121 // The first node found with a given name becomes the canonical.
122 LOG(WARNING) << "Duplicated node in the graph: " << node_name;
123 }
124 NodeDefT* canonical = rslt.second ? node : rslt.first->second;
125 for (const auto& input : node->input()) {
126 outputs_[NodeName(input)].insert(canonical);
127 }
128 }
129 }
130
131 // Get unordered list of fanouts from node. Notice, that the order is
132 // non-deterministic.
GetOutputs(const string & node_name)133 const absl::flat_hash_set<NodeDefT*>& GetOutputs(
134 const string& node_name) const {
135 auto it = outputs_.find(node_name);
136 if (it == outputs_.end()) {
137 return empty_set_;
138 }
139 return it->second;
140 }
141
142 // Get fanouts ordered by name.
GetOutputsOrderedByNodeName(const string & node_name)143 std::vector<NodeDefT*> GetOutputsOrderedByNodeName(
144 const string& node_name) const {
145 std::vector<NodeDefT*> result;
146 auto it = outputs_.find(node_name);
147 if (it != outputs_.end()) {
148 const absl::flat_hash_set<NodeDefT*>& outputs = it->second;
149 result.reserve(outputs.size());
150 result.assign(outputs.begin(), outputs.end());
151 std::sort(result.begin(), result.end(),
152 [](const NodeDef* n1, const NodeDef* n2) {
153 return n1->name() < n2->name();
154 });
155 }
156 return result;
157 }
158
159 // This method doesn't record the outputs of the added node; the outputs need
160 // to be explicitly added by the AddOutput method.
AddNode(const string & node_name,NodeDefT * node)161 void AddNode(const string& node_name, NodeDefT* node) {
162 DCHECK(node != nullptr);
163 auto ret = nodes_.emplace(node_name, node);
164 DCHECK(ret.second)
165 << "Pair (" << node_name << "," << node
166 << ") is not inserted because the same key already exists.";
167 }
168
RemoveNode(const string & name)169 void RemoveNode(const string& name) {
170 nodes_.erase(NodeName(name));
171 outputs_.erase(NodeName(name));
172 }
173
GetNode(const string & name)174 NodeDefT* GetNode(const string& name) const {
175 const string node_name = NodeName(name);
176 auto it = nodes_.find(node_name);
177 if (it == nodes_.end()) {
178 VLOG(1) << "Node could not be found: " << name;
179 return nullptr;
180 }
181 return it->second;
182 }
183
NodeExists(const string & name)184 bool NodeExists(const string& name) const {
185 const string node_name = NodeName(name);
186 return nodes_.find(node_name) != nodes_.end();
187 }
188
AddOutput(const string & node_name,const string & output_name)189 void AddOutput(const string& node_name, const string& output_name) {
190 auto output_node = nodes_[NodeName(output_name)];
191 DCHECK(output_node) << "Output node " << output_name
192 << " is missing in NodeMap.";
193 outputs_[node_name].insert(output_node);
194 }
195
RemoveOutput(const string & node_name,const string & output_name)196 void RemoveOutput(const string& node_name, const string& output_name) {
197 outputs_[node_name].erase(nodes_[NodeName(output_name)]);
198 }
199
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)200 void UpdateInput(const string& node_name, const string& old_input_name,
201 const string& new_input_name) {
202 RemoveOutput(NodeName(old_input_name), node_name);
203 AddOutput(NodeName(new_input_name), node_name);
204 }
205
RemoveInputs(const string & node_name)206 void RemoveInputs(const string& node_name) {
207 auto node = nodes_[node_name];
208 for (const auto& input : node->input()) {
209 RemoveOutput(NodeName(input), node->name());
210 }
211 }
212
RemoveOutputs(const string & node_name)213 void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
214
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)215 void UpdateOutput(const string& node_name, const string& old_output_name,
216 const string& new_output_name) {
217 absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
218 outputs.erase(nodes_[NodeName(old_output_name)]);
219 outputs.insert(nodes_[NodeName(new_output_name)]);
220 }
221
222 private:
223 // Helper method to get the NodeDef pointer of i-th node in a graph.
224 NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64 i) const;
225
226 const absl::flat_hash_set<NodeDefT*> empty_set_;
227 absl::node_hash_map<string, NodeDefT*> nodes_;
228 absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_;
229 };
230 } // namespace internal
231
232 // A utility class to lookup a node and its outputs by node name.
233 class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> {
234 public:
NodeMap(GraphDef * graph)235 explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {}
236 };
237
238 // Same to NodeMap, but uses const GraphDef.
239 class ImmutableNodeMap
240 : public internal::NodeMapInternal<const GraphDef, const NodeDef> {
241 public:
ImmutableNodeMap(const GraphDef * graph)242 explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {}
243 };
244
245 // A vector with a set. The set stores the same elements as the vector, and
246 // quickly answers whether a value is in the vector. Duplicated elements are not
247 // allowed for now.
248 template <class T, class Hash = std::hash<T>>
249 class SetVector {
250 public:
251 // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)252 bool PushBack(const T& value) {
253 if (!set_.insert(value).second) {
254 return false;
255 }
256 vector_.push_back(value);
257 return true;
258 }
259
PopBack()260 T PopBack() {
261 T back = vector_.back();
262 set_.erase(back);
263 vector_.pop_back();
264 return back;
265 }
266
Exists(const T & value)267 bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
268
Empty()269 bool Empty() const { return vector_.empty(); }
270
Reserve(int64 size)271 void Reserve(int64 size) { vector_.reserve(size); }
272
273 private:
274 gtl::FlatSet<T, Hash> set_;
275 std::vector<T> vector_;
276 };
277
278 // Returns formatted string from TensorId specific to grappler. Specifically,
279 // for the 0 port (first output), only the node name is returned.
280 string TensorIdToString(const TensorId& tensor_id);
281
282 // Returns formatted string from SafeTensorId specific to grappler.
283 // Specifically, for the 0 port (first output), only the node name is returned.
284 string SafeTensorIdToString(const SafeTensorId& tensor_id);
285
286 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
287 // the ^ character.
288 bool IsControlInput(const string& name);
289
290 // True iff tensor index refers to a control input.
291 bool IsControlInput(const TensorId& tensor_id);
292
293 // True iff 'name1' and 'name2' refer to the same input.
294 bool IsSameInput(const string& name1, const string& name2);
295
296
297 // Add a prefix to a node name with a custom delimiter.
298 string AddPrefixToNodeName(const string& name, const string& prefix,
299 const string& delimiter);
300
301 // Add a prefix to a node name.
302 string AddPrefixToNodeName(const string& name, const string& prefix);
303
304 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
305 // timeout (in milliseconds) for 'fn' to complete, before returning false.
306 //
307 // If returning false, the 'fn' may still continue to execute in the
308 // thread-pool. It is the responsibility of the caller to reset the thread-pool
309 // as appropriate.
310 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
311 thread::ThreadPool* thread_pool);
312
313 // Returns the node name prefixed with conventional symbol '^'
314 // for control dependency, given a NodeDef.
315 string AsControlDependency(const NodeDef& node);
316
317 // Returns the node name prefixed with conventional symbol '^'
318 // for control dependency, given a node name
319 string AsControlDependency(const string& node);
320
321 // Returns true if the node is assigned to run on CPU device.
322 bool NodeIsOnCpu(const NodeDef* node);
323
324 // Returns true if the node is assigned to run on GPU device.
325 bool NodeIsOnGpu(const NodeDef* node);
326
327 // Returns the number of outputs of a node according to its OpDef. Note that
328 // some of the outputs may be unconnected.
329 int NumOutputs(const NodeDef& node, GraphDef* graph);
330
331 // Returns true iff the node has at least one control input.
332 bool HasControlInputs(const NodeDef& node);
333
334 // Returns true iff the node has at least one regular input.
335 bool HasRegularInputs(const NodeDef& node);
336
337 // Returns true iff the node has at least one regular output.
338 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
339
340 // Returns true iff the node has at least one control output.
341 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
342
343 // Number of connected control inputs.
344 int NumControlInputs(const NodeDef& node);
345
346 // Number of connected non-control inputs.
347 int NumNonControlInputs(const NodeDef& node);
348
349 // Number of connected control outputs.
350 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
351
352 // Number of connected non-control outputs.
353 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
354
355 // Number of connected non-control data outputs (Ops that consume output tensor
356 // data, not just it's shape).
357 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
358
359 // Removes redundant control inputs from node.
360 void DedupControlInputs(NodeDef* node);
361
362 // Returns an error if an attribute with the given key does not exist in node.
363 Status CheckAttrExists(const NodeDef& node, const string& key);
364
365 // Returns an error if attributes with the given keys do not exist in node.
366 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
367
368 // Returns the data type in attribute `attr_name` of `node`. If that attribute
369 // doesn't exist, returns DT_INVALID.
370 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
371
372 // Returns the last node in the simple chain starting at source and traversing
373 // through the input(0) edge from each node as long as the next node satisfies
374 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
375 // will be returned. Example: For the chain
376 // source <- a <- b <- ... <- y <- z
377 // where
378 // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
379 // pred_fn(z) = false,
380 // the return value will be a pointer to y.
381 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
382 bool follow_control_input,
383 const std::function<bool(const NodeDef&)>& pred_fn);
384
385 // Permute the nodes of graph in place according to the permutation.
386 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
387 bool invert_permutation);
388
389 // Returns Status::OK() if a kernel is registered for node.op() on the device
390 // type corresponding to node.device().
391 Status IsKernelRegisteredForNode(
392 absl::string_view node_name, bool has_experimental_debug_info,
393 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
394 absl::string_view node_op, absl::string_view node_device,
395 AttrSlice node_attrs);
396 Status IsKernelRegisteredForNode(const NodeDef& node);
397
398 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
399
400 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
401
402 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
403
404 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
405 GraphDef* graph);
406
407 // Erase all attributes without leading underscore. Returns the number of
408 // attributes erased.
409 int EraseRegularNodeAttributes(NodeDef* node);
410
411 // Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
412 // "_output_".
413 int EraseNodeOutputAttributes(NodeDef* node);
414
415 } // end namespace grappler
416 } // end namespace tensorflow
417
418 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
419