1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Bidirectional RNN op.
16
17 #include <algorithm>
18 #include <functional>
19 #include <initializer_list>
20 #include <iterator>
21 #include <tuple>
22 #include <vector>
23
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
27 #include "tensorflow/lite/kernels/test_util.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29
30 namespace tflite {
31 namespace {
32
33 enum class AuxInputMode {
34 kNoAuxInput,
35 kCrossLinking,
36 kNoCrossLinking,
37 };
38
39 using ::testing::ElementsAreArray;
40
41 static float rnn_input[] = {
42 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
43 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
44 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
45 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
46 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
47 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
48 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
49 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
50 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
51 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
52 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
53 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
54 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
55 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
56 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
57 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
58 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
59 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
60 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
61 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
62 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
63 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
64 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
65 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
66 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
67 0.93455386, -0.6324693, -0.083922029};
68
69 static float rnn_golden_fw_output[] = {
70 0.496726, 0, 0.965996, 0, 0.0584254, 0,
71 0, 0.12315, 0, 0, 0.612266, 0.456601,
72 0, 0.52286, 1.16099, 0.0291232,
73
74 0, 0, 0.524901, 0, 0, 0,
75 0, 1.02116, 0, 1.35762, 0, 0.356909,
76 0.436415, 0.0355727, 0, 0,
77
78 0, 0, 0, 0.262335, 0, 0,
79 0, 1.33992, 0, 2.9739, 0, 0,
80 1.31914, 2.66147, 0, 0,
81
82 0.942568, 0, 0, 0, 0.025507, 0,
83 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
84 0.8158, 1.21805, 0.586239, 0.25427,
85
86 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
87 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
88 0, 1.22031, 1.30117, 0.495867,
89
90 0.222187, 0, 0.72725, 0, 0.767003, 0,
91 0, 0.147835, 0, 0, 0, 0.608758,
92 0.469394, 0.00720298, 0.927537, 0,
93
94 0.856974, 0.424257, 0, 0, 0.937329, 0,
95 0, 0, 0.476425, 0, 0.566017, 0.418462,
96 0.141911, 0.996214, 1.13063, 0,
97
98 0.967899, 0, 0, 0, 0.0831304, 0,
99 0, 1.00378, 0, 0, 0, 1.44818,
100 1.01768, 0.943891, 0.502745, 0,
101
102 0.940135, 0, 0, 0, 0, 0,
103 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
104 1.30225, 1.59644, 0.70222, 0,
105
106 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
107 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
108 0.0454298, 0.300267, 0.562784, 0.395095,
109
110 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
111 0, 0, 0, 0.735363, 0.0759267, 1.91017,
112 0.941888, 0, 0, 0,
113
114 0, 0, 1.5909, 0, 0, 0,
115 0, 0.5755, 0, 0.184687, 0, 1.56296,
116 0.625285, 0, 0, 0,
117
118 0, 0, 0.0857888, 0, 0, 0,
119 0, 0.488383, 0.252786, 0, 0, 0,
120 1.02817, 1.85665, 0, 0,
121
122 0.00981836, 0, 1.06371, 0, 0, 0,
123 0, 0, 0, 0.290445, 0.316406, 0,
124 0.304161, 1.25079, 0.0707152, 0,
125
126 0.986264, 0.309201, 0, 0, 0, 0,
127 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
128 0.524981, 1.92076, 2.07013, 0.333244,
129
130 0.415153, 0.210318, 0, 0, 0, 0,
131 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
132 0.628881, 3.58099, 1.49974, 0};
133
134 static float rnn_golden_bw_output[] = {
135 0.496726, 0, 1.00883, 0, 0.0584256, 0, 0,
136 0.236412, 0, 0, 0.612267, 0.487726, 0, 0.54883,
137 1.16099, 0.0291233, 0, 0, 0.428302, 0, 0,
138 0, 0, 1.13262, 0, 1.64415, 0, 0.311249,
139 0.570804, 0.259696, 0, 0, 0, 0, 0,
140 0.262334, 0, 0, 0, 1.23781, 0, 2.86532,
141 0, 0, 1.34389, 2.76409, 0, 0, 1.03969,
142 0, 0.00410865, 0, 0.0470295, 0, 0, 0,
143 0.371556, 0.27175, 1.36614, 1.63956, 0.683887, 1.06176, 0.719552,
144 0.301314, 0.971195, 0, 0.697143, 0, 0.215219, 0.210693,
145 0.363027, 0, 0.501283, 0, 1.13399, 0.623774, 0,
146 1.09851, 1.33313, 0.470441, 0.210965, 0, 0.664178, 0,
147 0.839686, 0, 0, 0.147834, 0, 0, 0,
148 0.58786, 0.490128, 0, 0.905806, 0, 0.932134, 0.424257,
149 0, 0, 0.860629, 0, 0, 0, 0.476425,
150 0, 0.566017, 0.513721, 0.207341, 1.09508, 1.08385, 0,
151 0.973787, 0, 0, 0, 0, 0, 0,
152 1.20698, 0, 0, 0, 1.56135, 1.12369, 0.99588,
153 0.459803, 0, 0.915854, 0, 0, 0, 0,
154 0, 0, 2.03206, 0, 0.773264, 0.267228, 1.55012,
155 1.202, 1.51611, 0.701202, 0, 0.725088, 0, 0.509069,
156 0, 0.671349, 0.581129, 0.343447, 0, 0.107755, 0.611838,
157 1.4331, 1.55871, 0.015242, 0.140624, 0.492562, 0.395095, 0.147722,
158 0, 0.784925, 0, 1.65477, 0.715257, 0, 0,
159 0, 0.685024, 0, 1.89505, 1.00037, 0, 0,
160 0, 0, 0, 1.52659, 0, 0, 0,
161 0, 0.618583, 0, 0.11115, 0, 1.37194, 0.630225,
162 0, 0, 0, 0, 0, 0.0322124, 0,
163 0, 0, 0, 0.430834, 0.252786, 0, 0,
164 0, 0.991297, 1.98451, 0, 0, 0.111511, 0,
165 1.05513, 0, 0, 0, 0, 0, 0,
166 0.290445, 0.412559, 0.0429958, 0.256564, 1.27858, 0.289948, 0,
167 1.01693, 0.327141, 0, 0, 0, 0, 0,
168 1.83508, 0.346248, 0, 0.961535, 0.790026, 0.552203, 2.13457,
169 2.19233, 0.333244, 0.316526, 0.179398, 0, 0, 0,
170 0, 0, 1.86126, 0, 0.728256, 0.750013, 0.011861,
171 0.576383, 3.38891, 1.29273, 0};
172
173 const std::initializer_list<float> weights = {
174 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
175 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
176 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
177 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
178 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
179 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
180 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
181 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
182 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
183 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
184 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
185 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
186 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
187 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
188 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
189 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
190 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
191 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
192 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
193 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
194 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
195 0.277308, 0.415818};
196
197 static float endtoend_input[] = {
198 0.996808, 0.060710, 0.981855, 0.570017, 0.525164, 0.796859, 0.696547,
199 0.505925, 0.991844, 0.461208, 0.949371, 0.027624, 0.539236, 0.841854,
200 0.915222, 0.538569, 0.069375, 0.237905, 0.903700, 0.441703, 0.536196,
201 0.402724, 0.761635, 0.025063, 0.082592, 0.688245, 0.239310, 0.256931,
202 0.658900, 0.105695, 0.301983, 0.655708, 0.166405, 0.283837, 0.225725,
203 0.691569, 0.080696, 0.922272, 0.197494, 0.072540, 0.383481, 0.146865,
204 0.100163, 0.922717, 0.988720, 0.015386, 0.461286, 0.058095, 0.253290,
205 0.364986, 0.499797, 0.789487, 0.767709, 0.261433, 0.814549, 0.850302,
206 0.949678, 0.053859, 0.107233, 0.608577, 0.159554, 0.409215, 0.264285,
207 0.325960, 0.693053, 0.490011, 0.017529, 0.773749, 0.412283, 0.215023,
208 0.846288, 0.795764, 0.361889, 0.946452, 0.718481, 0.350608, 0.961837,
209 0.179767, 0.408703, 0.215128, 0.544753, 0.908500, 0.004614, 0.312462,
210 0.169933, 0.819163, 0.162764, 0.119611, 0.873022, 0.269997, 0.728188,
211 0.032576, 0.679212, 0.992474, 0.358536, 0.372265, 0.482484, 0.376065,
212 0.146014, 0.894767, 0.591088, 0.992302, 0.690531, 0.952977, 0.938754,
213 0.409012, 0.303585, 0.900591, 0.588780, 0.712287, 0.115719, 0.133533,
214 0.620788, 0.120334, 0.445995, 0.790720, 0.939497, 0.608759, 0.910331,
215 0.812519, 0.878756, 0.638519, 0.845096, 0.557968, 0.630993, 0.203632,
216 0.930233, 0.113477, 0.579697, 0.076247, 0.008244, 0.170785, 0.068549,
217 0.698776, 0.123761, 0.007303, 0.107788, 0.427346, 0.907894, 0.696568,
218 0.139633, 0.023613, 0.830100, 0.760421, 0.143947, 0.276096, 0.551141,
219 0.083444, 0.884855, 0.461472, 0.895963, 0.763611, 0.099992, 0.741059,
220 0.321579, 0.730984, 0.944691, 0.251812, 0.844461, 0.524388, 0.328059,
221 0.852706, 0.695172, 0.396607, 0.551482, 0.818934, 0.403910, 0.659270,
222 0.246280, 0.311804, 0.355838, 0.385913, 0.335418, 0.185938, 0.146334,
223 0.479364, 0.462034, 0.697475, 0.562808, 0.346888, 0.158948, 0.458771,
224 0.110499, 0.258939, 0.199830, 0.432078, 0.989924, 0.144521, 0.683890,
225 0.834385, 0.668908, 0.011949, 0.687091, 0.364081, 0.408556, 0.238572,
226 0.183015, 0.812466, 0.897842, 0.429294, 0.124271, 0.253680, 0.815207,
227 0.459688, 0.439618, 0.961541, 0.939053, 0.901651, 0.659016, 0.501861,
228 0.248539, 0.817964, 0.960632, 0.359038, 0.076903, 0.160462, 0.791117,
229 0.066826, 0.304983, 0.475007, 0.901211, 0.973891, 0.486955, 0.588302,
230 0.337972, 0.895512, 0.826874, 0.520987, 0.707978, 0.724716, 0.950281,
231 0.832249, 0.978396, 0.765488, 0.291937, 0.418014, 0.727029, 0.230990,
232 0.319665, 0.386045, 0.732850, 0.568204, 0.204009, 0.693482, 0.927242,
233 0.280912, 0.853944, 0.718359, 0.347738, 0.158927, 0.193366, 0.248950,
234 0.132818, 0.680321, 0.837252, 0.470790, 0.575833, 0.664126, 0.991777,
235 0.283811, 0.388843, 0.942058, 0.116060, 0.367239, 0.707546, 0.407997,
236 0.785253, 0.434575, 0.638986, 0.104917, 0.820620, 0.371837, 0.673121,
237 0.024629, 0.065319, 0.600363, 0.305541, 0.919263, 0.318722, 0.653279,
238 0.078190, 0.512088, 0.902229, 0.211009, 0.192409, 0.739480, 0.681799,
239 0.768242, 0.403607, 0.673576, 0.052052, 0.792450, 0.615634, 0.168112,
240 0.159689, 0.323180, 0.576109, 0.944941, 0.757755, 0.215095, 0.049858,
241 0.578375, 0.586932, 0.722979, 0.603003, 0.652251, 0.323343, 0.908544,
242 0.571514, 0.642065, 0.561823, 0.649704, 0.154153, 0.464051, 0.860713,
243 0.346562, 0.203532, 0.542512, 0.114804, 0.607139, 0.216088, 0.166856,
244 0.399588, 0.831722, 0.334968, 0.559277, 0.154902, 0.911077, 0.504218,
245 0.912656, 0.126172, 0.554076, 0.491031, 0.713104, 0.277055, 0.094034,
246 0.365355, 0.600398, 0.002578, 0.936869, 0.242463, 0.564401, 0.586574,
247 0.396616, 0.028452, 0.447287, 0.743178, 0.231984, 0.989799, 0.857982,
248 0.839122, 0.205887, 0.024838, 0.238711, 0.037608, 0.359806, 0.797987,
249 0.192510, 0.270883, 0.302205, 0.105166, 0.397055, 0.856281, 0.596197,
250 0.110160, 0.133336, 0.690231, 0.475515, 0.733734, 0.692809, 0.412384,
251 0.976196, 0.257209, 0.998958, 0.372812, 0.285661, 0.446245, 0.115990,
252 0.517645, 0.436044, 0.973972, 0.356767, 0.641930, 0.998810, 0.595478,
253 0.679539, 0.358617, 0.393465, 0.872049, 0.629500, 0.695670, 0.977215,
254 0.026555, 0.551951, 0.573412, 0.136715, 0.685287, 0.263643, 0.612229,
255 0.419020, 0.956451, 0.024613, 0.395216, 0.213661, 0.023572, 0.768029,
256 0.499322, 0.469816, 0.884019, 0.016967, 0.905860, 0.857991, 0.373734,
257 0.547791, 0.856802, 0.969211, 0.227330, 0.215418, 0.362676, 0.099378,
258 0.844918, 0.058346, 0.076594, 0.871473, 0.610297, 0.650006, 0.008188,
259 0.295583, 0.913648, 0.620417, 0.714603, 0.870100, 0.645031, 0.109820,
260 0.083760, 0.668602, 0.877849, 0.583082, 0.138419, 0.761868, 0.600049,
261 0.044279, 0.619859, 0.973783, 0.592069, 0.476661, 0.942994, 0.819399,
262 0.692079, 0.305670, 0.918778, 0.536997, 0.364016, 0.995371, 0.408470,
263 0.974313, 0.645377, 0.416658, 0.269896, 0.559025, 0.037075, 0.984499,
264 0.429125, 0.682105, 0.094319, 0.512885, 0.350707, 0.972168, 0.095967,
265 0.489126, 0.734035, 0.696016, 0.533405, 0.353894, 0.669799, 0.125474,
266 0.830555, 0.612793, 0.944873, 0.522634, 0.918463, 0.863651, 0.059631,
267 0.282479, 0.859022, 0.468101, 0.256791, 0.504398, 0.884758, 0.526687,
268 0.063423, 0.921833, 0.511186, 0.492548, 0.603939, 0.605505, 0.005433,
269 0.954646, 0.577673, 0.101400, 0.443772, 0.311708, 0.797417, 0.977176,
270 0.665602, 0.467216, 0.102650, 0.496157, 0.080009, 0.047524, 0.018791,
271 0.998471, 0.911174, 0.078422, 0.280950, 0.770196, 0.546523, 0.537741,
272 0.274594, 0.431281, 0.064428, 0.338017, 0.353115, 0.575615, 0.830565,
273 0.957053, 0.181120, 0.835998, 0.911699, 0.758793, 0.937398, 0.355471,
274 0.070501, 0.734815, 0.332647, 0.736103, 0.202031, 0.435297, 0.232261,
275 0.282039, 0.482821, 0.251052, 0.280511, 0.393995, 0.329474, 0.561460,
276 0.164191, 0.875997, 0.099202, 0.438785, 0.307278, 0.163630, 0.776802,
277 0.660393, 0.739244, 0.607367, 0.617446, 0.920364, 0.443365, 0.529145,
278 0.679157, 0.380763, 0.884616, 0.749658, 0.115578, 0.217263, 0.485761,
279 0.317609, 0.652560, 0.718021, 0.599648, 0.135381, 0.969073, 0.880159,
280 0.529376, 0.298547, 0.441619, 0.693567, 0.174544, 0.540821, 0.132351,
281 0.481822, 0.704450, 0.909153, 0.142215, 0.443695, 0.516520, 0.759661,
282 0.364059, 0.959885, 0.288806, 0.043216, 0.340648, 0.173422, 0.792874,
283 0.456226, 0.390685, 0.278634, 0.773834, 0.043245, 0.996656, 0.373483,
284 0.178625, 0.965729, 0.253641, 0.708001, 0.264276, 0.695260, 0.401568,
285 0.438820, 0.236081, 0.533919, 0.920642, 0.940531, 0.443072, 0.062857,
286 0.384226, 0.959592, 0.822518, 0.748285, 0.919477, 0.111325, 0.791501,
287 0.260124, 0.284747, 0.584375, 0.716350, 0.675431, 0.863009, 0.490184,
288 0.718676, 0.859665, 0.863666, 0.897301, 0.825393, 0.117308, 0.605302,
289 0.089669, 0.812568, 0.006870, 0.528489, 0.048649, 0.540788, 0.449131,
290 0.989180, 0.983860, 0.511988, 0.373407, 0.943452, 0.334506, 0.121692,
291 0.862929, 0.445831, 0.913193, 0.123053, 0.730578, 0.497568, 0.839402,
292 0.406009, 0.360577, 0.329586, 0.124685, 0.220241, 0.193253, 0.021986,
293 0.045634, 0.310560, 0.627288, 0.135303, 0.123128, 0.634158, 0.663792,
294 0.171777, 0.174946, 0.112923, 0.160958, 0.158806, 0.624911, 0.534364,
295 0.102259, 0.959418, 0.656056, 0.965187, 0.405249, 0.569249, 0.088240,
296 0.135827, 0.066817, 0.927642, 0.541836, 0.427393, 0.257229, 0.666520,
297 0.647634, 0.450481, 0.688506, 0.693269, 0.761042, 0.315794, 0.828572,
298 0.884170, 0.949952, 0.492364, 0.055947, 0.124898, 0.605288, 0.216905,
299 0.283705, 0.230199, 0.751269, 0.385963, 0.189616, 0.407326, 0.351151,
300 0.594865, 0.976575, 0.439391, 0.730692, 0.043392, 0.367033, 0.272527,
301 0.470785, 0.624261, 0.939048, 0.118419, 0.074743, 0.627554, 0.811688,
302 0.835784, 0.943348, 0.640260, 0.719954, 0.893300, 0.132625, 0.775901,
303 0.018199, 0.737913, 0.992806, 0.301903, 0.968111, 0.744076, 0.687867,
304 0.157728, 0.151401, 0.039017, 0.752593, 0.127976, 0.478408, 0.483284,
305 0.171368, 0.845441, 0.755811, 0.642153, 0.469702, 0.694859, 0.760572,
306 0.544445, 0.322413, 0.572260, 0.380229, 0.265761, 0.212521, 0.100183,
307 0.159062, 0.345146, 0.876084, 0.177261, 0.083058, 0.868891, 0.479164,
308 0.051169, 0.612966, 0.167030, 0.208897, 0.764367, 0.206048, 0.961490,
309 0.892343, 0.684456, 0.444774, 0.063711, 0.529896, 0.200585, 0.705863,
310 0.999598, 0.895444, 0.466435, 0.544043, 0.217857, 0.038696, 0.924272,
311 0.483618, 0.251217, 0.024455, 0.642680, 0.596362, 0.900539, 0.819941,
312 0.679420, 0.769430, 0.299105, 0.730590, 0.382396, 0.466135, 0.939487,
313 0.146763, 0.672183, 0.900977, 0.039106, 0.356638, 0.345750, 0.102817,
314 0.886535, 0.546336, 0.808681, 0.886133, 0.441780, 0.275116, 0.430176,
315 0.659637, 0.313812, 0.354448, 0.143255, 0.565028, 0.378903, 0.785935,
316 0.161391, 0.279443, 0.605876, 0.840811, 0.048873, 0.904980, 0.571401,
317 0.431269, 0.371115, 0.510887, 0.578032, 0.043298, 0.411864, 0.617138,
318 0.399936, 0.757614, 0.719955, 0.286471, 0.303950, 0.528636, 0.172604,
319 0.745730, 0.803752, 0.602780, 0.405367, 0.117564, 0.957228, 0.548622,
320 0.682592, 0.336131, 0.334557, 0.843983, 0.615574, 0.940433, 0.684794,
321 0.664447, 0.845413, 0.256194, 0.095715, 0.216529, 0.767082, 0.673747,
322 0.259827, 0.178946, 0.290885, 0.659763, 0.936560, 0.010840, 0.946234,
323 0.240510, 0.539476, 0.118838, 0.986240, 0.343228, 0.721618, 0.391606,
324 0.460792, 0.678846, 0.940228, 0.143384, 0.014977, 0.274785, 0.987367,
325 0.630551, 0.215218, 0.672161, 0.294998, 0.060631, 0.928355, 0.390713,
326 0.277160, 0.695436, 0.064460, 0.536987, 0.874382, 0.355345, 0.196751,
327 0.810942, 0.366185, 0.142985, 0.051452, 0.905661, 0.261823, 0.037691,
328 0.248889, 0.983441, 0.429297, 0.709681, 0.662286, 0.369525, 0.853066,
329 0.677263, 0.644310, 0.840433, 0.307814, 0.859528, 0.512593, 0.602812,
330 0.920160, 0.440948, 0.993525, 0.197320, 0.136384, 0.057984, 0.734307,
331 0.010766, 0.413329, 0.931058, 0.821707, 0.779514, 0.074043, 0.873159,
332 0.685175, 0.335865, 0.910850, 0.934065, 0.319306, 0.340147, 0.643746,
333 0.981592, 0.709673, 0.496812, 0.658856, 0.353983, 0.337245, 0.966670,
334 0.213511, 0.849838, 0.569482, 0.133671, 0.290786, 0.563007, 0.330991,
335 0.427170, 0.620991, 0.065299, 0.437936, 0.034320, 0.996356, 0.259643,
336 0.813834, 0.070399, 0.132802, 0.499009, 0.406265, 0.043652, 0.433074,
337 0.725570, 0.383800, 0.076820, 0.707163, 0.093473, 0.573632, 0.366018,
338 0.447456, 0.910877, 0.332688, 0.660967, 0.760714, 0.902170, 0.794638,
339 0.051500, 0.465177, 0.125630, 0.478670, 0.086168, 0.190928, 0.916605,
340 0.120488, 0.187285, 0.176248, 0.934322, 0.257684, 0.309050, 0.433331,
341 0.663949, 0.352703, 0.866405, 0.389519, 0.736502, 0.943226, 0.096682,
342 0.829975, 0.516858, 0.462700, 0.277430, 0.427734, 0.795388, 0.938398,
343 0.188449, 0.697558, 0.733036, 0.239948, 0.162735, 0.858666, 0.718618,
344 0.248903, 0.049594, 0.635223, 0.369391, 0.236879, 0.811472, 0.303713,
345 0.494563, 0.120522, 0.737044, 0.158511, 0.473225, 0.603450, 0.548030,
346 0.209727, 0.546675, 0.644712, 0.039702, 0.063533, 0.107412, 0.317132,
347 0.491267, 0.902800, 0.255530, 0.679716, 0.600359, 0.988566, 0.919664,
348 0.763094, 0.847232, 0.638283, 0.011997, 0.896825, 0.273506, 0.381388,
349 0.133704, 0.084978, 0.685101, 0.628267, 0.205500, 0.422145, 0.786778,
350 0.678725, 0.025595, 0.334808, 0.888452, 0.572271, 0.979520, 0.928154,
351 0.635804, 0.086932, 0.245286, 0.127071, 0.989732, 0.500816, 0.806787,
352 0.590091, 0.489382, 0.726451, 0.353185, 0.336614, 0.364734, 0.365182,
353 0.233439, 0.638240, 0.746570, 0.367143, 0.723218, 0.431671, 0.995410,
354 0.928718, 0.853816, 0.782188, 0.607442, 0.879411, 0.116995, 0.495894,
355 0.451682, 0.096515, 0.424048, 0.087485, 0.183447, 0.669334, 0.214556,
356 0.173179, 0.170151, 0.021343, 0.763269, 0.659533, 0.747794, 0.116454,
357 0.996147, 0.112528, 0.481635, 0.229586, 0.750768, 0.228205, 0.596730,
358 0.473985, 0.659876, 0.592139, 0.402703, 0.513692, 0.374327, 0.010145,
359 0.393103, 0.491322, 0.506039, 0.844785, 0.587837, 0.930088, 0.932270,
360 0.771284, 0.599422, 0.146826, 0.944463, 0.769573, 0.168169, 0.707732,
361 0.429106, 0.915964, 0.824186, 0.425253, 0.028492, 0.305821, 0.654839,
362 0.779259, 0.534026, 0.251569, 0.253245, 0.193901, 0.843708, 0.655947,
363 0.707593, 0.218035, 0.666093, 0.100696, 0.709357, 0.172132, 0.945481,
364 0.297195, 0.102220, 0.877751, 0.068479, 0.701642, 0.024577, 0.012941,
365 0.471215, 0.192747, 0.720673, 0.900321, 0.108710, 0.544859, 0.325574,
366 0.137202, 0.850679, 0.980413, 0.916462, 0.384705, 0.231982, 0.169706,
367 0.578607, 0.075690, 0.825654, 0.286200, 0.293725, 0.491746, 0.386896,
368 0.003083, 0.663878, 0.332377, 0.300278, 0.766098, 0.210128, 0.368756,
369 0.467740, 0.234705, 0.381697, 0.938955, 0.427451, 0.102370, 0.839275,
370 0.536162, 0.647229, 0.164849, 0.673364, 0.497908, 0.145262, 0.589825,
371 0.882613, 0.377244, 0.759532, 0.461220, 0.452934, 0.585185, 0.747420,
372 0.746660, 0.076932, 0.134316, 0.749743, 0.740810, 0.466692, 0.050020,
373 0.506908, 0.676820, 0.418776, 0.974648, 0.911525, 0.800474, 0.913602,
374 0.338976, 0.902844, 0.752878, 0.875138, 0.550072, 0.917727, 0.548502,
375 0.047981, 0.062989, 0.138327, 0.930594, 0.440233, 0.897859, 0.391814,
376 0.893168, 0.483044, 0.139234, 0.639828, 0.559975, 0.273549, 0.389570,
377 0.300785, 0.740242, 0.439590, 0.807693, 0.417062, 0.858367, 0.782341,
378 0.328586, 0.658840, 0.695943, 0.667562, 0.561684, 0.448821, 0.542700,
379 0.111756, 0.366548, 0.091202, 0.159737, 0.429537, 0.229529, 0.090331,
380 0.869770, 0.127388, 0.482145, 0.762938, 0.610432, 0.621379, 0.402765,
381 0.170407, 0.894928, 0.792336, 0.471192, 0.635170, 0.231926, 0.278886,
382 0.052232, 0.090293, 0.061226, 0.380818, 0.749133, 0.757170, 0.048380,
383 0.310817, 0.205990, 0.591080, 0.422573, 0.572538, 0.682282, 0.582310,
384 0.002075, 0.911812, 0.672641, 0.871845, 0.039199, 0.154786, 0.634783,
385 0.649631, 0.776165, 0.037548, 0.820038, 0.671093, 0.829884, 0.291231,
386 0.306263, 0.061810, 0.570116, 0.358495, 0.152103, 0.631343, 0.739313,
387 0.901236, 0.388512, 0.787693, 0.212053, 0.594503, 0.378773, 0.634626,
388 0.167040, 0.061056, 0.216937, 0.169115, 0.972867, 0.889578, 0.040960,
389 0.012067, 0.044364, 0.675743, 0.661698, 0.820529, 0.713291, 0.481736,
390 0.491623, 0.543175, 0.772966, 0.797886, 0.604985, 0.343083, 0.156380,
391 0.757088, 0.974425, 0.895693, 0.658324, 0.362938, 0.683386, 0.870376,
392 0.957440, 0.062159, 0.505002, 0.124481, 0.123215, 0.721939, 0.293596,
393 0.096082, 0.611517, 0.334556, 0.108149, 0.655881, 0.010299, 0.769846,
394 0.476411, 0.723590, 0.251582, 0.968033, 0.266765, 0.024548, 0.765919,
395 0.871750, 0.367631, 0.922299, 0.628838, 0.342056, 0.817992, 0.287162,
396 0.704994, 0.501378, 0.157538, 0.662434, 0.563537, 0.662541, 0.786915,
397 0.686752, 0.384480, 0.080511, 0.782834, 0.995997, 0.415067, 0.890983,
398 0.651878, 0.425365, 0.660829, 0.128289, 0.148956, 0.912411, 0.096322,
399 0.415721, 0.936959, 0.862241, 0.287471, 0.304590, 0.784540, 0.916309,
400 0.646646, 0.602533, 0.203471, 0.351640, 0.103911, 0.361009, 0.014074,
401 0.667448, 0.023550, 0.800989, 0.354200, 0.408030, 0.881500, 0.137034,
402 0.404026, 0.296566, 0.028017, 0.055904, 0.721932, 0.688846, 0.184193,
403 0.870887, 0.601257, 0.280515, 0.286608, 0.538216, 0.142755, 0.574079,
404 0.842806, 0.927296, 0.490388, 0.489452, 0.529828, 0.693859, 0.841092,
405 0.633739, 0.054869, 0.855167, 0.301187, 0.078419, 0.656156, 0.655388,
406 0.486448, 0.537656, 0.792422, 0.890475, 0.834222, 0.820439, 0.946379,
407 0.556153, 0.509285, 0.130571, 0.427041, 0.110542, 0.411086, 0.713648,
408 0.648758, 0.553842, 0.287727, 0.491563, 0.481137, 0.778116, 0.981015,
409 0.010966, 0.471975, 0.822107, 0.644705, 0.526844, 0.677274, 0.945892,
410 0.605263, 0.333430, 0.601280, 0.091711, 0.871086, 0.393702, 0.982186,
411 0.705307, 0.214141, 0.928564, 0.261461, 0.723426, 0.059136, 0.688501,
412 0.833968, 0.470222, 0.402150, 0.482725, 0.024063, 0.689877, 0.974289,
413 0.505201, 0.467993, 0.955304, 0.516166, 0.939968, 0.777411, 0.160871,
414 0.466812, 0.454685, 0.106763, 0.072075, 0.788115, 0.708043, 0.163786,
415 0.659201, 0.101744, 0.145971, 0.364508, 0.315885, 0.074536, 0.625969,
416 0.039311, 0.133672, 0.314471, 0.873279, 0.603893, 0.716620, 0.356004,
417 0.627957, 0.406498, 0.330292, 0.133157, 0.874490, 0.285596, 0.649324,
418 0.814458, 0.063007, 0.810195, 0.281270, 0.517693, 0.916958, 0.353345,
419 0.305808, 0.625000, 0.517131, 0.965009, 0.726745, 0.663102, 0.329518,
420 0.042630, 0.737638, 0.955487, 0.081940, 0.871310, 0.269957, 0.955219,
421 0.475203, 0.986578, 0.311223, 0.103160, 0.393075, 0.641515, 0.236317,
422 0.267566, 0.927112, 0.885641, 0.082024, 0.990119, 0.695835, 0.363295,
423 0.507812, 0.612793, 0.716640, 0.813620, 0.237793, 0.233770, 0.778629,
424 0.964538, 0.896872, 0.108147, 0.007167, 0.634510, 0.063633, 0.089108,
425 0.505820, 0.333591, 0.044327, 0.981023, 0.320168, 0.355550, 0.084182,
426 0.713244, 0.997065, 0.320499, 0.980810, 0.924177, 0.206140, 0.062834,
427 0.914296, 0.901975, 0.426129, 0.422107, 0.514768, 0.142768, 0.235727,
428 0.752561, 0.376539, 0.014356, 0.717099, 0.273411, 0.122502, 0.724266,
429 0.907921, 0.186136, 0.813374, 0.413741, 0.519726, 0.857701, 0.394764,
430 0.839895, 0.213251, 0.478946, 0.553139, 0.210317, 0.799446, 0.533948,
431 0.134493, 0.005586, 0.596782, 0.048789, 0.907561, 0.022911, 0.470896,
432 0.422329, 0.165679, 0.706623, 0.174890, 0.542218, 0.720979, 0.891989,
433 0.815629, 0.843481, 0.616255, 0.723551, 0.029617, 0.429630, 0.137292,
434 0.549343, 0.287331, 0.532056, 0.389238, 0.500583, 0.011002, 0.942377,
435 0.710899, 0.810448, 0.476326, 0.845392, 0.816033, 0.073108, 0.894181,
436 0.723594, 0.096019, 0.365077, 0.145923, 0.261699, 0.071700, 0.320813,
437 0.803917, 0.792679, 0.212802, 0.619546, 0.636160, 0.829057, 0.343096,
438 0.665777, 0.258687, 0.480388, 0.215121, 0.546018, 0.012444, 0.604359,
439 0.046601, 0.023446, 0.546736, 0.757500, 0.833893, 0.023062, 0.602892,
440 0.649927, 0.096170, 0.497074, 0.373521, 0.192189, 0.862151, 0.519444,
441 0.453887, 0.933851, 0.840257, 0.257804, 0.726531, 0.053058, 0.877350,
442 0.362691, 0.882115, 0.220446, 0.028468, 0.140802, 0.700834, 0.243589,
443 0.686821, 0.713278, 0.847948, 0.733421, 0.736723, 0.394684, 0.490921,
444 0.570617, 0.417746, 0.093813, 0.220543, 0.513916, 0.590887, 0.594064,
445 0.706105, 0.453038, 0.113508, 0.159992, 0.386889, 0.953765, 0.417796,
446 0.113420, 0.006823, 0.295146, 0.476111, 0.888938, 0.515592, 0.504579,
447 0.029741, 0.216426, 0.748168, 0.716561, 0.929703, 0.596117, 0.449982,
448 0.666427, 0.990801, 0.940903, 0.237043, 0.408547, 0.034717, 0.457587,
449 0.922463, 0.625603, 0.051651, 0.628568, 0.078641, 0.165159, 0.788560,
450 0.465530, 0.118923, 0.206356, 0.578950, 0.125746, 0.501502, 0.055060,
451 0.014685, 0.017094, 0.559640, 0.044425, 0.233519, 0.307808, 0.760986,
452 0.163223, 0.903925, 0.210969, 0.829650, 0.894726, 0.151872, 0.066693,
453 0.303273, 0.186589, 0.524279, 0.225736, 0.812192, 0.575930, 0.854304,
454 0.890833, 0.741089, 0.642864, 0.356363, 0.860012, 0.849220, 0.935313,
455 0.985758, 0.350722, 0.990373, 0.000443, 0.367815, 0.550013, 0.044868,
456 0.601335, 0.857820, 0.805855, 0.764557, 0.761745, 0.016823, 0.594207,
457 0.656471, 0.168696, 0.660900, 0.959744, 0.355284, 0.185179, 0.185480,
458 0.167477, 0.761110, 0.039784, 0.058310, 0.502199, 0.682648, 0.414673,
459 0.362211, 0.531868, 0.349985, 0.347969, 0.882589, 0.340358, 0.348412,
460 0.250404, 0.890371, 0.393280, 0.851739, 0.748191, 0.199135, 0.616297,
461 0.509936, 0.215958, 0.210504, 0.166407, 0.384654, 0.871404, 0.126151,
462 0.739938, 0.056583, 0.311631, 0.907415, 0.817693, 0.351415, 0.965724,
463 0.319891, 0.034062, 0.380397, 0.682102, 0.565930, 0.730382, 0.030072,
464 0.448519, 0.070741, 0.378484, 0.698924, 0.961112, 0.771764, 0.550663,
465 0.709303, 0.970899, 0.166959, 0.219239, 0.186857, 0.377463, 0.385647,
466 0.571511, 0.248867, 0.511798, 0.311449, 0.305450, 0.823429, 0.218864,
467 0.123142, 0.174844, 0.184588, 0.443034, 0.208906, 0.564986, 0.125136,
468 0.774836, 0.295368, 0.155207, 0.223355, 0.366109, 0.533691, 0.922279,
469 0.327221, 0.305455, 0.472942, 0.036524, 0.276354, 0.639901, 0.255763,
470 0.463211, 0.017364, 0.641410, 0.034722, 0.266231, 0.153207, 0.346171,
471 0.571680, 0.976636, 0.565036, 0.694822, 0.151480, 0.749624, 0.137856,
472 0.360386, 0.314610, 0.262992, 0.135222, 0.609978, 0.418200, 0.358578,
473 0.976087, 0.951891, 0.280856, 0.303307, 0.257346, 0.753798, 0.339831,
474 0.533700, 0.393699, 0.595594, 0.996911, 0.411063, 0.237003, 0.031634,
475 0.677294, 0.390211, 0.377805, 0.248974, 0.366847, 0.942841, 0.943796,
476 0.518327, 0.692465, 0.081653, 0.878713, 0.007074, 0.344645, 0.013936,
477 0.617052, 0.762845, 0.372513, 0.593138, 0.714736, 0.653370, 0.896446,
478 0.972082, 0.407168, 0.236276, 0.505782, 0.800867, 0.831870, 0.502693,
479 0.211930, 0.068873, 0.534327, 0.889224, 0.459084, 0.912132, 0.138197,
480 0.825931, 0.854972, 0.081994, 0.344259, 0.547437, 0.163646, 0.222972,
481 0.554511, 0.508291, 0.236908, 0.171563, 0.271135, 0.609421, 0.764701,
482 0.985871, 0.262790, 0.661147, 0.957953, 0.669958, 0.897423, 0.463734,
483 0.470825, 0.729293, 0.966427, 0.682755, 0.798166, 0.500754, 0.571978,
484 0.257251, 0.412886, 0.710176, 0.083182, 0.267858, 0.792169, 0.427441,
485 0.815295, 0.955815, 0.650413, 0.369805, 0.464106, 0.887320, 0.541368,
486 0.735242, 0.496741, 0.306069, 0.721113, 0.759531, 0.967216, 0.679065,
487 0.429489, 0.864639, 0.142799, 0.900314, 0.593932, 0.109227, 0.583069,
488 0.392098, 0.609981, 0.155047, 0.649349, 0.022867, 0.865222, 0.732531,
489 0.290725, 0.657392, 0.159972, 0.106019, 0.613207, 0.810384, 0.475824,
490 0.077313, 0.697704, 0.017192, 0.812555};
491
492 static float golden_endtoend_output[] = {
493 -1.881211, -0.028385, -3.585066, 1.939770, -3.461155, 1.280415, -4.408978,
494 0.608663, -2.704937, 1.859742, -5.777429, 2.691839, -1.049012, 1.640870,
495 -4.856245, 1.604236, 0.992707, 0.422858, -4.307465, 1.887332, -0.884831,
496 -0.154277, -2.634801, 0.586827, -1.849960, 1.399608, -4.531559, 1.943591,
497 0.271676, -2.893054, -2.066826, 0.235467, -1.248263, -1.164534, -2.640174,
498 -0.112878, -4.386484, 1.253024, -4.135623, 1.068984, -0.043579, -0.832957,
499 -3.257258, -0.514396, -1.651174, 0.638630, -4.364372, 1.548441, -0.289455,
500 0.539845, -4.097627, 0.635001, -0.465071, -0.927701, -2.481498, 0.356616,
501 -2.355012, 0.728806, -3.340283, 1.609038, -4.786268, -0.532272, -1.886150,
502 0.254797, 0.746620, -1.657134, -3.264265, 0.525551, -1.756837, 0.845446,
503 -5.572190, 1.715797, -2.856942, 3.394245, -5.803662, 2.281806, -3.014739,
504 2.616136, -4.728482, 1.659984, -2.106307, 2.711709, -6.173832, 1.352869,
505 -0.038035, 0.107619, -4.279774, 2.341930, -0.980413, -0.119538, -4.049717,
506 1.172128, -3.477744, 2.602274, -6.231380, 2.537300, -0.862214, 0.568722,
507 -3.858362, 0.197867, -1.725885, 3.687312, -7.067363, 2.403544, -0.944963,
508 0.235639, -3.250094, 0.659117, -1.459576, 0.426128, -3.637207, 1.030386,
509 -4.224351, 3.516220, -6.053367, 0.993473, -2.182416, -0.762625, -1.884405,
510 -0.113736, -2.572602, 0.329290, -1.913233, 0.517418, -0.019757, 0.203176,
511 -3.715881, 0.482136, -1.912823, 1.357907, -5.473043, 1.714658, -3.177160,
512 0.089285, -3.127669, 1.268076, 0.772498, -1.622712, -3.850314, 0.436124,
513 -1.495983, 3.439982, -7.623405, 1.726721, -0.423979, 0.180201, -2.902406,
514 0.986457, -1.845638, 0.460903, -5.359343, -1.133931, -1.074456, 0.717304,
515 -3.519856, 1.012126, -0.562301, 1.881967, -6.716627, 2.525036, 0.945480,
516 0.337081, -5.210562, 2.572035, -0.943370, 0.442026, -2.666313, 0.411296,
517 0.002787, -0.000735, -2.498933, 0.771719, -3.568153, 3.833721, -6.617026,
518 2.813922, -0.573970, 1.025208, -3.909923, 1.722648, -1.406849, 0.719783,
519 -5.207438, 1.819442, -0.530895, -0.010887, -2.939614, 0.971225, -1.660297,
520 1.345243, -4.454571, 2.244876, -2.021213, 1.756090, -4.880947, 0.364597,
521 -2.380270, 2.763117, -5.613013, 2.137534, 0.289101, -2.279400, -3.365582,
522 0.170028, -1.142254, -0.709604, -3.656223, 1.804870, -0.854690, 0.592102,
523 -5.010415, 2.462687, -1.474710, 0.566002, -3.621819, -0.391946, -0.423524,
524 -0.631428, -3.513310, 0.962825, -1.480262, 0.319791, -3.610137, 1.842339,
525 -0.250073, 1.182022, -6.249267, 1.604172, 1.153759, -0.734054, -4.620415,
526 -0.030858, 0.050911, 1.524406, -4.724010, 1.451846, -3.277104, 2.414182,
527 -4.605285, 1.846092, -1.503047, -0.618200, -2.746546, -0.459332, -0.980326,
528 -1.199977, -2.043865, -0.165793, -2.214698, 3.108281, -7.127830, -0.123065,
529 1.244948, -3.039923, -4.660061, -0.225957, -0.307210, -1.513205, -2.456005,
530 0.840048, -0.741445, 2.328635, -6.015267, 2.723240, -1.381171, -0.728878,
531 -5.114925, -0.362034, -0.574923, 0.518080, -3.892457, 1.798948, 0.435119,
532 -0.371696, -2.807571, 1.302864, -2.063052, 1.036388, -4.232038, 1.397059,
533 -1.615668, -1.511019, -3.095508, 1.290955, -3.428723, 2.000287, -4.196487,
534 1.566983, 0.196957, 0.224343, -4.926359, -0.691975, -0.214941, 1.546821,
535 -5.384868, 2.290820, -1.878865, 0.493692, -4.129823, 2.112036, 0.516558,
536 -2.553077, -2.717338, 0.017146, -2.016057, 1.628995, -4.240602, 1.189533,
537 -5.460220, 1.254738, -4.214903, 0.755659, -2.893235, 2.937762, -6.169453,
538 2.035456, -5.613212, -0.122254, -1.973646, -0.060619, -2.119598, 1.413512,
539 -4.938738, 1.890244, 0.544169, -2.062413, -3.329637, -0.062515, -1.855805,
540 -0.791297, -2.570353, 0.607615, 0.305812, 0.338930, -4.150270, 2.274937,
541 0.042653, 0.133825, -3.538155, 1.523639, -3.173690, -1.496599, -2.414655,
542 0.464687, -1.448998, -0.368907, -3.520129, 0.203382, -2.443626, 1.266233,
543 -3.393848, 0.605911, -0.015353, 1.402006, -4.441003, 1.419281, 0.603587,
544 0.434146, -4.966566, 2.171872, -0.688264, -0.009981, -4.461103, 1.538354,
545 -5.029816, -0.264424, -1.713510, -0.315258, -1.891606, 0.252074, -2.419428,
546 0.043970, -1.291143, 2.048704, -4.590105, 0.524734, -1.889576, 0.134836,
547 -3.462745, 1.390663, -0.112773, 0.402735, -4.203784, 1.381043, -1.201634,
548 -1.968277, -1.425637, -0.181725, -1.250742, -2.102041, -3.925464, -1.256797,
549 -3.701354, -1.754610, -1.917231, -1.455910, -1.838006, 2.041781, -5.666212,
550 2.752957, -2.659553, 2.553637, -4.872212, 1.443437, -2.081846, 3.311263,
551 -5.912457, 1.871049, 0.196148, -0.307044, -4.024967, 2.149149, 0.361809,
552 0.620415, -5.939984, 0.180672, -1.209180, -0.269122, -3.240285, 1.460315,
553 -1.040803, 1.125700, -6.060366, 0.887767, -3.214111, 1.314368, -3.026808,
554 1.023640, -3.815175, 1.795642, -4.355603, 1.064454, -0.046472, 0.618463,
555 -5.941646, 2.861891, -2.852155, -0.990457, -2.624445, 1.794494, -1.176747,
556 -0.358159, -3.206776, 1.138721, -2.819523, -1.825522, -1.450902, -0.187312,
557 -0.808727, 0.636872, -4.120567, 1.192623, 0.810731, -1.768519, -3.699450,
558 1.527116, -2.772720, 3.012835, -5.912736, 1.599365, -4.696381, 2.234591,
559 -4.139552, 1.061768, -1.880089, 3.596274, -7.006379, 2.382152, -3.158115,
560 3.844430, -7.044156, 2.307596, -2.473970, 1.312644, -5.467269, 0.197154,
561 -1.530040, 1.762275, -5.550757, 0.630276, -3.048947, 1.043777, -3.096658,
562 1.345893, -1.329494, 2.065748, -4.711032, 2.227600, -0.413321, -0.032428,
563 -4.599650, 1.668734, -4.351490, -0.200022, -2.359903, 0.021997, 0.116028,
564 1.159718, -5.093972, -0.142951, -2.409895, 0.906133, -2.728812, 0.809932,
565 -2.597363, 0.494130, -2.357861, 0.369825, -2.165235, 1.148522, -3.130562,
566 0.759034, 0.646335, -1.463660, -3.508299, 1.059679, -1.485465, 1.007319,
567 -4.340716, 1.789864, -1.590654, 1.612324, -4.452007, 2.389805, -5.200148,
568 -1.068398, -1.306923, -0.472408, -0.392165, -0.524996, -2.933478, 1.518430,
569 -1.287781, 0.113422, -3.020525, 1.338359, -0.105982, 0.936014, -4.132197,
570 1.836807, -0.616589, -1.029716, -3.271347, 0.284889, -2.653359, 2.135829,
571 -4.643613, 1.627981, 0.287733, -2.017263, -2.776574, 1.184792, 1.004161,
572 -1.483019, -4.339290, -0.787322, 0.582420, 1.137839, -5.673941, -0.001862,
573 -1.219142, 0.532561, -4.457245, 1.826807, -3.343291, 3.034610, -6.179855,
574 2.235917, -4.369989, 4.018128, -6.632714, 0.926585, -0.485469, 0.536073,
575 -4.179557, 1.489637, -0.521762, 1.636089, -6.137912, 1.500867, -4.086009,
576 1.961372, -3.688977, 1.358220, -1.544034, 1.763837, -4.357567, 1.852201,
577 -2.018725, 1.046264, -6.211127, 1.609419, -0.118441, 1.602284, -6.242423,
578 1.518578, -0.604078, 1.106613, -5.393445, 2.595629, 0.142712, -1.903953,
579 -2.821177, 0.032758, -0.009152, 0.184628, -4.227636, 2.046843, -2.240138,
580 1.256176, -5.108516, -0.308447, -2.998571, 4.657396, -7.582112, 2.510951,
581 -3.535784, 1.704560, -5.068484, 1.318466, -3.058265, 3.073172, -6.998089,
582 3.178849, -2.420286, 2.277806, -4.999528, 1.423890, -1.672914, 0.447460,
583 -4.088940, 1.351087, -1.051546, -0.417955, -4.042147, 1.604102, -1.700931,
584 2.796663, -6.497579, 2.857974, -0.240828, 0.858001, -5.778933, 2.778508,
585 -0.406211, 1.300766, -5.073671, 2.089362, -0.201673, 1.588396, -6.000150,
586 2.185055, -2.332125, 0.768216, -2.609184, 0.327277, -3.358943, -1.020736,
587 -2.389984, 0.315512, -0.561905, 1.948740, -6.408485, 2.231985, -0.603652,
588 0.661829, -5.070386, -1.063058, -0.624796, 1.375772, -4.379606, 1.929358,
589 -1.047263, 0.739100, -5.217857, 2.127625, -5.025338, 0.650344, -2.068460,
590 0.076936, -0.457505, -1.050984, -1.917765, 1.150908, 0.782625, 0.855595,
591 -5.321719, 0.787209, -0.460232, 1.106736, -5.552326, 2.801043, -0.360217,
592 -0.434432, -4.273378, 0.967556, -0.972652, 0.874811, -5.429918, -0.331039,
593 0.115477, 0.111883, -5.418786, 1.240546, -1.842794, 0.505880, -3.676064,
594 -0.682369, 1.858984, -0.742566, -5.784060, 0.673239, -1.280398, 0.280842,
595 -4.848077, 2.214860, -0.785100, -0.588488, -2.438206, 0.786651, -1.568752,
596 1.935400, -6.320256, 2.125338, -1.476457, -1.651941, -2.695734, 0.007338,
597 -3.280860, 2.310385, -5.319578, 1.890123, -0.775723, 0.630606, -4.321582,
598 1.085521, -1.847371, 1.188521, -4.596577, 2.056443, -2.340172, -0.108501,
599 -3.156392, 0.933279, -0.495331, 0.122405, -5.171133, 1.763245, -0.796913,
600 2.310487, -7.247197, 2.401678, -1.908860, 0.043798, -2.393796, 0.573806,
601 -0.608531, 0.154710, -4.669001, 0.750680, 0.468380, 0.392591, -4.755001,
602 2.615217, -1.957774, 1.153513, -4.530099, 1.124362, -3.569415, 1.697154,
603 -3.536335, 0.910758, -2.976264, 1.833129, -4.287203, -0.547050, -2.409768,
604 0.061585, -1.324116, 0.268497, -2.962222, -1.524245, -2.063413, 0.442058,
605 -4.292337, 3.538863, -6.699603, 1.718664, -2.290363, 1.994596, -6.245037,
606 -0.433084, -0.367059, 1.020297, -4.940721, 2.902264, -0.577056, -0.709887,
607 -5.001413, -0.268316, -1.112048, -1.083307, -1.753492, 0.209973, 0.139540,
608 0.917602, -5.232745, 2.538467, -2.139234, -0.187388, -1.837249, -0.478582,
609 -0.731653, -0.481550, -2.531261, 1.044770, 0.707750, 0.279971, -3.221119,
610 1.552074, -2.373144, 0.859518, -3.665156, 1.620278, -1.440871, -0.525581,
611 -2.758271, 1.491873, -2.302013, 1.119935, -5.257080, 2.627170, -3.174739,
612 1.363282, -4.831639, 1.101076, -4.337008, 2.689639, -5.165915, 1.069201,
613 -1.882078, -0.120370, -2.287967, 1.147619, -1.403616, 1.077150, -5.084296,
614 1.658236, -0.919642, 0.487423, -3.001075, 0.741268, 0.107300, 0.943556,
615 -3.544311, 1.000239, -1.627171, 2.871253, -5.179172, 1.429893, -0.826040,
616 0.188670, -4.499894, 1.013447, -2.101299, 0.317516, -3.452141, -0.833776,
617 -1.362144, 1.272437, -4.449355, 1.613591, -2.039873, 2.613175, -6.229640,
618 1.659790, -1.595520, -0.237462, -2.744997, 0.337841, 0.148981, -1.703771,
619 -2.388023, 1.276469, 1.058508, -0.401642, -4.680769, 0.861881, -1.336381,
620 1.153080, -2.834378, 0.721075, 0.900115, 1.360511, -5.573611, 0.949182,
621 -2.970844, 2.017563, -5.186108, -0.201038, -1.192824, 0.610142, -4.450919,
622 -0.897114, -1.812093, 0.422310, -5.245487, 0.256549, 0.320275, -2.324150,
623 -2.967040, -0.260536, -0.721467, 0.454148, -5.058031, 0.526370, -0.895656,
624 0.732240, -3.327363, 1.353953, -1.277912, -0.483171, -1.926713, 0.065044,
625 -2.167506, -0.196606, -1.923437, 0.604962, -2.088319, 1.406834, -5.227296,
626 2.247351, -4.421744, 1.729791, -5.007922, 1.264769, -0.897019, 0.922902,
627 -3.887108, 2.087432, -1.310226, -0.101938, -3.359082, -0.079662, -0.514988,
628 -0.963179, -4.038209, 2.223278, -0.590083, -2.310458, -1.748338, 0.363406,
629 -0.540731, -0.885913, -4.179595, 2.216781, -3.044339, -0.447100, -2.446098,
630 0.931101, -1.676190, 2.096175, -4.980755, 2.262151, -1.095047, 1.897516,
631 -5.996138, 2.191038, 0.297128, -0.780974, -2.884299, 1.195408, -0.521065,
632 -1.955837, -3.091064, -0.404183, -1.961519, 4.076096, -7.521851, 2.242064,
633 -1.988043, 0.303300, -2.422585, 0.322230, -3.377634, 3.499955, -7.084434,
634 2.375587, -0.718851, 2.150076, -5.412241, 2.374280, -2.006088, 2.229828,
635 -5.848188, 2.543077, -2.171042, 2.096026, -5.300007, 0.141405, -1.187745,
636 0.105340, -4.003816, 1.034281, -3.980804, 1.856709, -5.103042, 0.623737,
637 -2.080307, 0.896140, -3.104050, 0.983158, -0.424898, -1.154270, -3.805728,
638 1.978917, -1.314387, 1.235096, -3.148906, 1.113173, 0.111713, 2.055213,
639 -7.565283, 2.100342};
640 const std::initializer_list<float> biases = {
641 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
642 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
643 0.37197268, 0.61957061, 0.3956964, -0.37609905};
644
645 const std::initializer_list<float> recurrent_weights = {
646 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
647 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
648 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
649 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
650 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
651 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
652 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
653 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
654 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
655 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
656 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
657 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
658 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
659 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
660 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
661 0.1};
662
663 class BidirectionalRNNOpModel : public SingleOpModel {
664 public:
BidirectionalRNNOpModel(int batches,int sequence_len,int fw_units,int bw_units,int input_size,int aux_input_size,AuxInputMode aux_input_mode,bool time_major,bool merge_outputs,bool quantize_weights=false,bool asymmetric_quantize_weights=false)665 BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
666 int bw_units, int input_size, int aux_input_size,
667 AuxInputMode aux_input_mode, bool time_major,
668 bool merge_outputs, bool quantize_weights = false,
669 bool asymmetric_quantize_weights = false)
670 : batches_(batches),
671 sequence_len_(sequence_len),
672 fw_units_(fw_units),
673 bw_units_(bw_units),
674 input_size_(input_size),
675 aux_input_size_(aux_input_size),
676 quantize_weights_(quantize_weights) {
677 const TensorType tensor_type =
678 quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
679 input_ = AddInput(TensorType_FLOAT32);
680 fw_weights_ = AddInput(tensor_type);
681 fw_recurrent_weights_ = AddInput(tensor_type);
682 fw_bias_ = AddInput(TensorType_FLOAT32);
683 fw_hidden_state_ = AddVariableInput(TensorType_FLOAT32);
684 bw_weights_ = AddInput(tensor_type);
685 bw_recurrent_weights_ = AddInput(tensor_type);
686 bw_bias_ = AddInput(TensorType_FLOAT32);
687 bw_hidden_state_ = AddVariableInput(TensorType_FLOAT32);
688
689 const auto input_shape =
690 (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
691 : std::vector<int>({batches_, sequence_len_, input_size_});
692
693 std::vector<int> aux_input_shape = {0};
694 std::vector<int> aux_fw_weights_shape = {0};
695 std::vector<int> aux_bw_weights_shape = {0};
696 if (aux_input_mode != AuxInputMode::kNoAuxInput) {
697 aux_input_ = AddInput(TensorType_FLOAT32);
698 aux_input_shape =
699 (time_major)
700 ? std::vector<int>({sequence_len_, batches_, aux_input_size_})
701 : std::vector<int>({batches_, sequence_len_, aux_input_size_});
702 } else {
703 aux_input_ = AddNullInput();
704 }
705
706 if (aux_input_mode == AuxInputMode::kCrossLinking) {
707 aux_fw_weights_ = AddInput(tensor_type);
708 aux_bw_weights_ = AddInput(tensor_type);
709
710 aux_fw_weights_shape = {fw_units, aux_input_size_};
711 aux_bw_weights_shape = {bw_units, aux_input_size_};
712 } else {
713 aux_fw_weights_ = AddNullInput();
714 aux_bw_weights_ = AddNullInput();
715 }
716
717 fw_output_ = AddOutput(TensorType_FLOAT32);
718 if (!merge_outputs) {
719 bw_output_ = AddOutput(TensorType_FLOAT32);
720 }
721
722 SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
723 BuiltinOptions_BidirectionalSequenceRNNOptions,
724 CreateBidirectionalSequenceRNNOptions(
725 builder_, time_major, ActivationFunctionType_RELU,
726 merge_outputs, asymmetric_quantize_weights)
727 .Union());
728
729 BuildInterpreter({
730 input_shape, // input
731 {fw_units_, input_size_}, // fw_weights
732 {fw_units_, fw_units_}, // fw_recurrent_weights
733 {fw_units_}, // fw_bias
734 {batches_, fw_units_}, // fw_hidden_state
735 {bw_units_, input_size_}, // bw_weights
736 {bw_units_, bw_units_}, // bw_recurrent_weights
737 {bw_units_}, // bw_bias
738 {batches_, bw_units_}, // bw_hidden_state
739 aux_input_shape, // aux_input
740 aux_fw_weights_shape, // aux_fw_weights
741 aux_bw_weights_shape, // aux_bw_weights
742 });
743 }
744
SetFwBias(std::initializer_list<float> f)745 void SetFwBias(std::initializer_list<float> f) {
746 PopulateTensor(fw_bias_, f);
747 }
748
SetBwBias(std::initializer_list<float> f)749 void SetBwBias(std::initializer_list<float> f) {
750 PopulateTensor(bw_bias_, f);
751 }
752
SetFwWeights(const std::vector<float> & f)753 void SetFwWeights(const std::vector<float>& f) {
754 if (quantize_weights_) {
755 SymmetricQuantizeAndPopulate(fw_weights_, f);
756 } else {
757 PopulateTensor(fw_weights_, f);
758 }
759 }
760
SetBwWeights(const std::vector<float> & f)761 void SetBwWeights(const std::vector<float>& f) {
762 if (quantize_weights_) {
763 SymmetricQuantizeAndPopulate(bw_weights_, f);
764 } else {
765 PopulateTensor(bw_weights_, f);
766 }
767 }
768
SetFwRecurrentWeights(const std::vector<float> & f)769 void SetFwRecurrentWeights(const std::vector<float>& f) {
770 if (quantize_weights_) {
771 SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
772 } else {
773 PopulateTensor(fw_recurrent_weights_, f);
774 }
775 }
776
SetBwRecurrentWeights(const std::vector<float> & f)777 void SetBwRecurrentWeights(const std::vector<float>& f) {
778 if (quantize_weights_) {
779 SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
780 } else {
781 PopulateTensor(bw_recurrent_weights_, f);
782 }
783 }
784
SetInput(std::initializer_list<float> data)785 void SetInput(std::initializer_list<float> data) {
786 PopulateTensor(input_, data);
787 }
788
SetInput(int offset,float * begin,float * end)789 void SetInput(int offset, float* begin, float* end) {
790 PopulateTensor(input_, offset, begin, end);
791 }
792
SetAuxInput(int offset,float * begin,float * end)793 void SetAuxInput(int offset, float* begin, float* end) {
794 PopulateTensor(aux_input_, offset, begin, end);
795 }
796
SetAuxFwWeights(const std::vector<float> & f)797 void SetAuxFwWeights(const std::vector<float>& f) {
798 if (quantize_weights_) {
799 SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
800 } else {
801 PopulateTensor(aux_fw_weights_, f);
802 }
803 }
804
SetAuxBwWeights(const std::vector<float> & f)805 void SetAuxBwWeights(const std::vector<float>& f) {
806 if (quantize_weights_) {
807 SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
808 } else {
809 PopulateTensor(aux_bw_weights_, f);
810 }
811 }
812
GetFwOutput()813 std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
GetBwOutput()814 std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
815
input_size()816 int input_size() { return input_size_; }
aux_input_size()817 int aux_input_size() { return aux_input_size_; }
num_fw_units()818 int num_fw_units() { return fw_units_; }
num_bw_units()819 int num_bw_units() { return bw_units_; }
num_batches()820 int num_batches() { return batches_; }
sequence_len()821 int sequence_len() { return sequence_len_; }
822
823 private:
824 int input_;
825 int fw_weights_;
826 int fw_recurrent_weights_;
827 int fw_bias_;
828 int fw_hidden_state_;
829 int fw_output_;
830 int bw_weights_;
831 int bw_recurrent_weights_;
832 int bw_bias_;
833 int bw_hidden_state_;
834 int bw_output_;
835 int aux_input_;
836 int aux_fw_weights_;
837 int aux_bw_weights_;
838
839 int batches_;
840 int sequence_len_;
841 int fw_units_;
842 int bw_units_;
843 int input_size_;
844 int aux_input_size_;
845 bool quantize_weights_;
846 };
847
848 // Declare LSTMOpTest as a parameterized test.
849 class BidirectionalRNNOpTest
850 : public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
851
852 INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
853 ::testing::Combine(
854 /*quantize_weights*/ ::testing::Bool(),
855 /*asymmetric_quantize_inputs*/ ::testing::Bool()));
856
857 // TODO(mirkov): add another test which directly compares to TF once TOCO
858 // supports the conversion from dynamic_rnn with BasicRNNCell.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTest)859 TEST_P(BidirectionalRNNOpTest, ClosedBoxTest) {
860 auto params = GetParam();
861 const bool quantize_weights = std::get<0>(params);
862 const bool asymmetric_quantize_inputs = std::get<1>(params);
863 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
864 /*fw_units=*/16, /*bw_units=*/16,
865 /*input_size=*/8, /*aux_input_size=*/0,
866 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
867 /*time_major=*/false,
868 /*merge_outputs=*/false, quantize_weights,
869 asymmetric_quantize_inputs);
870 rnn.SetFwWeights(weights);
871 rnn.SetBwWeights(weights);
872 rnn.SetFwBias(biases);
873 rnn.SetBwBias(biases);
874 rnn.SetFwRecurrentWeights(recurrent_weights);
875 rnn.SetBwRecurrentWeights(recurrent_weights);
876
877 const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
878 float* batch_start = rnn_input;
879 float* batch_end = batch_start + input_sequence_size;
880 rnn.SetInput(0, batch_start, batch_end);
881 rnn.SetInput(input_sequence_size, batch_start, batch_end);
882
883 rnn.Invoke();
884
885 float* golden_fw_start = rnn_golden_fw_output;
886 float* golden_fw_end =
887 golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
888 std::vector<float> fw_expected;
889 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
890 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
891 EXPECT_THAT(rnn.GetFwOutput(),
892 ElementsAreArray(ArrayFloatNear(
893 fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
894
895 float* golden_bw_start = rnn_golden_bw_output;
896 float* golden_bw_end =
897 golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
898 std::vector<float> bw_expected;
899 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
900 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
901 EXPECT_THAT(rnn.GetBwOutput(),
902 ElementsAreArray(ArrayFloatNear(
903 bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
904 }
905
906 // Same as ClosedBox test, but input is reshuffled to time_major format.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTestTimeMajor)907 TEST_P(BidirectionalRNNOpTest, ClosedBoxTestTimeMajor) {
908 auto params = GetParam();
909 const bool quantize_weights = std::get<0>(params);
910 const bool asymmetric_quantize_inputs = std::get<1>(params);
911 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
912 /*fw_units=*/16, /*bw_units=*/16,
913 /*input_size=*/8, /*aux_input_size=*/0,
914 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
915 /*time_major=*/true,
916 /*merge_outputs=*/false, quantize_weights,
917 asymmetric_quantize_inputs);
918 rnn.SetFwWeights(weights);
919 rnn.SetBwWeights(weights);
920 rnn.SetFwBias(biases);
921 rnn.SetBwBias(biases);
922 rnn.SetFwRecurrentWeights(recurrent_weights);
923 rnn.SetBwRecurrentWeights(recurrent_weights);
924
925 // Insert the inputs in time_major format. The batch_major format is:
926 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
927 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
928 for (int i = 0; i < rnn.sequence_len(); i++) {
929 float* batch_start = rnn_input + i * rnn.input_size();
930 float* batch_end = batch_start + rnn.input_size();
931 // The two batches are identical.
932 rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
933 rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
934 }
935
936 rnn.Invoke();
937
938 std::vector<float> fw_expected;
939 for (int i = 0; i < rnn.sequence_len(); i++) {
940 float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
941 float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
942 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
943 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
944 }
945 constexpr float kHybridTolerance = 3.57e-1;
946 constexpr float kFloatTolerance = 1e-5;
947 EXPECT_THAT(
948 rnn.GetFwOutput(),
949 ElementsAreArray(ArrayFloatNear(
950 fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
951 }
952
953 // Same as ClosedBox test, yet with merged outputs.
TEST_P(BidirectionalRNNOpTest,ClosedBoxTestMergeOutputs)954 TEST_P(BidirectionalRNNOpTest, ClosedBoxTestMergeOutputs) {
955 auto params = GetParam();
956 const bool quantize_weights = std::get<0>(params);
957 const bool asymmetric_quantize_inputs = std::get<1>(params);
958 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
959 /*fw_units=*/16, /*bw_units=*/16,
960 /*input_size=*/8, /*aux_input_size=*/0,
961 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
962 /*time_major=*/false,
963 /*merge_outputs=*/true, quantize_weights,
964 asymmetric_quantize_inputs);
965 rnn.SetFwWeights(weights);
966 rnn.SetBwWeights(weights);
967 rnn.SetFwBias(biases);
968 rnn.SetBwBias(biases);
969 rnn.SetFwRecurrentWeights(recurrent_weights);
970 rnn.SetBwRecurrentWeights(recurrent_weights);
971
972 const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
973 float* batch_start = rnn_input;
974 float* batch_end = batch_start + input_sequence_size;
975 rnn.SetInput(0, batch_start, batch_end);
976 rnn.SetInput(input_sequence_size, batch_start, batch_end);
977
978 rnn.Invoke();
979
980 std::vector<float> merged_expected;
981 for (int bid = 0; bid < rnn.num_batches(); bid++) {
982 for (int step = 0; step < rnn.sequence_len(); step++) {
983 merged_expected.insert(
984 merged_expected.end(),
985 rnn_golden_fw_output + rnn.num_fw_units() * step,
986 rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
987 merged_expected.insert(
988 merged_expected.end(),
989 rnn_golden_bw_output + rnn.num_bw_units() * step,
990 rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
991 }
992 }
993 EXPECT_THAT(rnn.GetFwOutput(),
994 ElementsAreArray(ArrayFloatNear(
995 merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
996 }
997
998 // Same as ClosedBox test, but input is reshuffled to time_major format.
TEST(BidirectionalRNNOpTest,ClosedBoxTestTimeMajorMergeOutputs)999 TEST(BidirectionalRNNOpTest, ClosedBoxTestTimeMajorMergeOutputs) {
1000 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1001 /*fw_units=*/16, /*bw_units=*/16,
1002 /*input_size=*/8, /*aux_input_size=*/0,
1003 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1004 /*time_major=*/true,
1005 /*merge_outputs=*/true);
1006 rnn.SetFwWeights(weights);
1007 rnn.SetBwWeights(weights);
1008 rnn.SetFwBias(biases);
1009 rnn.SetBwBias(biases);
1010 rnn.SetFwRecurrentWeights(recurrent_weights);
1011 rnn.SetBwRecurrentWeights(recurrent_weights);
1012
1013 // Insert the inputs in time_major format. The batch_major format is:
1014 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1015 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1016 for (int i = 0; i < rnn.sequence_len(); i++) {
1017 float* batch_start = rnn_input + i * rnn.input_size();
1018 float* batch_end = batch_start + rnn.input_size();
1019 // The two batches are identical.
1020 rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1021 rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1022 }
1023
1024 rnn.Invoke();
1025
1026 std::vector<float> merged_expected;
1027 for (int step = 0; step < rnn.sequence_len(); step++) {
1028 for (int bid = 0; bid < rnn.num_batches(); bid++) {
1029 merged_expected.insert(
1030 merged_expected.end(),
1031 rnn_golden_fw_output + rnn.num_fw_units() * step,
1032 rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
1033 merged_expected.insert(
1034 merged_expected.end(),
1035 rnn_golden_bw_output + rnn.num_bw_units() * step,
1036 rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
1037 }
1038 }
1039 EXPECT_THAT(rnn.GetFwOutput(),
1040 ElementsAreArray(ArrayFloatNear(merged_expected)));
1041 }
1042
1043 // Check that if the input sequence is reversed the outputs are the same just
1044 // forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest,ClosedBoxTestReverseInputs)1045 TEST(BidirectionalRNNOpTest, ClosedBoxTestReverseInputs) {
1046 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1047 /*fw_units=*/16, /*bw_units=*/16,
1048 /*input_size=*/8, /*aux_input_size=*/0,
1049 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1050 /*time_major=*/false,
1051 /*merge_outputs=*/false);
1052 rnn.SetFwWeights(weights);
1053 rnn.SetBwWeights(weights);
1054 rnn.SetFwBias(biases);
1055 rnn.SetBwBias(biases);
1056 rnn.SetFwRecurrentWeights(recurrent_weights);
1057 rnn.SetBwRecurrentWeights(recurrent_weights);
1058
1059 // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
1060 // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
1061 for (int i = 0; i < rnn.sequence_len(); i++) {
1062 float* batch_start = rnn_input + i * rnn.input_size();
1063 float* batch_end = batch_start + rnn.input_size();
1064 const int reverse_idx = rnn.sequence_len() - i - 1;
1065 rnn.SetInput(reverse_idx * rnn.input_size(), batch_start, batch_end);
1066 rnn.SetInput((rnn.sequence_len() + reverse_idx) * rnn.input_size(),
1067 batch_start, batch_end);
1068 }
1069
1070 rnn.Invoke();
1071
1072 // The forward and backward outputs are swapped.
1073 std::vector<float> fw_expected; // consider using std::deque instead.
1074 for (int i = 0; i < rnn.sequence_len(); i++) {
1075 float* golden_fw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1076 float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1077 fw_expected.insert(fw_expected.begin(), golden_fw_start, golden_fw_end);
1078 }
1079 fw_expected.insert(fw_expected.end(), fw_expected.begin(), fw_expected.end());
1080 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1081
1082 std::vector<float> bw_expected;
1083 for (int i = 0; i < rnn.sequence_len(); i++) {
1084 float* golden_bw_start = rnn_golden_fw_output + i * rnn.num_bw_units();
1085 float* golden_bw_end = golden_bw_start + rnn.num_bw_units();
1086 bw_expected.insert(bw_expected.begin(), golden_bw_start, golden_bw_end);
1087 }
1088 bw_expected.insert(bw_expected.end(), bw_expected.begin(), bw_expected.end());
1089 EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1090 }
1091
1092 // Tests an end-to-end neural network with a Bidirectional RNN followed by a
1093 // DNN that aggregates the outputs from the two sequences.
TEST(BidirectionalRNNOpTest,EndToEndTest)1094 TEST(BidirectionalRNNOpTest, EndToEndTest) {
1095 BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
1096 /*fw_units=*/16, /*bw_units=*/16,
1097 /*input_size=*/8, /*aux_input_size=*/0,
1098 /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
1099 /*time_major=*/false,
1100 /*merge_outputs=*/false);
1101 const int output_size = 4;
1102 float dnn_weights[] = {
1103 -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,
1104 -0.23420811, -0.39647382, 0.31423986, 0.61819065, -0.73659575,
1105 -0.89698344, -0.8931554, -0.0845688, 0.5617367, 0.38415289,
1106 -0.11487955, -0.7617774, 0.17927337, 0.15726972, 0.059798479,
1107 0.19009054, -0.27616632, -0.39142907, 0.77744663, -0.046830714,
1108 -0.6603595, 0.21945822, 0.051494241, 0.23785079, 0.19239247,
1109 -0.53268754, 0.65961659, -0.85981959, -0.80232513, 0.84745562,
1110 -0.66070104, -0.036533296, -0.54901814, 0.65353882, -0.41834265,
1111 -0.28561389, 0.75655544, -0.31149811, 0.62981737, 0.31829214,
1112 -0.92734522, -0.48506218, 0.55651462, 0.25192821, 0.67220747,
1113 -0.3836869, -0.55798125, -0.60395885, 0.22488403, -0.78053463,
1114 0.3492105, 0.56452453, 0.4389236, -0.59929526, -0.19762468,
1115 -0.36868393, -0.13198286, -0.53800809, -0.22850353};
1116
1117 std::initializer_list<float> dnn_biases = {0.29177809, -0.98799044,
1118 0.065919638, 0.68781924};
1119
1120 rnn.SetFwWeights(weights);
1121 rnn.SetBwWeights(weights);
1122 rnn.SetFwBias(biases);
1123 rnn.SetBwBias(biases);
1124 rnn.SetFwRecurrentWeights(recurrent_weights);
1125 rnn.SetBwRecurrentWeights(recurrent_weights);
1126
1127 const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
1128 const int output_sequence_size = output_size * rnn.sequence_len();
1129 const int num_examples = 64;
1130 for (int k = 0; k < num_examples; k++) {
1131 float* batch_start = endtoend_input + k * input_sequence_size;
1132 float* batch_end = batch_start + input_sequence_size;
1133 rnn.SetInput(0, batch_start, batch_end);
1134
1135 rnn.Invoke();
1136
1137 std::vector<float> fw_output = rnn.GetFwOutput();
1138 std::vector<float> bw_output = rnn.GetBwOutput();
1139 EXPECT_EQ(fw_output.size(), bw_output.size());
1140
1141 std::transform(fw_output.begin(), fw_output.end(), bw_output.begin(),
1142 fw_output.begin(), std::plus<float>());
1143
1144 std::vector<float> sequence_result;
1145 for (int s = 0; s < rnn.sequence_len(); s++) {
1146 const float* rnn_output = fw_output.data() + s * rnn.num_fw_units();
1147 std::vector<float> results(dnn_biases);
1148 for (int i = 0; i < output_size; i++) {
1149 for (int j = 0; j < rnn.num_fw_units(); j++) {
1150 results[i] += *(rnn_output + j) * dnn_weights[output_size * j + i];
1151 }
1152 }
1153 sequence_result.insert(sequence_result.end(), results.begin(),
1154 results.end());
1155 }
1156
1157 float* golden_start = golden_endtoend_output + k * output_sequence_size;
1158 float* golden_end = golden_start + output_sequence_size;
1159
1160 std::vector<float> expected;
1161 expected.insert(expected.end(), golden_start, golden_end);
1162 EXPECT_THAT(sequence_result, ElementsAreArray(ArrayFloatNear(expected)));
1163 }
1164 }
1165
1166 // Same as ClosedBox test, but has an auxiliary input. The layer has no
1167 // cross-linking, i.e. the regular input is passed as an input to the forward
1168 // network only and the auxiliary input is passed as an input to the backward
1169 // network only.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingRegularAndAuxInput)1170 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularAndAuxInput) {
1171 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1172 /*fw_units=*/16, /*bw_units=*/16,
1173 /*input_size=*/8, /*aux_input_size=*/8,
1174 /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1175 /*time_major=*/true,
1176 /*merge_outputs=*/false);
1177 rnn.SetFwWeights(weights);
1178 rnn.SetBwWeights(weights);
1179 rnn.SetFwBias(biases);
1180 rnn.SetBwBias(biases);
1181 rnn.SetFwRecurrentWeights(recurrent_weights);
1182 rnn.SetBwRecurrentWeights(recurrent_weights);
1183
1184 // Insert the inputs in time_major format. The batch_major format is:
1185 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1186 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1187 for (int i = 0; i < rnn.sequence_len(); i++) {
1188 float* batch_start = rnn_input + i * rnn.input_size();
1189 float* batch_end = batch_start + rnn.input_size();
1190 // The two batches are identical.
1191 // Also make aux input the same as input.
1192 rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1193 rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1194 rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1195 rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1196 }
1197
1198 rnn.Invoke();
1199
1200 std::vector<float> fw_expected;
1201 std::vector<float> bw_expected;
1202 for (int i = 0; i < rnn.sequence_len(); i++) {
1203 float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1204 float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1205 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1206 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1207
1208 float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1209 float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
1210 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1211 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1212 }
1213 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1214 EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1215 }
1216
1217 // Same as above but the auxiliary input is set to zeroes. This test makes sure
1218 // that the forward network works as expected in a no-cross-linking mode.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingRegularInputOnly)1219 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingRegularInputOnly) {
1220 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1221 /*fw_units=*/16, /*bw_units=*/16,
1222 /*input_size=*/8, /*aux_input_size=*/8,
1223 /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1224 /*time_major=*/true,
1225 /*merge_outputs=*/false);
1226 rnn.SetFwWeights(weights);
1227 rnn.SetBwWeights(weights);
1228 rnn.SetFwBias(biases);
1229 rnn.SetBwBias(biases);
1230 rnn.SetFwRecurrentWeights(recurrent_weights);
1231 rnn.SetBwRecurrentWeights(recurrent_weights);
1232
1233 // Initialize bw inputs with zeros.
1234 std::vector<float> bw_inputs(rnn.input_size(), 0);
1235
1236 // Insert the inputs in time_major format. The batch_major format is:
1237 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1238 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1239 for (int i = 0; i < rnn.sequence_len(); i++) {
1240 float* batch_start = rnn_input + i * rnn.input_size();
1241 float* batch_end = batch_start + rnn.input_size();
1242 // The two batches are identical.
1243 // Also make aux input the same as input.
1244 rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
1245 rnn.SetAuxInput(2 * i * rnn.input_size(), &bw_inputs[0],
1246 &bw_inputs[bw_inputs.size() - 1]);
1247 rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1248 rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), &bw_inputs[0],
1249 &bw_inputs[bw_inputs.size() - 1]);
1250 }
1251
1252 rnn.Invoke();
1253
1254 std::vector<float> fw_expected;
1255 for (int i = 0; i < rnn.sequence_len(); i++) {
1256 float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1257 float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1258 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1259 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1260 }
1261 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1262 }
1263
1264 // Same as above but the regular (i.e. not auxiliary) input is set to zeroes.
1265 // This test makes sure that the backward network works as expected in a
1266 // no-cross-linking mode.
TEST(BidirectionalRNNOpTest,ClosedBoxTestNoCrossLinkingAuxInputOnly)1267 TEST(BidirectionalRNNOpTest, ClosedBoxTestNoCrossLinkingAuxInputOnly) {
1268 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1269 /*fw_units=*/16, /*bw_units=*/16,
1270 /*input_size=*/8, /*aux_input_size=*/8,
1271 /*aux_input_mode=*/AuxInputMode::kNoCrossLinking,
1272 /*time_major=*/true,
1273 /*merge_outputs=*/false);
1274 rnn.SetFwWeights(weights);
1275 rnn.SetBwWeights(weights);
1276 rnn.SetFwBias(biases);
1277 rnn.SetBwBias(biases);
1278 rnn.SetFwRecurrentWeights(recurrent_weights);
1279 rnn.SetBwRecurrentWeights(recurrent_weights);
1280
1281 // Initialize bw inputs with zeros.
1282 std::vector<float> fw_inputs(rnn.input_size(), 0);
1283
1284 // Insert the inputs in time_major format. The batch_major format is:
1285 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1286 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1287 for (int i = 0; i < rnn.sequence_len(); i++) {
1288 float* batch_start = rnn_input + i * rnn.input_size();
1289 float* batch_end = batch_start + rnn.input_size();
1290 // The two batches are identical.
1291 // Also make aux input the same as input.
1292 rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1293 rnn.SetInput(2 * i * rnn.input_size(), &fw_inputs[0],
1294 &fw_inputs[fw_inputs.size() - 1]);
1295 rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1296 rnn.SetInput((2 * i + 1) * rnn.input_size(), &fw_inputs[0],
1297 &fw_inputs[fw_inputs.size() - 1]);
1298 }
1299
1300 rnn.Invoke();
1301
1302 std::vector<float> bw_expected;
1303 for (int i = 0; i < rnn.sequence_len(); i++) {
1304 float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units();
1305 float* golden_bw_end = golden_bw_start + rnn.num_fw_units();
1306 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1307 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1308 }
1309 EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1310 }
1311
1312 // Same as ClosedBox test, but an input is passed to auxiliary input instead of
1313 // the regular one. Regular input and weights are set to zero.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingAuxInputOnly)1314 TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnly) {
1315 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1316 /*fw_units=*/16, /*bw_units=*/16,
1317 /*input_size=*/8, /*aux_input_size=*/8,
1318 /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1319 /*time_major=*/false,
1320 /*merge_outputs=*/false);
1321 rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
1322 rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
1323 rnn.SetFwBias(biases);
1324 rnn.SetBwBias(biases);
1325 rnn.SetFwRecurrentWeights(recurrent_weights);
1326 rnn.SetBwRecurrentWeights(recurrent_weights);
1327 rnn.SetAuxFwWeights(weights);
1328 rnn.SetAuxBwWeights(weights);
1329
1330 const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
1331 std::vector<float> zero_input(input_sequence_size, 0.f);
1332 float* batch_start = rnn_input;
1333 float* batch_end = batch_start + input_sequence_size;
1334 // Set batch 0 inputs
1335 rnn.SetInput(0, zero_input.data(), zero_input.data() + zero_input.size());
1336 rnn.SetAuxInput(0, batch_start, batch_end);
1337 // Set batch 1 inputs
1338 rnn.SetInput(input_sequence_size, zero_input.data(),
1339 zero_input.data() + zero_input.size());
1340 rnn.SetAuxInput(input_sequence_size, batch_start, batch_end);
1341
1342 rnn.Invoke();
1343
1344 float* golden_fw_start = rnn_golden_fw_output;
1345 float* golden_fw_end =
1346 golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
1347 std::vector<float> fw_expected;
1348 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1349 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1350 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1351
1352 float* golden_bw_start = rnn_golden_bw_output;
1353 float* golden_bw_end =
1354 golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
1355 std::vector<float> bw_expected;
1356 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1357 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1358 EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1359 }
1360
1361 // Same as ClosedBox test, but an input is passed to auxiliary input instead of
1362 // the regular one. Regular input and weights are set to zero. Time major inputs
1363 // and outputs.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingAuxInputOnlyTimeMajor)1364 TEST(BidirectionalRNNOpTest, ClosedBoxTestCrossLinkingAuxInputOnlyTimeMajor) {
1365 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1366 /*fw_units=*/16, /*bw_units=*/16,
1367 /*input_size=*/8, /*aux_input_size=*/8,
1368 /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1369 /*time_major=*/true,
1370 /*merge_outputs=*/false);
1371 rnn.SetFwWeights(std::vector<float>(weights.size(), 0.0));
1372 rnn.SetBwWeights(std::vector<float>(weights.size(), 0.0));
1373 rnn.SetFwBias(biases);
1374 rnn.SetBwBias(biases);
1375 rnn.SetFwRecurrentWeights(recurrent_weights);
1376 rnn.SetBwRecurrentWeights(recurrent_weights);
1377 rnn.SetAuxFwWeights(weights);
1378 rnn.SetAuxBwWeights(weights);
1379
1380 std::vector<float> zero_input(rnn.input_size(), 0.f);
1381
1382 // Insert the inputs in time_major format. The batch_major format is:
1383 // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
1384 // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
1385 for (int i = 0; i < rnn.sequence_len(); i++) {
1386 float* batch_start = rnn_input + i * rnn.input_size();
1387 float* batch_end = batch_start + rnn.input_size();
1388 // The two batches are identical.
1389 // Set batch 0 inputs
1390 rnn.SetInput(2 * i * rnn.input_size(), &zero_input.front(),
1391 &zero_input.back() + 1);
1392 rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end);
1393 // Set batch 1 inputs
1394 rnn.SetInput((2 * i + 1) * rnn.input_size(), &zero_input.front(),
1395 &zero_input.back() + 1);
1396 rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
1397 }
1398
1399 rnn.Invoke();
1400
1401 std::vector<float> fw_expected;
1402 for (int i = 0; i < rnn.sequence_len(); i++) {
1403 float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
1404 float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
1405 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1406 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1407 }
1408 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1409 }
1410
1411 // Same as ClosedBox test, but the input tensor and weights tensor are split
1412 // along the last dimension and passed to both regular and auxiliary inputs and
1413 // weights. The output in this case is the same. To understand this, let's
1414 // define W and V as regular input weights matrix and auxiliary input weights
1415 // matrix correspondingly. It's easy to see that this is equivalent to a regular
1416 // RNN with weights U = (W|V) and z^T = x^T | y^T, where .|. denotes
1417 // concatenation along horizontal axis:
1418 // f(z) = Uz + b
1419 // is equivalent to:
1420 // f((x^T|y^T)^T) = (Wx + Vy) + b.
run_closedbox_test_with_input_split(int input_size,int aux_input_size)1421 void run_closedbox_test_with_input_split(int input_size, int aux_input_size) {
1422 const int num_units = 16;
1423 BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
1424 /*fw_units=*/num_units, /*bw_units=*/num_units,
1425 input_size, aux_input_size,
1426 /*aux_input_mode=*/AuxInputMode::kCrossLinking,
1427 /*time_major=*/false,
1428 /*merge_outputs=*/false);
1429 std::vector<float> reg_weights(num_units * rnn.input_size());
1430 std::vector<float> aux_weights(num_units * rnn.aux_input_size());
1431 int full_weights_size = weights.size();
1432 int reg_weights_offset = 0;
1433 int aux_weights_offset = 0;
1434 int weights_offset = 0;
1435 // Alternating copying to regular input weights and auxiliary input weights to
1436 // split the original weight matrix in half along the last axis.
1437 while (weights_offset < full_weights_size) {
1438 std::copy(weights.begin() + weights_offset,
1439 weights.begin() + weights_offset + rnn.input_size(),
1440 reg_weights.begin() + reg_weights_offset);
1441 weights_offset += rnn.input_size();
1442 reg_weights_offset += rnn.input_size();
1443
1444 std::copy(weights.begin() + weights_offset,
1445 weights.begin() + weights_offset + rnn.aux_input_size(),
1446 aux_weights.begin() + aux_weights_offset);
1447 weights_offset += rnn.aux_input_size();
1448 aux_weights_offset += rnn.aux_input_size();
1449 }
1450
1451 rnn.SetFwWeights(reg_weights);
1452 rnn.SetBwWeights(reg_weights);
1453 rnn.SetFwBias(biases);
1454 rnn.SetBwBias(biases);
1455 rnn.SetFwRecurrentWeights(recurrent_weights);
1456 rnn.SetBwRecurrentWeights(recurrent_weights);
1457 rnn.SetAuxFwWeights(aux_weights);
1458 rnn.SetAuxBwWeights(aux_weights);
1459
1460 int full_input_size =
1461 (rnn.input_size() + rnn.aux_input_size()) * rnn.sequence_len();
1462 int reg_input_offset = 0;
1463 int aux_input_offset = 0;
1464 // Alternating copying to regular input tensor and auxiliary input tensor to
1465 // split the original input matrix in half along the last axis.
1466 for (int batch = 0; batch < 2; ++batch) {
1467 int input_offset = 0;
1468 while (input_offset < full_input_size) {
1469 rnn.SetInput(reg_input_offset, rnn_input + input_offset,
1470 rnn_input + input_offset + rnn.input_size());
1471 input_offset += rnn.input_size();
1472 reg_input_offset += rnn.input_size();
1473
1474 rnn.SetAuxInput(aux_input_offset, rnn_input + input_offset,
1475 rnn_input + input_offset + rnn.aux_input_size());
1476 input_offset += rnn.aux_input_size();
1477 aux_input_offset += rnn.aux_input_size();
1478 }
1479 }
1480
1481 rnn.Invoke();
1482
1483 float* golden_fw_start = rnn_golden_fw_output;
1484 float* golden_fw_end =
1485 golden_fw_start + rnn.num_fw_units() * rnn.sequence_len();
1486 std::vector<float> fw_expected;
1487 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1488 fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
1489 EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
1490
1491 float* golden_bw_start = rnn_golden_bw_output;
1492 float* golden_bw_end =
1493 golden_bw_start + rnn.num_bw_units() * rnn.sequence_len();
1494 std::vector<float> bw_expected;
1495 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1496 bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
1497 EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
1498 }
1499
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingRegularAndAuxInputEvenSplit)1500 TEST(BidirectionalRNNOpTest,
1501 ClosedBoxTestCrossLinkingRegularAndAuxInputEvenSplit) {
1502 run_closedbox_test_with_input_split(/*input_size=*/4, /*aux_input_size=*/4);
1503 }
1504
1505 // Same as above but the input tensor and the weights tensor are split unevenly.
TEST(BidirectionalRNNOpTest,ClosedBoxTestCrossLinkingRegularAndAuxInputUnevenSplit)1506 TEST(BidirectionalRNNOpTest,
1507 ClosedBoxTestCrossLinkingRegularAndAuxInputUnevenSplit) {
1508 run_closedbox_test_with_input_split(/*input_size=*/2, /*aux_input_size=*/6);
1509 }
1510
1511 } // namespace
1512 } // namespace tflite
1513