• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/resource_variable_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/test.h"
37 
38 namespace tensorflow {
39 namespace {
40 
MakeRead(const Scope & scope,const string & id)41 Node* MakeRead(const Scope& scope, const string& id) {
42   Output var_handle =
43       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
44   Output read =
45       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
46   return read.node();
47 }
48 
MakeWrite(const Scope & scope,const string & id)49 Node* MakeWrite(const Scope& scope, const string& id) {
50   Output var_handle =
51       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
52   Output value_to_write =
53       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
54   ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
55                                   value_to_write);
56   return assign_op.operation.node();
57 }
58 
MakeModify(const Scope & scope,const string & id)59 Node* MakeModify(const Scope& scope, const string& id) {
60   Output var_handle =
61       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
62   Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
63   ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
64                                          var_handle, value_to_write);
65   return assign_add_op.operation.node();
66 }
67 
MakeNeutral(const Scope & scope,const string & id)68 Node* MakeNeutral(const Scope& scope, const string& id) {
69   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
70 }
71 
ComputeIncompatiblePairs(Graph * g,std::vector<std::pair<int,int>> * result)72 Status ComputeIncompatiblePairs(Graph* g,
73                                 std::vector<std::pair<int, int>>* result) {
74   FixupSourceAndSinkEdges(g);
75   return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
76                                                    result);
77 }
78 
TEST(ResourceOperationSafetyAnalysisTest,WriteRead)79 TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
80   Scope root = Scope::NewRootScope().ExitOnError();
81 
82   Node* read = MakeRead(root, "R");
83   Node* write = MakeWrite(root, "W");
84 
85   root.graph()->AddControlEdge(write, read);
86 
87   std::vector<std::pair<int, int>> incompatible_pairs;
88   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
89 
90   ASSERT_EQ(incompatible_pairs.size(), 1);
91   std::pair<int, int> write_read_pair = {write->id(), read->id()};
92   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
93 }
94 
TEST(ResourceOperationSafetyAnalysisTest,ReadWrite)95 TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
96   Scope root = Scope::NewRootScope().ExitOnError();
97 
98   Node* read = MakeRead(root, "R");
99   Node* write = MakeWrite(root, "W");
100 
101   root.graph()->AddControlEdge(read, write);
102 
103   std::vector<std::pair<int, int>> incompatible_pairs;
104   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
105 
106   EXPECT_EQ(incompatible_pairs.size(), 0);
107 }
108 
TEST(ResourceOperationSafetyAnalysisTest,ReadWriteNoEdges)109 TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
110   Scope root = Scope::NewRootScope().ExitOnError();
111 
112   MakeRead(root, "R");
113   MakeWrite(root, "W");
114 
115   std::vector<std::pair<int, int>> incompatible_pairs;
116   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
117 
118   EXPECT_EQ(incompatible_pairs.size(), 0);
119 }
120 
TEST(ResourceOperationSafetyAnalysisTest,ReadModify)121 TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
122   Scope root = Scope::NewRootScope().ExitOnError();
123 
124   Node* read = MakeRead(root, "R");
125   Node* modify = MakeModify(root, "M");
126 
127   root.graph()->AddControlEdge(read, modify);
128 
129   std::vector<std::pair<int, int>> incompatible_pairs;
130   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
131 
132   EXPECT_EQ(incompatible_pairs.size(), 0);
133 }
134 
TEST(ResourceOperationSafetyAnalysisTest,ModifyRead)135 TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
136   Scope root = Scope::NewRootScope().ExitOnError();
137 
138   Node* read = MakeRead(root, "R");
139   Node* modify = MakeModify(root, "M");
140 
141   root.graph()->AddControlEdge(modify, read);
142 
143   std::vector<std::pair<int, int>> incompatible_pairs;
144   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
145 
146   ASSERT_EQ(incompatible_pairs.size(), 1);
147   std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
148   EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
149 }
150 
TEST(ResourceOperationSafetyAnalysisTest,ModifyWrite)151 TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
152   Scope root = Scope::NewRootScope().ExitOnError();
153 
154   Node* modify = MakeModify(root, "M");
155   Node* write = MakeWrite(root, "W");
156 
157   root.graph()->AddControlEdge(modify, write);
158 
159   std::vector<std::pair<int, int>> incompatible_pairs;
160   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
161 
162   EXPECT_EQ(incompatible_pairs.size(), 0);
163 }
164 
TEST(ResourceOperationSafetyAnalysisTest,WriteModify)165 TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
166   Scope root = Scope::NewRootScope().ExitOnError();
167 
168   Node* modify = MakeModify(root, "M");
169   Node* write = MakeWrite(root, "W");
170 
171   root.graph()->AddControlEdge(write, modify);
172 
173   std::vector<std::pair<int, int>> incompatible_pairs;
174   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
175 
176   ASSERT_EQ(incompatible_pairs.size(), 1);
177   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
178   EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
179 }
180 
TEST(ResourceOperationSafetyAnalysisTest,ReadModifyWrite)181 TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
182   Scope root = Scope::NewRootScope().ExitOnError();
183 
184   Node* read = MakeRead(root, "R");
185   Node* modify = MakeModify(root, "M");
186   Node* write = MakeWrite(root, "W");
187 
188   root.graph()->AddControlEdge(read, modify);
189   root.graph()->AddControlEdge(modify, write);
190 
191   std::vector<std::pair<int, int>> incompatible_pairs;
192   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
193 
194   EXPECT_EQ(incompatible_pairs.size(), 0);
195 }
196 
TEST(ResourceOperationSafetyAnalysisTest,WriteModifyRead)197 TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
198   Scope root = Scope::NewRootScope().ExitOnError();
199 
200   Node* read = MakeRead(root, "R");
201   Node* modify = MakeModify(root, "M");
202   Node* write = MakeWrite(root, "W");
203 
204   root.graph()->AddControlEdge(write, modify);
205   root.graph()->AddControlEdge(modify, read);
206 
207   std::vector<std::pair<int, int>> incompatible_pairs;
208   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
209 
210   ASSERT_EQ(incompatible_pairs.size(), 3);
211 
212   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
213   std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
214   std::pair<int, int> write_read_pair = {write->id(), read->id()};
215   EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
216   EXPECT_EQ(incompatible_pairs[1], write_read_pair);
217   EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
218 }
219 
TEST(ResourceOperationSafetyAnalysisTest,WriteReadModify)220 TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
221   Scope root = Scope::NewRootScope().ExitOnError();
222 
223   Node* read = MakeRead(root, "R");
224   Node* modify = MakeModify(root, "M");
225   Node* write = MakeWrite(root, "W");
226 
227   root.graph()->AddControlEdge(write, read);
228   root.graph()->AddControlEdge(read, modify);
229 
230   std::vector<std::pair<int, int>> incompatible_pairs;
231   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
232 
233   ASSERT_EQ(incompatible_pairs.size(), 2);
234 
235   std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
236   std::pair<int, int> write_read_pair = {write->id(), read->id()};
237   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
238   EXPECT_EQ(incompatible_pairs[1], write_modify_pair);
239 }
240 
CreateFunctionDefLibWithConstFunction(const string & name)241 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
242   FunctionDefLibrary flib_def;
243   FunctionDef func = FunctionDefHelper::Create(
244       /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
245       /*attr_def*/
246       {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
247       /*ret_def=*/{{"out", "out:output:0"}});
248   *flib_def.add_function() = std::move(func);
249   return flib_def;
250 }
251 
MakeCall(Graph * graph,const string & callee_name,const string & node_name,Status * status)252 Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
253                Status* status) {
254   NodeDef call_node;
255   call_node.set_name(node_name);
256   call_node.set_op(callee_name);
257   return graph->AddNode(call_node, status);
258 }
259 
TEST(ResourceOperationSafetyAnalysisTest,CallRead)260 TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
261   Scope root = Scope::NewRootScope().ExitOnError();
262 
263   FunctionDefLibrary flib_def =
264       CreateFunctionDefLibWithConstFunction("Const_func");
265   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
266 
267   Node* read = MakeRead(root, "R");
268   Status status;
269   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
270   TF_ASSERT_OK(status);
271 
272   root.graph()->AddControlEdge(call, read);
273 
274   std::vector<std::pair<int, int>> incompatible_pairs;
275   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
276 
277   ASSERT_EQ(incompatible_pairs.size(), 1);
278   std::pair<int, int> call_read_edge = {call->id(), read->id()};
279   EXPECT_EQ(incompatible_pairs[0], call_read_edge);
280 }
281 
TEST(ResourceOperationSafetyAnalysisTest,ReadCall)282 TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
283   Scope root = Scope::NewRootScope().ExitOnError();
284 
285   FunctionDefLibrary flib_def =
286       CreateFunctionDefLibWithConstFunction("Const_func");
287   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
288 
289   Node* read = MakeRead(root, "R");
290   Status status;
291   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
292   TF_ASSERT_OK(status);
293 
294   root.graph()->AddControlEdge(read, call);
295 
296   std::vector<std::pair<int, int>> incompatible_pairs;
297   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
298 
299   EXPECT_EQ(incompatible_pairs.size(), 0);
300 }
301 
TEST(ResourceOperationSafetyAnalysisTest,CallWrite)302 TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
303   Scope root = Scope::NewRootScope().ExitOnError();
304 
305   FunctionDefLibrary flib_def =
306       CreateFunctionDefLibWithConstFunction("Const_func");
307   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
308 
309   Node* write = MakeWrite(root, "W");
310   Status status;
311   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
312   TF_ASSERT_OK(status);
313 
314   root.graph()->AddControlEdge(call, write);
315 
316   std::vector<std::pair<int, int>> incompatible_pairs;
317   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
318 
319   EXPECT_EQ(incompatible_pairs.size(), 0);
320 }
321 
TEST(ResourceOperationSafetyAnalysisTest,WriteCall)322 TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
323   Scope root = Scope::NewRootScope().ExitOnError();
324 
325   FunctionDefLibrary flib_def =
326       CreateFunctionDefLibWithConstFunction("Const_func");
327   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
328 
329   Node* write = MakeWrite(root, "W");
330   Status status;
331   Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
332   TF_ASSERT_OK(status);
333 
334   root.graph()->AddControlEdge(write, call);
335 
336   std::vector<std::pair<int, int>> incompatible_pairs;
337   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
338 
339   ASSERT_EQ(incompatible_pairs.size(), 1);
340   std::pair<int, int> write_call_edge = {write->id(), call->id()};
341   EXPECT_EQ(incompatible_pairs[0], write_call_edge);
342 }
343 
TEST(ResourceOperationSafetyAnalysisTest,SymbolicGradientRead)344 TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
345   Scope root = Scope::NewRootScope().ExitOnError();
346 
347   FunctionDefLibrary flib_def =
348       CreateFunctionDefLibWithConstFunction("Const_func");
349   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
350 
351   Node* read = MakeRead(root, "R");
352   NameAttrList fn;
353   fn.set_name("Const_func");
354   Node* symbolic_gradient =
355       ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
356                             /*Tout=*/{DT_FLOAT}, fn)
357           .output[0]
358           .node();
359 
360   root.graph()->AddControlEdge(symbolic_gradient, read);
361 
362   std::vector<std::pair<int, int>> incompatible_pairs;
363   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
364 
365   ASSERT_EQ(incompatible_pairs.size(), 1);
366   std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
367                                                      read->id()};
368   EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
369 }
370 
TEST(ResourceOperationSafetyAnalysisTest,WriteSymbolicGradient)371 TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
372   Scope root = Scope::NewRootScope().ExitOnError();
373 
374   FunctionDefLibrary flib_def =
375       CreateFunctionDefLibWithConstFunction("Const_func");
376   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
377 
378   Node* write = MakeWrite(root, "W");
379   NameAttrList fn;
380   fn.set_name("Const_func");
381   Node* symbolic_gradient =
382       ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
383                             /*Tout=*/{DT_FLOAT}, fn)
384           .output[0]
385           .node();
386 
387   root.graph()->AddControlEdge(write, symbolic_gradient);
388 
389   std::vector<std::pair<int, int>> incompatible_pairs;
390   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
391 
392   ASSERT_EQ(incompatible_pairs.size(), 1);
393   std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
394                                                       symbolic_gradient->id()};
395   EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
396 }
397 
TEST(ResourceOperationSafetyAnalysisTest,ChainOfOps)398 TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
399   Scope root = Scope::NewRootScope().ExitOnError();
400 
401   Node* write_0 = MakeWrite(root, "W0");
402   Node* neutral_0 = MakeNeutral(root, "N0");
403   Node* read_0 = MakeRead(root, "R0");
404   Node* write_1 = MakeWrite(root, "W1");
405   Node* neutral_1 = MakeNeutral(root, "N1");
406   Node* read_1 = MakeRead(root, "R1");
407 
408   root.graph()->AddControlEdge(write_0, neutral_0);
409   root.graph()->AddControlEdge(neutral_0, read_0);
410   root.graph()->AddControlEdge(read_0, write_1);
411   root.graph()->AddControlEdge(write_1, neutral_1);
412   root.graph()->AddControlEdge(neutral_1, read_1);
413 
414   std::vector<std::pair<int, int>> incompatible_pairs;
415   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
416 
417   ASSERT_EQ(incompatible_pairs.size(), 3);
418   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
419   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
420   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
421 
422   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
423   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
424   EXPECT_EQ(incompatible_pairs[2], write_1_read_1_pair);
425 }
426 
TEST(ResourceOperationSafetyAnalysisTest,DagOfOps)427 TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
428   Scope root = Scope::NewRootScope().ExitOnError();
429 
430   Node* write_0 = MakeWrite(root, "W0");
431   Node* write_1 = MakeWrite(root, "W1");
432   Node* neutral = MakeNeutral(root, "N");
433   Node* read_0 = MakeRead(root, "R0");
434   Node* read_1 = MakeRead(root, "R1");
435 
436   root.graph()->AddControlEdge(write_0, neutral);
437   root.graph()->AddControlEdge(write_1, neutral);
438   root.graph()->AddControlEdge(neutral, read_0);
439   root.graph()->AddControlEdge(neutral, read_1);
440 
441   std::vector<std::pair<int, int>> incompatible_pairs;
442   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
443 
444   ASSERT_EQ(incompatible_pairs.size(), 4);
445   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
446   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
447   std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
448   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
449 
450   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
451   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
452   EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
453   EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
454 }
455 
TEST(ResourceOperationSafetyAnalysisTest,DagOfOpsWithRepeatedPaths)456 TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
457   Scope root = Scope::NewRootScope().ExitOnError();
458 
459   Node* write_0 = MakeWrite(root, "W0");
460   Node* write_1 = MakeWrite(root, "W1");
461   Node* neutral = MakeNeutral(root, "N");
462   Node* read_0 = MakeRead(root, "R0");
463   Node* read_1 = MakeRead(root, "R1");
464 
465   root.graph()->AddControlEdge(write_0, neutral);
466   root.graph()->AddControlEdge(write_1, neutral);
467   root.graph()->AddControlEdge(neutral, read_0);
468   root.graph()->AddControlEdge(neutral, read_1);
469   root.graph()->AddControlEdge(write_1, read_1);
470 
471   std::vector<std::pair<int, int>> incompatible_pairs;
472   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
473 
474   ASSERT_EQ(incompatible_pairs.size(), 4);
475   std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
476   std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
477   std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
478   std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
479 
480   EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
481   EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
482   EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
483   EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
484 }
485 
TEST(ResourceOperationSafetyAnalysisTest,Loop)486 TEST(ResourceOperationSafetyAnalysisTest, Loop) {
487   Scope root = Scope::NewRootScope().ExitOnError();
488 
489   Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
490   Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
491   Output enter_value =
492       ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
493   ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
494   ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
495   ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
496   Output next_iteration =
497       ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
498   TF_ASSERT_OK(
499       root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
500 
501   Node* write = MakeWrite(root, "W");
502   Node* read = MakeRead(root, "R");
503 
504   root.graph()->AddControlEdge(iv.output.node(), write);
505   root.graph()->AddControlEdge(write, read);
506   root.graph()->AddControlEdge(read, next_iteration.node());
507 
508   std::vector<std::pair<int, int>> incompatible_pairs;
509   TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
510 
511   ASSERT_EQ(incompatible_pairs.size(), 1);
512 
513   std::pair<int, int> write_read_pair = {write->id(), read->id()};
514   EXPECT_EQ(incompatible_pairs[0], write_read_pair);
515 }
516 
IsResourceArgDef(const OpDef::ArgDef & arg_def)517 bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
518   return arg_def.type() == DT_RESOURCE;
519 }
520 }  // namespace
521 }  // namespace tensorflow
522