1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/fake_input.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/kernels/ops_testutil.h"
32 #include "tensorflow/core/kernels/ops_util.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/lib/random/simple_philox.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38
39 namespace tensorflow {
40
41 namespace test {
42 namespace graph {
43
GatherNd(Graph * g,class Node * in0,class Node * in1)44 class Node* GatherNd(Graph* g, class Node* in0, class Node* in1) {
45 class Node* ret;
46 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherNd")
47 .Input(in0)
48 .Input(in1)
49 .Finalize(g, &ret));
50 return ret;
51 }
52
53 } // namespace graph
54 } // namespace test
55
56 namespace {
57
58 class GatherNdOpTest : public OpsTestBase {
59 protected:
MakeOp(DataType param_type,DataType index_type)60 void MakeOp(DataType param_type, DataType index_type) {
61 TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
62 .Input(FakeInput(param_type))
63 .Input(FakeInput(index_type))
64 .Finalize(node_def()));
65 TF_ASSERT_OK(InitOp());
66 }
67 };
68
TEST_F(GatherNdOpTest,Simple)69 TEST_F(GatherNdOpTest, Simple) {
70 MakeOp(DT_FLOAT, DT_INT32);
71
72 // Feed and run
73 AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
74 AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
75 TF_ASSERT_OK(RunOpKernel());
76
77 // Check the output.
78 Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
79 test::FillValues<float>(&expected, {8, 4});
80 test::ExpectTensorEqual<float>(expected, *GetOutput(0));
81 }
82
TEST_F(GatherNdOpTest,Quantized_UINT8)83 TEST_F(GatherNdOpTest, Quantized_UINT8) {
84 MakeOp(DT_QUINT8, DT_INT32);
85
86 // Feed and run
87 AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
88 AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
89 TF_ASSERT_OK(RunOpKernel());
90
91 // Check the output.
92 Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
93 test::FillValues<quint8>(&expected, {8, 4});
94 test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
95 }
96
TEST_F(GatherNdOpTest,Quantized_INT8)97 TEST_F(GatherNdOpTest, Quantized_INT8) {
98 MakeOp(DT_QINT8, DT_INT32);
99
100 AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
101 AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
102 TF_ASSERT_OK(RunOpKernel());
103
104 Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
105 test::FillValues<qint8>(&expected, {8, 4});
106 test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
107 }
108
109 constexpr int kLookups = 2000;
110
111 template <typename Index>
GatherNd(int dim)112 static Graph* GatherNd(int dim) {
113 Graph* g = new Graph(OpRegistry::Global());
114 // Always use a 512MB buffer.
115 // const int kRows = ((512 << 20) / sizeof(float)) / dim;
116 Tensor params(DT_FLOAT, TensorShape({dim, 8, 16, 32}));
117 params.flat<float>().setRandom();
118
119 random::PhiloxRandom philox(301, 17);
120 random::SimplePhilox rnd(&philox);
121 Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups, 4}));
122 auto indices_mat = indices.matrix<Index>();
123 for (int i = 0; i < kLookups; i++) {
124 indices_mat(i, 0) = rnd.Uniform(dim);
125 indices_mat(i, 1) = rnd.Uniform(8);
126 indices_mat(i, 2) = rnd.Uniform(16);
127 indices_mat(i, 3) = rnd.Uniform(32);
128 }
129
130 test::graph::GatherNd(g, test::graph::Constant(g, params),
131 test::graph::Constant(g, indices));
132 return g;
133 }
134
135 #define BM_GATHER_ND(DEVICE, INDEX) \
136 static void BM_##DEVICE##_gather_nd_##INDEX( \
137 ::testing::benchmark::State& state) { \
138 const int dim = state.range(0); \
139 test::Benchmark(#DEVICE, GatherNd<INDEX>(dim), \
140 /*old_benchmark_api=*/false) \
141 .Run(state); \
142 const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \
143 state.SetItemsProcessed(tot); \
144 state.SetBytesProcessed(tot * sizeof(float)); \
145 } \
146 BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
147 ->UseRealTime() \
148 ->Arg(10) \
149 ->Arg(100) \
150 ->Arg(1000) \
151 ->Arg(10000)
152
153 BM_GATHER_ND(cpu, int32);
154 BM_GATHER_ND(gpu, int32);
155 BM_GATHER_ND(cpu, int64);
156 BM_GATHER_ND(gpu, int64);
157
158 } // namespace
159 } // namespace tensorflow
160