• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite RNN op.
16 
17 #include <string.h>
18 #include <initializer_list>
19 #include <memory>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/register.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28 
29 namespace tflite {
30 namespace {
31 
32 using ::testing::ElementsAreArray;
33 
34 static float rnn_input[] = {
35     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
36     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
37     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
38     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
39     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
40     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
41     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
42     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
43     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
44     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
45     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
46     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
47     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
48     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
49     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
50     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
51     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
52     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
53     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
54     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
55     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
56     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
57     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
58     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
59     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
60     0.93455386,   -0.6324693,   -0.083922029};
61 
62 static float rnn_golden_output[] = {
63     0.496726,   0,          0.965996,  0,         0.0584254, 0,
64     0,          0.12315,    0,         0,         0.612266,  0.456601,
65     0,          0.52286,    1.16099,   0.0291232,
66 
67     0,          0,          0.524901,  0,         0,         0,
68     0,          1.02116,    0,         1.35762,   0,         0.356909,
69     0.436415,   0.0355727,  0,         0,
70 
71     0,          0,          0,         0.262335,  0,         0,
72     0,          1.33992,    0,         2.9739,    0,         0,
73     1.31914,    2.66147,    0,         0,
74 
75     0.942568,   0,          0,         0,         0.025507,  0,
76     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
77     0.8158,     1.21805,    0.586239,  0.25427,
78 
79     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
80     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
81     0,          1.22031,    1.30117,   0.495867,
82 
83     0.222187,   0,          0.72725,   0,         0.767003,  0,
84     0,          0.147835,   0,         0,         0,         0.608758,
85     0.469394,   0.00720298, 0.927537,  0,
86 
87     0.856974,   0.424257,   0,         0,         0.937329,  0,
88     0,          0,          0.476425,  0,         0.566017,  0.418462,
89     0.141911,   0.996214,   1.13063,   0,
90 
91     0.967899,   0,          0,         0,         0.0831304, 0,
92     0,          1.00378,    0,         0,         0,         1.44818,
93     1.01768,    0.943891,   0.502745,  0,
94 
95     0.940135,   0,          0,         0,         0,         0,
96     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
97     1.30225,    1.59644,    0.70222,   0,
98 
99     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
100     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
101     0.0454298,  0.300267,   0.562784,  0.395095,
102 
103     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
104     0,          0,          0,         0.735363,  0.0759267, 1.91017,
105     0.941888,   0,          0,         0,
106 
107     0,          0,          1.5909,    0,         0,         0,
108     0,          0.5755,     0,         0.184687,  0,         1.56296,
109     0.625285,   0,          0,         0,
110 
111     0,          0,          0.0857888, 0,         0,         0,
112     0,          0.488383,   0.252786,  0,         0,         0,
113     1.02817,    1.85665,    0,         0,
114 
115     0.00981836, 0,          1.06371,   0,         0,         0,
116     0,          0,          0,         0.290445,  0.316406,  0,
117     0.304161,   1.25079,    0.0707152, 0,
118 
119     0.986264,   0.309201,   0,         0,         0,         0,
120     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
121     0.524981,   1.92076,    2.07013,   0.333244,
122 
123     0.415153,   0.210318,   0,         0,         0,         0,
124     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
125     0.628881,   3.58099,    1.49974,   0};
126 
127 static std::initializer_list<float> rnn_weights = {
128     0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
129     0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
130     0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
131     -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
132     -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
133     -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
134     -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
135     0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
136     0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
137     0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
138     -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
139     0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
140     -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
141     -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
142     0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
143     0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
144     0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
145     -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
146     0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
147     0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
148     -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
149     0.277308,    0.415818};
150 
151 static std::initializer_list<float> rnn_recurrent_weights = {
152     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
153     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
154     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
155     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
156     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
158     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
159     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
160     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
161     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
162     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
163     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
164     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
165     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
166     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
167     0.1};
168 
169 static std::initializer_list<float> rnn_bias = {
170     0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
171     -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
172     0.37197268,  0.61957061,  0.3956964,  -0.37609905};
173 
174 class RNNOpModel : public SingleOpModel {
175  public:
RNNOpModel(int batches,int units,int size,const TensorType & weights=TensorType_FLOAT32,const TensorType & recurrent_weights=TensorType_FLOAT32)176   RNNOpModel(int batches, int units, int size,
177              const TensorType& weights = TensorType_FLOAT32,
178              const TensorType& recurrent_weights = TensorType_FLOAT32)
179       : batches_(batches), units_(units), input_size_(size) {
180     input_ = AddInput(TensorType_FLOAT32);
181     weights_ = AddInput(weights);
182     recurrent_weights_ = AddInput(recurrent_weights);
183     bias_ = AddInput(TensorType_FLOAT32);
184     hidden_state_ = AddInput(TensorType_FLOAT32, true);
185     output_ = AddOutput(TensorType_FLOAT32);
186     SetBuiltinOp(
187         BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
188         CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
189     BuildInterpreter({{batches_, input_size_},  // input tensor
190                       {units_, input_size_},    // weights tensor
191                       {units_, units_},         // recurrent weights tensor
192                       {units_},                 // bias tensor
193                       {batches_, units_}});     // hidden state tensor
194   }
195 
SetBias(std::initializer_list<float> f)196   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
197 
SetWeights(std::initializer_list<float> f)198   void SetWeights(std::initializer_list<float> f) {
199     PopulateTensor(weights_, f);
200   }
201 
SetRecurrentWeights(std::initializer_list<float> f)202   void SetRecurrentWeights(std::initializer_list<float> f) {
203     PopulateTensor(recurrent_weights_, f);
204   }
205 
SetInput(std::initializer_list<float> data)206   void SetInput(std::initializer_list<float> data) {
207     PopulateTensor(input_, data);
208   }
209 
SetInput(int offset,float * begin,float * end)210   void SetInput(int offset, float* begin, float* end) {
211     PopulateTensor(input_, offset, begin, end);
212   }
213 
GetOutput()214   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
215 
input_size()216   int input_size() { return input_size_; }
num_units()217   int num_units() { return units_; }
num_batches()218   int num_batches() { return batches_; }
219 
220  protected:
221   int input_;
222   int weights_;
223   int recurrent_weights_;
224   int bias_;
225   int hidden_state_;
226   int output_;
227 
228   int batches_;
229   int units_;
230   int input_size_;
231 };
232 
233 // The hybrid model has quantized weights and recurrent_weights.
234 class HybridRNNOpModel : public RNNOpModel {
235  public:
HybridRNNOpModel(int batches,int units,int size,TensorType tensor_type)236   HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type)
237       : RNNOpModel(batches, units, size, tensor_type, tensor_type) {
238     tensor_type_ = tensor_type;
239   }
240 
241   TensorType tensor_type_;
242 
SetWeights(int weights_idx,const std::vector<float> & f)243   void SetWeights(int weights_idx, const std::vector<float>& f) {
244     if (tensor_type_ == TensorType_UINT8) {
245       SymmetricQuantizeAndPopulate(weights_idx, f);
246     } else {
247       SignedSymmetricQuantizeAndPopulate(weights_idx, f);
248     }
249   }
250 
SetWeights(std::initializer_list<float> f)251   void SetWeights(std::initializer_list<float> f) { SetWeights(weights_, f); }
252 
SetRecurrentWeights(std::initializer_list<float> f)253   void SetRecurrentWeights(std::initializer_list<float> f) {
254     SetWeights(recurrent_weights_, f);
255   }
256 };
257 
TEST(RnnOpTest,BlackBoxTest)258 TEST(RnnOpTest, BlackBoxTest) {
259   RNNOpModel rnn(2, 16, 8);
260   rnn.SetWeights(rnn_weights);
261   rnn.SetBias(rnn_bias);
262   rnn.SetRecurrentWeights(rnn_recurrent_weights);
263 
264   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
265                                   (rnn.input_size() * rnn.num_batches());
266 
267   for (int i = 0; i < input_sequence_size; i++) {
268     float* batch_start = rnn_input + i * rnn.input_size();
269     float* batch_end = batch_start + rnn.input_size();
270     rnn.SetInput(0, batch_start, batch_end);
271     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
272 
273     rnn.Invoke();
274 
275     float* golden_start = rnn_golden_output + i * rnn.num_units();
276     float* golden_end = golden_start + rnn.num_units();
277     std::vector<float> expected;
278     expected.insert(expected.end(), golden_start, golden_end);
279     expected.insert(expected.end(), golden_start, golden_end);
280 
281     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
282   }
283 }
284 
TEST(HybridRnnOpTest,BlackBoxTestUint8)285 TEST(HybridRnnOpTest, BlackBoxTestUint8) {
286   HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
287   rnn.SetWeights(rnn_weights);
288   rnn.SetBias(rnn_bias);
289   rnn.SetRecurrentWeights(rnn_recurrent_weights);
290 
291   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
292                                   (rnn.input_size() * rnn.num_batches());
293 
294   for (int i = 0; i < input_sequence_size; i++) {
295     float* batch_start = rnn_input + i * rnn.input_size();
296     float* batch_end = batch_start + rnn.input_size();
297     rnn.SetInput(0, batch_start, batch_end);
298     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
299 
300     rnn.Invoke();
301 
302     float* golden_start = rnn_golden_output + i * rnn.num_units();
303     float* golden_end = golden_start + rnn.num_units();
304     std::vector<float> expected;
305     expected.insert(expected.end(), golden_start, golden_end);
306     expected.insert(expected.end(), golden_start, golden_end);
307 
308     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
309                                      expected, /*max_abs_error=*/0.0104)));
310   }
311 }
312 
TEST(HybridRnnOpTest,BlackBoxTestInt8)313 TEST(HybridRnnOpTest, BlackBoxTestInt8) {
314   HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8);
315   rnn.SetWeights(rnn_weights);
316   rnn.SetBias(rnn_bias);
317   rnn.SetRecurrentWeights(rnn_recurrent_weights);
318 
319   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
320                                   (rnn.input_size() * rnn.num_batches());
321 
322   for (int i = 0; i < input_sequence_size; i++) {
323     float* batch_start = rnn_input + i * rnn.input_size();
324     float* batch_end = batch_start + rnn.input_size();
325     rnn.SetInput(0, batch_start, batch_end);
326     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
327 
328     rnn.Invoke();
329 
330     float* golden_start = rnn_golden_output + i * rnn.num_units();
331     float* golden_end = golden_start + rnn.num_units();
332     std::vector<float> expected;
333     expected.insert(expected.end(), golden_start, golden_end);
334     expected.insert(expected.end(), golden_start, golden_end);
335 
336     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
337                                      expected, /*max_abs_error=*/0.0104)));
338   }
339 }
340 
341 }  // namespace
342 }  // namespace tflite
343