• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #include <atomic>
19 
20 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/stream_executor.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 
39 // Subclass EventMgr to access its private constructor.
40 class TEST_EventMgr : public EventMgr {
41  public:
TEST_EventMgr(se::StreamExecutor * se,const GPUOptions & gpu_options)42   TEST_EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
43       : EventMgr(se, gpu_options) {}
44 };
45 
46 class TEST_EventMgrHelper {
47  public:
TEST_EventMgrHelper(EventMgr * em)48   explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {
49     // The polling loop can interfere with the measurements made here, and
50     // isn't needed since the member PollEvents() always clears the queue.
51     // The tested behavior is slightly different from what may occur in
52     // ordinary execution.
53     StopPollingLoop();
54   }
55 
queue_size()56   size_t queue_size() {
57     mutex_lock l(em_->mu_);
58     return em_->used_events_.size();
59   }
60 
free_size()61   size_t free_size() {
62     mutex_lock l(em_->mu_);
63     return em_->free_events_.size();
64   }
65 
PollEvents()66   void PollEvents() {
67     while (queue_size() > 0) {
68       // For ordinary tensor frees, this function
69       // should synchronously harvest all complete
70       // events and execute the corresponding memory frees.
71       EventMgr::ToFreeVector to_free;
72       {
73         mutex_lock l(em_->mu_);
74         em_->PollEvents(true, &to_free);
75       }
76       em_->FreeMemory(to_free);
77     }
78   }
79 
StopPollingLoop()80   void StopPollingLoop() { return em_->StopPollingLoop(); }
81 
StartPollingLoop()82   void StartPollingLoop() { return em_->StartPollingLoop(); }
83 
84  private:
85   EventMgr* em_;
86 };
87 
88 static std::atomic_int_fast64_t live_tensor_bytes(0);
89 
90 // A TensorBuffer that counts live memory usage for testing
91 class TestTensorBuffer : public TensorBuffer {
92  public:
TestTensorBuffer(size_t bytes)93   explicit TestTensorBuffer(size_t bytes)
94       : TensorBuffer(nullptr), bytes_(bytes) {
95     live_tensor_bytes += bytes_;
96   }
~TestTensorBuffer()97   ~TestTensorBuffer() override { live_tensor_bytes -= bytes_; }
98 
size() const99   size_t size() const override { return bytes_; }
100 
101   // Not used in this test
root_buffer()102   TensorBuffer* root_buffer() override { return nullptr; }
FillAllocationDescription(AllocationDescription * arg) const103   void FillAllocationDescription(AllocationDescription* arg) const override {}
104 
105  private:
106   size_t bytes_;
107 };
108 
109 namespace {
110 
TEST(EventMgr,Empty)111 TEST(EventMgr, Empty) {
112   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
113   TEST_EventMgr em(stream_exec, GPUOptions());
114   TEST_EventMgrHelper th(&em);
115   EXPECT_EQ(0, th.queue_size());
116   EXPECT_EQ(0, th.free_size());
117 }
118 
119 // Tests that WarnIfInCallback() triggers correctly.
TEST(EventMgr,WarnIfInCallback)120 TEST(EventMgr, WarnIfInCallback) {
121   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
122   TEST_EventMgr em(stream_exec, GPUOptions());
123   TEST_EventMgrHelper th(&em);
124   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
125   CHECK(stream);
126   stream->Init();
127   bool hit = false;
128   th.StartPollingLoop();
129   device_event_mgr::WarnIfInCallback([&hit] { hit = true; });
130   EXPECT_FALSE(hit);
131   Notification note;
132   em.ThenExecute(stream.get(), [&hit, &note]() {
133     device_event_mgr::WarnIfInCallback([&hit, &note] {
134       hit = true;
135       note.Notify();
136     });
137   });
138   note.WaitForNotification();
139   EXPECT_TRUE(hit);
140 }
141 }  // namespace
142 
143 // Provides access to private resources of BaseGPUDevice.
144 class GPUDeviceTestHelper {
145  public:
GPUDeviceTestHelper(size_t memory_limit,int pending_cap)146   GPUDeviceTestHelper(size_t memory_limit, int pending_cap) {
147     SessionOptions sops;
148     device_ =
149         DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
150     gpu_.reset(reinterpret_cast<BaseGPUDevice*>(device_.release()));
151     gpu_allocator_ = GPUProcessState::singleton()->GetGPUAllocator(
152         GPUOptions(), TfDeviceId(0), memory_limit, /*peer_gpu_ids=*/{});
153     host_allocator_ = GPUProcessState::singleton()->GetGpuHostAllocator(0);
154   }
155 
gpu()156   BaseGPUDevice* gpu() { return gpu_.get(); }
gpu_allocator()157   Allocator* gpu_allocator() { return gpu_allocator_; }
host_allocator()158   Allocator* host_allocator() { return host_allocator_; }
compute_stream()159   se::Stream* compute_stream() { return gpu_->stream_->compute; }
h2d_stream()160   se::Stream* h2d_stream() { return gpu_->stream_->host_to_device; }
d2h_stream()161   se::Stream* d2h_stream() { return gpu_->stream_->device_to_host; }
d2d_stream()162   se::Stream* d2d_stream() { return gpu_->stream_->device_to_device[0]; }
event_mgr()163   EventMgr* event_mgr() { return gpu_->em_; }
pending_cap()164   int pending_cap() { return gpu_->pending_cap_; }
165 
166  private:
167   std::unique_ptr<Device> device_;
168   std::unique_ptr<BaseGPUDevice> gpu_;
169   Allocator* gpu_allocator_;
170   Allocator* host_allocator_;
171 };
172 
173 namespace {
174 
175 // Class that can queue some GPU data transfers and simple kernels.
176 class EMBenchmarkHelper {
177   GPUDeviceTestHelper* gpu_helper_;
178   // We need one of these for each Add op in the chain.
179   std::vector<std::unique_ptr<OpKernel>> add_kernels_;
180   std::vector<OpKernelContext::Params*> add_params_;
181   std::vector<std::unique_ptr<OpKernelContext>> add_contexts_;
182   // The rest of these are one per chain.
183   NodeDef add_node_def_;
184   NodeDef id_node_def_;
185   gtl::InlinedVector<TensorValue, 4> add_inputs_;
186   std::vector<AllocatorAttributes> allocator_attrs_;
187   gtl::InlinedVector<Tensor, 4> gpu_inputs_;
188   gtl::InlinedVector<Tensor, 4> gpu_outputs_;
189   gtl::InlinedVector<Tensor, 4> host_inputs_;
190   gtl::InlinedVector<Tensor, 4> host_outputs_;
191 
192  public:
193   // Length of tensors.  TODO(tucker): make this a variable parameter.
194   static constexpr int kTDim = 1024;
195 
num_ops() const196   int num_ops() const { return add_kernels_.size(); }
tensor_size() const197   size_t tensor_size() const {
198     return add_inputs_.empty() ? 0 : add_inputs_[0]->NumElements();
199   }
200 
host_outputs(int i)201   Tensor& host_outputs(int i) { return host_outputs_[i]; }
host_inputs(int i)202   Tensor& host_inputs(int i) { return host_inputs_[i]; }
203 
EMBenchmarkHelper(GPUDeviceTestHelper * h)204   EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
205 
ReInit(int num_ops,int tensor_size)206   void ReInit(int num_ops, int tensor_size) {
207     gpu_inputs_.clear();
208     while (gpu_inputs_.size() < 2) {
209       gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
210                                    {tensor_size}, AllocationAttributes()));
211     }
212     gpu_outputs_.clear();
213     while (gpu_outputs_.size() < 1) {
214       gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
215                                     {tensor_size}, AllocationAttributes()));
216     }
217     host_inputs_.clear();
218     while (host_inputs_.size() < 2) {
219       int instance_index = host_inputs_.size();
220       host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
221                                     {tensor_size}, AllocationAttributes()));
222       for (int i = 0; i < tensor_size; ++i) {
223         host_inputs_.back().flat<float>()(i) =
224             i * (1.0 + (0.5 * instance_index));
225       }
226     }
227     host_outputs_.clear();
228     while (host_outputs_.size() < 1) {
229       host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
230                                      {tensor_size}, AllocationAttributes()));
231       for (int i = 0; i < tensor_size; ++i) {
232         host_outputs_.back().flat<float>()(i) = -1;
233       }
234     }
235     add_kernels_.clear();
236     add_params_.clear();
237     while (add_kernels_.size() < num_ops) {
238       MakeAddOp();
239     }
240   }
241 
GetOpKernel(const NodeDef & node_def,Status * status)242   std::unique_ptr<OpKernel> GetOpKernel(const NodeDef& node_def,
243                                         Status* status) {
244     return CreateOpKernel("GPU", gpu_helper_->gpu(),
245                           gpu_helper_->gpu_allocator(), node_def,
246                           TF_GRAPH_DEF_VERSION, status);
247   }
248 
MakeAddOp()249   void MakeAddOp() {
250     if (add_kernels_.empty()) {
251       TF_ASSERT_OK(NodeDefBuilder("add_op", "Add")
252                        .Input(FakeInput(DT_FLOAT))
253                        .Input(FakeInput(DT_FLOAT))
254                        .Device("/job:a/replica:0/task:0/GPU:0")
255                        .Finalize(&add_node_def_));
256     }
257     Status status;
258     add_kernels_.emplace_back(GetOpKernel(add_node_def_, &status));
259     TF_ASSERT_OK(status);
260     add_params_.push_back(new OpKernelContext::Params);
261     PrepOpKernel(add_params_.back(), add_kernels_.back().get());
262   }
263 
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)264   void SetOutputAttrs(OpKernelContext::Params* params,
265                       std::vector<AllocatorAttributes>* attrs) {
266     attrs->clear();
267     for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
268       AllocatorAttributes attr;
269       const bool on_host =
270           (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
271       attr.set_on_host(on_host);
272       attrs->push_back(attr);
273     }
274     params->output_attr_array = attrs->data();
275     params->forward_from_array = {};
276   }
277 
PrepOpKernel(OpKernelContext::Params * params,OpKernel * kernel)278   void PrepOpKernel(OpKernelContext::Params* params, OpKernel* kernel) {
279     // This mimics what happens in ExecutorState::Process to run
280     // a single graph node.
281     params->step_id = 1;
282     params->device = gpu_helper_->gpu();
283     params->log_memory = false;
284     params->rendezvous = nullptr;
285     params->collective_executor = nullptr;
286     params->session_state = nullptr;  // ???
287     params->session_handle = "session_handle";
288     params->tensor_store = nullptr;
289     params->cancellation_manager = nullptr;
290 
291     params->call_frame = nullptr;
292     params->function_library = nullptr;
293     params->runner = nullptr;
294     params->graph_collector = nullptr;
295 
296     params->step_container = nullptr;
297     params->slice_reader_cache = nullptr;
298     params->resource_manager = gpu_helper_->gpu()->resource_manager();
299 
300     params->stats_collector = nullptr;
301     params->inc_num_deferred_ops_function = nullptr;
302     params->dec_num_deferred_ops_function = nullptr;
303 
304     params->op_device_context = nullptr;
305     params->track_allocations = false;
306     params->op_kernel = kernel;
307     params->frame_iter = FrameAndIter(0, 0);
308     params->is_input_dead = false;
309 
310     if (add_inputs_.empty()) {
311       add_inputs_.resize(2);
312       add_inputs_[0] = TensorValue(&gpu_inputs_[0]);
313       add_inputs_[1] = TensorValue(&gpu_inputs_[1]);
314     }
315     params->inputs = &add_inputs_;
316     params->input_alloc_attrs = nullptr;
317     SetOutputAttrs(params, &allocator_attrs_);
318   }
319 
320   struct TimeSet {
321     int iter = 0;
322     int64 start = 0;
323     int64 copy_done = 0;
324     int64 compute_done = 0;
325     int64 final_copy = 0;
326     int64 all_done = 0;
327   };
328 
329   // Display sampled iteration times giving the approximate breakdown
330   // within iterations and overall curve.
DisplayTimes(std::vector<TimeSet> * times)331   void DisplayTimes(std::vector<TimeSet>* times) {
332     LOG(INFO) << "Summarize set of " << times->size() << " iters";
333     for (auto& ts : *times) {
334       ts.final_copy = ts.all_done - ts.compute_done;
335       ts.compute_done = ts.compute_done - ts.copy_done;
336       ts.copy_done = ts.copy_done - ts.start;
337       ts.all_done = ts.all_done - ts.start;
338     }
339     struct TSSort {
340       bool operator()(const TimeSet& a, const TimeSet& b) {
341         return a.all_done < b.all_done;
342       }
343     };
344     std::sort(times->begin(), times->end(), TSSort());
345     int64_t last_time = 0;
346     // Display first, last and every > 5% change.
347     for (int i = 0; i < times->size(); ++i) {
348       if (i == (times->size() - 1) ||
349           (times->at(i).all_done >= (1.05 * last_time))) {
350         LOG(INFO) << "rank " << i << " iter: " << times->at(i).iter
351                   << " copy: " << times->at(i).copy_done
352                   << " compute: " << times->at(i).compute_done
353                   << " copy back: " << times->at(i).final_copy
354                   << " sum: " << times->at(i).all_done;
355         last_time = times->at(i).all_done;
356       }
357     }
358   }
359 
360   // Queue one work unit on the GPU as follows:
361   // 1. Copy 2 input tensors from CPU to GPU using h2d stream.
362   // 2. Instruct compute stream to wait on h2d stream.
363   // 3. Queue a sequence of Add ops on the compute stream, all using
364   //    the same input tensors, allocating their own output tensors.
365   // 4. Instruct d2h stream to wait on the compute stream.
366   // 5. Copy final output tensor back to the CPU.
367   // 6. Instruct the EventMgr to execute callback when the final tensor
368   //    copy completes.
369   // If event_after_add == true then additionally instruct the EventMgr
370   //    to execute the callback after each Add completes.
371   // The optional times parameter is used for gathering detailed timing
372   // data.
DoAddChain(int adds_per_copy,int rounds,bool event_after_add,std::function<void ()> callback,std::vector<TimeSet> * times)373   void DoAddChain(int adds_per_copy, int rounds, bool event_after_add,
374                   std::function<void()> callback, std::vector<TimeSet>* times) {
375     // Take an extra ref on the inputs so that the add doesn't compute in place.
376     Tensor alias0(gpu_inputs_[0]);
377     Tensor alias1(gpu_inputs_[1]);
378     for (int r = 0; r < rounds; ++r) {
379       if (times) {
380         times->at(r).iter = r;
381         times->at(r).start = Env::Default()->NowMicros();
382       }
383       gpu_helper_->h2d_stream()->ThenWaitFor(gpu_helper_->compute_stream());
384       // Begin by copying the input values from CPU to GPU.
385       const int64_t src_bytes = host_inputs_[0].TotalBytes();
386       se::DeviceMemoryBase gpu_dst_ptr0(DMAHelper::base(&gpu_inputs_[0]),
387                                         src_bytes);
388       gpu_helper_->h2d_stream()->ThenMemcpy(
389           &gpu_dst_ptr0, DMAHelper::base(&host_inputs_[0]), src_bytes);
390       se::DeviceMemoryBase gpu_dst_ptr1(DMAHelper::base(&gpu_inputs_[1]),
391                                         src_bytes);
392       gpu_helper_->h2d_stream()->ThenMemcpy(
393           &gpu_dst_ptr1, DMAHelper::base(&host_inputs_[1]), src_bytes);
394       gpu_helper_->compute_stream()->ThenWaitFor(gpu_helper_->h2d_stream());
395       if (times) {
396         gpu_helper_->event_mgr()->ThenExecute(
397             gpu_helper_->compute_stream(), [times, r]() {
398               times->at(r).copy_done = Env::Default()->NowMicros();
399             });
400       }
401       std::unique_ptr<OpKernelContext> ctx;
402       for (int apc = 0; apc < adds_per_copy; ++apc) {
403         ctx.reset(new OpKernelContext(add_params_[apc], 1));
404         gpu_helper_->gpu()->Compute(add_kernels_[apc].get(), ctx.get());
405         TF_ASSERT_OK(ctx->status());
406         if (event_after_add) {
407           gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->compute_stream(),
408                                                 callback);
409         }
410       }
411       // Finish by copying output back to CPU.
412       if (times) {
413         gpu_helper_->event_mgr()->ThenExecute(
414             gpu_helper_->compute_stream(), [times, r]() {
415               times->at(r).compute_done = Env::Default()->NowMicros();
416             });
417       }
418       gpu_helper_->d2h_stream()->ThenWaitFor(gpu_helper_->compute_stream());
419       const int64_t return_bytes = ctx->mutable_output(0)->TotalBytes();
420       se::DeviceMemoryBase gpu_src_ptr(DMAHelper::base(ctx->mutable_output(0)),
421                                        return_bytes);
422       gpu_helper_->d2h_stream()->ThenMemcpy(DMAHelper::base(&host_outputs_[0]),
423                                             gpu_src_ptr, return_bytes);
424       gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->d2h_stream(),
425                                             callback);
426       if (times) {
427         gpu_helper_->event_mgr()->ThenExecute(
428             gpu_helper_->d2h_stream(), [times, r]() {
429               times->at(r).all_done = Env::Default()->NowMicros();
430             });
431       }
432     }
433   }
434 };
435 
BM_no_ops(::testing::benchmark::State & state)436 static void BM_no_ops(::testing::benchmark::State& state) {
437   const int threads = state.range(0);
438   const int iters = state.max_iterations;
439 
440   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
441   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
442   CHECK(stream);
443   stream->Init();
444   TEST_EventMgr em(stream_exec, GPUOptions());
445 
446   auto benchmark_exec = [&]() {
447     std::atomic<int> counter;
448     counter.store(0, std::memory_order_seq_cst);
449     se::Stream* stream_ptr = stream.get();
450     auto runner = [&em, &counter, stream_ptr, iters]() {
451       auto callback = [&counter]() { counter.fetch_add(1); };
452       for (int i = 0; i < iters; ++i) {
453         em.ThenExecute(stream_ptr, callback);
454       }
455     };
456     for (int t = 0; t < threads; ++t) {
457       Env::Default()->SchedClosure(runner);
458     }
459     int expected = iters * threads;
460     while (counter < expected) {
461       Env::Default()->SleepForMicroseconds(1);
462     }
463   };
464 
465 #ifdef PLATFORM_GOOGLE
466 
467   // The timer starts automatically
468   while (state.KeepRunningBatch(state.max_iterations)) {
469     benchmark_exec();
470   }
471 #else
472   // The tensorflow's own implementation of the benchmark does not support
473   // running-batch (yet), therefore we had to use the Stop/StartTimer.
474   // FIXME: Remove this if-def once we switched all tensorflow's benchmarks to
475   // using the OSS benchmark library.
476 
477   state.ResumeTiming();
478   benchmark_exec();
479   state.PauseTiming();
480 #endif
481 }
482 BENCHMARK(BM_no_ops)->UseRealTime()->Arg(4)->Arg(8)->Arg(32);
483 
484 // Benchmark functions are defined at top level.  In order to provide a real,
485 // persistent GPUDevice to the following function it also needs to be at top
486 // level.  But then we can't clean it up without a cuda runtime error, so we
487 // just leak it.
488 GPUDeviceTestHelper* gpu_helper = nullptr;
489 EMBenchmarkHelper* bm_helper = nullptr;
490 mutex helper_mu;
491 
492 #ifdef PLATFORM_GOOGLE
BM_chain_ops(::testing::benchmark::State & state,int tensor_size,int adds_per_round,bool event_after_add,int pending_cap)493 static void BM_chain_ops(::testing::benchmark::State& state, int tensor_size,
494                          int adds_per_round, bool event_after_add,
495                          int pending_cap) {
496 #else
497 static void BM_chain_ops(::testing::benchmark::State& state, int tensor_size,
498                          int adds_per_round, bool event_after_add,
499                          int pending_cap, int threads) {
500 #endif
501   const int iters = state.max_iterations;
502   {
503     mutex_lock l(helper_mu);
504     if (gpu_helper && gpu_helper->pending_cap() != pending_cap) {
505       delete bm_helper;
506       bm_helper = nullptr;
507       delete gpu_helper;
508       gpu_helper = nullptr;
509     }
510     if (!gpu_helper) {
511       gpu_helper = new GPUDeviceTestHelper(1 << 24, pending_cap);
512       bm_helper = new EMBenchmarkHelper(gpu_helper);
513     }
514     if (bm_helper->num_ops() != adds_per_round ||
515         bm_helper->tensor_size() != tensor_size) {
516       bm_helper->ReInit(adds_per_round, tensor_size);
517     }
518   }
519   std::vector<EMBenchmarkHelper::TimeSet> times;
520   std::vector<EMBenchmarkHelper::TimeSet>* time_ptr = nullptr;
521   if (VLOG_IS_ON(1)) {
522     times.resize(iters);
523     time_ptr = &times;
524   }
525   std::atomic<int> counter;
526   counter.store(0, std::memory_order_seq_cst);
527   auto callback = [&counter]() { counter.fetch_add(1); };
528   // First iter is always slow, so do one prior to the timed loop.
529   int expected = 1 + (event_after_add ? adds_per_round : 0);
530   bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
531   while (counter < expected) {
532     Env::Default()->SleepForMicroseconds(1);
533   }
534   counter = 0;
535 
536 #ifdef PLATFORM_GOOGLE
537   while (state.KeepRunningBatch(state.max_iterations)) {
538     expected = iters * (1 + (event_after_add ? adds_per_round : 0));
539     bm_helper->DoAddChain(adds_per_round, iters, event_after_add, callback,
540                           time_ptr);
541     while (counter < expected) {
542       Env::Default()->SleepForMicroseconds(1);
543     }
544   }
545 #else
546   state.ResumeTiming();
547   expected = threads * iters * (1 + (event_after_add ? adds_per_round : 0));
548   for (int i = 0; i < threads; ++i) {
549     Env::Default()->SchedClosure(
550         [callback, iters, adds_per_round, event_after_add, time_ptr]() {
551           bm_helper->DoAddChain(adds_per_round, iters, event_after_add,
552                                 callback, time_ptr);
553         });
554   }
555   while (counter < expected) {
556     Env::Default()->SleepForMicroseconds(1);
557   }
558   state.PauseTiming();
559 #endif
560   VLOG(1) << "counter = " << counter << " post_execute Output: "
561           << bm_helper->host_outputs(0).SummarizeValue(64);
562   if (time_ptr) bm_helper->DisplayTimes(time_ptr);
563 }
564 
565 #ifdef PLATFORM_GOOGLE
566 static void BM_chain_1024_1_false(::testing::benchmark::State& state) {
567   BM_chain_ops(state, 1024, 1, false, 0);
568 }
569 
570 static void BM_chain_1024_1_true(::testing::benchmark::State& state) {
571   BM_chain_ops(state, 1024, 1, true, 0);
572 }
573 
574 static void BM_chain_1024_10_false(::testing::benchmark::State& state) {
575   BM_chain_ops(state, 1024, 10, false, 0);
576 }
577 
578 static void BM_chain_1024_10_true(::testing::benchmark::State& state) {
579   BM_chain_ops(state, 1024, 10, true, 0);
580 }
581 
582 static void BM_chain_1024_100_false(::testing::benchmark::State& state) {
583   BM_chain_ops(state, 1024, 100, false, 0);
584 }
585 
586 static void BM_chain_1024_100_true(::testing::benchmark::State& state) {
587   BM_chain_ops(state, 1024, 100, true, 0);
588 }
589 
590 static void BM_chain_1M_1_false(::testing::benchmark::State& state) {
591   BM_chain_ops(state, 1 << 20, 1, false, 0);
592 }
593 
594 static void BM_chain_1M_1_true(::testing::benchmark::State& state) {
595   BM_chain_ops(state, 1 << 20, 1, true, 0);
596 }
597 
598 static void BM_chain_1M_10_false(::testing::benchmark::State& state) {
599   BM_chain_ops(state, 1 << 20, 10, false, 0);
600 }
601 
602 static void BM_chain_1M_10_true(::testing::benchmark::State& state) {
603   BM_chain_ops(state, 1 << 20, 10, true, 0);
604 }
605 
606 static void BM_chain_1M_100_false(::testing::benchmark::State& state) {
607   BM_chain_ops(state, 1 << 20, 100, false, 0);
608 }
609 
610 static void BM_chain_1M_100_true(::testing::benchmark::State& state) {
611   BM_chain_ops(state, 1 << 20, 100, true, 0);
612 }
613 
614 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(1);
615 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(1);
616 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(2);
617 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(2);
618 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(8);
619 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(8);
620 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Threads(1);
621 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Threads(1);
622 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Threads(8);
623 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Threads(8);
624 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(1);
625 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(1);
626 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(2);
627 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(2);
628 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(8);
629 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(8);
630 
631 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(1);
632 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(1);
633 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(2);
634 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(2);
635 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(8);
636 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(8);
637 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Threads(1);
638 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Threads(1);
639 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Threads(8);
640 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Threads(8);
641 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(1);
642 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(1);
643 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(2);
644 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(2);
645 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(8);
646 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(8);
647 #else
648 static void BM_chain_1024_1_false(::testing::benchmark::State& state) {
649   const int threads = state.range(0);
650   BM_chain_ops(state, 1024, 1, false, 0, threads);
651 }
652 
653 static void BM_chain_1024_1_true(::testing::benchmark::State& state) {
654   const int threads = state.range(0);
655   BM_chain_ops(state, 1024, 1, true, 0, threads);
656 }
657 
658 static void BM_chain_1024_10_false(::testing::benchmark::State& state) {
659   const int threads = state.range(0);
660   BM_chain_ops(state, 1024, 10, false, 0, threads);
661 }
662 
663 static void BM_chain_1024_10_true(::testing::benchmark::State& state) {
664   const int threads = state.range(0);
665   BM_chain_ops(state, 1024, 10, true, 0, threads);
666 }
667 
668 static void BM_chain_1024_100_false(::testing::benchmark::State& state) {
669   const int threads = state.range(0);
670   BM_chain_ops(state, 1024, 100, false, 0, threads);
671 }
672 
673 static void BM_chain_1024_100_true(::testing::benchmark::State& state) {
674   const int threads = state.range(0);
675   BM_chain_ops(state, 1024, 100, true, 0, threads);
676 }
677 
678 static void BM_chain_1M_1_false(::testing::benchmark::State& state) {
679   const int threads = state.range(0);
680   BM_chain_ops(state, 1 << 20, 1, false, 0, threads);
681 }
682 
683 static void BM_chain_1M_1_true(::testing::benchmark::State& state) {
684   const int threads = state.range(0);
685   BM_chain_ops(state, 1 << 20, 1, true, 0, threads);
686 }
687 
688 static void BM_chain_1M_10_false(::testing::benchmark::State& state) {
689   const int threads = state.range(0);
690   BM_chain_ops(state, 1 << 20, 10, false, 0, threads);
691 }
692 
693 static void BM_chain_1M_10_true(::testing::benchmark::State& state) {
694   const int threads = state.range(0);
695   BM_chain_ops(state, 1 << 20, 10, true, 0, threads);
696 }
697 
698 static void BM_chain_1M_100_false(::testing::benchmark::State& state) {
699   const int threads = state.range(0);
700   BM_chain_ops(state, 1 << 20, 100, false, 0, threads);
701 }
702 
703 static void BM_chain_1M_100_true(::testing::benchmark::State& state) {
704   const int threads = state.range(0);
705   BM_chain_ops(state, 1 << 20, 100, true, 0, threads);
706 }
707 
708 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(1);
709 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(1);
710 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(2);
711 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(2);
712 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(8);
713 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(8);
714 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Arg(1);
715 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Arg(1);
716 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Arg(8);
717 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Arg(8);
718 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(1);
719 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(1);
720 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(2);
721 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(2);
722 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(8);
723 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(8);
724 
725 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(1);
726 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(1);
727 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(2);
728 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(2);
729 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(8);
730 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(8);
731 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Arg(1);
732 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Arg(1);
733 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Arg(8);
734 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Arg(8);
735 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(1);
736 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(1);
737 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(2);
738 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(2);
739 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(8);
740 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(8);
741 #endif
742 }  // namespace
743 }  // namespace tensorflow
744 
745 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
746