• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils.h"
17 
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/scanner.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/notification.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 
40 namespace tensorflow {
41 namespace grappler {
42 namespace {
43 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)44 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
45   using RealType = typename Eigen::NumTraits<T>::Real;
46   if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
47       value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
48     return false;
49   }
50   tensor->flat<T>()(0) = static_cast<T>(value);
51   return true;
52 }
53 
54 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)55 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
56   using RealType = typename Eigen::NumTraits<T>::Real;
57   if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
58       value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
59     return false;
60   }
61   tensor->flat<T>()(0) = static_cast<T>(value);
62   return true;
63 }
64 
65 // Is 'node' an operator that consumes only the shape of its input, not the
66 // data itself?
67 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
68 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)69 bool IsShapeConsumer(const NodeDef& node) {
70   const string& op = node.op();
71   return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
72 }
73 
74 }  // namespace
75 
76 namespace internal {
77 // Specialized template class method GetNodeDefFromGraph.
78 template <>
GetNodeDefFromGraph(GraphDef * graph,int64 i) const79 NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph(
80     GraphDef* graph, int64 i) const {
81   return graph->mutable_node(i);
82 }
83 
84 template <>
85 const NodeDef*
GetNodeDefFromGraph(const GraphDef * graph,int64 i) const86 NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph(
87     const GraphDef* graph, int64 i) const {
88   return &graph->node(i);
89 }
90 }  // namespace internal
TensorIdToString(const TensorId & tensor_id)91 string TensorIdToString(const TensorId& tensor_id) {
92   return tensor_id.index() == 0 ? string(tensor_id.node())
93                                 : tensor_id.ToString();
94 }
95 
SafeTensorIdToString(const SafeTensorId & tensor_id)96 string SafeTensorIdToString(const SafeTensorId& tensor_id) {
97   return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
98 }
99 
IsSameInput(const string & name1,const string & name2)100 bool IsSameInput(const string& name1, const string& name2) {
101   if (name1 == name2) return true;
102   TensorId tensor1 = ParseTensorName(name1);
103   TensorId tensor2 = ParseTensorName(name2);
104   return tensor1 == tensor2;
105 }
106 
IsControlInput(const string & name)107 bool IsControlInput(const string& name) {
108   return !name.empty() && name[0] == '^';
109 }
110 
IsControlInput(const TensorId & tensor_id)111 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
112 
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)113 string AddPrefixToNodeName(const string& name, const string& prefix,
114                            const string& delimiter) {
115   if (!name.empty()) {
116     if (name[0] == '^') {
117       return absl::StrCat("^", prefix, delimiter, name.substr(1));
118     }
119   }
120   return absl::StrCat(prefix, delimiter, name);
121 }
122 
AddPrefixToNodeName(const string & name,const string & prefix)123 string AddPrefixToNodeName(const string& name, const string& prefix) {
124   return AddPrefixToNodeName(name, prefix, "/");
125 }
126 
ExecuteWithTimeout(std::function<void ()> fn,const int64 timeout_in_ms,thread::ThreadPool * const thread_pool)127 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
128                         thread::ThreadPool* const thread_pool) {
129   if (timeout_in_ms <= 0) {
130     fn();
131     return true;
132   }
133   auto done = std::make_shared<Notification>();
134   thread_pool->Schedule([done, fn]() {
135     fn();
136     done->Notify();
137   });
138   const bool notified =
139       WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
140   return notified;
141 }
142 
AsControlDependency(const NodeDef & node)143 string AsControlDependency(const NodeDef& node) {
144   return absl::StrCat("^", node.name());
145 }
146 
AsControlDependency(const string & node_name)147 string AsControlDependency(const string& node_name) {
148   CHECK(!node_name.empty());
149   return (!node_name.empty() && node_name[0] == '^')
150              ? node_name
151              : absl::StrCat("^", node_name);
152 }
153 
NodeIsOnCpu(const NodeDef * node)154 bool NodeIsOnCpu(const NodeDef* node) {
155   string task, device;
156   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
157          absl::StartsWith(device, DEVICE_CPU);
158 }
159 
NodeIsOnGpu(const NodeDef * node)160 bool NodeIsOnGpu(const NodeDef* node) {
161   string task, device;
162   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
163          absl::StartsWith(device, DEVICE_GPU);
164 }
165 
NumOutputs(const NodeDef & node,GraphDef * graph)166 int NumOutputs(const NodeDef& node, GraphDef* graph) {
167   int num_outputs = 0;
168   const OpDef* op_def = nullptr;
169   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
170   if (status.ok()) {
171     for (const auto& output : op_def->output_arg()) {
172       if (!output.type_list_attr().empty()) {
173         num_outputs +=
174             node.attr().at(output.type_list_attr()).list().type_size();
175       } else if (!output.number_attr().empty()) {
176         num_outputs += node.attr().at(output.number_attr()).i();
177       } else {
178         num_outputs++;
179       }
180     }
181   } else {
182     FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
183     auto status = fdef.LookUpOpDef(node.op(), &op_def);
184     if (status.ok()) {
185       num_outputs = op_def->output_arg_size();
186     }
187   }
188   return num_outputs;
189 }
190 
HasControlInputs(const NodeDef & node)191 bool HasControlInputs(const NodeDef& node) {
192   const int num_inputs = node.input_size();
193   if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
194     return true;
195   }
196   return false;
197 }
198 
HasRegularInputs(const NodeDef & node)199 bool HasRegularInputs(const NodeDef& node) {
200   const int num_inputs = node.input_size();
201   if (num_inputs > 0 && !IsControlInput(node.input(0))) {
202     return true;
203   }
204   return false;
205 }
206 
NumNonControlInputs(const NodeDef & node)207 int NumNonControlInputs(const NodeDef& node) {
208   int num_inputs = 0;
209   for (; num_inputs < node.input_size(); ++num_inputs) {
210     const string& input = node.input(num_inputs);
211     if (IsControlInput(input)) {
212       return num_inputs;
213     }
214   }
215   return num_inputs;
216 }
217 
NumControlInputs(const NodeDef & node)218 int NumControlInputs(const NodeDef& node) {
219   int num_inputs = 0;
220   for (; num_inputs < node.input_size(); ++num_inputs) {
221     const string& input = node.input(node.input_size() - num_inputs - 1);
222     if (!IsControlInput(input)) {
223       return num_inputs;
224     }
225   }
226   return num_inputs;
227 }
228 
HasRegularOutputs(const NodeDef & node,const NodeMap & node_map)229 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
230   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
231     for (const string& node_as_input : output->input()) {
232       if (IsControlInput(node_as_input)) break;
233 
234       TensorId tensor = ParseTensorName(node_as_input);
235       if (tensor.node() == node.name()) {
236         return true;
237       }
238     }
239   }
240   return false;
241 }
242 
HasControlOutputs(const NodeDef & node,const NodeMap & node_map)243 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
244   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
245     for (int idx = output->input_size() - 1; idx >= 0; --idx) {
246       const string& node_as_input = output->input(idx);
247       if (!IsControlInput(node_as_input)) break;
248 
249       TensorId tensor = ParseTensorName(node_as_input);
250       if (tensor.node() == node.name()) {
251         return true;
252       }
253     }
254   }
255   return false;
256 }
257 
NumControlOutputs(const NodeDef & node,const NodeMap & node_map)258 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
259   int num_outputs = 0;
260   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
261     for (int idx = output->input_size() - 1; idx >= 0; --idx) {
262       const string& node_as_input = output->input(idx);
263       if (!IsControlInput(node_as_input)) break;
264 
265       TensorId tensor = ParseTensorName(node_as_input);
266       if (tensor.node() == node.name()) {
267         ++num_outputs;
268       }
269     }
270   }
271   return num_outputs;
272 }
273 
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)274 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
275   int num_outputs = 0;
276   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
277     for (const string& node_as_input : output->input()) {
278       if (IsControlInput(node_as_input)) {
279         break;
280       }
281       if (node_as_input == node.name()) {
282         ++num_outputs;
283       } else {
284         const TensorId tensor = ParseTensorName(node_as_input);
285         if (tensor.node() == node.name()) {
286           ++num_outputs;
287         }
288       }
289     }
290   }
291   return num_outputs;
292 }
293 
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)294 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
295   int num_data_outputs = 0;
296   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
297     if (IsShapeConsumer(*output)) continue;
298 
299     for (int i = 0; i < output->input_size(); ++i) {
300       const string& input = output->input(i);
301       if (!IsControlInput(input) && NodeName(input) == node.name()) {
302         ++num_data_outputs;
303         break;
304       }
305     }
306   }
307   return num_data_outputs;
308 }
309 
310 // Returns the data type in attribute `attr_name` of `node`. If that attribute
311 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)312 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
313   if (!node.attr().count(type_attr)) {
314     return DT_INVALID;
315   }
316   const auto& attr = node.attr().at(type_attr);
317   if (attr.value_case() != AttrValue::kType) {
318     return DT_INVALID;
319   }
320   return attr.type();
321 }
322 
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)323 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
324                         bool follow_control_input,
325                         const std::function<bool(const NodeDef&)>& pred_fn) {
326   const NodeDef* current = &source;
327   const NodeDef* next = current;
328   while (next == &source || (next != nullptr && pred_fn(*next))) {
329     current = next;
330     if (current->input_size() == 0 ||
331         (!follow_control_input && IsControlInput(current->input(0)))) {
332       break;
333     }
334     next = node_map.GetNode(current->input(0));
335     if (next == nullptr) {
336       LOG(ERROR) << "Node not found: " << current->input(0);
337     }
338   }
339   return const_cast<NodeDef*>(current);
340 }
341 
342 // Every permutation is a product of one or more cycles. Iterate over the cycles
343 // in the permutation, and convert each of those into a product of
344 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)345 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
346                          bool invert_permutation) {
347   CHECK_EQ(graph->node_size(), permutation->size());
348   std::vector<int> inv_perm(permutation->size(), 0);
349   if (invert_permutation) {
350     for (size_t n = 0; n < permutation->size(); ++n) {
351       inv_perm[(*permutation)[n]] = n;
352     }
353     permutation->swap(inv_perm);
354   }
355   for (int n = 0, end = permutation->size(); n + 1 < end; ++n) {
356     while (n != (*permutation)[n]) {
357       std::size_t r = (*permutation)[n];
358       graph->mutable_node()->SwapElements(n, r);
359       std::swap((*permutation)[n], (*permutation)[r]);
360     }
361   }
362 }
363 
DedupControlInputs(NodeDef * node)364 void DedupControlInputs(NodeDef* node) {
365   absl::flat_hash_set<string> inputs;
366   int pos = 0;
367   while (pos < node->input_size()) {
368     const string& input = node->input(pos);
369     if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
370       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
371       node->mutable_input()->RemoveLast();
372     } else {
373       ++pos;
374     }
375   }
376 }
377 
378 namespace {
379 
380 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)381 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
382                              GraphDef* graph) {
383   static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
384                 "Need to pass container of ints");
385 
386   int last = graph->node_size() - 1;
387   for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
388     const int index = *it;
389     graph->mutable_node()->SwapElements(index, last);
390     last--;
391   }
392   graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
393 }
394 
395 template <typename T>
STLSortAndRemoveDuplicates(T * v)396 inline void STLSortAndRemoveDuplicates(T* v) {
397   std::sort(v->begin(), v->end());
398   v->erase(std::unique(v->begin(), v->end()), v->end());
399 }
400 
401 }  // namespace
402 
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)403 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
404                          GraphDef* graph) {
405   EraseNodesFromGraphImpl(nodes_to_delete, graph);
406 }
407 
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)408 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
409   STLSortAndRemoveDuplicates(&nodes_to_delete);
410   EraseNodesFromGraphImpl(nodes_to_delete, graph);
411 }
412 
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)413 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
414                          GraphDef* graph) {
415   std::vector<int> nodes_idx_to_delete;
416   nodes_idx_to_delete.reserve(nodes_to_delete.size());
417   for (int i = 0; i < graph->node_size(); ++i) {
418     if (nodes_to_delete.count(graph->node(i).name()))
419       nodes_idx_to_delete.push_back(i);
420   }
421   EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
422 }
423 
424 #define HANDLE_DOUBLE_CASE(DTYPE)                                     \
425   case DTYPE:                                                         \
426     if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
427             static_cast<double>(value), tensor)) {                    \
428       return errors::InvalidArgument("Cannot store value ", value,    \
429                                      " in tensor of type " #DTYPE);   \
430     }                                                                 \
431     break
432 
433 #define HANDLE_INT_CASE(DTYPE)                                               \
434   case DTYPE:                                                                \
435     if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value,     \
436                                                                   tensor)) { \
437       return errors::InvalidArgument("Cannot store value ", value,           \
438                                      " in tensor of type " #DTYPE);          \
439     }                                                                        \
440     break
441 
SetTensorValue(DataType dtype,int value,Tensor * tensor)442 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
443   // TODO(rmlarsen): Support more general shapes.
444   // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
445   if (tensor->NumElements() != 1) {
446     return errors::InvalidArgument(
447         "Expected scalar tensor, got num_elements = ", tensor->NumElements());
448   }
449   switch (dtype) {
450     HANDLE_DOUBLE_CASE(DT_HALF);
451     HANDLE_DOUBLE_CASE(DT_BFLOAT16);
452     HANDLE_DOUBLE_CASE(DT_BOOL);
453     HANDLE_DOUBLE_CASE(DT_FLOAT);
454     HANDLE_DOUBLE_CASE(DT_DOUBLE);
455     HANDLE_DOUBLE_CASE(DT_UINT8);
456     HANDLE_DOUBLE_CASE(DT_INT8);
457     HANDLE_DOUBLE_CASE(DT_UINT16);
458     HANDLE_DOUBLE_CASE(DT_INT16);
459     HANDLE_DOUBLE_CASE(DT_INT32);
460     HANDLE_DOUBLE_CASE(DT_INT64);
461     HANDLE_DOUBLE_CASE(DT_COMPLEX64);
462     HANDLE_DOUBLE_CASE(DT_COMPLEX128);
463     HANDLE_INT_CASE(DT_QINT8);
464     HANDLE_INT_CASE(DT_QUINT8);
465     HANDLE_INT_CASE(DT_QINT16);
466     HANDLE_INT_CASE(DT_QUINT16);
467     HANDLE_INT_CASE(DT_QINT32);
468     default:
469       return errors::InvalidArgument("Unsupported type ",
470                                      DataTypeString(dtype));
471   }
472   return Status::OK();
473 }
474 
475 #undef HANDLE_CASE
476 
CheckAttrExists(const NodeDef & node,const string & key)477 Status CheckAttrExists(const NodeDef& node, const string& key) {
478   if (!HasNodeAttr(node, key)) {
479     return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
480                                    "' attr: ", node.ShortDebugString());
481   }
482   return Status::OK();
483 }
484 
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)485 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
486   for (const string& key : keys) {
487     TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
488   }
489   return Status::OK();
490 }
491 
IsKernelRegisteredForNode(absl::string_view node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,absl::string_view node_op,absl::string_view node_device,AttrSlice node_attrs)492 Status IsKernelRegisteredForNode(
493     absl::string_view node_name, bool has_experimental_debug_info,
494     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
495     absl::string_view node_op, absl::string_view node_device,
496     AttrSlice node_attrs) {
497   DeviceNameUtils::ParsedName parsed_name;
498   if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
499     return errors::InvalidArgument("Could not parse device name: ",
500                                    node_device);
501   }
502   return FindKernelDef(DeviceType(parsed_name.type), node_name,
503                        has_experimental_debug_info, experimental_debug_info,
504                        node_op, node_device, node_attrs, nullptr, nullptr);
505 }
506 
IsKernelRegisteredForNode(const NodeDef & node)507 Status IsKernelRegisteredForNode(const NodeDef& node) {
508   return IsKernelRegisteredForNode(node.name(),
509                                    node.has_experimental_debug_info(),
510                                    node.experimental_debug_info(), node.op(),
511                                    node.device(), AttrSlice(&node.attr()));
512 }
513 
514 namespace {
RemoveAttributes(const std::vector<absl::string_view> & to_remove,NodeDef * node)515 void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
516                       NodeDef* node) {
517   if (to_remove.size() == node->attr_size()) {
518     node->clear_attr();
519   } else {
520     for (const auto& key : to_remove) {
521       node->mutable_attr()->erase(string(key));
522     }
523   }
524 }
525 }  // namespace
526 
EraseRegularNodeAttributes(NodeDef * node)527 int EraseRegularNodeAttributes(NodeDef* node) {
528   std::vector<absl::string_view> to_remove;
529   for (const auto& attr : node->attr()) {
530     if (!attr.first.empty() && (attr.first)[0] != '_') {
531       to_remove.push_back(attr.first);
532     }
533   }
534   RemoveAttributes(to_remove, node);
535   return to_remove.size();
536 }
537 
EraseNodeOutputAttributes(NodeDef * node)538 int EraseNodeOutputAttributes(NodeDef* node) {
539   std::vector<absl::string_view> to_remove;
540   for (const auto& attr : node->attr()) {
541     const string& attr_name = attr.first;
542     if (attr_name == "_xla_inferred_shapes" ||
543         absl::StartsWith(attr_name, "_output_")) {
544       to_remove.push_back(attr_name);
545     }
546   }
547   RemoveAttributes(to_remove, node);
548   return to_remove.size();
549 }
550 
551 }  // end namespace grappler
552 }  // end namespace tensorflow
553