• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/interpreter.h"
20 #include "tensorflow/lite/kernels/register.h"
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/model.h"
23 
24 namespace tflite {
25 namespace {
26 
27 using ::testing::ElementsAreArray;
28 
29 template <typename T>
30 class OneHotOpModel : public SingleOpModel {
31  public:
OneHotOpModel(std::initializer_list<int> input_shape,int depth_value,TensorType dtype,int axis=-1,T on_value=1,T off_value=0,TensorType indices_type=TensorType_INT32)32   OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
33                 TensorType dtype, int axis = -1, T on_value = 1,
34                 T off_value = 0, TensorType indices_type = TensorType_INT32) {
35     indices_ = AddInput(indices_type);
36     int depth = AddInput(TensorType_INT32);
37     int on = AddInput(dtype);
38     int off = AddInput(dtype);
39     output_ = AddOutput(dtype);
40     SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
41                  CreateOneHotOptions(builder_, axis).Union());
42     BuildInterpreter({input_shape});
43 
44     PopulateTensor<int>(depth, {depth_value});
45     PopulateTensor<T>(on, {on_value});
46     PopulateTensor<T>(off, {off_value});
47   }
48 
49   template <typename TI>
SetIndices(std::initializer_list<TI> data)50   void SetIndices(std::initializer_list<TI> data) {
51     PopulateTensor<TI>(indices_, data);
52   }
53 
InvokeWithResult()54   TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
55 
GetOutputSize()56   int32_t GetOutputSize() { return GetTensorSize(output_); }
GetOutput()57   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()58   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
59 
60  private:
61   int indices_;
62   int output_;
63 };
64 
TEST(OneHotOpTest,BasicFloat)65 TEST(OneHotOpTest, BasicFloat) {
66   const int depth = 3;
67   OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
68   model.SetIndices({0, 1, 2});
69   model.Invoke();
70 
71   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
72   EXPECT_THAT(model.GetOutput(),
73               ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
74 }
75 
TEST(OneHotOpTest,BasicInt)76 TEST(OneHotOpTest, BasicInt) {
77   const int depth = 3;
78   OneHotOpModel<int> model({3}, depth, TensorType_INT32);
79   model.SetIndices({0, 1, 2});
80   model.Invoke();
81 
82   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
83   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
84 }
85 
TEST(OneHotOpTest,BasicBool)86 TEST(OneHotOpTest, BasicBool) {
87   const int depth = 3;
88   OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
89   model.SetIndices({0, 1, 2});
90   model.Invoke();
91 
92   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
93   EXPECT_THAT(model.GetOutput(),
94               ElementsAreArray({true, false, false, false, true, false, false,
95                                 false, true}));
96 }
97 
TEST(OneHotOpTest,SmallDepth)98 TEST(OneHotOpTest, SmallDepth) {
99   const int depth = 1;
100   OneHotOpModel<int> model({3}, depth, TensorType_INT32);
101   model.SetIndices({0, 1, 2});
102   model.Invoke();
103 
104   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
105   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
106 }
107 
TEST(OneHotOpTest,BigDepth)108 TEST(OneHotOpTest, BigDepth) {
109   const int depth = 4;
110   OneHotOpModel<int> model({2}, depth, TensorType_INT32);
111   model.SetIndices({0, 1});
112   model.Invoke();
113 
114   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
115   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
116 }
117 
TEST(OneHotOpTest,OnOffValues)118 TEST(OneHotOpTest, OnOffValues) {
119   const int depth = 3;
120   const int axis = -1;
121   const int on = 5;
122   const int off = 0;
123   OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
124   model.SetIndices({0, 2, -1, 1});
125   model.Invoke();
126 
127   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
128   EXPECT_THAT(model.GetOutput(),
129               ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
130 }
131 
TEST(OneHotOpTest,ZeroAxis)132 TEST(OneHotOpTest, ZeroAxis) {
133   const int depth = 3;
134   const int axis = 0;
135   const int on = 5;
136   const int off = 0;
137   OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
138   model.SetIndices({0, 2, -1, 1});
139   model.Invoke();
140 
141   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
142   EXPECT_THAT(model.GetOutput(),
143               ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
144 }
145 
TEST(OneHotOpTest,MultiDimensionalIndices)146 TEST(OneHotOpTest, MultiDimensionalIndices) {
147   const int depth = 3;
148   const int axis = -1;
149   const float on = 2;
150   const float off = 0;
151   OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
152   model.SetIndices({0, 2, 1, -1});
153   model.Invoke();
154 
155   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
156   EXPECT_THAT(model.GetOutput(),
157               ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
158 }
159 
TEST(OneHotOpTest,Int64Indices)160 TEST(OneHotOpTest, Int64Indices) {
161   const int depth = 3;
162   const int axis = -1;
163   const int on = 1;
164   const int off = 0;
165   OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
166                            TensorType_INT64);
167   std::initializer_list<int64_t> indices = {0, 1, 2};
168   model.SetIndices(indices);
169   model.Invoke();
170 
171   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
172   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
173 }
174 
175 }  // namespace
176 }  // namespace tflite
177