1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/fake_input.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/random/simple_philox.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35
36 namespace tensorflow {
37 namespace {
38
39 class ScatterUpdateOpTest : public OpsTestBase {
40 protected:
MakeOp(DataType variable_ref_type,DataType index_type)41 void MakeOp(DataType variable_ref_type, DataType index_type) {
42 TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate")
43 .Input(FakeInput(variable_ref_type))
44 .Input(FakeInput(index_type))
45 .Input(FakeInput(RemoveRefType(variable_ref_type)))
46 .Finalize(node_def()));
47 TF_ASSERT_OK(InitOp());
48 }
49 };
50 class ScatterSubOpTest : public OpsTestBase {
51 protected:
MakeOp(DataType variable_ref_type,DataType index_type)52 void MakeOp(DataType variable_ref_type, DataType index_type) {
53 TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterSub")
54 .Input(FakeInput(variable_ref_type))
55 .Input(FakeInput(index_type))
56 .Input(FakeInput(RemoveRefType(variable_ref_type)))
57 .Finalize(node_def()));
58 TF_ASSERT_OK(InitOp());
59 }
60 };
61
TEST_F(ScatterUpdateOpTest,Simple_StringType)62 TEST_F(ScatterUpdateOpTest, Simple_StringType) {
63 MakeOp(DT_STRING_REF, DT_INT32);
64 AddInputFromArray<tstring>(TensorShape({1}), {"Brain"});
65 AddInputFromArray<int32>(TensorShape({1}), {0});
66 AddInputFromArray<tstring>(TensorShape({1}), {"TensorFlow"});
67 TF_ASSERT_OK(RunOpKernel());
68 // Check the new state of the input
69 Tensor params_tensor = *mutable_input(0).tensor;
70 Tensor expected(allocator(), DT_STRING, TensorShape({1}));
71 test::FillValues<tstring>(&expected, {"TensorFlow"});
72 test::ExpectTensorEqual<tstring>(expected, params_tensor);
73 }
74
TEST_F(ScatterUpdateOpTest,Simple_BoolType)75 TEST_F(ScatterUpdateOpTest, Simple_BoolType) {
76 MakeOp(DT_BOOL_REF, DT_INT32);
77 AddInputFromArray<bool>(TensorShape({1}), {false});
78 AddInputFromArray<int32>(TensorShape({1}), {0});
79 AddInputFromArray<bool>(TensorShape({1}), {true});
80 TF_ASSERT_OK(RunOpKernel());
81 // Check the new state of the input
82 Tensor params_tensor = *mutable_input(0).tensor;
83 Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
84 test::FillValues<bool>(&expected, {true});
85 test::ExpectTensorEqual<bool>(expected, params_tensor);
86 }
87
TEST_F(ScatterUpdateOpTest,Simple_TwoD32)88 TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
89 MakeOp(DT_FLOAT_REF, DT_INT32);
90
91 // Feed and run
92 AddInputFromArray<float>(TensorShape({5, 3}),
93 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
94 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
95 AddInputFromArray<float>(TensorShape({3, 3}),
96 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
97 TF_ASSERT_OK(RunOpKernel());
98
99 // Check the new state of the input
100 Tensor params_tensor = *mutable_input(0).tensor;
101 Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
102 test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
103 10002, 0, 0, 0, 777, 778, 779});
104 test::ExpectTensorEqual<float>(expected, params_tensor);
105 }
106
TEST_F(ScatterUpdateOpTest,Simple_Two64)107 TEST_F(ScatterUpdateOpTest, Simple_Two64) {
108 MakeOp(DT_FLOAT_REF, DT_INT64);
109
110 // Feed and run
111 AddInputFromArray<float>(TensorShape({5, 3}),
112 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
113 AddInputFromArray<int64>(TensorShape({3}), {0, 4, 2});
114 AddInputFromArray<float>(TensorShape({3, 3}),
115 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
116 TF_ASSERT_OK(RunOpKernel());
117
118 // Check the new state of the input
119 Tensor params_tensor = *mutable_input(0).tensor;
120 Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
121 test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
122 10002, 0, 0, 0, 777, 778, 779});
123 test::ExpectTensorEqual<float>(expected, params_tensor);
124 }
125
TEST_F(ScatterUpdateOpTest,Simple_ZeroD)126 TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
127 MakeOp(DT_FLOAT_REF, DT_INT32);
128
129 // Feed and run
130 AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
131 AddInputFromArray<int32>(TensorShape({}), {3});
132 AddInputFromArray<float>(TensorShape({}), {101});
133 TF_ASSERT_OK(RunOpKernel());
134
135 // Check the new state of the input
136 Tensor params_tensor = *mutable_input(0).tensor;
137 Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
138 test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
139 test::ExpectTensorEqual<float>(expected, params_tensor);
140 }
141
TEST_F(ScatterUpdateOpTest,Simple_OneD)142 TEST_F(ScatterUpdateOpTest, Simple_OneD) {
143 MakeOp(DT_FLOAT_REF, DT_INT32);
144
145 // Feed and run
146 AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
147 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
148 AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
149 TF_ASSERT_OK(RunOpKernel());
150
151 // Check the new state of the input
152 Tensor params_tensor = *mutable_input(0).tensor;
153 Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
154 test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
155 test::ExpectTensorEqual<float>(expected, params_tensor);
156 }
157
TEST_F(ScatterUpdateOpTest,HigherRank)158 TEST_F(ScatterUpdateOpTest, HigherRank) {
159 MakeOp(DT_FLOAT_REF, DT_INT32);
160
161 // Feed and run
162 AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
163 AddInputFromArray<int32>(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6});
164 AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
165 TF_ASSERT_OK(RunOpKernel());
166
167 // Check the new state of the input
168 Tensor params_tensor = *mutable_input(0).tensor;
169 Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
170 test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
171 test::ExpectTensorEqual<float>(expected, params_tensor);
172 }
173
TEST_F(ScatterUpdateOpTest,Error_IndexOutOfRange)174 TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
175 MakeOp(DT_FLOAT_REF, DT_INT32);
176
177 // Feed and run
178 AddInputFromArray<float>(TensorShape({5, 3}),
179 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
180 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 99});
181 AddInputFromArray<float>(TensorShape({3, 3}),
182 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
183 Status s = RunOpKernel();
184 EXPECT_TRUE(
185 absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)"))
186 << s;
187 }
188
TEST_F(ScatterSubOpTest,Error_IndexOutOfRange)189 TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) {
190 MakeOp(DT_FLOAT_REF, DT_INT32);
191 // Feed and run
192 AddInputFromArray<float>(TensorShape({14}),
193 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
194 AddInputFromArray<int32>(TensorShape({3}), {0, 1, 99});
195 AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
196 Status s = RunOpKernel();
197 EXPECT_TRUE(
198 absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)"))
199 << s;
200 }
201
TEST_F(ScatterSubOpTest,StressIndexTest)202 TEST_F(ScatterSubOpTest, StressIndexTest) {
203 MakeOp(DT_INT32_REF, DT_INT32);
204 // Feed and run
205 const int kRows = 1;
206 std::vector<int32> values(kRows, 0);
207 const int kNumUpdates = 1000000;
208 std::vector<int32> indices(kNumUpdates, 0);
209 std::vector<int32> updates(kNumUpdates, 1);
210 AddInputFromArray<int32>(TensorShape({kRows}), values);
211 AddInputFromArray<int32>(TensorShape({kNumUpdates}), indices);
212 AddInputFromArray<int32>(TensorShape({kNumUpdates}), updates);
213 Status s = RunOpKernel();
214 Tensor params_tensor = *mutable_input(0).tensor;
215 Tensor expected(allocator(), DT_INT32, TensorShape({1}));
216 test::FillValues<int32>(&expected, {-1000000});
217 test::ExpectTensorEqual<int32>(expected, params_tensor);
218 }
219
TEST_F(ScatterUpdateOpTest,Error_WrongDimsIndices)220 TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
221 MakeOp(DT_FLOAT_REF, DT_INT32);
222
223 // Feed and run
224 AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
225 AddInputFromArray<int32>(TensorShape({1, 3}), {0, 4, 99});
226 AddInputFromArray<float>(TensorShape({3, 3}),
227 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
228 Status s = RunOpKernel();
229 EXPECT_TRUE(absl::StrContains(s.ToString(),
230 "Must have updates.shape = indices.shape + "
231 "params.shape[1:] or updates.shape = [], got "))
232 << s;
233 }
234
TEST_F(ScatterUpdateOpTest,Error_MismatchedParamsAndUpdateDimensions)235 TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
236 MakeOp(DT_FLOAT_REF, DT_INT32);
237
238 // Feed and run
239 AddInputFromArray<float>(TensorShape({5, 3}),
240 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
241 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
242 AddInputFromArray<float>(
243 TensorShape({3, 4}),
244 {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
245 Status s = RunOpKernel();
246 EXPECT_TRUE(absl::StrContains(s.ToString(),
247 "Must have updates.shape = indices.shape + "
248 "params.shape[1:] or updates.shape = [], got "))
249
250 << s;
251 }
252
TEST_F(ScatterUpdateOpTest,Error_MismatchedIndicesAndUpdateDimensions)253 TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
254 MakeOp(DT_FLOAT_REF, DT_INT32);
255
256 // Feed and run
257 AddInputFromArray<float>(TensorShape({5, 3}),
258 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
259 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
260 AddInputFromArray<float>(TensorShape({2, 3}),
261 {100, 101, 102, 10000, 10001, 10002});
262 Status s = RunOpKernel();
263 EXPECT_TRUE(absl::StrContains(s.ToString(),
264 "Must have updates.shape = indices.shape + "
265 "params.shape[1:] or updates.shape = [], got "))
266 << s;
267 }
268
269 class ScatterUpdateBM : public ScatterUpdateOpTest {
270 public:
TestBody()271 void TestBody() override {}
MakeBenchmarkOp(const char * op,DataType index_type)272 void MakeBenchmarkOp(const char* op, DataType index_type) {
273 TF_ASSERT_OK(NodeDefBuilder("myop", op)
274 .Input(FakeInput(DT_FLOAT_REF))
275 .Input(FakeInput(index_type))
276 .Input(FakeInput(DT_FLOAT))
277 .Finalize(node_def()));
278 TF_CHECK_OK(InitOp());
279 }
280 };
281
282 template <typename Index>
BM_ScatterHelper(::testing::benchmark::State & state,int embedding_size,const char * op,bool big_num_updates=false)283 void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size,
284 const char* op, bool big_num_updates = false) {
285 const int kRows = 10000000 / embedding_size;
286 std::vector<float> values;
287 values.reserve(kRows);
288 for (int i = 0; i < kRows * embedding_size; i++) {
289 values.push_back(i);
290 }
291 const int kNumUpdates = big_num_updates ? 1000000 : 1000;
292 random::PhiloxRandom philox(301, 17);
293 random::SimplePhilox rnd(&philox);
294 std::vector<Index> indices;
295 std::vector<float> updates;
296 for (int i = 0; i < kNumUpdates; i++) {
297 indices.push_back(rnd.Uniform(kRows));
298 for (int j = 0; j < embedding_size; j++) {
299 updates.push_back(i * 10 + j);
300 }
301 }
302
303 ScatterUpdateBM bm;
304 bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
305 bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
306 bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
307 bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
308 updates);
309 for (auto i : state) {
310 Status s = bm.RunOpKernel();
311 }
312 state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
313 state.iterations());
314 }
315
BM_ScatterUpdateInt32(::testing::benchmark::State & state)316 void BM_ScatterUpdateInt32(::testing::benchmark::State& state) {
317 const int embedding_size = state.range(0);
318
319 BM_ScatterHelper<int32>(state, embedding_size, "ScatterUpdate");
320 }
BM_ScatterUpdateInt64(::testing::benchmark::State & state)321 void BM_ScatterUpdateInt64(::testing::benchmark::State& state) {
322 const int embedding_size = state.range(0);
323
324 BM_ScatterHelper<int64>(state, embedding_size, "ScatterUpdate");
325 }
326
BM_ScatterAddInt32(::testing::benchmark::State & state)327 void BM_ScatterAddInt32(::testing::benchmark::State& state) {
328 const int embedding_size = state.range(0);
329
330 BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd");
331 }
332
BM_ScatterAddInt32Large(::testing::benchmark::State & state)333 void BM_ScatterAddInt32Large(::testing::benchmark::State& state) {
334 const int embedding_size = state.range(0);
335
336 BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd", true);
337 }
BM_ScatterAddInt64(::testing::benchmark::State & state)338 void BM_ScatterAddInt64(::testing::benchmark::State& state) {
339 const int embedding_size = state.range(0);
340
341 BM_ScatterHelper<int64>(state, embedding_size, "ScatterAdd");
342 }
343
BM_ScatterMulInt32(::testing::benchmark::State & state)344 void BM_ScatterMulInt32(::testing::benchmark::State& state) {
345 const int embedding_size = state.range(0);
346
347 BM_ScatterHelper<int32>(state, embedding_size, "ScatterMul");
348 }
BM_ScatterMulInt64(::testing::benchmark::State & state)349 void BM_ScatterMulInt64(::testing::benchmark::State& state) {
350 const int embedding_size = state.range(0);
351
352 BM_ScatterHelper<int64>(state, embedding_size, "ScatterMul");
353 }
354
BM_ScatterDivInt32(::testing::benchmark::State & state)355 void BM_ScatterDivInt32(::testing::benchmark::State& state) {
356 const int embedding_size = state.range(0);
357
358 BM_ScatterHelper<int32>(state, embedding_size, "ScatterDiv");
359 }
BM_ScatterDivInt64(::testing::benchmark::State & state)360 void BM_ScatterDivInt64(::testing::benchmark::State& state) {
361 const int embedding_size = state.range(0);
362
363 BM_ScatterHelper<int64>(state, embedding_size, "ScatterDiv");
364 }
365
BM_ScatterMinInt32(::testing::benchmark::State & state)366 void BM_ScatterMinInt32(::testing::benchmark::State& state) {
367 const int embedding_size = state.range(0);
368
369 BM_ScatterHelper<int32>(state, embedding_size, "ScatterMin");
370 }
BM_ScatterMinInt64(::testing::benchmark::State & state)371 void BM_ScatterMinInt64(::testing::benchmark::State& state) {
372 const int embedding_size = state.range(0);
373
374 BM_ScatterHelper<int64>(state, embedding_size, "ScatterMin");
375 }
376
BM_ScatterMaxInt32(::testing::benchmark::State & state)377 void BM_ScatterMaxInt32(::testing::benchmark::State& state) {
378 const int embedding_size = state.range(0);
379
380 BM_ScatterHelper<int32>(state, embedding_size, "ScatterMax");
381 }
BM_ScatterMaxInt64(::testing::benchmark::State & state)382 void BM_ScatterMaxInt64(::testing::benchmark::State& state) {
383 const int embedding_size = state.range(0);
384
385 BM_ScatterHelper<int64>(state, embedding_size, "ScatterMax");
386 }
387
388 BENCHMARK(BM_ScatterUpdateInt32)
389 ->Arg(1)
390 ->Arg(10)
391 ->Arg(32)
392 ->Arg(50)
393 ->Arg(64)
394 ->Arg(80)
395 ->Arg(96)
396 ->Arg(112)
397 ->Arg(192)
398 ->Arg(256)
399 ->Arg(1024)
400 ->Arg(10000)
401 ->Arg(100000)
402 ->Arg(1000000);
403 BENCHMARK(BM_ScatterUpdateInt64)
404 ->Arg(1)
405 ->Arg(10)
406 ->Arg(64)
407 ->Arg(256)
408 ->Arg(1024)
409 ->Arg(100000);
410
411 BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
412
413 BENCHMARK(BM_ScatterAddInt32Large)
414 ->Arg(1)
415 ->Arg(10)
416 ->Arg(64)
417 ->Arg(256)
418 ->Arg(1024);
419
420 BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
421
422 BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
423 BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
424
425 BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
426 BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
427
428 BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
429 BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
430
431 BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
432 BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
433
434 } // namespace
435 } // namespace tensorflow
436