• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <set>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 #include "absl/types/span.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/gtl/flatmap.h"
35 #include "tensorflow/core/lib/gtl/flatset.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 // A utility class to lookup a node and its outputs by node name.
43 class NodeMap {
44  public:
45   // Note: The NodeMap will store pointers to nodes in graph, which may become
46   // invalid if graph is changed.
47   explicit NodeMap(GraphDef* graph);
48   NodeDef* GetNode(const string& name) const;
49   bool NodeExists(const string& name) const;
50   const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
51   // This method doesn't record the outputs of the added node; the outputs need
52   // to be explicitly added by the AddOutput method.
53   void AddNode(const string& name, NodeDef* node);
54   void RemoveNode(const string& name);
55   void UpdateInput(const string& node_name, const string& old_input_name,
56                    const string& new_input_name);
57   void AddOutput(const string& node_name, const string& output_name);
58   void RemoveInputs(const string& node_name);
59   void RemoveOutput(const string& node_name, const string& output_name);
60   void RemoveOutputs(const string& node_name);
61   void UpdateOutput(const string& node_name, const string& old_output_name,
62                     const string& new_output_name);
63 
64  private:
65   const std::set<NodeDef*> empty_set_;
66   gtl::FlatMap<string, NodeDef*> nodes_;
67   gtl::FlatMap<string, std::set<NodeDef*>> outputs_;
68 };
69 
70 // A vector with a set. The set stores the same elements as the vector, and
71 // quickly answers whether a value is in the vector. Duplicated elements are not
72 // allowed for now.
73 template <class T, class Hash = std::hash<T>>
74 class SetVector {
75  public:
76   // Returns false if value already existed in the set, true otherwise.
PushBack(const T & value)77   bool PushBack(const T& value) {
78     if (!set_.insert(value).second) {
79       return false;
80     }
81     vector_.push_back(value);
82     return true;
83   }
84 
PopBack()85   T PopBack() {
86     T back = vector_.back();
87     set_.erase(back);
88     vector_.pop_back();
89     return back;
90   }
91 
Exists(const T & value)92   bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
93 
Empty()94   bool Empty() const { return vector_.empty(); }
95 
Reserve(int64 size)96   void Reserve(int64 size) { vector_.reserve(size); }
97 
98  private:
99   gtl::FlatSet<T, Hash> set_;
100   std::vector<T> vector_;
101 };
102 
103 // Returns formatted string from TensorId specific to grappler. Specifically,
104 // for the 0 port (first output), only the node name is returned.
105 string TensorIdToString(const TensorId& tensor_id);
106 
107 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
108 // the ^ character.
109 bool IsControlInput(const string& name);
110 
111 // True iff tensor index refers to a control input.
112 bool IsControlInput(const TensorId& tensor_id);
113 
114 // True iff 'name1' and 'name2' refer to the same input.
115 bool IsSameInput(const string& name1, const string& name2);
116 
117 // Returns the trailing position number (or zero if no number is present) if
118 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
119 // Returns -2 if NodeName(input_name) is not equal to node_name.
120 // Note: This function is used very heavily, and this hand-optimized
121 // version is 3-4x faster than the version using Scanner, which it replaced.
122 // This is worth the reduction in readability.
NodePositionIfSameNode(const string & input_name,const string & node_name)123 inline int NodePositionIfSameNode(const string& input_name,
124                                   const string& node_name) {
125   if (input_name.empty()) return -2;
126   const bool is_ctrl = input_name[0] == '^';
127   auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
128   auto node_it = node_name.begin();
129   if (node_name.empty() ||
130       std::distance(input_it, input_name.end()) < node_name.size()) {
131     return -2;
132   }
133   while (node_it != node_name.end()) {
134     if (*input_it++ != *node_it++) {
135       return -2;
136     }
137   }
138   if (input_it == input_name.end()) {
139     return is_ctrl ? -1 : 0;
140   } else if (*input_it++ == ':') {
141     StringPiece remaining(&(*input_it),
142                           std::distance(input_it, input_name.end()));
143     int position;
144     if (!strings::safe_strto32(remaining, &position)) {
145       return -2;
146     }
147     return is_ctrl ? -1 : position;
148   } else {
149     return -2;
150   }
151 }
152 
153 // Return the node name corresponding to 'name' if name is valid, or the empty
154 // string otherwise.
NodeNameAsStringPiece(const string & name)155 inline StringPiece NodeNameAsStringPiece(const string& name) {
156   static const string empty;
157   if (name.empty()) return StringPiece(empty);
158   const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
159   auto end_it = begin_it;
160   while (end_it != name.end() && *end_it != ':') {
161     ++end_it;
162   }
163   if (end_it != name.end() && *end_it != ':') {
164     return StringPiece(empty);
165   }
166   return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
167 }
168 
169 // Return the node name corresponding to 'name' if name is valid, or the empty
170 // string otherwise.
NodeName(const string & name)171 inline string NodeName(const string& name) {
172   return string(NodeNameAsStringPiece(name));
173 }
174 
175 // Returns the node name and position in a single call.
176 // DEPRECATED(ezhulenev): Use TensorId and ParseTensorName.
ParseNodeNameAsStringPiece(const string & name,int * position)177 inline StringPiece ParseNodeNameAsStringPiece(const string& name,
178                                               int* position) {
179   static const string empty;
180   if (name.empty()) {
181     *position = 0;
182     return StringPiece(empty);
183   }
184   const bool is_ctrl = name[0] == '^';
185   const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
186   *position = is_ctrl ? -1 : 0;
187   auto end_it = begin_it;
188   while (end_it != name.end() && *end_it != ':') {
189     ++end_it;
190   }
191   const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
192   if (end_it != name.end()) {
193     if (*end_it != ':') {
194       return StringPiece(empty);
195     } else if (!is_ctrl) {
196       ++end_it;
197       StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
198       if (!strings::safe_strto32(remaining, position)) {
199         return StringPiece(empty);
200       }
201     }
202   }
203   return node_name;
204 }
205 
206 // Returns the node name and position in a single call.
207 // DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName.
ParseNodeName(const string & name,int * position)208 inline string ParseNodeName(const string& name, int* position) {
209   return string(ParseNodeNameAsStringPiece(name, position));
210 }
211 
NodePosition(const string & name)212 inline int NodePosition(const string& name) {
213   int position;
214   ParseNodeNameAsStringPiece(name, &position);
215   return position;
216 }
217 
218 // Add a prefix to a node name with a custom delimiter.
219 string AddPrefixToNodeName(const string& name, const string& prefix,
220                            const string& delimiter);
221 
222 // Add a prefix to a node name.
223 string AddPrefixToNodeName(const string& name, const string& prefix);
224 
225 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
226 // timeout (in milliseconds) for 'fn' to complete, before returning false.
227 //
228 // If returning false, the 'fn' may still continue to execute in the
229 // thread-pool. It is the responsibility of the caller to reset the thread-pool
230 // as appropriate.
231 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
232                         thread::ThreadPool* thread_pool);
233 
234 // Returns the node name prefixed with conventional symbol '^'
235 // for control dependency, given a NodeDef.
236 string AsControlDependency(const NodeDef& node);
237 
238 // Returns the node name prefixed with conventional symbol '^'
239 // for control dependency, given a node name
240 string AsControlDependency(const string& node);
241 
242 // Returns true if the node is assigned to run on CPU device.
243 bool NodeIsOnCpu(const NodeDef* node);
244 
245 // Returns true if the node is assigned to run on GPU device.
246 bool NodeIsOnGpu(const NodeDef* node);
247 
248 // Returns the number of outputs of a node according to its OpDef. Note that
249 // some of the outputs may be unconnected.
250 int NumOutputs(const NodeDef& node, GraphDef* graph);
251 
252 // Returns true iff the node has at least one control input.
253 bool HasControlInputs(const NodeDef& node);
254 
255 // Number of connected non-control inputs.
256 int NumNonControlInputs(const NodeDef& node);
257 
258 // Number of connected non-control outputs.
259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
260 
261 // Number of connected non-control data outputs (Ops that consume output tensor
262 // data, not just it's shape).
263 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
264 
265 // Removes redundant control inputs from node.
266 void DedupControlInputs(NodeDef* node);
267 
268 // Returns an error if an attribute with the given key does not exist in node.
269 Status CheckAttrExists(const NodeDef& node, const string& key);
270 
271 // Returns an error if attributes with the given keys do not exist in node.
272 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
273 
274 // Returns the data type in attribute `attr_name` of `node`. If that attribute
275 // doesn't exist, returns DT_INVALID.
276 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
277 
278 // Returns the last node in the simple chain starting at source and traversing
279 // through the input(0) edge from each node as long as the next node satisfies
280 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
281 // will be returned. Example: For the chain
282 //    source <- a <- b <- ... <- y <- z
283 // where
284 //    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
285 //    pred_fn(z) = false,
286 // the return value will be a pointer to y.
287 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
288                         bool follow_control_input,
289                         const std::function<bool(const NodeDef&)>& pred_fn);
290 
291 // Permute the nodes of graph in place according to the permutation.
292 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
293                          bool invert_permutation);
294 
295 // Returns Status::OK() if a kernel is registered for node.op() on the device
296 // type corresponding to node.device().
297 Status IsKernelRegisteredForNode(const NodeDef& node);
298 
299 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
300 
301 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
302 
303 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
304 
305 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
306                          GraphDef* graph);
307 
308 }  // end namespace grappler
309 }  // end namespace tensorflow
310 
311 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
312