• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Bidirectional RNN op.
16 
17 #include <algorithm>
18 #include <functional>
19 #include <initializer_list>
20 #include <iterator>
21 #include <tuple>
22 #include <vector>
23 
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
27 #include "tensorflow/lite/kernels/test_util.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 
30 namespace tflite {
31 namespace {
32 
33 enum class AuxInputMode {
34   kNoAuxInput,
35   kCrossLinking,
36   kNoCrossLinking,
37 };
38 
39 using ::testing::ElementsAreArray;
40 
41 static float rnn_input[] = {
42     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
43     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
44     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
45     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
46     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
47     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
48     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
49     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
50     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
51     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
52     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
53     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
54     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
55     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
56     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
57     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
58     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
59     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
60     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
61     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
62     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
63     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
64     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
65     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
66     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
67     0.93455386,   -0.6324693,   -0.083922029};
68 
69 static float rnn_golden_fw_output[] = {
70     0.496726,   0,          0.965996,  0,         0.0584254, 0,
71     0,          0.12315,    0,         0,         0.612266,  0.456601,
72     0,          0.52286,    1.16099,   0.0291232,
73 
74     0,          0,          0.524901,  0,         0,         0,
75     0,          1.02116,    0,         1.35762,   0,         0.356909,
76     0.436415,   0.0355727,  0,         0,
77 
78     0,          0,          0,         0.262335,  0,         0,
79     0,          1.33992,    0,         2.9739,    0,         0,
80     1.31914,    2.66147,    0,         0,
81 
82     0.942568,   0,          0,         0,         0.025507,  0,
83     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
84     0.8158,     1.21805,    0.586239,  0.25427,
85 
86     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
87     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
88     0,          1.22031,    1.30117,   0.495867,
89 
90     0.222187,   0,          0.72725,   0,         0.767003,  0,
91     0,          0.147835,   0,         0,         0,         0.608758,
92     0.469394,   0.00720298, 0.927537,  0,
93 
94     0.856974,   0.424257,   0,         0,         0.937329,  0,
95     0,          0,          0.476425,  0,         0.566017,  0.418462,
96     0.141911,   0.996214,   1.13063,   0,
97 
98     0.967899,   0,          0,         0,         0.0831304, 0,
99     0,          1.00378,    0,         0,         0,         1.44818,
100     1.01768,    0.943891,   0.502745,  0,
101 
102     0.940135,   0,          0,         0,         0,         0,
103     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
104     1.30225,    1.59644,    0.70222,   0,
105 
106     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
107     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
108     0.0454298,  0.300267,   0.562784,  0.395095,
109 
110     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
111     0,          0,          0,         0.735363,  0.0759267, 1.91017,
112     0.941888,   0,          0,         0,
113 
114     0,          0,          1.5909,    0,         0,         0,
115     0,          0.5755,     0,         0.184687,  0,         1.56296,
116     0.625285,   0,          0,         0,
117 
118     0,          0,          0.0857888, 0,         0,         0,
119     0,          0.488383,   0.252786,  0,         0,         0,
120     1.02817,    1.85665,    0,         0,
121 
122     0.00981836, 0,          1.06371,   0,         0,         0,
123     0,          0,          0,         0.290445,  0.316406,  0,
124     0.304161,   1.25079,    0.0707152, 0,
125 
126     0.986264,   0.309201,   0,         0,         0,         0,
127     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
128     0.524981,   1.92076,    2.07013,   0.333244,
129 
130     0.415153,   0.210318,   0,         0,         0,         0,
131     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
132     0.628881,   3.58099,    1.49974,   0};
133 
134 static float rnn_golden_bw_output[] = {
135     0.496726, 0,          1.00883,   0,         0.0584256, 0,         0,
136     0.236412, 0,          0,         0.612267,  0.487726,  0,         0.54883,
137     1.16099,  0.0291233,  0,         0,         0.428302,  0,         0,
138     0,        0,          1.13262,   0,         1.64415,   0,         0.311249,
139     0.570804, 0.259696,   0,         0,         0,         0,         0,
140     0.262334, 0,          0,         0,         1.23781,   0,         2.86532,
141     0,        0,          1.34389,   2.76409,   0,         0,         1.03969,
142     0,        0.00410865, 0,         0.0470295, 0,         0,         0,
143     0.371556, 0.27175,    1.36614,   1.63956,   0.683887,  1.06176,   0.719552,
144     0.301314, 0.971195,   0,         0.697143,  0,         0.215219,  0.210693,
145     0.363027, 0,          0.501283,  0,         1.13399,   0.623774,  0,
146     1.09851,  1.33313,    0.470441,  0.210965,  0,         0.664178,  0,
147     0.839686, 0,          0,         0.147834,  0,         0,         0,
148     0.58786,  0.490128,   0,         0.905806,  0,         0.932134,  0.424257,
149     0,        0,          0.860629,  0,         0,         0,         0.476425,
150     0,        0.566017,   0.513721,  0.207341,  1.09508,   1.08385,   0,
151     0.973787, 0,          0,         0,         0,         0,         0,
152     1.20698,  0,          0,         0,         1.56135,   1.12369,   0.99588,
153     0.459803, 0,          0.915854,  0,         0,         0,         0,
154     0,        0,          2.03206,   0,         0.773264,  0.267228,  1.55012,
155     1.202,    1.51611,    0.701202,  0,         0.725088,  0,         0.509069,
156     0,        0.671349,   0.581129,  0.343447,  0,         0.107755,  0.611838,
157     1.4331,   1.55871,    0.015242,  0.140624,  0.492562,  0.395095,  0.147722,
158     0,        0.784925,   0,         1.65477,   0.715257,  0,         0,
159     0,        0.685024,   0,         1.89505,   1.00037,   0,         0,
160     0,        0,          0,         1.52659,   0,         0,         0,
161     0,        0.618583,   0,         0.11115,   0,         1.37194,   0.630225,
162     0,        0,          0,         0,         0,         0.0322124, 0,
163     0,        0,          0,         0.430834,  0.252786,  0,         0,
164     0,        0.991297,   1.98451,   0,         0,         0.111511,  0,
165     1.05513,  0,          0,         0,         0,         0,         0,
166     0.290445, 0.412559,   0.0429958, 0.256564,  1.27858,   0.289948,  0,
167     1.01693,  0.327141,   0,         0,         0,         0,         0,
168     1.83508,  0.346248,   0,         0.961535,  0.790026,  0.552203,  2.13457,
169     2.19233,  0.333244,   0.316526,  0.179398,  0,         0,         0,
170     0,        0,          1.86126,   0,         0.728256,  0.750013,  0.011861,
171     0.576383, 3.38891,    1.29273,   0};
172 
173 const std::initializer_list<float> weights = {
174     0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
175     0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
176     0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
177     -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
178     -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
179     -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
180     -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
181     0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
182     0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
183     0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
184     -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
185     0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
186     -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
187     -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
188     0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
189     0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
190     0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
191     -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
192     0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
193     0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
194     -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
195     0.277308,    0.415818};
196 
197 static float endtoend_input[] = {
198     0.996808, 0.060710, 0.981855, 0.570017, 0.525164, 0.796859, 0.696547,
199     0.505925, 0.991844, 0.461208, 0.949371, 0.027624, 0.539236, 0.841854,
200     0.915222, 0.538569, 0.069375, 0.237905, 0.903700, 0.441703, 0.536196,
201     0.402724, 0.761635, 0.025063, 0.082592, 0.688245, 0.239310, 0.256931,
202     0.658900, 0.105695, 0.301983, 0.655708, 0.166405, 0.283837, 0.225725,
203     0.691569, 0.080696, 0.922272, 0.197494, 0.072540, 0.383481, 0.146865,
204     0.100163, 0.922717, 0.988720, 0.015386, 0.461286, 0.058095, 0.253290,
205     0.364986, 0.499797, 0.789487, 0.767709, 0.261433, 0.814549, 0.850302,
206     0.949678, 0.053859, 0.107233, 0.608577, 0.159554, 0.409215, 0.264285,
207     0.325960, 0.693053, 0.490011, 0.017529, 0.773749, 0.412283, 0.215023,
208     0.846288, 0.795764, 0.361889, 0.946452, 0.718481, 0.350608, 0.961837,
209     0.179767, 0.408703, 0.215128, 0.544753, 0.908500, 0.004614, 0.312462,
210     0.169933, 0.819163, 0.162764, 0.119611, 0.873022, 0.269997, 0.728188,
211     0.032576, 0.679212, 0.992474, 0.358536, 0.372265, 0.482484, 0.376065,
212     0.146014, 0.894767, 0.591088, 0.992302, 0.690531, 0.952977, 0.938754,
213     0.409012, 0.303585, 0.900591, 0.588780, 0.712287, 0.115719, 0.133533,
214     0.620788, 0.120334, 0.445995, 0.790720, 0.939497, 0.608759, 0.910331,
215     0.812519, 0.878756, 0.638519, 0.845096, 0.557968, 0.630993, 0.203632,
216     0.930233, 0.113477, 0.579697, 0.076247, 0.008244, 0.170785, 0.068549,
217     0.698776, 0.123761, 0.007303, 0.107788, 0.427346, 0.907894, 0.696568,
218     0.139633, 0.023613, 0.830100, 0.760421, 0.143947, 0.276096, 0.551141,
219     0.083444, 0.884855, 0.461472, 0.895963, 0.763611, 0.099992, 0.741059,
220     0.321579, 0.730984, 0.944691, 0.251812, 0.844461, 0.524388, 0.328059,
221     0.852706, 0.695172, 0.396607, 0.551482, 0.818934, 0.403910, 0.659270,
222     0.246280, 0.311804, 0.355838, 0.385913, 0.335418, 0.185938, 0.146334,
223     0.479364, 0.462034, 0.697475, 0.562808, 0.346888, 0.158948, 0.458771,
224     0.110499, 0.258939, 0.199830, 0.432078, 0.989924, 0.144521, 0.683890,
225     0.834385, 0.668908, 0.011949, 0.687091, 0.364081, 0.408556, 0.238572,
226     0.183015, 0.812466, 0.897842, 0.429294, 0.124271, 0.253680, 0.815207,
227     0.459688, 0.439618, 0.961541, 0.939053, 0.901651, 0.659016, 0.501861,
228     0.248539, 0.817964, 0.960632, 0.359038, 0.076903, 0.160462, 0.791117,
229     0.066826, 0.304983, 0.475007, 0.901211, 0.973891, 0.486955, 0.588302,
230     0.337972, 0.895512, 0.826874, 0.520987, 0.707978, 0.724716, 0.950281,
231     0.832249, 0.978396, 0.765488, 0.291937, 0.418014, 0.727029, 0.230990,
232     0.319665, 0.386045, 0.732850, 0.568204, 0.204009, 0.693482, 0.927242,
233     0.280912, 0.853944, 0.718359, 0.347738, 0.158927, 0.193366, 0.248950,
234     0.132818, 0.680321, 0.837252, 0.470790, 0.575833, 0.664126, 0.991777,
235     0.283811, 0.388843, 0.942058, 0.116060, 0.367239, 0.707546, 0.407997,
236     0.785253, 0.434575, 0.638986, 0.104917, 0.820620, 0.371837, 0.673121,
237     0.024629, 0.065319, 0.600363, 0.305541, 0.919263, 0.318722, 0.653279,
238     0.078190, 0.512088, 0.902229, 0.211009, 0.192409, 0.739480, 0.681799,
239     0.768242, 0.403607, 0.673576, 0.052052, 0.792450, 0.615634, 0.168112,
240     0.159689, 0.323180, 0.576109, 0.944941, 0.757755, 0.215095, 0.049858,
241     0.578375, 0.586932, 0.722979, 0.603003, 0.652251, 0.323343, 0.908544,
242     0.571514, 0.642065, 0.561823, 0.649704, 0.154153, 0.464051, 0.860713,
243     0.346562, 0.203532, 0.542512, 0.114804, 0.607139, 0.216088, 0.166856,
244     0.399588, 0.831722, 0.334968, 0.559277, 0.154902, 0.911077, 0.504218,
245     0.912656, 0.126172, 0.554076, 0.491031, 0.713104, 0.277055, 0.094034,
246     0.365355, 0.600398, 0.002578, 0.936869, 0.242463, 0.564401, 0.586574,
247     0.396616, 0.028452, 0.447287, 0.743178, 0.231984, 0.989799, 0.857982,
248     0.839122, 0.205887, 0.024838, 0.238711, 0.037608, 0.359806, 0.797987,
249     0.192510, 0.270883, 0.302205, 0.105166, 0.397055, 0.856281, 0.596197,
250     0.110160, 0.133336, 0.690231, 0.475515, 0.733734, 0.692809, 0.412384,
251     0.976196, 0.257209, 0.998958, 0.372812, 0.285661, 0.446245, 0.115990,
252     0.517645, 0.436044, 0.973972, 0.356767, 0.641930, 0.998810, 0.595478,
253     0.679539, 0.358617, 0.393465, 0.872049, 0.629500, 0.695670, 0.977215,
254     0.026555, 0.551951, 0.573412, 0.136715, 0.685287, 0.263643, 0.612229,
255     0.419020, 0.956451, 0.024613, 0.395216, 0.213661, 0.023572, 0.768029,
256     0.499322, 0.469816, 0.884019, 0.016967, 0.905860, 0.857991, 0.373734,
257     0.547791, 0.856802, 0.969211, 0.227330, 0.215418, 0.362676, 0.099378,
258     0.844918, 0.058346, 0.076594, 0.871473, 0.610297, 0.650006, 0.008188,
259     0.295583, 0.913648, 0.620417, 0.714603, 0.870100, 0.645031, 0.109820,
260     0.083760, 0.668602, 0.877849, 0.583082, 0.138419, 0.761868, 0.600049,
261     0.044279, 0.619859, 0.973783, 0.592069, 0.476661, 0.942994, 0.819399,
262     0.692079, 0.305670, 0.918778, 0.536997, 0.364016, 0.995371, 0.408470,
263     0.974313, 0.645377, 0.416658, 0.269896, 0.559025, 0.037075, 0.984499,
264     0.429125, 0.682105, 0.094319, 0.512885, 0.350707, 0.972168, 0.095967,
265     0.489126, 0.734035, 0.696016, 0.533405, 0.353894, 0.669799, 0.125474,
266     0.830555, 0.612793, 0.944873, 0.522634, 0.918463, 0.863651, 0.059631,
267     0.282479, 0.859022, 0.468101, 0.256791, 0.504398, 0.884758, 0.526687,
268     0.063423, 0.921833, 0.511186, 0.492548, 0.603939, 0.605505, 0.005433,
269     0.954646, 0.577673, 0.101400, 0.443772, 0.311708, 0.797417, 0.977176,
270     0.665602, 0.467216, 0.102650, 0.496157, 0.080009, 0.047524, 0.018791,
271     0.998471, 0.911174, 0.078422, 0.280950, 0.770196, 0.546523, 0.537741,
272     0.274594, 0.431281, 0.064428, 0.338017, 0.353115, 0.575615, 0.830565,
273     0.957053, 0.181120, 0.835998, 0.911699, 0.758793, 0.937398, 0.355471,
274     0.070501, 0.734815, 0.332647, 0.736103, 0.202031, 0.435297, 0.232261,
275     0.282039, 0.482821, 0.251052, 0.280511, 0.393995, 0.329474, 0.561460,
276     0.164191, 0.875997, 0.099202, 0.438785, 0.307278, 0.163630, 0.776802,
277     0.660393, 0.739244, 0.607367, 0.617446, 0.920364, 0.443365, 0.529145,
278     0.679157, 0.380763, 0.884616, 0.749658, 0.115578, 0.217263, 0.485761,
279     0.317609, 0.652560, 0.718021, 0.599648, 0.135381, 0.969073, 0.880159,
280     0.529376, 0.298547, 0.441619, 0.693567, 0.174544, 0.540821, 0.132351,
281     0.481822, 0.704450, 0.909153, 0.142215, 0.443695, 0.516520, 0.759661,
282     0.364059, 0.959885, 0.288806, 0.043216, 0.340648, 0.173422, 0.792874,
283     0.456226, 0.390685, 0.278634, 0.773834, 0.043245, 0.996656, 0.373483,
284     0.178625, 0.965729, 0.253641, 0.708001, 0.264276, 0.695260, 0.401568,
285     0.438820, 0.236081, 0.533919, 0.920642, 0.940531, 0.443072, 0.062857,
286     0.384226, 0.959592, 0.822518, 0.748285, 0.919477, 0.111325, 0.791501,
287     0.260124, 0.284747, 0.584375, 0.716350, 0.675431, 0.863009, 0.490184,
288     0.718676, 0.859665, 0.863666, 0.897301, 0.825393, 0.117308, 0.605302,
289     0.089669, 0.812568, 0.006870, 0.528489, 0.048649, 0.540788, 0.449131,
290     0.989180, 0.983860, 0.511988, 0.373407, 0.943452, 0.334506, 0.121692,
291     0.862929, 0.445831, 0.913193, 0.123053, 0.730578, 0.497568, 0.839402,
292     0.406009, 0.360577, 0.329586, 0.124685, 0.220241, 0.193253, 0.021986,
293     0.045634, 0.310560, 0.627288, 0.135303, 0.123128, 0.634158, 0.663792,
294     0.171777, 0.174946, 0.112923, 0.160958, 0.158806, 0.624911, 0.534364,
295     0.102259, 0.959418, 0.656056, 0.965187, 0.405249, 0.569249, 0.088240,
296     0.135827, 0.066817, 0.927642, 0.541836, 0.427393, 0.257229, 0.666520,
297     0.647634, 0.450481, 0.688506, 0.693269, 0.761042, 0.315794, 0.828572,
298     0.884170, 0.949952, 0.492364, 0.055947, 0.124898, 0.605288, 0.216905,
299     0.283705, 0.230199, 0.751269, 0.385963, 0.189616, 0.407326, 0.351151,
300     0.594865, 0.976575, 0.439391, 0.730692, 0.043392, 0.367033, 0.272527,
301     0.470785, 0.624261, 0.939048, 0.118419, 0.074743, 0.627554, 0.811688,
302     0.835784, 0.943348, 0.640260, 0.719954, 0.893300, 0.132625, 0.775901,
303     0.018199, 0.737913, 0.992806, 0.301903, 0.968111, 0.744076, 0.687867,
304     0.157728, 0.151401, 0.039017, 0.752593, 0.127976, 0.478408, 0.483284,
305     0.171368, 0.845441, 0.755811, 0.642153, 0.469702, 0.694859, 0.760572,
306     0.544445, 0.322413, 0.572260, 0.380229, 0.265761, 0.212521, 0.100183,
307     0.159062, 0.345146, 0.876084, 0.177261, 0.083058, 0.868891, 0.479164,
308     0.051169, 0.612966, 0.167030, 0.208897, 0.764367, 0.206048, 0.961490,
309     0.892343, 0.684456, 0.444774, 0.063711, 0.529896, 0.200585, 0.705863,
310     0.999598, 0.895444, 0.466435, 0.544043, 0.217857, 0.038696, 0.924272,
311     0.483618, 0.251217, 0.024455, 0.642680, 0.596362, 0.900539, 0.819941,
312     0.679420, 0.769430, 0.299105, 0.730590, 0.382396, 0.466135, 0.939487,
313     0.146763, 0.672183, 0.900977, 0.039106, 0.356638, 0.345750, 0.102817,
314     0.886535, 0.546336, 0.808681, 0.886133, 0.441780, 0.275116, 0.430176,
315     0.659637, 0.313812, 0.354448, 0.143255, 0.565028, 0.378903, 0.785935,
316     0.161391, 0.279443, 0.605876, 0.840811, 0.048873, 0.904980, 0.571401,
317     0.431269, 0.371115, 0.510887, 0.578032, 0.043298, 0.411864, 0.617138,
318     0.399936, 0.757614, 0.719955, 0.286471, 0.303950, 0.528636, 0.172604,
319     0.745730, 0.803752, 0.602780, 0.405367, 0.117564, 0.957228, 0.548622,
320     0.682592, 0.336131, 0.334557, 0.843983, 0.615574, 0.940433, 0.684794,
321     0.664447, 0.845413, 0.256194, 0.095715, 0.216529, 0.767082, 0.673747,
322     0.259827, 0.178946, 0.290885, 0.659763, 0.936560, 0.010840, 0.946234,
323     0.240510, 0.539476, 0.118838, 0.986240, 0.343228, 0.721618, 0.391606,
324     0.460792, 0.678846, 0.940228, 0.143384, 0.014977, 0.274785, 0.987367,
325     0.630551, 0.215218, 0.672161, 0.294998, 0.060631, 0.928355, 0.390713,
326     0.277160, 0.695436, 0.064460, 0.536987, 0.874382, 0.355345, 0.196751,
327     0.810942, 0.366185, 0.142985, 0.051452, 0.905661, 0.261823, 0.037691,
328     0.248889, 0.983441, 0.429297, 0.709681, 0.662286, 0.369525, 0.853066,
329     0.677263, 0.644310, 0.840433, 0.307814, 0.859528, 0.512593, 0.602812,
330     0.920160, 0.440948, 0.993525, 0.197320, 0.136384, 0.057984, 0.734307,
331     0.010766, 0.413329, 0.931058, 0.821707, 0.779514, 0.074043, 0.873159,
332     0.685175, 0.335865, 0.910850, 0.934065, 0.319306, 0.340147, 0.643746,
333     0.981592, 0.709673, 0.496812, 0.658856, 0.353983, 0.337245, 0.966670,
334     0.213511, 0.849838, 0.569482, 0.133671, 0.290786, 0.563007, 0.330991,
335     0.427170, 0.620991, 0.065299, 0.437936, 0.034320, 0.996356, 0.259643,
336     0.813834, 0.070399, 0.132802, 0.499009, 0.406265, 0.043652, 0.433074,
337     0.725570, 0.383800, 0.076820, 0.707163, 0.093473, 0.573632, 0.366018,
338     0.447456, 0.910877, 0.332688, 0.660967, 0.760714, 0.902170, 0.794638,
339     0.051500, 0.465177, 0.125630, 0.478670, 0.086168, 0.190928, 0.916605,
340     0.120488, 0.187285, 0.176248, 0.934322, 0.257684, 0.309050, 0.433331,
341     0.663949, 0.352703, 0.866405, 0.389519, 0.736502, 0.943226, 0.096682,
342     0.829975, 0.516858, 0.462700, 0.277430, 0.427734, 0.795388, 0.938398,
343     0.188449, 0.697558, 0.733036, 0.239948, 0.162735, 0.858666, 0.718618,
344     0.248903, 0.049594, 0.635223, 0.369391, 0.236879, 0.811472, 0.303713,
345     0.494563, 0.120522, 0.737044, 0.158511, 0.473225, 0.603450, 0.548030,
346     0.209727, 0.546675, 0.644712, 0.039702, 0.063533, 0.107412, 0.317132,
347     0.491267, 0.902800, 0.255530, 0.679716, 0.600359, 0.988566, 0.919664,
348     0.763094, 0.847232, 0.638283, 0.011997, 0.896825, 0.273506, 0.381388,
349     0.133704, 0.084978, 0.685101, 0.628267, 0.205500, 0.422145, 0.786778,
350     0.678725, 0.025595, 0.334808, 0.888452, 0.572271, 0.979520, 0.928154,
351     0.635804, 0.086932, 0.245286, 0.127071, 0.989732, 0.500816, 0.806787,
352     0.590091, 0.489382, 0.726451, 0.353185, 0.336614, 0.364734, 0.365182,
353     0.233439, 0.638240, 0.746570, 0.367143, 0.723218, 0.431671, 0.995410,
354     0.928718, 0.853816, 0.782188, 0.607442, 0.879411, 0.116995, 0.495894,
355     0.451682, 0.096515, 0.424048, 0.087485, 0.183447, 0.669334, 0.214556,
356     0.173179, 0.170151, 0.021343, 0.763269, 0.659533, 0.747794, 0.116454,
357     0.996147, 0.112528, 0.481635, 0.229586, 0.750768, 0.228205, 0.596730,
358     0.473985, 0.659876, 0.592139, 0.402703, 0.513692, 0.374327, 0.010145,
359     0.393103, 0.491322, 0.506039, 0.844785, 0.587837, 0.930088, 0.932270,
360     0.771284, 0.599422, 0.146826, 0.944463, 0.769573, 0.168169, 0.707732,
361     0.429106, 0.915964, 0.824186, 0.425253, 0.028492, 0.305821, 0.654839,
362     0.779259, 0.534026, 0.251569, 0.253245, 0.193901, 0.843708, 0.655947,
363     0.707593, 0.218035, 0.666093, 0.100696, 0.709357, 0.172132, 0.945481,
364     0.297195, 0.102220, 0.877751, 0.068479, 0.701642, 0.024577, 0.012941,
365     0.471215, 0.192747, 0.720673, 0.900321, 0.108710, 0.544859, 0.325574,
366     0.137202, 0.850679, 0.980413, 0.916462, 0.384705, 0.231982, 0.169706,
367     0.578607, 0.075690, 0.825654, 0.286200, 0.293725, 0.491746, 0.386896,
368     0.003083, 0.663878, 0.332377, 0.300278, 0.766098, 0.210128, 0.368756,
369     0.467740, 0.234705, 0.381697, 0.938955, 0.427451, 0.102370, 0.839275,
370     0.536162, 0.647229, 0.164849, 0.673364, 0.497908, 0.145262, 0.589825,
371     0.882613, 0.377244, 0.759532, 0.461220, 0.452934, 0.585185, 0.747420,
372     0.746660, 0.076932, 0.134316, 0.749743, 0.740810, 0.466692, 0.050020,
373     0.506908, 0.676820, 0.418776, 0.974648, 0.911525, 0.800474, 0.913602,
374     0.338976, 0.902844, 0.752878, 0.875138, 0.550072, 0.917727, 0.548502,
375     0.047981, 0.062989, 0.138327, 0.930594, 0.440233, 0.897859, 0.391814,
376     0.893168, 0.483044, 0.139234, 0.639828, 0.559975, 0.273549, 0.389570,
377     0.300785, 0.740242, 0.439590, 0.807693, 0.417062, 0.858367, 0.782341,
378     0.328586, 0.658840, 0.695943, 0.667562, 0.561684, 0.448821, 0.542700,
379     0.111756, 0.366548, 0.091202, 0.159737, 0.429537, 0.229529, 0.090331,
380     0.869770, 0.127388, 0.482145, 0.762938, 0.610432, 0.621379, 0.402765,
381     0.170407, 0.894928, 0.792336, 0.471192, 0.635170, 0.231926, 0.278886,
382     0.052232, 0.090293, 0.061226, 0.380818, 0.749133, 0.757170, 0.048380,
383     0.310817, 0.205990, 0.591080, 0.422573, 0.572538, 0.682282, 0.582310,
384     0.002075, 0.911812, 0.672641, 0.871845, 0.039199, 0.154786, 0.634783,
385     0.649631, 0.776165, 0.037548, 0.820038, 0.671093, 0.829884, 0.291231,
386     0.306263, 0.061810, 0.570116, 0.358495, 0.152103, 0.631343, 0.739313,
387     0.901236, 0.388512, 0.787693, 0.212053, 0.594503, 0.378773, 0.634626,
388     0.167040, 0.061056, 0.216937, 0.169115, 0.972867, 0.889578, 0.040960,
389     0.012067, 0.044364, 0.675743, 0.661698, 0.820529, 0.713291, 0.481736,
390     0.491623, 0.543175, 0.772966, 0.797886, 0.604985, 0.343083, 0.156380,
391     0.757088, 0.974425, 0.895693, 0.658324, 0.362938, 0.683386, 0.870376,
392     0.957440, 0.062159, 0.505002, 0.124481, 0.123215, 0.721939, 0.293596,
393     0.096082, 0.611517, 0.334556, 0.108149, 0.655881, 0.010299, 0.769846,
394     0.476411, 0.723590, 0.251582, 0.968033, 0.266765, 0.024548, 0.765919,
395     0.871750, 0.367631, 0.922299, 0.628838, 0.342056, 0.817992, 0.287162,
396     0.704994, 0.501378, 0.157538, 0.662434, 0.563537, 0.662541, 0.786915,
397     0.686752, 0.384480, 0.080511, 0.782834, 0.995997, 0.415067, 0.890983,
398     0.651878, 0.425365, 0.660829, 0.128289, 0.148956, 0.912411, 0.096322,
399     0.415721, 0.936959, 0.862241, 0.287471, 0.304590, 0.784540, 0.916309,
400     0.646646, 0.602533, 0.203471, 0.351640, 0.103911, 0.361009, 0.014074,
401     0.667448, 0.023550, 0.800989, 0.354200, 0.408030, 0.881500, 0.137034,
402     0.404026, 0.296566, 0.028017, 0.055904, 0.721932, 0.688846, 0.184193,
403     0.870887, 0.601257, 0.280515, 0.286608, 0.538216, 0.142755, 0.574079,
404     0.842806, 0.927296, 0.490388, 0.489452, 0.529828, 0.693859, 0.841092,
405     0.633739, 0.054869, 0.855167, 0.301187, 0.078419, 0.656156, 0.655388,
406     0.486448, 0.537656, 0.792422, 0.890475, 0.834222, 0.820439, 0.946379,
407     0.556153, 0.509285, 0.130571, 0.427041, 0.110542, 0.411086, 0.713648,
408     0.648758, 0.553842, 0.287727, 0.491563, 0.481137, 0.778116, 0.981015,
409     0.010966, 0.471975, 0.822107, 0.644705, 0.526844, 0.677274, 0.945892,
410     0.605263, 0.333430, 0.601280, 0.091711, 0.871086, 0.393702, 0.982186,
411     0.705307, 0.214141, 0.928564, 0.261461, 0.723426, 0.059136, 0.688501,
412     0.833968, 0.470222, 0.402150, 0.482725, 0.024063, 0.689877, 0.974289,
413     0.505201, 0.467993, 0.955304, 0.516166, 0.939968, 0.777411, 0.160871,
414     0.466812, 0.454685, 0.106763, 0.072075, 0.788115, 0.708043, 0.163786,
415     0.659201, 0.101744, 0.145971, 0.364508, 0.315885, 0.074536, 0.625969,
416     0.039311, 0.133672, 0.314471, 0.873279, 0.603893, 0.716620, 0.356004,
417     0.627957, 0.406498, 0.330292, 0.133157, 0.874490, 0.285596, 0.649324,
418     0.814458, 0.063007, 0.810195, 0.281270, 0.517693, 0.916958, 0.353345,
419     0.305808, 0.625000, 0.517131, 0.965009, 0.726745, 0.663102, 0.329518,
420     0.042630, 0.737638, 0.955487, 0.081940, 0.871310, 0.269957, 0.955219,
421     0.475203, 0.986578, 0.311223, 0.103160, 0.393075, 0.641515, 0.236317,
422     0.267566, 0.927112, 0.885641, 0.082024, 0.990119, 0.695835, 0.363295,
423     0.507812, 0.612793, 0.716640, 0.813620, 0.237793, 0.233770, 0.778629,
424     0.964538, 0.896872, 0.108147, 0.007167, 0.634510, 0.063633, 0.089108,
425     0.505820, 0.333591, 0.044327, 0.981023, 0.320168, 0.355550, 0.084182,
426     0.713244, 0.997065, 0.320499, 0.980810, 0.924177, 0.206140, 0.062834,
427     0.914296, 0.901975, 0.426129, 0.422107, 0.514768, 0.142768, 0.235727,
428     0.752561, 0.376539, 0.014356, 0.717099, 0.273411, 0.122502, 0.724266,
429     0.907921, 0.186136, 0.813374, 0.413741, 0.519726, 0.857701, 0.394764,
430     0.839895, 0.213251, 0.478946, 0.553139, 0.210317, 0.799446, 0.533948,
431     0.134493, 0.005586, 0.596782, 0.048789, 0.907561, 0.022911, 0.470896,
432     0.422329, 0.165679, 0.706623, 0.174890, 0.542218, 0.720979, 0.891989,
433     0.815629, 0.843481, 0.616255, 0.723551, 0.029617, 0.429630, 0.137292,
434     0.549343, 0.287331, 0.532056, 0.389238, 0.500583, 0.011002, 0.942377,
435     0.710899, 0.810448, 0.476326, 0.845392, 0.816033, 0.073108, 0.894181,
436     0.723594, 0.096019, 0.365077, 0.145923, 0.261699, 0.071700, 0.320813,
437     0.803917, 0.792679, 0.212802, 0.619546, 0.636160, 0.829057, 0.343096,
438     0.665777, 0.258687, 0.480388, 0.215121, 0.546018, 0.012444, 0.604359,
439     0.046601, 0.023446, 0.546736, 0.757500, 0.833893, 0.023062, 0.602892,
440     0.649927, 0.096170, 0.497074, 0.373521, 0.192189, 0.862151, 0.519444,
441     0.453887, 0.933851, 0.840257, 0.257804, 0.726531, 0.053058, 0.877350,
442     0.362691, 0.882115, 0.220446, 0.028468, 0.140802, 0.700834, 0.243589,
443     0.686821, 0.713278, 0.847948, 0.733421, 0.736723, 0.394684, 0.490921,
444     0.570617, 0.417746, 0.093813, 0.220543, 0.513916, 0.590887, 0.594064,
445     0.706105, 0.453038, 0.113508, 0.159992, 0.386889, 0.953765, 0.417796,
446     0.113420, 0.006823, 0.295146, 0.476111, 0.888938, 0.515592, 0.504579,
447     0.029741, 0.216426, 0.748168, 0.716561, 0.929703, 0.596117, 0.449982,
448     0.666427, 0.990801, 0.940903, 0.237043, 0.408547, 0.034717, 0.457587,
449     0.922463, 0.625603, 0.051651, 0.628568, 0.078641, 0.165159, 0.788560,
450     0.465530, 0.118923, 0.206356, 0.578950, 0.125746, 0.501502, 0.055060,
451     0.014685, 0.017094, 0.559640, 0.044425, 0.233519, 0.307808, 0.760986,
452     0.163223, 0.903925, 0.210969, 0.829650, 0.894726, 0.151872, 0.066693,
453     0.303273, 0.186589, 0.524279, 0.225736, 0.812192, 0.575930, 0.854304,
454     0.890833, 0.741089, 0.642864, 0.356363, 0.860012, 0.849220, 0.935313,
455     0.985758, 0.350722, 0.990373, 0.000443, 0.367815, 0.550013, 0.044868,
456     0.601335, 0.857820, 0.805855, 0.764557, 0.761745, 0.016823, 0.594207,
457     0.656471, 0.168696, 0.660900, 0.959744, 0.355284, 0.185179, 0.185480,
458     0.167477, 0.761110, 0.039784, 0.058310, 0.502199, 0.682648, 0.414673,
459     0.362211, 0.531868, 0.349985, 0.347969, 0.882589, 0.340358, 0.348412,
460     0.250404, 0.890371, 0.393280, 0.851739, 0.748191, 0.199135, 0.616297,
461     0.509936, 0.215958, 0.210504, 0.166407, 0.384654, 0.871404, 0.126151,
462     0.739938, 0.056583, 0.311631, 0.907415, 0.817693, 0.351415, 0.965724,
463     0.319891, 0.034062, 0.380397, 0.682102, 0.565930, 0.730382, 0.030072,
464     0.448519, 0.070741, 0.378484, 0.698924, 0.961112, 0.771764, 0.550663,
465     0.709303, 0.970899, 0.166959, 0.219239, 0.186857, 0.377463, 0.385647,
466     0.571511, 0.248867, 0.511798, 0.311449, 0.305450, 0.823429, 0.218864,
467     0.123142, 0.174844, 0.184588, 0.443034, 0.208906, 0.564986, 0.125136,
468     0.774836, 0.295368, 0.155207, 0.223355, 0.366109, 0.533691, 0.922279,
469     0.327221, 0.305455, 0.472942, 0.036524, 0.276354, 0.639901, 0.255763,
470     0.463211, 0.017364, 0.641410, 0.034722, 0.266231, 0.153207, 0.346171,
471     0.571680, 0.976636, 0.565036, 0.694822, 0.151480, 0.749624, 0.137856,
472     0.360386, 0.314610, 0.262992, 0.135222, 0.609978, 0.418200, 0.358578,
473     0.976087, 0.951891, 0.280856, 0.303307, 0.257346, 0.753798, 0.339831,
474     0.533700, 0.393699, 0.595594, 0.996911, 0.411063, 0.237003, 0.031634,
475     0.677294, 0.390211, 0.377805, 0.248974, 0.366847, 0.942841, 0.943796,
476     0.518327, 0.692465, 0.081653, 0.878713, 0.007074, 0.344645, 0.013936,
477     0.617052, 0.762845, 0.372513, 0.593138, 0.714736, 0.653370, 0.896446,
478     0.972082, 0.407168, 0.236276, 0.505782, 0.800867, 0.831870, 0.502693,
479     0.211930, 0.068873, 0.534327, 0.889224, 0.459084, 0.912132, 0.138197,
480     0.825931, 0.854972, 0.081994, 0.344259, 0.547437, 0.163646, 0.222972,
481     0.554511, 0.508291, 0.236908, 0.171563, 0.271135, 0.609421, 0.764701,
482     0.985871, 0.262790, 0.661147, 0.957953, 0.669958, 0.897423, 0.463734,
483     0.470825, 0.729293, 0.966427, 0.682755, 0.798166, 0.500754, 0.571978,
484     0.257251, 0.412886, 0.710176, 0.083182, 0.267858, 0.792169, 0.427441,
485     0.815295, 0.955815, 0.650413, 0.369805, 0.464106, 0.887320, 0.541368,
486     0.735242, 0.496741, 0.306069, 0.721113, 0.759531, 0.967216, 0.679065,
487     0.429489, 0.864639, 0.142799, 0.900314, 0.593932, 0.109227, 0.583069,
488     0.392098, 0.609981, 0.155047, 0.649349, 0.022867, 0.865222, 0.732531,
489     0.290725, 0.657392, 0.159972, 0.106019, 0.613207, 0.810384, 0.475824,
490     0.077313, 0.697704, 0.017192, 0.812555};
491 
492 static float golden_endtoend_output[] = {
493     -1.881211, -0.028385, -3.585066, 1.939770,  -3.461155, 1.280415,  -4.408978,
494     0.608663,  -2.704937, 1.859742,  -5.777429, 2.691839,  -1.049012, 1.640870,
495     -4.856245, 1.604236,  0.992707,  0.422858,  -4.307465, 1.887332,  -0.884831,
496     -0.154277, -2.634801, 0.586827,  -1.849960, 1.399608,  -4.531559, 1.943591,
497     0.271676,  -2.893054, -2.066826, 0.235467,  -1.248263, -1.164534, -2.640174,
498     -0.112878, -4.386484, 1.253024,  -4.135623, 1.068984,  -0.043579, -0.832957,
499     -3.257258, -0.514396, -1.651174, 0.638630,  -4.364372, 1.548441,  -0.289455,
500     0.539845,  -4.097627, 0.635001,  -0.465071, -0.927701, -2.481498, 0.356616,
501     -2.355012, 0.728806,  -3.340283, 1.609038,  -4.786268, -0.532272, -1.886150,
502     0.254797,  0.746620,  -1.657134, -3.264265, 0.525551,  -1.756837, 0.845446,
503     -5.572190, 1.715797,  -2.856942, 3.394245,  -5.803662, 2.281806,  -3.014739,
504     2.616136,  -4.728482, 1.659984,  -2.106307, 2.711709,  -6.173832, 1.352869,
505     -0.038035, 0.107619,  -4.279774, 2.341930,  -0.980413, -0.119538, -4.049717,
506     1.172128,  -3.477744, 2.602274,  -6.231380, 2.537300,  -0.862214, 0.568722,
507     -3.858362, 0.197867,  -1.725885, 3.687312,  -7.067363, 2.403544,  -0.944963,
508     0.235639,  -3.250094, 0.659117,  -1.459576, 0.426128,  -3.637207, 1.030386,
509     -4.224351, 3.516220,  -6.053367, 0.993473,  -2.182416, -0.762625, -1.884405,
510     -0.113736, -2.572602, 0.329290,  -1.913233, 0.517418,  -0.019757, 0.203176,
511     -3.715881, 0.482136,  -1.912823, 1.357907,  -5.473043, 1.714658,  -3.177160,
512     0.089285,  -3.127669, 1.268076,  0.772498,  -1.622712, -3.850314, 0.436124,
513     -1.495983, 3.439982,  -7.623405, 1.726721,  -0.423979, 0.180201,  -2.902406,
514     0.986457,  -1.845638, 0.460903,  -5.359343, -1.133931, -1.074456, 0.717304,
515     -3.519856, 1.012126,  -0.562301, 1.881967,  -6.716627, 2.525036,  0.945480,
516     0.337081,  -5.210562, 2.572035,  -0.943370, 0.442026,  -2.666313, 0.411296,
517     0.002787,  -0.000735, -2.498933, 0.771719,  -3.568153, 3.833721,  -6.617026,
518     2.813922,  -0.573970, 1.025208,  -3.909923, 1.722648,  -1.406849, 0.719783,
519     -5.207438, 1.819442,  -0.530895, -0.010887, -2.939614, 0.971225,  -1.660297,
520     1.345243,  -4.454571, 2.244876,  -2.021213, 1.756090,  -4.880947, 0.364597,
521     -2.380270, 2.763117,  -5.613013, 2.137534,  0.289101,  -2.279400, -3.365582,
522     0.170028,  -1.142254, -0.709604, -3.656223, 1.804870,  -0.854690, 0.592102,
523     -5.010415, 2.462687,  -1.474710, 0.566002,  -3.621819, -0.391946, -0.423524,
524     -0.631428, -3.513310, 0.962825,  -1.480262, 0.319791,  -3.610137, 1.842339,
525     -0.250073, 1.182022,  -6.249267, 1.604172,  1.153759,  -0.734054, -4.620415,
526     -0.030858, 0.050911,  1.524406,  -4.724010, 1.451846,  -3.277104, 2.414182,
527     -4.605285, 1.846092,  -1.503047, -0.618200, -2.746546, -0.459332, -0.980326,
528     -1.199977, -2.043865, -0.165793, -2.214698, 3.108281,  -7.127830, -0.123065,
529     1.244948,  -3.039923, -4.660061, -0.225957, -0.307210, -1.513205, -2.456005,
530     0.840048,  -0.741445, 2.328635,  -6.015267, 2.723240,  -1.381171, -0.728878,
531     -5.114925, -0.362034, -0.574923, 0.518080,  -3.892457, 1.798948,  0.435119,
532     -0.371696, -2.807571, 1.302864,  -2.063052, 1.036388,  -4.232038, 1.397059,
533     -1.615668, -1.511019, -3.095508, 1.290955,  -3.428723, 2.000287,  -4.196487,
534     1.566983,  0.196957,  0.224343,  -4.926359, -0.691975, -0.214941, 1.546821,
535     -5.384868, 2.290820,  -1.878865, 0.493692,  -4.129823, 2.112036,  0.516558,
536     -2.553077, -2.717338, 0.017146,  -2.016057, 1.628995,  -4.240602, 1.189533,
537     -5.460220, 1.254738,  -4.214903, 0.755659,  -2.893235, 2.937762,  -6.169453,
538     2.035456,  -5.613212, -0.122254, -1.973646, -0.060619, -2.119598, 1.413512,
539     -4.938738, 1.890244,  0.544169,  -2.062413, -3.329637, -0.062515, -1.855805,
540     -0.791297, -2.570353, 0.607615,  0.305812,  0.338930,  -4.150270, 2.274937,
541     0.042653,  0.133825,  -3.538155, 1.523639,  -3.173690, -1.496599, -2.414655,
542     0.464687,  -1.448998, -0.368907, -3.520129, 0.203382,  -2.443626, 1.266233,
543     -3.393848, 0.605911,  -0.015353, 1.402006,  -4.441003, 1.419281,  0.603587,
544     0.434146,  -4.966566, 2.171872,  -0.688264, -0.009981, -4.461103, 1.538354,
545     -5.029816, -0.264424, -1.713510, -0.315258, -1.891606, 0.252074,  -2.419428,
546     0.043970,  -1.291143, 2.048704,  -4.590105, 0.524734,  -1.889576, 0.134836,
547     -3.462745, 1.390663,  -0.112773, 0.402735,  -4.203784, 1.381043,  -1.201634,
548     -1.968277, -1.425637, -0.181725, -1.250742, -2.102041, -3.925464, -1.256797,
549     -3.701354, -1.754610, -1.917231, -1.455910, -1.838006, 2.041781,  -5.666212,
550     2.752957,  -2.659553, 2.553637,  -4.872212, 1.443437,  -2.081846, 3.311263,
551     -5.912457, 1.871049,  0.196148,  -0.307044, -4.024967, 2.149149,  0.361809,
552     0.620415,  -5.939984, 0.180672,  -1.209180, -0.269122, -3.240285, 1.460315,
553     -1.040803, 1.125700,  -6.060366, 0.887767,  -3.214111, 1.314368,  -3.026808,
554     1.023640,  -3.815175, 1.795642,  -4.355603, 1.064454,  -0.046472, 0.618463,
555     -5.941646, 2.861891,  -2.852155, -0.990457, -2.624445, 1.794494,  -1.176747,
556     -0.358159, -3.206776, 1.138721,  -2.819523, -1.825522, -1.450902, -0.187312,
557     -0.808727, 0.636872,  -4.120567, 1.192623,  0.810731,  -1.768519, -3.699450,
558     1.527116,  -2.772720, 3.012835,  -5.912736, 1.599365,  -4.696381, 2.234591,
559     -4.139552, 1.061768,  -1.880089, 3.596274,  -7.006379, 2.382152,  -3.158115,
560     3.844430,  -7.044156, 2.307596,  -2.473970, 1.312644,  -5.467269, 0.197154,
561     -1.530040, 1.762275,  -5.550757, 0.630276,  -3.048947, 1.043777,  -3.096658,
562     1.345893,  -1.329494, 2.065748,  -4.711032, 2.227600,  -0.413321, -0.032428,
563     -4.599650, 1.668734,  -4.351490, -0.200022, -2.359903, 0.021997,  0.116028,
564     1.159718,  -5.093972, -0.142951, -2.409895, 0.906133,  -2.728812, 0.809932,
565     -2.597363, 0.494130,  -2.357861, 0.369825,  -2.165235, 1.148522,  -3.130562,
566     0.759034,  0.646335,  -1.463660, -3.508299, 1.059679,  -1.485465, 1.007319,
567     -4.340716, 1.789864,  -1.590654, 1.612324,  -4.452007, 2.389805,  -5.200148,
568     -1.068398, -1.306923, -0.472408, -0.392165, -0.524996, -2.933478, 1.518430,
569     -1.287781, 0.113422,  -3.020525, 1.338359,  -0.105982, 0.936014,  -4.132197,
570     1.836807,  -0.616589, -1.029716, -3.271347, 0.284889,  -2.653359, 2.135829,
571     -4.643613, 1.627981,  0.287733,  -2.017263, -2.776574, 1.184792,  1.004161,
572     -1.483019, -4.339290, -0.787322, 0.582420,  1.137839,  -5.673941, -0.001862,
573     -1.219142, 0.532561,  -4.457245, 1.826807,  -3.343291, 3.034610,  -6.179855,
574     2.235917,  -4.369989, 4.018128,  -6.632714, 0.926585,  -0.485469, 0.536073,
575     -4.179557, 1.489637,  -0.521762, 1.636089,  -6.137912, 1.500867,  -4.086009,
576     1.961372,  -3.688977, 1.358220,  -1.544034, 1.763837,  -4.357567, 1.852201,
577     -2.018725, 1.046264,  -6.211127, 1.609419,  -0.118441, 1.602284,  -6.242423,
578     1.518578,  -0.604078, 1.106613,  -5.393445, 2.595629,  0.142712,  -1.903953,
579     -2.821177, 0.032758,  -0.009152, 0.184628,  -4.227636, 2.046843,  -2.240138,
580     1.256176,  -5.108516, -0.308447, -2.998571, 4.657396,  -7.582112, 2.510951,
581     -3.535784, 1.704560,  -5.068484, 1.318466,  -3.058265, 3.073172,  -6.998089,
582     3.178849,  -2.420286, 2.277806,  -4.999528, 1.423890,  -1.672914, 0.447460,
583     -4.088940, 1.351087,  -1.051546, -0.417955, -4.042147, 1.604102,  -1.700931,
584     2.796663,  -6.497579, 2.857974,  -0.240828, 0.858001,  -5.778933, 2.778508,
585     -0.406211, 1.300766,  -5.073671, 2.089362,  -0.201673, 1.588396,  -6.000150,
586     2.185055,  -2.332125, 0.768216,  -2.609184, 0.327277,  -3.358943, -1.020736,
587     -2.389984, 0.315512,  -0.561905, 1.948740,  -6.408485, 2.231985,  -0.603652,
588     0.661829,  -5.070386, -1.063058, -0.624796, 1.375772,  -4.379606, 1.929358,
589     -1.047263, 0.739100,  -5.217857, 2.127625,  -5.025338, 0.650344,  -2.068460,
590     0.076936,  -0.457505, -1.050984, -1.917765, 1.150908,  0.782625,  0.855595,
591     -5.321719, 0.787209,  -0.460232, 1.106736,  -5.552326, 2.801043,  -0.360217,
592     -0.434432, -4.273378, 0.967556,  -0.972652, 0.874811,  -5.429918, -0.331039,
593     0.115477,  0.111883,  -5.418786, 1.240546,  -1.842794, 0.505880,  -3.676064,
594     -0.682369, 1.858984,  -0.742566, -5.784060, 0.673239,  -1.280398, 0.280842,
595     -4.848077, 2.214860,  -0.785100, -0.588488, -2.438206, 0.786651,  -1.568752,
596     1.935400,  -6.320256, 2.125338,  -1.476457, -1.651941, -2.695734, 0.007338,
597     -3.280860, 2.310385,  -5.319578, 1.890123,  -0.775723, 0.630606,  -4.321582,
598     1.085521,  -1.847371, 1.188521,  -4.596577, 2.056443,  -2.340172, -0.108501,
599     -3.156392, 0.933279,  -0.495331, 0.122405,  -5.171133, 1.763245,  -0.796913,
600     2.310487,  -7.247197, 2.401678,  -1.908860, 0.043798,  -2.393796, 0.573806,
601     -0.608531, 0.154710,  -4.669001, 0.750680,  0.468380,  0.392591,  -4.755001,
602     2.615217,  -1.957774, 1.153513,  -4.530099, 1.124362,  -3.569415, 1.697154,
603     -3.536335, 0.910758,  -2.976264, 1.833129,  -4.287203, -0.547050, -2.409768,
604     0.061585,  -1.324116, 0.268497,  -2.962222, -1.524245, -2.063413, 0.442058,
605     -4.292337, 3.538863,  -6.699603, 1.718664,  -2.290363, 1.994596,  -6.245037,
606     -0.433084, -0.367059, 1.020297,  -4.940721, 2.902264,  -0.577056, -0.709887,
607     -5.001413, -0.268316, -1.112048, -1.083307, -1.753492, 0.209973,  0.139540,
608     0.917602,  -5.232745, 2.538467,  -2.139234, -0.187388, -1.837249, -0.478582,
609     -0.731653, -0.481550, -2.531261, 1.044770,  0.707750,  0.279971,  -3.221119,
610     1.552074,  -2.373144, 0.859518,  -3.665156, 1.620278,  -1.440871, -0.525581,
611     -2.758271, 1.491873,  -2.302013, 1.119935,  -5.257080, 2.627170,  -3.174739,
612     1.363282,  -4.831639, 1.101076,  -4.337008, 2.689639,  -5.165915, 1.069201,
613     -1.882078, -0.120370, -2.287967, 1.147619,  -1.403616, 1.077150,  -5.084296,
614     1.658236,  -0.919642, 0.487423,  -3.001075, 0.741268,  0.107300,  0.943556,
615     -3.544311, 1.000239,  -1.627171, 2.871253,  -5.179172, 1.429893,  -0.826040,
616     0.188670,  -4.499894, 1.013447,  -2.101299, 0.317516,  -3.452141, -0.833776,
617     -1.362144, 1.272437,  -4.449355, 1.613591,  -2.039873, 2.613175,  -6.229640,
618     1.659790,  -1.595520, -0.237462, -2.744997, 0.337841,  0.148981,  -1.703771,
619     -2.388023, 1.276469,  1.058508,  -0.401642, -4.680769, 0.861881,  -1.336381,
620     1.153080,  -2.834378, 0.721075,  0.900115,  1.360511,  -5.573611, 0.949182,
621     -2.970844, 2.017563,  -5.186108, -0.201038, -1.192824, 0.610142,  -4.450919,
622     -0.897114, -1.812093, 0.422310,  -5.245487, 0.256549,  0.320275,  -2.324150,
623     -2.967040, -0.260536, -0.721467, 0.454148,  -5.058031, 0.526370,  -0.895656,
624     0.732240,  -3.327363, 1.353953,  -1.277912, -0.483171, -1.926713, 0.065044,
625     -2.167506, -0.196606, -1.923437, 0.604962,  -2.088319, 1.406834,  -5.227296,
626     2.247351,  -4.421744, 1.729791,  -5.007922, 1.264769,  -0.897019, 0.922902,
627     -3.887108, 2.087432,  -1.310226, -0.101938, -3.359082, -0.079662, -0.514988,
628     -0.963179, -4.038209, 2.223278,  -0.590083, -2.310458, -1.748338, 0.363406,
629     -0.540731, -0.885913, -4.179595, 2.216781,  -3.044339, -0.447100, -2.446098,
630     0.931101,  -1.676190, 2.096175,  -4.980755, 2.262151,  -1.095047, 1.897516,
631     -5.996138, 2.191038,  0.297128,  -0.780974, -2.884299, 1.195408,  -0.521065,
632     -1.955837, -3.091064, -0.404183, -1.961519, 4.076096,  -7.521851, 2.242064,
633     -1.988043, 0.303300,  -2.422585, 0.322230,  -3.377634, 3.499955,  -7.084434,
634     2.375587,  -0.718851, 2.150076,  -5.412241, 2.374280,  -2.006088, 2.229828,
635     -5.848188, 2.543077,  -2.171042, 2.096026,  -5.300007, 0.141405,  -1.187745,
636     0.105340,  -4.003816, 1.034281,  -3.980804, 1.856709,  -5.103042, 0.623737,
637     -2.080307, 0.896140,  -3.104050, 0.983158,  -0.424898, -1.154270, -3.805728,
638     1.978917,  -1.314387, 1.235096,  -3.148906, 1.113173,  0.111713,  2.055213,
639     -7.565283, 2.100342};
640 const std::initializer_list<float> biases = {
641     0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
642     -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
643     0.37197268,  0.61957061,  0.3956964,  -0.37609905};
644 
645 const std::initializer_list<float> recurrent_weights = {
646     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
647     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
648     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
649     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
650     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
651     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
652     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
653     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
654     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
655     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
656     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
657     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
658     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
659     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
660     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
661     0.1};
662 
663 class BidirectionalRNNOpModel : public SingleOpModel {
664  public:
BidirectionalRNNOpModel(int batches,int sequence_len,int fw_units,int bw_units,int input_size,int aux_input_size,AuxInputMode aux_input_mode,bool time_major,bool merge_outputs,bool quantize_weights=false,bool asymmetric_quantize_weights=false)665   BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
666                           int bw_units, int input_size, int aux_input_size,
667                           AuxInputMode aux_input_mode, bool time_major,
668                           bool merge_outputs, bool quantize_weights = false,
669                           bool asymmetric_quantize_weights = false)
670       : batches_(batches),
671         sequence_len_(sequence_len),
672         fw_units_(fw_units),
673         bw_units_(bw_units),
674         input_size_(input_size),
675         aux_input_size_(aux_input_size),
676         quantize_weights_(quantize_weights) {
677     const TensorType tensor_type =
678         quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
679     input_ = AddInput(TensorType_FLOAT32);
680     fw_weights_ = AddInput(tensor_type);
681     fw_recurrent_weights_ = AddInput(tensor_type);
682     fw_bias_ = AddInput(TensorType_FLOAT32);
683     fw_hidden_state_ = AddVariableInput(TensorType_FLOAT32);
684     bw_weights_ = AddInput(tensor_type);
685     bw_recurrent_weights_ = AddInput(tensor_type);
686     bw_bias_ = AddInput(TensorType_FLOAT32);
687     bw_hidden_state_ = AddVariableInput(TensorType_FLOAT32);
688 
689     const auto input_shape =
690         (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
691                      : std::vector<int>({batches_, sequence_len_, input_size_});
692 
693     std::vector<int> aux_input_shape = {0};
694     std::vector<int> aux_fw_weights_shape = {0};
695     std::vector<int> aux_bw_weights_shape = {0};
696     if (aux_input_mode != AuxInputMode::kNoAuxInput) {
697       aux_input_ = AddInput(TensorType_FLOAT32);
698       aux_input_shape =
699           (time_major)
700               ? std::vector<int>({sequence_len_, batches_, aux_input_size_})
701               : std::vector<int>({batches_, sequence_len_, aux_input_size_});
702     } else {
703       aux_input_ = AddNullInput();
704     }
705 
706     if (aux_input_mode == AuxInputMode::kCrossLinking) {
707       aux_fw_weights_ = AddInput(tensor_type);
708       aux_bw_weights_ = AddInput(tensor_type);
709 
710       aux_fw_weights_shape = {fw_units, aux_input_size_};
711       aux_bw_weights_shape = {bw_units, aux_input_size_};
712     } else {
713       aux_fw_weights_ = AddNullInput();
714       aux_bw_weights_ = AddNullInput();
715     }
716 
717     fw_output_ = AddOutput(TensorType_FLOAT32);
718     if (!merge_outputs) {
719       bw_output_ = AddOutput(TensorType_FLOAT32);
720     }
721 
722     SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
723                  BuiltinOptions_BidirectionalSequenceRNNOptions,
724                  CreateBidirectionalSequenceRNNOptions(
725                      builder_, time_major, ActivationFunctionType_RELU,
726                      merge_outputs, asymmetric_quantize_weights)
727                      .Union());
728 
729     BuildInterpreter({
730         input_shape,               // input
731         {fw_units_, input_size_},  // fw_weights
732         {fw_units_, fw_units_},    // fw_recurrent_weights
733         {fw_units_},               // fw_bias
734         {batches_, fw_units_},     // fw_hidden_state
735         {bw_units_, input_size_},  // bw_weights
736         {bw_units_, bw_units_},    // bw_recurrent_weights
737         {bw_units_},               // bw_bias
738         {batches_, bw_units_},     // bw_hidden_state
739         aux_input_shape,           // aux_input
740         aux_fw_weights_shape,      // aux_fw_weights
741         aux_bw_weights_shape,      // aux_bw_weights
742     });
743   }
744 
SetFwBias(std::initializer_list<float> f)745   void SetFwBias(std::initializer_list<float> f) {
746     PopulateTensor(fw_bias_, f);
747   }
748 
SetBwBias(std::initializer_list<float> f)749   void SetBwBias(std::initializer_list<float> f) {
750     PopulateTensor(bw_bias_, f);
751   }
752 
SetFwWeights(const std::vector<float> & f)753   void SetFwWeights(const std::vector<float>& f) {
754     if (quantize_weights_) {
755       SymmetricQuantizeAndPopulate(fw_weights_, f);
756     } else {
757       PopulateTensor(fw_weights_, f);
758     }
759   }
760 
SetBwWeights(const std::vector<float> & f)761   void SetBwWeights(const std::vector<float>& f) {
762     if (quantize_weights_) {
763       SymmetricQuantizeAndPopulate(bw_weights_, f);
764     } else {
765       PopulateTensor(bw_weights_, f);
766     }
767   }
768 
SetFwRecurrentWeights(const std::vector<float> & f)769   void SetFwRecurrentWeights(const std::vector<float>& f) {
770     if (quantize_weights_) {
771       SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
772     } else {
773       PopulateTensor(fw_recurrent_weights_, f);
774     }
775   }
776 
SetBwRecurrentWeights(const std::vector<float> & f)777   void SetBwRecurrentWeights(const std::vector<float>& f) {
778     if (quantize_weights_) {
779       SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
780     } else {
781       PopulateTensor(bw_recurrent_weights_, f);
782     }
783   }
784 
SetInput(std::initializer_list<float> data)785   void SetInput(std::initializer_list<float> data) {
786     PopulateTensor(input_, data);
787   }
788 
SetInput(int offset,float * begin,float * end)789   void SetInput(int offset, float* begin, float* end) {
790     PopulateTensor(input_, offset, begin, end);
791   }
792 
SetAuxInput(int offset,float * begin,float * end)793   void SetAuxInput(int offset, float* begin, float* end) {
794     PopulateTensor(aux_input_, offset, begin, end);
795   }
796 
SetAuxFwWeights(const std::vector<float> & f)797   void SetAuxFwWeights(const std::vector<float>& f) {
798     if (quantize_weights_) {
799       SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
800     } else {
801       PopulateTensor(aux_fw_weights_, f);
802     }
803   }
804 
SetAuxBwWeights(const std::vector<float> & f)805   void SetAuxBwWeights(const std::vector<float>& f) {
806     if (quantize_weights_) {
807       SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
808     } else {
809       PopulateTensor(aux_bw_weights_, f);
810     }
811   }
812 
GetFwOutput()813   std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
GetBwOutput()814   std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
815 
input_size()816   int input_size() { return input_size_; }
aux_input_size()817   int aux_input_size() { return aux_input_size_; }
num_fw_units()818   int num_fw_units() { return fw_units_; }
num_bw_units()819   int num_bw_units() { return bw_units_; }
num_batches()820   int num_batches() { return batches_; }
sequence_len()821   int sequence_len() { return sequence_len_; }
822 
823  private:
824   int input_;
825   int fw_weights_;
826   int fw_recurrent_weights_;
827   int fw_bias_;
828   int fw_hidden_state_;
829   int fw_output_;
830   int bw_weights_;
831   int bw_recurrent_weights_;
832   int bw_bias_;
833   int bw_hidden_state_;
834   int bw_output_;
835   int aux_input_;
836   int aux_fw_weights_;
837   int aux_bw_weights_;
838 
839   int batches_;
840   int sequence_len_;
841   int fw_units_;
842   int bw_units_;
843   int input_size_;
844   int aux_input_size_;
845   bool quantize_weights_;
846 };
847 
848 // Declare LSTMOpTest as a parameterized test.
849 class BidirectionalRNNOpTest
850     : public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
851 
852 INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
853                          ::testing::Combine(
854                              /*quantize_weights*/ ::testing::Bool(),
855                              /*asymmetric_quantize_inputs*/ ::testing::Bool()));
856 
857 // TODO(mirkov): add another test which directly compares to TF once TOCO
858 // supports the conversion from dynamic_rnn with BasicRNNCell.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTest)859 TEST_P(BidirectionalRNNOpTest, ClosedBoxTest) {
860   auto params = GetParam();
861   const bool quantize_weights = std::get<0>(params);
862   const bool asymmetric_quantize_inputs = std::get<1>(params);
863   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
864                               /*fw_units=*/16, /*bw_units=*/16,
865                               /*input_size=*/8, /*aux_input_size=*/0,
866                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
867                               /*time_major=*/false,
868                               /*merge_outputs=*/false, quantize_weights,
869                               asymmetric_quantize_inputs);
870   rnn.SetFwWeights(weights);
871   rnn.SetBwWeights(weights);
872   rnn.SetFwBias(biases);
873   rnn.SetBwBias(biases);
874   rnn.SetFwRecurrentWeights(recurrent_weights);
875   rnn.SetBwRecurrentWeights(recurrent_weights);
876 
877   const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
878   float* batch_start = rnn_input;
879   float* batch_end = batch_start + input_sequence_size;
880   rnn.SetInput(0, batch_start, batch_end);
881   rnn.SetInput(input_sequence_size, batch_start, batch_end);
882 
883   rnn.Invoke();
884 
885   float* golden_fw_start = rnn_golden_fw_output;
886   float* golden_fw_end =
887       golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
888   std::vector<float> fw_expected;
889   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
890   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
891   EXPECT_THAT(rnn.GetFwOutput(),
892               ElementsAreArray(ArrayFloatNear(
893                   fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
894 
895   float* golden_bw_start = rnn_golden_bw_output;
896   float* golden_bw_end =
897       golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
898   std::vector<float> bw_expected;
899   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
900   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
901   EXPECT_THAT(rnn.GetBwOutput(),
902               ElementsAreArray(ArrayFloatNear(
903                   bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
904 }
905 
906 // Same as ClosedBox test, but input is reshuffled to time_major format.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTestTimeMajor)907 TEST_P(BidirectionalRNNOpTest, ClosedBoxTestTimeMajor) {
908   auto params = GetParam();
909   const bool quantize_weights = std::get<0>(params);
910   const bool asymmetric_quantize_inputs = std::get<1>(params);
911   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
912                               /*fw_units=*/16, /*bw_units=*/16,
913                               /*input_size=*/8, /*aux_input_size=*/0,
914                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
915                               /*time_major=*/true,
916                               /*merge_outputs=*/false, quantize_weights,
917                               asymmetric_quantize_inputs);
918   rnn.SetFwWeights(weights);
919   rnn.SetBwWeights(weights);
920   rnn.SetFwBias(biases);
921   rnn.SetBwBias(biases);
922   rnn.SetFwRecurrentWeights(recurrent_weights);
923   rnn.SetBwRecurrentWeights(recurrent_weights);
924 
925   // Insert the inputs in time_major format. The batch_major format is:
926   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
927   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
928   for (int i = 0; i < rnn.sequence_len(); i++) {
929     float* batch_start = rnn_input + i * rnn.input_size();
930     float* batch_end = batch_start + rnn.input_size();
931     // The two batches are identical.
932     rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
933     rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
934   }
935 
936   rnn.Invoke();
937 
938   std::vector<float> fw_expected;
939   for (int i = 0; i < rnn.sequence_len(); i++) {
940     float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
941     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
942     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
943     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
944   }
945   constexpr float kHybridTolerance = 3.57e-1;
946   constexpr float kFloatTolerance = 1e-5;
947   EXPECT_THAT(
948       rnn.GetFwOutput(),
949       ElementsAreArray(ArrayFloatNear(
950           fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
951 }
952 
953 // Same as ClosedBox test, yet with merged outputs.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTestMergeOutputs)954 TEST_P(BidirectionalRNNOpTest, ClosedBoxTestMergeOutputs) {
955   auto params = GetParam();
956   const bool quantize_weights = std::get<0>(params);
957   const bool asymmetric_quantize_inputs = std::get<1>(params);
958   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
959                               /*fw_units=*/16, /*bw_units=*/16,
960                               /*input_size=*/8, /*aux_input_size=*/0,
961                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
962                               /*time_major=*/false,
963                               /*merge_outputs=*/true, quantize_weights,
964                               asymmetric_quantize_inputs);
965   rnn.SetFwWeights(weights);
966   rnn.SetBwWeights(weights);
967   rnn.SetFwBias(biases);
968   rnn.SetBwBias(biases);
969   rnn.SetFwRecurrentWeights(recurrent_weights);
970   rnn.SetBwRecurrentWeights(recurrent_weights);
971 
972   const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
973   float* batch_start = rnn_input;
974   float* batch_end = batch_start + input_sequence_size;
975   rnn.SetInput(0, batch_start, batch_end);
976   rnn.SetInput(input_sequence_size, batch_start, batch_end);
977 
978   rnn.Invoke();
979 
980   std::vector<float> merged_expected;
981   for (int bid = 0; bid < rnn.num_batches(); bid++) {
982     for (int step = 0; step < rnn.sequence_len(); step++) {
983       merged_expected.insert(
984           merged_expected.end(),
985           rnn_golden_fw_output + rnn.num_fw_units() * step,
986           rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
987       merged_expected.insert(
988           merged_expected.end(),
989           rnn_golden_bw_output + rnn.num_bw_units() * step,
990           rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
991     }
992   }
993   EXPECT_THAT(rnn.GetFwOutput(),
994               ElementsAreArray(ArrayFloatNear(
995                   merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
996 }
997 
998 // Same as ClosedBox test, but input is reshuffled to time_major format.
TEST(BidirectionalRNNOpTest,ClosedBoxTestTimeMajorMergeOutputs)999 TEST(BidirectionalRNNOpTest, ClosedBoxTestTimeMajorMergeOutputs) {
1000   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1001                               /*fw_units=*/16, /*bw_units=*/16,
1002                               /*input_size=*/8, /*aux_input_size=*/0,
1003                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1004                               /*time_major=*/true,
1005                               /*merge_outputs=*/true);
1006   rnn.SetFwWeights(weights);
1007   rnn.SetBwWeights(weights);
1008   rnn.SetFwBias(biases);
1009   rnn.SetBwBias(biases);
1010   rnn.SetFwRecurrentWeights(recurrent_weights);
1011   rnn.SetBwRecurrentWeights(recurrent_weights);
1012 
1013   // Insert the inputs in time_major format. The batch_major format is:
1014   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1015   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1016   for (int i = 0; i < rnn.sequence_len(); i++) {
1017     float* batch_start = rnn_input + i * rnn.input_size();
1018     float* batch_end = batch_start + rnn.input_size();
1019     // The two batches are identical.
1020     rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1021     rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1022   }
1023 
1024   rnn.Invoke();
1025 
1026   std::vector<float> merged_expected;
1027   for (int step = 0; step < rnn.sequence_len(); step++) {
1028     for (int bid = 0; bid < rnn.num_batches(); bid++) {
1029       merged_expected.insert(
1030           merged_expected.end(),
1031           rnn_golden_fw_output + rnn.num_fw_units() * step,
1032           rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
1033       merged_expected.insert(
1034           merged_expected.end(),
1035           rnn_golden_bw_output + rnn.num_bw_units() * step,
1036           rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
1037     }
1038   }
1039   EXPECT_THAT(rnn.GetFwOutput(),
1040               ElementsAreArray(ArrayFloatNear(merged_expected)));
1041 }
1042 
1043 // Check that if the input sequence is reversed the outputs are the same just
1044 // forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest,ClosedBoxTestReverseInputs)1045 TEST(BidirectionalRNNOpTest, ClosedBoxTestReverseInputs) {
1046   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1047                               /*fw_units=*/16, /*bw_units=*/16,
1048                               /*input_size=*/8, /*aux_input_size=*/0,
1049                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1050                               /*time_major=*/false,
1051                               /*merge_outputs=*/false);
1052   rnn.SetFwWeights(weights);
1053   rnn.SetBwWeights(weights);
1054   rnn.SetFwBias(biases);
1055   rnn.SetBwBias(biases);
1056   rnn.SetFwRecurrentWeights(recurrent_weights);
1057   rnn.SetBwRecurrentWeights(recurrent_weights);
1058 
1059   // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
1060   // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
1061   for (int i = 0; i < rnn.sequence_len(); i++) {
1062     float* batch_start = rnn_input + i * rnn.input_size();
1063     float* batch_end = batch_start + rnn.input_size();
1064     const int reverse_idx = rnn.sequence_len() - i - 1;
1065     rnn.SetInput(reverse_idx * rnn.input_size(), batch_start, batch_end);
1066     rnn.SetInput((rnn.sequence_len() + reverse_idx) * rnn.input_size(),
1067                  batch_start, batch_end);
1068   }
1069 
1070   rnn.Invoke();
1071 
1072   // The forward and backward outputs are swapped.
1073   std::vector<float> fw_expected;  // consider using std::deque instead.
1074   for (int i = 0; i < rnn.sequence_len(); i++) {
1075     float* golden_fw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1076     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1077     fw_expected.insert(fw_expected.begin(), golden_fw_start, golden_fw_end);
1078   }
1079   fw_expected.insert(fw_expected.end(), fw_expected.begin(), fw_expected.end());
1080   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1081 
1082   std::vector<float> bw_expected;
1083   for (int i = 0; i < rnn.sequence_len(); i++) {
1084     float* golden_bw_start = rnn_golden_fw_output + i * rnn.num_bw_units();
1085     float* golden_bw_end = golden_bw_start + rnn.num_bw_units();
1086     bw_expected.insert(bw_expected.begin(), golden_bw_start, golden_bw_end);
1087   }
1088   bw_expected.insert(bw_expected.end(), bw_expected.begin(), bw_expected.end());
1089   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1090 }
1091 
1092 // Tests an end-to-end neural network with a Bidirectional RNN followed by a
1093 // DNN that aggregates the outputs from the two sequences.
TEST(BidirectionalRNNOpTest,EndToEndTest)1094 TEST(BidirectionalRNNOpTest, EndToEndTest) {
1095   BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
1096                               /*fw_units=*/16, /*bw_units=*/16,
1097                               /*input_size=*/8, /*aux_input_size=*/0,
1098                               /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1099                               /*time_major=*/false,
1100                               /*merge_outputs=*/false);
1101   const int output_size = 4;
1102   float dnn_weights[] = {
1103       -0.5782342,  -0.052212059, 0.73036242,  -0.81216097, -0.80088139,
1104       -0.23420811, -0.39647382,  0.31423986,  0.61819065,  -0.73659575,
1105       -0.89698344, -0.8931554,   -0.0845688,  0.5617367,   0.38415289,
1106       -0.11487955, -0.7617774,   0.17927337,  0.15726972,  0.059798479,
1107       0.19009054,  -0.27616632,  -0.39142907, 0.77744663,  -0.046830714,
1108       -0.6603595,  0.21945822,   0.051494241, 0.23785079,  0.19239247,
1109       -0.53268754, 0.65961659,   -0.85981959, -0.80232513, 0.84745562,
1110       -0.66070104, -0.036533296, -0.54901814, 0.65353882,  -0.41834265,
1111       -0.28561389, 0.75655544,   -0.31149811, 0.62981737,  0.31829214,
1112       -0.92734522, -0.48506218,  0.55651462,  0.25192821,  0.67220747,
1113       -0.3836869,  -0.55798125,  -0.60395885, 0.22488403,  -0.78053463,
1114       0.3492105,   0.56452453,   0.4389236,   -0.59929526, -0.19762468,
1115       -0.36868393, -0.13198286,  -0.53800809, -0.22850353};
1116 
1117   std::initializer_list<float> dnn_biases = {0.29177809, -0.98799044,
1118                                              0.065919638, 0.68781924};
1119 
1120   rnn.SetFwWeights(weights);
1121   rnn.SetBwWeights(weights);
1122   rnn.SetFwBias(biases);
1123   rnn.SetBwBias(biases);
1124   rnn.SetFwRecurrentWeights(recurrent_weights);
1125   rnn.SetBwRecurrentWeights(recurrent_weights);
1126 
1127   const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
1128   const int output_sequence_size = output_size * rnn.sequence_len();
1129   const int num_examples = 64;
1130   for (int k = 0; k < num_examples; k++) {
1131     float* batch_start = endtoend_input + k * input_sequence_size;
1132     float* batch_end = batch_start + input_sequence_size;
1133     rnn.SetInput(0, batch_start, batch_end);
1134 
1135     rnn.Invoke();
1136 
1137     std::vector<float> fw_output = rnn.GetFwOutput();
1138     std::vector<float> bw_output = rnn.GetBwOutput();
1139     EXPECT_EQ(fw_output.size(), bw_output.size());
1140 
1141     std::transform(fw_output.begin(), fw_output.end(), bw_output.begin(),
1142                    fw_output.begin(), std::plus<float>());
1143 
1144     std::vector<float> sequence_result;
1145     for (int s = 0; s < rnn.sequence_len(); s++) {
1146       const float* rnn_output = fw_output.data() + s * rnn.num_fw_units();
1147       std::vector<float> results(dnn_biases);
1148       for (int i = 0; i < output_size; i++) {
1149         for (int j = 0; j < rnn.num_fw_units(); j++) {
1150           results[i] += *(rnn_output + j) * dnn_weights[output_size * j + i];
1151         }
1152       }
1153       sequence_result.insert(sequence_result.end(), results.begin(),
1154                              results.end());
1155     }
1156 
1157     float* golden_start = golden_endtoend_output + k * output_sequence_size;
1158     float* golden_end = golden_start + output_sequence_size;
1159 
1160     std::vector<float> expected;
1161     expected.insert(expected.end(), golden_start, golden_end);
1162     EXPECT_THAT(sequence_result, ElementsAreArray(ArrayFloatNear(expected)));
1163   }
1164 }
1165 
1166 // Same as ClosedBox test, but has an auxiliary input. The layer has no
1167 // cross-linking, i.e. the regular input is passed as an input to the forward
1168 // network only and the auxiliary input is passed as an input to the backward
1169 // network only.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingRegularAndAuxInput)1170 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularAndAuxInput) {
1171   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1172                               /*fw_units=*/16, /*bw_units=*/16,
1173                               /*input_size=*/8, /*aux_input_size=*/8,
1174                               /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1175                               /*time_major=*/true,
1176                               /*merge_outputs=*/false);
1177   rnn.SetFwWeights(weights);
1178   rnn.SetBwWeights(weights);
1179   rnn.SetFwBias(biases);
1180   rnn.SetBwBias(biases);
1181   rnn.SetFwRecurrentWeights(recurrent_weights);
1182   rnn.SetBwRecurrentWeights(recurrent_weights);
1183 
1184   // Insert the inputs in time_major format. The batch_major format is:
1185   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1186   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1187   for (int i = 0; i < rnn.sequence_len(); i++) {
1188     float* batch_start = rnn_input + i * rnn.input_size();
1189     float* batch_end = batch_start + rnn.input_size();
1190     // The two batches are identical.
1191     // Also make aux input the same as input.
1192     rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1193     rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1194     rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1195     rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1196   }
1197 
1198   rnn.Invoke();
1199 
1200   std::vector<float> fw_expected;
1201   std::vector<float> bw_expected;
1202   for (int i = 0; i < rnn.sequence_len(); i++) {
1203     float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1204     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1205     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1206     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1207 
1208     float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1209     float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
1210     bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1211     bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1212   }
1213   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1214   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1215 }
1216 
1217 // Same as above but the auxiliary input is set to zeroes. This test makes sure
1218 // that the forward network works as expected in a no-cross-linking mode.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingRegularInputOnly)1219 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularInputOnly) {
1220   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1221                               /*fw_units=*/16, /*bw_units=*/16,
1222                               /*input_size=*/8, /*aux_input_size=*/8,
1223                               /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1224                               /*time_major=*/true,
1225                               /*merge_outputs=*/false);
1226   rnn.SetFwWeights(weights);
1227   rnn.SetBwWeights(weights);
1228   rnn.SetFwBias(biases);
1229   rnn.SetBwBias(biases);
1230   rnn.SetFwRecurrentWeights(recurrent_weights);
1231   rnn.SetBwRecurrentWeights(recurrent_weights);
1232 
1233   // Initialize bw inputs with zeros.
1234   std::vector<float> bw_inputs(rnn.input_size(), 0);
1235 
1236   // Insert the inputs in time_major format. The batch_major format is:
1237   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1238   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1239   for (int i = 0; i < rnn.sequence_len(); i++) {
1240     float* batch_start = rnn_input + i * rnn.input_size();
1241     float* batch_end = batch_start + rnn.input_size();
1242     // The two batches are identical.
1243     // Also make aux input the same as input.
1244     rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1245     rnn.SetAuxInput(2 * i * rnn.input_size(), &bw_inputs[0],
1246                     &bw_inputs[bw_inputs.size() - 1]);
1247     rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1248     rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), &bw_inputs[0],
1249                     &bw_inputs[bw_inputs.size() - 1]);
1250   }
1251 
1252   rnn.Invoke();
1253 
1254   std::vector<float> fw_expected;
1255   for (int i = 0; i < rnn.sequence_len(); i++) {
1256     float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1257     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1258     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1259     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1260   }
1261   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1262 }
1263 
1264 // Same as above but the regular (i.e. not auxiliary) input is set to zeroes.
1265 // This test makes sure that the backward network works as expected in a
1266 // no-cross-linking mode.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingAuxInputOnly)1267 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingAuxInputOnly) {
1268   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1269                               /*fw_units=*/16, /*bw_units=*/16,
1270                               /*input_size=*/8, /*aux_input_size=*/8,
1271                               /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1272                               /*time_major=*/true,
1273                               /*merge_outputs=*/false);
1274   rnn.SetFwWeights(weights);
1275   rnn.SetBwWeights(weights);
1276   rnn.SetFwBias(biases);
1277   rnn.SetBwBias(biases);
1278   rnn.SetFwRecurrentWeights(recurrent_weights);
1279   rnn.SetBwRecurrentWeights(recurrent_weights);
1280 
1281   // Initialize bw inputs with zeros.
1282   std::vector<float> fw_inputs(rnn.input_size(), 0);
1283 
1284   // Insert the inputs in time_major format. The batch_major format is:
1285   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1286   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1287   for (int i = 0; i < rnn.sequence_len(); i++) {
1288     float* batch_start = rnn_input + i * rnn.input_size();
1289     float* batch_end = batch_start + rnn.input_size();
1290     // The two batches are identical.
1291     // Also make aux input the same as input.
1292     rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1293     rnn.SetInput(2 * i * rnn.input_size(), &fw_inputs[0],
1294                  &fw_inputs[fw_inputs.size() - 1]);
1295     rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1296     rnn.SetInput((2 * i + 1) * rnn.input_size(), &fw_inputs[0],
1297                  &fw_inputs[fw_inputs.size() - 1]);
1298   }
1299 
1300   rnn.Invoke();
1301 
1302   std::vector<float> bw_expected;
1303   for (int i = 0; i < rnn.sequence_len(); i++) {
1304     float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1305     float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
1306     bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1307     bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1308   }
1309   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1310 }
1311 
1312 // Same as ClosedBox test, but an input is passed to auxiliary input instead of
1313 // the regular one. Regular input and weights are set to zero.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingAuxInputOnly)1314 TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnly) {
1315   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1316                               /*fw_units=*/16, /*bw_units=*/16,
1317                               /*input_size=*/8, /*aux_input_size=*/8,
1318                               /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1319                               /*time_major=*/false,
1320                               /*merge_outputs=*/false);
1321   rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
1322   rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
1323   rnn.SetFwBias(biases);
1324   rnn.SetBwBias(biases);
1325   rnn.SetFwRecurrentWeights(recurrent_weights);
1326   rnn.SetBwRecurrentWeights(recurrent_weights);
1327   rnn.SetAuxFwWeights(weights);
1328   rnn.SetAuxBwWeights(weights);
1329 
1330   const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
1331   std::vector<float> zero_input(input_sequence_size, 0.f);
1332   float* batch_start = rnn_input;
1333   float* batch_end = batch_start + input_sequence_size;
1334   // Set batch 0 inputs
1335   rnn.SetInput(0, zero_input.data(), zero_input.data() + zero_input.size());
1336   rnn.SetAuxInput(0, batch_start, batch_end);
1337   // Set batch 1 inputs
1338   rnn.SetInput(input_sequence_size, zero_input.data(),
1339                zero_input.data() + zero_input.size());
1340   rnn.SetAuxInput(input_sequence_size, batch_start, batch_end);
1341 
1342   rnn.Invoke();
1343 
1344   float* golden_fw_start = rnn_golden_fw_output;
1345   float* golden_fw_end =
1346       golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
1347   std::vector<float> fw_expected;
1348   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1349   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1350   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1351 
1352   float* golden_bw_start = rnn_golden_bw_output;
1353   float* golden_bw_end =
1354       golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
1355   std::vector<float> bw_expected;
1356   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1357   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1358   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1359 }
1360 
1361 // Same as ClosedBox test, but an input is passed to auxiliary input instead of
1362 // the regular one. Regular input and weights are set to zero. Time major inputs
1363 // and outputs.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingAuxInputOnlyTimeMajor)1364 TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
1365   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1366                               /*fw_units=*/16, /*bw_units=*/16,
1367                               /*input_size=*/8, /*aux_input_size=*/8,
1368                               /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1369                               /*time_major=*/true,
1370                               /*merge_outputs=*/false);
1371   rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
1372   rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
1373   rnn.SetFwBias(biases);
1374   rnn.SetBwBias(biases);
1375   rnn.SetFwRecurrentWeights(recurrent_weights);
1376   rnn.SetBwRecurrentWeights(recurrent_weights);
1377   rnn.SetAuxFwWeights(weights);
1378   rnn.SetAuxBwWeights(weights);
1379 
1380   std::vector<float> zero_input(rnn.input_size(), 0.f);
1381 
1382   // Insert the inputs in time_major format. The batch_major format is:
1383   // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1384   // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1385   for (int i = 0; i < rnn.sequence_len(); i++) {
1386     float* batch_start = rnn_input + i * rnn.input_size();
1387     float* batch_end = batch_start + rnn.input_size();
1388     // The two batches are identical.
1389     // Set batch 0 inputs
1390     rnn.SetInput(2 * i * rnn.input_size(), &zero_input.front(),
1391                  &zero_input.back() + 1);
1392     rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1393     // Set batch 1 inputs
1394     rnn.SetInput((2 * i + 1) * rnn.input_size(), &zero_input.front(),
1395                  &zero_input.back() + 1);
1396     rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1397   }
1398 
1399   rnn.Invoke();
1400 
1401   std::vector<float> fw_expected;
1402   for (int i = 0; i < rnn.sequence_len(); i++) {
1403     float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1404     float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1405     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1406     fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1407   }
1408   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1409 }
1410 
1411 // Same as ClosedBox test, but the input tensor and weights tensor are split
1412 // along the last dimension and passed to both regular and auxiliary inputs and
1413 // weights. The output in this case is the same. To understand this, let's
1414 // define W and V as regular input weights matrix and auxiliary input weights
1415 // matrix correspondingly. It's easy to see that this is equivalent to a regular
1416 // RNN with weights U = (W|V) and z^T = x^T | y^T, where .|. denotes
1417 // concatenation along horizontal axis:
1418 //   f(z) = Uz + b
1419 // is equivalent to:
1420 //   f((x^T|y^T)^T) = (Wx + Vy) + b.
run_closedbox_test_with_input_split(int input_size,int aux_input_size)1421 void run_closedbox_test_with_input_split(int input_size, int aux_input_size) {
1422   const int num_units = 16;
1423   BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1424                               /*fw_units=*/num_units, /*bw_units=*/num_units,
1425                               input_size, aux_input_size,
1426                               /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1427                               /*time_major=*/false,
1428                               /*merge_outputs=*/false);
1429   std::vector<float> reg_weights(num_units * rnn.input_size());
1430   std::vector<float> aux_weights(num_units * rnn.aux_input_size());
1431   int full_weights_size = weights.size();
1432   int reg_weights_offset = 0;
1433   int aux_weights_offset = 0;
1434   int weights_offset = 0;
1435   // Alternating copying to regular input weights and auxiliary input weights to
1436   // split the original weight matrix in half along the last axis.
1437   while (weights_offset < full_weights_size) {
1438     std::copy(weights.begin() + weights_offset,
1439               weights.begin() + weights_offset + rnn.input_size(),
1440               reg_weights.begin() + reg_weights_offset);
1441     weights_offset += rnn.input_size();
1442     reg_weights_offset += rnn.input_size();
1443 
1444     std::copy(weights.begin() + weights_offset,
1445               weights.begin() + weights_offset + rnn.aux_input_size(),
1446               aux_weights.begin() + aux_weights_offset);
1447     weights_offset += rnn.aux_input_size();
1448     aux_weights_offset += rnn.aux_input_size();
1449   }
1450 
1451   rnn.SetFwWeights(reg_weights);
1452   rnn.SetBwWeights(reg_weights);
1453   rnn.SetFwBias(biases);
1454   rnn.SetBwBias(biases);
1455   rnn.SetFwRecurrentWeights(recurrent_weights);
1456   rnn.SetBwRecurrentWeights(recurrent_weights);
1457   rnn.SetAuxFwWeights(aux_weights);
1458   rnn.SetAuxBwWeights(aux_weights);
1459 
1460   int full_input_size =
1461       (rnn.input_size() + rnn.aux_input_size()) * rnn.sequence_len();
1462   int reg_input_offset = 0;
1463   int aux_input_offset = 0;
1464   // Alternating copying to regular input tensor and auxiliary input tensor to
1465   // split the original input matrix in half along the last axis.
1466   for (int batch = 0; batch < 2; ++batch) {
1467     int input_offset = 0;
1468     while (input_offset < full_input_size) {
1469       rnn.SetInput(reg_input_offset, rnn_input + input_offset,
1470                    rnn_input + input_offset + rnn.input_size());
1471       input_offset += rnn.input_size();
1472       reg_input_offset += rnn.input_size();
1473 
1474       rnn.SetAuxInput(aux_input_offset, rnn_input + input_offset,
1475                       rnn_input + input_offset + rnn.aux_input_size());
1476       input_offset += rnn.aux_input_size();
1477       aux_input_offset += rnn.aux_input_size();
1478     }
1479   }
1480 
1481   rnn.Invoke();
1482 
1483   float* golden_fw_start = rnn_golden_fw_output;
1484   float* golden_fw_end =
1485       golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
1486   std::vector<float> fw_expected;
1487   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1488   fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1489   EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1490 
1491   float* golden_bw_start = rnn_golden_bw_output;
1492   float* golden_bw_end =
1493       golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
1494   std::vector<float> bw_expected;
1495   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1496   bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1497   EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1498 }
1499 
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingRegularAndAuxInputEvenSplit)1500 TEST(BidirectionalRNNOpTest,
1501      ClosedBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
1502   run_closedbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
1503 }
1504 
1505 // Same as above but the input tensor and the weights tensor are split unevenly.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingRegularAndAuxInputUnevenSplit)1506 TEST(BidirectionalRNNOpTest,
1507      ClosedBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
1508   run_closedbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
1509 }
1510 
1511 }  // namespace
1512 }  // namespace tflite
1513