1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/resource_variable_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/test.h"
37
38 namespace tensorflow {
39 namespace {
40
MakeRead(const Scope & scope,const string & id)41 Node* MakeRead(const Scope& scope, const string& id) {
42 Output var_handle =
43 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
44 Output read =
45 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
46 return read.node();
47 }
48
MakeWrite(const Scope & scope,const string & id)49 Node* MakeWrite(const Scope& scope, const string& id) {
50 Output var_handle =
51 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
52 Output value_to_write =
53 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
54 ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
55 value_to_write);
56 return assign_op.operation.node();
57 }
58
MakeModify(const Scope & scope,const string & id)59 Node* MakeModify(const Scope& scope, const string& id) {
60 Output var_handle =
61 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
62 Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
63 ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
64 var_handle, value_to_write);
65 return assign_add_op.operation.node();
66 }
67
MakeNeutral(const Scope & scope,const string & id)68 Node* MakeNeutral(const Scope& scope, const string& id) {
69 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
70 }
71
ComputeIncompatiblePairs(Graph * g,std::vector<std::pair<int,int>> * result)72 Status ComputeIncompatiblePairs(Graph* g,
73 std::vector<std::pair<int, int>>* result) {
74 FixupSourceAndSinkEdges(g);
75 return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
76 result);
77 }
78
TEST(ResourceOperationSafetyAnalysisTest,WriteRead)79 TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
80 Scope root = Scope::NewRootScope().ExitOnError();
81
82 Node* read = MakeRead(root, "R");
83 Node* write = MakeWrite(root, "W");
84
85 root.graph()->AddControlEdge(write, read);
86
87 std::vector<std::pair<int, int>> incompatible_pairs;
88 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
89
90 ASSERT_EQ(incompatible_pairs.size(), 1);
91 std::pair<int, int> write_read_pair = {write->id(), read->id()};
92 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
93 }
94
TEST(ResourceOperationSafetyAnalysisTest,ReadWrite)95 TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
96 Scope root = Scope::NewRootScope().ExitOnError();
97
98 Node* read = MakeRead(root, "R");
99 Node* write = MakeWrite(root, "W");
100
101 root.graph()->AddControlEdge(read, write);
102
103 std::vector<std::pair<int, int>> incompatible_pairs;
104 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
105
106 EXPECT_EQ(incompatible_pairs.size(), 0);
107 }
108
TEST(ResourceOperationSafetyAnalysisTest,ReadWriteNoEdges)109 TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
110 Scope root = Scope::NewRootScope().ExitOnError();
111
112 MakeRead(root, "R");
113 MakeWrite(root, "W");
114
115 std::vector<std::pair<int, int>> incompatible_pairs;
116 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
117
118 EXPECT_EQ(incompatible_pairs.size(), 0);
119 }
120
TEST(ResourceOperationSafetyAnalysisTest,ReadModify)121 TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
122 Scope root = Scope::NewRootScope().ExitOnError();
123
124 Node* read = MakeRead(root, "R");
125 Node* modify = MakeModify(root, "M");
126
127 root.graph()->AddControlEdge(read, modify);
128
129 std::vector<std::pair<int, int>> incompatible_pairs;
130 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
131
132 EXPECT_EQ(incompatible_pairs.size(), 0);
133 }
134
TEST(ResourceOperationSafetyAnalysisTest,ModifyRead)135 TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
136 Scope root = Scope::NewRootScope().ExitOnError();
137
138 Node* read = MakeRead(root, "R");
139 Node* modify = MakeModify(root, "M");
140
141 root.graph()->AddControlEdge(modify, read);
142
143 std::vector<std::pair<int, int>> incompatible_pairs;
144 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
145
146 ASSERT_EQ(incompatible_pairs.size(), 1);
147 std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
148 EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
149 }
150
TEST(ResourceOperationSafetyAnalysisTest,ModifyWrite)151 TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
152 Scope root = Scope::NewRootScope().ExitOnError();
153
154 Node* modify = MakeModify(root, "M");
155 Node* write = MakeWrite(root, "W");
156
157 root.graph()->AddControlEdge(modify, write);
158
159 std::vector<std::pair<int, int>> incompatible_pairs;
160 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
161
162 EXPECT_EQ(incompatible_pairs.size(), 0);
163 }
164
TEST(ResourceOperationSafetyAnalysisTest,WriteModify)165 TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
166 Scope root = Scope::NewRootScope().ExitOnError();
167
168 Node* modify = MakeModify(root, "M");
169 Node* write = MakeWrite(root, "W");
170
171 root.graph()->AddControlEdge(write, modify);
172
173 std::vector<std::pair<int, int>> incompatible_pairs;
174 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
175
176 ASSERT_EQ(incompatible_pairs.size(), 1);
177 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
178 EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
179 }
180
TEST(ResourceOperationSafetyAnalysisTest,ReadModifyWrite)181 TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
182 Scope root = Scope::NewRootScope().ExitOnError();
183
184 Node* read = MakeRead(root, "R");
185 Node* modify = MakeModify(root, "M");
186 Node* write = MakeWrite(root, "W");
187
188 root.graph()->AddControlEdge(read, modify);
189 root.graph()->AddControlEdge(modify, write);
190
191 std::vector<std::pair<int, int>> incompatible_pairs;
192 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
193
194 EXPECT_EQ(incompatible_pairs.size(), 0);
195 }
196
TEST(ResourceOperationSafetyAnalysisTest,WriteModifyRead)197 TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
198 Scope root = Scope::NewRootScope().ExitOnError();
199
200 Node* read = MakeRead(root, "R");
201 Node* modify = MakeModify(root, "M");
202 Node* write = MakeWrite(root, "W");
203
204 root.graph()->AddControlEdge(write, modify);
205 root.graph()->AddControlEdge(modify, read);
206
207 std::vector<std::pair<int, int>> incompatible_pairs;
208 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
209
210 ASSERT_EQ(incompatible_pairs.size(), 3);
211
212 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
213 std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
214 std::pair<int, int> write_read_pair = {write->id(), read->id()};
215 EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
216 EXPECT_EQ(incompatible_pairs[1], write_read_pair);
217 EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
218 }
219
TEST(ResourceOperationSafetyAnalysisTest,WriteReadModify)220 TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
221 Scope root = Scope::NewRootScope().ExitOnError();
222
223 Node* read = MakeRead(root, "R");
224 Node* modify = MakeModify(root, "M");
225 Node* write = MakeWrite(root, "W");
226
227 root.graph()->AddControlEdge(write, read);
228 root.graph()->AddControlEdge(read, modify);
229
230 std::vector<std::pair<int, int>> incompatible_pairs;
231 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
232
233 ASSERT_EQ(incompatible_pairs.size(), 2);
234
235 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
236 std::pair<int, int> write_read_pair = {write->id(), read->id()};
237 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
238 EXPECT_EQ(incompatible_pairs[1], write_modify_pair);
239 }
240
CreateFunctionDefLibWithConstFunction(const string & name)241 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
242 FunctionDefLibrary flib_def;
243 FunctionDef func = FunctionDefHelper::Create(
244 /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
245 /*attr_def*/
246 {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
247 /*ret_def=*/{{"out", "out:output:0"}});
248 *flib_def.add_function() = std::move(func);
249 return flib_def;
250 }
251
MakeCall(Graph * graph,const string & callee_name,const string & node_name,Status * status)252 Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
253 Status* status) {
254 NodeDef call_node;
255 call_node.set_name(node_name);
256 call_node.set_op(callee_name);
257 return graph->AddNode(call_node, status);
258 }
259
TEST(ResourceOperationSafetyAnalysisTest,CallRead)260 TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
261 Scope root = Scope::NewRootScope().ExitOnError();
262
263 FunctionDefLibrary flib_def =
264 CreateFunctionDefLibWithConstFunction("Const_func");
265 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
266
267 Node* read = MakeRead(root, "R");
268 Status status;
269 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
270 TF_ASSERT_OK(status);
271
272 root.graph()->AddControlEdge(call, read);
273
274 std::vector<std::pair<int, int>> incompatible_pairs;
275 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
276
277 ASSERT_EQ(incompatible_pairs.size(), 1);
278 std::pair<int, int> call_read_edge = {call->id(), read->id()};
279 EXPECT_EQ(incompatible_pairs[0], call_read_edge);
280 }
281
TEST(ResourceOperationSafetyAnalysisTest,ReadCall)282 TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
283 Scope root = Scope::NewRootScope().ExitOnError();
284
285 FunctionDefLibrary flib_def =
286 CreateFunctionDefLibWithConstFunction("Const_func");
287 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
288
289 Node* read = MakeRead(root, "R");
290 Status status;
291 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
292 TF_ASSERT_OK(status);
293
294 root.graph()->AddControlEdge(read, call);
295
296 std::vector<std::pair<int, int>> incompatible_pairs;
297 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
298
299 EXPECT_EQ(incompatible_pairs.size(), 0);
300 }
301
TEST(ResourceOperationSafetyAnalysisTest,CallWrite)302 TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
303 Scope root = Scope::NewRootScope().ExitOnError();
304
305 FunctionDefLibrary flib_def =
306 CreateFunctionDefLibWithConstFunction("Const_func");
307 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
308
309 Node* write = MakeWrite(root, "W");
310 Status status;
311 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
312 TF_ASSERT_OK(status);
313
314 root.graph()->AddControlEdge(call, write);
315
316 std::vector<std::pair<int, int>> incompatible_pairs;
317 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
318
319 EXPECT_EQ(incompatible_pairs.size(), 0);
320 }
321
TEST(ResourceOperationSafetyAnalysisTest,WriteCall)322 TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
323 Scope root = Scope::NewRootScope().ExitOnError();
324
325 FunctionDefLibrary flib_def =
326 CreateFunctionDefLibWithConstFunction("Const_func");
327 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
328
329 Node* write = MakeWrite(root, "W");
330 Status status;
331 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
332 TF_ASSERT_OK(status);
333
334 root.graph()->AddControlEdge(write, call);
335
336 std::vector<std::pair<int, int>> incompatible_pairs;
337 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
338
339 ASSERT_EQ(incompatible_pairs.size(), 1);
340 std::pair<int, int> write_call_edge = {write->id(), call->id()};
341 EXPECT_EQ(incompatible_pairs[0], write_call_edge);
342 }
343
TEST(ResourceOperationSafetyAnalysisTest,SymbolicGradientRead)344 TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
345 Scope root = Scope::NewRootScope().ExitOnError();
346
347 FunctionDefLibrary flib_def =
348 CreateFunctionDefLibWithConstFunction("Const_func");
349 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
350
351 Node* read = MakeRead(root, "R");
352 NameAttrList fn;
353 fn.set_name("Const_func");
354 Node* symbolic_gradient =
355 ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
356 /*Tout=*/{DT_FLOAT}, fn)
357 .output[0]
358 .node();
359
360 root.graph()->AddControlEdge(symbolic_gradient, read);
361
362 std::vector<std::pair<int, int>> incompatible_pairs;
363 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
364
365 ASSERT_EQ(incompatible_pairs.size(), 1);
366 std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
367 read->id()};
368 EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
369 }
370
TEST(ResourceOperationSafetyAnalysisTest,WriteSymbolicGradient)371 TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
372 Scope root = Scope::NewRootScope().ExitOnError();
373
374 FunctionDefLibrary flib_def =
375 CreateFunctionDefLibWithConstFunction("Const_func");
376 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
377
378 Node* write = MakeWrite(root, "W");
379 NameAttrList fn;
380 fn.set_name("Const_func");
381 Node* symbolic_gradient =
382 ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
383 /*Tout=*/{DT_FLOAT}, fn)
384 .output[0]
385 .node();
386
387 root.graph()->AddControlEdge(write, symbolic_gradient);
388
389 std::vector<std::pair<int, int>> incompatible_pairs;
390 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
391
392 ASSERT_EQ(incompatible_pairs.size(), 1);
393 std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
394 symbolic_gradient->id()};
395 EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
396 }
397
TEST(ResourceOperationSafetyAnalysisTest,ChainOfOps)398 TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
399 Scope root = Scope::NewRootScope().ExitOnError();
400
401 Node* write_0 = MakeWrite(root, "W0");
402 Node* neutral_0 = MakeNeutral(root, "N0");
403 Node* read_0 = MakeRead(root, "R0");
404 Node* write_1 = MakeWrite(root, "W1");
405 Node* neutral_1 = MakeNeutral(root, "N1");
406 Node* read_1 = MakeRead(root, "R1");
407
408 root.graph()->AddControlEdge(write_0, neutral_0);
409 root.graph()->AddControlEdge(neutral_0, read_0);
410 root.graph()->AddControlEdge(read_0, write_1);
411 root.graph()->AddControlEdge(write_1, neutral_1);
412 root.graph()->AddControlEdge(neutral_1, read_1);
413
414 std::vector<std::pair<int, int>> incompatible_pairs;
415 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
416
417 ASSERT_EQ(incompatible_pairs.size(), 3);
418 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
419 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
420 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
421
422 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
423 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
424 EXPECT_EQ(incompatible_pairs[2], write_1_read_1_pair);
425 }
426
TEST(ResourceOperationSafetyAnalysisTest,DagOfOps)427 TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
428 Scope root = Scope::NewRootScope().ExitOnError();
429
430 Node* write_0 = MakeWrite(root, "W0");
431 Node* write_1 = MakeWrite(root, "W1");
432 Node* neutral = MakeNeutral(root, "N");
433 Node* read_0 = MakeRead(root, "R0");
434 Node* read_1 = MakeRead(root, "R1");
435
436 root.graph()->AddControlEdge(write_0, neutral);
437 root.graph()->AddControlEdge(write_1, neutral);
438 root.graph()->AddControlEdge(neutral, read_0);
439 root.graph()->AddControlEdge(neutral, read_1);
440
441 std::vector<std::pair<int, int>> incompatible_pairs;
442 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
443
444 ASSERT_EQ(incompatible_pairs.size(), 4);
445 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
446 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
447 std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
448 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
449
450 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
451 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
452 EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
453 EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
454 }
455
TEST(ResourceOperationSafetyAnalysisTest,DagOfOpsWithRepeatedPaths)456 TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
457 Scope root = Scope::NewRootScope().ExitOnError();
458
459 Node* write_0 = MakeWrite(root, "W0");
460 Node* write_1 = MakeWrite(root, "W1");
461 Node* neutral = MakeNeutral(root, "N");
462 Node* read_0 = MakeRead(root, "R0");
463 Node* read_1 = MakeRead(root, "R1");
464
465 root.graph()->AddControlEdge(write_0, neutral);
466 root.graph()->AddControlEdge(write_1, neutral);
467 root.graph()->AddControlEdge(neutral, read_0);
468 root.graph()->AddControlEdge(neutral, read_1);
469 root.graph()->AddControlEdge(write_1, read_1);
470
471 std::vector<std::pair<int, int>> incompatible_pairs;
472 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
473
474 ASSERT_EQ(incompatible_pairs.size(), 4);
475 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
476 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
477 std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
478 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
479
480 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
481 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
482 EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
483 EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
484 }
485
TEST(ResourceOperationSafetyAnalysisTest,Loop)486 TEST(ResourceOperationSafetyAnalysisTest, Loop) {
487 Scope root = Scope::NewRootScope().ExitOnError();
488
489 Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
490 Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
491 Output enter_value =
492 ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
493 ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
494 ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
495 ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
496 Output next_iteration =
497 ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
498 TF_ASSERT_OK(
499 root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
500
501 Node* write = MakeWrite(root, "W");
502 Node* read = MakeRead(root, "R");
503
504 root.graph()->AddControlEdge(iv.output.node(), write);
505 root.graph()->AddControlEdge(write, read);
506 root.graph()->AddControlEdge(read, next_iteration.node());
507
508 std::vector<std::pair<int, int>> incompatible_pairs;
509 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
510
511 ASSERT_EQ(incompatible_pairs.size(), 1);
512
513 std::pair<int, int> write_read_pair = {write->id(), read->id()};
514 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
515 }
516
IsResourceArgDef(const OpDef::ArgDef & arg_def)517 bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
518 return arg_def.type() == DT_RESOURCE;
519 }
520 } // namespace
521 } // namespace tensorflow
522