1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/resource_variable_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_constructor.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/graph/graph_def_builder_util.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/platform/test.h"
38
39 namespace tensorflow {
40 namespace {
41
MakeRead(const Scope & scope,const string & id)42 Node* MakeRead(const Scope& scope, const string& id) {
43 Output var_handle =
44 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
45 Output read =
46 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
47 return read.node();
48 }
49
MakeWrite(const Scope & scope,const string & id)50 Node* MakeWrite(const Scope& scope, const string& id) {
51 Output var_handle =
52 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
53 Output value_to_write =
54 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
55 ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
56 value_to_write);
57 return assign_op.operation.node();
58 }
59
MakeModify(const Scope & scope,const string & id)60 Node* MakeModify(const Scope& scope, const string& id) {
61 Output var_handle =
62 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
63 Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
64 ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
65 var_handle, value_to_write);
66 return assign_add_op.operation.node();
67 }
68
MakeNeutral(const Scope & scope,const string & id)69 Node* MakeNeutral(const Scope& scope, const string& id) {
70 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
71 }
72
ComputeIncompatiblePairs(Graph * g,std::vector<std::pair<int,int>> * result)73 Status ComputeIncompatiblePairs(Graph* g,
74 std::vector<std::pair<int, int>>* result) {
75 FixupSourceAndSinkEdges(g);
76 return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
77 result);
78 }
79
TEST(ResourceOperationSafetyAnalysisTest,WriteRead)80 TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
81 Scope root = Scope::NewRootScope().ExitOnError();
82
83 Node* read = MakeRead(root, "R");
84 Node* write = MakeWrite(root, "W");
85
86 root.graph()->AddControlEdge(write, read);
87
88 std::vector<std::pair<int, int>> incompatible_pairs;
89 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
90
91 ASSERT_EQ(incompatible_pairs.size(), 1);
92 std::pair<int, int> write_read_pair = {write->id(), read->id()};
93 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
94 }
95
TEST(ResourceOperationSafetyAnalysisTest,ReadWrite)96 TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
97 Scope root = Scope::NewRootScope().ExitOnError();
98
99 Node* read = MakeRead(root, "R");
100 Node* write = MakeWrite(root, "W");
101
102 root.graph()->AddControlEdge(read, write);
103
104 std::vector<std::pair<int, int>> incompatible_pairs;
105 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
106
107 EXPECT_EQ(incompatible_pairs.size(), 0);
108 }
109
TEST(ResourceOperationSafetyAnalysisTest,ReadWriteNoEdges)110 TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
111 Scope root = Scope::NewRootScope().ExitOnError();
112
113 MakeRead(root, "R");
114 MakeWrite(root, "W");
115
116 std::vector<std::pair<int, int>> incompatible_pairs;
117 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
118
119 EXPECT_EQ(incompatible_pairs.size(), 0);
120 }
121
TEST(ResourceOperationSafetyAnalysisTest,ReadModify)122 TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
123 Scope root = Scope::NewRootScope().ExitOnError();
124
125 Node* read = MakeRead(root, "R");
126 Node* modify = MakeModify(root, "M");
127
128 root.graph()->AddControlEdge(read, modify);
129
130 std::vector<std::pair<int, int>> incompatible_pairs;
131 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
132
133 EXPECT_EQ(incompatible_pairs.size(), 0);
134 }
135
TEST(ResourceOperationSafetyAnalysisTest,ModifyRead)136 TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
137 Scope root = Scope::NewRootScope().ExitOnError();
138
139 Node* read = MakeRead(root, "R");
140 Node* modify = MakeModify(root, "M");
141
142 root.graph()->AddControlEdge(modify, read);
143
144 std::vector<std::pair<int, int>> incompatible_pairs;
145 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
146
147 ASSERT_EQ(incompatible_pairs.size(), 1);
148 std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
149 EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
150 }
151
TEST(ResourceOperationSafetyAnalysisTest,ModifyWrite)152 TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
153 Scope root = Scope::NewRootScope().ExitOnError();
154
155 Node* modify = MakeModify(root, "M");
156 Node* write = MakeWrite(root, "W");
157
158 root.graph()->AddControlEdge(modify, write);
159
160 std::vector<std::pair<int, int>> incompatible_pairs;
161 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
162
163 EXPECT_EQ(incompatible_pairs.size(), 0);
164 }
165
TEST(ResourceOperationSafetyAnalysisTest,WriteModify)166 TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
167 Scope root = Scope::NewRootScope().ExitOnError();
168
169 Node* modify = MakeModify(root, "M");
170 Node* write = MakeWrite(root, "W");
171
172 root.graph()->AddControlEdge(write, modify);
173
174 std::vector<std::pair<int, int>> incompatible_pairs;
175 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
176
177 ASSERT_EQ(incompatible_pairs.size(), 1);
178 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
179 EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
180 }
181
TEST(ResourceOperationSafetyAnalysisTest,ReadModifyWrite)182 TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
183 Scope root = Scope::NewRootScope().ExitOnError();
184
185 Node* read = MakeRead(root, "R");
186 Node* modify = MakeModify(root, "M");
187 Node* write = MakeWrite(root, "W");
188
189 root.graph()->AddControlEdge(read, modify);
190 root.graph()->AddControlEdge(modify, write);
191
192 std::vector<std::pair<int, int>> incompatible_pairs;
193 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
194
195 EXPECT_EQ(incompatible_pairs.size(), 0);
196 }
197
TEST(ResourceOperationSafetyAnalysisTest,WriteModifyRead)198 TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
199 Scope root = Scope::NewRootScope().ExitOnError();
200
201 Node* read = MakeRead(root, "R");
202 Node* modify = MakeModify(root, "M");
203 Node* write = MakeWrite(root, "W");
204
205 root.graph()->AddControlEdge(write, modify);
206 root.graph()->AddControlEdge(modify, read);
207
208 std::vector<std::pair<int, int>> incompatible_pairs;
209 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
210
211 ASSERT_EQ(incompatible_pairs.size(), 3);
212
213 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
214 std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
215 std::pair<int, int> write_read_pair = {write->id(), read->id()};
216 EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
217 EXPECT_EQ(incompatible_pairs[1], write_read_pair);
218 EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
219 }
220
TEST(ResourceOperationSafetyAnalysisTest,WriteReadModify)221 TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
222 Scope root = Scope::NewRootScope().ExitOnError();
223
224 Node* read = MakeRead(root, "R");
225 Node* modify = MakeModify(root, "M");
226 Node* write = MakeWrite(root, "W");
227
228 root.graph()->AddControlEdge(write, read);
229 root.graph()->AddControlEdge(read, modify);
230
231 std::vector<std::pair<int, int>> incompatible_pairs;
232 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
233
234 ASSERT_EQ(incompatible_pairs.size(), 2);
235
236 std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
237 std::pair<int, int> write_read_pair = {write->id(), read->id()};
238 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
239 EXPECT_EQ(incompatible_pairs[1], write_modify_pair);
240 }
241
CreateFunctionDefLibWithConstFunction(const string & name)242 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
243 FunctionDefLibrary flib_def;
244 FunctionDef func = FunctionDefHelper::Create(
245 /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
246 /*attr_def*/
247 {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
248 /*ret_def=*/{{"out", "out:output:0"}});
249 *flib_def.add_function() = std::move(func);
250 return flib_def;
251 }
252
MakeCall(Graph * graph,const string & callee_name,const string & node_name,Status * status)253 Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
254 Status* status) {
255 NodeDef call_node;
256 call_node.set_name(node_name);
257 call_node.set_op(callee_name);
258 return graph->AddNode(call_node, status);
259 }
260
TEST(ResourceOperationSafetyAnalysisTest,CallRead)261 TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
262 Scope root = Scope::NewRootScope().ExitOnError();
263
264 FunctionDefLibrary flib_def =
265 CreateFunctionDefLibWithConstFunction("Const_func");
266 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
267
268 Node* read = MakeRead(root, "R");
269 Status status;
270 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
271 TF_ASSERT_OK(status);
272
273 root.graph()->AddControlEdge(call, read);
274
275 std::vector<std::pair<int, int>> incompatible_pairs;
276 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
277
278 ASSERT_EQ(incompatible_pairs.size(), 1);
279 std::pair<int, int> call_read_edge = {call->id(), read->id()};
280 EXPECT_EQ(incompatible_pairs[0], call_read_edge);
281 }
282
TEST(ResourceOperationSafetyAnalysisTest,ReadCall)283 TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
284 Scope root = Scope::NewRootScope().ExitOnError();
285
286 FunctionDefLibrary flib_def =
287 CreateFunctionDefLibWithConstFunction("Const_func");
288 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
289
290 Node* read = MakeRead(root, "R");
291 Status status;
292 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
293 TF_ASSERT_OK(status);
294
295 root.graph()->AddControlEdge(read, call);
296
297 std::vector<std::pair<int, int>> incompatible_pairs;
298 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
299
300 EXPECT_EQ(incompatible_pairs.size(), 0);
301 }
302
TEST(ResourceOperationSafetyAnalysisTest,CallWrite)303 TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
304 Scope root = Scope::NewRootScope().ExitOnError();
305
306 FunctionDefLibrary flib_def =
307 CreateFunctionDefLibWithConstFunction("Const_func");
308 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
309
310 Node* write = MakeWrite(root, "W");
311 Status status;
312 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
313 TF_ASSERT_OK(status);
314
315 root.graph()->AddControlEdge(call, write);
316
317 std::vector<std::pair<int, int>> incompatible_pairs;
318 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
319
320 EXPECT_EQ(incompatible_pairs.size(), 0);
321 }
322
TEST(ResourceOperationSafetyAnalysisTest,WriteCall)323 TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
324 Scope root = Scope::NewRootScope().ExitOnError();
325
326 FunctionDefLibrary flib_def =
327 CreateFunctionDefLibWithConstFunction("Const_func");
328 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
329
330 Node* write = MakeWrite(root, "W");
331 Status status;
332 Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
333 TF_ASSERT_OK(status);
334
335 root.graph()->AddControlEdge(write, call);
336
337 std::vector<std::pair<int, int>> incompatible_pairs;
338 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
339
340 ASSERT_EQ(incompatible_pairs.size(), 1);
341 std::pair<int, int> write_call_edge = {write->id(), call->id()};
342 EXPECT_EQ(incompatible_pairs[0], write_call_edge);
343 }
344
TEST(ResourceOperationSafetyAnalysisTest,SymbolicGradientRead)345 TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
346 Scope root = Scope::NewRootScope().ExitOnError();
347
348 FunctionDefLibrary flib_def =
349 CreateFunctionDefLibWithConstFunction("Const_func");
350 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
351
352 Node* read = MakeRead(root, "R");
353 NameAttrList fn;
354 fn.set_name("Const_func");
355 Node* symbolic_gradient =
356 ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
357 /*Tout=*/{DT_FLOAT}, fn)
358 .output[0]
359 .node();
360
361 root.graph()->AddControlEdge(symbolic_gradient, read);
362
363 std::vector<std::pair<int, int>> incompatible_pairs;
364 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
365
366 ASSERT_EQ(incompatible_pairs.size(), 1);
367 std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
368 read->id()};
369 EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
370 }
371
TEST(ResourceOperationSafetyAnalysisTest,WriteSymbolicGradient)372 TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
373 Scope root = Scope::NewRootScope().ExitOnError();
374
375 FunctionDefLibrary flib_def =
376 CreateFunctionDefLibWithConstFunction("Const_func");
377 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
378
379 Node* write = MakeWrite(root, "W");
380 NameAttrList fn;
381 fn.set_name("Const_func");
382 Node* symbolic_gradient =
383 ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
384 /*Tout=*/{DT_FLOAT}, fn)
385 .output[0]
386 .node();
387
388 root.graph()->AddControlEdge(write, symbolic_gradient);
389
390 std::vector<std::pair<int, int>> incompatible_pairs;
391 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
392
393 ASSERT_EQ(incompatible_pairs.size(), 1);
394 std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
395 symbolic_gradient->id()};
396 EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
397 }
398
TEST(ResourceOperationSafetyAnalysisTest,ChainOfOps)399 TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
400 Scope root = Scope::NewRootScope().ExitOnError();
401
402 Node* write_0 = MakeWrite(root, "W0");
403 Node* neutral_0 = MakeNeutral(root, "N0");
404 Node* read_0 = MakeRead(root, "R0");
405 Node* write_1 = MakeWrite(root, "W1");
406 Node* neutral_1 = MakeNeutral(root, "N1");
407 Node* read_1 = MakeRead(root, "R1");
408
409 root.graph()->AddControlEdge(write_0, neutral_0);
410 root.graph()->AddControlEdge(neutral_0, read_0);
411 root.graph()->AddControlEdge(read_0, write_1);
412 root.graph()->AddControlEdge(write_1, neutral_1);
413 root.graph()->AddControlEdge(neutral_1, read_1);
414
415 std::vector<std::pair<int, int>> incompatible_pairs;
416 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
417
418 ASSERT_EQ(incompatible_pairs.size(), 3);
419 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
420 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
421 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
422
423 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
424 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
425 EXPECT_EQ(incompatible_pairs[2], write_1_read_1_pair);
426 }
427
TEST(ResourceOperationSafetyAnalysisTest,DagOfOps)428 TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
429 Scope root = Scope::NewRootScope().ExitOnError();
430
431 Node* write_0 = MakeWrite(root, "W0");
432 Node* write_1 = MakeWrite(root, "W1");
433 Node* neutral = MakeNeutral(root, "N");
434 Node* read_0 = MakeRead(root, "R0");
435 Node* read_1 = MakeRead(root, "R1");
436
437 root.graph()->AddControlEdge(write_0, neutral);
438 root.graph()->AddControlEdge(write_1, neutral);
439 root.graph()->AddControlEdge(neutral, read_0);
440 root.graph()->AddControlEdge(neutral, read_1);
441
442 std::vector<std::pair<int, int>> incompatible_pairs;
443 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
444
445 ASSERT_EQ(incompatible_pairs.size(), 4);
446 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
447 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
448 std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
449 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
450
451 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
452 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
453 EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
454 EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
455 }
456
TEST(ResourceOperationSafetyAnalysisTest,DagOfOpsWithRepeatedPaths)457 TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
458 Scope root = Scope::NewRootScope().ExitOnError();
459
460 Node* write_0 = MakeWrite(root, "W0");
461 Node* write_1 = MakeWrite(root, "W1");
462 Node* neutral = MakeNeutral(root, "N");
463 Node* read_0 = MakeRead(root, "R0");
464 Node* read_1 = MakeRead(root, "R1");
465
466 root.graph()->AddControlEdge(write_0, neutral);
467 root.graph()->AddControlEdge(write_1, neutral);
468 root.graph()->AddControlEdge(neutral, read_0);
469 root.graph()->AddControlEdge(neutral, read_1);
470 root.graph()->AddControlEdge(write_1, read_1);
471
472 std::vector<std::pair<int, int>> incompatible_pairs;
473 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
474
475 ASSERT_EQ(incompatible_pairs.size(), 4);
476 std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
477 std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
478 std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
479 std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
480
481 EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
482 EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
483 EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
484 EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
485 }
486
TEST(ResourceOperationSafetyAnalysisTest,Loop)487 TEST(ResourceOperationSafetyAnalysisTest, Loop) {
488 Scope root = Scope::NewRootScope().ExitOnError();
489
490 Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
491 Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
492 Output enter_value =
493 ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
494 ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
495 ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
496 ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
497 Output next_iteration =
498 ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
499 TF_ASSERT_OK(
500 root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
501
502 Node* write = MakeWrite(root, "W");
503 Node* read = MakeRead(root, "R");
504
505 root.graph()->AddControlEdge(iv.output.node(), write);
506 root.graph()->AddControlEdge(write, read);
507 root.graph()->AddControlEdge(read, next_iteration.node());
508
509 std::vector<std::pair<int, int>> incompatible_pairs;
510 TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
511
512 ASSERT_EQ(incompatible_pairs.size(), 1);
513
514 std::pair<int, int> write_read_pair = {write->id(), read->id()};
515 EXPECT_EQ(incompatible_pairs[0], write_read_pair);
516 }
517
IsResourceArgDef(const OpDef::ArgDef & arg_def)518 bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
519 return arg_def.type() == DT_RESOURCE;
520 }
521 } // namespace
522 } // namespace tensorflow
523