1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/training/queue_runner.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/cc/framework/scope.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/cc/training/coordinator.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/error_codes.pb.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/queue_runner.pb.h"
34 #include "tensorflow/core/public/session.h"
35
36 namespace tensorflow {
37 namespace {
38
39 using error::Code;
40 using ops::Assign;
41 using ops::Const;
42 using ops::CountUpTo;
43 using ops::FIFOQueue;
44 using ops::QueueClose;
45 using ops::QueueDequeue;
46 using ops::QueueEnqueue;
47 using ops::RandomNormal;
48 using ops::Square;
49 using ops::Variable;
50
51 constexpr char kAssignOpName[] = "assign";
52 constexpr char kCancelOp0[] = "cancel0";
53 constexpr char kCancelOp1[] = "cancel1";
54 constexpr char kCloseOp0[] = "close0";
55 constexpr char kCloseOp1[] = "close1";
56 constexpr char kCountUpToOpName[] = "count";
57 constexpr char kDequeueOp0[] = "dequeue0";
58 constexpr char kDequeueOp1[] = "dequeue1";
59 constexpr char kEnqueueOp0[] = "enqueue0";
60 constexpr char kEnqueueOp1[] = "enqueue1";
61 constexpr char kIllegalOpName1[] = "would fail";
62 constexpr char kIllegalOpName2[] = "fail again";
63 constexpr char kQueueName[] = "unit_test";
64 constexpr char kQueueName0[] = "q0";
65 constexpr char kQueueName1[] = "q1";
66 constexpr char kSquareOpName[] = "square";
67 constexpr char kVarOpName[] = "var";
68
BuildSimpleGraph()69 GraphDef BuildSimpleGraph() {
70 Scope root = Scope::NewRootScope();
71 auto init_value = Const(root, 0);
72 auto var = Variable(root.WithOpName(kVarOpName), TensorShape({}),
73 DataType::DT_INT32);
74 auto assign = Assign(root.WithOpName(kAssignOpName), var, init_value);
75 auto count = CountUpTo(root.WithOpName(kCountUpToOpName), var, 10);
76 Square(root.WithOpName(kSquareOpName), var); // NOLINT
77
78 GraphDef graph_def;
79 TF_EXPECT_OK(root.ToGraphDef(&graph_def));
80 return graph_def;
81 }
82
BuildQueueRunnerDef(const std::string & queue_name,const std::vector<std::string> & enqueue_ops,const std::string & close_op,const std::string & cancel_op,const std::vector<Code> & queue_closed_error_codes)83 QueueRunnerDef BuildQueueRunnerDef(
84 const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
85 const std::string& close_op, const std::string& cancel_op,
86 const std::vector<Code>& queue_closed_error_codes) {
87 QueueRunnerDef queue_runner_def;
88 *queue_runner_def.mutable_queue_name() = queue_name;
89 for (const std::string& enqueue_op : enqueue_ops) {
90 *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
91 }
92 *queue_runner_def.mutable_close_op_name() = close_op;
93 *queue_runner_def.mutable_cancel_op_name() = cancel_op;
94 for (const auto& error_code : queue_closed_error_codes) {
95 *queue_runner_def.mutable_queue_closed_exception_types()->Add() =
96 error_code;
97 }
98 return queue_runner_def;
99 }
100
BuildSessionAndInitVariable(const GraphDef & graph_def)101 std::unique_ptr<Session> BuildSessionAndInitVariable(
102 const GraphDef& graph_def) {
103 SessionOptions options;
104 std::unique_ptr<Session> session(NewSession(options));
105 TF_CHECK_OK(session->Create(graph_def));
106
107 TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
108 return session;
109 }
110
TEST(QueueRunnerTest,BasicTest)111 TEST(QueueRunnerTest, BasicTest) {
112 GraphDef graph_def = BuildSimpleGraph();
113 auto session = BuildSessionAndInitVariable(graph_def);
114
115 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
116 kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
117
118 std::unique_ptr<QueueRunner> qr;
119 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
120 TF_CHECK_OK(qr->Start(session.get()));
121 TF_EXPECT_OK(qr->Join());
122
123 std::vector<Tensor> outputs;
124 TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
125 int square_value = *outputs[0].scalar<int>().data();
126 EXPECT_EQ(square_value, 100);
127 }
128
TEST(QueueRunnerTest,QueueClosedCode)129 TEST(QueueRunnerTest, QueueClosedCode) {
130 GraphDef graph_def = BuildSimpleGraph();
131 auto session = BuildSessionAndInitVariable(graph_def);
132
133 // Start two queues so that multiple threads are in Run.
134 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
135 kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "",
136 {Code::OUT_OF_RANGE, Code::CANCELLED});
137
138 std::unique_ptr<QueueRunner> qr;
139 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
140 TF_EXPECT_OK(qr->Start(session.get()));
141 TF_EXPECT_OK(qr->Join());
142
143 std::vector<Tensor> outputs;
144 TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
145 int square_value = *outputs[0].scalar<int>().data();
146 EXPECT_EQ(square_value, 100);
147 }
148
TEST(QueueRunnerTest,QueueCloseFails)149 TEST(QueueRunnerTest, QueueCloseFails) {
150 GraphDef graph_def = BuildSimpleGraph();
151 auto session = BuildSessionAndInitVariable(graph_def);
152
153 QueueRunnerDef queue_runner_def =
154 BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kIllegalOpName1, "",
155 {Code::OUT_OF_RANGE});
156
157 std::unique_ptr<QueueRunner> qr;
158 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
159 TF_EXPECT_OK(qr->Start(session.get()));
160 auto status = qr->Join();
161 EXPECT_EQ(status.code(), Code::NOT_FOUND) << status;
162 }
163
TEST(QueueRunnerTest,CatchErrorInJoin)164 TEST(QueueRunnerTest, CatchErrorInJoin) {
165 GraphDef graph_def = BuildSimpleGraph();
166 auto session = BuildSessionAndInitVariable(graph_def);
167
168 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
169 kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
170
171 std::unique_ptr<QueueRunner> qr;
172 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
173 TF_EXPECT_OK(qr->Start(session.get()));
174 EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND);
175 }
176
BuildDoubleQueueGraph()177 GraphDef BuildDoubleQueueGraph() {
178 Scope root = Scope::NewRootScope();
179 auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32});
180 auto ten = Const(root, 10);
181 auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten});
182 auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
183 auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
184 QueueClose::CancelPendingEnqueues(true));
185 auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
186 FIFOQueue::Capacity(3));
187 auto dequeue0 =
188 QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
189 auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
190 auto dequeue1 =
191 QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32});
192 auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1);
193 auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1,
194 QueueClose::CancelPendingEnqueues(true));
195
196 GraphDef graph_def;
197 TF_EXPECT_OK(root.ToGraphDef(&graph_def));
198 return graph_def;
199 }
200
TEST(QueueRunnerTest,RealEnqueueDequeue)201 TEST(QueueRunnerTest, RealEnqueueDequeue) {
202 auto graph_def = BuildDoubleQueueGraph();
203
204 SessionOptions options;
205 std::unique_ptr<Session> session(NewSession(options));
206 TF_CHECK_OK(session->Create(graph_def));
207
208 QueueRunnerDef queue_runner_def =
209 BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
210 std::unique_ptr<QueueRunner> qr;
211 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
212 TF_CHECK_OK(qr->Start(session.get()));
213
214 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
215 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
216 // Closing queue 0 would also close the queue runner.
217 TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
218
219 TF_EXPECT_OK(qr->Join());
220 std::vector<Tensor> dq1;
221 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
222 EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
223 std::vector<Tensor> dq2;
224 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2));
225 EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
226
227 EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
228 Code::OUT_OF_RANGE);
229 }
230
JoinThread(QueueRunner * queue_runner,bool * join_succeeded,Notification * join_done)231 void JoinThread(QueueRunner* queue_runner, bool* join_succeeded,
232 Notification* join_done) {
233 EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED);
234 *join_succeeded = true;
235 join_done->Notify();
236 }
237
TEST(QueueRunnerTest,SessionCloseCancelPendingEnqueue)238 TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
239 auto graph_def = BuildDoubleQueueGraph();
240
241 SessionOptions options;
242 std::unique_ptr<Session> session(NewSession(options));
243 TF_CHECK_OK(session->Create(graph_def));
244
245 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
246 kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
247 std::unique_ptr<QueueRunner> qr;
248 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
249 TF_CHECK_OK(qr->Start(session.get()));
250
251 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
252
253 std::vector<Tensor> dq1;
254 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
255 EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
256
257 // The expected behavior is the QueueRunner::Join() call is blocked until
258 // Session::Close() is called.
259 bool join_succeeded = false;
260 Notification join_done;
261 Env::Default()->SchedClosure(
262 std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done));
263
264 Env::Default()->SleepForMicroseconds(10000000);
265 EXPECT_EQ(join_succeeded, false);
266
267 // Closing the session is required to cancel pending enqueue nodes.
268 TF_EXPECT_OK(session->Close());
269
270 join_done.WaitForNotification();
271 EXPECT_EQ(join_succeeded, true);
272 }
273
TEST(QueueRunnerTest,EmptyEnqueueOps)274 TEST(QueueRunnerTest, EmptyEnqueueOps) {
275 QueueRunnerDef queue_runner_def =
276 BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
277
278 std::unique_ptr<QueueRunner> qr;
279 EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
280 Code::INVALID_ARGUMENT);
281 }
282
TEST(QueueRunnerTest,StartTimeout)283 TEST(QueueRunnerTest, StartTimeout) {
284 GraphDef graph_def = BuildDoubleQueueGraph();
285 SessionOptions options;
286 std::unique_ptr<Session> session(NewSession(options));
287 TF_CHECK_OK(session->Create(graph_def));
288
289 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
290 kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
291
292 std::unique_ptr<QueueRunner> qr;
293 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
294 // This will timeout since queue0 is not fed and queue1 is fetching data from
295 // queue0.
296 EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
297 TF_EXPECT_OK(session->Close());
298 }
299
TEST(QueueRunnerTest,TestCoordinatorStop)300 TEST(QueueRunnerTest, TestCoordinatorStop) {
301 auto graph_def = BuildDoubleQueueGraph();
302 SessionOptions options;
303 std::unique_ptr<Session> session(NewSession(options));
304 TF_CHECK_OK(session->Create(graph_def));
305
306 QueueRunnerDef queue_runner0 =
307 BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0,
308 {Code::OUT_OF_RANGE, Code::CANCELLED});
309 QueueRunnerDef queue_runner1 =
310 BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
311 {Code::OUT_OF_RANGE, Code::CANCELLED});
312
313 Coordinator coord;
314 std::unique_ptr<QueueRunner> qr0;
315 TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
316 TF_CHECK_OK(qr0->Start(session.get()));
317 std::unique_ptr<QueueRunner> qr1;
318 TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
319 TF_CHECK_OK(qr1->Start(session.get()));
320
321 TF_EXPECT_OK(coord.RegisterRunner(std::move(qr0)));
322 TF_EXPECT_OK(coord.RegisterRunner(std::move(qr1)));
323
324 std::vector<Tensor> dq;
325 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
326 EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
327
328 TF_EXPECT_OK(coord.RequestStop());
329 TF_EXPECT_OK(coord.Join());
330 }
331
TEST(QueueRunnerTest,CallbackCalledOnError)332 TEST(QueueRunnerTest, CallbackCalledOnError) {
333 GraphDef graph_def = BuildSimpleGraph();
334 auto session = BuildSessionAndInitVariable(graph_def);
335
336 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
337 kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
338
339 std::unique_ptr<QueueRunner> qr;
340 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
341 bool error_caught = false;
342 qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; });
343 TF_EXPECT_OK(qr->Start(session.get()));
344 EXPECT_FALSE(qr->Join().ok());
345 EXPECT_TRUE(error_caught);
346 }
347
TEST(QueueRunnerTest,RunMetaDataTest)348 TEST(QueueRunnerTest, RunMetaDataTest) {
349 Scope root = Scope::NewRootScope();
350 auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT});
351 Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT);
352 Output square = Square(root.WithOpName(kSquareOpName), rnd);
353 auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square});
354 auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
355 auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
356 QueueClose::CancelPendingEnqueues(true));
357 auto dequeue0 =
358 QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT});
359
360 GraphDef graph_def;
361 TF_EXPECT_OK(root.ToGraphDef(&graph_def));
362 for (auto& node : *graph_def.mutable_node()) {
363 node.set_device("/cpu:0");
364 }
365 SessionOptions sess_options;
366 sess_options.config.mutable_graph_options()->set_build_cost_model(1);
367 std::unique_ptr<Session> session(NewSession(sess_options));
368
369 TF_CHECK_OK(session->Create(graph_def));
370
371 QueueRunnerDef queue_runner_def =
372 BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {});
373 std::unique_ptr<QueueRunner> qr;
374 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
375 RunOptions run_options;
376 TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), run_options));
377
378 // Make sure there was at least one element enqueued in q0: this prevents a
379 // race condition where we close the queue before it was populated.
380 std::vector<Tensor> dq0;
381 TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
382 // Second call to run dequeue op is to make sure the cost graph has been
383 // stored.
384 TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
385
386 CostGraphDef cost_graph;
387 TF_CHECK_OK(qr->ExportCostGraph(&cost_graph));
388 EXPECT_TRUE(cost_graph.node_size() > 0);
389
390 qr->Stop(session.get());
391 }
392
TEST(QueueRunnerTest,NoRunMetaDataTest)393 TEST(QueueRunnerTest, NoRunMetaDataTest) {
394 GraphDef graph_def = BuildSimpleGraph();
395 auto session = BuildSessionAndInitVariable(graph_def);
396
397 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
398 kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
399 std::unique_ptr<QueueRunner> qr;
400 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
401 TF_CHECK_OK(qr->Start(session.get()));
402
403 TF_EXPECT_OK(qr->Join());
404 CostGraphDef cost_graph;
405 EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(),
406 error::FAILED_PRECONDITION);
407 }
408
409 } // namespace
410 } // namespace tensorflow
411