1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24
25 namespace tflite {
26 namespace {
27
28 using ::testing::ElementsAreArray;
29
30 enum class TestType {
31 kConst = 0,
32 kDynamic = 1,
33 };
34
35 template <typename InputType>
36 class TopKV2OpModel : public SingleOpModel {
37 public:
TopKV2OpModel(int top_k,std::initializer_list<int> input_shape,std::initializer_list<InputType> input_data,TestType input_tensor_types)38 TopKV2OpModel(int top_k, std::initializer_list<int> input_shape,
39 std::initializer_list<InputType> input_data,
40 TestType input_tensor_types) {
41 input_ = AddInput(GetTensorType<InputType>());
42 if (input_tensor_types == TestType::kDynamic) {
43 top_k_ = AddInput(TensorType_INT32);
44 } else {
45 top_k_ = AddConstInput(TensorType_INT32, {top_k}, {1});
46 }
47 output_values_ = AddOutput(GetTensorType<InputType>());
48 output_indexes_ = AddOutput(TensorType_INT32);
49 SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
50 BuildInterpreter({input_shape, {1}});
51
52 PopulateTensor<InputType>(input_, input_data);
53 if (input_tensor_types == TestType::kDynamic) {
54 PopulateTensor<int32_t>(top_k_, {top_k});
55 }
56 }
57
GetIndexes()58 std::vector<int32_t> GetIndexes() {
59 return ExtractVector<int32_t>(output_indexes_);
60 }
61
GetValues()62 std::vector<InputType> GetValues() {
63 return ExtractVector<InputType>(output_values_);
64 }
65
66 protected:
67 int input_;
68 int top_k_;
69 int output_indexes_;
70 int output_values_;
71 };
72
73 class TopKV2OpTest : public ::testing::TestWithParam<TestType> {};
74
75 // The test where the tensor dimension is equal to top.
TEST_P(TopKV2OpTest,EqualFloat)76 TEST_P(TopKV2OpTest, EqualFloat) {
77 TopKV2OpModel<float> m(2, {2, 2}, {-2.0, 0.2, 0.8, 0.1}, GetParam());
78 m.Invoke();
79 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({1, 0, 0, 1}));
80 EXPECT_THAT(m.GetValues(),
81 ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
82 }
83
84 // Test when internal dimension is k+1.
TEST_P(TopKV2OpTest,BorderFloat)85 TEST_P(TopKV2OpTest, BorderFloat) {
86 TopKV2OpModel<float> m(2, {2, 3}, {-2.0, -3.0, 0.2, 0.8, 0.1, -0.1},
87 GetParam());
88 m.Invoke();
89 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 0, 0, 1}));
90 EXPECT_THAT(m.GetValues(),
91 ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
92 }
93 // Test when internal dimension is higher than k.
TEST_P(TopKV2OpTest,LargeFloat)94 TEST_P(TopKV2OpTest, LargeFloat) {
95 TopKV2OpModel<float> m(
96 2, {2, 4}, {-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}, GetParam());
97 m.Invoke();
98 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({3, 0, 0, 1}));
99 EXPECT_THAT(m.GetValues(),
100 ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
101 }
102
103 // Test 1D case.
TEST_P(TopKV2OpTest,VectorFloat)104 TEST_P(TopKV2OpTest, VectorFloat) {
105 TopKV2OpModel<float> m(2, {8}, {-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8},
106 GetParam());
107 m.Invoke();
108 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({4, 3}));
109 EXPECT_THAT(m.GetValues(), ElementsAreArray(ArrayFloatNear({0.8, 0.2})));
110 }
111
112 // Check that int32_t works.
TEST_P(TopKV2OpTest,TypeInt32)113 TEST_P(TopKV2OpTest, TypeInt32) {
114 TopKV2OpModel<int32_t> m(2, {2, 3}, {1, 2, 3, 10251, 10250, 10249},
115 GetParam());
116 m.Invoke();
117 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
118 EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 10251, 10250}));
119 }
120
121 INSTANTIATE_TEST_SUITE_P(TopKV2OpTest, TopKV2OpTest,
122 ::testing::Values(TestType::kConst,
123 TestType::kDynamic));
124
125 // Check that uint8_t works.
TEST_P(TopKV2OpTest,TypeUint8)126 TEST_P(TopKV2OpTest, TypeUint8) {
127 TopKV2OpModel<uint8_t> m(2, {2, 3}, {1, 2, 3, 251, 250, 249}, GetParam());
128 m.Invoke();
129 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
130 EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 251, 250}));
131 }
132
TEST_P(TopKV2OpTest,TypeInt8)133 TEST_P(TopKV2OpTest, TypeInt8) {
134 TopKV2OpModel<int8_t> m(2, {2, 3}, {1, 2, 3, -126, 125, -24}, GetParam());
135 m.Invoke();
136 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2}));
137 EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 125, -24}));
138 }
139
140 // Check that int64 works.
TEST_P(TopKV2OpTest,TypeInt64)141 TEST_P(TopKV2OpTest, TypeInt64) {
142 TopKV2OpModel<int64_t> m(2, {2, 3}, {1, 2, 3, -1, -2, -3}, GetParam());
143 m.Invoke();
144 EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
145 EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, -1, -2}));
146 }
147
148 } // namespace
149 } // namespace tflite
150