1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/rendezvous.h"
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/cancellation.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 #include "tensorflow/core/lib/random/simple_philox.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/notification.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace tensorflow {
39 namespace {
40
TEST(RendezvousTest,Key)41 TEST(RendezvousTest, Key) {
42 const string key = Rendezvous::CreateKey(
43 "/job:mnist/replica:1/task:2/CPU:0", 7890,
44 "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0));
45 EXPECT_EQ(key,
46 "/job:mnist/replica:1/task:2/CPU:0;"
47 "0000000000001ed2;" // 7890 = 0x1ed2
48 "/job:mnist/replica:1/task:2/device:GPU:0;"
49 "var0;"
50 "0:0");
51 Rendezvous::ParsedKey parsed;
52 TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
53 EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
54 EXPECT_EQ(parsed.src_incarnation, 7890);
55 EXPECT_EQ(parsed.src.type, "CPU");
56 EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0");
57 EXPECT_EQ(parsed.dst.type, "GPU");
58
59 EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
60 EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
61 "/job:mnist/replica:1/task:2/device:GPU:0;",
62 &parsed)
63 .ok());
64 EXPECT_FALSE(
65 Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
66 }
67
68 class LocalRendezvousTest : public ::testing::Test {
69 public:
LocalRendezvousTest()70 LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
71 rendez_ = NewLocalRendezvous();
72 }
73
~LocalRendezvousTest()74 ~LocalRendezvousTest() override { rendez_->Unref(); }
75
SchedClosure(std::function<void ()> fn)76 void SchedClosure(std::function<void()> fn) {
77 threads_.Schedule(std::move(fn));
78 }
79
80 Rendezvous* rendez_;
81
82 private:
83 thread::ThreadPool threads_;
84 };
85
86 // string -> Tensor<string>
V(const string & content)87 Tensor V(const string& content) {
88 Tensor tensor(DT_STRING, TensorShape({}));
89 tensor.scalar<tstring>()() = content;
90 return tensor;
91 }
92
93 // Tensor<string> -> string
V(const Tensor & tensor)94 string V(const Tensor& tensor) {
95 CHECK_EQ(tensor.dtype(), DT_STRING);
96 CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
97 return tensor.scalar<tstring>()();
98 }
99
MakeKey(const string & name)100 Rendezvous::ParsedKey MakeKey(const string& name) {
101 string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
102 "/job:mnist/replica:1/task:2/device:GPU:0",
103 name, FrameAndIter(0, 0));
104 Rendezvous::ParsedKey k;
105 TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
106 return k;
107 }
108
KeyFoo()109 const Rendezvous::ParsedKey& KeyFoo() {
110 static auto* key = new Rendezvous::ParsedKey(MakeKey("foo"));
111 return *key;
112 }
113
KeyBar()114 const Rendezvous::ParsedKey& KeyBar() {
115 static auto* key = new Rendezvous::ParsedKey(MakeKey("bar"));
116 return *key;
117 }
118
TEST_F(LocalRendezvousTest,SendRecv)119 TEST_F(LocalRendezvousTest, SendRecv) {
120 Rendezvous::Args args;
121 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
122 Tensor val(DT_STRING);
123 bool is_dead = false;
124 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
125 EXPECT_EQ("hello", V(val));
126 }
127
TEST_F(LocalRendezvousTest,RecvSend)128 TEST_F(LocalRendezvousTest, RecvSend) {
129 SchedClosure([this]() {
130 Env::Default()->SleepForMicroseconds(10000);
131 Rendezvous::Args args;
132 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
133 });
134 Tensor val(DT_STRING);
135 bool is_dead = false;
136 Rendezvous::Args args;
137 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
138 EXPECT_EQ("hello", V(val));
139 }
140
TEST_F(LocalRendezvousTest,PingPong)141 TEST_F(LocalRendezvousTest, PingPong) {
142 SchedClosure([this]() {
143 Tensor t(DT_STRING);
144 bool is_dead = false;
145 Rendezvous::Args args;
146 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
147 TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
148 });
149 Env::Default()->SleepForMicroseconds(1000000);
150 Tensor val(DT_STRING);
151 bool val_dead = false;
152 Rendezvous::Args args;
153 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
154 TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
155 EXPECT_EQ("secret msg", V(val));
156 }
157
TEST_F(LocalRendezvousTest,CancelBeforeRecv)158 TEST_F(LocalRendezvousTest, CancelBeforeRecv) {
159 auto* cm = new CancellationManager();
160 Tensor val(DT_STRING);
161 bool is_dead = false;
162 Rendezvous::Args args;
163 args.cancellation_manager = cm;
164 cm->StartCancel();
165 auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
166 EXPECT_FALSE(s.ok());
167 EXPECT_TRUE(errors::IsCancelled(s));
168 EXPECT_EQ("RecvAsync is cancelled.", s.error_message());
169 delete cm;
170 }
171
TEST_F(LocalRendezvousTest,CancelAfterRecv)172 TEST_F(LocalRendezvousTest, CancelAfterRecv) {
173 auto* cm = new CancellationManager();
174 Notification n;
175 SchedClosure([cm, &n]() {
176 Env::Default()->SleepForMicroseconds(10000);
177 cm->StartCancel();
178 n.Notify();
179 });
180 Tensor val(DT_STRING);
181 bool is_dead = false;
182 Rendezvous::Args args;
183 args.cancellation_manager = cm;
184 auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
185 EXPECT_FALSE(s.ok());
186 EXPECT_TRUE(errors::IsCancelled(s));
187 EXPECT_EQ("RecvAsync is cancelled.", s.error_message());
188 n.WaitForNotification();
189 delete cm;
190 }
191
TEST_F(LocalRendezvousTest,CancelEmptyQueue)192 TEST_F(LocalRendezvousTest, CancelEmptyQueue) {
193 auto* cm = new CancellationManager();
194 Notification n;
195 SchedClosure([this, cm, &n]() {
196 Env::Default()->SleepForMicroseconds(10000);
197 Rendezvous::Args args;
198 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
199 cm->StartCancel();
200 n.Notify();
201 });
202 Tensor val(DT_STRING);
203 bool is_dead = false;
204 Rendezvous::Args args;
205 args.cancellation_manager = cm;
206 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
207 EXPECT_EQ("hello", V(val));
208 n.WaitForNotification();
209 delete cm;
210 }
211
TEST_F(LocalRendezvousTest,CancelMultiple)212 TEST_F(LocalRendezvousTest, CancelMultiple) {
213 auto* cm = new CancellationManager();
214 SchedClosure([this, cm]() {
215 Env::Default()->SleepForMicroseconds(10000);
216 Rendezvous::Args args;
217 cm->StartCancel();
218 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
219 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
220 });
221 Tensor val(DT_STRING);
222 Rendezvous::Args args;
223 Rendezvous::Args args_with_cancellation;
224 args_with_cancellation.cancellation_manager = cm;
225 Notification n0;
226 Notification n1;
227 Notification n2;
228 Notification n3;
229 Status s0;
230 Status s1;
231 Status s2;
232 Status s3;
233
234 rendez_->RecvAsync(
235 KeyFoo(), args,
236 [&n0, &s0](const Status& s, const Rendezvous::Args& send_args,
237 const Rendezvous::Args& recv_args, const Tensor& v,
238 const bool dead) {
239 s0.Update(s);
240 n0.Notify();
241 });
242 rendez_->RecvAsync(
243 KeyFoo(), args_with_cancellation,
244 [&n1, &s1](const Status& s, const Rendezvous::Args& send_args,
245 const Rendezvous::Args& recv_args, const Tensor& v,
246 const bool dead) {
247 s1.Update(s);
248 n1.Notify();
249 });
250 rendez_->RecvAsync(
251 KeyFoo(), args,
252 [&n2, &s2](const Status& s, const Rendezvous::Args& send_args,
253 const Rendezvous::Args& recv_args, const Tensor& v,
254 const bool dead) {
255 s2.Update(s);
256 n2.Notify();
257 });
258 rendez_->RecvAsync(
259 KeyFoo(), args_with_cancellation,
260 [&n3, &s3](const Status& s, const Rendezvous::Args& send_args,
261 const Rendezvous::Args& recv_args, const Tensor& v,
262 const bool dead) {
263 s3.Update(s);
264 n3.Notify();
265 });
266 n0.WaitForNotification();
267 n1.WaitForNotification();
268 n2.WaitForNotification();
269 n3.WaitForNotification();
270 TF_ASSERT_OK(s0);
271 TF_ASSERT_OK(s2);
272 EXPECT_FALSE(s1.ok());
273 EXPECT_FALSE(s3.ok());
274
275 delete cm;
276 }
277
278 // A simple structure that behaves a bit like a blocking counter. The
279 // user that decrements counter to 0 does done.Notify(), and the main
280 // thread waits for done to be notified.
281 struct BlockingState {
282 mutex lock;
283 int counter = 0;
284 Notification done;
285 };
286
TEST_F(LocalRendezvousTest,RandomSendRecv)287 TEST_F(LocalRendezvousTest, RandomSendRecv) {
288 // We are scheduling 2*N closures in the this->threads_, which is
289 // configured with only 16 threads. Furthermore, because the
290 // threadpool may execute the closures in an arbitrary order, we
291 // must use RecvAsync below. Otherwise, blocking Recv() may run
292 // before all the Send() and deadlock.
293 static const int N = 100;
294 random::PhiloxRandom philox(testing::RandomSeed(), 17);
295 random::SimplePhilox rnd(&philox);
296 BlockingState state;
297 state.counter = N;
298 for (int i = 0; i < N; ++i) {
299 int micros = 100 + rnd.Uniform(1000);
300 SchedClosure([this, i, micros]() {
301 Env::Default()->SleepForMicroseconds(micros);
302 Rendezvous::Args args;
303 TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
304 V(strings::StrCat(i)), false));
305 });
306 auto recv_done = [this, &state, i](const Status& status,
307 const Rendezvous::Args& sender_args,
308 const Rendezvous::Args& recver_args,
309 const Tensor& val, const bool val_dead) {
310 EXPECT_EQ(strings::StrCat(i), V(val));
311 bool done = false;
312 {
313 mutex_lock l(state.lock);
314 state.counter--;
315 if (state.counter == 0) {
316 done = true;
317 }
318 }
319 if (done) {
320 state.done.Notify();
321 }
322 };
323 micros = 100 + rnd.Uniform(1000);
324 SchedClosure([this, i, micros, recv_done]() {
325 Env::Default()->SleepForMicroseconds(micros);
326 rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(),
327 recv_done);
328 });
329 }
330
331 state.done.WaitForNotification();
332 }
333
RandomSleep()334 void RandomSleep() {
335 if (std::rand() % 10 == 0) {
336 Env::Default()->SleepForMicroseconds(1000);
337 }
338 }
339
TEST_F(LocalRendezvousTest,MultiSends)340 TEST_F(LocalRendezvousTest, MultiSends) {
341 static const int N = 100;
342 const auto& key_foo = KeyFoo();
343 Rendezvous::Args args;
344 SchedClosure([=]() {
345 for (int i = 0; i < N; ++i) {
346 TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false));
347 RandomSleep();
348 }
349 });
350 Tensor val;
351 bool val_dead;
352 for (int i = 0; i < N; ++i) {
353 TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead));
354 RandomSleep();
355 }
356 }
357
TEST_F(LocalRendezvousTest,RecvAbort)358 TEST_F(LocalRendezvousTest, RecvAbort) {
359 rendez_->Ref();
360 SchedClosure([this]() {
361 rendez_->StartAbort(errors::Aborted("")); // abort
362 rendez_->Unref();
363 });
364 Tensor val(DT_STRING);
365 bool val_dead = false;
366 Rendezvous::Args args;
367 Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
368 EXPECT_TRUE(errors::IsAborted(status));
369 }
370
371 // Similar to RecvAbort. But this test case ensures the main thread
372 // Recv() call happens after StartAbort().
TEST_F(LocalRendezvousTest,RecvSleepAbort)373 TEST_F(LocalRendezvousTest, RecvSleepAbort) {
374 rendez_->Ref();
375 SchedClosure([this]() {
376 Env::Default()->SleepForMicroseconds(1000000);
377 rendez_->StartAbort(errors::Aborted("")); // abort
378 rendez_->Unref();
379 });
380 Tensor val(DT_STRING);
381 bool val_dead = false;
382 Rendezvous::Args args;
383 Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
384 EXPECT_TRUE(errors::IsAborted(status));
385 }
386
TEST_F(LocalRendezvousTest,AbortThenRecvOrSend)387 TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
388 rendez_->StartAbort(errors::Aborted(""));
389 Tensor val(DT_STRING);
390 bool val_dead = false;
391 Rendezvous::Args args;
392 EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
393 EXPECT_TRUE(
394 errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
395 }
396
397 class DummyDeviceContext : public DeviceContext {
398 public:
DummyDeviceContext(int stream_id)399 explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
~DummyDeviceContext()400 ~DummyDeviceContext() override {}
stream_id() const401 int stream_id() const { return stream_id_; }
402
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const403 void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
404 Tensor* output_tensor,
405 StatusCallback done) const override {
406 done(OkStatus());
407 }
408
409 private:
410 const int stream_id_;
411 };
412
TEST_F(LocalRendezvousTest,TransferDummyDeviceContext)413 TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
414 Rendezvous::Args args;
415 args.device_context = new DummyDeviceContext(123);
416
417 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
418
419 Notification n;
420 Rendezvous::Args args1;
421 args1.device_context = new DummyDeviceContext(1);
422 rendez_->RecvAsync(
423 KeyFoo(), args1,
424 [&n](const Status& s, const Rendezvous::Args& send_args,
425 const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
426 CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(
427 send_args.device_context)
428 ->stream_id());
429 n.Notify();
430 });
431
432 n.WaitForNotification();
433 args.device_context->Unref();
434 args1.device_context->Unref();
435 }
436
BM_SendRecv(::testing::benchmark::State & state)437 void BM_SendRecv(::testing::benchmark::State& state) {
438 Rendezvous* rendez = NewLocalRendezvous();
439 Tensor orig = V("val");
440 Tensor val(DT_STRING, TensorShape({}));
441 bool is_dead = false;
442 Rendezvous::Args args;
443
444 for (auto s : state) {
445 TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
446 TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
447 }
448 CHECK_EQ(V(val), V(orig));
449
450 rendez->Unref();
451 }
452 BENCHMARK(BM_SendRecv);
453
BM_RecvSend(::testing::benchmark::State & state)454 void BM_RecvSend(::testing::benchmark::State& state) {
455 Rendezvous* rendez = NewLocalRendezvous();
456 Tensor orig = V("val");
457 Tensor val(DT_STRING, TensorShape({}));
458 bool is_dead = false;
459 Rendezvous::Args args;
460
461 for (auto s : state) {
462 bool received = false;
463 rendez->RecvAsync(
464 KeyFoo(), args,
465 [&val, &received](const Status& /*s*/,
466 const Rendezvous::Args& /*send_args*/,
467 const Rendezvous::Args& /*recv_args*/,
468 const Tensor& tensor, bool /*is_dead*/) {
469 val = tensor;
470 received = true;
471 });
472 TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
473 CHECK(received);
474 }
475 CHECK_EQ(V(val), V(orig));
476
477 rendez->Unref();
478 }
479 BENCHMARK(BM_RecvSend);
480
BM_PingPong(::testing::benchmark::State & state)481 void BM_PingPong(::testing::benchmark::State& state) {
482 const int messages_count = state.range(0);
483 auto* cm = new CancellationManager();
484 thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
485
486 // Benchmark loop
487 // In each iteration:
488 // The main thread sends "foo" for messages_count times and receives "bar"
489 // for messages_count times. The other thread sends "bar" for
490 // messages_count times and receives "foo" for messages_count times.
491 for (auto s : state) {
492 Rendezvous* rendez = NewLocalRendezvous();
493 pool->Schedule([rendez, messages_count]() {
494 Tensor bar = V("bar");
495 Tensor foo(DT_STRING, TensorShape({}));
496 bool is_dead = false;
497 Rendezvous::Args args;
498 for (int i = 0; i < messages_count; ++i) {
499 TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
500 TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
501 }
502 CHECK_EQ("foo", V(foo));
503 });
504 Tensor foo = V("foo");
505 Tensor bar(DT_STRING, TensorShape({}));
506 bool is_dead = false;
507 Rendezvous::Args args;
508 args.cancellation_manager = cm;
509 for (int i = 0; i < messages_count; ++i) {
510 TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
511 TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
512 }
513 CHECK_EQ("bar", V(bar));
514 }
515 state.SetItemsProcessed(messages_count * state.iterations());
516 delete pool;
517 delete cm;
518 }
519 BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300);
520
521 } // namespace
522 } // namespace tensorflow
523