• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
19 
20 #include <atomic>
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/stream_executor.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 
39 // Subclass EventMgr to access its private constructor.
40 class TEST_EventMgr : public EventMgr {
41  public:
TEST_EventMgr(se::StreamExecutor * se,const GPUOptions & gpu_options)42   TEST_EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
43       : EventMgr(se, gpu_options) {}
44 };
45 
46 class TEST_EventMgrHelper {
47  public:
TEST_EventMgrHelper(EventMgr * em)48   explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {
49     // The polling loop can interfere with the measurements made here, and
50     // isn't needed since the member PollEvents() always clears the queue.
51     // The tested behavior is slightly different from what may occur in
52     // ordinary execution.
53     StopPollingLoop();
54   }
55 
queue_size()56   size_t queue_size() {
57     mutex_lock l(em_->mu_);
58     return em_->used_events_.size();
59   }
60 
free_size()61   size_t free_size() {
62     mutex_lock l(em_->mu_);
63     return em_->free_events_.size();
64   }
65 
QueueTensors(se::Stream * stream,TensorReferenceVector * tensors)66   void QueueTensors(se::Stream* stream, TensorReferenceVector* tensors) {
67     mutex_lock l(em_->mu_);
68     em_->QueueTensors(stream, tensors);
69   }
70 
PollEvents()71   void PollEvents() {
72     while (queue_size() > 0) {
73       // For ordinary tensor frees, this function
74       // should synchronously harvest all complete
75       // events and execute the corresponding memory frees.
76       EventMgr::ToFreeVector to_free;
77       {
78         mutex_lock l(em_->mu_);
79         em_->PollEvents(true, &to_free);
80       }
81       em_->FreeMemory(to_free);
82     }
83   }
84 
StopPollingLoop()85   void StopPollingLoop() { return em_->StopPollingLoop(); }
86 
StartPollingLoop()87   void StartPollingLoop() { return em_->StartPollingLoop(); }
88 
89  private:
90   EventMgr* em_;
91 };
92 
93 static std::atomic_int_fast64_t live_tensor_bytes(0);
94 
95 // A TensorBuffer that counts live memory usage for testing
96 class TestTensorBuffer : public TensorBuffer {
97  public:
TestTensorBuffer(size_t bytes)98   explicit TestTensorBuffer(size_t bytes)
99       : TensorBuffer(nullptr), bytes_(bytes) {
100     live_tensor_bytes += bytes_;
101   }
~TestTensorBuffer()102   ~TestTensorBuffer() override { live_tensor_bytes -= bytes_; }
103 
size() const104   size_t size() const override { return bytes_; }
105 
106   // Not used in this test
root_buffer()107   TensorBuffer* root_buffer() override { return nullptr; }
FillAllocationDescription(AllocationDescription * arg) const108   void FillAllocationDescription(AllocationDescription* arg) const override {}
109 
110  private:
111   size_t bytes_;
112 };
113 
114 namespace {
115 
TEST(EventMgr,Empty)116 TEST(EventMgr, Empty) {
117   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
118   TEST_EventMgr em(stream_exec, GPUOptions());
119   TEST_EventMgrHelper th(&em);
120   EXPECT_EQ(0, th.queue_size());
121   EXPECT_EQ(0, th.free_size());
122 }
123 
AddTensorReference(TensorReferenceVector * v,int64 size)124 static void AddTensorReference(TensorReferenceVector* v, int64 size) {
125   TestTensorBuffer* buf = new TestTensorBuffer(size);
126   v->push_back(TensorReference(buf));
127   buf->Unref();
128 }
129 
130 // Delaying polling until after several enqueings should grow the
131 // total number of allocated events.  Once we have enough events for
132 // the max simultaneously pending, we should not allocate any more.
TEST(EventMgr,DelayedPolling)133 TEST(EventMgr, DelayedPolling) {
134   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
135   TEST_EventMgr em(stream_exec, GPUOptions());
136   TEST_EventMgrHelper th(&em);
137   EXPECT_EQ(0, th.queue_size());
138   TensorReferenceVector* v = nullptr;
139   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
140   CHECK(stream);
141   stream->Init();
142   for (int i = 0; i < 5; ++i) {
143     v = new TensorReferenceVector;
144     AddTensorReference(v, 100 * 1048576);
145     th.QueueTensors(stream.get(), v);
146     EXPECT_EQ(i + 1, th.queue_size());
147     EXPECT_EQ(0, th.free_size());
148   }
149   th.PollEvents();
150   EXPECT_EQ(0, th.queue_size());
151   EXPECT_EQ(5, th.free_size());
152   for (int j = 0; j < 2; ++j) {
153     for (int i = 0; i < 5; ++i) {
154       v = new TensorReferenceVector;
155       AddTensorReference(v, 100 * 1048576);
156       th.QueueTensors(stream.get(), v);
157       EXPECT_EQ(i + 1, th.queue_size());
158       EXPECT_EQ(4 - i, th.free_size());
159     }
160     th.PollEvents();
161     EXPECT_EQ(0, th.queue_size());
162     EXPECT_EQ(5, th.free_size());
163   }
164 }
165 
TEST(EventMgr,FlushLargeTensorImmediately)166 TEST(EventMgr, FlushLargeTensorImmediately) {
167   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
168   TEST_EventMgr em(stream_exec, GPUOptions());
169   TEST_EventMgrHelper th(&em);
170   EXPECT_EQ(0, live_tensor_bytes);
171   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
172   CHECK(stream);
173   stream->Init();
174   for (int i = 0; i < 5; ++i) {
175     TensorReferenceVector v;
176     AddTensorReference(&v, 100 * 1048576);
177     em.ThenDeleteTensors(stream.get(), v);
178     th.PollEvents();  // Ensure things get registered to be freed by Poll
179     EXPECT_EQ(0, live_tensor_bytes);
180   }
181 }
182 
TEST(EventMgr,ManySmallTensorsFlushedImmediately)183 TEST(EventMgr, ManySmallTensorsFlushedImmediately) {
184   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
185   TEST_EventMgr em(stream_exec, GPUOptions());
186   TEST_EventMgrHelper th(&em);
187   EXPECT_EQ(0, live_tensor_bytes);
188   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
189   CHECK(stream);
190   stream->Init();
191   for (int i = 0; i < 5; ++i) {
192     TensorReferenceVector v;
193     for (int i = 0; i < 1000; i++) {
194       AddTensorReference(&v, 100 * 1024);
195     }
196     em.ThenDeleteTensors(stream.get(), v);
197     th.PollEvents();  // Harvest the tensors ready to be freed.
198     EXPECT_EQ(0, live_tensor_bytes);
199   }
200 }
201 
TEST(EventMgr,StreamSwitchingFlushesImmediately)202 TEST(EventMgr, StreamSwitchingFlushesImmediately) {
203   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
204   TEST_EventMgr em(stream_exec, GPUOptions());
205   TEST_EventMgrHelper th(&em);
206   EXPECT_EQ(0, live_tensor_bytes);
207   std::unique_ptr<se::Stream> stream1(new se::Stream(stream_exec));
208   std::unique_ptr<se::Stream> stream2(new se::Stream(stream_exec));
209   stream1->Init();
210   stream2->Init();
211   TensorReferenceVector v1;
212   AddTensorReference(&v1, 1024);
213   em.ThenDeleteTensors(stream1.get(), v1);
214 
215   TensorReferenceVector v2;
216   AddTensorReference(&v2, 1024);
217   int64 initial_live_bytes = live_tensor_bytes;
218   em.ThenDeleteTensors(stream2.get(), v2);
219   th.PollEvents();  // Ensure things get registered to be freed by Poll
220   // Different stream should cause first tensor to get deleted
221   EXPECT_GT(initial_live_bytes, live_tensor_bytes);
222 }
223 
TEST(EventMgr,ManySmallTensorsSeparateCallsFlushed)224 TEST(EventMgr, ManySmallTensorsSeparateCallsFlushed) {
225   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
226   TEST_EventMgr em(stream_exec, GPUOptions());
227   TEST_EventMgrHelper th(&em);
228   EXPECT_EQ(0, live_tensor_bytes);
229   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
230   CHECK(stream);
231   stream->Init();
232   for (int i = 0; i < 5; ++i) {
233     for (int i = 0; i < 1000; i++) {
234       TensorReferenceVector v;
235       AddTensorReference(&v, 100 * 1024);
236       em.ThenDeleteTensors(stream.get(), v);
237     }
238     th.PollEvents();  // Ensure things get registered to be freed by Poll
239     // Some of the tensors at least should be flushed
240     EXPECT_GT(1000 * 100 * 1024, live_tensor_bytes);
241   }
242 }
243 
244 // Deleting the EventMgr when events are still pending should shut
245 // down gracefully.
TEST(EventMgr,NonEmptyShutdown)246 TEST(EventMgr, NonEmptyShutdown) {
247   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
248   TEST_EventMgr em(stream_exec, GPUOptions());
249   TEST_EventMgrHelper th(&em);
250   EXPECT_EQ(0, th.queue_size());
251   EXPECT_EQ(0, th.free_size());
252   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
253   CHECK(stream);
254   stream->Init();
255   for (int i = 0; i < 5; ++i) {
256     TensorReferenceVector* v = new TensorReferenceVector;
257     AddTensorReference(v, 100 * 1048576);
258     th.QueueTensors(stream.get(), v);
259     EXPECT_EQ(1 + i, th.queue_size());
260     EXPECT_EQ(0, th.free_size());
261   }
262 }
263 
264 // Tests that WarnIfInCallback() triggers correctly.
TEST(EventMgr,WarnIfInCallback)265 TEST(EventMgr, WarnIfInCallback) {
266   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
267   TEST_EventMgr em(stream_exec, GPUOptions());
268   TEST_EventMgrHelper th(&em);
269   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
270   CHECK(stream);
271   stream->Init();
272   bool hit = false;
273   th.StartPollingLoop();
274   gpu_event_mgr::WarnIfInCallback([&hit] { hit = true; });
275   EXPECT_FALSE(hit);
276   Notification note;
277   em.ThenExecute(stream.get(), [&hit, &note]() {
278     gpu_event_mgr::WarnIfInCallback([&hit, &note] {
279       hit = true;
280       note.Notify();
281     });
282   });
283   note.WaitForNotification();
284   EXPECT_TRUE(hit);
285 }
286 }  // namespace
287 
288 // Provides access to private resources of BaseGPUDevice.
289 class GPUDeviceTestHelper {
290  public:
GPUDeviceTestHelper(size_t memory_limit,int pending_cap)291   GPUDeviceTestHelper(size_t memory_limit, int pending_cap) {
292     SessionOptions sops;
293     device_ =
294         DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
295     gpu_.reset(reinterpret_cast<BaseGPUDevice*>(device_.release()));
296     gpu_allocator_ = GPUProcessState::singleton()->GetGPUAllocator(
297         GPUOptions(), TfGpuId(0), memory_limit);
298     host_allocator_ = GPUProcessState::singleton()->GetGpuHostAllocator(0);
299   }
300 
gpu()301   BaseGPUDevice* gpu() { return gpu_.get(); }
gpu_allocator()302   Allocator* gpu_allocator() { return gpu_allocator_; }
host_allocator()303   Allocator* host_allocator() { return host_allocator_; }
compute_stream()304   se::Stream* compute_stream() { return gpu_->stream_->compute; }
h2d_stream()305   se::Stream* h2d_stream() { return gpu_->stream_->host_to_device; }
d2h_stream()306   se::Stream* d2h_stream() { return gpu_->stream_->device_to_host; }
d2d_stream()307   se::Stream* d2d_stream() { return gpu_->stream_->device_to_device[0]; }
event_mgr()308   EventMgr* event_mgr() { return gpu_->em_; }
pending_cap()309   int pending_cap() { return gpu_->pending_cap_; }
310 
311  private:
312   std::unique_ptr<Device> device_;
313   std::unique_ptr<BaseGPUDevice> gpu_;
314   Allocator* gpu_allocator_;
315   Allocator* host_allocator_;
316 };
317 
318 namespace {
319 
320 // Class that can queue some GPU data transfers and simple kernels.
321 class EMBenchmarkHelper {
322   GPUDeviceTestHelper* gpu_helper_;
323   // We need one of these for each Add op in the chain.
324   std::vector<std::unique_ptr<OpKernel>> add_kernels_;
325   std::vector<OpKernelContext::Params*> add_params_;
326   std::vector<std::unique_ptr<OpKernelContext>> add_contexts_;
327   // The rest of these are one per chain.
328   NodeDef add_node_def_;
329   NodeDef id_node_def_;
330   gtl::InlinedVector<TensorValue, 4> add_inputs_;
331   std::vector<AllocatorAttributes> allocator_attrs_;
332   gtl::InlinedVector<Tensor, 4> gpu_inputs_;
333   gtl::InlinedVector<Tensor, 4> gpu_outputs_;
334   gtl::InlinedVector<Tensor, 4> host_inputs_;
335   gtl::InlinedVector<Tensor, 4> host_outputs_;
336 
337  public:
338   // Length of tensors.  TODO(tucker): make this a variable parameter.
339   static const int kTDim = 1024;
340 
num_ops() const341   int num_ops() const { return add_kernels_.size(); }
tensor_size() const342   size_t tensor_size() const {
343     return add_inputs_.empty() ? 0 : add_inputs_[0]->NumElements();
344   }
345 
host_outputs(int i)346   Tensor& host_outputs(int i) { return host_outputs_[i]; }
host_inputs(int i)347   Tensor& host_inputs(int i) { return host_inputs_[i]; }
348 
EMBenchmarkHelper(GPUDeviceTestHelper * h)349   EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
350 
ReInit(int num_ops,int tensor_size)351   void ReInit(int num_ops, int tensor_size) {
352     gpu_inputs_.clear();
353     while (gpu_inputs_.size() < 2) {
354       gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
355                                    {tensor_size}, AllocationAttributes()));
356     }
357     gpu_outputs_.clear();
358     while (gpu_outputs_.size() < 1) {
359       gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
360                                     {tensor_size}, AllocationAttributes()));
361     }
362     host_inputs_.clear();
363     while (host_inputs_.size() < 2) {
364       int instance_index = host_inputs_.size();
365       host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
366                                     {tensor_size}, AllocationAttributes()));
367       for (int i = 0; i < tensor_size; ++i) {
368         host_inputs_.back().flat<float>()(i) =
369             i * (1.0 + (0.5 * instance_index));
370       }
371     }
372     host_outputs_.clear();
373     while (host_outputs_.size() < 1) {
374       host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
375                                      {tensor_size}, AllocationAttributes()));
376       for (int i = 0; i < tensor_size; ++i) {
377         host_outputs_.back().flat<float>()(i) = -1;
378       }
379     }
380     add_kernels_.clear();
381     add_params_.clear();
382     while (add_kernels_.size() < num_ops) {
383       MakeAddOp();
384     }
385   }
386 
GetOpKernel(const NodeDef & node_def,Status * status)387   std::unique_ptr<OpKernel> GetOpKernel(const NodeDef& node_def,
388                                         Status* status) {
389     return CreateOpKernel("GPU", gpu_helper_->gpu(),
390                           gpu_helper_->gpu_allocator(), node_def,
391                           TF_GRAPH_DEF_VERSION, status);
392   }
393 
MakeAddOp()394   void MakeAddOp() {
395     if (add_kernels_.empty()) {
396       TF_ASSERT_OK(NodeDefBuilder("add_op", "Add")
397                        .Input(FakeInput(DT_FLOAT))
398                        .Input(FakeInput(DT_FLOAT))
399                        .Device("/job:a/replica:0/task:0/GPU:0")
400                        .Finalize(&add_node_def_));
401     }
402     Status status;
403     add_kernels_.emplace_back(GetOpKernel(add_node_def_, &status));
404     TF_ASSERT_OK(status);
405     add_params_.push_back(new OpKernelContext::Params);
406     PrepOpKernel(add_params_.back(), add_kernels_.back().get());
407   }
408 
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)409   void SetOutputAttrs(OpKernelContext::Params* params,
410                       std::vector<AllocatorAttributes>* attrs) {
411     attrs->clear();
412     for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
413       AllocatorAttributes attr;
414       const bool on_host =
415           (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
416       attr.set_on_host(on_host);
417       attrs->push_back(attr);
418     }
419     params->output_attr_array = attrs->data();
420     params->forward_from_array = {};
421   }
422 
PrepOpKernel(OpKernelContext::Params * params,OpKernel * kernel)423   void PrepOpKernel(OpKernelContext::Params* params, OpKernel* kernel) {
424     // This mimics what happens in ExecutorState::Process to run
425     // a single graph node.
426     params->step_id = 1;
427     params->device = gpu_helper_->gpu();
428     params->log_memory = false;
429     params->record_tensor_accesses = false;
430     params->rendezvous = nullptr;
431     params->collective_executor = nullptr;
432     params->session_state = nullptr;  // ???
433     params->session_handle = "session_handle";
434     params->tensor_store = nullptr;
435     params->cancellation_manager = nullptr;
436 
437     params->call_frame = nullptr;
438     params->function_library = nullptr;
439     params->runner = nullptr;
440     params->graph_collector = nullptr;
441 
442     params->step_container = nullptr;
443     params->slice_reader_cache = nullptr;
444     params->resource_manager = gpu_helper_->gpu()->resource_manager();
445 
446     params->stats_collector = nullptr;
447     params->inc_num_deferred_ops_function = nullptr;
448     params->dec_num_deferred_ops_function = nullptr;
449 
450     params->op_device_context = nullptr;
451     params->track_allocations = false;
452     params->op_kernel = kernel;
453     params->frame_iter = FrameAndIter(0, 0);
454     params->is_input_dead = false;
455 
456     if (add_inputs_.empty()) {
457       add_inputs_.resize(2);
458       add_inputs_[0] = TensorValue(&gpu_inputs_[0]);
459       add_inputs_[1] = TensorValue(&gpu_inputs_[1]);
460     }
461     params->inputs = &add_inputs_;
462     params->input_alloc_attrs = nullptr;
463     SetOutputAttrs(params, &allocator_attrs_);
464   }
465 
466   struct TimeSet {
467     int iter = 0;
468     int64 start = 0;
469     int64 copy_done = 0;
470     int64 compute_done = 0;
471     int64 final_copy = 0;
472     int64 all_done = 0;
473   };
474 
475   // Display sampled iteration times giving the approximate breakdown
476   // within iterations and overall curve.
DisplayTimes(std::vector<TimeSet> * times)477   void DisplayTimes(std::vector<TimeSet>* times) {
478     LOG(INFO) << "Summarize set of " << times->size() << " iters";
479     for (auto& ts : *times) {
480       ts.final_copy = ts.all_done - ts.compute_done;
481       ts.compute_done = ts.compute_done - ts.copy_done;
482       ts.copy_done = ts.copy_done - ts.start;
483       ts.all_done = ts.all_done - ts.start;
484     }
485     struct TSSort {
486       bool operator()(const TimeSet& a, const TimeSet& b) {
487         return a.all_done < b.all_done;
488       }
489     };
490     std::sort(times->begin(), times->end(), TSSort());
491     int64 last_time = 0;
492     // Display first, last and every > 5% change.
493     for (int i = 0; i < times->size(); ++i) {
494       if (i == (times->size() - 1) ||
495           (times->at(i).all_done >= (1.05 * last_time))) {
496         LOG(INFO) << "rank " << i << " iter: " << times->at(i).iter
497                   << " copy: " << times->at(i).copy_done
498                   << " compute: " << times->at(i).compute_done
499                   << " copy back: " << times->at(i).final_copy
500                   << " sum: " << times->at(i).all_done;
501         last_time = times->at(i).all_done;
502       }
503     }
504   }
505 
506   // Queue one work unit on the GPU as follows:
507   // 1. Copy 2 input tensors from CPU to GPU using h2d stream.
508   // 2. Instruct compute stream to wait on h2d stream.
509   // 3. Queue a sequence of Add ops on the compute stream, all using
510   //    the same input tensors, allocating their own output tensors.
511   // 4. Instruct d2h stream to wait on the compute stream.
512   // 5. Copy final output tensor back to the CPU.
513   // 6. Instruct the EventMgr to execute callback when the final tensor
514   //    copy completes.
515   // If event_after_add == true then additionally instruct the EventMgr
516   //    to execute the callback after each Add completes.
517   // The optional times parameter is used for gathering detailed timing
518   // data.
DoAddChain(int adds_per_copy,int rounds,bool event_after_add,std::function<void ()> callback,std::vector<TimeSet> * times)519   void DoAddChain(int adds_per_copy, int rounds, bool event_after_add,
520                   std::function<void()> callback, std::vector<TimeSet>* times) {
521     // Take an extra ref on the inputs so that the add doesn't compute in place.
522     Tensor alias0(gpu_inputs_[0]);
523     Tensor alias1(gpu_inputs_[1]);
524     for (int r = 0; r < rounds; ++r) {
525       if (times) {
526         times->at(r).iter = r;
527         times->at(r).start = Env::Default()->NowMicros();
528       }
529       gpu_helper_->h2d_stream()->ThenWaitFor(gpu_helper_->compute_stream());
530       // Begin by copying the input values from CPU to GPU.
531       const int64 src_bytes = host_inputs_[0].TotalBytes();
532       se::DeviceMemoryBase gpu_dst_ptr0(DMAHelper::base(&gpu_inputs_[0]),
533                                         src_bytes);
534       gpu_helper_->h2d_stream()->ThenMemcpy(
535           &gpu_dst_ptr0, DMAHelper::base(&host_inputs_[0]), src_bytes);
536       se::DeviceMemoryBase gpu_dst_ptr1(DMAHelper::base(&gpu_inputs_[1]),
537                                         src_bytes);
538       gpu_helper_->h2d_stream()->ThenMemcpy(
539           &gpu_dst_ptr1, DMAHelper::base(&host_inputs_[1]), src_bytes);
540       gpu_helper_->compute_stream()->ThenWaitFor(gpu_helper_->h2d_stream());
541       if (times) {
542         gpu_helper_->event_mgr()->ThenExecute(
543             gpu_helper_->compute_stream(), [times, r]() {
544               times->at(r).copy_done = Env::Default()->NowMicros();
545             });
546       }
547       std::unique_ptr<OpKernelContext> ctx;
548       for (int apc = 0; apc < adds_per_copy; ++apc) {
549         ctx.reset(new OpKernelContext(add_params_[apc], 1));
550         gpu_helper_->gpu()->Compute(add_kernels_[apc].get(), ctx.get());
551         TF_ASSERT_OK(ctx->status());
552         if (event_after_add) {
553           gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->compute_stream(),
554                                                 callback);
555         }
556       }
557       // Finish by copying output back to CPU.
558       if (times) {
559         gpu_helper_->event_mgr()->ThenExecute(
560             gpu_helper_->compute_stream(), [times, r]() {
561               times->at(r).compute_done = Env::Default()->NowMicros();
562             });
563       }
564       gpu_helper_->d2h_stream()->ThenWaitFor(gpu_helper_->compute_stream());
565       const int64 return_bytes = ctx->mutable_output(0)->TotalBytes();
566       se::DeviceMemoryBase gpu_src_ptr(DMAHelper::base(ctx->mutable_output(0)),
567                                        return_bytes);
568       gpu_helper_->d2h_stream()->ThenMemcpy(DMAHelper::base(&host_outputs_[0]),
569                                             gpu_src_ptr, return_bytes);
570       gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->d2h_stream(),
571                                             callback);
572       if (times) {
573         gpu_helper_->event_mgr()->ThenExecute(
574             gpu_helper_->d2h_stream(), [times, r]() {
575               times->at(r).all_done = Env::Default()->NowMicros();
576             });
577       }
578     }
579   }
580 };
581 
BM_no_ops(int iters,int threads)582 static void BM_no_ops(int iters, int threads) {
583   testing::StopTiming();
584 #ifdef PLATFORM_GOOGLE
585   BenchmarkUseRealTime();
586 #else
587   testing::UseRealTime();
588 #endif  // PLATFORM_GOOGLE
589   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
590   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
591   CHECK(stream);
592   stream->Init();
593   TEST_EventMgr em(stream_exec, GPUOptions());
594   testing::StartTiming();
595   std::atomic<int> counter;
596   counter.store(0, std::memory_order_seq_cst);
597   se::Stream* stream_ptr = stream.get();
598   auto runner = [&em, &counter, stream_ptr, iters]() {
599     auto callback = [&counter]() { counter.fetch_add(1); };
600     for (int i = 0; i < iters; ++i) {
601       em.ThenExecute(stream_ptr, callback);
602     }
603   };
604   for (int t = 0; t < threads; ++t) {
605     Env::Default()->SchedClosure(runner);
606   }
607   int expected = iters * threads;
608   while (counter < expected) {
609     Env::Default()->SleepForMicroseconds(1);
610   }
611 }
612 BENCHMARK(BM_no_ops)->Arg(4);
613 BENCHMARK(BM_no_ops)->Arg(8);
614 BENCHMARK(BM_no_ops)->Arg(32);
615 
616 // Benchmark functions are defined at top level.  In order to provide a real,
617 // persistent GPUDevice to the following function it also needs to be at top
618 // level.  But then we can't clean it up without a cuda runtime error, so we
619 // just leak it.
620 GPUDeviceTestHelper* gpu_helper = nullptr;
621 EMBenchmarkHelper* bm_helper = nullptr;
622 mutex helper_mu;
623 
624 #ifdef PLATFORM_GOOGLE
BM_chain_ops(int iters,int tensor_size,int adds_per_round,bool event_after_add,int pending_cap)625 static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
626                          bool event_after_add, int pending_cap) {
627 #else
628 static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
629                          bool event_after_add, int pending_cap, int threads) {
630 #endif
631   testing::StopTiming();
632 #ifdef PLATFORM_GOOGLE
633   BenchmarkUseRealTime();
634 #else
635   testing::UseRealTime();
636 #endif  // PLATFORM_GOOGLE
637   {
638     mutex_lock l(helper_mu);
639     if (gpu_helper && gpu_helper->pending_cap() != pending_cap) {
640       delete bm_helper;
641       bm_helper = nullptr;
642       delete gpu_helper;
643       gpu_helper = nullptr;
644     }
645     if (!gpu_helper) {
646       gpu_helper = new GPUDeviceTestHelper(1 << 24, pending_cap);
647       bm_helper = new EMBenchmarkHelper(gpu_helper);
648     }
649     if (bm_helper->num_ops() != adds_per_round ||
650         bm_helper->tensor_size() != tensor_size) {
651       bm_helper->ReInit(adds_per_round, tensor_size);
652     }
653   }
654   std::vector<EMBenchmarkHelper::TimeSet> times;
655   std::vector<EMBenchmarkHelper::TimeSet>* time_ptr = nullptr;
656   if (VLOG_IS_ON(1)) {
657     times.resize(iters);
658     time_ptr = &times;
659   }
660   std::atomic<int> counter;
661   counter.store(0, std::memory_order_seq_cst);
662   auto callback = [&counter]() { counter.fetch_add(1); };
663   // First iter is always slow, so do one prior to the timed loop.
664   int expected = 1 + (event_after_add ? adds_per_round : 0);
665   bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
666   while (counter < expected) {
667     Env::Default()->SleepForMicroseconds(1);
668   }
669   counter = 0;
670   testing::StartTiming();
671 #ifdef PLATFORM_GOOGLE
672   expected = iters * (1 + (event_after_add ? adds_per_round : 0));
673   bm_helper->DoAddChain(adds_per_round, iters, event_after_add, callback,
674                         time_ptr);
675 #else
676   expected = threads * iters * (1 + (event_after_add ? adds_per_round : 0));
677   for (int i = 0; i < threads; ++i) {
678     Env::Default()->SchedClosure(
679         [callback, iters, adds_per_round, event_after_add, time_ptr]() {
680           bm_helper->DoAddChain(adds_per_round, iters, event_after_add,
681                                 callback, time_ptr);
682         });
683   }
684 #endif
685   while (counter < expected) {
686     Env::Default()->SleepForMicroseconds(1);
687   }
688   testing::StopTiming();
689   VLOG(1) << "counter = " << counter << " post_execute Output: "
690           << bm_helper->host_outputs(0).SummarizeValue(64);
691   if (time_ptr) bm_helper->DisplayTimes(time_ptr);
692 }
693 
694 #ifdef PLATFORM_GOOGLE
695 static void BM_chain_1024_1_false(int iters) {
696   BM_chain_ops(iters, 1024, 1, false, 0);
697 }
698 
699 static void BM_chain_1024_1_true(int iters) {
700   BM_chain_ops(iters, 1024, 1, true, 0);
701 }
702 
703 static void BM_chain_1024_10_false(int iters) {
704   BM_chain_ops(iters, 1024, 10, false, 0);
705 }
706 
707 static void BM_chain_1024_10_true(int iters) {
708   BM_chain_ops(iters, 1024, 10, true, 0);
709 }
710 
711 static void BM_chain_1024_100_false(int iters) {
712   BM_chain_ops(iters, 1024, 100, false, 0);
713 }
714 
715 static void BM_chain_1024_100_true(int iters) {
716   BM_chain_ops(iters, 1024, 100, true, 0);
717 }
718 
719 static void BM_chain_1M_1_false(int iters) {
720   BM_chain_ops(iters, 1 << 20, 1, false, 0);
721 }
722 
723 static void BM_chain_1M_1_true(int iters) {
724   BM_chain_ops(iters, 1 << 20, 1, true, 0);
725 }
726 
727 static void BM_chain_1M_10_false(int iters) {
728   BM_chain_ops(iters, 1 << 20, 10, false, 0);
729 }
730 
731 static void BM_chain_1M_10_true(int iters) {
732   BM_chain_ops(iters, 1 << 20, 10, true, 0);
733 }
734 
735 static void BM_chain_1M_100_false(int iters) {
736   BM_chain_ops(iters, 1 << 20, 100, false, 0);
737 }
738 
739 static void BM_chain_1M_100_true(int iters) {
740   BM_chain_ops(iters, 1 << 20, 100, true, 0);
741 }
742 
743 BENCHMARK(BM_chain_1024_1_false)->Threads(1);
744 BENCHMARK(BM_chain_1024_1_true)->Threads(1);
745 BENCHMARK(BM_chain_1024_1_false)->Threads(2);
746 BENCHMARK(BM_chain_1024_1_true)->Threads(2);
747 BENCHMARK(BM_chain_1024_1_false)->Threads(8);
748 BENCHMARK(BM_chain_1024_1_true)->Threads(8);
749 BENCHMARK(BM_chain_1024_10_false)->Threads(1);
750 BENCHMARK(BM_chain_1024_10_true)->Threads(1);
751 BENCHMARK(BM_chain_1024_10_false)->Threads(8);
752 BENCHMARK(BM_chain_1024_10_true)->Threads(8);
753 BENCHMARK(BM_chain_1024_100_false)->Threads(1);
754 BENCHMARK(BM_chain_1024_100_true)->Threads(1);
755 BENCHMARK(BM_chain_1024_100_false)->Threads(2);
756 BENCHMARK(BM_chain_1024_100_true)->Threads(2);
757 BENCHMARK(BM_chain_1024_100_false)->Threads(8);
758 BENCHMARK(BM_chain_1024_100_true)->Threads(8);
759 
760 BENCHMARK(BM_chain_1M_1_false)->Threads(1);
761 BENCHMARK(BM_chain_1M_1_true)->Threads(1);
762 BENCHMARK(BM_chain_1M_1_false)->Threads(2);
763 BENCHMARK(BM_chain_1M_1_true)->Threads(2);
764 BENCHMARK(BM_chain_1M_1_false)->Threads(8);
765 BENCHMARK(BM_chain_1M_1_true)->Threads(8);
766 BENCHMARK(BM_chain_1M_10_false)->Threads(1);
767 BENCHMARK(BM_chain_1M_10_true)->Threads(1);
768 BENCHMARK(BM_chain_1M_10_false)->Threads(8);
769 BENCHMARK(BM_chain_1M_10_true)->Threads(8);
770 BENCHMARK(BM_chain_1M_100_false)->Threads(1);
771 BENCHMARK(BM_chain_1M_100_true)->Threads(1);
772 BENCHMARK(BM_chain_1M_100_false)->Threads(2);
773 BENCHMARK(BM_chain_1M_100_true)->Threads(2);
774 BENCHMARK(BM_chain_1M_100_false)->Threads(8);
775 BENCHMARK(BM_chain_1M_100_true)->Threads(8);
776 #else
777 static void BM_chain_1024_1_false(int iters, int threads) {
778   BM_chain_ops(iters, 1024, 1, false, 0, threads);
779 }
780 
781 static void BM_chain_1024_1_true(int iters, int threads) {
782   BM_chain_ops(iters, 1024, 1, true, 0, threads);
783 }
784 
785 static void BM_chain_1024_10_false(int iters, int threads) {
786   BM_chain_ops(iters, 1024, 10, false, 0, threads);
787 }
788 
789 static void BM_chain_1024_10_true(int iters, int threads) {
790   BM_chain_ops(iters, 1024, 10, true, 0, threads);
791 }
792 
793 static void BM_chain_1024_100_false(int iters, int threads) {
794   BM_chain_ops(iters, 1024, 100, false, 0, threads);
795 }
796 
797 static void BM_chain_1024_100_true(int iters, int threads) {
798   BM_chain_ops(iters, 1024, 100, true, 0, threads);
799 }
800 
801 static void BM_chain_1M_1_false(int iters, int threads) {
802   BM_chain_ops(iters, 1 << 20, 1, false, 0, threads);
803 }
804 
805 static void BM_chain_1M_1_true(int iters, int threads) {
806   BM_chain_ops(iters, 1 << 20, 1, true, 0, threads);
807 }
808 
809 static void BM_chain_1M_10_false(int iters, int threads) {
810   BM_chain_ops(iters, 1 << 20, 10, false, 0, threads);
811 }
812 
813 static void BM_chain_1M_10_true(int iters, int threads) {
814   BM_chain_ops(iters, 1 << 20, 10, true, 0, threads);
815 }
816 
817 static void BM_chain_1M_100_false(int iters, int threads) {
818   BM_chain_ops(iters, 1 << 20, 100, false, 0, threads);
819 }
820 
821 static void BM_chain_1M_100_true(int iters, int threads) {
822   BM_chain_ops(iters, 1 << 20, 100, true, 0, threads);
823 }
824 
825 BENCHMARK(BM_chain_1024_1_false)->Arg(1);
826 BENCHMARK(BM_chain_1024_1_true)->Arg(1);
827 BENCHMARK(BM_chain_1024_1_false)->Arg(2);
828 BENCHMARK(BM_chain_1024_1_true)->Arg(2);
829 BENCHMARK(BM_chain_1024_1_false)->Arg(8);
830 BENCHMARK(BM_chain_1024_1_true)->Arg(8);
831 BENCHMARK(BM_chain_1024_10_false)->Arg(1);
832 BENCHMARK(BM_chain_1024_10_true)->Arg(1);
833 BENCHMARK(BM_chain_1024_10_false)->Arg(8);
834 BENCHMARK(BM_chain_1024_10_true)->Arg(8);
835 BENCHMARK(BM_chain_1024_100_false)->Arg(1);
836 BENCHMARK(BM_chain_1024_100_true)->Arg(1);
837 BENCHMARK(BM_chain_1024_100_false)->Arg(2);
838 BENCHMARK(BM_chain_1024_100_true)->Arg(2);
839 BENCHMARK(BM_chain_1024_100_false)->Arg(8);
840 BENCHMARK(BM_chain_1024_100_true)->Arg(8);
841 
842 BENCHMARK(BM_chain_1M_1_false)->Arg(1);
843 BENCHMARK(BM_chain_1M_1_true)->Arg(1);
844 BENCHMARK(BM_chain_1M_1_false)->Arg(2);
845 BENCHMARK(BM_chain_1M_1_true)->Arg(2);
846 BENCHMARK(BM_chain_1M_1_false)->Arg(8);
847 BENCHMARK(BM_chain_1M_1_true)->Arg(8);
848 BENCHMARK(BM_chain_1M_10_false)->Arg(1);
849 BENCHMARK(BM_chain_1M_10_true)->Arg(1);
850 BENCHMARK(BM_chain_1M_10_false)->Arg(8);
851 BENCHMARK(BM_chain_1M_10_true)->Arg(8);
852 BENCHMARK(BM_chain_1M_100_false)->Arg(1);
853 BENCHMARK(BM_chain_1M_100_true)->Arg(1);
854 BENCHMARK(BM_chain_1M_100_false)->Arg(2);
855 BENCHMARK(BM_chain_1M_100_true)->Arg(2);
856 BENCHMARK(BM_chain_1M_100_false)->Arg(8);
857 BENCHMARK(BM_chain_1M_100_true)->Arg(8);
858 #endif
859 }  // namespace
860 }  // namespace tensorflow
861 
862 #endif  // GOOGLE_CUDA
863