1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <set>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 #include "absl/types/span.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/gtl/flatmap.h"
35 #include "tensorflow/core/lib/gtl/flatset.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace tensorflow {
40 namespace grappler {
41
42 // A utility class to lookup a node and its outputs by node name.
43 class NodeMap {
44 public:
45 // Note: The NodeMap will store pointers to nodes in graph, which may become
46 // invalid if graph is changed.
47 explicit NodeMap(GraphDef* graph);
48 NodeDef* GetNode(const string& name) const;
49 bool NodeExists(const string& name) const;
50 const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
51 // This method doesn't record the outputs of the added node; the outputs need
52 // to be explicitly added by the AddOutput method.
53 void AddNode(const string& name, NodeDef* node);
54 void RemoveNode(const string& name);
55 void UpdateInput(const string& node_name, const string& old_input_name,
56 const string& new_input_name);
57 void AddOutput(const string& node_name, const string& output_name);
58 void RemoveInputs(const string& node_name);
59 void RemoveOutput(const string& node_name, const string& output_name);
60 void RemoveOutputs(const string& node_name);
61 void UpdateOutput(const string& node_name, const string& old_output_name,
62 const string& new_output_name);
63
64 private:
65 const std::set<NodeDef*> empty_set_;
66 gtl::FlatMap<string, NodeDef*> nodes_;
67 gtl::FlatMap<string, std::set<NodeDef*>> outputs_;
68 };
69
70 // A vector with a set. The set stores the same elements as the vector, and
71 // quickly answers whether a value is in the vector. Duplicated elements are not
72 // allowed for now.
73 template <class T, class Hash = std::hash<T>>
74 class SetVector {
75 public:
76 // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)77 bool PushBack(const T& value) {
78 if (!set_.insert(value).second) {
79 return false;
80 }
81 vector_.push_back(value);
82 return true;
83 }
84
PopBack()85 T PopBack() {
86 T back = vector_.back();
87 set_.erase(back);
88 vector_.pop_back();
89 return back;
90 }
91
Exists(const T & value)92 bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
93
Empty()94 bool Empty() const { return vector_.empty(); }
95
Reserve(int64 size)96 void Reserve(int64 size) { vector_.reserve(size); }
97
98 private:
99 gtl::FlatSet<T, Hash> set_;
100 std::vector<T> vector_;
101 };
102
103 // Returns formatted string from TensorId specific to grappler. Specifically,
104 // for the 0 port (first output), only the node name is returned.
105 string TensorIdToString(const TensorId& tensor_id);
106
107 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
108 // the ^ character.
109 bool IsControlInput(const string& name);
110
111 // True iff tensor index refers to a control input.
112 bool IsControlInput(const TensorId& tensor_id);
113
114 // True iff 'name1' and 'name2' refer to the same input.
115 bool IsSameInput(const string& name1, const string& name2);
116
117 // Returns the trailing position number (or zero if no number is present) if
118 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
119 // Returns -2 if NodeName(input_name) is not equal to node_name.
120 // Note: This function is used very heavily, and this hand-optimized
121 // version is 3-4x faster than the version using Scanner, which it replaced.
122 // This is worth the reduction in readability.
NodePositionIfSameNode(const string & input_name,const string & node_name)123 inline int NodePositionIfSameNode(const string& input_name,
124 const string& node_name) {
125 if (input_name.empty()) return -2;
126 const bool is_ctrl = input_name[0] == '^';
127 auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
128 auto node_it = node_name.begin();
129 if (node_name.empty() ||
130 std::distance(input_it, input_name.end()) < node_name.size()) {
131 return -2;
132 }
133 while (node_it != node_name.end()) {
134 if (*input_it++ != *node_it++) {
135 return -2;
136 }
137 }
138 if (input_it == input_name.end()) {
139 return is_ctrl ? -1 : 0;
140 } else if (*input_it++ == ':') {
141 StringPiece remaining(&(*input_it),
142 std::distance(input_it, input_name.end()));
143 int position;
144 if (!strings::safe_strto32(remaining, &position)) {
145 return -2;
146 }
147 return is_ctrl ? -1 : position;
148 } else {
149 return -2;
150 }
151 }
152
153 // Return the node name corresponding to 'name' if name is valid, or the empty
154 // string otherwise.
NodeNameAsStringPiece(const string & name)155 inline StringPiece NodeNameAsStringPiece(const string& name) {
156 static const string empty;
157 if (name.empty()) return StringPiece(empty);
158 const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
159 auto end_it = begin_it;
160 while (end_it != name.end() && *end_it != ':') {
161 ++end_it;
162 }
163 if (end_it != name.end() && *end_it != ':') {
164 return StringPiece(empty);
165 }
166 return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
167 }
168
169 // Return the node name corresponding to 'name' if name is valid, or the empty
170 // string otherwise.
NodeName(const string & name)171 inline string NodeName(const string& name) {
172 return string(NodeNameAsStringPiece(name));
173 }
174
175 // Returns the node name and position in a single call.
176 // DEPRECATED(ezhulenev): Use TensorId and ParseTensorName.
ParseNodeNameAsStringPiece(const string & name,int * position)177 inline StringPiece ParseNodeNameAsStringPiece(const string& name,
178 int* position) {
179 static const string empty;
180 if (name.empty()) {
181 *position = 0;
182 return StringPiece(empty);
183 }
184 const bool is_ctrl = name[0] == '^';
185 const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
186 *position = is_ctrl ? -1 : 0;
187 auto end_it = begin_it;
188 while (end_it != name.end() && *end_it != ':') {
189 ++end_it;
190 }
191 const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
192 if (end_it != name.end()) {
193 if (*end_it != ':') {
194 return StringPiece(empty);
195 } else if (!is_ctrl) {
196 ++end_it;
197 StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
198 if (!strings::safe_strto32(remaining, position)) {
199 return StringPiece(empty);
200 }
201 }
202 }
203 return node_name;
204 }
205
206 // Returns the node name and position in a single call.
207 // DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName.
ParseNodeName(const string & name,int * position)208 inline string ParseNodeName(const string& name, int* position) {
209 return string(ParseNodeNameAsStringPiece(name, position));
210 }
211
NodePosition(const string & name)212 inline int NodePosition(const string& name) {
213 int position;
214 ParseNodeNameAsStringPiece(name, &position);
215 return position;
216 }
217
218 // Add a prefix to a node name with a custom delimiter.
219 string AddPrefixToNodeName(const string& name, const string& prefix,
220 const string& delimiter);
221
222 // Add a prefix to a node name.
223 string AddPrefixToNodeName(const string& name, const string& prefix);
224
225 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
226 // timeout (in milliseconds) for 'fn' to complete, before returning false.
227 //
228 // If returning false, the 'fn' may still continue to execute in the
229 // thread-pool. It is the responsibility of the caller to reset the thread-pool
230 // as appropriate.
231 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
232 thread::ThreadPool* thread_pool);
233
234 // Returns the node name prefixed with conventional symbol '^'
235 // for control dependency, given a NodeDef.
236 string AsControlDependency(const NodeDef& node);
237
238 // Returns the node name prefixed with conventional symbol '^'
239 // for control dependency, given a node name
240 string AsControlDependency(const string& node);
241
242 // Returns true if the node is assigned to run on CPU device.
243 bool NodeIsOnCpu(const NodeDef* node);
244
245 // Returns true if the node is assigned to run on GPU device.
246 bool NodeIsOnGpu(const NodeDef* node);
247
248 // Returns the number of outputs of a node according to its OpDef. Note that
249 // some of the outputs may be unconnected.
250 int NumOutputs(const NodeDef& node, GraphDef* graph);
251
252 // Returns true iff the node has at least one control input.
253 bool HasControlInputs(const NodeDef& node);
254
255 // Number of connected non-control inputs.
256 int NumNonControlInputs(const NodeDef& node);
257
258 // Number of connected non-control outputs.
259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
260
261 // Number of connected non-control data outputs (Ops that consume output tensor
262 // data, not just it's shape).
263 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
264
265 // Removes redundant control inputs from node.
266 void DedupControlInputs(NodeDef* node);
267
268 // Returns an error if an attribute with the given key does not exist in node.
269 Status CheckAttrExists(const NodeDef& node, const string& key);
270
271 // Returns an error if attributes with the given keys do not exist in node.
272 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
273
274 // Returns the data type in attribute `attr_name` of `node`. If that attribute
275 // doesn't exist, returns DT_INVALID.
276 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
277
278 // Returns the last node in the simple chain starting at source and traversing
279 // through the input(0) edge from each node as long as the next node satisfies
280 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
281 // will be returned. Example: For the chain
282 // source <- a <- b <- ... <- y <- z
283 // where
284 // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
285 // pred_fn(z) = false,
286 // the return value will be a pointer to y.
287 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
288 bool follow_control_input,
289 const std::function<bool(const NodeDef&)>& pred_fn);
290
291 // Permute the nodes of graph in place according to the permutation.
292 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
293 bool invert_permutation);
294
295 // Returns Status::OK() if a kernel is registered for node.op() on the device
296 // type corresponding to node.device().
297 Status IsKernelRegisteredForNode(const NodeDef& node);
298
299 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
300
301 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
302
303 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
304
305 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
306 GraphDef* graph);
307
308 } // end namespace grappler
309 } // end namespace tensorflow
310
311 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
312