1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/tensor_id.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 // Utilities for manipulating node name and input strings.
45
46 // Returns the trailing position number (or zero if no number is present) if
47 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
48 // Returns -2 if input_name is empty or NodeName(input_name) is not equal to
49 // node_name.
NodePositionIfSameNode(absl::string_view input_name,absl::string_view node_name)50 inline int NodePositionIfSameNode(absl::string_view input_name,
51 absl::string_view node_name) {
52 bool is_control = absl::StartsWith(input_name, "^");
53 if (is_control) input_name.remove_prefix(1);
54 if (input_name.empty() || node_name.empty() ||
55 input_name.size() < node_name.size()) {
56 return -2;
57 }
58 TensorId id = ParseTensorName(input_name);
59 if (id.first != node_name) return -2;
60 if (is_control) return -1;
61 return id.second;
62 }
63
64 // Returns the node name and position in a single call.
ParseNodeNameAsStringPiece(absl::string_view name,int * position)65 inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
66 int* position) {
67 const bool is_control = absl::StartsWith(name, "^");
68 TensorId id = ParseTensorName(name);
69 if (position) {
70 *position = is_control ? -1 : id.second;
71 }
72 if (is_control && id.second >= 0) {
73 id.first.remove_prefix(1);
74 }
75 return id.first;
76 }
77
78 // Returns the node name and position in a single call.
ParseNodeName(const string & name,int * position)79 inline string ParseNodeName(const string& name, int* position) {
80 return string(ParseNodeNameAsStringPiece(name, position));
81 }
82
83 // Return the node name corresponding to 'name' if name is valid, or the empty
84 // string otherwise.
NodeNameAsStringPiece(const string & name)85 inline StringPiece NodeNameAsStringPiece(const string& name) {
86 return ParseNodeNameAsStringPiece(name, nullptr);
87 }
88
89 // Return the node name corresponding to 'name' if name is valid, or the empty
90 // string otherwise.
NodeName(const string & name)91 inline string NodeName(const string& name) {
92 return string(NodeNameAsStringPiece(name));
93 }
94
NodePosition(const string & name)95 inline int NodePosition(const string& name) {
96 int position;
97 ParseNodeNameAsStringPiece(name, &position);
98 return position;
99 }
100
101 namespace internal {
102 // Base template class for NodeMap and ImmutableNodeMap.
103 template <typename GraphDefT, typename NodeDefT>
104 class NodeMapInternal {
105 public:
106 // Note: The NodeMap will store pointers to nodes in graph, which may become
107 // invalid if graph is changed.
NodeMapInternal(GraphDefT * graph)108 explicit NodeMapInternal(GraphDefT* graph) {
109 if (graph == nullptr) {
110 LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!";
111 return;
112 }
113 nodes_.reserve(graph->node_size());
114 outputs_.reserve(graph->node_size());
115 for (int i = 0; i < graph->node_size(); i++) {
116 NodeDefT* node = GetNodeDefFromGraph(graph, i);
117 const string& node_name = node->name();
118 auto rslt = nodes_.emplace(node_name, node);
119 // Check that the graph doesn't contain multiple nodes with the same name.
120 if (!rslt.second) {
121 // The first node found with a given name becomes the canonical.
122 LOG(WARNING) << "Duplicated node in the graph: " << node_name;
123 }
124 NodeDefT* canonical = rslt.second ? node : rslt.first->second;
125 for (const auto& input : node->input()) {
126 outputs_[NodeName(input)].insert(canonical);
127 }
128 }
129 }
130
131 // Get unordered list of fanouts from node. Notice, that the order is
132 // non-deterministic.
GetOutputs(const string & node_name)133 const absl::flat_hash_set<NodeDefT*>& GetOutputs(
134 const string& node_name) const {
135 auto it = outputs_.find(node_name);
136 if (it == outputs_.end()) {
137 return empty_set_;
138 }
139 return it->second;
140 }
141
142 // Get fanouts ordered by name.
GetOutputsOrderedByNodeName(const string & node_name)143 std::vector<NodeDefT*> GetOutputsOrderedByNodeName(
144 const string& node_name) const {
145 std::vector<NodeDefT*> result;
146 auto it = outputs_.find(node_name);
147 if (it != outputs_.end()) {
148 const absl::flat_hash_set<NodeDefT*>& outputs = it->second;
149 result.reserve(outputs.size());
150 result.assign(outputs.begin(), outputs.end());
151 std::sort(result.begin(), result.end(),
152 [](const NodeDef* n1, const NodeDef* n2) {
153 return n1->name() < n2->name();
154 });
155 }
156 return result;
157 }
158
159 // This method doesn't record the outputs of the added node; the outputs need
160 // to be explicitly added by the AddOutput method.
AddNode(const string & node_name,NodeDefT * node)161 void AddNode(const string& node_name, NodeDefT* node) {
162 DCHECK(node != nullptr);
163 auto ret = nodes_.emplace(node_name, node);
164 DCHECK(ret.second)
165 << "Pair (" << node_name << "," << node
166 << ") is not inserted because the same key already exists.";
167 }
168
RemoveNode(const string & name)169 void RemoveNode(const string& name) {
170 nodes_.erase(NodeName(name));
171 outputs_.erase(NodeName(name));
172 }
173
GetNode(const string & name)174 NodeDefT* GetNode(const string& name) const {
175 const string node_name = NodeName(name);
176 auto it = nodes_.find(node_name);
177 if (it == nodes_.end()) {
178 VLOG(1) << "Node could not be found: " << name;
179 return nullptr;
180 }
181 return it->second;
182 }
183
NodeExists(const string & name)184 bool NodeExists(const string& name) const {
185 const string node_name = NodeName(name);
186 return nodes_.find(node_name) != nodes_.end();
187 }
188
AddOutput(const string & node_name,const string & output_name)189 void AddOutput(const string& node_name, const string& output_name) {
190 auto output_node = nodes_[NodeName(output_name)];
191 DCHECK(output_node) << "Output node " << output_name
192 << " is missing in NodeMap.";
193 outputs_[node_name].insert(output_node);
194 }
195
RemoveOutput(const string & node_name,const string & output_name)196 void RemoveOutput(const string& node_name, const string& output_name) {
197 outputs_[node_name].erase(nodes_[NodeName(output_name)]);
198 }
199
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)200 void UpdateInput(const string& node_name, const string& old_input_name,
201 const string& new_input_name) {
202 RemoveOutput(NodeName(old_input_name), node_name);
203 AddOutput(NodeName(new_input_name), node_name);
204 }
205
RemoveInputs(const string & node_name)206 void RemoveInputs(const string& node_name) {
207 auto node = nodes_[node_name];
208 for (const auto& input : node->input()) {
209 RemoveOutput(NodeName(input), node->name());
210 }
211 }
212
RemoveOutputs(const string & node_name)213 void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
214
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)215 void UpdateOutput(const string& node_name, const string& old_output_name,
216 const string& new_output_name) {
217 absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
218 outputs.erase(nodes_[NodeName(old_output_name)]);
219 outputs.insert(nodes_[NodeName(new_output_name)]);
220 }
221
222 private:
223 // Helper method to get the NodeDef pointer of i-th node in a graph.
224 inline NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64_t i) const;
225
226 const absl::flat_hash_set<NodeDefT*> empty_set_;
227 absl::node_hash_map<string, NodeDefT*> nodes_;
228 absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_;
229 };
230
231 // Specialized template class method GetNodeDefFromGraph.
232 template <>
GetNodeDefFromGraph(GraphDef * graph,int64_t i)233 inline NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph(
234 GraphDef* graph, int64_t i) const {
235 return graph->mutable_node(i);
236 }
237
238 template <>
239 inline const NodeDef*
GetNodeDefFromGraph(const GraphDef * graph,int64_t i)240 NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph(
241 const GraphDef* graph, int64_t i) const {
242 return &graph->node(i);
243 }
244 } // namespace internal
245
246 // A utility class to lookup a node and its outputs by node name.
247 class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> {
248 public:
NodeMap(GraphDef * graph)249 explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {}
250 };
251
252 // Same to NodeMap, but uses const GraphDef.
253 class ImmutableNodeMap
254 : public internal::NodeMapInternal<const GraphDef, const NodeDef> {
255 public:
ImmutableNodeMap(const GraphDef * graph)256 explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {}
257 };
258
259 // A vector with a set. The set stores the same elements as the vector, and
260 // quickly answers whether a value is in the vector. Duplicated elements are not
261 // allowed for now.
262 template <class T, class Hash = std::hash<T>>
263 class SetVector {
264 public:
265 // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)266 bool PushBack(const T& value) {
267 if (!set_.insert(value).second) {
268 return false;
269 }
270 vector_.push_back(value);
271 return true;
272 }
273
PopBack()274 T PopBack() {
275 T back = vector_.back();
276 set_.erase(back);
277 vector_.pop_back();
278 return back;
279 }
280
Exists(const T & value)281 bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
282
Empty()283 bool Empty() const { return vector_.empty(); }
284
Reserve(int64_t size)285 void Reserve(int64_t size) { vector_.reserve(size); }
286
287 private:
288 gtl::FlatSet<T, Hash> set_;
289 std::vector<T> vector_;
290 };
291
292 // Returns formatted string from TensorId specific to grappler. Specifically,
293 // for the 0 port (first output), only the node name is returned.
294 string TensorIdToString(const TensorId& tensor_id);
295
296 // Returns formatted string from SafeTensorId specific to grappler.
297 // Specifically, for the 0 port (first output), only the node name is returned.
298 string SafeTensorIdToString(const SafeTensorId& tensor_id);
299
300 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
301 // the ^ character.
302 bool IsControlInput(absl::string_view name);
303
304 // True iff tensor index refers to a control input.
305 bool IsControlInput(const TensorId& tensor_id);
306
307 // True iff 'name1' and 'name2' refer to the same input.
308 bool IsSameInput(const string& name1, const string& name2);
309
310
311 // Add a prefix to a node name with a custom delimiter.
312 string AddPrefixToNodeName(const string& name, const string& prefix,
313 const string& delimiter);
314
315 // Add a prefix to a node name.
316 string AddPrefixToNodeName(const string& name, const string& prefix);
317
318 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
319 // timeout (in milliseconds) for 'fn' to complete, before returning false.
320 //
321 // If returning false, the 'fn' may still continue to execute in the
322 // thread-pool. It is the responsibility of the caller to reset the thread-pool
323 // as appropriate.
324 bool ExecuteWithTimeout(std::function<void()> fn, int64_t timeout_in_ms,
325 thread::ThreadPool* thread_pool);
326
327 // Returns the node name prefixed with conventional symbol '^'
328 // for control dependency, given a NodeDef.
329 string AsControlDependency(const NodeDef& node);
330
331 // Returns the node name prefixed with conventional symbol '^'
332 // for control dependency, given a node name
333 string AsControlDependency(const string& node);
334
335 // Returns true if the node is assigned to run on CPU device.
336 bool NodeIsOnCpu(const NodeDef* node);
337
338 // Returns true if the node is assigned to run on GPU device.
339 bool NodeIsOnGpu(const NodeDef* node);
340
341 // Returns the number of outputs of a node according to its OpDef. Note that
342 // some of the outputs may be unconnected.
343 int NumOutputs(const NodeDef& node, GraphDef* graph);
344
345 // Returns true iff the node has at least one control input.
346 bool HasControlInputs(const NodeDef& node);
347
348 // Returns true iff the node has at least one regular input.
349 bool HasRegularInputs(const NodeDef& node);
350
351 // Returns true iff the node has at least one regular output.
352 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
353
354 // Returns true iff the node has at least one control output.
355 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
356
357 // Number of connected control inputs.
358 int NumControlInputs(const NodeDef& node);
359
360 // Number of connected non-control inputs.
361 int NumNonControlInputs(const NodeDef& node);
362
363 // Number of connected control outputs.
364 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
365
366 // Number of connected non-control outputs.
367 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
368
369 // Number of connected non-control data outputs (Ops that consume output tensor
370 // data, not just it's shape).
371 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
372
373 // Removes redundant control inputs from node.
374 void DedupControlInputs(NodeDef* node);
375
376 // Returns an error if an attribute with the given key does not exist in node.
377 Status CheckAttrExists(const NodeDef& node, const string& key);
378
379 // Returns an error if attributes with the given keys do not exist in node.
380 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
381
382 // Returns the data type in attribute `attr_name` of `node`. If that attribute
383 // doesn't exist, returns DT_INVALID.
384 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
385
386 // Returns the last node in the simple chain starting at source and traversing
387 // through the input(0) edge from each node as long as the next node satisfies
388 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
389 // will be returned. Example: For the chain
390 // source <- a <- b <- ... <- y <- z
391 // where
392 // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
393 // pred_fn(z) = false,
394 // the return value will be a pointer to y.
395 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
396 bool follow_control_input,
397 const std::function<bool(const NodeDef&)>& pred_fn);
398
399 // Permute the nodes of graph in place according to the permutation.
400 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
401 bool invert_permutation);
402
403 // Returns Status::OK() if a kernel is registered for node.op() on the device
404 // type corresponding to node.device().
405 Status IsKernelRegisteredForNode(
406 absl::string_view node_name, bool has_experimental_debug_info,
407 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
408 absl::string_view node_op, absl::string_view node_device,
409 AttrSlice node_attrs);
410 Status IsKernelRegisteredForNode(const NodeDef& node);
411
412 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
413
414 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
415
416 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
417
418 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
419 GraphDef* graph);
420
421 // Erase all attributes without leading underscore. Returns the number of
422 // attributes erased.
423 int EraseRegularNodeAttributes(NodeDef* node);
424
425 // Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
426 // "_output_".
427 int EraseNodeOutputAttributes(NodeDef* node);
428
429 } // end namespace grappler
430 } // end namespace tensorflow
431
432 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
433