• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_testutil.h"
28 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/utils.h"
31 #include "tensorflow/core/grappler/utils/grappler_test.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/protobuf/device_properties.pb.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 namespace {
39 
40 class RecomputeSubgraphTest : public GrapplerTest {};
41 
TEST_F(RecomputeSubgraphTest,SimpleSubgraph)42 TEST_F(RecomputeSubgraphTest, SimpleSubgraph) {
43   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44 
45   Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
46   Output b = ops::Identity(s.WithOpName("b"), a);  // Recomputed
47   Output c = ops::Identity(s.WithOpName("c"), b);
48   Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
49   Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
50   Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
51 
52   GrapplerItem item;
53   TF_CHECK_OK(s.ToGraphDef(&item.graph));
54   EXPECT_EQ(6, item.graph.node_size());
55   NodeMap pre_transform_node_map(&item.graph);
56   (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
57       .set_i(0);
58 
59   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
60   GraphDef output;
61   Status status = optimizer.Optimize(nullptr, item, &output);
62 
63   TF_EXPECT_OK(status);
64   NodeMap post_transform_node_map(&output);
65   EXPECT_EQ(8, output.node_size());
66   NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
67   EXPECT_EQ(2, transformed_e->input_size());
68   EXPECT_EQ("gradients/d", transformed_e->input(0));
69   EXPECT_EQ("Recomputed/b", transformed_e->input(1));
70   NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/b");
71   EXPECT_EQ(2, recomputed_b->input_size());
72   EXPECT_EQ("a", recomputed_b->input(0));
73   EXPECT_EQ("^RecomputeTrigger/b", recomputed_b->input(1));
74   NodeDef* recompute_trigger =
75       post_transform_node_map.GetNode("RecomputeTrigger/b");
76   EXPECT_EQ(1, recompute_trigger->input_size());
77   EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
78 }
79 
TEST_F(RecomputeSubgraphTest,NoFeedsRecomputed)80 TEST_F(RecomputeSubgraphTest, NoFeedsRecomputed) {
81   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
82 
83   Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
84   Output b = ops::Identity(s.WithOpName("b"), a);  // Would be recomputed, but
85                                                    // for being fed
86   Output c = ops::Identity(s.WithOpName("c"), b);
87   Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
88   Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
89   Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
90 
91   GrapplerItem item;
92   TF_CHECK_OK(s.ToGraphDef(&item.graph));
93   item.feed.emplace_back("b", Tensor());
94   EXPECT_EQ(6, item.graph.node_size());
95   NodeMap pre_transform_node_map(&item.graph);
96   (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
97       .set_i(0);
98 
99   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
100   GraphDef output;
101   Status status = optimizer.Optimize(nullptr, item, &output);
102 
103   TF_EXPECT_OK(status);
104   EXPECT_EQ(6, output.node_size());
105 }
106 
TEST_F(RecomputeSubgraphTest,TwoInputSubgraphs)107 TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
108   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
109 
110   Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
111   Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT);
112   Output d = ops::AddN(
113       s.WithOpName("some_name_scope/gradients/two_subgraph_inputs"), {a, b});
114 
115   GrapplerItem item;
116   TF_CHECK_OK(s.ToGraphDef(&item.graph));
117   EXPECT_EQ(3, item.graph.node_size());
118   NodeMap pre_transform_node_map(&item.graph);
119   (*pre_transform_node_map.GetNode("a")->mutable_attr())["_recompute_hint"]
120       .set_i(0);
121   (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
122       .set_i(0);
123 
124   MemoryOptimizer optimizer(RewriterConfig::MANUAL,
125                             "some_name_scope/gradients");
126   GraphDef output;
127   Status status = optimizer.Optimize(nullptr, item, &output);
128 
129   TF_EXPECT_OK(status);
130   NodeMap post_transform_node_map(&output);
131   // Mostly checking that this case does not crash.
132   EXPECT_EQ(7, output.node_size());
133   EXPECT_NE(post_transform_node_map.GetNode("Recomputed/a"), nullptr);
134   EXPECT_NE(post_transform_node_map.GetNode("Recomputed/b"), nullptr);
135   EXPECT_NE(post_transform_node_map.GetNode("RecomputeTrigger/a"), nullptr);
136   EXPECT_NE(post_transform_node_map.GetNode("RecomputeTrigger/b"), nullptr);
137 }
138 
TEST_F(RecomputeSubgraphTest,MultiNode)139 TEST_F(RecomputeSubgraphTest, MultiNode) {
140   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
141 
142   Output a = ops::Variable(s.WithOpName("Conv"), {2, 3, 4}, DT_FLOAT);
143   Output b = ops::Identity(s.WithOpName("BN"), a);    // Recomputed
144   Output c = ops::Identity(s.WithOpName("ReLU"), b);  // Recomputed
145   Output d = ops::Identity(s.WithOpName("Conv1"), c);
146 
147   // The "gradients/" prefix means the heuristic will pick these up as
148   // candidates to have their inputs recomputed.
149   Output trigger = ops::AddN(s.WithOpName("gradients/BN1Grad"), {d});
150   Output e = ops::AddN(s.WithOpName("gradients/Conv1Grad"), {trigger, c});
151   Output f = ops::AddN(s.WithOpName("gradients/ReLUGrad"), {e, c});
152   Output g = ops::AddN(s.WithOpName("gradients/BNGrad"), {f, a});
153   Output h = ops::AddN(s.WithOpName("gradients/ConvGrad"), {g});
154 
155   GrapplerItem item;
156   TF_CHECK_OK(s.ToGraphDef(&item.graph));
157   EXPECT_EQ(9, item.graph.node_size());
158   NodeMap pre_transform_node_map(&item.graph);
159   // Set op types so that the heuristic will pick these nodes up to be
160   // recomputed
161   pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm");
162   pre_transform_node_map.GetNode("ReLU")->set_op("Relu");
163 
164   MemoryOptimizer optimizer(RewriterConfig::RECOMPUTATION_HEURISTICS);
165   GraphDef first_pass_output;
166   Status first_pass_status =
167       optimizer.Optimize(nullptr, item, &first_pass_output);
168   TF_EXPECT_OK(first_pass_status);
169 
170   NodeMap post_transform_node_map(&first_pass_output);
171   EXPECT_EQ(13, first_pass_output.node_size());
172   NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
173   EXPECT_EQ(2, transformed_e->input_size());
174   EXPECT_EQ("gradients/BN1Grad", transformed_e->input(0));
175   EXPECT_EQ("Recomputed/ReLU", transformed_e->input(1));
176   NodeDef* transformed_f = post_transform_node_map.GetNode(f.name());
177   EXPECT_EQ(2, transformed_f->input_size());
178   EXPECT_EQ("gradients/Conv1Grad", transformed_f->input(0));
179   EXPECT_EQ("Recomputed/ReLU", transformed_f->input(1));
180   NodeDef* transformed_g = post_transform_node_map.GetNode(g.name());
181   EXPECT_EQ(2, transformed_g->input_size());
182   EXPECT_EQ("gradients/ReLUGrad", transformed_g->input(0));
183   EXPECT_EQ("Conv", transformed_g->input(1));
184 
185   NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/BN");
186   EXPECT_EQ(2, recomputed_b->input_size());
187   EXPECT_EQ("Conv", recomputed_b->input(0));
188   EXPECT_EQ("^RecomputeTrigger/BN", recomputed_b->input(1));
189   NodeDef* recompute_trigger_b =
190       post_transform_node_map.GetNode("RecomputeTrigger/BN");
191   EXPECT_EQ(1, recompute_trigger_b->input_size());
192   EXPECT_EQ("^RecomputeTrigger/ReLU", recompute_trigger_b->input(0));
193 
194   NodeDef* recomputed_c = post_transform_node_map.GetNode("Recomputed/ReLU");
195   EXPECT_EQ(2, recomputed_c->input_size());
196   EXPECT_EQ("Recomputed/BN", recomputed_c->input(0));
197   EXPECT_EQ("^RecomputeTrigger/ReLU", recomputed_c->input(1));
198   NodeDef* recompute_trigger_c =
199       post_transform_node_map.GetNode("RecomputeTrigger/ReLU");
200   EXPECT_EQ(1, recompute_trigger_c->input_size());
201   EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0));
202 }
203 
204 class MemoryOptimizerTest : public GrapplerTest {
205  public:
CreateVirtualCluster()206   static std::unique_ptr<VirtualCluster> CreateVirtualCluster() {
207     DeviceProperties cpu_device;
208     cpu_device.set_type("CPU");
209     cpu_device.set_frequency(1000);
210     cpu_device.set_num_cores(4);
211     cpu_device.set_bandwidth(32);
212     cpu_device.set_memory_size(1024 * 1024);
213     DeviceProperties gpu_device;
214     gpu_device.set_type("GPU");
215     gpu_device.set_frequency(1000);
216     gpu_device.set_num_cores(24);
217     gpu_device.set_bandwidth(128);
218     gpu_device.set_memory_size(1024 * 1024);
219     gpu_device.mutable_environment()->insert({"architecture", "6"});
220     std::unordered_map<string, DeviceProperties> devices;
221     devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
222     devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
223     return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
224   }
225 };
226 
TEST_F(MemoryOptimizerTest,SimpleSwapping)227 TEST_F(MemoryOptimizerTest, SimpleSwapping) {
228   // Build a simple graph with an op that's marked for swapping.
229   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
230 
231   Output a =
232       ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), {10, 10}, DT_FLOAT);
233   Output b = ops::AddN(s.WithOpName("b").WithDevice("/gpu:0"), {a});
234   Output c = ops::AddN(s.WithOpName("c").WithDevice("/gpu:0"), {b});
235   Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {c});
236   Output e = ops::AddN(s.WithOpName("e").WithDevice("/gpu:0"), {b, d});
237 
238   Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {10, 10});
239   Output init = ops::Assign(s.WithOpName("init"), a, constant);
240 
241   GrapplerItem item;
242   TF_CHECK_OK(s.ToGraphDef(&item.graph));
243 
244   EXPECT_EQ(7, item.graph.node_size());
245   EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
246   AttrValue& val =
247       (*item.graph.mutable_node(4)->mutable_attr())["_swap_to_host"];
248   val.mutable_list()->add_i(0);
249 
250   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
251 
252   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
253   GraphDef output;
254   Status status = optimizer.Optimize(cluster.get(), item, &output);
255   TF_EXPECT_OK(status);
256 
257   EXPECT_EQ(9, output.node_size());
258   const NodeDef& new_e = output.node(6);
259   EXPECT_EQ(NodeName(e.name()), new_e.name());
260 
261   EXPECT_EQ(2, new_e.input_size());
262   EXPECT_EQ(NodeName(d.name()), new_e.input(1));
263   EXPECT_EQ("swap_in_e_0", new_e.input(0));
264 
265   const NodeDef& swap_out = output.node(7);
266   EXPECT_EQ("swap_out_e_0", swap_out.name());
267   EXPECT_EQ("_CopyFromGpuToHost", swap_out.op());
268 
269   const NodeDef& swap_in = output.node(8);
270   EXPECT_EQ("swap_in_e_0", swap_in.name());
271   EXPECT_EQ("_CopyFromHostToGpu", swap_in.op());
272 
273   EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
274   EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
275   EXPECT_EQ("^c", swap_in.input(1));
276 
277   const NodeDef& new_c = output.node(4);
278   EXPECT_EQ(NodeName(c.name()), new_c.name());
279   EXPECT_EQ("^swap_out_e_0", new_c.input(1));
280 
281   // Run the optimizer a second time to ensure it's idempotent.
282   GrapplerItem item_copy = item.WithGraph(std::move(output));
283   status = optimizer.Optimize(cluster.get(), item_copy, &output);
284   TF_EXPECT_OK(status);
285 
286 #if GOOGLE_CUDA
287   item.fetch = {"e"};
288   item.init_ops = {init.name()};
289   auto tensors_expected = EvaluateFetchNodes(item);
290   GrapplerItem optimized = item.WithGraph(std::move(output));
291   auto tensors = EvaluateFetchNodes(optimized);
292   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
293 #endif
294 }
295 
TEST_F(MemoryOptimizerTest,SwappingHeuristics)296 TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
297   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
298   Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
299                            {128, 128, 8}, DT_FLOAT);
300   Output a = ops::Identity(s.WithOpName("a").WithDevice("/gpu:0"), v);
301   Output b = ops::Square(s.WithOpName("b").WithDevice("/gpu:0"), v);
302   Output c = ops::Sqrt(s.WithOpName("c").WithDevice("/gpu:0"), a);
303   Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), b);
304   Output axis = ops::Const(s.WithOpName("axis"), 0);
305   Output e =
306       ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {a, b, c, d}, axis);
307   Output f = ops::Square(s.WithOpName("f").WithDevice("/gpu:0"), a);
308   Output g = ops::Sqrt(s.WithOpName("g").WithDevice("/gpu:0"), b);
309   Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c);
310   Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d);
311 
312   Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8});
313   Output init = ops::Assign(s.WithOpName("init"), v, constant);
314 
315   GrapplerItem item;
316   TF_CHECK_OK(s.ToGraphDef(&item.graph));
317   item.fetch = {"e", "f", "g", "h", "i"};
318   item.init_ops = {init.name()};
319 
320   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
321 
322   MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
323   GraphDef output;
324   Status status = optimizer.Optimize(cluster.get(), item, &output);
325   TF_EXPECT_OK(status);
326 
327   for (const auto& node : output.node()) {
328     if (node.name() == "e") {
329       EXPECT_EQ(5, node.input_size());
330       EXPECT_EQ("a", node.input(0));
331       EXPECT_EQ("swap_in_e_1", node.input(1));
332       EXPECT_EQ("swap_in_e_2", node.input(2));
333       EXPECT_EQ("swap_in_e_3", node.input(3));
334       EXPECT_EQ("axis", node.input(4));
335     }
336   }
337 
338 #if GOOGLE_CUDA
339   auto tensors_expected = EvaluateFetchNodes(item);
340   GrapplerItem optimized = item.WithGraph(std::move(output));
341   auto tensors = EvaluateFetchNodes(optimized);
342   for (int i = 0; i < item.fetch.size(); ++i) {
343     test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
344   }
345 #endif
346 }
347 
TEST_F(MemoryOptimizerTest,UnswappableInputs)348 TEST_F(MemoryOptimizerTest, UnswappableInputs) {
349   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
350   Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
351                            {128, 128, 8}, DT_FLOAT);
352   Output a = ops::Square(s.WithOpName("a").WithDevice("/gpu:0"), v);
353   Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
354   Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
355   Output index = ops::Const(s.WithOpName("index"), {0});
356   Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
357   Output d =
358       ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), v, indices, c);
359   Output axis = ops::Const(s.WithOpName("axis"), 0);
360   Output e =
361       ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
362 
363   Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8});
364   Output init = ops::Assign(s.WithOpName("init"), v, constant);
365 
366   GrapplerItem item;
367   TF_CHECK_OK(s.ToGraphDef(&item.graph));
368   item.fetch = {"e"};
369   item.init_ops = {init.name()};
370 
371   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
372 
373   MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
374   GraphDef output;
375   Status status = optimizer.Optimize(cluster.get(), item, &output);
376   TF_EXPECT_OK(status);
377 
378   for (const auto& node : output.node()) {
379     if (node.name() == "e") {
380       // The d node isn't swappable.
381       EXPECT_EQ(5, node.input_size());
382       EXPECT_EQ("d", node.input(2));
383       EXPECT_EQ("^swap_out_d_2", node.input(4));
384     }
385   }
386 
387 #if GOOGLE_CUDA
388   auto tensors_expected = EvaluateFetchNodes(item);
389   GrapplerItem optimized = item.WithGraph(std::move(output));
390   auto tensors = EvaluateFetchNodes(optimized);
391   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
392 #endif
393 }
394 
TEST_F(MemoryOptimizerTest,AccumulationRewrites)395 TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
396   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
397   Output a = ops::RandomNormal(s.WithOpName("a").WithDevice("/cpu:0"),
398                                {128, 128, 8}, DT_FLOAT);
399   Output b = ops::RandomNormal(s.WithOpName("b").WithDevice("/cpu:0"),
400                                {128, 128, 8}, DT_FLOAT);
401   Output c = ops::RandomNormal(s.WithOpName("c").WithDevice("/cpu:0"),
402                                {128, 128, 8}, DT_FLOAT);
403   Output d = ops::AddN(s.WithOpName("d").WithDevice("/cpu:0"), {a, b, c});
404   Output e = ops::Square(s.WithOpName("e").WithDevice("/cpu:0"), d);
405 
406   GrapplerItem item;
407   TF_CHECK_OK(s.ToGraphDef(&item.graph));
408   item.fetch = {"e"};
409 
410   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
411   MemoryOptimizer optimizer(RewriterConfig::SCHEDULING_HEURISTICS);
412   GraphDef output;
413   Status status = optimizer.Optimize(cluster.get(), item, &output);
414   TF_EXPECT_OK(status);
415 
416   int count = 0;
417   for (const auto& node : output.node()) {
418     if (node.name() == "d") {
419       EXPECT_EQ("DestroyTemporaryVariable", node.op());
420       count++;
421     } else if (node.name() == "d/tmp_var_initializer") {
422       EXPECT_EQ("Assign", node.op());
423       count++;
424     } else if (node.name() == "d/tmp_var") {
425       EXPECT_EQ("TemporaryVariable", node.op());
426       count++;
427     } else if (node.name() == "e") {
428       EXPECT_EQ("Square", node.op());
429       EXPECT_EQ("d", node.input(0));
430       count++;
431     }
432   }
433   EXPECT_EQ(4, count);
434 
435   std::vector<string> fetch = {"a", "b", "c", "e"};
436   auto tensors = EvaluateNodes(output, fetch, {});
437   EXPECT_EQ(4, tensors.size());
438 
439   for (int i = 0; i < tensors[0].NumElements(); ++i) {
440     float actual = tensors[3].flat<float>()(i);
441     float expected = 0.0f;
442     for (int j = 0; j < 3; ++j) {
443       expected += tensors[j].flat<float>()(i);
444     }
445     expected *= expected;
446     EXPECT_NEAR(actual, expected, 1e-4);
447   }
448 }
449 
450 class RelaxAllocatorConstraintsTest : public GrapplerTest {};
451 
TEST_F(RelaxAllocatorConstraintsTest,SameDevice)452 TEST_F(RelaxAllocatorConstraintsTest, SameDevice) {
453   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
454   Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
455                                -3.14f, {128, 128});
456   Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
457                                   {128, 128}, DT_FLOAT);
458   Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
459                               variable, constant);
460   Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/cpu:0"), assign);
461 
462   GrapplerItem item;
463   TF_CHECK_OK(s.ToGraphDef(&item.graph));
464 
465   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
466   GraphDef output;
467   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
468 
469   auto node = output.node(2);
470   EXPECT_EQ("assign", node.name());
471   EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
472   EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
473 
474   item.fetch = {"exp"};
475   item.init_ops = {"variable"};
476   auto tensors_expected = EvaluateFetchNodes(item);
477   GrapplerItem optimized = item.WithGraph(std::move(output));
478   auto tensors = EvaluateFetchNodes(optimized);
479   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
480 }
481 
TEST_F(RelaxAllocatorConstraintsTest,DifferentDevice)482 TEST_F(RelaxAllocatorConstraintsTest, DifferentDevice) {
483   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
484   Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
485                                -3.14f, {128, 128});
486   Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
487                                   {128, 128}, DT_FLOAT);
488   Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
489                               variable, constant);
490   // exp runs on a different device, so we cannot relax the allocation
491   // constraints on assign.
492   Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/gpu:0"), assign);
493 
494   GrapplerItem item;
495   TF_CHECK_OK(s.ToGraphDef(&item.graph));
496 
497   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
498   GraphDef output;
499   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
500 
501   auto node = output.node(2);
502   EXPECT_EQ("assign", node.name());
503   EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
504 #if GOOGLE_CUDA
505   item.fetch = {"exp"};
506   item.init_ops = {"variable"};
507   auto tensors_expected = EvaluateFetchNodes(item);
508   GrapplerItem optimized = item.WithGraph(std::move(output));
509   auto tensors = EvaluateFetchNodes(optimized);
510   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
511 #endif
512 }
513 
TEST_F(RelaxAllocatorConstraintsTest,SameDeviceType)514 TEST_F(RelaxAllocatorConstraintsTest, SameDeviceType) {
515   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
516   Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
517                                -3.14f, {128, 128});
518   Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
519                                   {128, 128}, DT_FLOAT);
520   Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
521                               variable, constant);
522   // Assign and Exp run on different devies, but do not straddle a CPU:GPU
523   // boundary, so we can we do not need to enforce allocation in pinned memory.
524   Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/cpu:1"), assign);
525 
526   GrapplerItem item;
527   TF_CHECK_OK(s.ToGraphDef(&item.graph));
528 
529   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
530   GraphDef output;
531   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
532 
533   auto node = output.node(2);
534   EXPECT_EQ("assign", node.name());
535   EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
536   EXPECT_TRUE(node.attr().at("_grappler_relax_allocator_constraints").b());
537 }
538 
TEST_F(RelaxAllocatorConstraintsTest,SendNode)539 TEST_F(RelaxAllocatorConstraintsTest, SendNode) {
540   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
541   Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
542                                -3.14f, {128, 128});
543   Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
544                                   {128, 128}, DT_FLOAT);
545   Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
546                               variable, constant);
547 
548   GrapplerItem item;
549   TF_CHECK_OK(s.ToGraphDef(&item.graph));
550   NodeDef* send = item.graph.add_node();
551   // Add a send node to the graph in the fanout of "assign".
552   send->set_name("send");
553   send->set_op("_Send");
554   send->add_input("assign");
555 
556   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
557   GraphDef output;
558   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
559 
560   auto node = output.node(2);
561   EXPECT_EQ("assign", node.name());
562   EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
563 }
564 
TEST_F(RelaxAllocatorConstraintsTest,AssignNodeInFanout)565 TEST_F(RelaxAllocatorConstraintsTest, AssignNodeInFanout) {
566   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
567   Output constant0 = ops::Const(s.WithOpName("constant0").WithDevice("/cpu:0"),
568                                 -42.0f, {128, 128});
569   Output variable0 = ops::Variable(
570       s.WithOpName("variable0").WithDevice("/cpu:0"), {128, 128}, DT_FLOAT);
571   Output assign0 = ops::Assign(s.WithOpName("assign0").WithDevice("/cpu:0"),
572                                variable0, constant0);
573   Output assign2 = ops::Assign(s.WithOpName("assign2").WithDevice("/cpu:0"),
574                                variable0, constant0);
575   Output assign3 = ops::Assign(s.WithOpName("assign3").WithDevice("/cpu:0"),
576                                variable0, constant0);
577   Output assign4 = ops::Assign(s.WithOpName("assign4").WithDevice("/cpu:0"),
578                                variable0, constant0);
579   // Rank does not forward its input buffer, so assign3 can be relaxed.
580   Output rank_cpu =
581       ops::Rank(s.WithOpName("rank_cpu").WithDevice("/cpu:0"), assign3);
582   // Exp could forward its input buffer, so we cannot relax assign4.
583   Output exp_cpu =
584       ops::Exp(s.WithOpName("exp_cpu").WithDevice("/cpu:0"), assign4);
585 
586   // The rest of the graph is on a second device, so we can relax the
587   // constraint for assign1, but not for assign0. Assign2 only has a
588   // control dependency crossing the device boundary, so it can be relaxed too.
589   Output rank_gpu = ops::Rank(s.WithOpName("rank_gpu")
590                                   .WithDevice("/gpu:0")
591                                   .WithControlDependencies(assign2),
592                               assign0);
593   Output id_gpu = ops::Identity(s.WithOpName("id_gpu"), rank_cpu);
594   Output id_gpu2 = ops::Identity(s.WithOpName("id_gpu2"), exp_cpu);
595   Output variable_gpu = ops::Variable(
596       s.WithOpName("variable_gpu").WithDevice("/gpu:0"), {128, 128}, DT_FLOAT);
597   Output assign_gpu = ops::Assign(
598       s.WithOpName("assign_gpu").WithDevice("/gpu:0"), variable_gpu, exp_cpu);
599 
600   GrapplerItem item;
601   TF_CHECK_OK(s.ToGraphDef(&item.graph));
602   item.fetch = {"assign0", "assign_gpu", "rank_gpu", "id_gpu", "id_gpu2"};
603 
604   MemoryOptimizer optimizer(RewriterConfig::MANUAL);
605   GraphDef output;
606   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
607 
608   auto node = output.node(3);
609   EXPECT_EQ("assign0", node.name());
610   EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
611 
612   node = output.node(4);
613   EXPECT_EQ("assign2", node.name());
614   EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
615   EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
616 
617   node = output.node(5);
618   EXPECT_EQ("assign3", node.name());
619   EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
620   EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
621 
622   node = output.node(6);
623   EXPECT_EQ("assign4", node.name());
624   EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
625 
626   node = output.node(12);
627   EXPECT_EQ("assign_gpu", node.name());
628   EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
629   EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
630 
631 #if GOOGLE_CUDA
632   item.init_ops = {"exp_cpu", "variable_gpu"};
633   auto tensors_expected = EvaluateFetchNodes(item);
634   GrapplerItem optimized = item.WithGraph(std::move(output));
635   auto tensors = EvaluateFetchNodes(optimized);
636   for (int i = 0; i < tensors_expected.size(); ++i) {
637     if (i == 2 || i == 3) {
638       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
639     } else {
640       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
641     }
642   }
643 #endif
644 }
645 
646 }  // namespace
647 }  // namespace grappler
648 }  // namespace tensorflow
649