1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy
5 // of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15
16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
17 #include "tensorflow/core/framework/tensor.h"
18 #include "tensorflow/core/framework/types.pb.h"
19 #include "tensorflow/core/graph/node_builder.h"
20 #include "tensorflow/core/lib/random/simple_philox.h"
21 #include "tensorflow/core/lib/strings/stringprintf.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/core/platform/test_benchmark.h"
24
25 namespace tensorflow {
26 namespace {
27
28 constexpr int k100Dim = 100;
29 // Number of points for tests.
30 constexpr int k10Points = 10;
31 constexpr int k100Points = 100;
32 constexpr int k1kPoints = 1000;
33 constexpr int k10kPoints = 10000;
34 constexpr int k1MPoints = 1000000;
35 // Number of centers for tests.
36 constexpr int k2Centers = 2;
37 constexpr int k5Centers = 5;
38 constexpr int k10Centers = 10;
39 constexpr int k20Centers = 20;
40 constexpr int k50Centers = 50;
41 constexpr int k100Centers = 100;
42 constexpr int k200Centers = 200;
43 constexpr int k500Centers = 500;
44 constexpr int k1kCenters = 1000;
45 constexpr int k10kCenters = 10000;
46 // Number of retries for tests.
47 constexpr int k0RetriesPerSample = 0;
48 constexpr int k3RetriesPerSample = 3;
49
SetUpKmeansPlusPlusInitialization(int num_dims,int num_points,int num_to_sample,int retries_per_sample)50 Graph* SetUpKmeansPlusPlusInitialization(int num_dims, int num_points,
51 int num_to_sample,
52 int retries_per_sample) {
53 Graph* g = new Graph(OpRegistry::Global());
54 Tensor points(DT_FLOAT, TensorShape({num_points, num_dims}));
55 Tensor sample_size(DT_INT64, TensorShape({}));
56 Tensor seed(DT_INT64, TensorShape({}));
57 Tensor num_retries_per_sample(DT_INT64, TensorShape({}));
58 points.flat<float>().setRandom();
59 sample_size.flat<int64>().setConstant(num_to_sample);
60 seed.flat<int64>().setConstant(12345);
61 num_retries_per_sample.flat<int64>().setConstant(retries_per_sample);
62
63 TF_CHECK_OK(NodeBuilder("kmeans_plus_plus_initialization_op",
64 "KmeansPlusPlusInitialization")
65 .Input(test::graph::Constant(g, points))
66 .Input(test::graph::Constant(g, sample_size))
67 .Input(test::graph::Constant(g, seed))
68 .Input(test::graph::Constant(g, num_retries_per_sample))
69 .Finalize(g, nullptr /* node */));
70 return g;
71 }
72
73 template <int num_points, int num_to_sample, int num_dims,
74 int retries_per_sample>
BM_KmeansPlusPlusInitialization(::testing::benchmark::State & state)75 void BM_KmeansPlusPlusInitialization(::testing::benchmark::State& state) {
76 Graph* g = SetUpKmeansPlusPlusInitialization(
77 num_dims, num_points, num_to_sample, retries_per_sample);
78 test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
79 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_points *
80 num_dims * num_to_sample);
81 }
82
83 #define BENCHMARK_KMEANS_PLUS_PLUS(p, c, d, r) \
84 void BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r( \
85 ::testing::benchmark::State& state) { \
86 BM_KmeansPlusPlusInitialization<p, c, d, r>(state); \
87 } \
88 BENCHMARK(BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r) \
89 ->UseRealTime();
90
91 #define RUN_BM_KmeansPlusPlusInitialization(retries) \
92 BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k2Centers, k100Dim, retries); \
93 BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k5Centers, k100Dim, retries); \
94 BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k10Centers, k100Dim, retries); \
95 BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k10Centers, k100Dim, retries); \
96 BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k20Centers, k100Dim, retries); \
97 BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k50Centers, k100Dim, retries); \
98 BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k100Centers, k100Dim, retries); \
99 BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k100Centers, k100Dim, retries); \
100 BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k200Centers, k100Dim, retries); \
101 BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k500Centers, k100Dim, retries); \
102 BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k1kCenters, k100Dim, retries); \
103 BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k100Centers, k100Dim, retries); \
104 BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k200Centers, k100Dim, retries); \
105 BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k500Centers, k100Dim, retries); \
106 BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k1kCenters, k100Dim, retries); \
107 BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k100Centers, k100Dim, retries); \
108 BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k200Centers, k100Dim, retries); \
109 BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k500Centers, k100Dim, retries); \
110 BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k1kCenters, k100Dim, retries)
111
112 RUN_BM_KmeansPlusPlusInitialization(k0RetriesPerSample);
113 RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
114
115 #undef RUN_BM_KmeansPlusPlusInitialization
116 #undef BENCHMARK_KMEANS_PLUS_PLUS
117
SetUpKMC2Initialization(int num_points)118 Graph* SetUpKMC2Initialization(int num_points) {
119 Graph* g = new Graph(OpRegistry::Global());
120 Tensor distances(DT_FLOAT, TensorShape({num_points}));
121 Tensor seed(DT_INT64, TensorShape({}));
122 distances.flat<float>().setRandom();
123 seed.flat<int64>().setConstant(12345);
124
125 TF_CHECK_OK(
126 NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization")
127 .Input(test::graph::Constant(g, distances))
128 .Input(test::graph::Constant(g, seed))
129 .Finalize(g, nullptr /* node */));
130 return g;
131 }
132
133 template <int num_points, int num_to_sample, int num_dims>
BM_KMC2Initialization(::testing::benchmark::State & state)134 void BM_KMC2Initialization(::testing::benchmark::State& state) {
135 Graph* g = SetUpKMC2Initialization(num_points);
136 test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
137 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_points *
138 num_dims * num_to_sample);
139 }
140 #define BENCHMARK_KMC2(p, c, d) \
141 void BM_KMC2Initialization_##p##_##c##_##d( \
142 ::testing::benchmark::State& state) { \
143 BM_KMC2Initialization<p, c, d>(state); \
144 } \
145 BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d)->UseRealTime();
146
147 #define RUN_BM_KMC2Initialization \
148 BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
149 BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
150 BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
151 BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
152 BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
153 BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
154 BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
155 BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
156 BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
157 BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
158 BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
159 BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
160 BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
161 BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
162 BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
163 BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
164 BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
165 BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
166 BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
167
168 RUN_BM_KMC2Initialization;
169 #undef RUN_BM_KMC2Initialization
170 #undef BENCHMARK_KMC2
171
SetUpNearestNeighbors(int num_dims,int num_points,int num_centers,int k)172 Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
173 int k) {
174 Graph* g = new Graph(OpRegistry::Global());
175 Tensor points(DT_FLOAT, TensorShape({num_points, num_dims}));
176 Tensor centers(DT_FLOAT, TensorShape({num_centers, num_dims}));
177 Tensor top(DT_INT64, TensorShape({}));
178 points.flat<float>().setRandom();
179 centers.flat<float>().setRandom();
180 top.flat<int64>().setConstant(k);
181
182 TF_CHECK_OK(NodeBuilder("nearest_centers_op", "NearestNeighbors")
183 .Input(test::graph::Constant(g, points))
184 .Input(test::graph::Constant(g, centers))
185 .Input(test::graph::Constant(g, top))
186 .Finalize(g, nullptr /* node */));
187 return g;
188 }
189
190 template <int num_dims, int num_points, int num_centers, int k>
BM_NearestNeighbors(::testing::benchmark::State & state)191 void BM_NearestNeighbors(::testing::benchmark::State& state) {
192 Graph* g = SetUpNearestNeighbors(num_dims, num_points, num_centers, k);
193 test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
194 state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_points *
195 num_dims * num_centers);
196 }
197
198 constexpr int kTop1 = 1;
199 constexpr int kTop2 = 2;
200 constexpr int kTop5 = 5;
201 constexpr int kTop10 = 10;
202
203 #define BENCHMARK_NEAREST_NEIGHBORS(d, p, c, k) \
204 void BM_NearestNeighbors##d##_##p##_##c##_##k( \
205 ::testing::benchmark::State& state) { \
206 BM_NearestNeighbors<d, p, c, k>(state); \
207 } \
208 BENCHMARK(BM_NearestNeighbors##d##_##p##_##c##_##k)->UseRealTime();
209
210 #define RUN_BM_NearestNeighbors(k) \
211 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k100Centers, k); \
212 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k1kCenters, k); \
213 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k10kCenters, k); \
214 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k100Centers, k); \
215 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k1kCenters, k); \
216 BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k10kCenters, k)
217
218 RUN_BM_NearestNeighbors(kTop1);
219 // k > 1
220 RUN_BM_NearestNeighbors(kTop2);
221 RUN_BM_NearestNeighbors(kTop5);
222 RUN_BM_NearestNeighbors(kTop10);
223
224 #undef RUN_BM_NearestNeighbors
225 #undef BENCHMARK_NEAREST_NEIGHBORS
226 } // namespace
227 } // namespace tensorflow
228