1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/interpreter.h"
20 #include "tensorflow/lite/kernels/register.h"
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/model.h"
23
24 namespace tflite {
25 namespace {
26
27 using ::testing::ElementsAreArray;
28
29 template <typename T>
30 class OneHotOpModel : public SingleOpModel {
31 public:
OneHotOpModel(std::initializer_list<int> input_shape,int depth_value,TensorType dtype,int axis=-1,T on_value=1,T off_value=0,TensorType indices_type=TensorType_INT32)32 OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
33 TensorType dtype, int axis = -1, T on_value = 1,
34 T off_value = 0, TensorType indices_type = TensorType_INT32) {
35 indices_ = AddInput(indices_type);
36 int depth = AddInput(TensorType_INT32);
37 int on = AddInput(dtype);
38 int off = AddInput(dtype);
39 output_ = AddOutput(dtype);
40 SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
41 CreateOneHotOptions(builder_, axis).Union());
42 BuildInterpreter({input_shape});
43
44 PopulateTensor<int>(depth, {depth_value});
45 PopulateTensor<T>(on, {on_value});
46 PopulateTensor<T>(off, {off_value});
47 }
48
49 template <typename TI>
SetIndices(std::initializer_list<TI> data)50 void SetIndices(std::initializer_list<TI> data) {
51 PopulateTensor<TI>(indices_, data);
52 }
53
InvokeWithResult()54 TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
55
GetOutputSize()56 int32_t GetOutputSize() { return GetTensorSize(output_); }
GetOutput()57 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()58 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
59
60 private:
61 int indices_;
62 int output_;
63 };
64
TEST(OneHotOpTest,BasicFloat)65 TEST(OneHotOpTest, BasicFloat) {
66 const int depth = 3;
67 OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
68 model.SetIndices({0, 1, 2});
69 model.Invoke();
70
71 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
72 EXPECT_THAT(model.GetOutput(),
73 ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
74 }
75
TEST(OneHotOpTest,BasicInt)76 TEST(OneHotOpTest, BasicInt) {
77 const int depth = 3;
78 OneHotOpModel<int> model({3}, depth, TensorType_INT32);
79 model.SetIndices({0, 1, 2});
80 model.Invoke();
81
82 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
83 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
84 }
85
TEST(OneHotOpTest,BasicBool)86 TEST(OneHotOpTest, BasicBool) {
87 const int depth = 3;
88 OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
89 model.SetIndices({0, 1, 2});
90 model.Invoke();
91
92 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
93 EXPECT_THAT(model.GetOutput(),
94 ElementsAreArray({true, false, false, false, true, false, false,
95 false, true}));
96 }
97
TEST(OneHotOpTest,SmallDepth)98 TEST(OneHotOpTest, SmallDepth) {
99 const int depth = 1;
100 OneHotOpModel<int> model({3}, depth, TensorType_INT32);
101 model.SetIndices({0, 1, 2});
102 model.Invoke();
103
104 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
105 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
106 }
107
TEST(OneHotOpTest,BigDepth)108 TEST(OneHotOpTest, BigDepth) {
109 const int depth = 4;
110 OneHotOpModel<int> model({2}, depth, TensorType_INT32);
111 model.SetIndices({0, 1});
112 model.Invoke();
113
114 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
115 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
116 }
117
TEST(OneHotOpTest,OnOffValues)118 TEST(OneHotOpTest, OnOffValues) {
119 const int depth = 3;
120 const int axis = -1;
121 const int on = 5;
122 const int off = 0;
123 OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
124 model.SetIndices({0, 2, -1, 1});
125 model.Invoke();
126
127 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
128 EXPECT_THAT(model.GetOutput(),
129 ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
130 }
131
TEST(OneHotOpTest,ZeroAxis)132 TEST(OneHotOpTest, ZeroAxis) {
133 const int depth = 3;
134 const int axis = 0;
135 const int on = 5;
136 const int off = 0;
137 OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
138 model.SetIndices({0, 2, -1, 1});
139 model.Invoke();
140
141 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
142 EXPECT_THAT(model.GetOutput(),
143 ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
144 }
145
TEST(OneHotOpTest,MultiDimensionalIndices)146 TEST(OneHotOpTest, MultiDimensionalIndices) {
147 const int depth = 3;
148 const int axis = -1;
149 const float on = 2;
150 const float off = 0;
151 OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
152 model.SetIndices({0, 2, 1, -1});
153 model.Invoke();
154
155 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
156 EXPECT_THAT(model.GetOutput(),
157 ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
158 }
159
TEST(OneHotOpTest,Int64Indices)160 TEST(OneHotOpTest, Int64Indices) {
161 const int depth = 3;
162 const int axis = -1;
163 const int on = 1;
164 const int off = 0;
165 OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
166 TensorType_INT64);
167 std::initializer_list<int64_t> indices = {0, 1, 2};
168 model.SetIndices(indices);
169 model.Invoke();
170
171 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
172 EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
173 }
174
175 } // namespace
176 } // namespace tflite
177