• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/fake_input.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/kernels/ops_testutil.h"
32 #include "tensorflow/core/kernels/ops_util.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/lib/random/simple_philox.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38 
39 namespace tensorflow {
40 
41 namespace test {
42 namespace graph {
43 
GatherNd(Graph * g,class Node * in0,class Node * in1)44 class Node* GatherNd(Graph* g, class Node* in0, class Node* in1) {
45   class Node* ret;
46   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherNd")
47                   .Input(in0)
48                   .Input(in1)
49                   .Finalize(g, &ret));
50   return ret;
51 }
52 
53 }  // namespace graph
54 }  // namespace test
55 
56 namespace {
57 
58 class GatherNdOpTest : public OpsTestBase {
59  protected:
MakeOp(DataType param_type,DataType index_type)60   void MakeOp(DataType param_type, DataType index_type) {
61     TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
62                      .Input(FakeInput(param_type))
63                      .Input(FakeInput(index_type))
64                      .Finalize(node_def()));
65     TF_ASSERT_OK(InitOp());
66   }
67 };
68 
TEST_F(GatherNdOpTest,Simple)69 TEST_F(GatherNdOpTest, Simple) {
70   MakeOp(DT_FLOAT, DT_INT32);
71 
72   // Feed and run
73   AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
74   AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
75   TF_ASSERT_OK(RunOpKernel());
76 
77   // Check the output.
78   Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
79   test::FillValues<float>(&expected, {8, 4});
80   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
81 }
82 
TEST_F(GatherNdOpTest,Quantized_UINT8)83 TEST_F(GatherNdOpTest, Quantized_UINT8) {
84   MakeOp(DT_QUINT8, DT_INT32);
85 
86   // Feed and run
87   AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
88   AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
89   TF_ASSERT_OK(RunOpKernel());
90 
91   // Check the output.
92   Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
93   test::FillValues<quint8>(&expected, {8, 4});
94   test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
95 }
96 
TEST_F(GatherNdOpTest,Quantized_INT8)97 TEST_F(GatherNdOpTest, Quantized_INT8) {
98   MakeOp(DT_QINT8, DT_INT32);
99 
100   AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
101   AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
102   TF_ASSERT_OK(RunOpKernel());
103 
104   Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
105   test::FillValues<qint8>(&expected, {8, 4});
106   test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
107 }
108 
109 constexpr int kLookups = 2000;
110 
111 template <typename Index>
GatherNd(int dim)112 static Graph* GatherNd(int dim) {
113   Graph* g = new Graph(OpRegistry::Global());
114   // Always use a 512MB buffer.
115   // const int kRows = ((512 << 20) / sizeof(float)) / dim;
116   Tensor params(DT_FLOAT, TensorShape({dim, 8, 16, 32}));
117   params.flat<float>().setRandom();
118 
119   random::PhiloxRandom philox(301, 17);
120   random::SimplePhilox rnd(&philox);
121   Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups, 4}));
122   auto indices_mat = indices.matrix<Index>();
123   for (int i = 0; i < kLookups; i++) {
124     indices_mat(i, 0) = rnd.Uniform(dim);
125     indices_mat(i, 1) = rnd.Uniform(8);
126     indices_mat(i, 2) = rnd.Uniform(16);
127     indices_mat(i, 3) = rnd.Uniform(32);
128   }
129 
130   test::graph::GatherNd(g, test::graph::Constant(g, params),
131                         test::graph::Constant(g, indices));
132   return g;
133 }
134 
135 #define BM_GATHER_ND(DEVICE, INDEX)                                          \
136   static void BM_##DEVICE##_gather_nd_##INDEX(                               \
137       ::testing::benchmark::State& state) {                                  \
138     const int dim = state.range(0);                                          \
139     test::Benchmark(#DEVICE, GatherNd<INDEX>(dim),                           \
140                     /*old_benchmark_api=*/false)                             \
141         .Run(state);                                                         \
142     const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \
143     state.SetItemsProcessed(tot);                                            \
144     state.SetBytesProcessed(tot * sizeof(float));                            \
145   }                                                                          \
146   BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX)                                 \
147       ->UseRealTime()                                                        \
148       ->Arg(10)                                                              \
149       ->Arg(100)                                                             \
150       ->Arg(1000)                                                            \
151       ->Arg(10000)
152 
153 BM_GATHER_ND(cpu, int32);
154 BM_GATHER_ND(gpu, int32);
155 BM_GATHER_ND(cpu, int64);
156 BM_GATHER_ND(gpu, int64);
157 
158 }  // namespace
159 }  // namespace tensorflow
160