• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/graph/graph_def_builder_util.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 using deadness_analysis_internal::ComputePredicates;
41 using deadness_analysis_internal::PredicateMapTy;
42 
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)43 Status AnalyzeDeadness(Graph* graph,
44                        std::unique_ptr<DeadnessAnalysis>* result) {
45   FixupSourceAndSinkEdges(graph);
46   return DeadnessAnalysis::Run(*graph, result);
47 }
48 
CreateSwitch(const Scope & root,const string & prefix)49 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
50   Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
51   Output predicate =
52       ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
53   return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
54 }
55 
ControlOutputFor(const Output & o)56 TensorId ControlOutputFor(const Output& o) {
57   return {o.node()->name(), Graph::kControlSlot};
58 }
59 
VLogGraphIfAsked(const Graph & graph)60 void VLogGraphIfAsked(const Graph& graph) {
61   if (VLOG_IS_ON(3)) {
62     GraphDef graph_def;
63     graph.ToGraphDef(&graph_def);
64     string serialized;
65     ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
66     LOG(INFO) << serialized;
67   }
68 }
69 
70 struct InductionVarInfo {
71   Output induction_var;
72   Output loop_cond;
73 };
74 
75 // Creates an induction variable with the following structure (simplified for
76 // brevity):
77 //
78 //            +---------------+
79 //            | initial_value |
80 //            +---------------+
81 //              |
82 //              |
83 //              v
84 //            +---------------+
85 //            |     Enter     |
86 //            +---------------+
87 //              |
88 //              |
89 //              v
90 //            +---------------+
91 //         +> |     Merge     | -+
92 //         |  +---------------+  |
93 //         |    |                |
94 //         |    |                |
95 //         |    v                |
96 //         |  +---------------+  |
97 //         |  |  LessThan10   |  |
98 //         |  +---------------+  |
99 //         |    |                |
100 //         |    |                |
101 //         |    v                |
102 //         |  +---------------+  |
103 //    +----+- |    Switch     | <+
104 //    |    |  +---------------+
105 //    |    |    |
106 //    |    |    |
107 //    |    |    v
108 //    |    |  +---------------+
109 //    |    +- |    AddOne     |
110 //    |       +---------------+
111 //    |       +---------------+
112 //    +-----> |     Exit      |
113 //            +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)114 InductionVarInfo CreateInductionVariable(const Scope& root,
115                                          const string& prefix,
116                                          const string& frame_name,
117                                          const Output& initial_value) {
118   Output enter_initial_value = ops::internal::Enter(
119       root.WithOpName(prefix + "/enter"), initial_value, frame_name);
120 
121   ops::Merge iv(root.WithOpName(prefix + "/iv"),
122                 {enter_initial_value, enter_initial_value});
123   Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
124   Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
125   Output loop_cond_expr =
126       ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
127   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
128                     loop_cond_expr);
129   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
130                            latch.output_false);
131   Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
132                             latch.output_true, increment_by);
133   Output next_iteration =
134       ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
135 
136   CHECK(root.graph()
137             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
138             .ok());
139   root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
140   root.graph()->AddControlEdge(iv.output.node(), final_value.node());
141 
142   return {iv.output, loop_cond_expr};
143 }
144 
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32 init)145 InductionVarInfo CreateInductionVariable(const Scope& root,
146                                          const string& prefix,
147                                          const string& frame_name, int32 init) {
148   return CreateInductionVariable(
149       root, prefix, frame_name,
150       ops::Const(root.WithOpName(prefix + "/init"), init));
151 }
152 
153 // Creates an induction variable with the following structure:
154 //
155 //                           +---------------+
156 //                           | initial_value |
157 //                           +---------------+
158 //                             |
159 //                             |
160 //                             v
161 //                           +---------------+
162 //                           |     Enter     |
163 //                           +---------------+
164 //                             |
165 //                             |
166 //                             v
167 //                           +---------------+
168 //                           |     Merge     | <+
169 //                           +---------------+  |
170 //                             |                |
171 //                             |                |
172 //                             v                |
173 //         +-----------+     +---------------+  |
174 //         | loop_cond | --> |    Switch     | -+
175 //         +-----------+     +---------------+
176 //                             |
177 //                             |
178 //                             v
179 //                           +---------------+
180 //                           |     Exit      |
181 //                           +---------------+
182 struct DependentInductionVar {
183   Output induction_var;
184   ops::Switch latch;
185 };
186 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)187 DependentInductionVar CreateDependentLoopInvariantValue(
188     const Scope& root, const string& prefix, const string& frame_name,
189     const Output& loop_cond, const Output& value) {
190   Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
191                                             value, frame_name);
192   ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
193   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
194   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
195                            latch.output_false);
196   Output next_iteration = ops::NextIteration(
197       root.WithOpName(prefix + "/next_iteration"), latch.output_true);
198   CHECK(root.graph()
199             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
200             .ok());
201   return {iv.output, latch};
202 }
203 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32 value)204 DependentInductionVar CreateDependentLoopInvariantValue(
205     const Scope& root, const string& prefix, const string& frame_name,
206     const Output& loop_cond, int32 value) {
207   return CreateDependentLoopInvariantValue(
208       root, prefix, frame_name, loop_cond,
209       ops::Const(root.WithOpName(prefix + "/init"), value));
210 }
211 
TEST(DeadnessAnalysisTest,BasicPositive)212 TEST(DeadnessAnalysisTest, BasicPositive) {
213   Scope root = Scope::NewRootScope().ExitOnError();
214 
215   ops::Switch sw = CreateSwitch(root, "0");
216   Output add =
217       ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
218 
219   std::unique_ptr<DeadnessAnalysis> result;
220   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
221 
222   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
223 }
224 
TEST(DeadnessAnalysisTest,BasicNegative)225 TEST(DeadnessAnalysisTest, BasicNegative) {
226   Scope root = Scope::NewRootScope().ExitOnError();
227 
228   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
229   Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
230   Output add = ops::Add(root.WithOpName("add"), a, b);
231 
232   std::unique_ptr<DeadnessAnalysis> result;
233   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
234 
235   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
236 }
237 
TEST(DeadnessAnalysisTest,AndIsCommutative)238 TEST(DeadnessAnalysisTest, AndIsCommutative) {
239   Scope root = Scope::NewRootScope().ExitOnError();
240 
241   ops::Switch sw_0 = CreateSwitch(root, "0");
242   ops::Switch sw_1 = CreateSwitch(root, "1");
243 
244   Output a0 =
245       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
246   Output a1 =
247       ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
248 
249   Output b0 =
250       ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
251   Output b1 =
252       ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
253 
254   Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
255   Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
256 
257   Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
258   Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
259 
260   std::unique_ptr<DeadnessAnalysis> result;
261   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
262 
263   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
264   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
265 
266   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
267   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
268 }
269 
TEST(DeadnessAnalysisTest,AndIsAssociative)270 TEST(DeadnessAnalysisTest, AndIsAssociative) {
271   Scope root = Scope::NewRootScope().ExitOnError();
272 
273   ops::Switch sw_0 = CreateSwitch(root, "0");
274   ops::Switch sw_1 = CreateSwitch(root, "1");
275   ops::Switch sw_2 = CreateSwitch(root, "2");
276 
277   Output a0 =
278       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
279   Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
280 
281   Output b0 =
282       ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
283   Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
284 
285   Output add = ops::Add(root.WithOpName("add"), a1, b1);
286 
287   std::unique_ptr<DeadnessAnalysis> result;
288   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
289 
290   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
291 }
292 
TEST(DeadnessAnalysisTest,OrIsCommutative)293 TEST(DeadnessAnalysisTest, OrIsCommutative) {
294   Scope root = Scope::NewRootScope().ExitOnError();
295 
296   ops::Switch sw_0 = CreateSwitch(root, "0");
297   ops::Switch sw_1 = CreateSwitch(root, "1");
298 
299   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
300   ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
301   ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
302   ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
303 
304   Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
305   Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
306 
307   Output halfdead0 =
308       ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
309   Output halfdead1 =
310       ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
311 
312   std::unique_ptr<DeadnessAnalysis> result;
313   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
314 
315   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
316   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
317 
318   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
319   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
320 }
321 
TEST(DeadnessAnalysisTest,OrIsAssociative)322 TEST(DeadnessAnalysisTest, OrIsAssociative) {
323   Scope root = Scope::NewRootScope().ExitOnError();
324 
325   ops::Switch sw_0 = CreateSwitch(root, "0");
326   ops::Switch sw_1 = CreateSwitch(root, "1");
327   ops::Switch sw_2 = CreateSwitch(root, "2");
328 
329   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
330   ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
331   ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
332   ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
333 
334   Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
335 
336   std::unique_ptr<DeadnessAnalysis> result;
337   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
338 
339   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
340 }
341 
TEST(DeadnessAnalysisTest,AndOfOr)342 TEST(DeadnessAnalysisTest, AndOfOr) {
343   Scope root = Scope::NewRootScope().ExitOnError();
344 
345   ops::Switch sw_0 = CreateSwitch(root, "0");
346   ops::Switch sw_1 = CreateSwitch(root, "1");
347   ops::Switch sw_2 = CreateSwitch(root, "2");
348   ops::Switch sw_3 = CreateSwitch(root, "3");
349 
350   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
351   ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
352 
353   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
354   Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
355 
356   Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
357 
358   std::unique_ptr<DeadnessAnalysis> result;
359   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
360 
361   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
362 }
363 
TEST(DeadnessAnalysisTest,OrOfAnd)364 TEST(DeadnessAnalysisTest, OrOfAnd) {
365   Scope root = Scope::NewRootScope().ExitOnError();
366 
367   ops::Switch sw_0 = CreateSwitch(root, "0");
368   ops::Switch sw_1 = CreateSwitch(root, "1");
369   ops::Switch sw_2 = CreateSwitch(root, "2");
370   ops::Switch sw_3 = CreateSwitch(root, "3");
371 
372   Output add0 =
373       ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
374   Output add1 =
375       ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
376 
377   ops::Merge m0(root.WithOpName("m0"), {add0, add1});
378   ops::Merge m1(root.WithOpName("m1"), {add0, add1});
379 
380   Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
381 
382   std::unique_ptr<DeadnessAnalysis> result;
383   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
384 
385   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
386 }
387 
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)388 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
389   // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
390   Scope root = Scope::NewRootScope().ExitOnError();
391 
392   ops::Switch sw_0 = CreateSwitch(root, "A");
393   ops::Switch sw_1 = CreateSwitch(root, "B");
394   Output add0 =
395       ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
396   Output add1 =
397       ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
398   ops::Merge or2(root.WithOpName("or2"), {add0, add1});
399   Output add3 =
400       ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
401   ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
402 
403   std::unique_ptr<DeadnessAnalysis> result;
404   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
405 
406   PredicateMapTy predicate_map;
407   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
408   EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
409 }
410 
TEST(DeadnessAnalysisTest,AndOrDistributive)411 TEST(DeadnessAnalysisTest, AndOrDistributive) {
412   // (A|B)&C == (A&C)|(B&C)
413   Scope root = Scope::NewRootScope().ExitOnError();
414 
415   ops::Switch sw_0 = CreateSwitch(root, "0");
416   ops::Switch sw_1 = CreateSwitch(root, "1");
417   ops::Switch sw_2 = CreateSwitch(root, "2");
418 
419   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
420   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
421 
422   Output add1 =
423       ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
424   Output add2 =
425       ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
426   ops::Merge m1(root.WithOpName("m1"), {add1, add2});
427 
428   Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
429 
430   std::unique_ptr<DeadnessAnalysis> result;
431   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
432 
433   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
434 }
435 
TEST(DeadnessAnalysisTest,Ternary)436 TEST(DeadnessAnalysisTest, Ternary) {
437   Scope root = Scope::NewRootScope().ExitOnError();
438 
439   Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
440   Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
441   Output false_value =
442       ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
443 
444   ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
445                               predicate);
446 
447   ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
448                                predicate);
449   ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
450                                                 predicated_false.output_false});
451   Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
452   Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
453 
454   std::unique_ptr<DeadnessAnalysis> result;
455   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
456 
457   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
458 }
459 
TEST(DeadnessAnalysisTest,Recv)460 TEST(DeadnessAnalysisTest, Recv) {
461   Scope root = Scope::NewRootScope().ExitOnError();
462 
463   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
464                              "sender", 0, "receiver");
465   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
466                              "sender", 0, "receiver");
467   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
468 
469   std::unique_ptr<DeadnessAnalysis> result;
470   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
471 
472   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
473 }
474 
TEST(DeadnessAnalysisTest,HostRecv)475 TEST(DeadnessAnalysisTest, HostRecv) {
476   Scope root = Scope::NewRootScope().ExitOnError();
477 
478   Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
479                                  "tensor_a", "sender", 0, "receiver");
480   Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
481                                  "tensor_b", "sender", 0, "receiver");
482   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
483 
484   std::unique_ptr<DeadnessAnalysis> result;
485   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
486 
487   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
488 }
489 
TEST(DeadnessAnalysisTest,Loop)490 TEST(DeadnessAnalysisTest, Loop) {
491   Scope root = Scope::NewRootScope().ExitOnError();
492   Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
493   Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
494   Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
495   Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
496   Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
497 
498   // NB!  iv0 and iv1 are equivalent and a smarter deadness analysis would have
499   // noticed that.  Today we are pessimistic here because we assign an
500   // uninterpreted symbol to merges with backedges.
501 
502   VLogGraphIfAsked(*root.graph());
503 
504   {
505     std::unique_ptr<DeadnessAnalysis> result;
506     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
507 
508     EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
509     EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
510   }
511   {
512     PredicateMapTy predicate_map;
513     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
514 
515     // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
516     // produce the same deadness.  But we're not that smart today.
517     EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
518               "{#true,&,*iv0/cond:0}<fr0>");
519     EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
520               "{#true,&,*iv1/cond:0}<fr0>");
521     EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
522               "{#true,&,*iv2/cond:0}<fr0>");
523     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
524               "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
525     EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
526               "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
527   }
528 }
529 
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)530 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
531   Scope root = Scope::NewRootScope().ExitOnError();
532   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
533   Output dependent_iv0 =
534       CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
535           .induction_var;
536   Output dependent_iv1 =
537       CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
538           .induction_var;
539   Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
540 
541   VLogGraphIfAsked(*root.graph());
542 
543   {
544     std::unique_ptr<DeadnessAnalysis> result;
545     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
546 
547     EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
548   }
549   {
550     PredicateMapTy predicate_map;
551     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
552 
553     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
554               "{#true,&,*iv0/cond:0}<loop>");
555     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
556               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
557     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
558               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
559     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
560               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
561   }
562 }
563 
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)564 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
565   // Create a merge that "looks like" a loop but isn't really.  It has a value
566   // that does not depend on the merge on its backedge.
567   Scope root = Scope::NewRootScope().ExitOnError();
568   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
569   DependentInductionVar dependent_iv =
570       CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
571   FixupSourceAndSinkEdges(root.graph());
572 
573   // To make deadness analysis think that dependent_iv is a loop we need an RPO
574   // that visits the merge before the backedge.  This is a legal RPO for
575   // deadness analysis since it ignores NextIteration->Merge edges during RPO.
576   // Right now dependent_iv has an edge from Merge to NextIteration so do the
577   // RPO with this edge in place.  Then remove this edge to get our test case.
578   std::vector<Node*> rpo;
579   GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
580                       /*edge_filter=*/[](const Edge& edge) {
581                         return !edge.src()->IsNextIteration();
582                       });
583   TF_ASSERT_OK(root.graph()->UpdateEdge(
584       iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
585 
586   VLogGraphIfAsked(*root.graph());
587 
588   {
589     PredicateMapTy predicate_map;
590     TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
591 
592     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
593               "div0/iv:0");
594   }
595 }
596 
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)597 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
598   Scope root = Scope::NewRootScope().ExitOnError();
599   InductionVarInfo iv_outer =
600       CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
601   Output enter_constant_outer_loop = ops::internal::Enter(
602       root.WithOpName("constant_enter_outer_loop"),
603       ops::Const(root.WithOpName("constant"), 5), "outer_loop",
604       ops::internal::Enter::Attrs().IsConstant(true));
605   ops::Switch inner_value(root.WithOpName("outer_is_live"),
606                           enter_constant_outer_loop, iv_outer.loop_cond);
607   InductionVarInfo iv_inner = CreateInductionVariable(
608       root, "iv_inner", "inner_loop", inner_value.output_true);
609 
610   Output dependent_outer_iv0 =
611       CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
612                                         "outer_loop", iv_outer.loop_cond, 0)
613           .induction_var;
614   Output dependent_outer_iv1 =
615       CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
616                                         "outer_loop", iv_outer.loop_cond, 0)
617           .induction_var;
618 
619   Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
620                                    root, "dependent_inner_iv0", "inner_loop",
621                                    iv_inner.loop_cond, dependent_outer_iv0)
622                                    .induction_var;
623   Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
624                                    root, "dependent_inner_iv1", "inner_loop",
625                                    iv_inner.loop_cond, dependent_outer_iv1)
626                                    .induction_var;
627 
628   Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
629                          dependent_inner_iv1);
630 
631   VLogGraphIfAsked(*root.graph());
632 
633   {
634     std::unique_ptr<DeadnessAnalysis> result;
635     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
636 
637     EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
638   }
639   {
640     PredicateMapTy predicate_map;
641     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
642 
643     EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
644               "{#true,&,*iv_outer/cond:0}<outer_loop>");
645     EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
646               "{(*iv_outer/cond:0 & "
647               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
648               "cond:0}<inner_loop;outer_loop>");
649 
650     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
651               "{{#true,&,(iv_outer/iv:0 & "
652               "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
653               "*iv_inner/cond:0)}<inner_loop;outer_loop>");
654 
655     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
656               "{{#true,&,(iv_outer/iv:0 & "
657               "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
658               "*iv_inner/cond:0)}<inner_loop;outer_loop>");
659     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
660               "{{#true,&,(iv_outer/iv:0 & "
661               "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
662               "*iv_inner/cond:0)}<inner_loop;outer_loop>");
663   }
664 }
665 
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)666 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
667   Scope root = Scope::NewRootScope().ExitOnError();
668 
669   std::array<Output, 2> outer_iv;
670   std::array<Output, 2> inner_iv;
671 
672   for (int i : {0, 1}) {
673     InductionVarInfo iv_outer =
674         CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
675     Output enter_constant_outer_loop = ops::internal::Enter(
676         root.WithOpName("constant_enter_outer_loop"),
677         ops::Const(root.WithOpName("constant"), 5), "outer_loop",
678         ops::internal::Enter::Attrs().IsConstant(true));
679     ops::Switch inner_value(root.WithOpName("outer_is_live"),
680                             enter_constant_outer_loop, iv_outer.loop_cond);
681     InductionVarInfo iv_inner = CreateInductionVariable(
682         root, "iv_inner", "inner_loop", inner_value.output_true);
683 
684     outer_iv[i] = iv_outer.induction_var;
685     inner_iv[i] = iv_inner.induction_var;
686   }
687 
688   Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
689 
690   VLogGraphIfAsked(*root.graph());
691 
692   {
693     std::unique_ptr<DeadnessAnalysis> result;
694     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
695 
696     EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
697   }
698 
699   {
700     PredicateMapTy predicate_map;
701     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
702 
703     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
704               "{#true,&,*iv_outer/cond:0}<outer_loop>");
705     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
706               "{(*iv_outer/cond:0 & "
707               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
708               "cond:0}<inner_loop;outer_loop>");
709     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
710               "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
711     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
712               "{(*iv_outer/cond_1:0 & "
713               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
714               "cond_1:0}<inner_loop;outer_loop>");
715     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
716               "({(*iv_outer/cond:0 & "
717               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
718               "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
719               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
720               "cond_1:0}<inner_loop;outer_loop>)");
721   }
722 }
723 
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)724 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
725   Scope root = Scope::NewRootScope().ExitOnError();
726   InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
727   InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
728 
729   Output init = CreateSwitch(root, "init").output_true;
730   Output step = CreateSwitch(root, "step").output_true;
731 
732   std::array<Output, 2> exits;
733   std::array<Output, 2> next_iterations;
734 
735   for (int i : {0, 1}) {
736     Output init_enter = ops::internal::Enter(
737         root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
738         absl::StrCat("frame_", i),
739         ops::internal::Enter::Attrs().IsConstant(true));
740     Output step_enter = ops::internal::Enter(
741         root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
742         absl::StrCat("frame_", i),
743         ops::internal::Enter::Attrs().IsConstant(true));
744 
745     ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
746                   {init_enter, init_enter});
747     Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
748                           step_enter);
749     next_iterations[i] = ops::NextIteration(
750         root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
751     EXPECT_TRUE(
752         root.graph()
753             ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
754             .ok());
755     exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
756                                    iv.output);
757   }
758 
759   FixupSourceAndSinkEdges(root.graph());
760 
761   {
762     PredicateMapTy predicate_map;
763     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
764 
765     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
766               predicate_map[ControlOutputFor(exits[1])]);
767     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
768     EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
769 
770     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
771               predicate_map[ControlOutputFor(next_iterations[1])]);
772     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
773     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
774   }
775 }
776 
TEST(DeadnessAnalysisTest,ControlInputs)777 TEST(DeadnessAnalysisTest, ControlInputs) {
778   Scope root = Scope::NewRootScope().ExitOnError();
779   ops::Switch sw = CreateSwitch(root, "0");
780 
781   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
782   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
783 
784   Output const0 = ops::Const(root.WithOpName("const0"), 1);
785   Output const1 = ops::Const(root.WithOpName("const1"), 2);
786 
787   Output add = ops::Add(root.WithOpName("add"), const0, const1);
788 
789   root.graph()->AddControlEdge(id0.node(), const0.node());
790   root.graph()->AddControlEdge(id1.node(), const1.node());
791 
792   std::unique_ptr<DeadnessAnalysis> result;
793   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
794 
795   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
796 }
797 
TEST(DeadnessAnalysisTest,ControlTrigger)798 TEST(DeadnessAnalysisTest, ControlTrigger) {
799   Scope root = Scope::NewRootScope().ExitOnError();
800   ops::Switch sw = CreateSwitch(root, "0");
801 
802   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
803   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
804 
805   ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
806   ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
807 
808   Output const0 = ops::Const(root.WithOpName("const0"), 1);
809   Output const1 = ops::Const(root.WithOpName("const1"), 2);
810 
811   Output add = ops::Add(root.WithOpName("add"), const0, const1);
812 
813   root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
814   root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
815 
816   root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
817   root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
818 
819   std::unique_ptr<DeadnessAnalysis> result;
820   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
821 
822   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
823 }
824 
TEST(DeadnessAnalysisTest,ControlInputsToMerge)825 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
826   Scope root = Scope::NewRootScope().ExitOnError();
827   ops::Switch sw = CreateSwitch(root, "0");
828 
829   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
830   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
831 
832   Output constant = ops::Const(root.WithOpName("constant"), 5);
833   ops::Merge m0(root.WithOpName("m0"), {constant});
834   ops::Merge m1(root.WithOpName("m0"), {constant});
835   Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
836 
837   root.graph()->AddControlEdge(id0.node(), m0.output.node());
838   root.graph()->AddControlEdge(id1.node(), m1.output.node());
839 
840   std::unique_ptr<DeadnessAnalysis> result;
841   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
842 
843   EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
844 }
845 
TEST(DeadnessAnalysisTest,RecvVsSwitch)846 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
847   // Demonstrates why we need the must_be_true bit on SymbolP.
848   Scope root = Scope::NewRootScope().ExitOnError();
849 
850   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
851                            0, "receiver");
852   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
853   ops::Switch sw(root.WithOpName("switch"), value, recv);
854   Output logical_and =
855       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
856 
857   std::unique_ptr<DeadnessAnalysis> result;
858   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
859 
860   EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
861 }
862 
TEST(DeadnessAnalysisTest,RecvVsSwitchText)863 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
864   // Demonstrates why we need the must_be_true bit on SymbolP.
865   Scope root = Scope::NewRootScope().ExitOnError();
866 
867   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
868                            0, "receiver");
869   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
870   ops::Switch sw(root.WithOpName("switch"), value, recv);
871   Output logical_and =
872       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
873 
874   std::unique_ptr<DeadnessAnalysis> result;
875   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
876 
877   PredicateMapTy predicate_map;
878   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
879 
880   TensorId logical_and_output_0 = {logical_and.node()->name(),
881                                    Graph::kControlSlot};
882   EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
883 }
884 
TEST(DeadnessAnalysisTest,DeMorgan)885 TEST(DeadnessAnalysisTest, DeMorgan) {
886   Scope root = Scope::NewRootScope().ExitOnError();
887 
888   Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
889   Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
890   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
891 
892   ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
893   ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
894 
895   Output and_0_1 =
896       ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
897 
898   Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
899                                    {sw_0.output_false, sw_1.output_false})
900                             .output;
901 
902   // Predicate(should_always_be_dead) =
903   // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
904   Output should_always_be_dead =
905       ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
906 
907   // Predicate(should_always_be_dead) =
908   // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
909   Output should_always_be_alive =
910       ops::Merge(root.WithOpName("should_always_be_alive"),
911                  {and_0_1, or_not0_not1})
912           .output;
913 
914   std::unique_ptr<DeadnessAnalysis> result;
915   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
916 
917   PredicateMapTy predicate_map;
918   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
919 
920   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
921   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
922 }
923 
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)924 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
925   Scope root = Scope::NewRootScope().ExitOnError();
926 
927   Output constant_true = ops::Const(root.WithOpName("const_true"), true);
928   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
929   ops::Switch sw(root.WithOpName("switch"), value, constant_true);
930 
931   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
932   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
933 
934   FixupSourceAndSinkEdges(root.graph());
935 
936   PredicateMapTy predicate_map;
937   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
938 
939   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
940   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
941 }
942 
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)943 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
944   Scope root = Scope::NewRootScope().ExitOnError();
945 
946   Output constant_false = ops::Const(root.WithOpName("const_false"), false);
947   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
948   ops::Switch sw(root.WithOpName("switch"), value, constant_false);
949 
950   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
951   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
952 
953   FixupSourceAndSinkEdges(root.graph());
954 
955   PredicateMapTy predicate_map;
956   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
957 
958   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
959   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
960 }
961 
962 }  // namespace
963 }  // namespace tensorflow
964