• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/fake_input.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/graph/testlib.h"
29 #include "tensorflow/core/kernels/ops_testutil.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/gtl/array_slice.h"
33 #include "tensorflow/core/lib/random/simple_philox.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 
38 namespace tensorflow {
39 namespace {
40 
41 class GatherOpTest : public OpsTestBase {
42  protected:
MakeOp(DataType data_type,DataType index_type,int batch_dims=0)43   void MakeOp(DataType data_type, DataType index_type, int batch_dims = 0) {
44     TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2")
45                      .Input(FakeInput(data_type))
46                      .Input(FakeInput(index_type))
47                      .Input(FakeInput(index_type))
48                      .Attr("batch_dims", batch_dims)
49                      .Finalize(node_def()));
50     TF_ASSERT_OK(InitOp());
51   }
52 };
53 
TEST_F(GatherOpTest,ScalarIndices)54 TEST_F(GatherOpTest, ScalarIndices) {
55   MakeOp(DT_FLOAT, DT_INT32);
56 
57   // Feed and run
58   AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
59   AddInputFromArray<int32>(TensorShape({}), {3});
60   AddInputFromArray<int32>(TensorShape({}), {0});
61   TF_ASSERT_OK(RunOpKernel());
62 
63   // Check the output.
64   Tensor expected(allocator(), DT_FLOAT, TensorShape({}));
65   test::FillValues<float>(&expected, {3});
66   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
67 }
68 
TEST_F(GatherOpTest,ScalarIndices_Complex)69 TEST_F(GatherOpTest, ScalarIndices_Complex) {
70   MakeOp(DT_COMPLEX64, DT_INT32);
71 
72   // Feed and run
73   AddInputFromArray<std::complex<float>>(
74       TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
75                          std::complex<float>(2, 12), std::complex<float>(3, 13),
76                          std::complex<float>(4, 14)});
77   AddInputFromArray<int32>(TensorShape({}), {3});
78   AddInputFromArray<int32>(TensorShape({}), {0});
79   TF_ASSERT_OK(RunOpKernel());
80 
81   // Check the output.
82   Tensor expected(allocator(), DT_COMPLEX64, TensorShape({}));
83   test::FillValues<std::complex<float>>(&expected,
84                                         {std::complex<float>(3, 13)});
85   test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
86 }
87 
TEST_F(GatherOpTest,Simple_TwoD32_Axis0)88 TEST_F(GatherOpTest, Simple_TwoD32_Axis0) {
89   MakeOp(DT_FLOAT, DT_INT32);
90 
91   // Feed and run
92   AddInputFromArray<float>(TensorShape({5, 3}),
93                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
94   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
95   AddInputFromArray<int32>(TensorShape({}), {0});
96   TF_ASSERT_OK(RunOpKernel());
97 
98   // Check the output.
99   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
100   test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
101   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
102 }
103 
TEST_F(GatherOpTest,Simple_TwoD32_Axis1)104 TEST_F(GatherOpTest, Simple_TwoD32_Axis1) {
105   MakeOp(DT_FLOAT, DT_INT32);
106 
107   // Feed and run
108   AddInputFromArray<float>(TensorShape({5, 3}),
109                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
110   AddInputFromArray<int32>(TensorShape({4}), {0, 1, 0, 2});
111   AddInputFromArray<int32>(TensorShape({}), {1});
112   TF_ASSERT_OK(RunOpKernel());
113 
114   // Check the output.
115   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 4}));
116   test::FillValues<float>(&expected, {0, 1, 0, 2,  3, 4,  3,  5,  6,  7,
117                                       6, 8, 9, 10, 9, 11, 12, 13, 12, 14});
118   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
119 }
120 
TEST_F(GatherOpTest,ZeroSize_TwoD32)121 TEST_F(GatherOpTest, ZeroSize_TwoD32) {
122   MakeOp(DT_FLOAT, DT_INT32);
123 
124   // Feed and run
125   AddInputFromArray<float>(TensorShape({5, 0}), {});
126   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
127   AddInputFromArray<int32>(TensorShape({}), {0});
128   TF_ASSERT_OK(RunOpKernel());
129 
130   // Check the output.
131   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 0}));
132   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
133 }
134 
TEST_F(GatherOpTest,Simple_TwoD64)135 TEST_F(GatherOpTest, Simple_TwoD64) {
136   MakeOp(DT_FLOAT, DT_INT64);
137 
138   // Feed and run
139   AddInputFromArray<float>(TensorShape({5, 3}),
140                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
141   AddInputFromArray<int64>(TensorShape({4}), {0, 4, 0, 2});
142   AddInputFromArray<int64>(TensorShape({}), {0});
143   TF_ASSERT_OK(RunOpKernel());
144 
145   // Check the output.
146   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
147   test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
148   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
149 }
150 
TEST_F(GatherOpTest,HighRank)151 TEST_F(GatherOpTest, HighRank) {
152   MakeOp(DT_FLOAT, DT_INT32);
153 
154   // Feed and run
155   AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3});
156   AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0});
157   AddInputFromArray<int32>(TensorShape({}), {0});
158   TF_ASSERT_OK(RunOpKernel());
159 
160   // Check the output
161   Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
162   test::FillValues<float>(&expected, {1, 2, 0, 2, 3, 0});
163   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
164 }
165 
TEST_F(GatherOpTest,Error_IndexOutOfRange)166 TEST_F(GatherOpTest, Error_IndexOutOfRange) {
167   MakeOp(DT_FLOAT, DT_INT32);
168 
169   // Feed and run
170   AddInputFromArray<float>(TensorShape({5, 3}),
171                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
172   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
173   AddInputFromArray<int32>(TensorShape({}), {0});
174   Status s = RunOpKernel();
175   EXPECT_TRUE(
176       absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)"))
177       << s;
178 }
179 
TEST_F(GatherOpTest,Error_BatchDimsOutOfRange)180 TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) {
181   MakeOp(DT_FLOAT, DT_INT32, 10);
182 
183   // Feed and run
184   AddInputFromArray<float>(TensorShape({5, 3}),
185                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
186   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
187   AddInputFromArray<int32>(TensorShape({}), {0});
188   Status s = RunOpKernel();
189   EXPECT_TRUE(absl::StrContains(
190       s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10"))
191       << s;
192 }
193 
194 constexpr int kLookups = 2000;
195 
196 template <typename Index>
Gather(int dim)197 static Graph* Gather(int dim) {
198   Graph* g = new Graph(OpRegistry::Global());
199   // Always use a 512MB buffer.
200   const int kRows = ((512 << 20) / sizeof(float)) / dim;
201   Tensor params(DT_FLOAT, TensorShape({kRows, dim}));
202   params.flat<float>().setRandom();
203 
204   random::PhiloxRandom philox(301, 17);
205   random::SimplePhilox rnd(&philox);
206   std::vector<Index> indices_vec;
207   indices_vec.reserve(kLookups);
208   for (int i = 0; i < kLookups; i++) {
209     indices_vec.push_back(rnd.Uniform(kRows));
210   }
211   Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups}));
212   for (int i = 0; i < indices_vec.size(); i++) {
213     indices.flat<Index>()(i) = indices_vec[i];
214   }
215 
216   Tensor axis(DataTypeToEnum<Index>::value, TensorShape({}));
217   axis.scalar<Index>()() = 0;
218 
219   test::graph::Gather(g, test::graph::Constant(g, params),
220                       test::graph::Constant(g, indices),
221                       test::graph::HostConstant(g, axis));
222   return g;
223 }
224 
225 #define BM_GATHER(DEVICE, INDEX)                                               \
226   static void BM_##DEVICE##_gather_##INDEX(                                    \
227       ::testing::benchmark::State& state) {                                    \
228     const int dim = state.range(0);                                            \
229     test::Benchmark(#DEVICE, Gather<INDEX>(dim), /*old_benchmark_api=*/false)  \
230         .Run(state);                                                           \
231     const int64 tot = static_cast<int64>(state.iterations()) * kLookups * dim; \
232     state.SetItemsProcessed(tot);                                              \
233     state.SetBytesProcessed(tot * sizeof(float));                              \
234   }                                                                            \
235   BENCHMARK(BM_##DEVICE##_gather_##INDEX)                                      \
236       ->UseRealTime()                                                          \
237       ->Arg(1)                                                                 \
238       ->Arg(10)                                                                \
239       ->Arg(20)                                                                \
240       ->Arg(64)                                                                \
241       ->Arg(100)                                                               \
242       ->Arg(200)                                                               \
243       ->Arg(1000)
244 
245 BM_GATHER(cpu, int32);
246 BM_GATHER(gpu, int32);
247 BM_GATHER(cpu, int64);
248 BM_GATHER(gpu, int64);
249 
250 }  // namespace
251 }  // namespace tensorflow
252