• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/default/test_benchmark.h"
17 
18 #include <algorithm>
19 #include <cstdio>
20 #include <cstdlib>
21 #include <vector>
22 
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/str_util.h"
26 #include "tensorflow/core/util/reporter.h"
27 
28 namespace tensorflow {
29 namespace testing {
30 namespace internal {
31 
UseCharPointer(char const volatile *)32 void UseCharPointer(char const volatile*) {}
33 
34 }  // namespace internal
35 
36 static std::vector<Benchmark*>* all_benchmarks = nullptr;
37 static std::string label;
38 static int64 bytes_processed;
39 static int64 items_processed;
40 static int64 accum_time = 0;
41 static int64 start_time = 0;
42 static Env* env;
43 
Benchmark(const char * name,void (* fn)(int))44 Benchmark::Benchmark(const char* name, void (*fn)(int))
45     : name_(name), num_args_(0), fn0_(fn) {
46   args_.push_back(std::make_pair(-1, -1));
47   Register();
48 }
49 
Benchmark(const char * name,void (* fn)(int,int))50 Benchmark::Benchmark(const char* name, void (*fn)(int, int))
51     : name_(name), num_args_(1), fn1_(fn) {
52   Register();
53 }
54 
Benchmark(const char * name,void (* fn)(int,int,int))55 Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
56     : name_(name), num_args_(2), fn2_(fn) {
57   Register();
58 }
59 
Benchmark(const char * name,void (* fn)(::testing::benchmark::State &))60 Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&))
61     : name_(name),
62       // -1 because the number of parameters is not part of the benchmark
63       // routine signature.
64       num_args_(-1),
65       fn_state_(fn) {
66   Register();
67 }
68 
CheckArgCount(int expected)69 void Benchmark::CheckArgCount(int expected) {
70   if (num_args_ == expected) return;
71 
72   // Number of args is not part of function signature.
73   // Verify that if benchmark instantiation has previously provided args, they
74   // match "args".
75   if (num_args_ < 0) {
76     if (args_.empty() || instantiated_num_args_ == expected) return;
77   }
78   CHECK(false) << "Expected " << expected << " args for benchmark, but got "
79                << instantiated_num_args_;
80 }
81 
Arg(int x)82 Benchmark* Benchmark::Arg(int x) {
83   CheckArgCount(/*expected=*/1);
84   args_.push_back(std::make_pair(x, -1));
85   instantiated_num_args_ = 1;
86   return this;
87 }
88 
ArgPair(int x,int y)89 Benchmark* Benchmark::ArgPair(int x, int y) {
90   CheckArgCount(/*expected=*/2);
91   instantiated_num_args_ = 2;
92   args_.push_back(std::make_pair(x, y));
93   return this;
94 }
95 
UseRealTime()96 Benchmark* Benchmark::UseRealTime() {
97   // Do nothing.
98   // This only exists for API compatibility with internal benchmarks.
99   return this;
100 }
101 
102 namespace {
103 
AddRange(std::vector<int> * dst,int lo,int hi,int mult)104 void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
105   CHECK_GE(lo, 0);
106   CHECK_GE(hi, lo);
107 
108   // Add "lo"
109   dst->push_back(lo);
110 
111   // Now space out the benchmarks in multiples of "mult"
112   for (int32 i = 1; i < kint32max / mult; i *= mult) {
113     if (i >= hi) break;
114     if (i > lo) {
115       dst->push_back(i);
116     }
117   }
118   // Add "hi" (if different from "lo")
119   if (hi != lo) {
120     dst->push_back(hi);
121   }
122 }
123 
124 }  // namespace
125 
Range(int lo,int hi)126 Benchmark* Benchmark::Range(int lo, int hi) {
127   std::vector<int> args;
128   AddRange(&args, lo, hi, 8);
129   for (int arg : args) {
130     Arg(arg);
131   }
132   return this;
133 }
134 
RangePair(int lo1,int hi1,int lo2,int hi2)135 Benchmark* Benchmark::RangePair(int lo1, int hi1, int lo2, int hi2) {
136   std::vector<int> args1;
137   std::vector<int> args2;
138   AddRange(&args1, lo1, hi1, 8);
139   AddRange(&args2, lo2, hi2, 8);
140   for (int arg1 : args1) {
141     for (int arg2 : args2) {
142       ArgPair(arg1, arg2);
143     }
144   }
145   return this;
146 }
147 
Run(const char * pattern)148 void Benchmark::Run(const char* pattern) {
149   if (!all_benchmarks) return;
150 
151   // Converts "all" into the wildcard '.*'.  Currently pattern isn't
152   // specified by clients, but we keep this here to match the internal
153   // Google implementation, should we ever enable user-specified
154   // pattern specification.
155   if (StringPiece(pattern) == "all") {
156     pattern = ".*";
157   }
158 
159   // Compute name width.
160   int width = 10;
161   string name;
162   for (auto b : *all_benchmarks) {
163     name = b->name_;
164     for (auto arg : b->args_) {
165       name.resize(b->name_.size());
166       if (arg.first >= 0) {
167         strings::StrAppend(&name, "/", arg.first);
168         if (arg.second >= 0) {
169           strings::StrAppend(&name, "/", arg.second);
170         }
171       }
172 
173       // TODO(vrv): Check against 'pattern' using a regex before
174       // computing the width, if we start allowing clients to pass in
175       // a custom pattern.
176       width = std::max<int>(width, name.size());
177     }
178   }
179 
180   printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
181   printf("%s\n", string(width + 22, '-').c_str());
182   for (auto b : *all_benchmarks) {
183     name = b->name_;
184     if (b->instantiated_num_args_ == -1 && b->args_.empty()) {
185       // The BM_*(int) interface (ie, benchmark without params) automatically
186       // adds a default (-1, -1) arg pair to b->args_.
187       // The BM_(benchmark::State&) interface does not do this because it does
188       // not know how many parameters are going to be registered.
189       // So we just add the place holder here.
190       b->args_.push_back(std::make_pair(-1, -1));
191     }
192     for (auto arg : b->args_) {
193       name.resize(b->name_.size());
194       if (arg.first >= 0) {
195         strings::StrAppend(&name, "/", arg.first);
196         if (arg.second >= 0) {
197           strings::StrAppend(&name, "/", arg.second);
198         }
199       }
200 
201       // TODO(vrv): Match 'name' against 'pattern' using a regex
202       // before continuing, if we start allowing clients to pass in a
203       // custom pattern.
204 
205       int iters;
206       double seconds;
207       b->Run(arg.first, arg.second, &iters, &seconds);
208 
209       char buf[100];
210       std::string full_label = label;
211       if (bytes_processed > 0) {
212         snprintf(buf, sizeof(buf), " %.5fMB/s",
213                  (bytes_processed * 1e-6) / seconds);
214         full_label += buf;
215       }
216       if (items_processed > 0) {
217         snprintf(buf, sizeof(buf), " %.5fM items/s",
218                  (items_processed * 1e-6) / seconds);
219         full_label += buf;
220       }
221       printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(),
222              seconds * 1e9 / iters, iters, full_label.c_str());
223 
224       TestReporter reporter(name);
225       Status s = reporter.Initialize();
226       if (!s.ok()) {
227         LOG(ERROR) << s.ToString();
228         exit(EXIT_FAILURE);
229       }
230       s = reporter.Benchmark(iters, 0.0, seconds,
231                              items_processed * 1e-6 / seconds);
232       if (!s.ok()) {
233         LOG(ERROR) << s.ToString();
234         exit(EXIT_FAILURE);
235       }
236       s = reporter.Close();
237       if (!s.ok()) {
238         LOG(ERROR) << s.ToString();
239         exit(EXIT_FAILURE);
240       }
241     }
242   }
243 }
244 
Register()245 void Benchmark::Register() {
246   if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>;
247   all_benchmarks->push_back(this);
248 }
249 
Run(int arg1,int arg2,int * run_count,double * run_seconds)250 void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
251   env = Env::Default();
252   static const int64 kMinIters = 100;
253   static const int64 kMaxIters = 1000000000;
254   static const double kMinTime = 0.5;
255   int64 iters = kMinIters;
256 
257   while (true) {
258     accum_time = 0;
259     start_time = env->NowMicros();
260     bytes_processed = -1;
261     items_processed = -1;
262     label.clear();
263     if (fn0_) {
264       (*fn0_)(iters);
265     } else if (fn1_) {
266       (*fn1_)(iters, arg1);
267     } else if (fn2_) {
268       (*fn2_)(iters, arg1, arg2);
269     } else if (fn_state_) {
270       std::vector<int> arg_list = {arg1, arg2};
271       ::testing::benchmark::State state(iters, instantiated_num_args_,
272                                         std::move(arg_list));
273       (*fn_state_)(state);
274     }
275     StopTiming();
276     const double seconds = accum_time * 1e-6;
277     if (seconds >= kMinTime || iters >= kMaxIters) {
278       *run_count = iters;
279       *run_seconds = seconds;
280       return;
281     }
282 
283     // Update number of iterations.  Overshoot by 40% in an attempt
284     // to succeed the next time.
285     double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9);
286     multiplier = std::min(10.0, multiplier);
287     if (multiplier <= 1.0) multiplier *= 2.0;
288     iters = std::max<int64>(multiplier * iters, iters + 1);
289     iters = std::min(iters, kMaxIters);
290   }
291 }
292 
293 // TODO(vrv): Add support for running a subset of benchmarks by having
294 // RunBenchmarks take in a spec (and maybe other options such as
295 // benchmark_min_time, etc).
RunBenchmarks()296 void RunBenchmarks() { Benchmark::Run("all"); }
SetLabel(const std::string & l)297 void SetLabel(const std::string& l) { label = l; }
BytesProcessed(int64 n)298 void BytesProcessed(int64 n) { bytes_processed = n; }
ItemsProcessed(int64 n)299 void ItemsProcessed(int64 n) { items_processed = n; }
StartTiming()300 void StartTiming() {
301   if (start_time == 0) start_time = env->NowMicros();
302 }
StopTiming()303 void StopTiming() {
304   if (start_time != 0) {
305     accum_time += (env->NowMicros() - start_time);
306     start_time = 0;
307   }
308 }
UseRealTime()309 void UseRealTime() {}
310 
311 }  // namespace testing
312 }  // namespace tensorflow
313 
314 namespace testing {
315 namespace benchmark {
State(size_t max_iterations,int formal_arg_count,std::vector<int> args)316 State::State(size_t max_iterations, int formal_arg_count, std::vector<int> args)
317     : max_iterations(max_iterations),
318       formal_arg_count_(formal_arg_count),
319       args_(std::move(args)) {
320   completed_iterations_ = 0;
321 }
322 
PauseTiming()323 void State::PauseTiming() { ::tensorflow::testing::StopTiming(); }
324 
ResumeTiming()325 void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); }
326 
SetBytesProcessed(::tensorflow::int64 bytes)327 void State::SetBytesProcessed(::tensorflow::int64 bytes) {
328   ::tensorflow::testing::BytesProcessed(bytes);
329 }
330 
SetItemsProcessed(::tensorflow::int64 items)331 void State::SetItemsProcessed(::tensorflow::int64 items) {
332   ::tensorflow::testing::ItemsProcessed(items);
333 }
334 
SetLabel(absl::string_view label)335 void State::SetLabel(absl::string_view label) {
336   ::tensorflow::testing::SetLabel(std::string(label));
337 }
338 
range(size_t i) const339 int State::range(size_t i) const {
340   if (i >= formal_arg_count_) {
341     LOG(FATAL) << "argument for range " << i << " is not set";
342   }
343   return args_[i];
344 }
345 
RunSpecifiedBenchmarks()346 void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); }
347 
348 }  // namespace benchmark
349 }  // namespace testing
350