• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <complex>
17 #include <vector>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/interpreter.h"
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow/lite/testing/util.h"
25 
26 namespace tflite {
27 namespace {
28 
29 template <typename T>
30 class RealOpModel : public SingleOpModel {
31  public:
RealOpModel(const TensorData & input,const TensorData & output)32   RealOpModel(const TensorData& input, const TensorData& output) {
33     input_ = AddInput(input);
34 
35     output_ = AddOutput(output);
36 
37     const std::vector<uint8_t> custom_option;
38     SetBuiltinOp(BuiltinOperator_REAL, BuiltinOptions_NONE, 0);
39 
40     BuildInterpreter({GetShape(input_)});
41   }
42 
input()43   int input() { return input_; }
44 
GetOutput()45   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
46 
47  private:
48   int input_;
49   int output_;
50 };
51 
TEST(RealOpTest,SimpleFloatTest)52 TEST(RealOpTest, SimpleFloatTest) {
53   RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
54                        {TensorType_FLOAT32, {}});
55 
56   m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
57                                                     {-6, -1},
58                                                     {9, 0},
59                                                     {-10, 5},
60                                                     {-3, 2},
61                                                     {-6, 11},
62                                                     {0, 0},
63                                                     {22.1, 33.3}});
64 
65   ASSERT_EQ(m.Invoke(), kTfLiteOk);
66 
67   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
68                                  {75, -6, 9, -10, -3, -6, 0, 22.1f})));
69 }
70 
TEST(RealOpTest,SimpleDoubleTest)71 TEST(RealOpTest, SimpleDoubleTest) {
72   RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
73                         {TensorType_FLOAT64, {}});
74 
75   m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
76                                                      {-6, -1},
77                                                      {9, 0},
78                                                      {-10, 5},
79                                                      {-3, 2},
80                                                      {-6, 11},
81                                                      {0, 0},
82                                                      {22.1, 33.3}});
83 
84   ASSERT_EQ(m.Invoke(), kTfLiteOk);
85 
86   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
87                                  {75, -6, 9, -10, -3, -6, 0, 22.1f})));
88 }
89 
90 template <typename T>
91 class ImagOpModel : public SingleOpModel {
92  public:
ImagOpModel(const TensorData & input,const TensorData & output)93   ImagOpModel(const TensorData& input, const TensorData& output) {
94     input_ = AddInput(input);
95 
96     output_ = AddOutput(output);
97 
98     const std::vector<uint8_t> custom_option;
99     SetBuiltinOp(BuiltinOperator_IMAG, BuiltinOptions_NONE, 0);
100 
101     BuildInterpreter({GetShape(input_)});
102   }
103 
input()104   int input() { return input_; }
105 
GetOutput()106   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
107 
108  private:
109   int input_;
110   int output_;
111 };
112 
TEST(ImagOpTest,SimpleFloatTest)113 TEST(ImagOpTest, SimpleFloatTest) {
114   ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
115                        {TensorType_FLOAT32, {}});
116 
117   m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
118                                                     {-6, -1},
119                                                     {9, 3.5},
120                                                     {-10, 5},
121                                                     {-3, 2},
122                                                     {-6, 11},
123                                                     {0, 0},
124                                                     {22.1, 33.3}});
125 
126   ASSERT_EQ(m.Invoke(), kTfLiteOk);
127 
128   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
129                                  {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
130 }
131 
TEST(ImagOpTest,SimpleDoubleTest)132 TEST(ImagOpTest, SimpleDoubleTest) {
133   ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
134                         {TensorType_FLOAT64, {}});
135 
136   m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
137                                                      {-6, -1},
138                                                      {9, 3.5},
139                                                      {-10, 5},
140                                                      {-3, 2},
141                                                      {-6, 11},
142                                                      {0, 0},
143                                                      {22.1, 33.3}});
144 
145   ASSERT_EQ(m.Invoke(), kTfLiteOk);
146 
147   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
148                                  {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
149 }
150 
151 template <typename T>
152 class ComplexAbsOpModel : public SingleOpModel {
153  public:
ComplexAbsOpModel(const TensorData & input,const TensorData & output)154   ComplexAbsOpModel(const TensorData& input, const TensorData& output) {
155     input_ = AddInput(input);
156 
157     output_ = AddOutput(output);
158 
159     const std::vector<uint8_t> custom_option;
160     SetBuiltinOp(BuiltinOperator_COMPLEX_ABS, BuiltinOptions_NONE, 0);
161 
162     BuildInterpreter({GetShape(input_)});
163   }
164 
input()165   int input() { return input_; }
166 
GetOutput()167   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
168 
GetOutputShape()169   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
170 
171  private:
172   int input_;
173   int output_;
174 };
175 
TEST(ComplexAbsOpTest,IncompatibleType64Test)176 TEST(ComplexAbsOpTest, IncompatibleType64Test) {
177   EXPECT_DEATH_IF_SUPPORTED(
178       ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
179                                  {TensorType_FLOAT64, {}}),
180       "output->type != kTfLiteFloat32");
181 }
182 
TEST(ComplexAbsOpTest,IncompatibleType128Test)183 TEST(ComplexAbsOpTest, IncompatibleType128Test) {
184   EXPECT_DEATH_IF_SUPPORTED(
185       ComplexAbsOpModel<float> m({TensorType_COMPLEX128, {2, 4}},
186                                  {TensorType_FLOAT32, {}}),
187       "output->type != kTfLiteFloat64");
188 }
189 
TEST(ComplexAbsOpTest,SimpleFloatTest)190 TEST(ComplexAbsOpTest, SimpleFloatTest) {
191   ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
192                              {TensorType_FLOAT32, {}});
193 
194   m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
195                                                     {-6, -1},
196                                                     {9, 3.5},
197                                                     {-10, 5},
198                                                     {-3, 2},
199                                                     {-6, 11},
200                                                     {0, 0},
201                                                     {22.1, 33.3}});
202 
203   ASSERT_EQ(m.Invoke(), kTfLiteOk);
204 
205   EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
206   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
207                                  {75.32596f, 6.0827627f, 9.656604f, 11.18034f,
208                                   3.6055512f, 12.529964f, 0.f, 39.966236f})));
209 }
210 
TEST(ComplexAbsOpTest,SimpleDoubleTest)211 TEST(ComplexAbsOpTest, SimpleDoubleTest) {
212   ComplexAbsOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
213                               {TensorType_FLOAT64, {}});
214 
215   m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
216                                                      {-6, -1},
217                                                      {9, 3.5},
218                                                      {-10, 5},
219                                                      {-3, 2},
220                                                      {-6, 11},
221                                                      {0, 0},
222                                                      {22.1, 33.3}});
223 
224   ASSERT_EQ(m.Invoke(), kTfLiteOk);
225 
226   EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
227   EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
228                                  {75.32596f, 6.0827627f, 9.656604f, 11.18034f,
229                                   3.6055512f, 12.529964f, 0.f, 39.966236f})));
230 }
231 
232 }  // namespace
233 }  // namespace tflite
234