1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/costs/graph_memory.h"
31 #include "tensorflow/core/grappler/costs/graph_properties.h"
32 #include "tensorflow/core/grappler/costs/utils.h"
33 #include "tensorflow/core/grappler/graph_topology_view.h"
34 #include "tensorflow/core/grappler/grappler_item.h"
35 #include "tensorflow/core/grappler/mutable_graph_view.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
38 #include "tensorflow/core/grappler/utils.h"
39 #include "tensorflow/core/grappler/utils/topological_sort.h"
40 #include "tensorflow/core/grappler/utils/traversal.h"
41 #include "tensorflow/core/lib/math/math_util.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45
46 namespace tensorflow {
47 namespace grappler {
48
49 namespace {
50
51 // Prefix added to nodes which are recomputed.
52 const char* kRecomputedNodePrefix = "Recomputed";
53 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
54 // Attribute which may be added to nodes to manually allow them to be
55 // recomputed.
56 const char* kRecomputeHint = "_recompute_hint";
57
58 // Ops which we wouldn't mind recomputing to save memory.
59 // TODO(allenl): Replace this list with a cost model.
GetCheapToRecomputeOps()60 std::unordered_set<string> GetCheapToRecomputeOps() {
61 std::unordered_set<string> cheap_ops = {
62 "Add", "AddN", "BiasAdd", "Cast", "Fill",
63 "FloorDiv", "FloorMod", "FusedBatchNorm", "Mul", "Neg",
64 "RealDiv", "Reciprocal", "Relu", "Relu6", "Reshape",
65 "Rsqrt", "Sigmoid", "Sqrt", "Square", "SquaredDifference",
66 "Sub", "Tile", "Transpose"};
67 return cheap_ops;
68 }
69
70 // Find recomputable ops which feed into target nodes.
FindCandidateRecomputeNodes(const NodeMap & node_map,const GraphDef * graph,const std::function<bool (const NodeDef &)> & is_candidate,const std::function<bool (const NodeDef &)> & is_target)71 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
72 const NodeMap& node_map, const GraphDef* graph,
73 const std::function<bool(const NodeDef&)>& is_candidate,
74 const std::function<bool(const NodeDef&)>& is_target) {
75 std::unordered_set<const NodeDef*> candidate_recompute_nodes;
76 for (const auto& node : graph->node()) {
77 if (!is_candidate(node)) {
78 continue;
79 }
80 bool has_target_output = false;
81 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
82 // It only makes sense to recompute this if it feeds into a target
83 // node. We expand this to dependencies in GetOpGroupsToRecompute.
84 if (is_target(*output)) {
85 has_target_output = true;
86 break;
87 }
88 }
89 if (!has_target_output) {
90 continue;
91 }
92 bool has_target_input = false;
93 for (const string& input_name : node.input()) {
94 // Don't recompute nodes which depend on target nodes.
95 const NodeDef* input_node = node_map.GetNode(input_name);
96 if (is_target(*input_node)) {
97 has_target_input = true;
98 break;
99 }
100 }
101 if (has_target_input) {
102 continue;
103 }
104 candidate_recompute_nodes.insert(&node);
105 }
106 return candidate_recompute_nodes;
107 }
108
connected_subgraph(const NodeMap & node_map,bool collect_inputs,bool collect_outputs,const std::function<bool (const NodeDef &)> & is_candidate,std::unordered_set<const NodeDef * > * expanded_nodes)109 void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
110 bool collect_outputs,
111 const std::function<bool(const NodeDef&)>& is_candidate,
112 std::unordered_set<const NodeDef*>* expanded_nodes) {
113 std::queue<const NodeDef*> to_visit;
114 for (const NodeDef* starting_node : *expanded_nodes) {
115 to_visit.push(starting_node);
116 }
117 expanded_nodes->clear();
118 while (!to_visit.empty()) {
119 const NodeDef* current_node = to_visit.front();
120 to_visit.pop();
121 if (!expanded_nodes->insert(current_node).second) {
122 // We already visited this node
123 continue;
124 }
125 if (collect_inputs) {
126 // Add inputs and outputs to this subgraph if they are candidates
127 for (const string& input_name_raw : current_node->input()) {
128 const NodeDef* input_node = node_map.GetNode(input_name_raw);
129 if (expanded_nodes->count(input_node) == 0 &&
130 is_candidate(*input_node)) {
131 to_visit.push(input_node);
132 }
133 }
134 }
135 if (collect_outputs) {
136 for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
137 if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
138 to_visit.push(output);
139 }
140 }
141 }
142 }
143 }
144
145 struct RecomputedSubGraph {
146 std::unordered_set<const NodeDef*> recomputed_source_nodes;
147 std::unordered_set<NodeDef*> target_nodes;
148 };
149
150 // Find groups of ops to recompute together based on `should_recompute`.
GetOpGroupsToRecompute(const GraphDef * graph,const NodeMap & node_map,const std::function<bool (const NodeDef &)> & should_recompute,const std::function<bool (const NodeDef &)> & is_target)151 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
152 const GraphDef* graph, const NodeMap& node_map,
153 const std::function<bool(const NodeDef&)>& should_recompute,
154 const std::function<bool(const NodeDef&)>& is_target) {
155 std::unordered_set<const NodeDef*> visited_nodes;
156 std::vector<RecomputedSubGraph> subgraphs_to_recompute;
157 std::unordered_set<const NodeDef*> candidate_recompute_nodes =
158 FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
159 for (const NodeDef* recompute_node : candidate_recompute_nodes) {
160 if (visited_nodes.count(recompute_node) > 0) {
161 continue;
162 }
163 RecomputedSubGraph current_recomputation;
164 // Build out recomputation groups by expanding to inexpensive-to-recompute
165 // nodes which do not feed target nodes. The goal is to capture some
166 // intermediate activations within this graph.
167 std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
168 unpruned_recompute_nodes.insert(recompute_node);
169 connected_subgraph(node_map,
170 true, // Collect inputs
171 true, // Collect outputs
172 should_recompute, &unpruned_recompute_nodes);
173 visited_nodes.insert(unpruned_recompute_nodes.begin(),
174 unpruned_recompute_nodes.end());
175 for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
176 bool inserted_feed = false;
177 for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
178 if (is_target(*output)) {
179 current_recomputation.target_nodes.insert(output);
180 if (!inserted_feed) {
181 // Keep track of nodes which feed directly into a target node. These
182 // and nodes which feed into them will define the recomputed
183 // subgraph.
184 current_recomputation.recomputed_source_nodes.insert(
185 recompute_node);
186 inserted_feed = true;
187 }
188 }
189 }
190 }
191 // Recompute only nodes which eventually feed into a target node.
192 connected_subgraph(
193 node_map,
194 true, // Collect inputs
195 false, // Collect outputs
196 [&unpruned_recompute_nodes](const NodeDef& node) {
197 return unpruned_recompute_nodes.count(&node) != 0;
198 },
199 ¤t_recomputation.recomputed_source_nodes);
200 if (current_recomputation.target_nodes.empty()) {
201 continue;
202 }
203 subgraphs_to_recompute.push_back(current_recomputation);
204 }
205 return subgraphs_to_recompute;
206 }
207
208 // Computes the maximum topological numbers of (1) target node components
209 // (gradient nodes being fed by the recomputation), and (2) child recompute node
210 // components for each recomputed node. We will not attach any control
211 // dependencies to a recomputation unless they have component numbers greater
212 // than this value (to prevent cycles).
GetMaxDownstreamComponents(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components)213 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
214 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
215 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
216 const std::unordered_map<const NodeDef*, int>& components) {
217 std::unordered_map<const NodeDef*, int> recomputed_node_components;
218 // Start by setting component numbers to the maximum among target nodes.
219 for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
220 int max_target_component = -1;
221 for (NodeDef* output :
222 node_map.GetOutputs(original_recompute_node->name())) {
223 if (target_nodes.count(output) != 0) {
224 int current_target_component = components.find(output)->second;
225 if (current_target_component > max_target_component) {
226 max_target_component = current_target_component;
227 }
228 }
229 }
230 if (max_target_component > -1) {
231 recomputed_node_components[original_recompute_node] =
232 max_target_component;
233 }
234 }
235 // Sort recomputed nodes topologically (based on the original graph) so we can
236 // efficiently assign to each node the maximum of its recomputed child
237 // components and its own targets.
238 std::vector<const NodeDef*> recomputed_source_nodes_topological(
239 recomputed_source_nodes.begin(), recomputed_source_nodes.end());
240 std::sort(recomputed_source_nodes_topological.begin(),
241 recomputed_source_nodes_topological.end(),
242 [&components](const NodeDef* first, const NodeDef* second) {
243 return components.find(first)->second <
244 components.find(second)->second;
245 });
246 for (const NodeDef* original_recompute_node :
247 recomputed_source_nodes_topological) {
248 int max_component;
249 auto recomputed_component_iterator =
250 recomputed_node_components.find(original_recompute_node);
251 if (recomputed_component_iterator != recomputed_node_components.end()) {
252 max_component = recomputed_component_iterator->second;
253 } else {
254 max_component = -1;
255 }
256 for (NodeDef* output :
257 node_map.GetOutputs(original_recompute_node->name())) {
258 if (recomputed_source_nodes.count(output) == 0) {
259 continue;
260 }
261 auto child_component_iterator = recomputed_node_components.find(output);
262 CHECK(child_component_iterator != recomputed_node_components.end());
263 int child_component = child_component_iterator->second;
264 if (child_component > max_component) {
265 max_component = child_component;
266 }
267 }
268 CHECK_GE(max_component, 0);
269 recomputed_node_components[original_recompute_node] = max_component;
270 }
271 return recomputed_node_components;
272 }
273
274 // Modifies `graph`, adding trigger nodes and returning a mapping from
275 // `recomputed_source_nodes` to trigger nodes which will not create loops in the
276 // graph (using the component numberings in `components` and
277 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
278 // recomputed_source_nodes, which are the originals) eventually get these
279 // control dependencies.
280 std::unordered_map<const NodeDef*, const NodeDef*>
AddRecomputeControlDependencyNodes(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,const std::unordered_map<const NodeDef *,int> & recomputed_node_max_feed_components,GraphDef * graph)281 AddRecomputeControlDependencyNodes(
282 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
283 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
284 const std::unordered_map<const NodeDef*, int>& components,
285 const std::unordered_map<const NodeDef*, int>&
286 recomputed_node_max_feed_components,
287 GraphDef* graph) {
288 // Sort recomputed nodes based on max downstream components.
289 std::vector<const NodeDef*> recomputed_source_nodes_topological(
290 recomputed_source_nodes.begin(), recomputed_source_nodes.end());
291 std::sort(recomputed_source_nodes_topological.begin(),
292 recomputed_source_nodes_topological.end(),
293 [&recomputed_node_max_feed_components](const NodeDef* first,
294 const NodeDef* second) {
295 int first_component =
296 recomputed_node_max_feed_components.find(first)->second;
297 int second_component =
298 recomputed_node_max_feed_components.find(second)->second;
299 return first_component > second_component
300 // Ensure a consistent ordering. This is necessary because
301 // we're working not with node component numbers (which are
302 // unique) but with the maximum across nodes they feed into
303 // (very much not unique).
304 || (first_component == second_component &&
305 first->name() > second->name());
306 });
307 // Create merged control dependency nodes by sorting target inputs
308 // topologically and zipper merging with the sorted recomputed nodes.
309 std::vector<const NodeDef*> target_inputs_topological;
310 for (const NodeDef* target_node : target_nodes) {
311 for (const string& target_input_name_raw : target_node->input()) {
312 const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
313 // If this node has already had one of its inputs recomputed during this
314 // rewriting pass, we ignore that recomputed node here (it will not be in
315 // the NodeMap).
316 if (target_input == nullptr ||
317 recomputed_source_nodes.count(target_input) != 0 ||
318 components.find(target_node)->second ==
319 components.find(target_input)->second) {
320 continue;
321 }
322 target_inputs_topological.push_back(target_input);
323 }
324 }
325 std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
326 [&components](const NodeDef* first, const NodeDef* second) {
327 return components.find(first)->second >
328 components.find(second)->second;
329 });
330 auto target_input_iterator = target_inputs_topological.begin();
331 NodeDef* current_trigger_node = nullptr;
332 std::unordered_map<const NodeDef*, const NodeDef*> triggers;
333 for (const NodeDef* original_recomputed_node :
334 recomputed_source_nodes_topological) {
335 NodeDef* new_trigger_node = graph->add_node();
336 new_trigger_node->set_name(AddPrefixToNodeName(
337 original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
338 new_trigger_node->set_op("NoOp");
339 new_trigger_node->set_device(original_recomputed_node->device());
340 if (current_trigger_node != nullptr) {
341 *new_trigger_node->add_input() =
342 strings::StrCat("^", current_trigger_node->name());
343 }
344 current_trigger_node = new_trigger_node;
345 triggers[original_recomputed_node] = current_trigger_node;
346 for (;
347 target_input_iterator != target_inputs_topological.end() &&
348 components.find(*target_input_iterator)->second >
349 recomputed_node_max_feed_components.find(original_recomputed_node)
350 ->second;
351 ++target_input_iterator) {
352 *current_trigger_node->add_input() =
353 strings::StrCat("^", (*target_input_iterator)->name());
354 VLOG(2) << " Recomputation trigger " << current_trigger_node->name()
355 << " depends on " << (*target_input_iterator)->name();
356 }
357 }
358 return triggers;
359 }
360
RecomputedOrOriginalNodeName(const std::unordered_set<string> & recomputed_node_names,const string & original_node_name)361 string RecomputedOrOriginalNodeName(
362 const std::unordered_set<string>& recomputed_node_names,
363 const string& original_node_name) {
364 if (recomputed_node_names.find(original_node_name) ==
365 recomputed_node_names.end()) {
366 return original_node_name;
367 } else {
368 return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix);
369 }
370 }
371
372 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
373 // from recomputed_source_nodes to target_nodes are changed to start from the
374 // recomputed nodes.
RecomputeSubgraph(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,GraphDef * graph)375 void RecomputeSubgraph(
376 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
377 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
378 const std::unordered_map<const NodeDef*, int>& components,
379 GraphDef* graph) {
380 std::unordered_set<string> recomputed_node_names;
381 VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
382 << " node subgraph";
383 std::unordered_map<const NodeDef*, int> recomputed_node_components =
384 GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
385 node_map, components);
386 for (const NodeDef* original_node : recomputed_source_nodes) {
387 VLOG(2) << " " << original_node->name();
388 recomputed_node_names.insert(original_node->name());
389 }
390 std::unordered_map<const NodeDef*, const NodeDef*> triggers =
391 AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
392 node_map, components,
393 recomputed_node_components, graph);
394 // Create the recomputed sub-graph
395 for (const NodeDef* original_node : recomputed_source_nodes) {
396 NodeDef* copied_node = graph->add_node();
397 copied_node->set_name(
398 AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix));
399 copied_node->set_op(original_node->op());
400 *copied_node->mutable_attr() = original_node->attr();
401 copied_node->set_device(original_node->device());
402 for (const string& original_input_name : original_node->input()) {
403 // Set inputs which are internal to the copied subgraph to their copied
404 // versions.
405 *copied_node->add_input() = RecomputedOrOriginalNodeName(
406 recomputed_node_names, original_input_name);
407 }
408 // Each recomputed node gets a control dependency to prevent it from being
409 // recomputed immediately.
410 *copied_node->add_input() =
411 strings::StrCat("^", triggers[original_node]->name());
412 }
413 // Set the inputs of nodes in the target subgraph to the recomputed nodes
414 // where applicable.
415 for (NodeDef* target_node : target_nodes) {
416 for (string& target_input_name : *target_node->mutable_input()) {
417 target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names,
418 target_input_name);
419 }
420 }
421 }
422
RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,const string & recomputation_targets_name_scope,GraphDef * graph,const GrapplerItem & item)423 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
424 const string& recomputation_targets_name_scope,
425 GraphDef* graph, const GrapplerItem& item) {
426 if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
427 optimization_level != RewriterConfig::HEURISTICS &&
428 optimization_level != RewriterConfig::MANUAL) {
429 // Nothing to do
430 return;
431 }
432 // The topological numberings and NodeMap will be stale as soon as we start
433 // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
434 // looks up nodes which were in the original graph, and preserves the graph
435 // topology it's interested in.
436 // We don't use the results of this topological sort until later, but this
437 // call invalidates all NodeDef pointers, so it needs to be done before we
438 // start collecting those.
439 TF_CHECK_OK(TopologicalSort(graph));
440 NodeMap node_map(graph);
441 std::vector<RecomputedSubGraph> recomputed_subgraphs;
442 // Do not recompute nodes which are fed, since the recomputed node would not
443 // take on the fed value (i.e. gradients would be incorrect).
444 std::unordered_set<string> feeds;
445 for (const auto& feed : item.feed) {
446 feeds.insert(NodeName(feed.first));
447 }
448 std::function<bool(const NodeDef&)> is_target =
449 [&recomputation_targets_name_scope](const NodeDef& node) {
450 // Nodes whose inputs we may want to recompute. This matches node names
451 // that contain recomputation_targets_name_scope as a name scope,
452 // meaning it either begins with or contains the name scope.
453 // Defaults to "gradients/" which will match any node names that begins
454 // with "gradients/" or contains "/gradients/".
455 return node.name().find(recomputation_targets_name_scope) == 0 ||
456 node.name().find("/" + recomputation_targets_name_scope) != -1;
457 };
458
459 if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
460 optimization_level == RewriterConfig::HEURISTICS) {
461 // TODO(allenl): Handle ResNet-like architectures better. Right now all of
462 // the cheap forward ops get grouped into a single subgraph which must
463 // execute before gradients start executing (unless layers are manually
464 // separated by identity ops).
465 std::unordered_set<string> cheap_to_recompute_ops =
466 GetCheapToRecomputeOps();
467 recomputed_subgraphs = GetOpGroupsToRecompute(
468 graph, node_map,
469 [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
470 return !is_target(node) && feeds.count(node.name()) == 0 &&
471 (cheap_to_recompute_ops.count(node.op()) > 0 ||
472 node.attr().count(kRecomputeHint) > 0);
473 },
474 is_target);
475 } else if (optimization_level == RewriterConfig::MANUAL) {
476 recomputed_subgraphs = GetOpGroupsToRecompute(
477 graph, node_map,
478 [&feeds, &is_target](const NodeDef& node) {
479 return !is_target(node) && feeds.count(node.name()) == 0 &&
480 node.attr().count(kRecomputeHint) > 0;
481 },
482 is_target);
483 }
484 if (!recomputed_subgraphs.empty()) {
485 std::unordered_map<const NodeDef*, int> topological_numbering;
486 for (int node_number = 0; node_number < graph->node().size();
487 ++node_number) {
488 topological_numbering[graph->mutable_node(node_number)] =
489 graph->node().size() - node_number - 1;
490 }
491 // Duplicate the indicated sub-graphs and set up control dependencies
492 for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
493 RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
494 node_map, topological_numbering, graph);
495 }
496 }
497 }
498
SchedulingPass(Cluster * cluster,GrapplerItem * item)499 bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
500 // Look for AddN nodes (and equivalent) and record input names.
501 MutableGraphView view(&item->graph);
502
503 // It's ok to use immutable GraphTopologyView here, because we do not destroy
504 // any of the nodes in the underlying graph, we only add new nodes.
505 GraphTopologyView graph_topology;
506 Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
507 if (!initialized_topology.ok()) {
508 VLOG(1) << "Failed to initialize graph topology view: "
509 << initialized_topology.error_message();
510 return false;
511 }
512
513 std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
514 for (NodeDef& node : *item->graph.mutable_node()) {
515 if (!IsAddN(node) && node.op() != "AccumulateNV2") {
516 continue;
517 }
518 // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
519 if (view.NumFanins(node, false) <= 2) {
520 continue;
521 }
522 for (const auto& input : view.GetFanins(node, false)) {
523 if (input.node->device() == node.device()) {
524 string tensor_name =
525 strings::StrCat(input.node->name(), ":", input.port_id);
526 addn_list[tensor_name].insert(&node);
527 }
528 }
529 }
530
531 if (addn_list.empty()) {
532 return false;
533 }
534
535 GraphMemory memory(*item);
536 const std::unordered_map<string, DeviceProperties>& devices =
537 cluster->GetDevices();
538 Status s = memory.InferStatically(devices);
539 if (!s.ok()) {
540 VLOG(1) << "Failed to infer memory usage: " << s.error_message();
541 return false;
542 }
543
544 std::unordered_set<NodeDef*> addn_to_rewrite;
545 for (const auto& device : devices) {
546 const string& name = device.first;
547 const DeviceProperties& prop = device.second;
548 if (prop.memory_size() <= 0) {
549 VLOG(1) << "Available memory unknown for device " << name;
550 continue;
551 }
552 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
553
554 if (mem_usage.used_memory <= prop.memory_size() * 0.8) {
555 continue;
556 }
557
558 for (const auto& live : mem_usage.live_tensors) {
559 string tensor_name = strings::StrCat(live.node, ":", live.output_id);
560 auto it = addn_list.find(tensor_name);
561 if (it != addn_list.end()) {
562 addn_to_rewrite.insert(it->second.begin(), it->second.end());
563 }
564 }
565 }
566
567 if (addn_to_rewrite.empty()) {
568 return false;
569 }
570 GraphProperties properties(*item);
571 s = properties.InferStatically(false);
572 if (!s.ok()) {
573 VLOG(1) << "Failed to infer shapes: " << s.error_message();
574 return false;
575 }
576
577 bool updated_graph = false;
578 // Rewrite the AddN.
579 for (NodeDef* node : addn_to_rewrite) {
580 if (!properties.HasOutputProperties(node->name())) {
581 VLOG(1) << "Missing properties for " << node->name();
582 continue;
583 }
584 const TensorShapeProto& shape =
585 properties.GetOutputProperties(node->name())[0].shape();
586 PartialTensorShape shp(shape);
587 if (!shp.IsFullyDefined()) {
588 VLOG(1) << "Shape not fully known for " << node->name();
589 continue;
590 }
591
592 // Compute a topological ordering for the node fanin.
593 std::unordered_map<const NodeDef*, int> topo_order;
594 DfsTraversal(graph_topology, {node}, TraversalDirection::kFollowInputs,
595 DfsCallbacks::PostOrder([&topo_order](const NodeDef* n) {
596 int topo_index = static_cast<int>(topo_order.size());
597 topo_order[n] = topo_index;
598 }));
599
600 std::vector<int> input_topo_index;
601
602 for (int i = 0; i < node->input_size(); ++i) {
603 const string& input = node->input(i);
604 const string node_name = NodeName(input);
605 const NodeDef* node = view.GetNode(node_name);
606 input_topo_index.push_back(topo_order.at(node));
607 }
608 int min_input_topo_index = INT_MAX;
609 int min_input_id = -1;
610 for (int i = 0; i < node->input_size(); ++i) {
611 if (IsControlInput(node->input(i))) {
612 // control inputs are always last.
613 break;
614 }
615 const int current = input_topo_index[i];
616 if (current < min_input_topo_index) {
617 min_input_topo_index = current;
618 min_input_id = i;
619 }
620 }
621 CHECK_LE(0, min_input_id);
622 std::vector<string> pre_ctrl_deps;
623 std::vector<string> post_ctrl_deps;
624 for (int i = node->input_size() - 1; i >= 0; --i) {
625 if (!IsControlInput(node->input(i))) {
626 // control inputs are always last.
627 break;
628 }
629 if (input_topo_index[i] < min_input_topo_index) {
630 // These control dependencies can be executed before the node.
631 pre_ctrl_deps.push_back(node->input(i));
632 } else {
633 // These control dependencies should be executed after the node.
634 post_ctrl_deps.push_back(node->input(i));
635 }
636 }
637
638 DataType dtype = node->attr().at("T").type();
639 const string& device = node->device();
640
641 // Create the temporary variable that will hold intermediate results
642 NodeDef* tmp_var = item->graph.add_node();
643 tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var"));
644 tmp_var->set_op("TemporaryVariable");
645 tmp_var->set_device(device);
646 (*tmp_var->mutable_attr())["dtype"].set_type(dtype);
647 *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
648 (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
649
650 for (const string& ctrl_dep : pre_ctrl_deps) {
651 *tmp_var->add_input() = ctrl_dep;
652 }
653 *tmp_var->add_input() =
654 AsControlDependency(NodeName(node->input(min_input_id)));
655
656 // Initialize it to zero
657 NodeDef* zeros = item->graph.add_node();
658 zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
659 zeros->set_op("ZerosLike");
660 zeros->set_device(device);
661 (*zeros->mutable_attr())["T"].set_type(dtype);
662 *zeros->add_input() = node->input(min_input_id);
663
664 NodeDef* initialize = item->graph.add_node();
665 initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
666 initialize->set_op("Assign");
667 initialize->set_device(device);
668 (*initialize->mutable_attr())["T"].set_type(dtype);
669 (*initialize->mutable_attr())["use_locking"].set_b(false);
670 (*initialize->mutable_attr())["validate_shape"].set_b(false);
671 *initialize->add_input() = tmp_var->name();
672 *initialize->add_input() = zeros->name();
673
674 // Add the assignadd nodes
675 std::vector<NodeDef*> accumulates;
676 for (int i = 0; i < node->input_size(); ++i) {
677 const string& input = node->input(i);
678 if (!IsControlInput(input)) {
679 NodeDef* accumulate = item->graph.add_node();
680 accumulate->set_name(
681 strings::StrCat(node->name(), "/tmp_var_accum_", i));
682 accumulate->set_op("AssignAdd");
683 accumulate->set_device(device);
684 (*accumulate->mutable_attr())["T"].set_type(dtype);
685 (*accumulate->mutable_attr())["use_locking"].set_b(true);
686 *accumulate->add_input() = initialize->name();
687 *accumulate->add_input() = input;
688 accumulates.push_back(accumulate);
689 }
690 }
691
692 // Rewrite the AddN node as a DestroyTemporaryVariable ops
693 node->set_op("DestroyTemporaryVariable");
694 node->clear_input();
695 node->clear_attr();
696 (*node->mutable_attr())["T"].set_type(dtype);
697 (*node->mutable_attr())["var_name"].set_s(tmp_var->name());
698 *node->add_input() = initialize->name();
699 for (const NodeDef* accum : accumulates) {
700 *node->add_input() = AsControlDependency(accum->name());
701 }
702 for (const string& ctrl_dep : post_ctrl_deps) {
703 *node->add_input() = ctrl_dep;
704 }
705
706 updated_graph = true;
707 }
708
709 return updated_graph;
710 }
711
BuildSwapPair(NodeDef * node,int input_to_swap,const std::unordered_map<string,const NodeDef * > & name_map,GraphDef * graph,std::pair<NodeDef *,NodeDef * > * swap_pair)712 Status BuildSwapPair(NodeDef* node, int input_to_swap,
713 const std::unordered_map<string, const NodeDef*>& name_map,
714 GraphDef* graph,
715 std::pair<NodeDef*, NodeDef*>* swap_pair) {
716 string task, device;
717 if (!DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) ||
718 !str_util::StrContains(device, DEVICE_GPU)) {
719 return errors::InvalidArgument("Can't swap input ", input_to_swap,
720 " of node ", node->name(),
721 " since it is not on GPU");
722 }
723 const OpDef* op_def;
724 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
725 DataType input_type;
726 TF_RETURN_IF_ERROR(
727 InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
728 if (IsRefType(input_type)) {
729 return errors::InvalidArgument("Can't swap input ", input_to_swap,
730 " of node ", node->name(),
731 " since it expects a reference");
732 }
733
734 string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
735 string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap);
736 string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap);
737 if (name_map.find(swap_out_name) != name_map.end() ||
738 name_map.find(swap_in_name) != name_map.end()) {
739 return errors::InvalidArgument("Input ", input_to_swap, " of node ",
740 node->name(), " is already swapped");
741 }
742
743 // Force the tensor to be copied to cpu.
744 NodeDef* swap_out_node = graph->add_node();
745 swap_out_node->set_name(swap_out_name);
746 swap_out_node->set_op("_CopyFromGpuToHost");
747
748 // Force the tensor to be restored to the device.
749 NodeDef* swap_in_node = graph->add_node();
750 swap_in_node->set_name(swap_in_name);
751 swap_in_node->set_op("_CopyFromHostToGpu");
752 *swap_in_node->add_input() = swap_out_node->name();
753
754 // Colocate the swap_out_ and swap_in_ nodes with the node itself.
755 swap_out_node->set_device(node->device());
756 swap_in_node->set_device(node->device());
757 string coloc_group = strings::StrCat("loc@", tensor_to_swap);
758 (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
759 (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
760 (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
761
762 (*swap_in_node->mutable_attr())["T"].set_type(input_type);
763 (*swap_out_node->mutable_attr())["T"].set_type(input_type);
764 *swap_pair = std::make_pair(swap_out_node, swap_in_node);
765
766 return Status::OK();
767 }
768
769 struct SwapInfo {
770 std::vector<int> inputs_to_swap;
771 Costs::NanoSeconds time_to_swap = 0;
772 };
773
FindSwapInTrigger(const NodeDef * node,const SwapInfo & swap_info,const std::unordered_map<string,const NodeDef * > & name_map,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)774 static const NodeDef* FindSwapInTrigger(
775 const NodeDef* node, const SwapInfo& swap_info,
776 const std::unordered_map<string, const NodeDef*>& name_map,
777 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
778 execution_times) {
779 // max_trigger_time stores the time before which the swap operation needs to
780 // be started in order to load the data back onto the accelerator without
781 // delaying the downstream computation.
782 Costs::NanoSeconds max_trigger_time(0);
783 std::set<string> possible_inputs;
784 for (int i = 0; i < node->input_size(); ++i) {
785 const string input_node_name = NodeName(node->input(i));
786 auto it1 = name_map.find(input_node_name);
787 if (it1 == name_map.end()) {
788 return nullptr;
789 }
790 const NodeDef* input_node = it1->second;
791
792 auto it2 = execution_times.find(input_node);
793 if (it2 == execution_times.end()) {
794 return nullptr;
795 }
796 max_trigger_time = std::max(max_trigger_time, it2->second);
797 possible_inputs.insert(input_node_name);
798 }
799
800 for (const int i : swap_info.inputs_to_swap) {
801 const string input_node_name = NodeName(node->input(i));
802 possible_inputs.erase(input_node_name);
803 }
804 if (possible_inputs.empty()) {
805 return nullptr;
806 }
807
808 max_trigger_time -= swap_info.time_to_swap;
809
810 std::map<Costs::NanoSeconds, const NodeDef*> candidates;
811 std::set<string> already_processed;
812
813 while (!possible_inputs.empty()) {
814 const string input_node_name = *possible_inputs.begin();
815 possible_inputs.erase(possible_inputs.begin());
816 already_processed.insert(input_node_name);
817 auto it1 = name_map.find(input_node_name);
818 if (it1 == name_map.end()) {
819 return nullptr;
820 }
821 const NodeDef* input_node = it1->second;
822 // Don't jump over frames, since adding a control dependency from one frame
823 // to the next isn't supported. Don't go through branches, since we don't
824 // know whether they'll be executed or not.
825 if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
826 IsMerge(*input_node)) {
827 continue;
828 }
829 auto it2 = execution_times.find(input_node);
830 if (it2 == execution_times.end()) {
831 return nullptr;
832 }
833 if (it2->second < max_trigger_time) {
834 candidates[it2->second] = input_node;
835 } else {
836 for (const string& fanin : input_node->input()) {
837 string name = NodeName(fanin);
838 if (already_processed.find(name) == already_processed.end()) {
839 possible_inputs.insert(name);
840 }
841 }
842 }
843 }
844
845 // Select the candidate that will execute last, since we want to swap the data
846 // back at the last minute while still allowing enough time for data to be
847 // swapped back timely to feed the downstream nodes.
848 if (!candidates.empty()) {
849 return candidates.rbegin()->second;
850 }
851 return nullptr;
852 }
853
IsSwappable(const MutableGraphView & graph,MutableGraphView::OutputPort output)854 static bool IsSwappable(const MutableGraphView& graph,
855 MutableGraphView::OutputPort output) {
856 const NodeDef& node = *output.node;
857 // There is no point in swapping out persistent tensors, since the tensor will
858 // continue to use memory.
859 if (IsPersistent(node)) {
860 return false;
861 }
862
863 const OpDef* op_def;
864 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
865 return false;
866 }
867 DataType dtype;
868 if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
869 return false;
870 }
871 // References can only refer to persistent memory: therefore the node isn't
872 // swappable.
873 if (IsRefType(dtype)) {
874 return false;
875 }
876
877 if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
878 // If placed on the same device, these nodes are just forwarding references
879 // to their input. Therefore they are swappable iff their fanin is swappable
880 // or it resides on a different device.
881 MutableGraphView::InputPort input;
882 input.node = output.node;
883 input.port_id = 0;
884 MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
885 if (fanin.node->device() == node.device()) {
886 return IsSwappable(graph, fanin);
887 }
888 }
889 return true;
890 }
891
FindSwapOutTrigger(const NodeDef * node,int input_id,const MutableGraphView & view,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)892 static NodeDef* FindSwapOutTrigger(
893 const NodeDef* node, int input_id, const MutableGraphView& view,
894 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
895 execution_times) {
896 // Find the output port that generated the tensor to swap.
897 MutableGraphView::InputPort swap;
898 swap.node = const_cast<NodeDef*>(node);
899 swap.port_id = input_id;
900 MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
901 if (!generator.node) {
902 return nullptr;
903 }
904
905 const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
906 view.GetFanout(generator);
907 NodeDef* trigger = nullptr;
908 Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
909
910 for (const auto& port : fanout) {
911 if (port.node == node) {
912 continue;
913 }
914 auto it = execution_times.find(port.node);
915 if (it != execution_times.end() && it->second < earliest_fanout) {
916 earliest_fanout = it->second;
917 trigger = port.node;
918 }
919 }
920
921 return trigger;
922 }
923
IsSwappable(MutableGraphView::InputPort input)924 static bool IsSwappable(MutableGraphView::InputPort input) {
925 const NodeDef& node = *input.node;
926
927 const OpDef* op_def;
928 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
929 return false;
930 }
931
932 DataType dtype;
933 if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) {
934 return false;
935 }
936
937 return !IsRefType(dtype);
938 }
939
940 struct MemInfo {
941 MutableGraphView::OutputPort port;
942 int64 memory_used;
943 std::vector<MutableGraphView::InputPort> uses_left;
944 double fitness;
945
operator <tensorflow::grappler::__anon82b145960111::MemInfo946 bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
947 };
948
IdentifySwappingCandidates(Cluster * cluster,GrapplerItem * item,std::unordered_set<string> * skip_list,std::unordered_map<NodeDef *,SwapInfo> * nodes_to_swap)949 static bool IdentifySwappingCandidates(
950 Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
951 std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
952 GraphMemory memory(*item);
953 const std::unordered_map<string, DeviceProperties>& devices =
954 cluster->GetDevices();
955 Status s = memory.InferStatically(devices);
956 if (!s.ok()) {
957 VLOG(1) << "Failed to infer memory usage: " << s.error_message();
958 return false;
959 }
960
961 bool updated_graph = false;
962 for (const auto& device : devices) {
963 const string& name = device.first;
964 const DeviceProperties& prop = device.second;
965 if (prop.type() != "GPU") {
966 continue;
967 }
968 if (prop.memory_size() <= 0) {
969 VLOG(1) << "Peak memory usage unknown for device " << name;
970 continue;
971 }
972 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
973
974 if (mem_usage.used_memory <= prop.memory_size()) {
975 continue;
976 }
977 int64 required_savings = mem_usage.used_memory - prop.memory_size();
978
979 std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
980 {
981 VirtualCluster vcluster(cluster->GetDevices());
982 if (!vcluster.Provision().ok()) {
983 return false;
984 }
985 if (!vcluster.Initialize(*item).ok()) {
986 return false;
987 }
988 RunMetadata metadata;
989 Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
990 if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
991 return false;
992 }
993
994 for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
995 for (const auto& node_stats : dev_stats.node_stats()) {
996 Costs::NanoSeconds exec_time =
997 Costs::NanoSeconds(1) +
998 Costs::MicroSeconds(node_stats.all_start_micros() +
999 node_stats.op_end_rel_micros());
1000 op_completion_times.emplace(node_stats.node_name(), exec_time);
1001 }
1002 }
1003 }
1004
1005 Costs::Duration peak_time = -1;
1006 for (const auto& live_tensor : mem_usage.live_tensors) {
1007 if (live_tensor.allocation_time > peak_time) {
1008 peak_time = live_tensor.allocation_time;
1009 }
1010 }
1011
1012 std::vector<MemInfo> mem_state;
1013
1014 MutableGraphView graph(&item->graph);
1015 for (const auto& live_tensor : mem_usage.live_tensors) {
1016 if (live_tensor.memory_used <= 1024) {
1017 // Don't bother with small tensors.
1018 continue;
1019 }
1020 if (live_tensor.deallocation_time - live_tensor.allocation_time <=
1021 Costs::Duration(1e6)) {
1022 // Not enough time to swap.
1023 VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
1024 continue;
1025 }
1026
1027 if (skip_list->find(live_tensor.node) != skip_list->end()) {
1028 continue;
1029 }
1030 MutableGraphView::OutputPort port =
1031 graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
1032 if (!IsSwappable(graph, port)) {
1033 continue;
1034 }
1035 MemInfo mem_info;
1036 mem_info.port = port;
1037 mem_info.memory_used = live_tensor.memory_used;
1038 Costs::Duration allocation_time = live_tensor.allocation_time;
1039 Costs::Duration earliest_use(Costs::Duration::infinity());
1040 bool valid = true;
1041 for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
1042 // Get execution time.
1043 auto it = op_completion_times.find(input.node->name());
1044 if (it == op_completion_times.end()) {
1045 valid = false;
1046 break;
1047 }
1048 if (it->second <= peak_time) {
1049 continue;
1050 }
1051
1052 if (skip_list->find(input.node->name()) != skip_list->end()) {
1053 valid = false;
1054 break;
1055 }
1056 string input_name =
1057 strings::StrCat(input.node->name(), ":", input.port_id);
1058 if (skip_list->find(input_name) != skip_list->end()) {
1059 valid = false;
1060 break;
1061 }
1062 if (!IsSwappable(input)) {
1063 valid = false;
1064 break;
1065 }
1066
1067 // Set earliest use time that's after peak.
1068 mem_info.uses_left.emplace_back(input);
1069 earliest_use = std::min(earliest_use, it->second);
1070 }
1071 if (valid && !mem_info.uses_left.empty()) {
1072 // Compute the fitness: we need the tensor to be generated way away of
1073 // the time of peak memory usage (to ensure there is enough time to swap
1074 // it out). We also need to ensure it's used way after the peak time, to
1075 // ensure that swapping the tensor back in won't recreate the memory
1076 // bottleneck. Last but not least, we want the tensor to have as few
1077 // remaining uses as possible.
1078 //
1079 // Note that we must perform the arithmetic inexactly as "double", since
1080 // the values do not fit into any integral type.
1081 mem_info.fitness =
1082 MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
1083 MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
1084 MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
1085 mem_info.fitness = -mem_info.fitness;
1086 mem_state.push_back(mem_info);
1087 }
1088 }
1089
1090 // Sort by fitness
1091 std::sort(mem_state.begin(), mem_state.end());
1092
1093 for (const MemInfo& mem_info : mem_state) {
1094 for (const MutableGraphView::InputPort fanout_to_swap :
1095 mem_info.uses_left) {
1096 VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
1097 << fanout_to_swap.port_id << " of tensor "
1098 << mem_info.port.node->name() << ":" << mem_info.port.port_id
1099 << " of size " << mem_info.memory_used;
1100
1101 (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
1102 fanout_to_swap.port_id);
1103 }
1104 required_savings -= mem_info.memory_used;
1105 updated_graph = true;
1106 if (required_savings < 0) {
1107 break;
1108 }
1109 }
1110 }
1111 return updated_graph;
1112 }
1113
SwappingPass(RewriterConfig::MemOptType optimization_level,Cluster * cluster,GrapplerItem * item,std::unordered_set<string> * skip_list)1114 bool SwappingPass(RewriterConfig::MemOptType optimization_level,
1115 Cluster* cluster, GrapplerItem* item,
1116 std::unordered_set<string>* skip_list) {
1117 std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
1118 if (optimization_level == RewriterConfig::DEFAULT_MEM_OPT ||
1119 optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
1120 optimization_level == RewriterConfig::HEURISTICS) {
1121 // Use heuristics to figure out what needs to be swapped;
1122 IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap);
1123 }
1124 // Look for manual annotatations in the graph.
1125 for (auto& node : *item->graph.mutable_node()) {
1126 if (node.attr().count("_swap_to_host") != 0) {
1127 SwapInfo& swap_info = nodes_to_swap[&node];
1128 const AttrValue& val = node.attr().at("_swap_to_host");
1129 if (val.has_list()) {
1130 for (int64 input_id : val.list().i()) {
1131 swap_info.inputs_to_swap.push_back(input_id);
1132 }
1133 } else {
1134 int64 input_id = val.i();
1135 swap_info.inputs_to_swap.push_back(input_id);
1136 }
1137 }
1138 }
1139 if (nodes_to_swap.empty()) {
1140 // Nothing to do.
1141 return false;
1142 }
1143
1144 // Estimate the size of the data to swap for each node.
1145 GraphProperties properties(*item);
1146 if (!properties.InferStatically(true).ok()) {
1147 return false;
1148 }
1149 for (auto& swap : nodes_to_swap) {
1150 const NodeDef* node = swap.first;
1151 const std::vector<OpInfo::TensorProperties>& props =
1152 properties.GetInputProperties(node->name());
1153 SwapInfo& swap_info = swap.second;
1154 int64 bytes_to_swap = 0;
1155 for (int64 input_id : swap_info.inputs_to_swap) {
1156 const OpInfo::TensorProperties& t = props[input_id];
1157 bytes_to_swap += CalculateTensorSize(t);
1158 }
1159 // Let's assume we're going to swap over PCIe running at 16 GBps.
1160 swap_info.time_to_swap = bytes_to_swap / 16;
1161 }
1162
1163 std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
1164 if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
1165 return false;
1166 }
1167
1168 std::unordered_map<string, const NodeDef*> name_map;
1169 for (const auto& node : item->graph.node()) {
1170 name_map[node.name()] = &node;
1171 }
1172 MutableGraphView view(&item->graph);
1173
1174 bool updated_graph = false;
1175
1176 for (auto& swap : nodes_to_swap) {
1177 NodeDef* node = swap.first;
1178 const SwapInfo& swap_info = swap.second;
1179 if (skip_list->find(node->name()) != skip_list->end()) {
1180 continue;
1181 }
1182
1183 // Make sure the tensor isn't swapped back in right away: look for node that
1184 // will execute just before we need to swap the data back, and add a control
1185 // dependency from that node to the swap node.
1186 const NodeDef* in_trigger =
1187 FindSwapInTrigger(node, swap_info, name_map, execution_times);
1188 // If we failed, don't attempt to reprocess this node in a subsequent pass.
1189 if (!in_trigger) {
1190 skip_list->insert(node->name());
1191 continue;
1192 }
1193
1194 // Swap all the tensors that are marked with the 'swap_to_host' attribute.
1195 for (int input_id : swap_info.inputs_to_swap) {
1196 string input_name = strings::StrCat(node->name(), ":", input_id);
1197 if (skip_list->find(input_name) != skip_list->end()) {
1198 continue;
1199 } else {
1200 // Don't attempt to reprocess this input in a subsequent pass.
1201 skip_list->insert(input_name);
1202 }
1203
1204 // Make sure the tensor is swapped out quickly: look for node that
1205 // will execute just after the tensor is generated and add a control
1206 // dependency from the swap out node to that node.
1207 NodeDef* out_trigger =
1208 FindSwapOutTrigger(node, input_id, view, execution_times);
1209 if (!out_trigger) {
1210 continue;
1211 }
1212
1213 std::pair<NodeDef*, NodeDef*> swap_nodes;
1214 if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
1215 .ok()) {
1216 continue;
1217 }
1218 *swap_nodes.first->add_input() = node->input(input_id);
1219 *node->mutable_input(input_id) = swap_nodes.second->name();
1220
1221 // Add the control dependencies needed to delay the execution of the swap.
1222 out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name()));
1223 swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name()));
1224
1225 // Make sure we won't try to swap the swap nodes in subsequent passes.
1226 skip_list->insert(swap_nodes.first->name());
1227 skip_list->insert(swap_nodes.second->name());
1228 }
1229 }
1230 return updated_graph;
1231 }
1232
CrossesTaskOrCpuGpuBoundary(const NodeDef & node1,const NodeDef & node2)1233 bool CrossesTaskOrCpuGpuBoundary(const NodeDef& node1, const NodeDef& node2) {
1234 string task1;
1235 string device1;
1236 DeviceNameUtils::SplitDeviceName(node1.device(), &task1, &device1);
1237 string task2;
1238 string device2;
1239 DeviceNameUtils::SplitDeviceName(node2.device(), &task2, &device2);
1240 return task1 != task2 ||
1241 (str_util::StrContains(device1, DEVICE_CPU) &&
1242 str_util::StrContains(device2, DEVICE_GPU)) ||
1243 (str_util::StrContains(device1, DEVICE_GPU) &&
1244 str_util::StrContains(device2, DEVICE_CPU));
1245 }
1246
1247 // TODO(rmlarsen): Add distributed TF test.
RelaxAllocatorConstraints(GraphDef * optimized_graph)1248 Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
1249 std::unordered_set<string> devices;
1250 std::vector<int> assign_nodes;
1251 bool found_send = false;
1252 for (int i = 0; i < optimized_graph->node_size(); ++i) {
1253 const NodeDef& node = optimized_graph->node(i);
1254 devices.insert(node.device());
1255 if (IsAssign(node)) {
1256 assign_nodes.push_back(i);
1257 }
1258 if (IsSend(node)) {
1259 found_send = true;
1260 break;
1261 }
1262 }
1263 if (!found_send && devices.size() == 1) {
1264 for (int assign_idx : assign_nodes) {
1265 // Set an attribute telling AssignOp to ignore allocator constraints.
1266 NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
1267 (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
1268 .set_b(true);
1269 }
1270 return Status::OK();
1271 }
1272
1273 GraphTopologyView graph_view;
1274 TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
1275 *optimized_graph, /*ignore_control_edges=*/true));
1276 std::unordered_set<const NodeDef*> optimized_nodes;
1277
1278 for (int i : assign_nodes) {
1279 const NodeDef& assign_node = optimized_graph->node(i);
1280
1281 if (optimized_nodes.find(&assign_node) == optimized_nodes.end()) {
1282 std::vector<const NodeDef*> assign_nodes_in_fanout;
1283 optimized_nodes.insert(&assign_node);
1284 assign_nodes_in_fanout.push_back(&assign_node);
1285
1286 std::vector<const NodeDef*> transitive_fanout;
1287 // Find the nodes in transitive fanout. If a node is known to never
1288 // forward its inputs, we can skip its fanout.
1289 DfsTraversal(graph_view, {graph_view.GetNode(i)},
1290 TraversalDirection::kFollowOutputs,
1291 DfsPredicates::Advance([&](const NodeDef* node) {
1292 return !NeverForwardsInputs(*node);
1293 }),
1294 DfsCallbacks::PreOrder([&](const NodeDef* node) {
1295 transitive_fanout.push_back(node);
1296 }));
1297
1298 bool relax_constraint = true;
1299 // If all nodes in the transitive fanout are on the same device as the
1300 // assign node, there is no need to allocate the output in pinned memory.
1301 for (const NodeDef* fanout_node : transitive_fanout) {
1302 if (relax_constraint &&
1303 (IsSend(*fanout_node) ||
1304 CrossesTaskOrCpuGpuBoundary(*fanout_node, assign_node))) {
1305 relax_constraint = false;
1306 break;
1307 }
1308 if (optimized_nodes.find(fanout_node) == optimized_nodes.end() &&
1309 IsAssign(*fanout_node)) {
1310 assign_nodes_in_fanout.push_back(fanout_node);
1311 }
1312 }
1313
1314 if (relax_constraint) {
1315 for (const NodeDef* assign_node_in_fanout : assign_nodes_in_fanout) {
1316 // If all devices match in fanout of node(i) then, by transitivity,
1317 // they must also match in the fanout of other assign nodes
1318 // in the fanout of node(i), so we can process them here,
1319 // and save computing their transitive fanout later.
1320 optimized_nodes.insert(assign_node_in_fanout);
1321
1322 // Set an attribute telling AssignOp to ignore allocator constraints.
1323 const absl::optional<int> assign_node_idx =
1324 graph_view.GetNodeIndex(*assign_node_in_fanout);
1325 NodeDef* assign_node_to_relax =
1326 optimized_graph->mutable_node(assign_node_idx.value());
1327 (*assign_node_to_relax
1328 ->mutable_attr())["_grappler_relax_allocator_constraints"]
1329 .set_b(true);
1330 }
1331 }
1332 }
1333 }
1334 return Status::OK();
1335 }
1336
1337 } // namespace
1338
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1339 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
1340 GraphDef* optimized_graph) {
1341 GrapplerItem optimized_item(item);
1342
1343 RecomputationRewritingPass(optimization_level_,
1344 recomputation_targets_name_scope_,
1345 &optimized_item.graph, item);
1346
1347 std::unordered_set<string> skip_list;
1348 // Bound the number of rewrite passes to avoid long processing times on graphs
1349 // that simply won't fit in memory.
1350 bool updated_graph = true;
1351 for (int i = 0; i < 25 && updated_graph; ++i) {
1352 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1353 updated_graph = false;
1354 if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1355 optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
1356 optimization_level_ == RewriterConfig::HEURISTICS) &&
1357 cluster != nullptr) {
1358 updated_graph |= SchedulingPass(cluster, &optimized_item);
1359 }
1360
1361 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1362 if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1363 optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
1364 optimization_level_ == RewriterConfig::HEURISTICS ||
1365 optimization_level_ == RewriterConfig::MANUAL) &&
1366 cluster != nullptr) {
1367 updated_graph |= SwappingPass(optimization_level_, cluster,
1368 &optimized_item, &skip_list);
1369 }
1370 }
1371
1372 TF_RETURN_IF_ERROR(RelaxAllocatorConstraints(&optimized_item.graph));
1373
1374 optimized_graph->Swap(&optimized_item.graph);
1375 return Status::OK();
1376 }
1377
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)1378 void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
1379 const GraphDef& optimized_graph, double result) {
1380 // Nothing to do for MemoryOptimizer.
1381 }
1382
1383 } // end namespace grappler
1384 } // end namespace tensorflow
1385