• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 
25 namespace tflite {
26 namespace {
27 
28 using ::testing::ElementsAreArray;
29 
30 enum class TestType {
31   kConst = 0,
32   kDynamic = 1,
33 };
34 
35 template <typename InputType>
36 class TopKV2OpModel : public SingleOpModel {
37  public:
TopKV2OpModel(int top_k,std::initializer_list<int> input_shape,std::initializer_list<InputType> input_data,TestType input_tensor_types)38   TopKV2OpModel(int top_k, std::initializer_list<int> input_shape,
39                 std::initializer_list<InputType> input_data,
40                 TestType input_tensor_types) {
41     input_ = AddInput(GetTensorType<InputType>());
42     if (input_tensor_types == TestType::kDynamic) {
43       top_k_ = AddInput(TensorType_INT32);
44     } else {
45       top_k_ = AddConstInput(TensorType_INT32, {top_k}, {1});
46     }
47     output_values_ = AddOutput(GetTensorType<InputType>());
48     output_indexes_ = AddOutput(TensorType_INT32);
49     SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
50     BuildInterpreter({input_shape, {1}});
51 
52     PopulateTensor<InputType>(input_, input_data);
53     if (input_tensor_types == TestType::kDynamic) {
54       PopulateTensor<int32_t>(top_k_, {top_k});
55     }
56   }
57 
GetIndexes()58   std::vector<int32_t> GetIndexes() {
59     return ExtractVector<int32_t>(output_indexes_);
60   }
61 
GetValues()62   std::vector<InputType> GetValues() {
63     return ExtractVector<InputType>(output_values_);
64   }
65 
66  protected:
67   int input_;
68   int top_k_;
69   int output_indexes_;
70   int output_values_;
71 };
72 
73 class TopKV2OpTest : public ::testing::TestWithParam<TestType> {};
74 
75 // The test where the tensor dimension is equal to top.
TEST_P(TopKV2OpTest,EqualFloat)76 TEST_P(TopKV2OpTest, EqualFloat) {
77   TopKV2OpModel<float> m(2, {2, 2}, {-2.0, 0.2, 0.8, 0.1}, GetParam());
78   m.Invoke();
79   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({1, 0, 0, 1}));
80   EXPECT_THAT(m.GetValues(),
81               ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
82 }
83 
84 // Test when internal dimension is k+1.
TEST_P(TopKV2OpTest,BorderFloat)85 TEST_P(TopKV2OpTest, BorderFloat) {
86   TopKV2OpModel<float> m(2, {2, 3}, {-2.0, -3.0, 0.2, 0.8, 0.1, -0.1},
87                          GetParam());
88   m.Invoke();
89   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 0, 0, 1}));
90   EXPECT_THAT(m.GetValues(),
91               ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
92 }
93 // Test when internal dimension is higher than k.
TEST_P(TopKV2OpTest,LargeFloat)94 TEST_P(TopKV2OpTest, LargeFloat) {
95   TopKV2OpModel<float> m(
96       2, {2, 4}, {-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}, GetParam());
97   m.Invoke();
98   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({3, 0, 0, 1}));
99   EXPECT_THAT(m.GetValues(),
100               ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1})));
101 }
102 
103 // Test 1D case.
TEST_P(TopKV2OpTest,VectorFloat)104 TEST_P(TopKV2OpTest, VectorFloat) {
105   TopKV2OpModel<float> m(2, {8}, {-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8},
106                          GetParam());
107   m.Invoke();
108   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({4, 3}));
109   EXPECT_THAT(m.GetValues(), ElementsAreArray(ArrayFloatNear({0.8, 0.2})));
110 }
111 
112 // Check that int32_t works.
TEST_P(TopKV2OpTest,TypeInt32)113 TEST_P(TopKV2OpTest, TypeInt32) {
114   TopKV2OpModel<int32_t> m(2, {2, 3}, {1, 2, 3, 10251, 10250, 10249},
115                            GetParam());
116   m.Invoke();
117   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
118   EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 10251, 10250}));
119 }
120 
121 INSTANTIATE_TEST_SUITE_P(TopKV2OpTest, TopKV2OpTest,
122                          ::testing::Values(TestType::kConst,
123                                            TestType::kDynamic));
124 
125 // Check that uint8_t works.
TEST_P(TopKV2OpTest,TypeUint8)126 TEST_P(TopKV2OpTest, TypeUint8) {
127   TopKV2OpModel<uint8_t> m(2, {2, 3}, {1, 2, 3, 251, 250, 249}, GetParam());
128   m.Invoke();
129   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
130   EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 251, 250}));
131 }
132 
TEST_P(TopKV2OpTest,TypeInt8)133 TEST_P(TopKV2OpTest, TypeInt8) {
134   TopKV2OpModel<int8_t> m(2, {2, 3}, {1, 2, 3, -126, 125, -24}, GetParam());
135   m.Invoke();
136   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2}));
137   EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, 125, -24}));
138 }
139 
140 // Check that int64 works.
TEST_P(TopKV2OpTest,TypeInt64)141 TEST_P(TopKV2OpTest, TypeInt64) {
142   TopKV2OpModel<int64_t> m(2, {2, 3}, {1, 2, 3, -1, -2, -3}, GetParam());
143   m.Invoke();
144   EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1}));
145   EXPECT_THAT(m.GetValues(), ElementsAreArray({3, 2, -1, -2}));
146 }
147 
148 }  // namespace
149 }  // namespace tflite
150