1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils.h"
17
18 #include <iterator>
19 #include <memory>
20 #include <queue>
21 #include <vector>
22
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/notification.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40 namespace grappler {
41 namespace {
42 template <typename T>
SafeSetDoubleScalarTensorValue(double value,Tensor * tensor)43 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
44 using RealType = typename Eigen::NumTraits<T>::Real;
45 if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
46 value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
47 return false;
48 }
49 tensor->flat<T>()(0) = static_cast<T>(value);
50 return true;
51 }
52
53 template <typename T>
SafeSetIntScalarTensorValue(int value,Tensor * tensor)54 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
55 using RealType = typename Eigen::NumTraits<T>::Real;
56 if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
57 value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
58 return false;
59 }
60 tensor->flat<T>()(0) = static_cast<T>(value);
61 return true;
62 }
63
64 // Is 'node' an operator that consumes only the shape of its input, not the
65 // data itself?
66 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
67 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
IsShapeConsumer(const NodeDef & node)68 bool IsShapeConsumer(const NodeDef& node) {
69 const string& op = node.op();
70 return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
71 }
72
73 } // namespace
74
NodeMap(GraphDef * graph)75 NodeMap::NodeMap(GraphDef* graph) {
76 CHECK(graph != nullptr);
77 for (int i = 0; i < graph->node_size(); i++) {
78 NodeDef* node = graph->mutable_node(i);
79 const string& node_name = node->name();
80 auto rslt = nodes_.emplace(node_name, node);
81 // Check that the graph doesn't contain multiple nodes with the same name.
82 if (!rslt.second) {
83 LOG(WARNING) << "Duplicated node in the graph: " << node_name;
84 }
85 for (const auto& input : node->input()) {
86 outputs_[NodeName(input)].insert(nodes_[node_name]);
87 }
88 }
89 }
90
RemoveNode(const string & name)91 void NodeMap::RemoveNode(const string& name) {
92 nodes_.erase(NodeName(name));
93 outputs_.erase(NodeName(name));
94 }
95
GetNode(const string & name) const96 NodeDef* NodeMap::GetNode(const string& name) const {
97 const string node_name = NodeName(name);
98 auto it = nodes_.find(node_name);
99 if (it == nodes_.end()) {
100 return nullptr;
101 }
102 return it->second;
103 }
104
NodeExists(const string & name) const105 bool NodeMap::NodeExists(const string& name) const {
106 const string node_name = NodeName(name);
107 return nodes_.find(node_name) != nodes_.end();
108 }
109
GetOutputs(const string & node_name) const110 const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
111 auto it = outputs_.find(node_name);
112 if (it == outputs_.end()) {
113 return empty_set_;
114 }
115 return it->second;
116 }
117
AddNode(const string & node_name,NodeDef * node)118 void NodeMap::AddNode(const string& node_name, NodeDef* node) {
119 auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
120 CHECK(ret.second) << "Pair (" << node_name << "," << node
121 << ") is not inserted because the same key already exists.";
122 }
123
AddOutput(const string & node_name,const string & output_name)124 void NodeMap::AddOutput(const string& node_name, const string& output_name) {
125 auto output_node = nodes_[NodeName(output_name)];
126 CHECK(output_node) << "Output node " << output_name
127 << " is missing in NodeMap.";
128 outputs_[node_name].insert(output_node);
129 }
130
RemoveOutput(const string & node_name,const string & output_name)131 void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
132 outputs_[node_name].erase(nodes_[NodeName(output_name)]);
133 }
134
UpdateInput(const string & node_name,const string & old_input_name,const string & new_input_name)135 void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
136 const string& new_input_name) {
137 RemoveOutput(NodeName(old_input_name), node_name);
138 AddOutput(NodeName(new_input_name), node_name);
139 }
140
RemoveInputs(const string & node_name)141 void NodeMap::RemoveInputs(const string& node_name) {
142 auto node = nodes_[node_name];
143 for (const auto& input : node->input()) {
144 RemoveOutput(NodeName(input), node->name());
145 }
146 }
147
RemoveOutputs(const string & node_name)148 void NodeMap::RemoveOutputs(const string& node_name) {
149 outputs_.erase(node_name);
150 }
151
UpdateOutput(const string & node_name,const string & old_output_name,const string & new_output_name)152 void NodeMap::UpdateOutput(const string& node_name,
153 const string& old_output_name,
154 const string& new_output_name) {
155 std::set<NodeDef*>& outputs = outputs_[node_name];
156 outputs.erase(nodes_[NodeName(old_output_name)]);
157 outputs.insert(nodes_[NodeName(new_output_name)]);
158 }
159
TensorIdToString(const TensorId & tensor_id)160 string TensorIdToString(const TensorId& tensor_id) {
161 return tensor_id.index() == 0 ? string(tensor_id.node())
162 : tensor_id.ToString();
163 }
164
IsSameInput(const string & name1,const string & name2)165 bool IsSameInput(const string& name1, const string& name2) {
166 if (name1 == name2) return true;
167 TensorId tensor1 = ParseTensorName(name1);
168 TensorId tensor2 = ParseTensorName(name2);
169 return tensor1 == tensor2;
170 }
171
IsControlInput(const string & name)172 bool IsControlInput(const string& name) {
173 return !name.empty() && name[0] == '^';
174 }
175
IsControlInput(const TensorId & tensor_id)176 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
177
AddPrefixToNodeName(const string & name,const string & prefix,const string & delimiter)178 string AddPrefixToNodeName(const string& name, const string& prefix,
179 const string& delimiter) {
180 if (!name.empty()) {
181 if (name[0] == '^') {
182 return absl::StrCat("^", prefix, delimiter, name.substr(1));
183 }
184 }
185 return absl::StrCat(prefix, delimiter, name);
186 }
187
AddPrefixToNodeName(const string & name,const string & prefix)188 string AddPrefixToNodeName(const string& name, const string& prefix) {
189 return AddPrefixToNodeName(name, prefix, "/");
190 }
191
ExecuteWithTimeout(std::function<void ()> fn,const int64 timeout_in_ms,thread::ThreadPool * const thread_pool)192 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
193 thread::ThreadPool* const thread_pool) {
194 if (timeout_in_ms <= 0) {
195 fn();
196 return true;
197 }
198 auto done = std::make_shared<Notification>();
199 thread_pool->Schedule([done, fn]() {
200 fn();
201 done->Notify();
202 });
203 const bool notified =
204 WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
205 return notified;
206 }
207
AsControlDependency(const NodeDef & node)208 string AsControlDependency(const NodeDef& node) {
209 return absl::StrCat("^", node.name());
210 }
211
AsControlDependency(const string & node_name)212 string AsControlDependency(const string& node_name) {
213 CHECK(!node_name.empty());
214 return (!node_name.empty() && node_name[0] == '^')
215 ? node_name
216 : absl::StrCat("^", node_name);
217 }
218
NodeIsOnCpu(const NodeDef * node)219 bool NodeIsOnCpu(const NodeDef* node) {
220 string task, device;
221 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
222 absl::StartsWith(device, DEVICE_CPU);
223 }
224
NodeIsOnGpu(const NodeDef * node)225 bool NodeIsOnGpu(const NodeDef* node) {
226 string task, device;
227 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
228 absl::StartsWith(device, DEVICE_GPU);
229 }
230
NumOutputs(const NodeDef & node,GraphDef * graph)231 int NumOutputs(const NodeDef& node, GraphDef* graph) {
232 int num_outputs = 0;
233 const OpDef* op_def = nullptr;
234 auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
235 if (status.ok()) {
236 for (const auto& output : op_def->output_arg()) {
237 if (!output.type_list_attr().empty()) {
238 num_outputs +=
239 node.attr().at(output.type_list_attr()).list().type_size();
240 } else if (!output.number_attr().empty()) {
241 num_outputs += node.attr().at(output.number_attr()).i();
242 } else {
243 num_outputs++;
244 }
245 }
246 } else {
247 FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
248 auto status = fdef.LookUpOpDef(node.op(), &op_def);
249 if (status.ok()) {
250 num_outputs = op_def->output_arg_size();
251 }
252 }
253 return num_outputs;
254 }
255
HasControlInputs(const NodeDef & node)256 bool HasControlInputs(const NodeDef& node) {
257 int num_inputs = node.input_size();
258 if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
259 return true;
260 }
261 return false;
262 }
263
NumNonControlInputs(const NodeDef & node)264 int NumNonControlInputs(const NodeDef& node) {
265 int num_inputs = node.input_size();
266 for (const string& input : node.input()) {
267 if (IsControlInput(input)) {
268 --num_inputs;
269 }
270 }
271 return num_inputs;
272 }
273
NumNonControlOutputs(const NodeDef & node,const NodeMap & node_map)274 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
275 int num_outputs = 0;
276 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
277 for (const string& node_as_input : output->input()) {
278 if (IsControlInput(node_as_input)) {
279 break;
280 }
281 if (node_as_input == node.name()) {
282 ++num_outputs;
283 } else {
284 const TensorId tensor = ParseTensorName(node_as_input);
285 if (tensor.node() == node.name()) {
286 ++num_outputs;
287 }
288 }
289 }
290 }
291 return num_outputs;
292 }
293
NumNonControlDataOutputs(const NodeDef & node,const NodeMap & node_map)294 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
295 int num_data_outputs = 0;
296 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
297 if (IsShapeConsumer(*output)) continue;
298
299 for (int i = 0; i < output->input_size(); ++i) {
300 const string& input = output->input(i);
301 if (!IsControlInput(input) && NodeName(input) == node.name()) {
302 ++num_data_outputs;
303 break;
304 }
305 }
306 }
307 return num_data_outputs;
308 }
309
310 // Returns the data type in attribute `attr_name` of `node`. If that attribute
311 // doesn't exist, returns DT_INVALID.
GetDataTypeFromAttr(const NodeDef & node,const string & type_attr)312 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
313 if (!node.attr().count(type_attr)) {
314 return DT_INVALID;
315 }
316 const auto& attr = node.attr().at(type_attr);
317 if (attr.value_case() != AttrValue::kType) {
318 return DT_INVALID;
319 }
320 return attr.type();
321 }
322
GetTailOfChain(const NodeDef & source,const NodeMap & node_map,bool follow_control_input,const std::function<bool (const NodeDef &)> & pred_fn)323 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
324 bool follow_control_input,
325 const std::function<bool(const NodeDef&)>& pred_fn) {
326 const NodeDef* current = &source;
327 const NodeDef* next = current;
328 while (next == &source || (next != nullptr && pred_fn(*next))) {
329 current = next;
330 if (current->input_size() == 0 ||
331 (!follow_control_input && IsControlInput(current->input(0)))) {
332 break;
333 }
334 next = node_map.GetNode(current->input(0));
335 if (next == nullptr) {
336 LOG(ERROR) << "Node not found: " << current->input(0);
337 }
338 }
339 return const_cast<NodeDef*>(current);
340 }
341
342 // Every permutation is a product of one or more cycles. Iterate over the cycles
343 // in the permutation, and convert each of those into a product of
344 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
PermuteNodesInPlace(GraphDef * graph,std::vector<int> * permutation,bool invert_permutation)345 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
346 bool invert_permutation) {
347 CHECK_EQ(graph->node_size(), permutation->size());
348 std::vector<int> inv_perm(permutation->size(), 0);
349 if (invert_permutation) {
350 for (size_t n = 0; n < permutation->size(); ++n) {
351 inv_perm[(*permutation)[n]] = n;
352 }
353 permutation->swap(inv_perm);
354 }
355 for (std::size_t n = 0; n + 1 < permutation->size(); ++n) {
356 while (n != (*permutation)[n]) {
357 std::size_t r = (*permutation)[n];
358 graph->mutable_node()->SwapElements(n, r);
359 std::swap((*permutation)[n], (*permutation)[r]);
360 }
361 }
362 }
363
DedupControlInputs(NodeDef * node)364 void DedupControlInputs(NodeDef* node) {
365 std::unordered_set<string> inputs;
366 int pos = 0;
367 while (pos < node->input_size()) {
368 const string& input = node->input(pos);
369 if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
370 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
371 node->mutable_input()->RemoveLast();
372 } else {
373 ++pos;
374 }
375 }
376 }
377
378 namespace {
379
380 template <typename UniqueContainer>
EraseNodesFromGraphImpl(const UniqueContainer & nodes_to_delete,GraphDef * graph)381 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
382 GraphDef* graph) {
383 static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
384 "Need to pass container of ints");
385
386 int last = graph->node_size() - 1;
387 for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
388 const int index = *it;
389 graph->mutable_node()->SwapElements(index, last);
390 last--;
391 }
392 graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
393 }
394
395 template <typename T>
STLSortAndRemoveDuplicates(T * v)396 inline void STLSortAndRemoveDuplicates(T* v) {
397 std::sort(v->begin(), v->end());
398 v->erase(std::unique(v->begin(), v->end()), v->end());
399 }
400
401 } // namespace
402
EraseNodesFromGraph(const std::set<int> & nodes_to_delete,GraphDef * graph)403 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
404 GraphDef* graph) {
405 EraseNodesFromGraphImpl(nodes_to_delete, graph);
406 }
407
EraseNodesFromGraph(std::vector<int> && nodes_to_delete,GraphDef * graph)408 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
409 STLSortAndRemoveDuplicates(&nodes_to_delete);
410 EraseNodesFromGraphImpl(nodes_to_delete, graph);
411 }
412
EraseNodesFromGraph(const std::set<string> & nodes_to_delete,GraphDef * graph)413 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
414 GraphDef* graph) {
415 std::vector<int> nodes_idx_to_delete;
416 nodes_idx_to_delete.reserve(nodes_to_delete.size());
417 for (int i = 0; i < graph->node_size(); ++i) {
418 if (nodes_to_delete.count(graph->node(i).name()))
419 nodes_idx_to_delete.push_back(i);
420 }
421 EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
422 }
423
424 #define HANDLE_DOUBLE_CASE(DTYPE) \
425 case DTYPE: \
426 if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
427 static_cast<double>(value), tensor)) { \
428 return errors::InvalidArgument("Cannot store value ", value, \
429 " in tensor of type " #DTYPE); \
430 } \
431 break
432
433 #define HANDLE_INT_CASE(DTYPE) \
434 case DTYPE: \
435 if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
436 tensor)) { \
437 return errors::InvalidArgument("Cannot store value ", value, \
438 " in tensor of type " #DTYPE); \
439 } \
440 break
441
SetTensorValue(DataType dtype,int value,Tensor * tensor)442 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
443 // TODO(rmlarsen): Support more general shapes.
444 // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
445 if (tensor->NumElements() != 1) {
446 return errors::InvalidArgument(
447 "Expected scalar tensor, got num_elements = ", tensor->NumElements());
448 }
449 switch (dtype) {
450 HANDLE_DOUBLE_CASE(DT_HALF);
451 HANDLE_DOUBLE_CASE(DT_BFLOAT16);
452 HANDLE_DOUBLE_CASE(DT_BOOL);
453 HANDLE_DOUBLE_CASE(DT_FLOAT);
454 HANDLE_DOUBLE_CASE(DT_DOUBLE);
455 HANDLE_DOUBLE_CASE(DT_UINT8);
456 HANDLE_DOUBLE_CASE(DT_INT8);
457 HANDLE_DOUBLE_CASE(DT_UINT16);
458 HANDLE_DOUBLE_CASE(DT_INT16);
459 HANDLE_DOUBLE_CASE(DT_INT32);
460 HANDLE_DOUBLE_CASE(DT_INT64);
461 HANDLE_DOUBLE_CASE(DT_COMPLEX64);
462 HANDLE_DOUBLE_CASE(DT_COMPLEX128);
463 HANDLE_INT_CASE(DT_QINT8);
464 HANDLE_INT_CASE(DT_QUINT8);
465 HANDLE_INT_CASE(DT_QINT16);
466 HANDLE_INT_CASE(DT_QUINT16);
467 HANDLE_INT_CASE(DT_QINT32);
468 default:
469 return errors::InvalidArgument("Unsupported type ",
470 DataTypeString(dtype));
471 }
472 return Status::OK();
473 }
474
475 #undef HANDLE_CASE
476
CheckAttrExists(const NodeDef & node,const string & key)477 Status CheckAttrExists(const NodeDef& node, const string& key) {
478 if (!HasNodeAttr(node, key)) {
479 return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
480 "' attr: ", node.ShortDebugString());
481 }
482 return Status::OK();
483 }
484
CheckAttrsExist(const NodeDef & node,absl::Span<const string> keys)485 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
486 for (const string& key : keys) {
487 TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
488 }
489 return Status::OK();
490 }
491
IsKernelRegisteredForNode(const NodeDef & node)492 Status IsKernelRegisteredForNode(const NodeDef& node) {
493 DeviceNameUtils::ParsedName parsed_name;
494 if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
495 return errors::InvalidArgument("Could not parse device name: ",
496 node.device());
497 }
498 return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr);
499 }
500
501 } // end namespace grappler
502 } // end namespace tensorflow
503