• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils.h"
17 
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22 
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/notification.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 namespace {
42 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)43 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
44   using RealType = typename Eigen::NumTraits<T>::Real;
45   if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
46       value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
47     return false;
48   }
49   tensor->flat<T>()(0) = static_cast<T>(value);
50   return true;
51 }
52 
53 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)54 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
55   using RealType = typename Eigen::NumTraits<T>::Real;
56   if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
57       value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
58     return false;
59   }
60   tensor->flat<T>()(0) = static_cast<T>(value);
61   return true;
62 }
63 
64 // Is 'node' an operator that consumes only the shape of its input, not the
65 // data itself?
66 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
67 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)68 bool IsShapeConsumer(const NodeDef& node) {
69   const string& op = node.op();
70   return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
71 }
72 
73 }  // namespace
74 
NodeMap(GraphDef * graph)75 NodeMap::NodeMap(GraphDef* graph) {
76   CHECK(graph != nullptr);
77   for (int i = 0; i < graph->node_size(); i++) {
78     NodeDef* node = graph->mutable_node(i);
79     const string& node_name = node->name();
80     auto rslt = nodes_.emplace(node_name, node);
81     // Check that the graph doesn't contain multiple nodes with the same name.
82     if (!rslt.second) {
83       LOG(WARNING) << "Duplicated node in the graph: " << node_name;
84     }
85     for (const auto& input : node->input()) {
86       outputs_[NodeName(input)].insert(nodes_[node_name]);
87     }
88   }
89 }
90 
RemoveNode(const string & name)91 void NodeMap::RemoveNode(const string& name) {
92   nodes_.erase(NodeName(name));
93   outputs_.erase(NodeName(name));
94 }
95 
GetNode(const string & name) const96 NodeDef* NodeMap::GetNode(const string& name) const {
97   const string node_name = NodeName(name);
98   auto it = nodes_.find(node_name);
99   if (it == nodes_.end()) {
100     return nullptr;
101   }
102   return it->second;
103 }
104 
NodeExists(const string & name) const105 bool NodeMap::NodeExists(const string& name) const {
106   const string node_name = NodeName(name);
107   return nodes_.find(node_name) != nodes_.end();
108 }
109 
GetOutputs(const string & node_name) const110 const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
111   auto it = outputs_.find(node_name);
112   if (it == outputs_.end()) {
113     return empty_set_;
114   }
115   return it->second;
116 }
117 
AddNode(const string & node_name,NodeDef * node)118 void NodeMap::AddNode(const string& node_name, NodeDef* node) {
119   auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
120   CHECK(ret.second) << "Pair (" << node_name << "," << node
121                     << ") is not inserted because the same key already exists.";
122 }
123 
AddOutput(const string & node_name,const string & output_name)124 void NodeMap::AddOutput(const string& node_name, const string& output_name) {
125   auto output_node = nodes_[NodeName(output_name)];
126   CHECK(output_node) << "Output node " << output_name
127                      << " is missing in NodeMap.";
128   outputs_[node_name].insert(output_node);
129 }
130 
RemoveOutput(const string & node_name,const string & output_name)131 void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
132   outputs_[node_name].erase(nodes_[NodeName(output_name)]);
133 }
134 
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)135 void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
136                           const string& new_input_name) {
137   RemoveOutput(NodeName(old_input_name), node_name);
138   AddOutput(NodeName(new_input_name), node_name);
139 }
140 
RemoveInputs(const string & node_name)141 void NodeMap::RemoveInputs(const string& node_name) {
142   auto node = nodes_[node_name];
143   for (const auto& input : node->input()) {
144     RemoveOutput(NodeName(input), node->name());
145   }
146 }
147 
RemoveOutputs(const string & node_name)148 void NodeMap::RemoveOutputs(const string& node_name) {
149   outputs_.erase(node_name);
150 }
151 
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)152 void NodeMap::UpdateOutput(const string& node_name,
153                            const string& old_output_name,
154                            const string& new_output_name) {
155   std::set<NodeDef*>& outputs = outputs_[node_name];
156   outputs.erase(nodes_[NodeName(old_output_name)]);
157   outputs.insert(nodes_[NodeName(new_output_name)]);
158 }
159 
TensorIdToString(const TensorId & tensor_id)160 string TensorIdToString(const TensorId& tensor_id) {
161   return tensor_id.index() == 0 ? string(tensor_id.node())
162                                 : tensor_id.ToString();
163 }
164 
IsSameInput(const string & name1,const string & name2)165 bool IsSameInput(const string& name1, const string& name2) {
166   if (name1 == name2) return true;
167   TensorId tensor1 = ParseTensorName(name1);
168   TensorId tensor2 = ParseTensorName(name2);
169   return tensor1 == tensor2;
170 }
171 
IsControlInput(const string & name)172 bool IsControlInput(const string& name) {
173   return !name.empty() && name[0] == '^';
174 }
175 
IsControlInput(const TensorId & tensor_id)176 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
177 
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)178 string AddPrefixToNodeName(const string& name, const string& prefix,
179                            const string& delimiter) {
180   if (!name.empty()) {
181     if (name[0] == '^') {
182       return absl::StrCat("^", prefix, delimiter, name.substr(1));
183     }
184   }
185   return absl::StrCat(prefix, delimiter, name);
186 }
187 
AddPrefixToNodeName(const string & name,const string & prefix)188 string AddPrefixToNodeName(const string& name, const string& prefix) {
189   return AddPrefixToNodeName(name, prefix, "/");
190 }
191 
ExecuteWithTimeout(std::function<void ()> fn,const int64 timeout_in_ms,thread::ThreadPool * const thread_pool)192 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
193                         thread::ThreadPool* const thread_pool) {
194   if (timeout_in_ms <= 0) {
195     fn();
196     return true;
197   }
198   auto done = std::make_shared<Notification>();
199   thread_pool->Schedule([done, fn]() {
200     fn();
201     done->Notify();
202   });
203   const bool notified =
204       WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
205   return notified;
206 }
207 
AsControlDependency(const NodeDef & node)208 string AsControlDependency(const NodeDef& node) {
209   return absl::StrCat("^", node.name());
210 }
211 
AsControlDependency(const string & node_name)212 string AsControlDependency(const string& node_name) {
213   CHECK(!node_name.empty());
214   return (!node_name.empty() && node_name[0] == '^')
215              ? node_name
216              : absl::StrCat("^", node_name);
217 }
218 
NodeIsOnCpu(const NodeDef * node)219 bool NodeIsOnCpu(const NodeDef* node) {
220   string task, device;
221   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
222          absl::StartsWith(device, DEVICE_CPU);
223 }
224 
NodeIsOnGpu(const NodeDef * node)225 bool NodeIsOnGpu(const NodeDef* node) {
226   string task, device;
227   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
228          absl::StartsWith(device, DEVICE_GPU);
229 }
230 
NumOutputs(const NodeDef & node,GraphDef * graph)231 int NumOutputs(const NodeDef& node, GraphDef* graph) {
232   int num_outputs = 0;
233   const OpDef* op_def = nullptr;
234   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
235   if (status.ok()) {
236     for (const auto& output : op_def->output_arg()) {
237       if (!output.type_list_attr().empty()) {
238         num_outputs +=
239             node.attr().at(output.type_list_attr()).list().type_size();
240       } else if (!output.number_attr().empty()) {
241         num_outputs += node.attr().at(output.number_attr()).i();
242       } else {
243         num_outputs++;
244       }
245     }
246   } else {
247     FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
248     auto status = fdef.LookUpOpDef(node.op(), &op_def);
249     if (status.ok()) {
250       num_outputs = op_def->output_arg_size();
251     }
252   }
253   return num_outputs;
254 }
255 
HasControlInputs(const NodeDef & node)256 bool HasControlInputs(const NodeDef& node) {
257   int num_inputs = node.input_size();
258   if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
259     return true;
260   }
261   return false;
262 }
263 
NumNonControlInputs(const NodeDef & node)264 int NumNonControlInputs(const NodeDef& node) {
265   int num_inputs = node.input_size();
266   for (const string& input : node.input()) {
267     if (IsControlInput(input)) {
268       --num_inputs;
269     }
270   }
271   return num_inputs;
272 }
273 
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)274 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
275   int num_outputs = 0;
276   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
277     for (const string& node_as_input : output->input()) {
278       if (IsControlInput(node_as_input)) {
279         break;
280       }
281       if (node_as_input == node.name()) {
282         ++num_outputs;
283       } else {
284         const TensorId tensor = ParseTensorName(node_as_input);
285         if (tensor.node() == node.name()) {
286           ++num_outputs;
287         }
288       }
289     }
290   }
291   return num_outputs;
292 }
293 
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)294 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
295   int num_data_outputs = 0;
296   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
297     if (IsShapeConsumer(*output)) continue;
298 
299     for (int i = 0; i < output->input_size(); ++i) {
300       const string& input = output->input(i);
301       if (!IsControlInput(input) && NodeName(input) == node.name()) {
302         ++num_data_outputs;
303         break;
304       }
305     }
306   }
307   return num_data_outputs;
308 }
309 
310 // Returns the data type in attribute `attr_name` of `node`. If that attribute
311 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)312 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
313   if (!node.attr().count(type_attr)) {
314     return DT_INVALID;
315   }
316   const auto& attr = node.attr().at(type_attr);
317   if (attr.value_case() != AttrValue::kType) {
318     return DT_INVALID;
319   }
320   return attr.type();
321 }
322 
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)323 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
324                         bool follow_control_input,
325                         const std::function<bool(const NodeDef&)>& pred_fn) {
326   const NodeDef* current = &source;
327   const NodeDef* next = current;
328   while (next == &source || (next != nullptr && pred_fn(*next))) {
329     current = next;
330     if (current->input_size() == 0 ||
331         (!follow_control_input && IsControlInput(current->input(0)))) {
332       break;
333     }
334     next = node_map.GetNode(current->input(0));
335     if (next == nullptr) {
336       LOG(ERROR) << "Node not found: " << current->input(0);
337     }
338   }
339   return const_cast<NodeDef*>(current);
340 }
341 
342 // Every permutation is a product of one or more cycles. Iterate over the cycles
343 // in the permutation, and convert each of those into a product of
344 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)345 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
346                          bool invert_permutation) {
347   CHECK_EQ(graph->node_size(), permutation->size());
348   std::vector<int> inv_perm(permutation->size(), 0);
349   if (invert_permutation) {
350     for (size_t n = 0; n < permutation->size(); ++n) {
351       inv_perm[(*permutation)[n]] = n;
352     }
353     permutation->swap(inv_perm);
354   }
355   for (std::size_t n = 0; n + 1 < permutation->size(); ++n) {
356     while (n != (*permutation)[n]) {
357       std::size_t r = (*permutation)[n];
358       graph->mutable_node()->SwapElements(n, r);
359       std::swap((*permutation)[n], (*permutation)[r]);
360     }
361   }
362 }
363 
DedupControlInputs(NodeDef * node)364 void DedupControlInputs(NodeDef* node) {
365   std::unordered_set<string> inputs;
366   int pos = 0;
367   while (pos < node->input_size()) {
368     const string& input = node->input(pos);
369     if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
370       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
371       node->mutable_input()->RemoveLast();
372     } else {
373       ++pos;
374     }
375   }
376 }
377 
378 namespace {
379 
380 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)381 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
382                              GraphDef* graph) {
383   static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
384                 "Need to pass container of ints");
385 
386   int last = graph->node_size() - 1;
387   for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
388     const int index = *it;
389     graph->mutable_node()->SwapElements(index, last);
390     last--;
391   }
392   graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
393 }
394 
395 template <typename T>
STLSortAndRemoveDuplicates(T * v)396 inline void STLSortAndRemoveDuplicates(T* v) {
397   std::sort(v->begin(), v->end());
398   v->erase(std::unique(v->begin(), v->end()), v->end());
399 }
400 
401 }  // namespace
402 
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)403 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
404                          GraphDef* graph) {
405   EraseNodesFromGraphImpl(nodes_to_delete, graph);
406 }
407 
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)408 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
409   STLSortAndRemoveDuplicates(&nodes_to_delete);
410   EraseNodesFromGraphImpl(nodes_to_delete, graph);
411 }
412 
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)413 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
414                          GraphDef* graph) {
415   std::vector<int> nodes_idx_to_delete;
416   nodes_idx_to_delete.reserve(nodes_to_delete.size());
417   for (int i = 0; i < graph->node_size(); ++i) {
418     if (nodes_to_delete.count(graph->node(i).name()))
419       nodes_idx_to_delete.push_back(i);
420   }
421   EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
422 }
423 
424 #define HANDLE_DOUBLE_CASE(DTYPE)                                     \
425   case DTYPE:                                                         \
426     if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
427             static_cast<double>(value), tensor)) {                    \
428       return errors::InvalidArgument("Cannot store value ", value,    \
429                                      " in tensor of type " #DTYPE);   \
430     }                                                                 \
431     break
432 
433 #define HANDLE_INT_CASE(DTYPE)                                               \
434   case DTYPE:                                                                \
435     if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value,     \
436                                                                   tensor)) { \
437       return errors::InvalidArgument("Cannot store value ", value,           \
438                                      " in tensor of type " #DTYPE);          \
439     }                                                                        \
440     break
441 
SetTensorValue(DataType dtype,int value,Tensor * tensor)442 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
443   // TODO(rmlarsen): Support more general shapes.
444   // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
445   if (tensor->NumElements() != 1) {
446     return errors::InvalidArgument(
447         "Expected scalar tensor, got num_elements = ", tensor->NumElements());
448   }
449   switch (dtype) {
450     HANDLE_DOUBLE_CASE(DT_HALF);
451     HANDLE_DOUBLE_CASE(DT_BFLOAT16);
452     HANDLE_DOUBLE_CASE(DT_BOOL);
453     HANDLE_DOUBLE_CASE(DT_FLOAT);
454     HANDLE_DOUBLE_CASE(DT_DOUBLE);
455     HANDLE_DOUBLE_CASE(DT_UINT8);
456     HANDLE_DOUBLE_CASE(DT_INT8);
457     HANDLE_DOUBLE_CASE(DT_UINT16);
458     HANDLE_DOUBLE_CASE(DT_INT16);
459     HANDLE_DOUBLE_CASE(DT_INT32);
460     HANDLE_DOUBLE_CASE(DT_INT64);
461     HANDLE_DOUBLE_CASE(DT_COMPLEX64);
462     HANDLE_DOUBLE_CASE(DT_COMPLEX128);
463     HANDLE_INT_CASE(DT_QINT8);
464     HANDLE_INT_CASE(DT_QUINT8);
465     HANDLE_INT_CASE(DT_QINT16);
466     HANDLE_INT_CASE(DT_QUINT16);
467     HANDLE_INT_CASE(DT_QINT32);
468     default:
469       return errors::InvalidArgument("Unsupported type ",
470                                      DataTypeString(dtype));
471   }
472   return Status::OK();
473 }
474 
475 #undef HANDLE_CASE
476 
CheckAttrExists(const NodeDef & node,const string & key)477 Status CheckAttrExists(const NodeDef& node, const string& key) {
478   if (!HasNodeAttr(node, key)) {
479     return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
480                                    "' attr: ", node.ShortDebugString());
481   }
482   return Status::OK();
483 }
484 
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)485 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
486   for (const string& key : keys) {
487     TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
488   }
489   return Status::OK();
490 }
491 
IsKernelRegisteredForNode(const NodeDef & node)492 Status IsKernelRegisteredForNode(const NodeDef& node) {
493   DeviceNameUtils::ParsedName parsed_name;
494   if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
495     return errors::InvalidArgument("Could not parse device name: ",
496                                    node.device());
497   }
498   return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr);
499 }
500 
501 }  // end namespace grappler
502 }  // end namespace tensorflow
503