1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite RNN op.
16
17 #include <string.h>
18 #include <initializer_list>
19 #include <memory>
20 #include <vector>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/register.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28
29 namespace tflite {
30 namespace {
31
32 using ::testing::ElementsAreArray;
33
34 static float rnn_input[] = {
35 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
36 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
37 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
38 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
39 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
40 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
41 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
42 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
43 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
44 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
45 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
46 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
47 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
48 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
49 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
50 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
51 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
52 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
53 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
54 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
55 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
56 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
57 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
58 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
59 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
60 0.93455386, -0.6324693, -0.083922029};
61
62 static float rnn_golden_output[] = {
63 0.496726, 0, 0.965996, 0, 0.0584254, 0,
64 0, 0.12315, 0, 0, 0.612266, 0.456601,
65 0, 0.52286, 1.16099, 0.0291232,
66
67 0, 0, 0.524901, 0, 0, 0,
68 0, 1.02116, 0, 1.35762, 0, 0.356909,
69 0.436415, 0.0355727, 0, 0,
70
71 0, 0, 0, 0.262335, 0, 0,
72 0, 1.33992, 0, 2.9739, 0, 0,
73 1.31914, 2.66147, 0, 0,
74
75 0.942568, 0, 0, 0, 0.025507, 0,
76 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
77 0.8158, 1.21805, 0.586239, 0.25427,
78
79 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
80 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
81 0, 1.22031, 1.30117, 0.495867,
82
83 0.222187, 0, 0.72725, 0, 0.767003, 0,
84 0, 0.147835, 0, 0, 0, 0.608758,
85 0.469394, 0.00720298, 0.927537, 0,
86
87 0.856974, 0.424257, 0, 0, 0.937329, 0,
88 0, 0, 0.476425, 0, 0.566017, 0.418462,
89 0.141911, 0.996214, 1.13063, 0,
90
91 0.967899, 0, 0, 0, 0.0831304, 0,
92 0, 1.00378, 0, 0, 0, 1.44818,
93 1.01768, 0.943891, 0.502745, 0,
94
95 0.940135, 0, 0, 0, 0, 0,
96 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
97 1.30225, 1.59644, 0.70222, 0,
98
99 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
100 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
101 0.0454298, 0.300267, 0.562784, 0.395095,
102
103 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
104 0, 0, 0, 0.735363, 0.0759267, 1.91017,
105 0.941888, 0, 0, 0,
106
107 0, 0, 1.5909, 0, 0, 0,
108 0, 0.5755, 0, 0.184687, 0, 1.56296,
109 0.625285, 0, 0, 0,
110
111 0, 0, 0.0857888, 0, 0, 0,
112 0, 0.488383, 0.252786, 0, 0, 0,
113 1.02817, 1.85665, 0, 0,
114
115 0.00981836, 0, 1.06371, 0, 0, 0,
116 0, 0, 0, 0.290445, 0.316406, 0,
117 0.304161, 1.25079, 0.0707152, 0,
118
119 0.986264, 0.309201, 0, 0, 0, 0,
120 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
121 0.524981, 1.92076, 2.07013, 0.333244,
122
123 0.415153, 0.210318, 0, 0, 0, 0,
124 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
125 0.628881, 3.58099, 1.49974, 0};
126
127 static std::initializer_list<float> rnn_weights = {
128 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
129 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
130 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
131 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
132 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
133 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
134 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
135 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
136 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
137 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
138 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
139 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
140 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
141 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
142 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
143 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
144 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
145 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
146 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
147 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
148 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
149 0.277308, 0.415818};
150
151 static std::initializer_list<float> rnn_recurrent_weights = {
152 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
153 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
154 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
155 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
156 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
158 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
159 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
160 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
161 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
162 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
163 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
164 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
165 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
166 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
167 0.1};
168
169 static std::initializer_list<float> rnn_bias = {
170 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
171 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
172 0.37197268, 0.61957061, 0.3956964, -0.37609905};
173
174 class RNNOpModel : public SingleOpModel {
175 public:
RNNOpModel(int batches,int units,int size,const TensorType & weights=TensorType_FLOAT32,const TensorType & recurrent_weights=TensorType_FLOAT32)176 RNNOpModel(int batches, int units, int size,
177 const TensorType& weights = TensorType_FLOAT32,
178 const TensorType& recurrent_weights = TensorType_FLOAT32)
179 : batches_(batches), units_(units), input_size_(size) {
180 input_ = AddInput(TensorType_FLOAT32);
181 weights_ = AddInput(weights);
182 recurrent_weights_ = AddInput(recurrent_weights);
183 bias_ = AddInput(TensorType_FLOAT32);
184 hidden_state_ = AddInput(TensorType_FLOAT32, true);
185 output_ = AddOutput(TensorType_FLOAT32);
186 SetBuiltinOp(
187 BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
188 CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
189 BuildInterpreter({{batches_, input_size_}, // input tensor
190 {units_, input_size_}, // weights tensor
191 {units_, units_}, // recurrent weights tensor
192 {units_}, // bias tensor
193 {batches_, units_}}); // hidden state tensor
194 }
195
SetBias(std::initializer_list<float> f)196 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
197
SetWeights(std::initializer_list<float> f)198 void SetWeights(std::initializer_list<float> f) {
199 PopulateTensor(weights_, f);
200 }
201
SetRecurrentWeights(std::initializer_list<float> f)202 void SetRecurrentWeights(std::initializer_list<float> f) {
203 PopulateTensor(recurrent_weights_, f);
204 }
205
SetInput(std::initializer_list<float> data)206 void SetInput(std::initializer_list<float> data) {
207 PopulateTensor(input_, data);
208 }
209
SetInput(int offset,float * begin,float * end)210 void SetInput(int offset, float* begin, float* end) {
211 PopulateTensor(input_, offset, begin, end);
212 }
213
GetOutput()214 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
215
input_size()216 int input_size() { return input_size_; }
num_units()217 int num_units() { return units_; }
num_batches()218 int num_batches() { return batches_; }
219
220 protected:
221 int input_;
222 int weights_;
223 int recurrent_weights_;
224 int bias_;
225 int hidden_state_;
226 int output_;
227
228 int batches_;
229 int units_;
230 int input_size_;
231 };
232
233 // The hybrid model has quantized weights and recurrent_weights.
234 class HybridRNNOpModel : public RNNOpModel {
235 public:
HybridRNNOpModel(int batches,int units,int size,TensorType tensor_type)236 HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type)
237 : RNNOpModel(batches, units, size, tensor_type, tensor_type) {
238 tensor_type_ = tensor_type;
239 }
240
241 TensorType tensor_type_;
242
SetWeights(int weights_idx,const std::vector<float> & f)243 void SetWeights(int weights_idx, const std::vector<float>& f) {
244 if (tensor_type_ == TensorType_UINT8) {
245 SymmetricQuantizeAndPopulate(weights_idx, f);
246 } else {
247 SignedSymmetricQuantizeAndPopulate(weights_idx, f);
248 }
249 }
250
SetWeights(std::initializer_list<float> f)251 void SetWeights(std::initializer_list<float> f) { SetWeights(weights_, f); }
252
SetRecurrentWeights(std::initializer_list<float> f)253 void SetRecurrentWeights(std::initializer_list<float> f) {
254 SetWeights(recurrent_weights_, f);
255 }
256 };
257
TEST(RnnOpTest,BlackBoxTest)258 TEST(RnnOpTest, BlackBoxTest) {
259 RNNOpModel rnn(2, 16, 8);
260 rnn.SetWeights(rnn_weights);
261 rnn.SetBias(rnn_bias);
262 rnn.SetRecurrentWeights(rnn_recurrent_weights);
263
264 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
265 (rnn.input_size() * rnn.num_batches());
266
267 for (int i = 0; i < input_sequence_size; i++) {
268 float* batch_start = rnn_input + i * rnn.input_size();
269 float* batch_end = batch_start + rnn.input_size();
270 rnn.SetInput(0, batch_start, batch_end);
271 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
272
273 rnn.Invoke();
274
275 float* golden_start = rnn_golden_output + i * rnn.num_units();
276 float* golden_end = golden_start + rnn.num_units();
277 std::vector<float> expected;
278 expected.insert(expected.end(), golden_start, golden_end);
279 expected.insert(expected.end(), golden_start, golden_end);
280
281 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
282 }
283 }
284
TEST(HybridRnnOpTest,BlackBoxTestUint8)285 TEST(HybridRnnOpTest, BlackBoxTestUint8) {
286 HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
287 rnn.SetWeights(rnn_weights);
288 rnn.SetBias(rnn_bias);
289 rnn.SetRecurrentWeights(rnn_recurrent_weights);
290
291 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
292 (rnn.input_size() * rnn.num_batches());
293
294 for (int i = 0; i < input_sequence_size; i++) {
295 float* batch_start = rnn_input + i * rnn.input_size();
296 float* batch_end = batch_start + rnn.input_size();
297 rnn.SetInput(0, batch_start, batch_end);
298 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
299
300 rnn.Invoke();
301
302 float* golden_start = rnn_golden_output + i * rnn.num_units();
303 float* golden_end = golden_start + rnn.num_units();
304 std::vector<float> expected;
305 expected.insert(expected.end(), golden_start, golden_end);
306 expected.insert(expected.end(), golden_start, golden_end);
307
308 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
309 expected, /*max_abs_error=*/0.0104)));
310 }
311 }
312
TEST(HybridRnnOpTest,BlackBoxTestInt8)313 TEST(HybridRnnOpTest, BlackBoxTestInt8) {
314 HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8);
315 rnn.SetWeights(rnn_weights);
316 rnn.SetBias(rnn_bias);
317 rnn.SetRecurrentWeights(rnn_recurrent_weights);
318
319 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
320 (rnn.input_size() * rnn.num_batches());
321
322 for (int i = 0; i < input_sequence_size; i++) {
323 float* batch_start = rnn_input + i * rnn.input_size();
324 float* batch_end = batch_start + rnn.input_size();
325 rnn.SetInput(0, batch_start, batch_end);
326 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
327
328 rnn.Invoke();
329
330 float* golden_start = rnn_golden_output + i * rnn.num_units();
331 float* golden_end = golden_start + rnn.num_units();
332 std::vector<float> expected;
333 expected.insert(expected.end(), golden_start, golden_end);
334 expected.insert(expected.end(), golden_start, golden_end);
335
336 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
337 expected, /*max_abs_error=*/0.0104)));
338 }
339 }
340
341 } // namespace
342 } // namespace tflite
343