1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <complex>
17 #include <vector>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/interpreter.h"
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow/lite/testing/util.h"
25
26 namespace tflite {
27 namespace {
28
29 template <typename T>
30 class RealOpModel : public SingleOpModel {
31 public:
RealOpModel(const TensorData & input,const TensorData & output)32 RealOpModel(const TensorData& input, const TensorData& output) {
33 input_ = AddInput(input);
34
35 output_ = AddOutput(output);
36
37 const std::vector<uint8_t> custom_option;
38 SetBuiltinOp(BuiltinOperator_REAL, BuiltinOptions_NONE, 0);
39
40 BuildInterpreter({GetShape(input_)});
41 }
42
input()43 int input() { return input_; }
44
GetOutput()45 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
46
47 private:
48 int input_;
49 int output_;
50 };
51
TEST(RealOpTest,SimpleFloatTest)52 TEST(RealOpTest, SimpleFloatTest) {
53 RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
54 {TensorType_FLOAT32, {}});
55
56 m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
57 {-6, -1},
58 {9, 0},
59 {-10, 5},
60 {-3, 2},
61 {-6, 11},
62 {0, 0},
63 {22.1, 33.3}});
64
65 ASSERT_EQ(m.Invoke(), kTfLiteOk);
66
67 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
68 {75, -6, 9, -10, -3, -6, 0, 22.1f})));
69 }
70
TEST(RealOpTest,SimpleDoubleTest)71 TEST(RealOpTest, SimpleDoubleTest) {
72 RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
73 {TensorType_FLOAT64, {}});
74
75 m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
76 {-6, -1},
77 {9, 0},
78 {-10, 5},
79 {-3, 2},
80 {-6, 11},
81 {0, 0},
82 {22.1, 33.3}});
83
84 ASSERT_EQ(m.Invoke(), kTfLiteOk);
85
86 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
87 {75, -6, 9, -10, -3, -6, 0, 22.1f})));
88 }
89
90 template <typename T>
91 class ImagOpModel : public SingleOpModel {
92 public:
ImagOpModel(const TensorData & input,const TensorData & output)93 ImagOpModel(const TensorData& input, const TensorData& output) {
94 input_ = AddInput(input);
95
96 output_ = AddOutput(output);
97
98 const std::vector<uint8_t> custom_option;
99 SetBuiltinOp(BuiltinOperator_IMAG, BuiltinOptions_NONE, 0);
100
101 BuildInterpreter({GetShape(input_)});
102 }
103
input()104 int input() { return input_; }
105
GetOutput()106 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
107
108 private:
109 int input_;
110 int output_;
111 };
112
TEST(ImagOpTest,SimpleFloatTest)113 TEST(ImagOpTest, SimpleFloatTest) {
114 ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
115 {TensorType_FLOAT32, {}});
116
117 m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
118 {-6, -1},
119 {9, 3.5},
120 {-10, 5},
121 {-3, 2},
122 {-6, 11},
123 {0, 0},
124 {22.1, 33.3}});
125
126 ASSERT_EQ(m.Invoke(), kTfLiteOk);
127
128 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
129 {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
130 }
131
TEST(ImagOpTest,SimpleDoubleTest)132 TEST(ImagOpTest, SimpleDoubleTest) {
133 ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
134 {TensorType_FLOAT64, {}});
135
136 m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
137 {-6, -1},
138 {9, 3.5},
139 {-10, 5},
140 {-3, 2},
141 {-6, 11},
142 {0, 0},
143 {22.1, 33.3}});
144
145 ASSERT_EQ(m.Invoke(), kTfLiteOk);
146
147 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
148 {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
149 }
150
151 template <typename T>
152 class ComplexAbsOpModel : public SingleOpModel {
153 public:
ComplexAbsOpModel(const TensorData & input,const TensorData & output)154 ComplexAbsOpModel(const TensorData& input, const TensorData& output) {
155 input_ = AddInput(input);
156
157 output_ = AddOutput(output);
158
159 const std::vector<uint8_t> custom_option;
160 SetBuiltinOp(BuiltinOperator_COMPLEX_ABS, BuiltinOptions_NONE, 0);
161
162 BuildInterpreter({GetShape(input_)});
163 }
164
input()165 int input() { return input_; }
166
GetOutput()167 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
168
GetOutputShape()169 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
170
171 private:
172 int input_;
173 int output_;
174 };
175
TEST(ComplexAbsOpTest,IncompatibleType64Test)176 TEST(ComplexAbsOpTest, IncompatibleType64Test) {
177 EXPECT_DEATH_IF_SUPPORTED(
178 ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
179 {TensorType_FLOAT64, {}}),
180 "output->type != kTfLiteFloat32");
181 }
182
TEST(ComplexAbsOpTest,IncompatibleType128Test)183 TEST(ComplexAbsOpTest, IncompatibleType128Test) {
184 EXPECT_DEATH_IF_SUPPORTED(
185 ComplexAbsOpModel<float> m({TensorType_COMPLEX128, {2, 4}},
186 {TensorType_FLOAT32, {}}),
187 "output->type != kTfLiteFloat64");
188 }
189
TEST(ComplexAbsOpTest,SimpleFloatTest)190 TEST(ComplexAbsOpTest, SimpleFloatTest) {
191 ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
192 {TensorType_FLOAT32, {}});
193
194 m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
195 {-6, -1},
196 {9, 3.5},
197 {-10, 5},
198 {-3, 2},
199 {-6, 11},
200 {0, 0},
201 {22.1, 33.3}});
202
203 ASSERT_EQ(m.Invoke(), kTfLiteOk);
204
205 EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
206 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
207 {75.32596f, 6.0827627f, 9.656604f, 11.18034f,
208 3.6055512f, 12.529964f, 0.f, 39.966236f})));
209 }
210
TEST(ComplexAbsOpTest,SimpleDoubleTest)211 TEST(ComplexAbsOpTest, SimpleDoubleTest) {
212 ComplexAbsOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
213 {TensorType_FLOAT64, {}});
214
215 m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
216 {-6, -1},
217 {9, 3.5},
218 {-10, 5},
219 {-3, 2},
220 {-6, 11},
221 {0, 0},
222 {22.1, 33.3}});
223
224 ASSERT_EQ(m.Invoke(), kTfLiteOk);
225
226 EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
227 EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
228 {75.32596f, 6.0827627f, 9.656604f, 11.18034f,
229 3.6055512f, 12.529964f, 0.f, 39.966236f})));
230 }
231
232 } // namespace
233 } // namespace tflite
234