1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils.h"
17
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/scanner.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/notification.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39
40 namespace tensorflow {
41 namespace grappler {
42 namespace {
43 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)44 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
45 using RealType = typename Eigen::NumTraits<T>::Real;
46 if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
47 value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
48 return false;
49 }
50 tensor->flat<T>()(0) = static_cast<T>(value);
51 return true;
52 }
53
54 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)55 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
56 using RealType = typename Eigen::NumTraits<T>::Real;
57 if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
58 value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
59 return false;
60 }
61 tensor->flat<T>()(0) = static_cast<T>(value);
62 return true;
63 }
64
65 // Is 'node' an operator that consumes only the shape of its input, not the
66 // data itself?
67 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
68 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)69 bool IsShapeConsumer(const NodeDef& node) {
70 const string& op = node.op();
71 return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
72 }
73
74 } // namespace
75
76 namespace internal {
77 // Specialized template class method GetNodeDefFromGraph.
78 template <>
GetNodeDefFromGraph(GraphDef * graph,int64 i) const79 NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph(
80 GraphDef* graph, int64 i) const {
81 return graph->mutable_node(i);
82 }
83
84 template <>
85 const NodeDef*
GetNodeDefFromGraph(const GraphDef * graph,int64 i) const86 NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph(
87 const GraphDef* graph, int64 i) const {
88 return &graph->node(i);
89 }
90 } // namespace internal
TensorIdToString(const TensorId & tensor_id)91 string TensorIdToString(const TensorId& tensor_id) {
92 return tensor_id.index() == 0 ? string(tensor_id.node())
93 : tensor_id.ToString();
94 }
95
SafeTensorIdToString(const SafeTensorId & tensor_id)96 string SafeTensorIdToString(const SafeTensorId& tensor_id) {
97 return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
98 }
99
IsSameInput(const string & name1,const string & name2)100 bool IsSameInput(const string& name1, const string& name2) {
101 if (name1 == name2) return true;
102 TensorId tensor1 = ParseTensorName(name1);
103 TensorId tensor2 = ParseTensorName(name2);
104 return tensor1 == tensor2;
105 }
106
IsControlInput(const string & name)107 bool IsControlInput(const string& name) {
108 return !name.empty() && name[0] == '^';
109 }
110
IsControlInput(const TensorId & tensor_id)111 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
112
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)113 string AddPrefixToNodeName(const string& name, const string& prefix,
114 const string& delimiter) {
115 if (!name.empty()) {
116 if (name[0] == '^') {
117 return absl::StrCat("^", prefix, delimiter, name.substr(1));
118 }
119 }
120 return absl::StrCat(prefix, delimiter, name);
121 }
122
AddPrefixToNodeName(const string & name,const string & prefix)123 string AddPrefixToNodeName(const string& name, const string& prefix) {
124 return AddPrefixToNodeName(name, prefix, "/");
125 }
126
ExecuteWithTimeout(std::function<void ()> fn,const int64 timeout_in_ms,thread::ThreadPool * const thread_pool)127 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
128 thread::ThreadPool* const thread_pool) {
129 if (timeout_in_ms <= 0) {
130 fn();
131 return true;
132 }
133 auto done = std::make_shared<Notification>();
134 thread_pool->Schedule([done, fn]() {
135 fn();
136 done->Notify();
137 });
138 const bool notified =
139 WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
140 return notified;
141 }
142
AsControlDependency(const NodeDef & node)143 string AsControlDependency(const NodeDef& node) {
144 return absl::StrCat("^", node.name());
145 }
146
AsControlDependency(const string & node_name)147 string AsControlDependency(const string& node_name) {
148 CHECK(!node_name.empty());
149 return (!node_name.empty() && node_name[0] == '^')
150 ? node_name
151 : absl::StrCat("^", node_name);
152 }
153
NodeIsOnCpu(const NodeDef * node)154 bool NodeIsOnCpu(const NodeDef* node) {
155 string task, device;
156 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
157 absl::StartsWith(device, DEVICE_CPU);
158 }
159
NodeIsOnGpu(const NodeDef * node)160 bool NodeIsOnGpu(const NodeDef* node) {
161 string task, device;
162 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
163 absl::StartsWith(device, DEVICE_GPU);
164 }
165
NumOutputs(const NodeDef & node,GraphDef * graph)166 int NumOutputs(const NodeDef& node, GraphDef* graph) {
167 int num_outputs = 0;
168 const OpDef* op_def = nullptr;
169 auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
170 if (status.ok()) {
171 for (const auto& output : op_def->output_arg()) {
172 if (!output.type_list_attr().empty()) {
173 num_outputs +=
174 node.attr().at(output.type_list_attr()).list().type_size();
175 } else if (!output.number_attr().empty()) {
176 num_outputs += node.attr().at(output.number_attr()).i();
177 } else {
178 num_outputs++;
179 }
180 }
181 } else {
182 FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
183 auto status = fdef.LookUpOpDef(node.op(), &op_def);
184 if (status.ok()) {
185 num_outputs = op_def->output_arg_size();
186 }
187 }
188 return num_outputs;
189 }
190
HasControlInputs(const NodeDef & node)191 bool HasControlInputs(const NodeDef& node) {
192 const int num_inputs = node.input_size();
193 if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
194 return true;
195 }
196 return false;
197 }
198
HasRegularInputs(const NodeDef & node)199 bool HasRegularInputs(const NodeDef& node) {
200 const int num_inputs = node.input_size();
201 if (num_inputs > 0 && !IsControlInput(node.input(0))) {
202 return true;
203 }
204 return false;
205 }
206
NumNonControlInputs(const NodeDef & node)207 int NumNonControlInputs(const NodeDef& node) {
208 int num_inputs = 0;
209 for (; num_inputs < node.input_size(); ++num_inputs) {
210 const string& input = node.input(num_inputs);
211 if (IsControlInput(input)) {
212 return num_inputs;
213 }
214 }
215 return num_inputs;
216 }
217
NumControlInputs(const NodeDef & node)218 int NumControlInputs(const NodeDef& node) {
219 int num_inputs = 0;
220 for (; num_inputs < node.input_size(); ++num_inputs) {
221 const string& input = node.input(node.input_size() - num_inputs - 1);
222 if (!IsControlInput(input)) {
223 return num_inputs;
224 }
225 }
226 return num_inputs;
227 }
228
HasRegularOutputs(const NodeDef & node,const NodeMap & node_map)229 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
230 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
231 for (const string& node_as_input : output->input()) {
232 if (IsControlInput(node_as_input)) break;
233
234 TensorId tensor = ParseTensorName(node_as_input);
235 if (tensor.node() == node.name()) {
236 return true;
237 }
238 }
239 }
240 return false;
241 }
242
HasControlOutputs(const NodeDef & node,const NodeMap & node_map)243 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
244 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
245 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
246 const string& node_as_input = output->input(idx);
247 if (!IsControlInput(node_as_input)) break;
248
249 TensorId tensor = ParseTensorName(node_as_input);
250 if (tensor.node() == node.name()) {
251 return true;
252 }
253 }
254 }
255 return false;
256 }
257
NumControlOutputs(const NodeDef & node,const NodeMap & node_map)258 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
259 int num_outputs = 0;
260 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
261 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
262 const string& node_as_input = output->input(idx);
263 if (!IsControlInput(node_as_input)) break;
264
265 TensorId tensor = ParseTensorName(node_as_input);
266 if (tensor.node() == node.name()) {
267 ++num_outputs;
268 }
269 }
270 }
271 return num_outputs;
272 }
273
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)274 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
275 int num_outputs = 0;
276 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
277 for (const string& node_as_input : output->input()) {
278 if (IsControlInput(node_as_input)) {
279 break;
280 }
281 if (node_as_input == node.name()) {
282 ++num_outputs;
283 } else {
284 const TensorId tensor = ParseTensorName(node_as_input);
285 if (tensor.node() == node.name()) {
286 ++num_outputs;
287 }
288 }
289 }
290 }
291 return num_outputs;
292 }
293
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)294 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
295 int num_data_outputs = 0;
296 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
297 if (IsShapeConsumer(*output)) continue;
298
299 for (int i = 0; i < output->input_size(); ++i) {
300 const string& input = output->input(i);
301 if (!IsControlInput(input) && NodeName(input) == node.name()) {
302 ++num_data_outputs;
303 break;
304 }
305 }
306 }
307 return num_data_outputs;
308 }
309
310 // Returns the data type in attribute `attr_name` of `node`. If that attribute
311 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)312 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
313 if (!node.attr().count(type_attr)) {
314 return DT_INVALID;
315 }
316 const auto& attr = node.attr().at(type_attr);
317 if (attr.value_case() != AttrValue::kType) {
318 return DT_INVALID;
319 }
320 return attr.type();
321 }
322
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)323 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
324 bool follow_control_input,
325 const std::function<bool(const NodeDef&)>& pred_fn) {
326 const NodeDef* current = &source;
327 const NodeDef* next = current;
328 while (next == &source || (next != nullptr && pred_fn(*next))) {
329 current = next;
330 if (current->input_size() == 0 ||
331 (!follow_control_input && IsControlInput(current->input(0)))) {
332 break;
333 }
334 next = node_map.GetNode(current->input(0));
335 if (next == nullptr) {
336 LOG(ERROR) << "Node not found: " << current->input(0);
337 }
338 }
339 return const_cast<NodeDef*>(current);
340 }
341
342 // Every permutation is a product of one or more cycles. Iterate over the cycles
343 // in the permutation, and convert each of those into a product of
344 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)345 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
346 bool invert_permutation) {
347 CHECK_EQ(graph->node_size(), permutation->size());
348 std::vector<int> inv_perm(permutation->size(), 0);
349 if (invert_permutation) {
350 for (size_t n = 0; n < permutation->size(); ++n) {
351 inv_perm[(*permutation)[n]] = n;
352 }
353 permutation->swap(inv_perm);
354 }
355 for (int n = 0, end = permutation->size(); n + 1 < end; ++n) {
356 while (n != (*permutation)[n]) {
357 std::size_t r = (*permutation)[n];
358 graph->mutable_node()->SwapElements(n, r);
359 std::swap((*permutation)[n], (*permutation)[r]);
360 }
361 }
362 }
363
DedupControlInputs(NodeDef * node)364 void DedupControlInputs(NodeDef* node) {
365 absl::flat_hash_set<string> inputs;
366 int pos = 0;
367 while (pos < node->input_size()) {
368 const string& input = node->input(pos);
369 if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
370 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
371 node->mutable_input()->RemoveLast();
372 } else {
373 ++pos;
374 }
375 }
376 }
377
378 namespace {
379
380 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)381 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
382 GraphDef* graph) {
383 static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
384 "Need to pass container of ints");
385
386 int last = graph->node_size() - 1;
387 for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
388 const int index = *it;
389 graph->mutable_node()->SwapElements(index, last);
390 last--;
391 }
392 graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
393 }
394
395 template <typename T>
STLSortAndRemoveDuplicates(T * v)396 inline void STLSortAndRemoveDuplicates(T* v) {
397 std::sort(v->begin(), v->end());
398 v->erase(std::unique(v->begin(), v->end()), v->end());
399 }
400
401 } // namespace
402
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)403 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
404 GraphDef* graph) {
405 EraseNodesFromGraphImpl(nodes_to_delete, graph);
406 }
407
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)408 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
409 STLSortAndRemoveDuplicates(&nodes_to_delete);
410 EraseNodesFromGraphImpl(nodes_to_delete, graph);
411 }
412
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)413 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
414 GraphDef* graph) {
415 std::vector<int> nodes_idx_to_delete;
416 nodes_idx_to_delete.reserve(nodes_to_delete.size());
417 for (int i = 0; i < graph->node_size(); ++i) {
418 if (nodes_to_delete.count(graph->node(i).name()))
419 nodes_idx_to_delete.push_back(i);
420 }
421 EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
422 }
423
424 #define HANDLE_DOUBLE_CASE(DTYPE) \
425 case DTYPE: \
426 if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
427 static_cast<double>(value), tensor)) { \
428 return errors::InvalidArgument("Cannot store value ", value, \
429 " in tensor of type " #DTYPE); \
430 } \
431 break
432
433 #define HANDLE_INT_CASE(DTYPE) \
434 case DTYPE: \
435 if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
436 tensor)) { \
437 return errors::InvalidArgument("Cannot store value ", value, \
438 " in tensor of type " #DTYPE); \
439 } \
440 break
441
SetTensorValue(DataType dtype,int value,Tensor * tensor)442 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
443 // TODO(rmlarsen): Support more general shapes.
444 // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
445 if (tensor->NumElements() != 1) {
446 return errors::InvalidArgument(
447 "Expected scalar tensor, got num_elements = ", tensor->NumElements());
448 }
449 switch (dtype) {
450 HANDLE_DOUBLE_CASE(DT_HALF);
451 HANDLE_DOUBLE_CASE(DT_BFLOAT16);
452 HANDLE_DOUBLE_CASE(DT_BOOL);
453 HANDLE_DOUBLE_CASE(DT_FLOAT);
454 HANDLE_DOUBLE_CASE(DT_DOUBLE);
455 HANDLE_DOUBLE_CASE(DT_UINT8);
456 HANDLE_DOUBLE_CASE(DT_INT8);
457 HANDLE_DOUBLE_CASE(DT_UINT16);
458 HANDLE_DOUBLE_CASE(DT_INT16);
459 HANDLE_DOUBLE_CASE(DT_INT32);
460 HANDLE_DOUBLE_CASE(DT_INT64);
461 HANDLE_DOUBLE_CASE(DT_COMPLEX64);
462 HANDLE_DOUBLE_CASE(DT_COMPLEX128);
463 HANDLE_INT_CASE(DT_QINT8);
464 HANDLE_INT_CASE(DT_QUINT8);
465 HANDLE_INT_CASE(DT_QINT16);
466 HANDLE_INT_CASE(DT_QUINT16);
467 HANDLE_INT_CASE(DT_QINT32);
468 default:
469 return errors::InvalidArgument("Unsupported type ",
470 DataTypeString(dtype));
471 }
472 return Status::OK();
473 }
474
475 #undef HANDLE_CASE
476
CheckAttrExists(const NodeDef & node,const string & key)477 Status CheckAttrExists(const NodeDef& node, const string& key) {
478 if (!HasNodeAttr(node, key)) {
479 return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
480 "' attr: ", node.ShortDebugString());
481 }
482 return Status::OK();
483 }
484
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)485 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
486 for (const string& key : keys) {
487 TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
488 }
489 return Status::OK();
490 }
491
IsKernelRegisteredForNode(absl::string_view node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,absl::string_view node_op,absl::string_view node_device,AttrSlice node_attrs)492 Status IsKernelRegisteredForNode(
493 absl::string_view node_name, bool has_experimental_debug_info,
494 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
495 absl::string_view node_op, absl::string_view node_device,
496 AttrSlice node_attrs) {
497 DeviceNameUtils::ParsedName parsed_name;
498 if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
499 return errors::InvalidArgument("Could not parse device name: ",
500 node_device);
501 }
502 return FindKernelDef(DeviceType(parsed_name.type), node_name,
503 has_experimental_debug_info, experimental_debug_info,
504 node_op, node_device, node_attrs, nullptr, nullptr);
505 }
506
IsKernelRegisteredForNode(const NodeDef & node)507 Status IsKernelRegisteredForNode(const NodeDef& node) {
508 return IsKernelRegisteredForNode(node.name(),
509 node.has_experimental_debug_info(),
510 node.experimental_debug_info(), node.op(),
511 node.device(), AttrSlice(&node.attr()));
512 }
513
514 namespace {
RemoveAttributes(const std::vector<absl::string_view> & to_remove,NodeDef * node)515 void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
516 NodeDef* node) {
517 if (to_remove.size() == node->attr_size()) {
518 node->clear_attr();
519 } else {
520 for (const auto& key : to_remove) {
521 node->mutable_attr()->erase(string(key));
522 }
523 }
524 }
525 } // namespace
526
EraseRegularNodeAttributes(NodeDef * node)527 int EraseRegularNodeAttributes(NodeDef* node) {
528 std::vector<absl::string_view> to_remove;
529 for (const auto& attr : node->attr()) {
530 if (!attr.first.empty() && (attr.first)[0] != '_') {
531 to_remove.push_back(attr.first);
532 }
533 }
534 RemoveAttributes(to_remove, node);
535 return to_remove.size();
536 }
537
EraseNodeOutputAttributes(NodeDef * node)538 int EraseNodeOutputAttributes(NodeDef* node) {
539 std::vector<absl::string_view> to_remove;
540 for (const auto& attr : node->attr()) {
541 const string& attr_name = attr.first;
542 if (attr_name == "_xla_inferred_shapes" ||
543 absl::StartsWith(attr_name, "_output_")) {
544 to_remove.push_back(attr_name);
545 }
546 }
547 RemoveAttributes(to_remove, node);
548 return to_remove.size();
549 }
550
551 } // end namespace grappler
552 } // end namespace tensorflow
553