1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils.h"
17
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/scanner.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/notification.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39
40 namespace tensorflow {
41 namespace grappler {
42 namespace {
43 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)44 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
45 using RealType = typename Eigen::NumTraits<T>::Real;
46 if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
47 value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
48 return false;
49 }
50 tensor->flat<T>()(0) = static_cast<T>(value);
51 return true;
52 }
53
54 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)55 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
56 using RealType = typename Eigen::NumTraits<T>::Real;
57 if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
58 value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
59 return false;
60 }
61 tensor->flat<T>()(0) = static_cast<T>(value);
62 return true;
63 }
64
65 // Is 'node' an operator that consumes only the shape of its input, not the
66 // data itself?
67 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
68 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)69 bool IsShapeConsumer(const NodeDef& node) {
70 const string& op = node.op();
71 return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
72 }
73
74 } // namespace
75
TensorIdToString(const TensorId & tensor_id)76 string TensorIdToString(const TensorId& tensor_id) {
77 return tensor_id.index() == 0 ? string(tensor_id.node())
78 : tensor_id.ToString();
79 }
80
SafeTensorIdToString(const SafeTensorId & tensor_id)81 string SafeTensorIdToString(const SafeTensorId& tensor_id) {
82 return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
83 }
84
IsSameInput(const string & name1,const string & name2)85 bool IsSameInput(const string& name1, const string& name2) {
86 if (name1 == name2) return true;
87 TensorId tensor1 = ParseTensorName(name1);
88 TensorId tensor2 = ParseTensorName(name2);
89 return tensor1 == tensor2;
90 }
91
IsControlInput(absl::string_view name)92 bool IsControlInput(absl::string_view name) {
93 return !name.empty() && name[0] == '^';
94 }
95
IsControlInput(const TensorId & tensor_id)96 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
97
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)98 string AddPrefixToNodeName(const string& name, const string& prefix,
99 const string& delimiter) {
100 if (!name.empty()) {
101 if (name[0] == '^') {
102 return absl::StrCat("^", prefix, delimiter, name.substr(1));
103 }
104 }
105 return absl::StrCat(prefix, delimiter, name);
106 }
107
AddPrefixToNodeName(const string & name,const string & prefix)108 string AddPrefixToNodeName(const string& name, const string& prefix) {
109 return AddPrefixToNodeName(name, prefix, "/");
110 }
111
ExecuteWithTimeout(std::function<void ()> fn,const int64_t timeout_in_ms,thread::ThreadPool * const thread_pool)112 bool ExecuteWithTimeout(std::function<void()> fn, const int64_t timeout_in_ms,
113 thread::ThreadPool* const thread_pool) {
114 if (timeout_in_ms <= 0) {
115 fn();
116 return true;
117 }
118 auto done = std::make_shared<Notification>();
119 thread_pool->Schedule([done, fn]() {
120 fn();
121 done->Notify();
122 });
123 const bool notified =
124 WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
125 return notified;
126 }
127
AsControlDependency(const NodeDef & node)128 string AsControlDependency(const NodeDef& node) {
129 return absl::StrCat("^", node.name());
130 }
131
AsControlDependency(const string & node_name)132 string AsControlDependency(const string& node_name) {
133 CHECK(!node_name.empty());
134 return (!node_name.empty() && node_name[0] == '^')
135 ? node_name
136 : absl::StrCat("^", node_name);
137 }
138
NodeIsOnCpu(const NodeDef * node)139 bool NodeIsOnCpu(const NodeDef* node) {
140 string task, device;
141 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
142 absl::StartsWith(device, DEVICE_CPU);
143 }
144
NodeIsOnGpu(const NodeDef * node)145 bool NodeIsOnGpu(const NodeDef* node) {
146 string task, device;
147 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
148 absl::StartsWith(device, DEVICE_GPU);
149 }
150
NumOutputs(const NodeDef & node,GraphDef * graph)151 int NumOutputs(const NodeDef& node, GraphDef* graph) {
152 int num_outputs = 0;
153 const OpDef* op_def = nullptr;
154 auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
155 if (status.ok()) {
156 for (const auto& output : op_def->output_arg()) {
157 if (!output.type_list_attr().empty()) {
158 num_outputs +=
159 node.attr().at(output.type_list_attr()).list().type_size();
160 } else if (!output.number_attr().empty()) {
161 num_outputs += node.attr().at(output.number_attr()).i();
162 } else {
163 num_outputs++;
164 }
165 }
166 } else {
167 FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
168 auto status = fdef.LookUpOpDef(node.op(), &op_def);
169 if (status.ok()) {
170 num_outputs = op_def->output_arg_size();
171 }
172 }
173 return num_outputs;
174 }
175
HasControlInputs(const NodeDef & node)176 bool HasControlInputs(const NodeDef& node) {
177 const int num_inputs = node.input_size();
178 if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
179 return true;
180 }
181 return false;
182 }
183
HasRegularInputs(const NodeDef & node)184 bool HasRegularInputs(const NodeDef& node) {
185 const int num_inputs = node.input_size();
186 if (num_inputs > 0 && !IsControlInput(node.input(0))) {
187 return true;
188 }
189 return false;
190 }
191
NumNonControlInputs(const NodeDef & node)192 int NumNonControlInputs(const NodeDef& node) {
193 int num_inputs = 0;
194 for (; num_inputs < node.input_size(); ++num_inputs) {
195 const string& input = node.input(num_inputs);
196 if (IsControlInput(input)) {
197 return num_inputs;
198 }
199 }
200 return num_inputs;
201 }
202
NumControlInputs(const NodeDef & node)203 int NumControlInputs(const NodeDef& node) {
204 int num_inputs = 0;
205 for (; num_inputs < node.input_size(); ++num_inputs) {
206 const string& input = node.input(node.input_size() - num_inputs - 1);
207 if (!IsControlInput(input)) {
208 return num_inputs;
209 }
210 }
211 return num_inputs;
212 }
213
HasRegularOutputs(const NodeDef & node,const NodeMap & node_map)214 bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
215 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
216 for (const string& node_as_input : output->input()) {
217 if (IsControlInput(node_as_input)) break;
218
219 TensorId tensor = ParseTensorName(node_as_input);
220 if (tensor.node() == node.name()) {
221 return true;
222 }
223 }
224 }
225 return false;
226 }
227
HasControlOutputs(const NodeDef & node,const NodeMap & node_map)228 bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
229 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
230 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
231 const string& node_as_input = output->input(idx);
232 if (!IsControlInput(node_as_input)) break;
233
234 TensorId tensor = ParseTensorName(node_as_input);
235 if (tensor.node() == node.name()) {
236 return true;
237 }
238 }
239 }
240 return false;
241 }
242
NumControlOutputs(const NodeDef & node,const NodeMap & node_map)243 int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
244 int num_outputs = 0;
245 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
246 for (int idx = output->input_size() - 1; idx >= 0; --idx) {
247 const string& node_as_input = output->input(idx);
248 if (!IsControlInput(node_as_input)) break;
249
250 TensorId tensor = ParseTensorName(node_as_input);
251 if (tensor.node() == node.name()) {
252 ++num_outputs;
253 }
254 }
255 }
256 return num_outputs;
257 }
258
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
260 int num_outputs = 0;
261 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
262 for (const string& node_as_input : output->input()) {
263 if (IsControlInput(node_as_input)) {
264 break;
265 }
266 if (node_as_input == node.name()) {
267 ++num_outputs;
268 } else {
269 const TensorId tensor = ParseTensorName(node_as_input);
270 if (tensor.node() == node.name()) {
271 ++num_outputs;
272 }
273 }
274 }
275 }
276 return num_outputs;
277 }
278
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)279 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
280 int num_data_outputs = 0;
281 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
282 if (IsShapeConsumer(*output)) continue;
283
284 for (int i = 0; i < output->input_size(); ++i) {
285 const string& input = output->input(i);
286 if (!IsControlInput(input) && NodeName(input) == node.name()) {
287 ++num_data_outputs;
288 break;
289 }
290 }
291 }
292 return num_data_outputs;
293 }
294
295 // Returns the data type in attribute `attr_name` of `node`. If that attribute
296 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)297 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
298 if (!node.attr().count(type_attr)) {
299 return DT_INVALID;
300 }
301 const auto& attr = node.attr().at(type_attr);
302 if (attr.value_case() != AttrValue::kType) {
303 return DT_INVALID;
304 }
305 return attr.type();
306 }
307
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)308 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
309 bool follow_control_input,
310 const std::function<bool(const NodeDef&)>& pred_fn) {
311 const NodeDef* current = &source;
312 const NodeDef* next = current;
313 while (next == &source || (next != nullptr && pred_fn(*next))) {
314 current = next;
315 if (current->input_size() == 0 ||
316 (!follow_control_input && IsControlInput(current->input(0)))) {
317 break;
318 }
319 next = node_map.GetNode(current->input(0));
320 if (next == nullptr) {
321 LOG(ERROR) << "Node not found: " << current->input(0);
322 }
323 }
324 return const_cast<NodeDef*>(current);
325 }
326
327 // Every permutation is a product of one or more cycles. Iterate over the cycles
328 // in the permutation, and convert each of those into a product of
329 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)330 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
331 bool invert_permutation) {
332 CHECK_EQ(graph->node_size(), permutation->size());
333 std::vector<int> inv_perm(permutation->size(), 0);
334 if (invert_permutation) {
335 for (size_t n = 0; n < permutation->size(); ++n) {
336 inv_perm[(*permutation)[n]] = n;
337 }
338 permutation->swap(inv_perm);
339 }
340 for (int n = 0, end = permutation->size(); n + 1 < end; ++n) {
341 while (n != (*permutation)[n]) {
342 std::size_t r = (*permutation)[n];
343 graph->mutable_node()->SwapElements(n, r);
344 std::swap((*permutation)[n], (*permutation)[r]);
345 }
346 }
347 }
348
DedupControlInputs(NodeDef * node)349 void DedupControlInputs(NodeDef* node) {
350 absl::flat_hash_set<string> inputs;
351 int pos = 0;
352 while (pos < node->input_size()) {
353 const string& input = node->input(pos);
354 if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
355 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
356 node->mutable_input()->RemoveLast();
357 } else {
358 ++pos;
359 }
360 }
361 }
362
363 namespace {
364
365 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)366 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
367 GraphDef* graph) {
368 static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
369 "Need to pass container of ints");
370
371 int last = graph->node_size() - 1;
372 for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
373 const int index = *it;
374 graph->mutable_node()->SwapElements(index, last);
375 last--;
376 }
377 graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
378 }
379
380 template <typename T>
STLSortAndRemoveDuplicates(T * v)381 inline void STLSortAndRemoveDuplicates(T* v) {
382 std::sort(v->begin(), v->end());
383 v->erase(std::unique(v->begin(), v->end()), v->end());
384 }
385
386 } // namespace
387
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)388 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
389 GraphDef* graph) {
390 EraseNodesFromGraphImpl(nodes_to_delete, graph);
391 }
392
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)393 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
394 STLSortAndRemoveDuplicates(&nodes_to_delete);
395 EraseNodesFromGraphImpl(nodes_to_delete, graph);
396 }
397
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)398 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
399 GraphDef* graph) {
400 std::vector<int> nodes_idx_to_delete;
401 nodes_idx_to_delete.reserve(nodes_to_delete.size());
402 for (int i = 0; i < graph->node_size(); ++i) {
403 if (nodes_to_delete.count(graph->node(i).name()))
404 nodes_idx_to_delete.push_back(i);
405 }
406 EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
407 }
408
409 #define HANDLE_DOUBLE_CASE(DTYPE) \
410 case DTYPE: \
411 if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
412 static_cast<double>(value), tensor)) { \
413 return errors::InvalidArgument("Cannot store value ", value, \
414 " in tensor of type " #DTYPE); \
415 } \
416 break
417
418 #define HANDLE_INT_CASE(DTYPE) \
419 case DTYPE: \
420 if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
421 tensor)) { \
422 return errors::InvalidArgument("Cannot store value ", value, \
423 " in tensor of type " #DTYPE); \
424 } \
425 break
426
SetTensorValue(DataType dtype,int value,Tensor * tensor)427 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
428 // TODO(rmlarsen): Support more general shapes.
429 // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
430 if (tensor->NumElements() != 1) {
431 return errors::InvalidArgument(
432 "Expected scalar tensor, got num_elements = ", tensor->NumElements());
433 }
434 switch (dtype) {
435 HANDLE_DOUBLE_CASE(DT_HALF);
436 HANDLE_DOUBLE_CASE(DT_BFLOAT16);
437 HANDLE_DOUBLE_CASE(DT_BOOL);
438 HANDLE_DOUBLE_CASE(DT_FLOAT);
439 HANDLE_DOUBLE_CASE(DT_DOUBLE);
440 HANDLE_DOUBLE_CASE(DT_UINT8);
441 HANDLE_DOUBLE_CASE(DT_INT8);
442 HANDLE_DOUBLE_CASE(DT_UINT16);
443 HANDLE_DOUBLE_CASE(DT_INT16);
444 HANDLE_DOUBLE_CASE(DT_INT32);
445 HANDLE_DOUBLE_CASE(DT_INT64);
446 HANDLE_DOUBLE_CASE(DT_COMPLEX64);
447 HANDLE_DOUBLE_CASE(DT_COMPLEX128);
448 HANDLE_INT_CASE(DT_QINT8);
449 HANDLE_INT_CASE(DT_QUINT8);
450 HANDLE_INT_CASE(DT_QINT16);
451 HANDLE_INT_CASE(DT_QUINT16);
452 HANDLE_INT_CASE(DT_QINT32);
453 default:
454 return errors::InvalidArgument("Unsupported type ",
455 DataTypeString(dtype));
456 }
457 return OkStatus();
458 }
459
460 #undef HANDLE_CASE
461
CheckAttrExists(const NodeDef & node,const string & key)462 Status CheckAttrExists(const NodeDef& node, const string& key) {
463 if (!HasNodeAttr(node, key)) {
464 return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
465 "' attr: ", node.ShortDebugString());
466 }
467 return OkStatus();
468 }
469
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)470 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
471 for (const string& key : keys) {
472 TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
473 }
474 return OkStatus();
475 }
476
IsKernelRegisteredForNode(absl::string_view node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,absl::string_view node_op,absl::string_view node_device,AttrSlice node_attrs)477 Status IsKernelRegisteredForNode(
478 absl::string_view node_name, bool has_experimental_debug_info,
479 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
480 absl::string_view node_op, absl::string_view node_device,
481 AttrSlice node_attrs) {
482 DeviceNameUtils::ParsedName parsed_name;
483 if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
484 return errors::InvalidArgument("Could not parse device name: ",
485 node_device);
486 }
487 return FindKernelDef(DeviceType(parsed_name.type), node_name,
488 has_experimental_debug_info, experimental_debug_info,
489 node_op, node_device, node_attrs, nullptr, nullptr);
490 }
491
IsKernelRegisteredForNode(const NodeDef & node)492 Status IsKernelRegisteredForNode(const NodeDef& node) {
493 return IsKernelRegisteredForNode(node.name(),
494 node.has_experimental_debug_info(),
495 node.experimental_debug_info(), node.op(),
496 node.device(), AttrSlice(&node.attr()));
497 }
498
499 namespace {
RemoveAttributes(const std::vector<absl::string_view> & to_remove,NodeDef * node)500 void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
501 NodeDef* node) {
502 if (to_remove.size() == node->attr_size()) {
503 node->clear_attr();
504 } else {
505 for (const auto& key : to_remove) {
506 node->mutable_attr()->erase(string(key));
507 }
508 }
509 }
510 } // namespace
511
EraseRegularNodeAttributes(NodeDef * node)512 int EraseRegularNodeAttributes(NodeDef* node) {
513 std::vector<absl::string_view> to_remove;
514 for (const auto& attr : node->attr()) {
515 if (!attr.first.empty() && (attr.first)[0] != '_') {
516 to_remove.push_back(attr.first);
517 }
518 }
519 RemoveAttributes(to_remove, node);
520 return to_remove.size();
521 }
522
EraseNodeOutputAttributes(NodeDef * node)523 int EraseNodeOutputAttributes(NodeDef* node) {
524 std::vector<absl::string_view> to_remove;
525 for (const auto& attr : node->attr()) {
526 const string& attr_name = attr.first;
527 if (attr_name == "_xla_inferred_shapes" ||
528 absl::StartsWith(attr_name, "_output_")) {
529 to_remove.push_back(attr_name);
530 }
531 }
532 RemoveAttributes(to_remove, node);
533 return to_remove.size();
534 }
535
536 } // end namespace grappler
537 } // end namespace tensorflow
538