1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/graph/graph_def_builder_util.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace tensorflow {
38 namespace {
39
40 using deadness_analysis_internal::ComputePredicates;
41 using deadness_analysis_internal::PredicateMapTy;
42
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)43 Status AnalyzeDeadness(Graph* graph,
44 std::unique_ptr<DeadnessAnalysis>* result) {
45 FixupSourceAndSinkEdges(graph);
46 return DeadnessAnalysis::Run(*graph, result);
47 }
48
CreateSwitch(const Scope & root,const string & prefix)49 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
50 Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
51 Output predicate =
52 ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
53 return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
54 }
55
ControlOutputFor(const Output & o)56 TensorId ControlOutputFor(const Output& o) {
57 return {o.node()->name(), Graph::kControlSlot};
58 }
59
VLogGraphIfAsked(const Graph & graph)60 void VLogGraphIfAsked(const Graph& graph) {
61 if (VLOG_IS_ON(3)) {
62 GraphDef graph_def;
63 graph.ToGraphDef(&graph_def);
64 string serialized;
65 ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
66 LOG(INFO) << serialized;
67 }
68 }
69
70 struct InductionVarInfo {
71 Output induction_var;
72 Output loop_cond;
73 };
74
75 // Creates an induction variable with the following structure (simplified for
76 // brevity):
77 //
78 // +---------------+
79 // | initial_value |
80 // +---------------+
81 // |
82 // |
83 // v
84 // +---------------+
85 // | Enter |
86 // +---------------+
87 // |
88 // |
89 // v
90 // +---------------+
91 // +> | Merge | -+
92 // | +---------------+ |
93 // | | |
94 // | | |
95 // | v |
96 // | +---------------+ |
97 // | | LessThan10 | |
98 // | +---------------+ |
99 // | | |
100 // | | |
101 // | v |
102 // | +---------------+ |
103 // +----+- | Switch | <+
104 // | | +---------------+
105 // | | |
106 // | | |
107 // | | v
108 // | | +---------------+
109 // | +- | AddOne |
110 // | +---------------+
111 // | +---------------+
112 // +-----> | Exit |
113 // +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)114 InductionVarInfo CreateInductionVariable(const Scope& root,
115 const string& prefix,
116 const string& frame_name,
117 const Output& initial_value) {
118 Output enter_initial_value = ops::internal::Enter(
119 root.WithOpName(prefix + "/enter"), initial_value, frame_name);
120
121 ops::Merge iv(root.WithOpName(prefix + "/iv"),
122 {enter_initial_value, enter_initial_value});
123 Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
124 Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
125 Output loop_cond_expr =
126 ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
127 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
128 loop_cond_expr);
129 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
130 latch.output_false);
131 Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
132 latch.output_true, increment_by);
133 Output next_iteration =
134 ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
135
136 CHECK(root.graph()
137 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
138 .ok());
139 root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
140 root.graph()->AddControlEdge(iv.output.node(), final_value.node());
141
142 return {iv.output, loop_cond_expr};
143 }
144
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32 init)145 InductionVarInfo CreateInductionVariable(const Scope& root,
146 const string& prefix,
147 const string& frame_name, int32 init) {
148 return CreateInductionVariable(
149 root, prefix, frame_name,
150 ops::Const(root.WithOpName(prefix + "/init"), init));
151 }
152
153 // Creates an induction variable with the following structure:
154 //
155 // +---------------+
156 // | initial_value |
157 // +---------------+
158 // |
159 // |
160 // v
161 // +---------------+
162 // | Enter |
163 // +---------------+
164 // |
165 // |
166 // v
167 // +---------------+
168 // | Merge | <+
169 // +---------------+ |
170 // | |
171 // | |
172 // v |
173 // +-----------+ +---------------+ |
174 // | loop_cond | --> | Switch | -+
175 // +-----------+ +---------------+
176 // |
177 // |
178 // v
179 // +---------------+
180 // | Exit |
181 // +---------------+
182 struct DependentInductionVar {
183 Output induction_var;
184 ops::Switch latch;
185 };
186
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)187 DependentInductionVar CreateDependentLoopInvariantValue(
188 const Scope& root, const string& prefix, const string& frame_name,
189 const Output& loop_cond, const Output& value) {
190 Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
191 value, frame_name);
192 ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
193 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
194 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
195 latch.output_false);
196 Output next_iteration = ops::NextIteration(
197 root.WithOpName(prefix + "/next_iteration"), latch.output_true);
198 CHECK(root.graph()
199 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
200 .ok());
201 return {iv.output, latch};
202 }
203
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32 value)204 DependentInductionVar CreateDependentLoopInvariantValue(
205 const Scope& root, const string& prefix, const string& frame_name,
206 const Output& loop_cond, int32 value) {
207 return CreateDependentLoopInvariantValue(
208 root, prefix, frame_name, loop_cond,
209 ops::Const(root.WithOpName(prefix + "/init"), value));
210 }
211
TEST(DeadnessAnalysisTest,BasicPositive)212 TEST(DeadnessAnalysisTest, BasicPositive) {
213 Scope root = Scope::NewRootScope().ExitOnError();
214
215 ops::Switch sw = CreateSwitch(root, "0");
216 Output add =
217 ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
218
219 std::unique_ptr<DeadnessAnalysis> result;
220 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
221
222 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
223 }
224
TEST(DeadnessAnalysisTest,BasicNegative)225 TEST(DeadnessAnalysisTest, BasicNegative) {
226 Scope root = Scope::NewRootScope().ExitOnError();
227
228 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
229 Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
230 Output add = ops::Add(root.WithOpName("add"), a, b);
231
232 std::unique_ptr<DeadnessAnalysis> result;
233 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
234
235 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
236 }
237
TEST(DeadnessAnalysisTest,AndIsCommutative)238 TEST(DeadnessAnalysisTest, AndIsCommutative) {
239 Scope root = Scope::NewRootScope().ExitOnError();
240
241 ops::Switch sw_0 = CreateSwitch(root, "0");
242 ops::Switch sw_1 = CreateSwitch(root, "1");
243
244 Output a0 =
245 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
246 Output a1 =
247 ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
248
249 Output b0 =
250 ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
251 Output b1 =
252 ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
253
254 Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
255 Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
256
257 Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
258 Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
259
260 std::unique_ptr<DeadnessAnalysis> result;
261 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
262
263 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
264 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
265
266 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
267 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
268 }
269
TEST(DeadnessAnalysisTest,AndIsAssociative)270 TEST(DeadnessAnalysisTest, AndIsAssociative) {
271 Scope root = Scope::NewRootScope().ExitOnError();
272
273 ops::Switch sw_0 = CreateSwitch(root, "0");
274 ops::Switch sw_1 = CreateSwitch(root, "1");
275 ops::Switch sw_2 = CreateSwitch(root, "2");
276
277 Output a0 =
278 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
279 Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
280
281 Output b0 =
282 ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
283 Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
284
285 Output add = ops::Add(root.WithOpName("add"), a1, b1);
286
287 std::unique_ptr<DeadnessAnalysis> result;
288 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
289
290 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
291 }
292
TEST(DeadnessAnalysisTest,OrIsCommutative)293 TEST(DeadnessAnalysisTest, OrIsCommutative) {
294 Scope root = Scope::NewRootScope().ExitOnError();
295
296 ops::Switch sw_0 = CreateSwitch(root, "0");
297 ops::Switch sw_1 = CreateSwitch(root, "1");
298
299 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
300 ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
301 ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
302 ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
303
304 Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
305 Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
306
307 Output halfdead0 =
308 ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
309 Output halfdead1 =
310 ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
311
312 std::unique_ptr<DeadnessAnalysis> result;
313 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
314
315 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
316 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
317
318 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
319 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
320 }
321
TEST(DeadnessAnalysisTest,OrIsAssociative)322 TEST(DeadnessAnalysisTest, OrIsAssociative) {
323 Scope root = Scope::NewRootScope().ExitOnError();
324
325 ops::Switch sw_0 = CreateSwitch(root, "0");
326 ops::Switch sw_1 = CreateSwitch(root, "1");
327 ops::Switch sw_2 = CreateSwitch(root, "2");
328
329 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
330 ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
331 ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
332 ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
333
334 Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
335
336 std::unique_ptr<DeadnessAnalysis> result;
337 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
338
339 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
340 }
341
TEST(DeadnessAnalysisTest,AndOfOr)342 TEST(DeadnessAnalysisTest, AndOfOr) {
343 Scope root = Scope::NewRootScope().ExitOnError();
344
345 ops::Switch sw_0 = CreateSwitch(root, "0");
346 ops::Switch sw_1 = CreateSwitch(root, "1");
347 ops::Switch sw_2 = CreateSwitch(root, "2");
348 ops::Switch sw_3 = CreateSwitch(root, "3");
349
350 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
351 ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
352
353 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
354 Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
355
356 Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
357
358 std::unique_ptr<DeadnessAnalysis> result;
359 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
360
361 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
362 }
363
TEST(DeadnessAnalysisTest,OrOfAnd)364 TEST(DeadnessAnalysisTest, OrOfAnd) {
365 Scope root = Scope::NewRootScope().ExitOnError();
366
367 ops::Switch sw_0 = CreateSwitch(root, "0");
368 ops::Switch sw_1 = CreateSwitch(root, "1");
369 ops::Switch sw_2 = CreateSwitch(root, "2");
370 ops::Switch sw_3 = CreateSwitch(root, "3");
371
372 Output add0 =
373 ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
374 Output add1 =
375 ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
376
377 ops::Merge m0(root.WithOpName("m0"), {add0, add1});
378 ops::Merge m1(root.WithOpName("m1"), {add0, add1});
379
380 Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
381
382 std::unique_ptr<DeadnessAnalysis> result;
383 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
384
385 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
386 }
387
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)388 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
389 // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
390 Scope root = Scope::NewRootScope().ExitOnError();
391
392 ops::Switch sw_0 = CreateSwitch(root, "A");
393 ops::Switch sw_1 = CreateSwitch(root, "B");
394 Output add0 =
395 ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
396 Output add1 =
397 ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
398 ops::Merge or2(root.WithOpName("or2"), {add0, add1});
399 Output add3 =
400 ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
401 ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
402
403 std::unique_ptr<DeadnessAnalysis> result;
404 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
405
406 PredicateMapTy predicate_map;
407 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
408 EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
409 }
410
TEST(DeadnessAnalysisTest,AndOrDistributive)411 TEST(DeadnessAnalysisTest, AndOrDistributive) {
412 // (A|B)&C == (A&C)|(B&C)
413 Scope root = Scope::NewRootScope().ExitOnError();
414
415 ops::Switch sw_0 = CreateSwitch(root, "0");
416 ops::Switch sw_1 = CreateSwitch(root, "1");
417 ops::Switch sw_2 = CreateSwitch(root, "2");
418
419 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
420 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
421
422 Output add1 =
423 ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
424 Output add2 =
425 ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
426 ops::Merge m1(root.WithOpName("m1"), {add1, add2});
427
428 Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
429
430 std::unique_ptr<DeadnessAnalysis> result;
431 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
432
433 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
434 }
435
TEST(DeadnessAnalysisTest,Ternary)436 TEST(DeadnessAnalysisTest, Ternary) {
437 Scope root = Scope::NewRootScope().ExitOnError();
438
439 Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
440 Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
441 Output false_value =
442 ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
443
444 ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
445 predicate);
446
447 ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
448 predicate);
449 ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
450 predicated_false.output_false});
451 Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
452 Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
453
454 std::unique_ptr<DeadnessAnalysis> result;
455 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
456
457 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
458 }
459
TEST(DeadnessAnalysisTest,Recv)460 TEST(DeadnessAnalysisTest, Recv) {
461 Scope root = Scope::NewRootScope().ExitOnError();
462
463 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
464 "sender", 0, "receiver");
465 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
466 "sender", 0, "receiver");
467 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
468
469 std::unique_ptr<DeadnessAnalysis> result;
470 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
471
472 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
473 }
474
TEST(DeadnessAnalysisTest,HostRecv)475 TEST(DeadnessAnalysisTest, HostRecv) {
476 Scope root = Scope::NewRootScope().ExitOnError();
477
478 Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
479 "tensor_a", "sender", 0, "receiver");
480 Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
481 "tensor_b", "sender", 0, "receiver");
482 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
483
484 std::unique_ptr<DeadnessAnalysis> result;
485 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
486
487 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
488 }
489
TEST(DeadnessAnalysisTest,Loop)490 TEST(DeadnessAnalysisTest, Loop) {
491 Scope root = Scope::NewRootScope().ExitOnError();
492 Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
493 Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
494 Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
495 Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
496 Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
497
498 // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
499 // noticed that. Today we are pessimistic here because we assign an
500 // uninterpreted symbol to merges with backedges.
501
502 VLogGraphIfAsked(*root.graph());
503
504 {
505 std::unique_ptr<DeadnessAnalysis> result;
506 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
507
508 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
509 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
510 }
511 {
512 PredicateMapTy predicate_map;
513 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
514
515 // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
516 // produce the same deadness. But we're not that smart today.
517 EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
518 "{#true,&,*iv0/cond:0}<fr0>");
519 EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
520 "{#true,&,*iv1/cond:0}<fr0>");
521 EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
522 "{#true,&,*iv2/cond:0}<fr0>");
523 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
524 "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
525 EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
526 "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
527 }
528 }
529
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)530 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
531 Scope root = Scope::NewRootScope().ExitOnError();
532 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
533 Output dependent_iv0 =
534 CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
535 .induction_var;
536 Output dependent_iv1 =
537 CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
538 .induction_var;
539 Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
540
541 VLogGraphIfAsked(*root.graph());
542
543 {
544 std::unique_ptr<DeadnessAnalysis> result;
545 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
546
547 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
548 }
549 {
550 PredicateMapTy predicate_map;
551 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
552
553 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
554 "{#true,&,*iv0/cond:0}<loop>");
555 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
556 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
557 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
558 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
559 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
560 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
561 }
562 }
563
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)564 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
565 // Create a merge that "looks like" a loop but isn't really. It has a value
566 // that does not depend on the merge on its backedge.
567 Scope root = Scope::NewRootScope().ExitOnError();
568 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
569 DependentInductionVar dependent_iv =
570 CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
571 FixupSourceAndSinkEdges(root.graph());
572
573 // To make deadness analysis think that dependent_iv is a loop we need an RPO
574 // that visits the merge before the backedge. This is a legal RPO for
575 // deadness analysis since it ignores NextIteration->Merge edges during RPO.
576 // Right now dependent_iv has an edge from Merge to NextIteration so do the
577 // RPO with this edge in place. Then remove this edge to get our test case.
578 std::vector<Node*> rpo;
579 GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
580 /*edge_filter=*/[](const Edge& edge) {
581 return !edge.src()->IsNextIteration();
582 });
583 TF_ASSERT_OK(root.graph()->UpdateEdge(
584 iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
585
586 VLogGraphIfAsked(*root.graph());
587
588 {
589 PredicateMapTy predicate_map;
590 TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
591
592 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
593 "div0/iv:0");
594 }
595 }
596
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)597 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
598 Scope root = Scope::NewRootScope().ExitOnError();
599 InductionVarInfo iv_outer =
600 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
601 Output enter_constant_outer_loop = ops::internal::Enter(
602 root.WithOpName("constant_enter_outer_loop"),
603 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
604 ops::internal::Enter::Attrs().IsConstant(true));
605 ops::Switch inner_value(root.WithOpName("outer_is_live"),
606 enter_constant_outer_loop, iv_outer.loop_cond);
607 InductionVarInfo iv_inner = CreateInductionVariable(
608 root, "iv_inner", "inner_loop", inner_value.output_true);
609
610 Output dependent_outer_iv0 =
611 CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
612 "outer_loop", iv_outer.loop_cond, 0)
613 .induction_var;
614 Output dependent_outer_iv1 =
615 CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
616 "outer_loop", iv_outer.loop_cond, 0)
617 .induction_var;
618
619 Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
620 root, "dependent_inner_iv0", "inner_loop",
621 iv_inner.loop_cond, dependent_outer_iv0)
622 .induction_var;
623 Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
624 root, "dependent_inner_iv1", "inner_loop",
625 iv_inner.loop_cond, dependent_outer_iv1)
626 .induction_var;
627
628 Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
629 dependent_inner_iv1);
630
631 VLogGraphIfAsked(*root.graph());
632
633 {
634 std::unique_ptr<DeadnessAnalysis> result;
635 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
636
637 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
638 }
639 {
640 PredicateMapTy predicate_map;
641 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
642
643 EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
644 "{#true,&,*iv_outer/cond:0}<outer_loop>");
645 EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
646 "{(*iv_outer/cond:0 & "
647 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
648 "cond:0}<inner_loop;outer_loop>");
649
650 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
651 "{{#true,&,(iv_outer/iv:0 & "
652 "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
653 "*iv_inner/cond:0)}<inner_loop;outer_loop>");
654
655 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
656 "{{#true,&,(iv_outer/iv:0 & "
657 "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
658 "*iv_inner/cond:0)}<inner_loop;outer_loop>");
659 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
660 "{{#true,&,(iv_outer/iv:0 & "
661 "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
662 "*iv_inner/cond:0)}<inner_loop;outer_loop>");
663 }
664 }
665
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)666 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
667 Scope root = Scope::NewRootScope().ExitOnError();
668
669 std::array<Output, 2> outer_iv;
670 std::array<Output, 2> inner_iv;
671
672 for (int i : {0, 1}) {
673 InductionVarInfo iv_outer =
674 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
675 Output enter_constant_outer_loop = ops::internal::Enter(
676 root.WithOpName("constant_enter_outer_loop"),
677 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
678 ops::internal::Enter::Attrs().IsConstant(true));
679 ops::Switch inner_value(root.WithOpName("outer_is_live"),
680 enter_constant_outer_loop, iv_outer.loop_cond);
681 InductionVarInfo iv_inner = CreateInductionVariable(
682 root, "iv_inner", "inner_loop", inner_value.output_true);
683
684 outer_iv[i] = iv_outer.induction_var;
685 inner_iv[i] = iv_inner.induction_var;
686 }
687
688 Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
689
690 VLogGraphIfAsked(*root.graph());
691
692 {
693 std::unique_ptr<DeadnessAnalysis> result;
694 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
695
696 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
697 }
698
699 {
700 PredicateMapTy predicate_map;
701 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
702
703 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
704 "{#true,&,*iv_outer/cond:0}<outer_loop>");
705 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
706 "{(*iv_outer/cond:0 & "
707 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
708 "cond:0}<inner_loop;outer_loop>");
709 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
710 "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
711 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
712 "{(*iv_outer/cond_1:0 & "
713 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
714 "cond_1:0}<inner_loop;outer_loop>");
715 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
716 "({(*iv_outer/cond:0 & "
717 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
718 "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
719 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
720 "cond_1:0}<inner_loop;outer_loop>)");
721 }
722 }
723
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)724 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
725 Scope root = Scope::NewRootScope().ExitOnError();
726 InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
727 InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
728
729 Output init = CreateSwitch(root, "init").output_true;
730 Output step = CreateSwitch(root, "step").output_true;
731
732 std::array<Output, 2> exits;
733 std::array<Output, 2> next_iterations;
734
735 for (int i : {0, 1}) {
736 Output init_enter = ops::internal::Enter(
737 root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
738 absl::StrCat("frame_", i),
739 ops::internal::Enter::Attrs().IsConstant(true));
740 Output step_enter = ops::internal::Enter(
741 root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
742 absl::StrCat("frame_", i),
743 ops::internal::Enter::Attrs().IsConstant(true));
744
745 ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
746 {init_enter, init_enter});
747 Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
748 step_enter);
749 next_iterations[i] = ops::NextIteration(
750 root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
751 EXPECT_TRUE(
752 root.graph()
753 ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
754 .ok());
755 exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
756 iv.output);
757 }
758
759 FixupSourceAndSinkEdges(root.graph());
760
761 {
762 PredicateMapTy predicate_map;
763 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
764
765 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
766 predicate_map[ControlOutputFor(exits[1])]);
767 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
768 EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
769
770 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
771 predicate_map[ControlOutputFor(next_iterations[1])]);
772 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
773 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
774 }
775 }
776
TEST(DeadnessAnalysisTest,ControlInputs)777 TEST(DeadnessAnalysisTest, ControlInputs) {
778 Scope root = Scope::NewRootScope().ExitOnError();
779 ops::Switch sw = CreateSwitch(root, "0");
780
781 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
782 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
783
784 Output const0 = ops::Const(root.WithOpName("const0"), 1);
785 Output const1 = ops::Const(root.WithOpName("const1"), 2);
786
787 Output add = ops::Add(root.WithOpName("add"), const0, const1);
788
789 root.graph()->AddControlEdge(id0.node(), const0.node());
790 root.graph()->AddControlEdge(id1.node(), const1.node());
791
792 std::unique_ptr<DeadnessAnalysis> result;
793 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
794
795 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
796 }
797
TEST(DeadnessAnalysisTest,ControlTrigger)798 TEST(DeadnessAnalysisTest, ControlTrigger) {
799 Scope root = Scope::NewRootScope().ExitOnError();
800 ops::Switch sw = CreateSwitch(root, "0");
801
802 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
803 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
804
805 ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
806 ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
807
808 Output const0 = ops::Const(root.WithOpName("const0"), 1);
809 Output const1 = ops::Const(root.WithOpName("const1"), 2);
810
811 Output add = ops::Add(root.WithOpName("add"), const0, const1);
812
813 root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
814 root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
815
816 root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
817 root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
818
819 std::unique_ptr<DeadnessAnalysis> result;
820 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
821
822 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
823 }
824
TEST(DeadnessAnalysisTest,ControlInputsToMerge)825 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
826 Scope root = Scope::NewRootScope().ExitOnError();
827 ops::Switch sw = CreateSwitch(root, "0");
828
829 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
830 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
831
832 Output constant = ops::Const(root.WithOpName("constant"), 5);
833 ops::Merge m0(root.WithOpName("m0"), {constant});
834 ops::Merge m1(root.WithOpName("m0"), {constant});
835 Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
836
837 root.graph()->AddControlEdge(id0.node(), m0.output.node());
838 root.graph()->AddControlEdge(id1.node(), m1.output.node());
839
840 std::unique_ptr<DeadnessAnalysis> result;
841 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
842
843 EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
844 }
845
TEST(DeadnessAnalysisTest,RecvVsSwitch)846 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
847 // Demonstrates why we need the must_be_true bit on SymbolP.
848 Scope root = Scope::NewRootScope().ExitOnError();
849
850 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
851 0, "receiver");
852 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
853 ops::Switch sw(root.WithOpName("switch"), value, recv);
854 Output logical_and =
855 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
856
857 std::unique_ptr<DeadnessAnalysis> result;
858 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
859
860 EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
861 }
862
TEST(DeadnessAnalysisTest,RecvVsSwitchText)863 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
864 // Demonstrates why we need the must_be_true bit on SymbolP.
865 Scope root = Scope::NewRootScope().ExitOnError();
866
867 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
868 0, "receiver");
869 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
870 ops::Switch sw(root.WithOpName("switch"), value, recv);
871 Output logical_and =
872 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
873
874 std::unique_ptr<DeadnessAnalysis> result;
875 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
876
877 PredicateMapTy predicate_map;
878 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
879
880 TensorId logical_and_output_0 = {logical_and.node()->name(),
881 Graph::kControlSlot};
882 EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
883 }
884
TEST(DeadnessAnalysisTest,DeMorgan)885 TEST(DeadnessAnalysisTest, DeMorgan) {
886 Scope root = Scope::NewRootScope().ExitOnError();
887
888 Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
889 Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
890 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
891
892 ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
893 ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
894
895 Output and_0_1 =
896 ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
897
898 Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
899 {sw_0.output_false, sw_1.output_false})
900 .output;
901
902 // Predicate(should_always_be_dead) =
903 // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
904 Output should_always_be_dead =
905 ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
906
907 // Predicate(should_always_be_dead) =
908 // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
909 Output should_always_be_alive =
910 ops::Merge(root.WithOpName("should_always_be_alive"),
911 {and_0_1, or_not0_not1})
912 .output;
913
914 std::unique_ptr<DeadnessAnalysis> result;
915 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
916
917 PredicateMapTy predicate_map;
918 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
919
920 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
921 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
922 }
923
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)924 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
925 Scope root = Scope::NewRootScope().ExitOnError();
926
927 Output constant_true = ops::Const(root.WithOpName("const_true"), true);
928 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
929 ops::Switch sw(root.WithOpName("switch"), value, constant_true);
930
931 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
932 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
933
934 FixupSourceAndSinkEdges(root.graph());
935
936 PredicateMapTy predicate_map;
937 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
938
939 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
940 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
941 }
942
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)943 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
944 Scope root = Scope::NewRootScope().ExitOnError();
945
946 Output constant_false = ops::Const(root.WithOpName("const_false"), false);
947 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
948 ops::Switch sw(root.WithOpName("switch"), value, constant_false);
949
950 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
951 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
952
953 FixupSourceAndSinkEdges(root.graph());
954
955 PredicateMapTy predicate_map;
956 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
957
958 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
959 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
960 }
961
962 } // namespace
963 } // namespace tensorflow
964