• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/rendezvous.h"
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/cancellation.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 #include "tensorflow/core/lib/random/simple_philox.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/notification.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 namespace {
40 
TEST(RendezvousTest,Key)41 TEST(RendezvousTest, Key) {
42   const string key = Rendezvous::CreateKey(
43       "/job:mnist/replica:1/task:2/CPU:0", 7890,
44       "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0));
45   EXPECT_EQ(key,
46             "/job:mnist/replica:1/task:2/CPU:0;"
47             "0000000000001ed2;"  // 7890 = 0x1ed2
48             "/job:mnist/replica:1/task:2/device:GPU:0;"
49             "var0;"
50             "0:0");
51   Rendezvous::ParsedKey parsed;
52   TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
53   EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
54   EXPECT_EQ(parsed.src_incarnation, 7890);
55   EXPECT_EQ(parsed.src.type, "CPU");
56   EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0");
57   EXPECT_EQ(parsed.dst.type, "GPU");
58 
59   EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
60   EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
61                                     "/job:mnist/replica:1/task:2/device:GPU:0;",
62                                     &parsed)
63                    .ok());
64   EXPECT_FALSE(
65       Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
66 }
67 
68 class LocalRendezvousTest : public ::testing::Test {
69  public:
LocalRendezvousTest()70   LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
71     rendez_ = NewLocalRendezvous();
72   }
73 
~LocalRendezvousTest()74   ~LocalRendezvousTest() override { rendez_->Unref(); }
75 
SchedClosure(std::function<void ()> fn)76   void SchedClosure(std::function<void()> fn) {
77     threads_.Schedule(std::move(fn));
78   }
79 
80   Rendezvous* rendez_;
81 
82  private:
83   thread::ThreadPool threads_;
84 };
85 
86 // string -> Tensor<string>
V(const string & content)87 Tensor V(const string& content) {
88   Tensor tensor(DT_STRING, TensorShape({}));
89   tensor.scalar<tstring>()() = content;
90   return tensor;
91 }
92 
93 // Tensor<string> -> string
V(const Tensor & tensor)94 string V(const Tensor& tensor) {
95   CHECK_EQ(tensor.dtype(), DT_STRING);
96   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
97   return tensor.scalar<tstring>()();
98 }
99 
MakeKey(const string & name)100 Rendezvous::ParsedKey MakeKey(const string& name) {
101   string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
102                                    "/job:mnist/replica:1/task:2/device:GPU:0",
103                                    name, FrameAndIter(0, 0));
104   Rendezvous::ParsedKey k;
105   TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
106   return k;
107 }
108 
KeyFoo()109 const Rendezvous::ParsedKey& KeyFoo() {
110   static auto* key = new Rendezvous::ParsedKey(MakeKey("foo"));
111   return *key;
112 }
113 
KeyBar()114 const Rendezvous::ParsedKey& KeyBar() {
115   static auto* key = new Rendezvous::ParsedKey(MakeKey("bar"));
116   return *key;
117 }
118 
TEST_F(LocalRendezvousTest,SendRecv)119 TEST_F(LocalRendezvousTest, SendRecv) {
120   Rendezvous::Args args;
121   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
122   Tensor val(DT_STRING);
123   bool is_dead = false;
124   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
125   EXPECT_EQ("hello", V(val));
126 }
127 
TEST_F(LocalRendezvousTest,RecvSend)128 TEST_F(LocalRendezvousTest, RecvSend) {
129   SchedClosure([this]() {
130     Env::Default()->SleepForMicroseconds(10000);
131     Rendezvous::Args args;
132     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
133   });
134   Tensor val(DT_STRING);
135   bool is_dead = false;
136   Rendezvous::Args args;
137   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
138   EXPECT_EQ("hello", V(val));
139 }
140 
TEST_F(LocalRendezvousTest,PingPong)141 TEST_F(LocalRendezvousTest, PingPong) {
142   SchedClosure([this]() {
143     Tensor t(DT_STRING);
144     bool is_dead = false;
145     Rendezvous::Args args;
146     TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
147     TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
148   });
149   Env::Default()->SleepForMicroseconds(1000000);
150   Tensor val(DT_STRING);
151   bool val_dead = false;
152   Rendezvous::Args args;
153   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
154   TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
155   EXPECT_EQ("secret msg", V(val));
156 }
157 
TEST_F(LocalRendezvousTest,CancelBeforeRecv)158 TEST_F(LocalRendezvousTest, CancelBeforeRecv) {
159   auto* cm = new CancellationManager();
160   Tensor val(DT_STRING);
161   bool is_dead = false;
162   Rendezvous::Args args;
163   args.cancellation_manager = cm;
164   cm->StartCancel();
165   auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
166   EXPECT_FALSE(s.ok());
167   EXPECT_TRUE(errors::IsCancelled(s));
168   EXPECT_EQ("RecvAsync is cancelled.", s.error_message());
169   delete cm;
170 }
171 
TEST_F(LocalRendezvousTest,CancelAfterRecv)172 TEST_F(LocalRendezvousTest, CancelAfterRecv) {
173   auto* cm = new CancellationManager();
174   Notification n;
175   SchedClosure([cm, &n]() {
176     Env::Default()->SleepForMicroseconds(10000);
177     cm->StartCancel();
178     n.Notify();
179   });
180   Tensor val(DT_STRING);
181   bool is_dead = false;
182   Rendezvous::Args args;
183   args.cancellation_manager = cm;
184   auto s = rendez_->Recv(KeyFoo(), args, &val, &is_dead);
185   EXPECT_FALSE(s.ok());
186   EXPECT_TRUE(errors::IsCancelled(s));
187   EXPECT_EQ("RecvAsync is cancelled.", s.error_message());
188   n.WaitForNotification();
189   delete cm;
190 }
191 
TEST_F(LocalRendezvousTest,CancelEmptyQueue)192 TEST_F(LocalRendezvousTest, CancelEmptyQueue) {
193   auto* cm = new CancellationManager();
194   Notification n;
195   SchedClosure([this, cm, &n]() {
196     Env::Default()->SleepForMicroseconds(10000);
197     Rendezvous::Args args;
198     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
199     cm->StartCancel();
200     n.Notify();
201   });
202   Tensor val(DT_STRING);
203   bool is_dead = false;
204   Rendezvous::Args args;
205   args.cancellation_manager = cm;
206   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
207   EXPECT_EQ("hello", V(val));
208   n.WaitForNotification();
209   delete cm;
210 }
211 
TEST_F(LocalRendezvousTest,CancelMultiple)212 TEST_F(LocalRendezvousTest, CancelMultiple) {
213   auto* cm = new CancellationManager();
214   SchedClosure([this, cm]() {
215     Env::Default()->SleepForMicroseconds(10000);
216     Rendezvous::Args args;
217     cm->StartCancel();
218     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
219     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
220   });
221   Tensor val(DT_STRING);
222   Rendezvous::Args args;
223   Rendezvous::Args args_with_cancellation;
224   args_with_cancellation.cancellation_manager = cm;
225   Notification n0;
226   Notification n1;
227   Notification n2;
228   Notification n3;
229   Status s0;
230   Status s1;
231   Status s2;
232   Status s3;
233 
234   rendez_->RecvAsync(
235       KeyFoo(), args,
236       [&n0, &s0](const Status& s, const Rendezvous::Args& send_args,
237                  const Rendezvous::Args& recv_args, const Tensor& v,
238                  const bool dead) {
239         s0.Update(s);
240         n0.Notify();
241       });
242   rendez_->RecvAsync(
243       KeyFoo(), args_with_cancellation,
244       [&n1, &s1](const Status& s, const Rendezvous::Args& send_args,
245                  const Rendezvous::Args& recv_args, const Tensor& v,
246                  const bool dead) {
247         s1.Update(s);
248         n1.Notify();
249       });
250   rendez_->RecvAsync(
251       KeyFoo(), args,
252       [&n2, &s2](const Status& s, const Rendezvous::Args& send_args,
253                  const Rendezvous::Args& recv_args, const Tensor& v,
254                  const bool dead) {
255         s2.Update(s);
256         n2.Notify();
257       });
258   rendez_->RecvAsync(
259       KeyFoo(), args_with_cancellation,
260       [&n3, &s3](const Status& s, const Rendezvous::Args& send_args,
261                  const Rendezvous::Args& recv_args, const Tensor& v,
262                  const bool dead) {
263         s3.Update(s);
264         n3.Notify();
265       });
266   n0.WaitForNotification();
267   n1.WaitForNotification();
268   n2.WaitForNotification();
269   n3.WaitForNotification();
270   TF_ASSERT_OK(s0);
271   TF_ASSERT_OK(s2);
272   EXPECT_FALSE(s1.ok());
273   EXPECT_FALSE(s3.ok());
274 
275   delete cm;
276 }
277 
278 // A simple structure that behaves a bit like a blocking counter.  The
279 // user that decrements counter to 0 does done.Notify(), and the main
280 // thread waits for done to be notified.
281 struct BlockingState {
282   mutex lock;
283   int counter = 0;
284   Notification done;
285 };
286 
TEST_F(LocalRendezvousTest,RandomSendRecv)287 TEST_F(LocalRendezvousTest, RandomSendRecv) {
288   // We are scheduling 2*N closures in the this->threads_, which is
289   // configured with only 16 threads. Furthermore, because the
290   // threadpool may execute the closures in an arbitrary order, we
291   // must use RecvAsync below. Otherwise, blocking Recv() may run
292   // before all the Send() and deadlock.
293   static const int N = 100;
294   random::PhiloxRandom philox(testing::RandomSeed(), 17);
295   random::SimplePhilox rnd(&philox);
296   BlockingState state;
297   state.counter = N;
298   for (int i = 0; i < N; ++i) {
299     int micros = 100 + rnd.Uniform(1000);
300     SchedClosure([this, i, micros]() {
301       Env::Default()->SleepForMicroseconds(micros);
302       Rendezvous::Args args;
303       TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
304                                  V(strings::StrCat(i)), false));
305     });
306     auto recv_done = [this, &state, i](const Status& status,
307                                        const Rendezvous::Args& sender_args,
308                                        const Rendezvous::Args& recver_args,
309                                        const Tensor& val, const bool val_dead) {
310       EXPECT_EQ(strings::StrCat(i), V(val));
311       bool done = false;
312       {
313         mutex_lock l(state.lock);
314         state.counter--;
315         if (state.counter == 0) {
316           done = true;
317         }
318       }
319       if (done) {
320         state.done.Notify();
321       }
322     };
323     micros = 100 + rnd.Uniform(1000);
324     SchedClosure([this, i, micros, recv_done]() {
325       Env::Default()->SleepForMicroseconds(micros);
326       rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(),
327                          recv_done);
328     });
329   }
330 
331   state.done.WaitForNotification();
332 }
333 
RandomSleep()334 void RandomSleep() {
335   if (std::rand() % 10 == 0) {
336     Env::Default()->SleepForMicroseconds(1000);
337   }
338 }
339 
TEST_F(LocalRendezvousTest,MultiSends)340 TEST_F(LocalRendezvousTest, MultiSends) {
341   static const int N = 100;
342   const auto& key_foo = KeyFoo();
343   Rendezvous::Args args;
344   SchedClosure([=]() {
345     for (int i = 0; i < N; ++i) {
346       TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false));
347       RandomSleep();
348     }
349   });
350   Tensor val;
351   bool val_dead;
352   for (int i = 0; i < N; ++i) {
353     TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead));
354     RandomSleep();
355   }
356 }
357 
TEST_F(LocalRendezvousTest,RecvAbort)358 TEST_F(LocalRendezvousTest, RecvAbort) {
359   rendez_->Ref();
360   SchedClosure([this]() {
361     rendez_->StartAbort(errors::Aborted(""));  // abort
362     rendez_->Unref();
363   });
364   Tensor val(DT_STRING);
365   bool val_dead = false;
366   Rendezvous::Args args;
367   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
368   EXPECT_TRUE(errors::IsAborted(status));
369 }
370 
371 // Similar to RecvAbort. But this test case ensures the main thread
372 // Recv() call happens after StartAbort().
TEST_F(LocalRendezvousTest,RecvSleepAbort)373 TEST_F(LocalRendezvousTest, RecvSleepAbort) {
374   rendez_->Ref();
375   SchedClosure([this]() {
376     Env::Default()->SleepForMicroseconds(1000000);
377     rendez_->StartAbort(errors::Aborted(""));  // abort
378     rendez_->Unref();
379   });
380   Tensor val(DT_STRING);
381   bool val_dead = false;
382   Rendezvous::Args args;
383   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
384   EXPECT_TRUE(errors::IsAborted(status));
385 }
386 
TEST_F(LocalRendezvousTest,AbortThenRecvOrSend)387 TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
388   rendez_->StartAbort(errors::Aborted(""));
389   Tensor val(DT_STRING);
390   bool val_dead = false;
391   Rendezvous::Args args;
392   EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
393   EXPECT_TRUE(
394       errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
395 }
396 
397 class DummyDeviceContext : public DeviceContext {
398  public:
DummyDeviceContext(int stream_id)399   explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
~DummyDeviceContext()400   ~DummyDeviceContext() override {}
stream_id() const401   int stream_id() const { return stream_id_; }
402 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const403   void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
404                               Tensor* output_tensor,
405                               StatusCallback done) const override {
406     done(OkStatus());
407   }
408 
409  private:
410   const int stream_id_;
411 };
412 
TEST_F(LocalRendezvousTest,TransferDummyDeviceContext)413 TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
414   Rendezvous::Args args;
415   args.device_context = new DummyDeviceContext(123);
416 
417   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
418 
419   Notification n;
420   Rendezvous::Args args1;
421   args1.device_context = new DummyDeviceContext(1);
422   rendez_->RecvAsync(
423       KeyFoo(), args1,
424       [&n](const Status& s, const Rendezvous::Args& send_args,
425            const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
426         CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(
427                           send_args.device_context)
428                           ->stream_id());
429         n.Notify();
430       });
431 
432   n.WaitForNotification();
433   args.device_context->Unref();
434   args1.device_context->Unref();
435 }
436 
BM_SendRecv(::testing::benchmark::State & state)437 void BM_SendRecv(::testing::benchmark::State& state) {
438   Rendezvous* rendez = NewLocalRendezvous();
439   Tensor orig = V("val");
440   Tensor val(DT_STRING, TensorShape({}));
441   bool is_dead = false;
442   Rendezvous::Args args;
443 
444   for (auto s : state) {
445     TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
446     TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
447   }
448   CHECK_EQ(V(val), V(orig));
449 
450   rendez->Unref();
451 }
452 BENCHMARK(BM_SendRecv);
453 
BM_RecvSend(::testing::benchmark::State & state)454 void BM_RecvSend(::testing::benchmark::State& state) {
455   Rendezvous* rendez = NewLocalRendezvous();
456   Tensor orig = V("val");
457   Tensor val(DT_STRING, TensorShape({}));
458   bool is_dead = false;
459   Rendezvous::Args args;
460 
461   for (auto s : state) {
462     bool received = false;
463     rendez->RecvAsync(
464         KeyFoo(), args,
465         [&val, &received](const Status& /*s*/,
466                           const Rendezvous::Args& /*send_args*/,
467                           const Rendezvous::Args& /*recv_args*/,
468                           const Tensor& tensor, bool /*is_dead*/) {
469           val = tensor;
470           received = true;
471         });
472     TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
473     CHECK(received);
474   }
475   CHECK_EQ(V(val), V(orig));
476 
477   rendez->Unref();
478 }
479 BENCHMARK(BM_RecvSend);
480 
BM_PingPong(::testing::benchmark::State & state)481 void BM_PingPong(::testing::benchmark::State& state) {
482   const int messages_count = state.range(0);
483   auto* cm = new CancellationManager();
484   thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
485 
486   // Benchmark loop
487   // In each iteration:
488   // The main thread sends "foo" for messages_count times and receives "bar"
489   // for messages_count times.  The other thread sends "bar" for
490   // messages_count times and receives "foo" for messages_count times.
491   for (auto s : state) {
492     Rendezvous* rendez = NewLocalRendezvous();
493     pool->Schedule([rendez, messages_count]() {
494       Tensor bar = V("bar");
495       Tensor foo(DT_STRING, TensorShape({}));
496       bool is_dead = false;
497       Rendezvous::Args args;
498       for (int i = 0; i < messages_count; ++i) {
499         TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
500         TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
501       }
502       CHECK_EQ("foo", V(foo));
503     });
504     Tensor foo = V("foo");
505     Tensor bar(DT_STRING, TensorShape({}));
506     bool is_dead = false;
507     Rendezvous::Args args;
508     args.cancellation_manager = cm;
509     for (int i = 0; i < messages_count; ++i) {
510       TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
511       TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
512     }
513     CHECK_EQ("bar", V(bar));
514   }
515   state.SetItemsProcessed(messages_count * state.iterations());
516   delete pool;
517   delete cm;
518 }
519 BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300);
520 
521 }  // namespace
522 }  // namespace tensorflow
523