• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/fake_input.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/random/simple_philox.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 class ScatterUpdateOpTest : public OpsTestBase {
40  protected:
MakeOp(DataType variable_ref_type,DataType index_type)41   void MakeOp(DataType variable_ref_type, DataType index_type) {
42     TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate")
43                      .Input(FakeInput(variable_ref_type))
44                      .Input(FakeInput(index_type))
45                      .Input(FakeInput(RemoveRefType(variable_ref_type)))
46                      .Finalize(node_def()));
47     TF_ASSERT_OK(InitOp());
48   }
49 };
50 class ScatterSubOpTest : public OpsTestBase {
51  protected:
MakeOp(DataType variable_ref_type,DataType index_type)52   void MakeOp(DataType variable_ref_type, DataType index_type) {
53     TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterSub")
54                      .Input(FakeInput(variable_ref_type))
55                      .Input(FakeInput(index_type))
56                      .Input(FakeInput(RemoveRefType(variable_ref_type)))
57                      .Finalize(node_def()));
58     TF_ASSERT_OK(InitOp());
59   }
60 };
61 
TEST_F(ScatterUpdateOpTest,Simple_StringType)62 TEST_F(ScatterUpdateOpTest, Simple_StringType) {
63   MakeOp(DT_STRING_REF, DT_INT32);
64   AddInputFromArray<tstring>(TensorShape({1}), {"Brain"});
65   AddInputFromArray<int32>(TensorShape({1}), {0});
66   AddInputFromArray<tstring>(TensorShape({1}), {"TensorFlow"});
67   TF_ASSERT_OK(RunOpKernel());
68   // Check the new state of the input
69   Tensor params_tensor = *mutable_input(0).tensor;
70   Tensor expected(allocator(), DT_STRING, TensorShape({1}));
71   test::FillValues<tstring>(&expected, {"TensorFlow"});
72   test::ExpectTensorEqual<tstring>(expected, params_tensor);
73 }
74 
TEST_F(ScatterUpdateOpTest,Simple_BoolType)75 TEST_F(ScatterUpdateOpTest, Simple_BoolType) {
76   MakeOp(DT_BOOL_REF, DT_INT32);
77   AddInputFromArray<bool>(TensorShape({1}), {false});
78   AddInputFromArray<int32>(TensorShape({1}), {0});
79   AddInputFromArray<bool>(TensorShape({1}), {true});
80   TF_ASSERT_OK(RunOpKernel());
81   // Check the new state of the input
82   Tensor params_tensor = *mutable_input(0).tensor;
83   Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
84   test::FillValues<bool>(&expected, {true});
85   test::ExpectTensorEqual<bool>(expected, params_tensor);
86 }
87 
TEST_F(ScatterUpdateOpTest,Simple_TwoD32)88 TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
89   MakeOp(DT_FLOAT_REF, DT_INT32);
90 
91   // Feed and run
92   AddInputFromArray<float>(TensorShape({5, 3}),
93                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
94   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
95   AddInputFromArray<float>(TensorShape({3, 3}),
96                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
97   TF_ASSERT_OK(RunOpKernel());
98 
99   // Check the new state of the input
100   Tensor params_tensor = *mutable_input(0).tensor;
101   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
102   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
103                                       10002, 0, 0, 0, 777, 778, 779});
104   test::ExpectTensorEqual<float>(expected, params_tensor);
105 }
106 
TEST_F(ScatterUpdateOpTest,Simple_Two64)107 TEST_F(ScatterUpdateOpTest, Simple_Two64) {
108   MakeOp(DT_FLOAT_REF, DT_INT64);
109 
110   // Feed and run
111   AddInputFromArray<float>(TensorShape({5, 3}),
112                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
113   AddInputFromArray<int64>(TensorShape({3}), {0, 4, 2});
114   AddInputFromArray<float>(TensorShape({3, 3}),
115                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
116   TF_ASSERT_OK(RunOpKernel());
117 
118   // Check the new state of the input
119   Tensor params_tensor = *mutable_input(0).tensor;
120   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
121   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
122                                       10002, 0, 0, 0, 777, 778, 779});
123   test::ExpectTensorEqual<float>(expected, params_tensor);
124 }
125 
TEST_F(ScatterUpdateOpTest,Simple_ZeroD)126 TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
127   MakeOp(DT_FLOAT_REF, DT_INT32);
128 
129   // Feed and run
130   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
131   AddInputFromArray<int32>(TensorShape({}), {3});
132   AddInputFromArray<float>(TensorShape({}), {101});
133   TF_ASSERT_OK(RunOpKernel());
134 
135   // Check the new state of the input
136   Tensor params_tensor = *mutable_input(0).tensor;
137   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
138   test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
139   test::ExpectTensorEqual<float>(expected, params_tensor);
140 }
141 
TEST_F(ScatterUpdateOpTest,Simple_OneD)142 TEST_F(ScatterUpdateOpTest, Simple_OneD) {
143   MakeOp(DT_FLOAT_REF, DT_INT32);
144 
145   // Feed and run
146   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
147   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
148   AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
149   TF_ASSERT_OK(RunOpKernel());
150 
151   // Check the new state of the input
152   Tensor params_tensor = *mutable_input(0).tensor;
153   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
154   test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
155   test::ExpectTensorEqual<float>(expected, params_tensor);
156 }
157 
TEST_F(ScatterUpdateOpTest,HigherRank)158 TEST_F(ScatterUpdateOpTest, HigherRank) {
159   MakeOp(DT_FLOAT_REF, DT_INT32);
160 
161   // Feed and run
162   AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
163   AddInputFromArray<int32>(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6});
164   AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
165   TF_ASSERT_OK(RunOpKernel());
166 
167   // Check the new state of the input
168   Tensor params_tensor = *mutable_input(0).tensor;
169   Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
170   test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
171   test::ExpectTensorEqual<float>(expected, params_tensor);
172 }
173 
TEST_F(ScatterUpdateOpTest,Error_IndexOutOfRange)174 TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
175   MakeOp(DT_FLOAT_REF, DT_INT32);
176 
177   // Feed and run
178   AddInputFromArray<float>(TensorShape({5, 3}),
179                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
180   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 99});
181   AddInputFromArray<float>(TensorShape({3, 3}),
182                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
183   Status s = RunOpKernel();
184   EXPECT_TRUE(
185       absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)"))
186       << s;
187 }
188 
TEST_F(ScatterSubOpTest,Error_IndexOutOfRange)189 TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) {
190   MakeOp(DT_FLOAT_REF, DT_INT32);
191   // Feed and run
192   AddInputFromArray<float>(TensorShape({14}),
193                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
194   AddInputFromArray<int32>(TensorShape({3}), {0, 1, 99});
195   AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
196   Status s = RunOpKernel();
197   EXPECT_TRUE(
198       absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)"))
199       << s;
200 }
201 
TEST_F(ScatterSubOpTest,StressIndexTest)202 TEST_F(ScatterSubOpTest, StressIndexTest) {
203   MakeOp(DT_INT32_REF, DT_INT32);
204   // Feed and run
205   const int kRows = 1;
206   std::vector<int32> values(kRows, 0);
207   const int kNumUpdates = 1000000;
208   std::vector<int32> indices(kNumUpdates, 0);
209   std::vector<int32> updates(kNumUpdates, 1);
210   AddInputFromArray<int32>(TensorShape({kRows}), values);
211   AddInputFromArray<int32>(TensorShape({kNumUpdates}), indices);
212   AddInputFromArray<int32>(TensorShape({kNumUpdates}), updates);
213   Status s = RunOpKernel();
214   Tensor params_tensor = *mutable_input(0).tensor;
215   Tensor expected(allocator(), DT_INT32, TensorShape({1}));
216   test::FillValues<int32>(&expected, {-1000000});
217   test::ExpectTensorEqual<int32>(expected, params_tensor);
218 }
219 
TEST_F(ScatterUpdateOpTest,Error_WrongDimsIndices)220 TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
221   MakeOp(DT_FLOAT_REF, DT_INT32);
222 
223   // Feed and run
224   AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
225   AddInputFromArray<int32>(TensorShape({1, 3}), {0, 4, 99});
226   AddInputFromArray<float>(TensorShape({3, 3}),
227                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
228   Status s = RunOpKernel();
229   EXPECT_TRUE(absl::StrContains(s.ToString(),
230                                 "Must have updates.shape = indices.shape + "
231                                 "params.shape[1:] or updates.shape = [], got "))
232       << s;
233 }
234 
TEST_F(ScatterUpdateOpTest,Error_MismatchedParamsAndUpdateDimensions)235 TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
236   MakeOp(DT_FLOAT_REF, DT_INT32);
237 
238   // Feed and run
239   AddInputFromArray<float>(TensorShape({5, 3}),
240                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
241   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
242   AddInputFromArray<float>(
243       TensorShape({3, 4}),
244       {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
245   Status s = RunOpKernel();
246   EXPECT_TRUE(absl::StrContains(s.ToString(),
247                                 "Must have updates.shape = indices.shape + "
248                                 "params.shape[1:] or updates.shape = [], got "))
249 
250       << s;
251 }
252 
TEST_F(ScatterUpdateOpTest,Error_MismatchedIndicesAndUpdateDimensions)253 TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
254   MakeOp(DT_FLOAT_REF, DT_INT32);
255 
256   // Feed and run
257   AddInputFromArray<float>(TensorShape({5, 3}),
258                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
259   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
260   AddInputFromArray<float>(TensorShape({2, 3}),
261                            {100, 101, 102, 10000, 10001, 10002});
262   Status s = RunOpKernel();
263   EXPECT_TRUE(absl::StrContains(s.ToString(),
264                                 "Must have updates.shape = indices.shape + "
265                                 "params.shape[1:] or updates.shape = [], got "))
266       << s;
267 }
268 
269 class ScatterUpdateBM : public ScatterUpdateOpTest {
270  public:
TestBody()271   void TestBody() override {}
MakeBenchmarkOp(const char * op,DataType index_type)272   void MakeBenchmarkOp(const char* op, DataType index_type) {
273     TF_ASSERT_OK(NodeDefBuilder("myop", op)
274                      .Input(FakeInput(DT_FLOAT_REF))
275                      .Input(FakeInput(index_type))
276                      .Input(FakeInput(DT_FLOAT))
277                      .Finalize(node_def()));
278     TF_CHECK_OK(InitOp());
279   }
280 };
281 
282 template <typename Index>
BM_ScatterHelper(::testing::benchmark::State & state,int embedding_size,const char * op,bool big_num_updates=false)283 void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size,
284                       const char* op, bool big_num_updates = false) {
285   const int kRows = 10000000 / embedding_size;
286   std::vector<float> values;
287   values.reserve(kRows);
288   for (int i = 0; i < kRows * embedding_size; i++) {
289     values.push_back(i);
290   }
291   const int kNumUpdates = big_num_updates ? 1000000 : 1000;
292   random::PhiloxRandom philox(301, 17);
293   random::SimplePhilox rnd(&philox);
294   std::vector<Index> indices;
295   std::vector<float> updates;
296   for (int i = 0; i < kNumUpdates; i++) {
297     indices.push_back(rnd.Uniform(kRows));
298     for (int j = 0; j < embedding_size; j++) {
299       updates.push_back(i * 10 + j);
300     }
301   }
302 
303   ScatterUpdateBM bm;
304   bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
305   bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
306   bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
307   bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
308                               updates);
309   for (auto i : state) {
310     Status s = bm.RunOpKernel();
311   }
312   state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
313                           state.iterations());
314 }
315 
BM_ScatterUpdateInt32(::testing::benchmark::State & state)316 void BM_ScatterUpdateInt32(::testing::benchmark::State& state) {
317   const int embedding_size = state.range(0);
318 
319   BM_ScatterHelper<int32>(state, embedding_size, "ScatterUpdate");
320 }
BM_ScatterUpdateInt64(::testing::benchmark::State & state)321 void BM_ScatterUpdateInt64(::testing::benchmark::State& state) {
322   const int embedding_size = state.range(0);
323 
324   BM_ScatterHelper<int64>(state, embedding_size, "ScatterUpdate");
325 }
326 
BM_ScatterAddInt32(::testing::benchmark::State & state)327 void BM_ScatterAddInt32(::testing::benchmark::State& state) {
328   const int embedding_size = state.range(0);
329 
330   BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd");
331 }
332 
BM_ScatterAddInt32Large(::testing::benchmark::State & state)333 void BM_ScatterAddInt32Large(::testing::benchmark::State& state) {
334   const int embedding_size = state.range(0);
335 
336   BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd", true);
337 }
BM_ScatterAddInt64(::testing::benchmark::State & state)338 void BM_ScatterAddInt64(::testing::benchmark::State& state) {
339   const int embedding_size = state.range(0);
340 
341   BM_ScatterHelper<int64>(state, embedding_size, "ScatterAdd");
342 }
343 
BM_ScatterMulInt32(::testing::benchmark::State & state)344 void BM_ScatterMulInt32(::testing::benchmark::State& state) {
345   const int embedding_size = state.range(0);
346 
347   BM_ScatterHelper<int32>(state, embedding_size, "ScatterMul");
348 }
BM_ScatterMulInt64(::testing::benchmark::State & state)349 void BM_ScatterMulInt64(::testing::benchmark::State& state) {
350   const int embedding_size = state.range(0);
351 
352   BM_ScatterHelper<int64>(state, embedding_size, "ScatterMul");
353 }
354 
BM_ScatterDivInt32(::testing::benchmark::State & state)355 void BM_ScatterDivInt32(::testing::benchmark::State& state) {
356   const int embedding_size = state.range(0);
357 
358   BM_ScatterHelper<int32>(state, embedding_size, "ScatterDiv");
359 }
BM_ScatterDivInt64(::testing::benchmark::State & state)360 void BM_ScatterDivInt64(::testing::benchmark::State& state) {
361   const int embedding_size = state.range(0);
362 
363   BM_ScatterHelper<int64>(state, embedding_size, "ScatterDiv");
364 }
365 
BM_ScatterMinInt32(::testing::benchmark::State & state)366 void BM_ScatterMinInt32(::testing::benchmark::State& state) {
367   const int embedding_size = state.range(0);
368 
369   BM_ScatterHelper<int32>(state, embedding_size, "ScatterMin");
370 }
BM_ScatterMinInt64(::testing::benchmark::State & state)371 void BM_ScatterMinInt64(::testing::benchmark::State& state) {
372   const int embedding_size = state.range(0);
373 
374   BM_ScatterHelper<int64>(state, embedding_size, "ScatterMin");
375 }
376 
BM_ScatterMaxInt32(::testing::benchmark::State & state)377 void BM_ScatterMaxInt32(::testing::benchmark::State& state) {
378   const int embedding_size = state.range(0);
379 
380   BM_ScatterHelper<int32>(state, embedding_size, "ScatterMax");
381 }
BM_ScatterMaxInt64(::testing::benchmark::State & state)382 void BM_ScatterMaxInt64(::testing::benchmark::State& state) {
383   const int embedding_size = state.range(0);
384 
385   BM_ScatterHelper<int64>(state, embedding_size, "ScatterMax");
386 }
387 
388 BENCHMARK(BM_ScatterUpdateInt32)
389     ->Arg(1)
390     ->Arg(10)
391     ->Arg(32)
392     ->Arg(50)
393     ->Arg(64)
394     ->Arg(80)
395     ->Arg(96)
396     ->Arg(112)
397     ->Arg(192)
398     ->Arg(256)
399     ->Arg(1024)
400     ->Arg(10000)
401     ->Arg(100000)
402     ->Arg(1000000);
403 BENCHMARK(BM_ScatterUpdateInt64)
404     ->Arg(1)
405     ->Arg(10)
406     ->Arg(64)
407     ->Arg(256)
408     ->Arg(1024)
409     ->Arg(100000);
410 
411 BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
412 
413 BENCHMARK(BM_ScatterAddInt32Large)
414     ->Arg(1)
415     ->Arg(10)
416     ->Arg(64)
417     ->Arg(256)
418     ->Arg(1024);
419 
420 BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
421 
422 BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
423 BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
424 
425 BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
426 BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
427 
428 BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
429 BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
430 
431 BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
432 BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
433 
434 }  // namespace
435 }  // namespace tensorflow
436