1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/test_benchmark.h"
17
18 #include <cstdio>
19 #include <cstdlib>
20
21 #include <algorithm>
22 #include <vector>
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/util/reporter.h"
27
28 namespace tensorflow {
29 namespace testing {
30
31 static std::vector<Benchmark*>* all_benchmarks = nullptr;
32 static std::string label;
33 static int64 bytes_processed;
34 static int64 items_processed;
35 static int64 accum_time = 0;
36 static int64 start_time = 0;
37 static Env* env;
38
Benchmark(const char * name,void (* fn)(int))39 Benchmark::Benchmark(const char* name, void (*fn)(int))
40 : name_(name), num_args_(0), fn0_(fn) {
41 args_.push_back(std::make_pair(-1, -1));
42 Register();
43 }
44
Benchmark(const char * name,void (* fn)(int,int))45 Benchmark::Benchmark(const char* name, void (*fn)(int, int))
46 : name_(name), num_args_(1), fn1_(fn) {
47 Register();
48 }
49
Benchmark(const char * name,void (* fn)(int,int,int))50 Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
51 : name_(name), num_args_(2), fn2_(fn) {
52 Register();
53 }
54
Arg(int x)55 Benchmark* Benchmark::Arg(int x) {
56 CHECK_EQ(num_args_, 1);
57 args_.push_back(std::make_pair(x, -1));
58 return this;
59 }
60
ArgPair(int x,int y)61 Benchmark* Benchmark::ArgPair(int x, int y) {
62 CHECK_EQ(num_args_, 2);
63 args_.push_back(std::make_pair(x, y));
64 return this;
65 }
66
67 namespace {
68
AddRange(std::vector<int> * dst,int lo,int hi,int mult)69 void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
70 CHECK_GE(lo, 0);
71 CHECK_GE(hi, lo);
72
73 // Add "lo"
74 dst->push_back(lo);
75
76 // Now space out the benchmarks in multiples of "mult"
77 for (int32 i = 1; i < kint32max / mult; i *= mult) {
78 if (i >= hi) break;
79 if (i > lo) {
80 dst->push_back(i);
81 }
82 }
83 // Add "hi" (if different from "lo")
84 if (hi != lo) {
85 dst->push_back(hi);
86 }
87 }
88
89 } // namespace
90
Range(int lo,int hi)91 Benchmark* Benchmark::Range(int lo, int hi) {
92 std::vector<int> args;
93 AddRange(&args, lo, hi, 8);
94 for (int arg : args) {
95 Arg(arg);
96 }
97 return this;
98 }
99
RangePair(int lo1,int hi1,int lo2,int hi2)100 Benchmark* Benchmark::RangePair(int lo1, int hi1, int lo2, int hi2) {
101 std::vector<int> args1;
102 std::vector<int> args2;
103 AddRange(&args1, lo1, hi1, 8);
104 AddRange(&args2, lo2, hi2, 8);
105 for (int arg1 : args1) {
106 for (int arg2 : args2) {
107 ArgPair(arg1, arg2);
108 }
109 }
110 return this;
111 }
112
Run(const char * pattern)113 void Benchmark::Run(const char* pattern) {
114 if (!all_benchmarks) return;
115
116 // Converts "all" into the wildcard '.*'. Currently pattern isn't
117 // specified by clients, but we keep this here to match the internal
118 // Google implementation, should we ever enable user-specified
119 // pattern specification.
120 if (StringPiece(pattern) == "all") {
121 pattern = ".*";
122 }
123
124 // Compute name width.
125 int width = 10;
126 string name;
127 for (auto b : *all_benchmarks) {
128 name = b->name_;
129 for (auto arg : b->args_) {
130 name.resize(b->name_.size());
131 if (arg.first >= 0) {
132 strings::StrAppend(&name, "/", arg.first);
133 if (arg.second >= 0) {
134 strings::StrAppend(&name, "/", arg.second);
135 }
136 }
137
138 // TODO(vrv): Check against 'pattern' using a regex before
139 // computing the width, if we start allowing clients to pass in
140 // a custom pattern.
141 width = std::max<int>(width, name.size());
142 }
143 }
144
145 printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
146 printf("%s\n", string(width + 22, '-').c_str());
147 for (auto b : *all_benchmarks) {
148 name = b->name_;
149 for (auto arg : b->args_) {
150 name.resize(b->name_.size());
151 if (arg.first >= 0) {
152 strings::StrAppend(&name, "/", arg.first);
153 if (arg.second >= 0) {
154 strings::StrAppend(&name, "/", arg.second);
155 }
156 }
157
158 // TODO(vrv): Match 'name' against 'pattern' using a regex
159 // before continuing, if we start allowing clients to pass in a
160 // custom pattern.
161
162 int iters;
163 double seconds;
164 b->Run(arg.first, arg.second, &iters, &seconds);
165
166 char buf[100];
167 std::string full_label = label;
168 if (bytes_processed > 0) {
169 snprintf(buf, sizeof(buf), " %.1fMB/s",
170 (bytes_processed * 1e-6) / seconds);
171 full_label += buf;
172 }
173 if (items_processed > 0) {
174 snprintf(buf, sizeof(buf), " %.1fM items/s",
175 (items_processed * 1e-6) / seconds);
176 full_label += buf;
177 }
178 printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(),
179 seconds * 1e9 / iters, iters, full_label.c_str());
180
181 TestReporter reporter(name);
182 Status s = reporter.Initialize();
183 if (!s.ok()) {
184 LOG(ERROR) << s.ToString();
185 exit(EXIT_FAILURE);
186 }
187 s = reporter.Benchmark(iters, 0.0, seconds,
188 items_processed * 1e-6 / seconds);
189 if (!s.ok()) {
190 LOG(ERROR) << s.ToString();
191 exit(EXIT_FAILURE);
192 }
193 s = reporter.Close();
194 if (!s.ok()) {
195 LOG(ERROR) << s.ToString();
196 exit(EXIT_FAILURE);
197 }
198 }
199 }
200 }
201
Register()202 void Benchmark::Register() {
203 if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>;
204 all_benchmarks->push_back(this);
205 }
206
Run(int arg1,int arg2,int * run_count,double * run_seconds)207 void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
208 env = Env::Default();
209 static const int64 kMinIters = 100;
210 static const int64 kMaxIters = 1000000000;
211 static const double kMinTime = 0.5;
212 int64 iters = kMinIters;
213 while (true) {
214 accum_time = 0;
215 start_time = env->NowMicros();
216 bytes_processed = -1;
217 items_processed = -1;
218 label.clear();
219 if (fn0_) {
220 (*fn0_)(iters);
221 } else if (fn1_) {
222 (*fn1_)(iters, arg1);
223 } else {
224 (*fn2_)(iters, arg1, arg2);
225 }
226 StopTiming();
227 const double seconds = accum_time * 1e-6;
228 if (seconds >= kMinTime || iters >= kMaxIters) {
229 *run_count = iters;
230 *run_seconds = seconds;
231 return;
232 }
233
234 // Update number of iterations. Overshoot by 40% in an attempt
235 // to succeed the next time.
236 double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9);
237 multiplier = std::min(10.0, multiplier);
238 if (multiplier <= 1.0) multiplier *= 2.0;
239 iters = std::max<int64>(multiplier * iters, iters + 1);
240 iters = std::min(iters, kMaxIters);
241 }
242 }
243
244 // TODO(vrv): Add support for running a subset of benchmarks by having
245 // RunBenchmarks take in a spec (and maybe other options such as
246 // benchmark_min_time, etc).
RunBenchmarks()247 void RunBenchmarks() { Benchmark::Run("all"); }
SetLabel(const std::string & l)248 void SetLabel(const std::string& l) { label = l; }
BytesProcessed(int64 n)249 void BytesProcessed(int64 n) { bytes_processed = n; }
ItemsProcessed(int64 n)250 void ItemsProcessed(int64 n) { items_processed = n; }
StartTiming()251 void StartTiming() {
252 if (start_time == 0) start_time = env->NowMicros();
253 }
StopTiming()254 void StopTiming() {
255 if (start_time != 0) {
256 accum_time += (env->NowMicros() - start_time);
257 start_time = 0;
258 }
259 }
UseRealTime()260 void UseRealTime() {}
261
262 } // namespace testing
263 } // namespace tensorflow
264