1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <random>
17
18 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/lib/math/math_util.h"
21 #include "tensorflow/core/lib/random/philox_random.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/core/platform/test_benchmark.h"
24
25 namespace tensorflow {
26 namespace {
27
VecShape(int64_t v)28 Tensor VecShape(int64_t v) {
29 if (v >= std::numeric_limits<int32>::max()) {
30 Tensor shape(DT_INT64, TensorShape({1}));
31 shape.vec<int64>()(0) = v;
32 return shape;
33 } else {
34 Tensor shape(DT_INT32, TensorShape({1}));
35 shape.vec<int32>()(0) = v;
36 return shape;
37 }
38 }
39
RandomUniform(int64_t n)40 Graph* RandomUniform(int64_t n) {
41 Graph* g = new Graph(OpRegistry::Global());
42 test::graph::RandomUniform(g, test::graph::Constant(g, VecShape(n)),
43 DT_FLOAT);
44 return g;
45 }
46
RandomNormal(int64_t n)47 Graph* RandomNormal(int64_t n) {
48 Graph* g = new Graph(OpRegistry::Global());
49 test::graph::RandomGaussian(g, test::graph::Constant(g, VecShape(n)),
50 DT_FLOAT);
51 return g;
52 }
53
TruncatedNormal(int64_t n)54 Graph* TruncatedNormal(int64_t n) {
55 Graph* g = new Graph(OpRegistry::Global());
56 test::graph::TruncatedNormal(g, test::graph::Constant(g, VecShape(n)),
57 DT_FLOAT);
58 return g;
59 }
60
61 #define BM_RNG(DEVICE, RNG) \
62 void BM_##DEVICE##_##RNG(::testing::benchmark::State& state) { \
63 const int arg = state.range(0); \
64 \
65 test::Benchmark(#DEVICE, RNG(arg), /*old_benchmark_api*/ false) \
66 .Run(state); \
67 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * arg); \
68 } \
69 BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
70
71 BM_RNG(cpu, RandomUniform);
72 BM_RNG(cpu, RandomNormal);
73 BM_RNG(cpu, TruncatedNormal);
74
75 BM_RNG(gpu, RandomUniform);
76 BM_RNG(gpu, RandomNormal);
77 BM_RNG(gpu, TruncatedNormal);
78
VecAlphas(int64_t n)79 Tensor VecAlphas(int64_t n) {
80 Tensor alphas(DT_DOUBLE, TensorShape({n}));
81 for (int i = 0; i < n; i++) {
82 // Alternate back and forth between small-and-growing (.25) and
83 // large-and-shrinking (26.67) alpha.
84 alphas.vec<double>()(i) =
85 0.25 + MathUtil::IPow(1.1, i % 2 == 0 ? i : n - i);
86 }
87 return alphas;
88 }
89
BM_cpu_RandomGamma(::testing::benchmark::State & state)90 void BM_cpu_RandomGamma(::testing::benchmark::State& state) {
91 const int nsamp = state.range(0);
92 const int nalpha = state.range(1);
93
94 Graph* g = new Graph(OpRegistry::Global());
95 test::graph::RandomGamma(g, test::graph::Constant(g, VecShape(nsamp)),
96 test::graph::Constant(g, VecAlphas(nalpha)));
97 test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
98 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * nsamp *
99 nalpha);
100 }
101 BENCHMARK(BM_cpu_RandomGamma)->RangePair(1 << 14, 4 << 15, 2, 50);
102
BM_PhiloxRandom(::testing::benchmark::State & state)103 void BM_PhiloxRandom(::testing::benchmark::State& state) {
104 // Fill 2M random numbers
105 int count = 2 << 20;
106 random::PhiloxRandom gen(0x12345);
107
108 for (auto s : state) {
109 for (int j = 0; j < count; j += 4) {
110 /// each invocation of gen() returns 128-bit samples
111 auto samples = gen();
112 tensorflow::testing::DoNotOptimize(samples);
113 }
114 }
115 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
116 }
117 BENCHMARK(BM_PhiloxRandom);
118
BM_StdMTRandom(::testing::benchmark::State & state)119 void BM_StdMTRandom(::testing::benchmark::State& state) {
120 // Fill 2M random numbers
121 int count = 2 << 20;
122 std::mt19937 gen(0x12345);
123
124 for (auto s : state) {
125 for (int j = 0; j < count; ++j) {
126 /// each invocation of gen() returns 32-bit sample
127 uint_fast32_t sample = gen();
128 tensorflow::testing::DoNotOptimize(sample);
129 }
130 }
131 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
132 }
133 BENCHMARK(BM_StdMTRandom);
134
135 } // namespace
136 } // namespace tensorflow
137