1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/default/test_benchmark.h"
17
18 #include <algorithm>
19 #include <cstdio>
20 #include <cstdlib>
21 #include <vector>
22
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/str_util.h"
26 #include "tensorflow/core/util/reporter.h"
27
28 namespace tensorflow {
29 namespace testing {
30 namespace internal {
31
UseCharPointer(char const volatile *)32 void UseCharPointer(char const volatile*) {}
33
34 } // namespace internal
35
36 static std::vector<Benchmark*>* all_benchmarks = nullptr;
37 static std::string label;
38 static int64 bytes_processed;
39 static int64 items_processed;
40 static int64 accum_time = 0;
41 static int64 start_time = 0;
42 static Env* env;
43
Benchmark(const char * name,void (* fn)(int))44 Benchmark::Benchmark(const char* name, void (*fn)(int))
45 : name_(name), num_args_(0), fn0_(fn) {
46 args_.push_back(std::make_pair(-1, -1));
47 Register();
48 }
49
Benchmark(const char * name,void (* fn)(int,int))50 Benchmark::Benchmark(const char* name, void (*fn)(int, int))
51 : name_(name), num_args_(1), fn1_(fn) {
52 Register();
53 }
54
Benchmark(const char * name,void (* fn)(int,int,int))55 Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
56 : name_(name), num_args_(2), fn2_(fn) {
57 Register();
58 }
59
Benchmark(const char * name,void (* fn)(::testing::benchmark::State &))60 Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&))
61 : name_(name),
62 // -1 because the number of parameters is not part of the benchmark
63 // routine signature.
64 num_args_(-1),
65 fn_state_(fn) {
66 Register();
67 }
68
CheckArgCount(int expected)69 void Benchmark::CheckArgCount(int expected) {
70 if (num_args_ == expected) return;
71
72 // Number of args is not part of function signature.
73 // Verify that if benchmark instantiation has previously provided args, they
74 // match "args".
75 if (num_args_ < 0) {
76 if (args_.empty() || instantiated_num_args_ == expected) return;
77 }
78 CHECK(false) << "Expected " << expected << " args for benchmark, but got "
79 << instantiated_num_args_;
80 }
81
Arg(int x)82 Benchmark* Benchmark::Arg(int x) {
83 CheckArgCount(/*expected=*/1);
84 args_.push_back(std::make_pair(x, -1));
85 instantiated_num_args_ = 1;
86 return this;
87 }
88
ArgPair(int x,int y)89 Benchmark* Benchmark::ArgPair(int x, int y) {
90 CheckArgCount(/*expected=*/2);
91 instantiated_num_args_ = 2;
92 args_.push_back(std::make_pair(x, y));
93 return this;
94 }
95
UseRealTime()96 Benchmark* Benchmark::UseRealTime() {
97 // Do nothing.
98 // This only exists for API compatibility with internal benchmarks.
99 return this;
100 }
101
102 namespace {
103
AddRange(std::vector<int> * dst,int lo,int hi,int mult)104 void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
105 CHECK_GE(lo, 0);
106 CHECK_GE(hi, lo);
107
108 // Add "lo"
109 dst->push_back(lo);
110
111 // Now space out the benchmarks in multiples of "mult"
112 for (int32 i = 1; i < kint32max / mult; i *= mult) {
113 if (i >= hi) break;
114 if (i > lo) {
115 dst->push_back(i);
116 }
117 }
118 // Add "hi" (if different from "lo")
119 if (hi != lo) {
120 dst->push_back(hi);
121 }
122 }
123
124 } // namespace
125
Range(int lo,int hi)126 Benchmark* Benchmark::Range(int lo, int hi) {
127 std::vector<int> args;
128 AddRange(&args, lo, hi, 8);
129 for (int arg : args) {
130 Arg(arg);
131 }
132 return this;
133 }
134
RangePair(int lo1,int hi1,int lo2,int hi2)135 Benchmark* Benchmark::RangePair(int lo1, int hi1, int lo2, int hi2) {
136 std::vector<int> args1;
137 std::vector<int> args2;
138 AddRange(&args1, lo1, hi1, 8);
139 AddRange(&args2, lo2, hi2, 8);
140 for (int arg1 : args1) {
141 for (int arg2 : args2) {
142 ArgPair(arg1, arg2);
143 }
144 }
145 return this;
146 }
147
Run(const char * pattern)148 void Benchmark::Run(const char* pattern) {
149 if (!all_benchmarks) return;
150
151 // Converts "all" into the wildcard '.*'. Currently pattern isn't
152 // specified by clients, but we keep this here to match the internal
153 // Google implementation, should we ever enable user-specified
154 // pattern specification.
155 if (StringPiece(pattern) == "all") {
156 pattern = ".*";
157 }
158
159 // Compute name width.
160 int width = 10;
161 string name;
162 for (auto b : *all_benchmarks) {
163 name = b->name_;
164 for (auto arg : b->args_) {
165 name.resize(b->name_.size());
166 if (arg.first >= 0) {
167 strings::StrAppend(&name, "/", arg.first);
168 if (arg.second >= 0) {
169 strings::StrAppend(&name, "/", arg.second);
170 }
171 }
172
173 // TODO(vrv): Check against 'pattern' using a regex before
174 // computing the width, if we start allowing clients to pass in
175 // a custom pattern.
176 width = std::max<int>(width, name.size());
177 }
178 }
179
180 printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
181 printf("%s\n", string(width + 22, '-').c_str());
182 for (auto b : *all_benchmarks) {
183 name = b->name_;
184 if (b->instantiated_num_args_ == -1 && b->args_.empty()) {
185 // The BM_*(int) interface (ie, benchmark without params) automatically
186 // adds a default (-1, -1) arg pair to b->args_.
187 // The BM_(benchmark::State&) interface does not do this because it does
188 // not know how many parameters are going to be registered.
189 // So we just add the place holder here.
190 b->args_.push_back(std::make_pair(-1, -1));
191 }
192 for (auto arg : b->args_) {
193 name.resize(b->name_.size());
194 if (arg.first >= 0) {
195 strings::StrAppend(&name, "/", arg.first);
196 if (arg.second >= 0) {
197 strings::StrAppend(&name, "/", arg.second);
198 }
199 }
200
201 // TODO(vrv): Match 'name' against 'pattern' using a regex
202 // before continuing, if we start allowing clients to pass in a
203 // custom pattern.
204
205 int iters;
206 double seconds;
207 b->Run(arg.first, arg.second, &iters, &seconds);
208
209 char buf[100];
210 std::string full_label = label;
211 if (bytes_processed > 0) {
212 snprintf(buf, sizeof(buf), " %.5fMB/s",
213 (bytes_processed * 1e-6) / seconds);
214 full_label += buf;
215 }
216 if (items_processed > 0) {
217 snprintf(buf, sizeof(buf), " %.5fM items/s",
218 (items_processed * 1e-6) / seconds);
219 full_label += buf;
220 }
221 printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(),
222 seconds * 1e9 / iters, iters, full_label.c_str());
223
224 TestReporter reporter(name);
225 Status s = reporter.Initialize();
226 if (!s.ok()) {
227 LOG(ERROR) << s.ToString();
228 exit(EXIT_FAILURE);
229 }
230 s = reporter.Benchmark(iters, 0.0, seconds,
231 items_processed * 1e-6 / seconds);
232 if (!s.ok()) {
233 LOG(ERROR) << s.ToString();
234 exit(EXIT_FAILURE);
235 }
236 s = reporter.Close();
237 if (!s.ok()) {
238 LOG(ERROR) << s.ToString();
239 exit(EXIT_FAILURE);
240 }
241 }
242 }
243 }
244
Register()245 void Benchmark::Register() {
246 if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>;
247 all_benchmarks->push_back(this);
248 }
249
Run(int arg1,int arg2,int * run_count,double * run_seconds)250 void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
251 env = Env::Default();
252 static const int64 kMinIters = 100;
253 static const int64 kMaxIters = 1000000000;
254 static const double kMinTime = 0.5;
255 int64 iters = kMinIters;
256
257 while (true) {
258 accum_time = 0;
259 start_time = env->NowMicros();
260 bytes_processed = -1;
261 items_processed = -1;
262 label.clear();
263 if (fn0_) {
264 (*fn0_)(iters);
265 } else if (fn1_) {
266 (*fn1_)(iters, arg1);
267 } else if (fn2_) {
268 (*fn2_)(iters, arg1, arg2);
269 } else if (fn_state_) {
270 std::vector<int> arg_list = {arg1, arg2};
271 ::testing::benchmark::State state(iters, instantiated_num_args_,
272 std::move(arg_list));
273 (*fn_state_)(state);
274 }
275 StopTiming();
276 const double seconds = accum_time * 1e-6;
277 if (seconds >= kMinTime || iters >= kMaxIters) {
278 *run_count = iters;
279 *run_seconds = seconds;
280 return;
281 }
282
283 // Update number of iterations. Overshoot by 40% in an attempt
284 // to succeed the next time.
285 double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9);
286 multiplier = std::min(10.0, multiplier);
287 if (multiplier <= 1.0) multiplier *= 2.0;
288 iters = std::max<int64>(multiplier * iters, iters + 1);
289 iters = std::min(iters, kMaxIters);
290 }
291 }
292
293 // TODO(vrv): Add support for running a subset of benchmarks by having
294 // RunBenchmarks take in a spec (and maybe other options such as
295 // benchmark_min_time, etc).
RunBenchmarks()296 void RunBenchmarks() { Benchmark::Run("all"); }
SetLabel(const std::string & l)297 void SetLabel(const std::string& l) { label = l; }
BytesProcessed(int64 n)298 void BytesProcessed(int64 n) { bytes_processed = n; }
ItemsProcessed(int64 n)299 void ItemsProcessed(int64 n) { items_processed = n; }
StartTiming()300 void StartTiming() {
301 if (start_time == 0) start_time = env->NowMicros();
302 }
StopTiming()303 void StopTiming() {
304 if (start_time != 0) {
305 accum_time += (env->NowMicros() - start_time);
306 start_time = 0;
307 }
308 }
UseRealTime()309 void UseRealTime() {}
310
311 } // namespace testing
312 } // namespace tensorflow
313
314 namespace testing {
315 namespace benchmark {
State(size_t max_iterations,int formal_arg_count,std::vector<int> args)316 State::State(size_t max_iterations, int formal_arg_count, std::vector<int> args)
317 : max_iterations(max_iterations),
318 formal_arg_count_(formal_arg_count),
319 args_(std::move(args)) {
320 completed_iterations_ = 0;
321 }
322
PauseTiming()323 void State::PauseTiming() { ::tensorflow::testing::StopTiming(); }
324
ResumeTiming()325 void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); }
326
SetBytesProcessed(::tensorflow::int64 bytes)327 void State::SetBytesProcessed(::tensorflow::int64 bytes) {
328 ::tensorflow::testing::BytesProcessed(bytes);
329 }
330
SetItemsProcessed(::tensorflow::int64 items)331 void State::SetItemsProcessed(::tensorflow::int64 items) {
332 ::tensorflow::testing::ItemsProcessed(items);
333 }
334
SetLabel(absl::string_view label)335 void State::SetLabel(absl::string_view label) {
336 ::tensorflow::testing::SetLabel(std::string(label));
337 }
338
range(size_t i) const339 int State::range(size_t i) const {
340 if (i >= formal_arg_count_) {
341 LOG(FATAL) << "argument for range " << i << " is not set";
342 }
343 return args_[i];
344 }
345
RunSpecifiedBenchmarks()346 void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); }
347
348 } // namespace benchmark
349 } // namespace testing
350