1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
17
18 #include <unordered_set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
33
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37
IsTrivialIdentity(const NodeDef & node,const GraphView & graph_view)38 bool IsTrivialIdentity(const NodeDef& node, const GraphView& graph_view) {
39 for (const auto input :
40 graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
41 if (input.port_id == Graph::kControlSlot) {
42 // Node is driven by control dependency.
43 return false;
44 } else if (IsSwitch(*input.node)) { // Node is driven by switch.
45 return false;
46 }
47 }
48 for (const auto output :
49 graph_view.GetFanouts(node, /*include_controlled_nodes=*/true)) {
50 if (output.port_id == Graph::kControlSlot) {
51 // Node drives control dependency.
52 return false;
53 } else if (IsMerge(*output.node)) { // Node feeds merge.
54 return false;
55 }
56 }
57 return true;
58 }
59
IsTrivialOp(const NodeDef & node,const GraphView & graph_view)60 bool IsTrivialOp(const NodeDef& node, const GraphView& graph_view) {
61 // Remove the stop gradient nodes since they serve no purpose once the graph
62 // is built. Also remove Identity ops.
63 if (IsStopGradient(node)) {
64 return true;
65 }
66 if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
67 return IsTrivialIdentity(node, graph_view);
68 }
69 if (IsNoOp(node) && node.input().empty()) {
70 return true;
71 }
72 // Const nodes are always executed before anything else, so if they only
73 // have control outputs we can remove them.
74 if (IsConstant(node) && node.input().empty() &&
75 graph_view.NumFanouts(node, /*include_controlled_nodes=*/false) == 0) {
76 return true;
77 }
78 return IsAddN(node) && NumNonControlInputs(node) <= 1;
79 }
80
RemovalIncreasesEdgeCount(const NodeDef & node,const GraphView & graph_view)81 bool RemovalIncreasesEdgeCount(const NodeDef& node,
82 const GraphView& graph_view) {
83 int in_degree =
84 graph_view.NumFanins(node, /*include_controlling_nodes=*/true);
85 int out_degree =
86 graph_view.NumFanouts(node, /*include_controlled_nodes=*/true);
87 return in_degree * out_degree > in_degree + out_degree;
88 }
89
IsOutputPortRefValue(const NodeDef & node,int port_id,const OpRegistryInterface & op_registry)90 bool IsOutputPortRefValue(const NodeDef& node, int port_id,
91 const OpRegistryInterface& op_registry) {
92 const OpRegistrationData* op_reg_data = nullptr;
93 Status s = op_registry.LookUp(node.op(), &op_reg_data);
94 if (s.ok()) {
95 DataType output_type;
96 s = OutputTypeForNode(node, op_reg_data->op_def, port_id, &output_type);
97 if (s.ok() && IsRefType(output_type)) {
98 return true;
99 }
100 }
101 return false;
102 }
103
CanRemoveNode(const NodeDef & node,const GraphView & graph_view,const absl::flat_hash_set<string> & function_names,const OpRegistryInterface & op_registry)104 bool CanRemoveNode(const NodeDef& node, const GraphView& graph_view,
105 const absl::flat_hash_set<string>& function_names,
106 const OpRegistryInterface& op_registry) {
107 if (IsNoOp(node) &&
108 (node.input().empty() ||
109 graph_view.NumFanouts(node, /*include_controlled_nodes=*/true) == 0)) {
110 return true;
111 }
112 if (IsConstant(node) && node.input().empty() &&
113 graph_view.NumFanouts(node, /*include_controlled_nodes=*/false) == 0) {
114 return true;
115 }
116 if (RemovalIncreasesEdgeCount(node, graph_view)) {
117 return false;
118 }
119 for (const auto input :
120 graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
121 if (node.device() != input.node->device()) {
122 // Node is driven by a different device.
123 return false;
124 } else if (input.port_id == Graph::kControlSlot) {
125 // Node is driven by control dependency.
126 continue;
127 } else if (function_names.find(input.node->op()) != function_names.end()) {
128 // Node input is a function call.
129 return false;
130 } else if (IsOutputPortRefValue(*input.node, input.port_id, op_registry)) {
131 return false;
132 }
133 }
134 for (const auto output :
135 graph_view.GetFanouts(node, /*include_controlled_nodes=*/false)) {
136 if (function_names.find(output.node->op()) != function_names.end()) {
137 // Node output is a function call.
138 return false;
139 }
140 }
141 return true;
142 }
143
ForwardInputsInternal(const NodeDef & node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,bool add_as_control,NodeDef * new_node,const absl::flat_hash_map<string,const NodeDef * > & optimized_nodes,const GraphView & graph_view)144 void ForwardInputsInternal(
145 const NodeDef& node,
146 const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
147 bool add_as_control, NodeDef* new_node,
148 const absl::flat_hash_map<string, const NodeDef*>& optimized_nodes,
149 const GraphView& graph_view) {
150 // To speed things up, use the optimized version of the node if
151 // available.
152 auto itr = optimized_nodes.find(node.name());
153 if (itr != optimized_nodes.end()) {
154 for (const string& input : itr->second->input()) {
155 *new_node->add_input() =
156 add_as_control ? AsControlDependency(NodeName(input)) : input;
157 }
158 return;
159 }
160 for (const auto& input : node.input()) {
161 const NodeDef* input_node = graph_view.GetNode(NodeName(input));
162 if (input_node == nullptr) {
163 // Invalid input, preserve it as is.
164 *new_node->add_input() =
165 add_as_control ? AsControlDependency(NodeName(input)) : input;
166 continue;
167 }
168 if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
169 ForwardInputsInternal(*input_node, nodes_to_delete,
170 add_as_control || IsControlInput(input), new_node,
171 optimized_nodes, graph_view);
172 } else {
173 *new_node->add_input() =
174 add_as_control ? AsControlDependency(NodeName(input)) : input;
175 }
176 }
177 }
178
ForwardInputs(const NodeDef & original_node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,NodeDef * new_node,absl::flat_hash_map<string,const NodeDef * > * optimized_nodes,const GraphView & graph_view)179 void ForwardInputs(const NodeDef& original_node,
180 const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
181 NodeDef* new_node,
182 absl::flat_hash_map<string, const NodeDef*>* optimized_nodes,
183 const GraphView& graph_view) {
184 // Forwards inputs of nodes to be deleted to their respective outputs.
185 ForwardInputsInternal(original_node, nodes_to_delete,
186 /*add_as_control=*/false, new_node, *optimized_nodes,
187 graph_view);
188 if (!new_node->name().empty()) {
189 (*optimized_nodes)[new_node->name()] = new_node;
190 }
191 // Reorder inputs such that control inputs come after regular inputs.
192 int pos = 0;
193 for (int i = 0; i < new_node->input_size(); ++i) {
194 if (!IsControlInput(new_node->input(i))) {
195 new_node->mutable_input()->SwapElements(pos, i);
196 ++pos;
197 }
198 }
199 DedupControlInputs(new_node);
200 }
201
IdentityNTerminalPorts(const NodeMap & node_map,const std::vector<string> & terminal_nodes,int graph_size)202 absl::flat_hash_map<string, absl::flat_hash_set<int>> IdentityNTerminalPorts(
203 const NodeMap& node_map, const std::vector<string>& terminal_nodes,
204 int graph_size) {
205 // Determines which ports for IdentityN nodes (that can be rewritten) lead to
206 // a terminal node.
207 std::vector<string> to_visit;
208 to_visit.reserve(graph_size);
209 // Set terminal nodes as visited so terminal nodes that may be IdentityN don't
210 // get pruned later on.
211 absl::flat_hash_set<string> visited(terminal_nodes.begin(),
212 terminal_nodes.end());
213 for (const string& terminal_node : terminal_nodes) {
214 NodeDef* node = node_map.GetNode(terminal_node);
215 if (node == nullptr) {
216 continue;
217 }
218 for (const string& input : node->input()) {
219 to_visit.push_back(input);
220 }
221 }
222
223 absl::flat_hash_set<string> identity_n_fanouts;
224 while (!to_visit.empty()) {
225 string curr = to_visit.back();
226 to_visit.pop_back();
227 NodeDef* curr_node = node_map.GetNode(curr);
228 if (curr_node == nullptr ||
229 visited.find(curr_node->name()) != visited.end()) {
230 continue;
231 }
232 // For IdentityN nodes, only traverse up through the port that comes from a
233 // terminal node along with control inputs. The IdentityN node is not marked
234 // as visited so other node input traversals can go through the other ports
235 // of the IdentityN node.
236 if (IsIdentityN(*curr_node)) {
237 if (identity_n_fanouts.find(curr) == identity_n_fanouts.end()) {
238 identity_n_fanouts.emplace(curr);
239 int pos = NodePositionIfSameNode(curr, curr_node->name());
240 if (pos >= 0) {
241 to_visit.push_back(curr_node->input(pos));
242 }
243 for (const string& input : curr_node->input()) {
244 if (IsControlInput(input) &&
245 identity_n_fanouts.find(input) == identity_n_fanouts.end()) {
246 to_visit.push_back(input);
247 }
248 }
249 }
250 } else {
251 for (const string& input : curr_node->input()) {
252 to_visit.push_back(input);
253 }
254 visited.emplace(curr_node->name());
255 }
256 }
257
258 absl::flat_hash_map<string, absl::flat_hash_set<int>> identity_n_ports;
259 for (const auto& fanout : identity_n_fanouts) {
260 int pos;
261 string node_name = ParseNodeName(fanout, &pos);
262 if (node_name.empty() || pos < 0) { // Exclude control inputs.
263 continue;
264 }
265 if (identity_n_ports.find(node_name) == identity_n_ports.end()) {
266 identity_n_ports[node_name] = {pos};
267 } else {
268 identity_n_ports[node_name].emplace(pos);
269 }
270 }
271
272 return identity_n_ports;
273 }
274
NewIdentityFromIdentityN(int pos,const NodeDef & identity_n,GraphDef * graph,NodeMap * node_map)275 string NewIdentityFromIdentityN(int pos, const NodeDef& identity_n,
276 GraphDef* graph, NodeMap* node_map) {
277 // TODO(lyandy): Migrate over to GrapplerOptimizerStage and use
278 // OptimizedNodeName for new node name.
279 string new_node_name =
280 strings::StrCat(identity_n.name(), "-", pos, "-grappler-ModelPruner");
281 if (node_map->NodeExists(new_node_name)) {
282 return "";
283 }
284 NodeDef* new_node = graph->add_node();
285 Status status = NodeDefBuilder(new_node_name, "Identity")
286 .Input(identity_n.input(pos), 0,
287 identity_n.attr().at("T").list().type(pos))
288 .Device(identity_n.device())
289 .Finalize(new_node);
290 if (!status.ok()) {
291 return "";
292 }
293 node_map->AddNode(new_node->name(), new_node);
294 node_map->AddOutput(NodeName(new_node->input(0)), new_node->name());
295 return new_node->name();
296 }
297
RewriteIdentityNAndInputsOutputs(NodeDef * node,int num_non_control_inputs,const absl::flat_hash_set<int> & terminal_ports,GraphDef * graph,NodeMap * node_map)298 Status RewriteIdentityNAndInputsOutputs(
299 NodeDef* node, int num_non_control_inputs,
300 const absl::flat_hash_set<int>& terminal_ports, GraphDef* graph,
301 NodeMap* node_map) {
302 // Rewrite IdentityN node and associated inputs and outputs. For inputs and
303 // outputs that don't lead to a terminal node, a new Identity node is created
304 // and those inputs and outputs are rewritten to use the new Identity node as
305 // their outputs and inputs respectively. For the remaining nodes, the outputs
306 // have their inputs updated with the adjusted port, from the IdentityN node
307 // having less inputs.
308 struct NodeOutputUpdate {
309 string input;
310 string output;
311 };
312
313 absl::flat_hash_map<int, int> terminal_input_pos;
314 absl::flat_hash_map<int, string> new_identities;
315 int new_idx = 0;
316 for (int i = 0; i < num_non_control_inputs; i++) {
317 if (terminal_ports.find(i) != terminal_ports.end()) {
318 terminal_input_pos[i] = new_idx++;
319 } else {
320 string identity = NewIdentityFromIdentityN(i, *node, graph, node_map);
321 if (identity.empty()) {
322 // Fail early when creating Identity from IdentityN errors.
323 return errors::Internal(
324 "Could not create Identity node from IdentityN node ", node->name(),
325 " at port ", i);
326 }
327 new_identities[i] = identity;
328 }
329 }
330
331 std::vector<NodeOutputUpdate> updates;
332 for (NodeDef* output : node_map->GetOutputs(node->name())) {
333 for (int i = 0; i < output->input_size(); i++) {
334 string input = output->input(i);
335 if (IsControlInput(input)) {
336 continue;
337 }
338 TensorId input_tensor = ParseTensorName(input);
339 if (input_tensor.node() == node->name()) {
340 if (terminal_ports.find(input_tensor.index()) == terminal_ports.end()) {
341 // Replace input that does not lead to a terminal node with newly
342 // created identity.
343 string new_identity = new_identities[input_tensor.index()];
344 output->set_input(i, new_identity);
345 updates.push_back({new_identity, output->name()});
346 } else {
347 // Update input ports that lead to a terminal node from splitting
348 // inputs.
349 int new_pos = terminal_input_pos[input_tensor.index()];
350 string updated_input_name =
351 new_pos > 0 ? strings::StrCat(node->name(), ":", new_pos)
352 : node->name();
353 output->set_input(i, updated_input_name);
354 }
355 }
356 }
357 }
358
359 for (const NodeOutputUpdate& update : updates) {
360 node_map->AddOutput(update.input, update.output);
361 }
362
363 // Update inputs and types by removing inputs that were split away from
364 // main IdentityN node.
365 const int num_inputs = node->input_size();
366 int curr_pos = 0;
367 auto mutable_inputs = node->mutable_input();
368 auto mutable_types =
369 node->mutable_attr()->at("T").mutable_list()->mutable_type();
370 for (int i = 0; i < num_non_control_inputs; i++) {
371 if (terminal_input_pos.find(i) != terminal_input_pos.end()) {
372 mutable_inputs->SwapElements(i, curr_pos);
373 mutable_types->SwapElements(i, curr_pos);
374 curr_pos++;
375 }
376 }
377 mutable_types->Truncate(curr_pos);
378 // Control inputs.
379 for (int i = num_non_control_inputs; i < num_inputs; i++) {
380 mutable_inputs->SwapElements(i, curr_pos++);
381 }
382 mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
383
384 return Status::OK();
385 }
386
SplitIdentityNInputs(GraphDef * graph,const std::vector<string> & terminal_nodes,bool * updated_graph)387 Status SplitIdentityNInputs(GraphDef* graph,
388 const std::vector<string>& terminal_nodes,
389 bool* updated_graph) {
390 // For inputs of IdentityN nodes that do not lead to a terminal node, remove
391 // them from IdentityN and create new individual Identity nodes. This will
392 // allow ModelPruner to possibly remove nodes in the transitive fanin of the
393 // newly created Identity nodes.
394 NodeMap node_map(graph);
395
396 for (auto const& terminal :
397 IdentityNTerminalPorts(node_map, terminal_nodes, graph->node_size())) {
398 NodeDef* node = node_map.GetNode(terminal.first);
399 if (node == nullptr) {
400 continue;
401 }
402
403 const int num_non_control_inputs = NumNonControlInputs(*node);
404 const int terminal_second_size = terminal.second.size();
405 if (node->attr().count("T") == 0 ||
406 node->attr().at("T").list().type_size() != num_non_control_inputs ||
407 terminal_second_size >= num_non_control_inputs) {
408 continue;
409 }
410
411 TF_RETURN_IF_ERROR(RewriteIdentityNAndInputsOutputs(
412 node, num_non_control_inputs, terminal.second, graph, &node_map));
413 *updated_graph = true;
414 }
415
416 return Status::OK();
417 }
418
419 } // namespace
420
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)421 Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
422 GraphDef* optimized_graph) {
423 const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
424
425 // Prune all the nodes that won't be executed, ie all the nodes that aren't in
426 // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume
427 // the whole graph might be executed.
428 std::unique_ptr<GraphDef> pruned_graph_release;
429 GraphDef* pruned_graph;
430 if (!nodes_to_preserve.empty()) {
431 pruned_graph_release.reset(new GraphDef());
432 pruned_graph = pruned_graph_release.get();
433 pruned_graph->mutable_node()->Reserve(item.graph.node_size());
434 std::vector<string> terminal_nodes(nodes_to_preserve.begin(),
435 nodes_to_preserve.end());
436 std::sort(terminal_nodes.begin(), terminal_nodes.end());
437 TF_RETURN_IF_ERROR(
438 SetTransitiveFaninGraph(item.graph, pruned_graph, terminal_nodes));
439 bool did_split_identity_n = false;
440 TF_RETURN_IF_ERROR(SplitIdentityNInputs(pruned_graph, terminal_nodes,
441 &did_split_identity_n));
442 if (did_split_identity_n) {
443 GraphDef fanin_split_identity_n_graph;
444 TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(
445 *pruned_graph, &fanin_split_identity_n_graph, terminal_nodes));
446 pruned_graph->Swap(&fanin_split_identity_n_graph);
447 }
448 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
449 } else {
450 pruned_graph = const_cast<GraphDef*>(&item.graph);
451 }
452
453 GraphView graph_view(pruned_graph);
454 absl::flat_hash_set<string> function_names;
455 for (const auto& function : item.graph.library().function()) {
456 function_names.insert(function.signature().name());
457 }
458 OpRegistryInterface* op_registry = OpRegistry::Global();
459
460 // Check if we can further prune the graph, by removing the trivial ops.
461 absl::flat_hash_set<const NodeDef*> nodes_to_delete;
462 for (int i = 0; i < pruned_graph->node_size(); ++i) {
463 NodeDef* node = pruned_graph->mutable_node(i);
464 // Remove redundant control inputs, since they may prevent pruning below.
465 DedupControlInputs(node);
466
467 if (!IsTrivialOp(*node, graph_view)) {
468 VLOG(3) << node->name() << " is not trivial.";
469 continue;
470 }
471
472 // Don't remove nodes that must be preserved.
473 if (nodes_to_preserve.find(node->name()) != nodes_to_preserve.end()) {
474 continue;
475 }
476
477 // - Don't remove nodes that drive control dependencies.
478 // - Don't remove nodes that are driven by control dependencies either since
479 // we can't ensure (yet) that we won't increase the number of control
480 // dependency edges by deleting them (for example, removing a node driven
481 // by 10 control edges and driving 10 control edges would result in the
482 // creation of 100 edges).
483 // - Don't modify nodes that are connected to functions since that can
484 // result in inlining failures later on.
485 // - Don't prune nodes that are driven by another device since these could
486 // be used to reduce cross device communication.
487 // - Don't remove nodes that receive reference values, as those can be
488 // converting references to non-references. It is important to preserve
489 // these non-references since the partitioner will avoid sending
490 // non-references across partitions more than once.
491 if (CanRemoveNode(*node, graph_view, function_names, *op_registry)) {
492 nodes_to_delete.insert(node);
493 } else {
494 VLOG(3) << node->name() << " cannot be removed";
495 }
496 }
497
498 if (nodes_to_delete.empty() && nodes_to_preserve.empty()) {
499 return errors::Aborted("Nothing to do.");
500 }
501
502 optimized_graph->Clear();
503 *optimized_graph->mutable_library() = item.graph.library();
504 *optimized_graph->mutable_versions() = item.graph.versions();
505 if (nodes_to_delete.empty()) {
506 optimized_graph->mutable_node()->Swap(pruned_graph->mutable_node());
507 return Status::OK();
508 }
509
510 const bool fetches_are_known = !item.fetch.empty();
511 absl::flat_hash_map<string, const NodeDef*> optimized_nodes;
512 optimized_graph->mutable_node()->Reserve(pruned_graph->node_size());
513 for (const auto& node : pruned_graph->node()) {
514 if (!fetches_are_known ||
515 nodes_to_delete.find(&node) == nodes_to_delete.end()) {
516 NodeDef* new_node = optimized_graph->add_node();
517 *new_node = node;
518 new_node->clear_input();
519 ForwardInputs(node, nodes_to_delete, new_node, &optimized_nodes,
520 graph_view);
521 }
522 }
523 VLOG(1) << "Pruned " << nodes_to_delete.size()
524 << " nodes from the graph. The graph now contains "
525 << optimized_graph->node_size() << " nodes.";
526 if (optimized_graph->node_size() > item.graph.node_size()) {
527 return errors::Internal("Pruning increased graph size.");
528 }
529 return Status::OK();
530 }
531
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)532 void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
533 const GraphDef& optimized_graph, double result) {
534 // Nothing to do for ModelPruner.
535 }
536
537 } // end namespace grappler
538 } // end namespace tensorflow
539