• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils.h"
17 
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/scanner.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/notification.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 
40 namespace tensorflow {
41 namespace grappler {
42 namespace {
43 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)44 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
45   using RealType = typename Eigen::NumTraits<T>::Real;
46   if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
47       value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
48     return false;
49   }
50   tensor->flat<T>()(0) = static_cast<T>(value);
51   return true;
52 }
53 
54 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)55 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
56   using RealType = typename Eigen::NumTraits<T>::Real;
57   if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
58       value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
59     return false;
60   }
61   tensor->flat<T>()(0) = static_cast<T>(value);
62   return true;
63 }
64 
65 // Is 'node' an operator that consumes only the shape of its input, not the
66 // data itself?
67 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
68 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)69 bool IsShapeConsumer(const NodeDef& node) {
70   const string& op = node.op();
71   return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
72 }
73 
74 }  // namespace
75 
TensorIdToString(const TensorId & tensor_id)76 string TensorIdToString(const TensorId& tensor_id) {
77   return tensor_id.index() == 0 ? string(tensor_id.node())
78                                 : tensor_id.ToString();
79 }
80 
SafeTensorIdToString(const SafeTensorId & tensor_id)81 string SafeTensorIdToString(const SafeTensorId& tensor_id) {
82   return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
83 }
84 
IsSameInput(const string & name1,const string & name2)85 bool IsSameInput(const string& name1, const string& name2) {
86   if (name1 == name2) return true;
87   TensorId tensor1 = ParseTensorName(name1);
88   TensorId tensor2 = ParseTensorName(name2);
89   return tensor1 == tensor2;
90 }
91 
IsControlInput(absl::string_view name)92 bool IsControlInput(absl::string_view name) {
93   return !name.empty() && name[0] == '^';
94 }
95 
IsControlInput(const TensorId & tensor_id)96 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
97 
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)98 string AddPrefixToNodeName(const string& name, const string& prefix,
99                            const string& delimiter) {
100   if (!name.empty()) {
101     if (name[0] == '^') {
102       return absl::StrCat("^", prefix, delimiter, name.substr(1));
103     }
104   }
105   return absl::StrCat(prefix, delimiter, name);
106 }
107 
AddPrefixToNodeName(const string & name,const string & prefix)108 string AddPrefixToNodeName(const string& name, const string& prefix) {
109   return AddPrefixToNodeName(name, prefix, "/");
110 }
111 
ExecuteWithTimeout(std::function<void ()> fn,const int64_t timeout_in_ms,thread::ThreadPool * const thread_pool)112 bool ExecuteWithTimeout(std::function<void()> fn, const int64_t timeout_in_ms,
113                         thread::ThreadPool* const thread_pool) {
114   if (timeout_in_ms <= 0) {
115     fn();
116     return true;
117   }
118   auto done = std::make_shared<Notification>();
119   thread_pool->Schedule([done, fn]() {
120     fn();
121     done->Notify();
122   });
123   const bool notified =
124       WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
125   return notified;
126 }
127 
AsControlDependency(const NodeDef & node)128 string AsControlDependency(const NodeDef& node) {
129   return absl::StrCat("^", node.name());
130 }
131 
AsControlDependency(const string & node_name)132 string AsControlDependency(const string& node_name) {
133   CHECK(!node_name.empty());
134   return (!node_name.empty() && node_name[0] == '^')
135              ? node_name
136              : absl::StrCat("^", node_name);
137 }
138 
NodeIsOnCpu(const NodeDef * node)139 bool NodeIsOnCpu(const NodeDef* node) {
140   string task, device;
141   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
142          absl::StartsWith(device, DEVICE_CPU);
143 }
144 
NodeIsOnGpu(const NodeDef * node)145 bool NodeIsOnGpu(const NodeDef* node) {
146   string task, device;
147   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
148          absl::StartsWith(device, DEVICE_GPU);
149 }
150 
NumOutputs(const NodeDef & node,GraphDef * graph)151 int NumOutputs(const NodeDef& node, GraphDef* graph) {
152   int num_outputs = 0;
153   const OpDef* op_def = nullptr;
154   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
155   if (status.ok()) {
156     for (const auto& output : op_def->output_arg()) {
157       if (!output.type_list_attr().empty()) {
158         num_outputs +=
159             node.attr().at(output.type_list_attr()).list().type_size();
160       } else if (!output.number_attr().empty()) {
161         num_outputs += node.attr().at(output.number_attr()).i();
162       } else {
163         num_outputs++;
164       }
165     }
166   } else {
167     FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
168     auto status = fdef.LookUpOpDef(node.op(), &op_def);
169     if (status.ok()) {
170       num_outputs = op_def->output_arg_size();
171     }
172   }
173   return num_outputs;
174 }
175 
HasControlInputs(const NodeDef & node)176 bool HasControlInputs(const NodeDef& node) {
177   const int num_inputs = node.input_size();
178   if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
179     return true;
180   }
181   return false;
182 }
183 
HasRegularInputs(const NodeDef & node)184 bool HasRegularInputs(const NodeDef& node) {
185   const int num_inputs = node.input_size();
186   if (num_inputs > 0 && !IsControlInput(node.input(0))) {
187     return true;
188   }
189   return false;
190 }
191 
NumNonControlInputs(const NodeDef & node)192 int NumNonControlInputs(const NodeDef& node) {
193   int num_inputs = 0;
194   for (; num_inputs < node.input_size(); ++num_inputs) {
195     const string& input = node.input(num_inputs);
196     if (IsControlInput(input)) {
197       return num_inputs;
198     }
199   }
200   return num_inputs;
201 }
202 
NumControlInputs(const NodeDef & node)203 int NumControlInputs(const NodeDef& node) {
204   int num_inputs = 0;
205   for (; num_inputs < node.input_size(); ++num_inputs) {
206     const string& input = node.input(node.input_size() - num_inputs - 1);
207     if (!IsControlInput(input)) {
208       return num_inputs;
209     }
210   }
211   return num_inputs;
212 }
213 
HasRegularOutputs(const NodeDef & node,const NodeMap & node_map)214 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
215   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
216     for (const string& node_as_input : output->input()) {
217       if (IsControlInput(node_as_input)) break;
218 
219       TensorId tensor = ParseTensorName(node_as_input);
220       if (tensor.node() == node.name()) {
221         return true;
222       }
223     }
224   }
225   return false;
226 }
227 
HasControlOutputs(const NodeDef & node,const NodeMap & node_map)228 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
229   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
230     for (int idx = output->input_size() - 1; idx >= 0; --idx) {
231       const string& node_as_input = output->input(idx);
232       if (!IsControlInput(node_as_input)) break;
233 
234       TensorId tensor = ParseTensorName(node_as_input);
235       if (tensor.node() == node.name()) {
236         return true;
237       }
238     }
239   }
240   return false;
241 }
242 
NumControlOutputs(const NodeDef & node,const NodeMap & node_map)243 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
244   int num_outputs = 0;
245   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
246     for (int idx = output->input_size() - 1; idx >= 0; --idx) {
247       const string& node_as_input = output->input(idx);
248       if (!IsControlInput(node_as_input)) break;
249 
250       TensorId tensor = ParseTensorName(node_as_input);
251       if (tensor.node() == node.name()) {
252         ++num_outputs;
253       }
254     }
255   }
256   return num_outputs;
257 }
258 
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
260   int num_outputs = 0;
261   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
262     for (const string& node_as_input : output->input()) {
263       if (IsControlInput(node_as_input)) {
264         break;
265       }
266       if (node_as_input == node.name()) {
267         ++num_outputs;
268       } else {
269         const TensorId tensor = ParseTensorName(node_as_input);
270         if (tensor.node() == node.name()) {
271           ++num_outputs;
272         }
273       }
274     }
275   }
276   return num_outputs;
277 }
278 
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)279 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
280   int num_data_outputs = 0;
281   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
282     if (IsShapeConsumer(*output)) continue;
283 
284     for (int i = 0; i < output->input_size(); ++i) {
285       const string& input = output->input(i);
286       if (!IsControlInput(input) && NodeName(input) == node.name()) {
287         ++num_data_outputs;
288         break;
289       }
290     }
291   }
292   return num_data_outputs;
293 }
294 
295 // Returns the data type in attribute `attr_name` of `node`. If that attribute
296 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)297 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
298   if (!node.attr().count(type_attr)) {
299     return DT_INVALID;
300   }
301   const auto& attr = node.attr().at(type_attr);
302   if (attr.value_case() != AttrValue::kType) {
303     return DT_INVALID;
304   }
305   return attr.type();
306 }
307 
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)308 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
309                         bool follow_control_input,
310                         const std::function<bool(const NodeDef&)>& pred_fn) {
311   const NodeDef* current = &source;
312   const NodeDef* next = current;
313   while (next == &source || (next != nullptr && pred_fn(*next))) {
314     current = next;
315     if (current->input_size() == 0 ||
316         (!follow_control_input && IsControlInput(current->input(0)))) {
317       break;
318     }
319     next = node_map.GetNode(current->input(0));
320     if (next == nullptr) {
321       LOG(ERROR) << "Node not found: " << current->input(0);
322     }
323   }
324   return const_cast<NodeDef*>(current);
325 }
326 
327 // Every permutation is a product of one or more cycles. Iterate over the cycles
328 // in the permutation, and convert each of those into a product of
329 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)330 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
331                          bool invert_permutation) {
332   CHECK_EQ(graph->node_size(), permutation->size());
333   std::vector<int> inv_perm(permutation->size(), 0);
334   if (invert_permutation) {
335     for (size_t n = 0; n < permutation->size(); ++n) {
336       inv_perm[(*permutation)[n]] = n;
337     }
338     permutation->swap(inv_perm);
339   }
340   for (int n = 0, end = permutation->size(); n + 1 < end; ++n) {
341     while (n != (*permutation)[n]) {
342       std::size_t r = (*permutation)[n];
343       graph->mutable_node()->SwapElements(n, r);
344       std::swap((*permutation)[n], (*permutation)[r]);
345     }
346   }
347 }
348 
DedupControlInputs(NodeDef * node)349 void DedupControlInputs(NodeDef* node) {
350   absl::flat_hash_set<string> inputs;
351   int pos = 0;
352   while (pos < node->input_size()) {
353     const string& input = node->input(pos);
354     if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
355       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
356       node->mutable_input()->RemoveLast();
357     } else {
358       ++pos;
359     }
360   }
361 }
362 
363 namespace {
364 
365 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)366 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
367                              GraphDef* graph) {
368   static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
369                 "Need to pass container of ints");
370 
371   int last = graph->node_size() - 1;
372   for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
373     const int index = *it;
374     graph->mutable_node()->SwapElements(index, last);
375     last--;
376   }
377   graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
378 }
379 
380 template <typename T>
STLSortAndRemoveDuplicates(T * v)381 inline void STLSortAndRemoveDuplicates(T* v) {
382   std::sort(v->begin(), v->end());
383   v->erase(std::unique(v->begin(), v->end()), v->end());
384 }
385 
386 }  // namespace
387 
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)388 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
389                          GraphDef* graph) {
390   EraseNodesFromGraphImpl(nodes_to_delete, graph);
391 }
392 
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)393 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
394   STLSortAndRemoveDuplicates(&nodes_to_delete);
395   EraseNodesFromGraphImpl(nodes_to_delete, graph);
396 }
397 
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)398 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
399                          GraphDef* graph) {
400   std::vector<int> nodes_idx_to_delete;
401   nodes_idx_to_delete.reserve(nodes_to_delete.size());
402   for (int i = 0; i < graph->node_size(); ++i) {
403     if (nodes_to_delete.count(graph->node(i).name()))
404       nodes_idx_to_delete.push_back(i);
405   }
406   EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
407 }
408 
409 #define HANDLE_DOUBLE_CASE(DTYPE)                                     \
410   case DTYPE:                                                         \
411     if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
412             static_cast<double>(value), tensor)) {                    \
413       return errors::InvalidArgument("Cannot store value ", value,    \
414                                      " in tensor of type " #DTYPE);   \
415     }                                                                 \
416     break
417 
418 #define HANDLE_INT_CASE(DTYPE)                                               \
419   case DTYPE:                                                                \
420     if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value,     \
421                                                                   tensor)) { \
422       return errors::InvalidArgument("Cannot store value ", value,           \
423                                      " in tensor of type " #DTYPE);          \
424     }                                                                        \
425     break
426 
SetTensorValue(DataType dtype,int value,Tensor * tensor)427 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
428   // TODO(rmlarsen): Support more general shapes.
429   // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
430   if (tensor->NumElements() != 1) {
431     return errors::InvalidArgument(
432         "Expected scalar tensor, got num_elements = ", tensor->NumElements());
433   }
434   switch (dtype) {
435     HANDLE_DOUBLE_CASE(DT_HALF);
436     HANDLE_DOUBLE_CASE(DT_BFLOAT16);
437     HANDLE_DOUBLE_CASE(DT_BOOL);
438     HANDLE_DOUBLE_CASE(DT_FLOAT);
439     HANDLE_DOUBLE_CASE(DT_DOUBLE);
440     HANDLE_DOUBLE_CASE(DT_UINT8);
441     HANDLE_DOUBLE_CASE(DT_INT8);
442     HANDLE_DOUBLE_CASE(DT_UINT16);
443     HANDLE_DOUBLE_CASE(DT_INT16);
444     HANDLE_DOUBLE_CASE(DT_INT32);
445     HANDLE_DOUBLE_CASE(DT_INT64);
446     HANDLE_DOUBLE_CASE(DT_COMPLEX64);
447     HANDLE_DOUBLE_CASE(DT_COMPLEX128);
448     HANDLE_INT_CASE(DT_QINT8);
449     HANDLE_INT_CASE(DT_QUINT8);
450     HANDLE_INT_CASE(DT_QINT16);
451     HANDLE_INT_CASE(DT_QUINT16);
452     HANDLE_INT_CASE(DT_QINT32);
453     default:
454       return errors::InvalidArgument("Unsupported type ",
455                                      DataTypeString(dtype));
456   }
457   return OkStatus();
458 }
459 
460 #undef HANDLE_CASE
461 
CheckAttrExists(const NodeDef & node,const string & key)462 Status CheckAttrExists(const NodeDef& node, const string& key) {
463   if (!HasNodeAttr(node, key)) {
464     return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
465                                    "' attr: ", node.ShortDebugString());
466   }
467   return OkStatus();
468 }
469 
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)470 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
471   for (const string& key : keys) {
472     TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
473   }
474   return OkStatus();
475 }
476 
IsKernelRegisteredForNode(absl::string_view node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,absl::string_view node_op,absl::string_view node_device,AttrSlice node_attrs)477 Status IsKernelRegisteredForNode(
478     absl::string_view node_name, bool has_experimental_debug_info,
479     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
480     absl::string_view node_op, absl::string_view node_device,
481     AttrSlice node_attrs) {
482   DeviceNameUtils::ParsedName parsed_name;
483   if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
484     return errors::InvalidArgument("Could not parse device name: ",
485                                    node_device);
486   }
487   return FindKernelDef(DeviceType(parsed_name.type), node_name,
488                        has_experimental_debug_info, experimental_debug_info,
489                        node_op, node_device, node_attrs, nullptr, nullptr);
490 }
491 
IsKernelRegisteredForNode(const NodeDef & node)492 Status IsKernelRegisteredForNode(const NodeDef& node) {
493   return IsKernelRegisteredForNode(node.name(),
494                                    node.has_experimental_debug_info(),
495                                    node.experimental_debug_info(), node.op(),
496                                    node.device(), AttrSlice(&node.attr()));
497 }
498 
499 namespace {
RemoveAttributes(const std::vector<absl::string_view> & to_remove,NodeDef * node)500 void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
501                       NodeDef* node) {
502   if (to_remove.size() == node->attr_size()) {
503     node->clear_attr();
504   } else {
505     for (const auto& key : to_remove) {
506       node->mutable_attr()->erase(string(key));
507     }
508   }
509 }
510 }  // namespace
511 
EraseRegularNodeAttributes(NodeDef * node)512 int EraseRegularNodeAttributes(NodeDef* node) {
513   std::vector<absl::string_view> to_remove;
514   for (const auto& attr : node->attr()) {
515     if (!attr.first.empty() && (attr.first)[0] != '_') {
516       to_remove.push_back(attr.first);
517     }
518   }
519   RemoveAttributes(to_remove, node);
520   return to_remove.size();
521 }
522 
EraseNodeOutputAttributes(NodeDef * node)523 int EraseNodeOutputAttributes(NodeDef* node) {
524   std::vector<absl::string_view> to_remove;
525   for (const auto& attr : node->attr()) {
526     const string& attr_name = attr.first;
527     if (attr_name == "_xla_inferred_shapes" ||
528         absl::StartsWith(attr_name, "_output_")) {
529       to_remove.push_back(attr_name);
530     }
531   }
532   RemoveAttributes(to_remove, node);
533   return to_remove.size();
534 }
535 
536 }  // end namespace grappler
537 }  // end namespace tensorflow
538