1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
19 #include "tensorflow/lite/micro/test_helpers.h"
20 #include "tensorflow/lite/micro/testing/micro_test.h"
21 namespace tflite {
22 namespace testing {
23 namespace {
24
25 // naming as follows: <tensor name>_<input size>x<batch size>x<batch count>
26
27 // 10 inputs each with shape {2, 2}.
28 const float input_data_2x2x10[] = {
29 0.12609188, -0.46347019, 0.35867718, 0.36897406,
30
31 0.14278367, -1.64410412, -0.57290924, 0.12729003,
32
33 0.49837467, 0.19278903, 0.17660543, 0.52949083,
34
35 -0.11186574, 0.13164264, -0.72674477, -0.5683046,
36
37 -0.68892461, 0.37783599, -0.63690937, 0.44483393,
38
39 -0.81299269, -0.86831826, -0.95760226, 1.82078898,
40
41 -1.45006323, -0.82251364, -1.65087092, -1.89238167,
42
43 0.03966608, -0.24936394, 2.06740379, -1.51439476,
44
45 0.11771342, -0.23761693, 0.31088525, -1.55601168,
46
47 -0.89477462, 1.67204106, -0.6230064, 0.29819036,
48 };
49
50 // Feature filter of shape {8, 2}.
51 const float feature_weights_data_2x2x10[] = {
52 -0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322,
53 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
54 -0.36118156, -0.0976817, -0.36916667, 0.22197971};
55
56 // Time filter of shape {8, 10}.
57 const float time_weights_data_2x2x10[] = {
58 -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
59 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
60
61 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
62 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
63
64 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
65 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
66
67 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
68 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
69
70 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
71 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
72
73 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
74 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
75
76 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
77 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
78
79 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
80 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763};
81
82 // Activation state with shape {2, 80}. These initial values must be copied into
83 // a mutable activation state tensor.
84
85 const float initial_activation_state_data_2x2x10[] = {
86 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
88 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
90 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
91 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
92 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
93
94 // Bias with shape {8}
95 const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};
96
97 // 10 outputs each of shape {2, 4}
98 const float golden_output_2x2x10[] = {
99 -0.044205, -0.013757, 0.050369, -0.018447,
100 0.073010, 0.025142, -0.021154, 0.013551,
101
102 -0.209613, -0.062421, 0.150209, -0.108334,
103 0.028256, -0.006950, -0.030885, 0.009603,
104
105 -0.076800, -0.037075, -0.087198, -0.155183,
106 0.091069, 0.098446, -0.016083, 0.106475,
107
108 -0.082123, -0.162238, -0.084434, -0.141074,
109 -0.029340, -0.090685, 0.053302, -0.030604,
110
111 -0.201440, 0.088424, 0.139877, 0.012416,
112 -0.113212, 0.103893, -0.100842, 0.122780,
113
114 -0.166632, -0.116705, 0.175298, -0.047163,
115 0.313077, -0.166485, -0.285860, 0.129069,
116
117 -0.625911, 0.046134, 0.138081, -0.129581,
118 -0.521455, -0.061579, 0.230289, 0.114963,
119
120 -0.216693, -0.161643, -0.179177, -0.052599,
121 -0.213239, 0.029502, 0.260858, 0.275045,
122
123 -0.213689, -0.323608, -0.285635, -0.317687,
124 -0.324092, -0.317972, -0.208450, -0.462504,
125
126 -0.255126, -0.218576, -0.041528, 0.179421,
127 -0.440583, 0.072127, -0.284136, 0.241570};
128
129 // Simulated real-world inputs, weights and expected outputs.
130
131 // Input of shape {1x16}
132 const float input_data_16x1x1[] = {
133 -0.488494, 2.023762, -2.233117, -0.488494, 3.559030, 9.490748,
134 -3.210106, -1.953977, -0.279140, 0.907204, 1.674838, 0.000000,
135 -0.279140, -0.628064, -0.069785, -0.628064,
136 };
137
138 // Feature filter of shape {64, 16}.
139 const float feature_weights_data_16x1x1[] = {
140 0.173588, 0.173588, -0.024798, 0.193426, -0.099193, 0.044637, 0.183507,
141 0.183507, 0.044637, 0.198386, -0.069435, 0.084314, 0.312458, 0.024798,
142 0.173588, -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
143 -0.128951, 0.193426, 0.357095, -0.317418, -0.119032, -0.218225, -0.004960,
144 -0.386853, -0.133911, 0.252942, -0.019839, -0.024798, -0.054556, -0.069435,
145 -0.128951, 0.029758, -0.099193, -0.312458, -0.029758, 0.064475, 0.183507,
146 0.114072, -0.178547, -0.247982, -0.119032, 0.243023, -0.119032, -0.034718,
147 -0.178547, 0.019839, 0.128951, -0.223184, -0.009919, -0.213265, 0.168628,
148 -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
149 -0.267821, -0.580279, -0.099193, 0.213265, 0.119032, -0.119032, -0.178547,
150 0.610037, 0.109112, 0.049596, -0.014879, -0.049596, -0.193426, 0.039677,
151 -0.148789, -0.114072, -0.158709, -0.158709, 0.094233, 0.099193, -0.114072,
152 0.104153, -0.123991, 0.198386, -0.173588, 0.089274, -0.247982, -0.054556,
153 0.123991, 0.183507, 0.114072, 0.188467, 0.302539, 0.044637, 0.039677,
154 -0.099193, 0.168628, -0.024798, -0.054556, -0.109112, 0.014879, -0.009919,
155 0.069435, -0.396772, -0.287660, -0.079354, -0.104153, 0.054556, 0.089274,
156 -0.099193, 0.114072, 0.034718, 0.119032, 0.282700, -0.119032, -0.505884,
157 -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749, 0.128951,
158 0.143830, -0.188467, -0.183507, 0.104153, -0.024798, 0.193426, -0.287660,
159 0.168628, -0.009919, 0.119032, -0.024798, -0.099193, -0.203346, 0.099193,
160 0.084314, -0.168628, 0.123991, -0.148789, 0.114072, -0.029758, 0.228144,
161 -0.238063, 0.089274, -0.064475, 0.307498, -0.188467, -0.004960, -0.252942,
162 -0.173588, -0.158709, -0.044637, -0.009919, 0.312458, -0.262861, 0.059516,
163 0.158709, 0.069435, -0.282700, 0.074395, -0.322377, -0.183507, -0.123991,
164 -0.233104, 0.009919, 0.252942, -0.243023, 0.555481, -0.099193, -0.119032,
165 -0.441409, 0.148789, 0.084314, -0.168628, -0.183507, 0.188467, 0.024798,
166 -0.302539, 0.223184, 0.143830, -0.193426, -0.054556, -0.218225, -0.297579,
167 0.104153, 0.272781, -0.034718, 0.114072, -0.059516, 0.044637, 0.342216,
168 0.421570, 0.138870, -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
169 -0.128951, -0.044637, -0.173588, 0.302539, 0.079354, 0.049596, 0.133911,
170 -0.029758, -0.312458, -0.029758, 0.079354, 0.128951, 0.252942, 0.213265,
171 0.014879, 0.287660, 0.178547, 0.297579, 0.352135, 0.401732, 0.024798,
172 -0.277740, -0.411651, -0.069435, 0.342216, -0.158709, -0.104153, -0.009919,
173 0.223184, 0.228144, -0.019839, 0.059516, -0.104153, -0.510844, 0.029758,
174 -0.406691, 0.089274, 0.421570, 0.163668, -0.143830, -0.019839, -0.039677,
175 0.104153, -0.044637, -0.128951, 0.203346, 0.079354, -0.069435, 0.094233,
176 -0.138870, 0.466207, -0.163668, 0.049596, 0.029758, 0.267821, 0.029758,
177 -0.049596, 0.009919, 0.004960, -0.099193, 0.094233, -0.262861, 0.089274,
178 -0.302539, 0.332297, -0.307498, -0.014879, 0.168628, -0.094233, -0.272781,
179 0.034718, -0.133911, -0.228144, 0.094233, 0.257902, -0.228144, 0.153749,
180 -0.054556, -0.252942, 0.054556, 0.218225, -0.054556, 0.302539, 0.282700,
181 0.054556, -0.044637, -0.133911, 0.233104, -0.049596, 0.411651, 0.044637,
182 -0.297579, -0.029758, -0.114072, 0.114072, -0.580279, 0.079354, -0.024798,
183 -0.347175, -0.128951, -0.099193, 0.238063, -0.104153, -0.009919, 0.158709,
184 -0.034718, 0.123991, -0.163668, 0.059516, 0.342216, 0.009919, 0.064475,
185 -0.307498, -0.520763, -0.238063, 0.163668, 0.362054, 0.034718, -0.178547,
186 -0.104153, -0.257902, 0.322377, 0.054556, 0.148789, -0.178547, 0.084314,
187 0.004960, 0.257902, 0.029758, 0.079354, -0.223184, -0.193426, 0.282700,
188 0.000000, -0.019839, -0.114072, 0.491005, -0.193426, -0.029758, -0.243023,
189 0.009919, 0.089274, -0.277740, -0.089274, 0.104153, 0.337256, 0.138870,
190 -0.307498, -0.054556, 0.352135, 0.133911, -0.044637, 0.133911, -0.089274,
191 -0.357095, -0.272781, 0.069435, 0.059516, -0.109112, 0.148789, -0.044637,
192 -0.019839, -0.153749, 0.123991, -0.223184, 0.322377, 0.074395, -0.312458,
193 0.024798, -0.223184, 0.109112, -0.138870, 0.218225, -0.074395, -0.406691,
194 0.009919, -0.198386, -0.009919, 0.416611, 0.178547, 0.148789, 0.133911,
195 -0.004960, 0.069435, -0.054556, -0.044637, 0.297579, 0.059516, -0.456288,
196 -0.148789, -0.004960, 0.054556, 0.094233, -0.104153, 0.198386, -0.302539,
197 0.133911, 0.411651, 0.054556, 0.525723, -0.089274, 0.079354, 0.238063,
198 0.079354, -0.039677, 0.039677, 0.029758, 0.332297, -0.014879, -0.367014,
199 -0.143830, -0.123991, -0.064475, 0.014879, 0.173588, -0.168628, 0.386853,
200 0.009919, 0.173588, 0.163668, 0.123991, 0.163668, 0.198386, 0.203346,
201 -0.401732, -0.009919, 0.272781, -0.173588, 0.044637, 0.238063, 0.133911,
202 0.049596, 0.208305, -0.024798, 0.049596, -0.049596, 0.034718, -0.446368,
203 0.466207, -0.089274, -0.099193, -0.128951, -0.228144, 0.014879, -0.252942,
204 0.074395, -0.223184, -0.168628, -0.292619, 0.178547, 0.153749, -0.014879,
205 0.054556, 0.000000, 0.193426, 0.158709, 0.178547, -0.327337, -0.138870,
206 -0.114072, 0.168628, 0.297579, -0.109112, -0.029758, -0.029758, -0.416611,
207 0.059516, 0.000000, -0.168628, -0.322377, 0.238063, -0.128951, -0.029758,
208 0.500925, 0.292619, 0.123991, -0.099193, 0.074395, 0.317418, -0.148789,
209 0.064475, -0.104153, -0.044637, -0.094233, 0.188467, -0.044637, 0.213265,
210 -0.233104, -0.049596, 0.004960, -0.198386, 0.287660, -0.148789, -0.257902,
211 0.004960, -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
212 0.029758, -0.019839, -0.009919, -0.143830, -0.158709, 0.158709, -0.243023,
213 -0.039677, -0.297579, 0.069435, 0.049596, 0.302539, 0.059516, 0.074395,
214 -0.019839, 0.352135, -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
215 0.252942, -0.228144, -0.049596, 0.173588, 0.173588, -0.074395, -0.034718,
216 -0.292619, 0.362054, 0.183507, 0.243023, -0.203346, -0.044637, 0.054556,
217 0.059516, -0.158709, -0.158709, 0.000000, 0.327337, 0.119032, 0.034718,
218 -0.044637, -0.089274, 0.089274, -0.233104, 0.000000, -0.317418, 0.371974,
219 0.213265, 0.307498, -0.178547, -0.367014, 0.039677, -0.059516, 0.168628,
220 -0.014879, 0.143830, 0.123991, -0.084314, -0.332297, -0.416611, 0.183507,
221 0.109112, -0.039677, 0.014879, 0.292619, -0.213265, -0.054556, 0.004960,
222 0.123991, 0.119032, 0.000000, -0.332297, -0.312458, -0.198386, -0.213265,
223 0.119032, 0.322377, 0.168628, 0.104153, -0.262861, 0.327337, -0.049596,
224 -0.228144, -0.074395, 0.168628, 0.123991, 0.396772, 0.044637, 0.322377,
225 0.193426, 0.267821, -0.178547, 0.297579, 0.148789, -0.218225, -0.138870,
226 0.044637, 0.049596, 0.133911, 0.064475, 0.069435, 0.064475, -0.158709,
227 -0.044637, -0.173588, 0.267821, 0.327337, 0.079354, -0.228144, 0.029758,
228 0.014879, 0.198386, -0.109112, -0.133911, 0.431490, 0.099193, 0.421570,
229 0.233104, -0.054556, 0.054556, -0.317418, -0.133911, -0.123991, -0.287660,
230 0.342216, -0.049596, -0.153749, 0.228144, -0.213265, 0.262861, 0.406691,
231 -0.084314, -0.004960, 0.193426, 0.188467, -0.099193, -0.223184, 0.163668,
232 -0.257902, -0.153749, 0.441409, 0.099193, 0.128951, -0.089274, -0.208305,
233 -0.009919, -0.004960, -0.109112, 0.024798, -0.119032, 0.019839, 0.391812,
234 -0.024798, 0.198386, 0.327337, -0.505884, -0.099193, 0.510844, -0.148789,
235 0.094233, -0.153749, -0.039677, 0.352135, 0.272781, -0.228144, -0.287660,
236 -0.272781, 0.148789, 0.277740, 0.074395, 0.109112, -0.064475, 0.044637,
237 0.074395, -0.292619, 0.153749, -0.064475, -0.114072, 0.198386, -0.039677,
238 -0.128951, -0.004960, 0.257902, -0.228144, -0.094233, 0.064475, 0.014879,
239 0.188467, -0.416611, 0.099193, 0.362054, -0.208305, 0.198386, -0.079354,
240 0.009919, 0.119032, 0.332297, 0.243023, -0.168628, 0.158709, 0.039677,
241 0.143830, 0.277740, -0.168628, 0.009919, 0.099193, -0.004960, -0.257902,
242 -0.297579, 0.208305, -0.104153, 0.119032, 0.247982, 0.381893, -0.223184,
243 -0.367014, -0.327337, -0.168628, -0.094233, 0.208305, -0.019839, 0.183507,
244 0.084314, 0.133911, 0.109112, -0.148789, -0.183507, -0.411651, -0.024798,
245 -0.114072, -0.029758, -0.009919, 0.173588, -0.059516, -0.049596, 0.039677,
246 0.317418, 0.138870, -0.247982, -0.084314, 0.158709, 0.054556, -0.084314,
247 -0.049596, 0.074395, 0.019839, -0.282700, -0.119032, -0.262861, 0.163668,
248 -0.069435, -0.064475, -0.059516, 0.094233, 0.123991, -0.079354, -0.272781,
249 -0.267821, 0.233104, 0.114072, -0.218225, 0.540602, 0.089274, 0.262861,
250 0.079354, 0.267821, -0.119032, -0.109112, -0.128951, 0.128951, -0.044637,
251 -0.272781, 0.277740, 0.297579, -0.054556, -0.084314, -0.049596, 0.123991,
252 0.059516, 0.238063, -0.168628, -0.009919, 0.163668, -0.307498, 0.109112,
253 -0.064475, 0.218225, -0.168628, -0.004960, -0.168628, 0.119032, 0.094233,
254 -0.183507, -0.089274, -0.292619, -0.094233, 0.064475, -0.183507, -0.168628,
255 0.089274, 0.074395, -0.367014, -0.024798, -0.069435, 0.119032, -0.302539,
256 -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
257 -0.183507, -0.238063, 0.163668, -0.332297, -0.148789, -0.391812, -0.024798,
258 -0.133911, -0.059516, -0.123991, 0.123991, -0.292619, -0.044637, 0.059516,
259 -0.069435, 0.049596, -0.069435, 0.034718, 0.158709, -0.347175, -0.044637,
260 0.352135, -0.347175, -0.282700, -0.054556, 0.307498, 0.029758, 0.357095,
261 -0.148789, 0.208305, -0.317418, 0.009919, 0.004960, -0.243023, 0.049596,
262 -0.099193, 0.213265, -0.342216, 0.158709, 0.123991, -0.332297, 0.386853,
263 -0.262861, -0.208305, 0.123991, -0.044637, 0.148789, 0.084314, -0.297579,
264 -0.307498, -0.163668, 0.337256, -0.014879, 0.074395, 0.178547, -0.004960,
265 -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
266 -0.153749, 0.629876, 0.277740, 0.178547, -0.267821, -0.004960, 0.247982,
267 0.084314, -0.094233, 0.000000, -0.039677, 0.332297, 0.178547, 0.009919,
268 -0.213265, -0.208305, -0.044637, 0.019839, 0.218225, -0.297579, 0.014879,
269 -0.247982, -0.004960, -0.128951, 0.421570, -0.059516, 0.362054, -0.203346,
270 -0.143830, -0.099193, -0.024798, 0.094233, -0.123991, 0.163668, 0.109112,
271 -0.104153, -0.233104, 0.009919, -0.218225, 0.376933, 0.104153, -0.059516,
272 0.049596, -0.054556, 0.019839, -0.044637, -0.019839, 0.371974, -0.019839,
273 0.104153, 0.168628, -0.024798, -0.272781, -0.158709, 0.223184, 0.044637,
274 0.039677, -0.168628, -0.287660, -0.109112, 0.094233, -0.089274, -0.148789,
275 0.178547, -0.039677, -0.089274, -0.049596, -0.024798, 0.064475, -0.158709,
276 0.089274, 0.029758, -0.247982, 0.362054, 0.024798, -0.004960, -0.099193,
277 0.173588, -0.059516, 0.188467, -0.629876, 0.094233, 0.371974, 0.069435,
278 0.252942, -0.357095, -0.272781, -0.367014, 0.014879, -0.049596, -0.262861,
279 0.009919, -0.094233, -0.094233, 0.059516, 0.223184, 0.133911, 0.411651,
280 -0.044637, -0.044637, 0.109112, 0.228144, 0.386853, -0.233104, 0.069435,
281 0.228144, -0.302539, 0.029758, 0.089274, 0.044637, -0.238063, -0.138870,
282 -0.158709, -0.019839, 0.049596, 0.039677, 0.000000, -0.069435, 0.109112,
283 -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911, 0.391812,
284 0.123991, -0.317418, 0.233104, -0.029758, -0.099193, -0.193426, 0.074395,
285 -0.009919, 0.252942, 0.322377, -0.530683, 0.208305, 0.252942, 0.203346,
286 -0.069435, -0.262861};
287
288 // Time filter of shape {64, 8}.
289 const float time_weights_data_16x1x1[] = {
290 -0.052026, 0.043107, 0.053512, 0.013378, 0.011892, -0.182834, -0.108511,
291 0.153105, 0.050539, -0.173915, 0.145672, 0.208103, -0.221481, 0.108511,
292 -0.496475, 0.181347, -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
293 -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
294 -0.057972, -0.057972, -0.497962, 0.054999, 0.181347, 0.047566, -0.099592,
295 -0.111484, -0.130808, -0.071350, 0.380532, 0.010405, 0.041621, 0.052026,
296 0.022297, 0.081755, 0.098106, 0.099592, -0.584176, -0.023783, 0.062431,
297 -0.090674, -0.279453, -0.486070, -0.273507, 0.004459, -0.062431, 0.095133,
298 0.056485, 0.022297, -0.105538, -0.184320, 0.358235, 0.254183, 0.049053,
299 0.084728, 0.218508, 0.078782, -0.136754, -0.017837, -0.124862, -0.118916,
300 -0.001486, 0.043107, 0.254183, 0.087701, 0.261616, 0.309182, -0.404315,
301 -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
302 -0.046080, -0.062431, 0.020810, 0.040134, -0.135267, -0.169456, -0.050539,
303 -0.576743, 0.034188, 0.075809, 0.101079, 0.136754, 0.083241, 0.077296,
304 -0.050539, 0.761064, -0.335938, -0.080268, 0.025270, 0.257156, 0.227427,
305 0.252697, 0.065404, 0.115943, 0.222968, -0.026756, -0.054999, 0.107025,
306 -0.093646, 0.041621, -0.092160, -0.474178, -0.016351, 0.004459, 0.049053,
307 0.019324, 0.019324, 0.074323, 0.038648, -0.613905, 0.182834, 0.075809,
308 0.028243, 0.019324, 0.010405, -0.011892, 0.001486, -0.492016, -0.224454,
309 -0.474178, -0.147159, 0.002973, 0.102565, 0.136754, -0.267561, -0.001486,
310 -0.095133, -0.040134, 0.066890, 0.074323, 0.104052, 0.532150, 0.090674,
311 0.072836, -0.053512, -0.004459, 0.020810, 0.046080, 0.062431, 0.477151,
312 0.133781, -0.029729, -0.026756, 0.031215, 0.156077, 0.096619, 0.251210,
313 0.352289, 0.657012, 0.047566, -0.014865, -0.072836, -0.016351, 0.008919,
314 -0.053512, 0.016351, 0.300263, 0.047566, 0.020810, 0.169456, 0.001486,
315 0.007432, 0.111484, 0.044594, -0.188779, -0.096619, 0.074323, -0.040134,
316 0.160537, 0.138240, 0.184320, 0.377559, -0.092160, -0.049053, 0.056485,
317 -0.032702, 0.001486, -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
318 0.026756, 0.163510, 0.172428, 0.127835, -0.199185, -0.218508, -0.057972,
319 -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
320 0.248238, 0.191752, 0.221481, 0.173915, 0.173915, -0.208103, -0.077296,
321 0.384991, -0.313641, -0.313641, -0.147159, -0.090674, 0.035675, 0.059458,
322 -0.010405, 0.019324, 0.087701, 0.016351, 0.037161, 0.469719, -0.074323,
323 0.092160, 0.026756, 0.090674, 0.098106, 0.004459, -0.034188, 0.492016,
324 -0.367154, -0.093646, -0.063917, 0.041621, 0.017837, 0.026756, -0.062431,
325 -0.350803, 0.425125, 0.002973, 0.083241, 0.075809, 0.016351, 0.047566,
326 -0.185807, -0.107025, -0.098106, -0.144186, 0.255670, 0.020810, 0.105538,
327 0.029729, 0.129321, 0.156077, 0.141213, 0.334452, 0.147159, -0.066890,
328 0.035675, 0.115943, 0.240805, 0.328506, 0.162023, -0.237832, 0.218508,
329 0.233373, 0.214049, 0.099592, 0.026756, -0.322560, -0.236346, -0.166483,
330 0.225941, 0.109997, -0.147159, 0.147159, -0.266075, 0.111484, 0.078782,
331 -0.120403, 0.022297, -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
332 -0.023783, 0.016351, 0.026756, -0.013378, -0.069863, -0.112970, 0.013378,
333 0.086214, 0.014865, 0.352289, -0.240805, -0.135267, -0.114457, -0.472692,
334 0.334452, 0.095133, 0.047566, 0.130808, -0.068377, -0.007432, -0.130808,
335 -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000, -0.028243,
336 0.029729, -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
337 -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
338 -0.047566, -0.107025, 0.004459, 0.053512, 0.047566, -0.358235, -0.193239,
339 0.040134, -0.096619, -0.054999, 0.099592, 0.032702, 0.205130, -0.170942,
340 -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
341 -0.074323, 0.078782, 0.607959, -0.437017, -0.164996, -0.166483, 0.043107,
342 -0.016351, 0.258643, 0.065404, -0.057972, 0.017837, 0.080268, 0.050539,
343 -0.013378, -0.215536, -0.524718, 0.260129, 0.040134, -0.002973, -0.046080,
344 0.020810, 0.025270, 0.145672, 0.515799, 0.233373, 0.011892, 0.139727,
345 0.126348, 0.065404, -0.007432, -0.008919, 0.035675, 0.083241, 0.040134,
346 -0.005946, 0.503907, -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
347 0.133781, -0.011892, 0.041621, 0.062431, -0.062431, -0.040134, -0.092160,
348 -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161, -0.092160,
349 -0.056485, -0.041621, 0.112970, 0.248238, 0.438503, 0.258643, -0.013378,
350 0.004459, 0.043107, 0.040134, 0.017837, 0.101079, 0.264589, 0.212563,
351 0.014865, 0.285399, 0.153105, 0.170942, 0.358235, 0.334452, 0.086214,
352 0.132294, 0.098106, -0.001486, 0.107025, 0.200671, -0.026756, 0.344857,
353 0.227427, -0.041621, 0.098106, 0.063917, -0.093646, 0.130808, 0.285399,
354 -0.319587, 0.035675, -0.017837, -0.319587, 0.016351, -0.098106, -0.017837,
355 0.083241, 0.074323, -0.054999, 0.276480, 0.316614, -0.099592, -0.059458,
356 0.156077, -0.043107, 0.035675, 0.056485, -0.022297, 0.017837, -0.001486,
357 0.340398, 0.492016, 0.004459, 0.057972, -0.150132, -0.206617, -0.257156,
358 -0.248238, -0.080268, -0.164996, 0.352289, -0.054999, -0.056485, 0.010405,
359 -0.049053, -0.041621, -0.099592, 0.013378, -0.089187, 0.057972, -0.413234,
360 0.217022, 0.013378, -0.080268, -0.035675, 0.035675, 0.007432, 0.002973,
361 -0.469719, 0.141213, 0.136754, 0.153105, 0.130808, -0.104052, -0.508367,
362 -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
363 0.484583};
364
365 // Bias of shape {64}
366 const float bias_data_16x1x1[] = {
367 -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
368 -0.378594, 0.178152, 0.400380, -0.301349, -0.240913, -0.159454, -0.158757,
369 -0.073665, 0.455906, -0.061232, 0.318907, -0.226993, -0.344644, 0.140316,
370 0.559608, 0.109774, 0.437391, 0.113849, -0.162068, 0.039572, 0.569472,
371 0.460205, 0.113459, 0.370469, 0.176811, 0.203063, -0.296975, -0.271655,
372 0.059862, -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
373 0.083172, 0.029040, -0.236288, -0.267054, -0.166627, 0.188319, -0.271391,
374 -0.222920, 0.106463, 0.263614, 0.384986, -0.125957, -0.095890, 0.363686,
375 -0.036990, -0.358884, -0.178254, 0.305596, 0.390088, -0.189437, 0.613409,
376 0.399639};
377
378 // Activation state with shape {64, 8}. These initial values must be copied into
379 // a mutable activation state tensor.
380 const float initial_activation_state_data_16x1x1[] = {
381 -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
382 -1.184425, -0.462094, -1.443421, 0.230736, -0.494701, -0.354955, -2.534061,
383 -4.277471, -4.218467, 0.403711, -0.248748, -0.330111, -0.467683, 0.549047,
384 0.733511, -0.230115, 0.793136, -1.126353, -0.984123, -0.081984, -0.222351,
385 0.692830, 0.517060, 1.367958, 2.118860, -0.116766, -0.826365, -2.402700,
386 -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490, 0.085400,
387 -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
388 1.109273, 1.693411, -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
389 0.253406, -0.357439, -1.593415, -0.879779, -1.111136, 1.821357, 2.471952,
390 1.236908, -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
391 -0.319242, 0.749349, 1.156476, 0.658670, 1.997437, 2.080663, 2.912618,
392 2.677224, 2.642442, 2.796163, -0.272349, -0.473273, 3.120063, 2.747097,
393 3.595510, 1.874150, 2.049919, 2.093396, -1.049959, 0.277939, -1.255541,
394 -1.052443, -1.810177, -0.883505, -0.538178, 0.524203, -1.017662, -0.269244,
395 0.039129, -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
396 0.102480, 0.349986, 0.405885, 1.287216, 0.756181, 0.319242, -0.641590,
397 -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
398 -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
399 0.039439, 0.672023, 1.408019, 1.362679, 1.467644, 1.006171, 0.310236,
400 -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
401 0.051551, 0.232600, 0.088816, 2.570395, 0.704009, 2.465120, 3.010751,
402 2.139357, 0.630410, 1.006171, 1.545281, 1.486898, -1.162998, -2.344317,
403 -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
404 -0.368930, -0.162726, 0.396879, 0.505570, 0.534451, 0.554947, 1.270447,
405 0.388805, 0.531967, -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
406 -1.164550, 0.609603, 3.293348, 2.600208, 1.454290, -1.034121, -1.760179,
407 -1.192500, -0.613951, 3.449553, 2.912618, 1.917937, 1.435968, 0.879158,
408 1.118279, 0.102791, -0.502465, -0.239121, -0.092853, 1.786265, 1.943091,
409 2.547104, 2.630641, 2.585302, 2.965411, -0.945615, -2.538720, -2.474126,
410 -1.088156, 0.056209, 0.864873, 0.170490, 0.457435, 0.545941, 0.752765,
411 1.569503, 1.129459, 0.662086, -0.527929, -0.810838, -1.662978, 1.285042,
412 1.653040, 4.130893, 2.961995, 4.147041, 3.256393, 3.881524, 2.522571,
413 -0.875431, -1.112378, 2.105817, 2.180970, 3.121926, 1.577577, 1.639376,
414 2.906407, -0.142230, 0.421101, 2.212335, 2.311399, 3.993321, 3.651719,
415 4.206666, 4.678387, -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
416 -1.197158, -0.949652, -0.282908, 0.320795, -1.543728, 1.290322, 1.788128,
417 3.957297, 3.205774, 2.892432, 2.297114, 0.138814, -0.139435, 0.936920,
418 0.344707, 0.723263, -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
419 -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
420 3.858543, 2.499901, 1.087535, 0.290051, -0.026086, -0.880400, -2.602692,
421 -1.404292, 0.253096, -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
422 1.846200, -1.604284, -1.244671, -0.464888, 0.326385, 0.168006, -0.262723,
423 -0.744691, 0.953379, -0.407127, -0.349986, -1.154302, 0.831023, 1.590931,
424 2.538720, 2.063583, 3.697680, -0.752455, -1.293117, -1.330693, -1.869802,
425 -0.592523, 0.631652, 1.198089, -0.481347, 3.738983, 4.153252, 2.782499,
426 2.244321, 0.709289, 1.650245, 1.700865, 0.385078, 2.192460, 2.610456,
427 4.009780, 3.492719, 2.574743, 2.116687, 1.856138, 1.205853, 2.722563,
428 4.075305, 5.415935, 3.009198, 2.715421, 1.571056, 0.897170, -2.430339,
429 0.749970, 0.425760, -0.302783, 0.817359, 1.031636, 1.913589, 2.686229,
430 1.631923, -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
431 -1.805519, -0.486627, 0.909591, 2.082837, -1.473855, -2.456735, -3.851401,
432 -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
433 -0.245642, -0.588486, 0.033850, 2.084700, 0.076084, 0.690035, 0.747797,
434 0.594697, -1.016109, -1.348083, -1.201195, -1.088466, 2.045571, 2.460772,
435 0.717984, 0.041613, -0.721711, 1.134738, 2.322269, 1.112378, -0.307441,
436 -0.581033, -0.868599, -0.018633, 0.856488, 0.919839, 0.303094, -0.433213,
437 0.811148, -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
438 -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205, 0.622646,
439 0.145957, 0.359303, 1.012072, 0.774814, -0.400295, -1.484103, -2.007374,
440 -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451, 0.087264,
441 -0.227320, -1.211753, -1.532859, -1.688753, 0.065215, 0.134777, 0.608051,
442 -0.393152, -0.214588, -0.635689, -1.499320, 0.069562, -1.555839, -2.633126,
443 -2.966032, -1.550870, -0.101549, 0.874189, 0.436318, 0.299367, 2.289972,
444 2.339659, 2.602071, 1.564535, 0.019254, -0.583207, -1.295912, -2.424749,
445 -1.221070, -1.175109, -0.577306, -0.102791, 1.877876, 2.568222, 2.173827,
446 3.131243, 2.637784, 2.088737, 3.679047, 3.218506, 2.483442, 1.650556,
447 1.363611, -0.027328, 1.486898, -0.721711, -3.684327, -3.006093, -3.777491,
448 -2.327548, -2.737470, -4.549510, -0.060867, 0.127635, 0.680408, 0.581344,
449 0.320174, -0.403090, -0.838166, 0.293777, -0.995613, -0.165521, -0.419859,
450 1.110515, 1.203679, 1.749931, 2.467294, 4.276539, 0.031055, -0.967664,
451 1.167035, 1.865144, 3.221923, 3.248630, 4.121266, 4.187723, 0.749039,
452 -1.571056, 0.785994, 1.568572, 3.759479, 3.588678, 4.116608, 3.864444,
453 -0.290051, -0.271107, 0.375140, 0.537556, 0.536314, 0.095959, 0.054656,
454 0.088816};
455
456 // One output with shape {1, 64}
457 const float golden_output_16x1x1[] = {
458 -0.087914, 1.145864, -0.418088, -1.556392, -0.925298, 0.205252, 0.289119,
459 1.331180, -0.218010, 0.963057, -2.225886, 1.248478, 1.448983, 0.355467,
460 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, -2.975136, 0.222774,
461 0.241589, -0.104216, 1.561748, 0.936818, -0.089907, -0.520117, -0.870353,
462 1.606074, 0.895770, 0.521297, -0.369994, -0.889351, -2.809309, 2.404628,
463 1.069754, -0.195456, -1.105652, 1.272715, -1.233177, 1.271416, -1.691805,
464 -1.058125, -0.716227, 0.052540, 1.262483, 0.540555, 1.735760, -0.539197,
465 -0.014367, -0.243002, 1.072254, 0.528985, -0.731151, -1.262649, 2.338702,
466 -0.603093, 0.970736, -3.567897, 0.035085, -0.201711, -0.550400, 1.545573,
467 -1.805005};
468
469 // One output with shape {1, 64}
470 const float golden_output_relu_16x1x1[] = {
471 0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
472 1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
473 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
474 0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
475 1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
476 1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
477 0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
478 0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
479 0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
480 0.000000};
481
482 template <typename T>
ValidateSVDFGoldens(const int batch_size,const int num_units,const int input_size,const int rank,TfLiteTensor * tensors,const int tensor_count,TfLiteFusedActivation activaiton,const T * input_sequences_data,const int input_sequences_len,T * output_data,const T * expected_output,float tolerance=1e-5f)483 void ValidateSVDFGoldens(const int batch_size, const int num_units,
484 const int input_size, const int rank,
485 TfLiteTensor* tensors, const int tensor_count,
486 TfLiteFusedActivation activaiton,
487 const T* input_sequences_data,
488 const int input_sequences_len, T* output_data,
489 const T* expected_output, float tolerance = 1e-5f) {
490 TfLiteSVDFParams params;
491 params.rank = rank;
492 params.activation = activaiton;
493
494 int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
495 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
496
497 int outputs_array_data[] = {1, 5};
498 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
499
500 const TfLiteRegistration registration = Register_SVDF();
501 micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
502 outputs_array, ¶ms);
503
504 TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
505 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
506
507 // Abort early to make it clear init and prepare failed.
508 if (init_and_prepare_status != kTfLiteOk) {
509 return;
510 }
511
512 int num_inputs = input_sequences_len / (input_size * batch_size);
513
514 for (int i = 0; i < num_inputs; ++i) {
515 const T* input_batch_start =
516 input_sequences_data + i * input_size * batch_size;
517
518 memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
519 TfLiteStatus status = runner.Invoke();
520 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
521
522 // Only validate outputs when invoke has succeeded.
523 if (status == kTfLiteOk) {
524 int output_idx = 0;
525 int golden_idx = i * batch_size * num_units;
526 for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
527 TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
528 tolerance);
529 output_idx++;
530 }
531 }
532 }
533 }
534
535 #if !defined(XTENSA) // Needed to avoid build errors from unused functions.
TestSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,float * input_data,const float * feature_weights_data,const float * time_weights_data,float * activation_state_data,const float * bias_data,float * scratch_data,float * output_data,const float * input_sequences_data,int input_sequences_len,const float * expected_output,float tolerance=1e-5f)536 void TestSVDF(const int batch_size, const int num_units, const int input_size,
537 const int memory_size, const int rank,
538 TfLiteFusedActivation activation, float* input_data,
539 const float* feature_weights_data, const float* time_weights_data,
540 float* activation_state_data, const float* bias_data,
541 float* scratch_data, float* output_data,
542 const float* input_sequences_data, int input_sequences_len,
543 const float* expected_output, float tolerance = 1e-5f) {
544 const int num_filters = num_units * rank;
545
546 const int input_dims_arg[] = {2, batch_size, input_size};
547 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
548
549 const int feature_weights_dims_args[] = {2, num_filters, input_size};
550 TfLiteIntArray* feature_weights_dims =
551 IntArrayFromInts(feature_weights_dims_args);
552
553 const int time_weights_dims_args[] = {2, num_filters, memory_size};
554 TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
555
556 const int activation_state_dims_args[] = {2, batch_size,
557 memory_size * num_filters};
558 TfLiteIntArray* activation_state_dims =
559 IntArrayFromInts(activation_state_dims_args);
560
561 const int bias_dims_args[] = {1, num_units};
562 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);
563
564 const int output_dims_args[] = {2, batch_size, num_units};
565 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
566
567 const int tensor_count = 6; // 5 inputs, 1 output
568 TfLiteTensor tensors[] = {
569 CreateTensor(input_data, input_dims),
570 CreateTensor(feature_weights_data, feature_weights_dims),
571 CreateTensor(time_weights_data, time_weights_dims),
572 CreateTensor(bias_data, bias_dims),
573 CreateTensor(activation_state_data, activation_state_dims,
574 /*is_variable=*/true),
575 CreateTensor(output_data, output_dims),
576 };
577
578 ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
579 tensor_count, activation, input_sequences_data,
580 input_sequences_len, output_data, expected_output,
581 tolerance);
582 }
583 #endif
584
585 // The pattern to this method's arguemnts is:
586 // <kernel metadata>
587 // for each tensor in
588 // {input, feature weights, time weights, bias, activation state, output}:
589 // <tensor float values> <tensor quantized buffer> <tensor quantization data>
TestIntegerSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,int8_t * input_quantized,float input_scale,int input_zero_point,const float * feature_weights_data,int8_t * feature_weights_quantized,const float feature_weights_scale,const float * time_weights_data,int16_t * time_weights_quantized,float time_weights_scale,const float * bias_data,int32_t * bias_quantized,const float * initial_activation_state_data,int16_t * activation_state_quantized,float activation_state_scale,int8_t * output_data,float output_scale,int output_zero_point,const float * input_sequences_data,int8_t * input_sequences_quantized,const int input_sequences_len,const float * golden_output,int8_t * golden_output_quantized,int golden_output_len)590 inline void TestIntegerSVDF(
591 const int batch_size, const int num_units, const int input_size,
592 const int memory_size, const int rank, TfLiteFusedActivation activation,
593 int8_t* input_quantized, float input_scale, int input_zero_point,
594 const float* feature_weights_data, int8_t* feature_weights_quantized,
595 const float feature_weights_scale, const float* time_weights_data,
596 int16_t* time_weights_quantized, float time_weights_scale,
597 const float* bias_data, int32_t* bias_quantized,
598 const float* initial_activation_state_data,
599 int16_t* activation_state_quantized, float activation_state_scale,
600 int8_t* output_data, float output_scale, int output_zero_point,
601 const float* input_sequences_data, int8_t* input_sequences_quantized,
602 const int input_sequences_len, const float* golden_output,
603 int8_t* golden_output_quantized, int golden_output_len) {
604 const int num_filters = num_units * rank;
605
606 const int input_dims_arg[] = {2, batch_size, input_size};
607 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
608
609 const int feature_weights_dims_args[] = {2, num_filters, input_size};
610 TfLiteIntArray* feature_weights_dims =
611 IntArrayFromInts(feature_weights_dims_args);
612
613 const int time_weights_dims_args[] = {2, num_filters, memory_size};
614 TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
615
616 const int bias_dims_data[] = {1, num_units};
617 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
618
619 const int activation_state_dims_args[] = {2, batch_size,
620 memory_size * num_filters};
621 TfLiteIntArray* activation_state_dims =
622 IntArrayFromInts(activation_state_dims_args);
623
624 const int output_dims_args[] = {2, batch_size, num_units};
625 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
626
627 const int tensor_count = 6; // 5 inputs, 1 output
628
629 TfLiteTensor tensors[] = {
630 CreateQuantizedTensor(input_quantized, input_dims, input_scale,
631 input_zero_point),
632 CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
633 feature_weights_dims, feature_weights_scale, 0),
634 CreateQuantizedTensor(time_weights_data, time_weights_quantized,
635 time_weights_dims, time_weights_scale, 0),
636 CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
637 time_weights_scale, activation_state_scale),
638 CreateQuantizedTensor(initial_activation_state_data,
639 activation_state_quantized, activation_state_dims,
640 activation_state_scale, 0,
641 /*is_variable=*/true),
642 CreateQuantizedTensor(output_data, output_dims, output_scale,
643 output_zero_point)};
644
645 tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
646 output_scale, output_zero_point);
647 tflite::Quantize(input_sequences_data, input_sequences_quantized,
648 input_sequences_len, input_scale, input_zero_point);
649
650 ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
651 tensor_count, activation, input_sequences_quantized,
652 input_sequences_len, output_data, golden_output_quantized,
653 /*tolerance*/ 1);
654 }
655
656 } // namespace
657 } // namespace testing
658 } // namespace tflite
659
660 TF_LITE_MICRO_TESTS_BEGIN
661
662 #if !defined(XTENSA) // TODO(b/170332589): xtensa kernels are less general than
663 // reference kernels and we ifdef out test cases that are
664 // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden)665 TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
666 constexpr int batch_size = 2;
667 constexpr int num_units = 4;
668 constexpr int input_size = 2;
669 constexpr int memory_size = 10;
670 constexpr int rank = 2;
671 constexpr int num_filters = num_units * rank;
672
673 const int input_size_dims_count = batch_size * input_size;
674 float input_data[input_size_dims_count];
675
676 const int activation_state_dims_count =
677 batch_size * memory_size * num_filters;
678 float activation_state_data[activation_state_dims_count];
679
680 memcpy(activation_state_data,
681 tflite::testing::initial_activation_state_data_2x2x10,
682 sizeof(tflite::testing::initial_activation_state_data_2x2x10));
683
684 const int scratch_dims_count = batch_size * num_filters;
685 float scratch_data[scratch_dims_count];
686
687 const int output_dims_count = batch_size * num_units;
688 float output_data[output_dims_count];
689
690 tflite::testing::TestSVDF(
691 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
692 input_data, tflite::testing::feature_weights_data_2x2x10,
693 tflite::testing::time_weights_data_2x2x10, activation_state_data,
694 tflite::testing::bias_data_2x2x10, scratch_data, output_data,
695 tflite::testing::input_data_2x2x10,
696 sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
697 tflite::testing::golden_output_2x2x10);
698 }
699 #endif
700
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden)701 TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden) {
702 constexpr int batch_size = 2;
703 constexpr int num_units = 4;
704 constexpr int input_size = 2;
705 constexpr int memory_size = 10;
706 constexpr int rank = 2;
707 constexpr int num_filters = num_units * rank;
708
709 const int input_size_dims_count = batch_size * input_size;
710
711 const int activation_state_dims_count =
712 batch_size * memory_size * num_filters;
713
714 const int output_dims_count = batch_size * num_units;
715 int8_t output_data[output_dims_count];
716
717 float input_scale = 2.5f / INT8_MAX; // Range is [-2.5, 2.5]
718 float feature_weights_scale = 1.f / INT8_MAX; // Range is [-1, 1]
719 float time_weights_scale = 1.f / INT16_MAX; // Range is [-1, 1]
720 float activation_state_scale = 16.f / INT16_MAX; // Range is [-16, 16]
721 float output_scale = 1.f / INT8_MAX; // Range is [-1, 1]
722
723 int input_zero_point = 0;
724 int output_zero_point = 0;
725
726 int8_t input_quantized[input_size_dims_count];
727 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
728 sizeof(float)];
729 int8_t feature_weights_quantized
730 [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
731 int16_t
732 time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
733 sizeof(float)];
734 int16_t activation_state_quantized[activation_state_dims_count];
735 int32_t
736 bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
737 int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
738 sizeof(float)];
739
740 tflite::testing::TestIntegerSVDF(
741 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
742 input_quantized, input_scale, input_zero_point,
743 tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
744 feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
745 time_weights_quantized, time_weights_scale,
746 tflite::testing::bias_data_2x2x10, bias_quantized,
747 tflite::testing::initial_activation_state_data_2x2x10,
748 activation_state_quantized, activation_state_scale, output_data,
749 output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
750 input_sequences_quantized,
751 sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
752 tflite::testing::golden_output_2x2x10, golden_quantized,
753 sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
754 }
755
756 #if !defined(XTENSA) // TODO(b/170332589): xtensa kernels are less general than
757 // reference kernels and we ifdef out test cases that are
758 // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden)759 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
760 constexpr int batch_size = 1;
761 constexpr int num_units = 64;
762 constexpr int input_size = 16;
763 constexpr int memory_size = 8;
764 constexpr int rank = 1;
765 constexpr int num_filters = num_units * rank;
766 constexpr int activation_state_dims_count =
767 batch_size * memory_size * num_filters;
768 constexpr int output_dims_count = batch_size * num_units;
769 constexpr int input_dims_count = batch_size * input_size;
770
771 float input_data[input_dims_count];
772 float output_data[output_dims_count];
773 float scratch_buffer[batch_size * num_filters];
774 float activation_state_data_mutable[activation_state_dims_count];
775
776 // Initialize activation state to starting values.
777 memcpy(activation_state_data_mutable,
778 tflite::testing::initial_activation_state_data_16x1x1,
779 sizeof(tflite::testing::initial_activation_state_data_16x1x1));
780
781 tflite::testing::TestSVDF(
782 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
783 input_data, tflite::testing::feature_weights_data_16x1x1,
784 tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
785 tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
786 tflite::testing::input_data_16x1x1, input_size,
787 tflite::testing::golden_output_16x1x1);
788 }
789
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden)790 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
791 constexpr int batch_size = 1;
792 constexpr int num_units = 64;
793 constexpr int input_size = 16;
794 constexpr int memory_size = 8;
795 constexpr int rank = 1;
796 constexpr int num_filters = num_units * rank;
797 constexpr int activation_state_dims_count =
798 batch_size * memory_size * num_filters;
799 constexpr int output_dims_count = batch_size * num_units;
800 constexpr int input_dims_count = batch_size * input_size;
801
802 float input_data[input_dims_count];
803 float output_data[output_dims_count];
804 float scratch_buffer[batch_size * num_filters];
805 float activation_state_data_mutable[activation_state_dims_count];
806
807 // Initialize activation state to starting values.
808 memcpy(activation_state_data_mutable,
809 tflite::testing::initial_activation_state_data_16x1x1,
810 sizeof(tflite::testing::initial_activation_state_data_16x1x1));
811
812 tflite::testing::TestSVDF(
813 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
814 input_data, tflite::testing::feature_weights_data_16x1x1,
815 tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
816 tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
817 tflite::testing::input_data_16x1x1, input_size,
818 tflite::testing::golden_output_relu_16x1x1);
819 }
820 #endif
821
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden)822 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden) {
823 constexpr int batch_size = 1;
824 constexpr int num_units = 64;
825 constexpr int input_size = 16;
826 constexpr int memory_size = 8;
827 constexpr int rank = 1;
828 constexpr int num_filters = num_units * rank;
829 constexpr int activation_state_dims_count =
830 batch_size * memory_size * num_filters;
831 constexpr int output_dims_count = batch_size * num_units;
832 constexpr int input_dims_count = batch_size * input_size;
833
834 int8_t output_data[output_dims_count];
835
836 float input_scale = 0.10075444;
837 float feature_weights_scale = 0.00649388;
838 float time_weights_scale = 0.001571355;
839 float activation_state_scale = 0.00045896982;
840 float output_scale = 0.051445257;
841
842 int input_zero_point = 2;
843 int output_zero_point = 0;
844
845 int8_t input_quantized[input_dims_count];
846 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
847 sizeof(float)];
848 int8_t feature_weights_quantized
849 [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
850 int16_t
851 time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
852 sizeof(float)];
853 int16_t activation_state_quantized[activation_state_dims_count];
854 int32_t
855 bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
856 int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
857 sizeof(float)];
858
859 tflite::testing::TestIntegerSVDF(
860 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
861 input_quantized, input_scale, input_zero_point,
862 tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
863 feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
864 time_weights_quantized, time_weights_scale,
865 tflite::testing::bias_data_16x1x1, bias_quantized,
866 tflite::testing::initial_activation_state_data_16x1x1,
867 activation_state_quantized, activation_state_scale, output_data,
868 output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
869 input_sequences_quantized,
870 sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
871 tflite::testing::golden_output_16x1x1, golden_quantized,
872 sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
873 }
874
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden)875 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden) {
876 constexpr int batch_size = 1;
877 constexpr int num_units = 64;
878 constexpr int input_size = 16;
879 constexpr int memory_size = 8;
880 constexpr int rank = 1;
881 constexpr int num_filters = num_units * rank;
882 constexpr int activation_state_dims_count =
883 batch_size * memory_size * num_filters;
884 constexpr int output_dims_count = batch_size * num_units;
885 constexpr int input_dims_count = batch_size * input_size;
886
887 int8_t output_data[output_dims_count];
888
889 float input_scale = 0.10075444;
890 float feature_weights_scale = 0.00649388;
891 float time_weights_scale = 0.001571355;
892 float activation_state_scale = 0.00045896982;
893 float output_scale = 0.051445257;
894
895 int input_zero_point = 2;
896 int output_zero_point = -128;
897
898 int8_t input_quantized[input_dims_count];
899 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
900 sizeof(float)];
901 int8_t feature_weights_quantized
902 [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
903 int16_t
904 time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
905 sizeof(float)];
906 int16_t activation_state_quantized[activation_state_dims_count];
907 int32_t
908 bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
909 int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
910 sizeof(float)];
911
912 tflite::testing::TestIntegerSVDF(
913 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
914 input_quantized, input_scale, input_zero_point,
915 tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
916 feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
917 time_weights_quantized, time_weights_scale,
918 tflite::testing::bias_data_16x1x1, bias_quantized,
919 tflite::testing::initial_activation_state_data_16x1x1,
920 activation_state_quantized, activation_state_scale, output_data,
921 output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
922 input_sequences_quantized,
923 sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
924 tflite::testing::golden_output_relu_16x1x1, golden_quantized,
925 sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
926 }
927
928 TF_LITE_MICRO_TESTS_END
929