• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
19 #include "tensorflow/lite/micro/test_helpers.h"
20 #include "tensorflow/lite/micro/testing/micro_test.h"
21 namespace tflite {
22 namespace testing {
23 namespace {
24 
25 // naming as follows: <tensor name>_<input size>x<batch size>x<batch count>
26 
27 // 10 inputs each with shape {2, 2}.
28 const float input_data_2x2x10[] = {
29     0.12609188,  -0.46347019, 0.35867718,  0.36897406,
30 
31     0.14278367,  -1.64410412, -0.57290924, 0.12729003,
32 
33     0.49837467,  0.19278903,  0.17660543,  0.52949083,
34 
35     -0.11186574, 0.13164264,  -0.72674477, -0.5683046,
36 
37     -0.68892461, 0.37783599,  -0.63690937, 0.44483393,
38 
39     -0.81299269, -0.86831826, -0.95760226, 1.82078898,
40 
41     -1.45006323, -0.82251364, -1.65087092, -1.89238167,
42 
43     0.03966608,  -0.24936394, 2.06740379,  -1.51439476,
44 
45     0.11771342,  -0.23761693, 0.31088525,  -1.55601168,
46 
47     -0.89477462, 1.67204106,  -0.6230064,  0.29819036,
48 };
49 
50 // Feature filter of shape {8, 2}.
51 const float feature_weights_data_2x2x10[] = {
52     -0.31930989, 0.0079667,  0.39296314,  0.37613347,  0.12416199,  0.15785322,
53     0.27901134,  0.3905206,  0.21931258,  -0.36137494, -0.10640851, 0.31053296,
54     -0.36118156, -0.0976817, -0.36916667, 0.22197971};
55 
56 // Time filter of shape {8, 10}.
57 const float time_weights_data_2x2x10[] = {
58     -0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
59     0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
60 
61     0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
62     -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
63 
64     -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
65     0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
66 
67     -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
68     -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
69 
70     -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
71     0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
72 
73     -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
74     0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
75 
76     -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
77     -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
78 
79     0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
80     0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763};
81 
82 // Activation state with shape {2, 80}. These initial values must be copied into
83 // a mutable activation state tensor.
84 
85 const float initial_activation_state_data_2x2x10[] = {
86     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
88     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
90     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
91     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
92     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
93 
94 // Bias with shape {8}
95 const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};
96 
97 // 10 outputs each of shape {2, 4}
98 const float golden_output_2x2x10[] = {
99     -0.044205, -0.013757, 0.050369,  -0.018447,
100     0.073010,  0.025142,  -0.021154, 0.013551,
101 
102     -0.209613, -0.062421, 0.150209,  -0.108334,
103     0.028256,  -0.006950, -0.030885, 0.009603,
104 
105     -0.076800, -0.037075, -0.087198, -0.155183,
106     0.091069,  0.098446,  -0.016083, 0.106475,
107 
108     -0.082123, -0.162238, -0.084434, -0.141074,
109     -0.029340, -0.090685, 0.053302,  -0.030604,
110 
111     -0.201440, 0.088424,  0.139877,  0.012416,
112     -0.113212, 0.103893,  -0.100842, 0.122780,
113 
114     -0.166632, -0.116705, 0.175298,  -0.047163,
115     0.313077,  -0.166485, -0.285860, 0.129069,
116 
117     -0.625911, 0.046134,  0.138081,  -0.129581,
118     -0.521455, -0.061579, 0.230289,  0.114963,
119 
120     -0.216693, -0.161643, -0.179177, -0.052599,
121     -0.213239, 0.029502,  0.260858,  0.275045,
122 
123     -0.213689, -0.323608, -0.285635, -0.317687,
124     -0.324092, -0.317972, -0.208450, -0.462504,
125 
126     -0.255126, -0.218576, -0.041528, 0.179421,
127     -0.440583, 0.072127,  -0.284136, 0.241570};
128 
129 // Simulated real-world inputs, weights and expected outputs.
130 
131 // Input of shape {1x16}
132 const float input_data_16x1x1[] = {
133     -0.488494, 2.023762,  -2.233117, -0.488494, 3.559030, 9.490748,
134     -3.210106, -1.953977, -0.279140, 0.907204,  1.674838, 0.000000,
135     -0.279140, -0.628064, -0.069785, -0.628064,
136 };
137 
138 // Feature filter of shape {64, 16}.
139 const float feature_weights_data_16x1x1[] = {
140     0.173588,  0.173588,  -0.024798, 0.193426,  -0.099193, 0.044637,  0.183507,
141     0.183507,  0.044637,  0.198386,  -0.069435, 0.084314,  0.312458,  0.024798,
142     0.173588,  -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
143     -0.128951, 0.193426,  0.357095,  -0.317418, -0.119032, -0.218225, -0.004960,
144     -0.386853, -0.133911, 0.252942,  -0.019839, -0.024798, -0.054556, -0.069435,
145     -0.128951, 0.029758,  -0.099193, -0.312458, -0.029758, 0.064475,  0.183507,
146     0.114072,  -0.178547, -0.247982, -0.119032, 0.243023,  -0.119032, -0.034718,
147     -0.178547, 0.019839,  0.128951,  -0.223184, -0.009919, -0.213265, 0.168628,
148     -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
149     -0.267821, -0.580279, -0.099193, 0.213265,  0.119032,  -0.119032, -0.178547,
150     0.610037,  0.109112,  0.049596,  -0.014879, -0.049596, -0.193426, 0.039677,
151     -0.148789, -0.114072, -0.158709, -0.158709, 0.094233,  0.099193,  -0.114072,
152     0.104153,  -0.123991, 0.198386,  -0.173588, 0.089274,  -0.247982, -0.054556,
153     0.123991,  0.183507,  0.114072,  0.188467,  0.302539,  0.044637,  0.039677,
154     -0.099193, 0.168628,  -0.024798, -0.054556, -0.109112, 0.014879,  -0.009919,
155     0.069435,  -0.396772, -0.287660, -0.079354, -0.104153, 0.054556,  0.089274,
156     -0.099193, 0.114072,  0.034718,  0.119032,  0.282700,  -0.119032, -0.505884,
157     -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749,  0.128951,
158     0.143830,  -0.188467, -0.183507, 0.104153,  -0.024798, 0.193426,  -0.287660,
159     0.168628,  -0.009919, 0.119032,  -0.024798, -0.099193, -0.203346, 0.099193,
160     0.084314,  -0.168628, 0.123991,  -0.148789, 0.114072,  -0.029758, 0.228144,
161     -0.238063, 0.089274,  -0.064475, 0.307498,  -0.188467, -0.004960, -0.252942,
162     -0.173588, -0.158709, -0.044637, -0.009919, 0.312458,  -0.262861, 0.059516,
163     0.158709,  0.069435,  -0.282700, 0.074395,  -0.322377, -0.183507, -0.123991,
164     -0.233104, 0.009919,  0.252942,  -0.243023, 0.555481,  -0.099193, -0.119032,
165     -0.441409, 0.148789,  0.084314,  -0.168628, -0.183507, 0.188467,  0.024798,
166     -0.302539, 0.223184,  0.143830,  -0.193426, -0.054556, -0.218225, -0.297579,
167     0.104153,  0.272781,  -0.034718, 0.114072,  -0.059516, 0.044637,  0.342216,
168     0.421570,  0.138870,  -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
169     -0.128951, -0.044637, -0.173588, 0.302539,  0.079354,  0.049596,  0.133911,
170     -0.029758, -0.312458, -0.029758, 0.079354,  0.128951,  0.252942,  0.213265,
171     0.014879,  0.287660,  0.178547,  0.297579,  0.352135,  0.401732,  0.024798,
172     -0.277740, -0.411651, -0.069435, 0.342216,  -0.158709, -0.104153, -0.009919,
173     0.223184,  0.228144,  -0.019839, 0.059516,  -0.104153, -0.510844, 0.029758,
174     -0.406691, 0.089274,  0.421570,  0.163668,  -0.143830, -0.019839, -0.039677,
175     0.104153,  -0.044637, -0.128951, 0.203346,  0.079354,  -0.069435, 0.094233,
176     -0.138870, 0.466207,  -0.163668, 0.049596,  0.029758,  0.267821,  0.029758,
177     -0.049596, 0.009919,  0.004960,  -0.099193, 0.094233,  -0.262861, 0.089274,
178     -0.302539, 0.332297,  -0.307498, -0.014879, 0.168628,  -0.094233, -0.272781,
179     0.034718,  -0.133911, -0.228144, 0.094233,  0.257902,  -0.228144, 0.153749,
180     -0.054556, -0.252942, 0.054556,  0.218225,  -0.054556, 0.302539,  0.282700,
181     0.054556,  -0.044637, -0.133911, 0.233104,  -0.049596, 0.411651,  0.044637,
182     -0.297579, -0.029758, -0.114072, 0.114072,  -0.580279, 0.079354,  -0.024798,
183     -0.347175, -0.128951, -0.099193, 0.238063,  -0.104153, -0.009919, 0.158709,
184     -0.034718, 0.123991,  -0.163668, 0.059516,  0.342216,  0.009919,  0.064475,
185     -0.307498, -0.520763, -0.238063, 0.163668,  0.362054,  0.034718,  -0.178547,
186     -0.104153, -0.257902, 0.322377,  0.054556,  0.148789,  -0.178547, 0.084314,
187     0.004960,  0.257902,  0.029758,  0.079354,  -0.223184, -0.193426, 0.282700,
188     0.000000,  -0.019839, -0.114072, 0.491005,  -0.193426, -0.029758, -0.243023,
189     0.009919,  0.089274,  -0.277740, -0.089274, 0.104153,  0.337256,  0.138870,
190     -0.307498, -0.054556, 0.352135,  0.133911,  -0.044637, 0.133911,  -0.089274,
191     -0.357095, -0.272781, 0.069435,  0.059516,  -0.109112, 0.148789,  -0.044637,
192     -0.019839, -0.153749, 0.123991,  -0.223184, 0.322377,  0.074395,  -0.312458,
193     0.024798,  -0.223184, 0.109112,  -0.138870, 0.218225,  -0.074395, -0.406691,
194     0.009919,  -0.198386, -0.009919, 0.416611,  0.178547,  0.148789,  0.133911,
195     -0.004960, 0.069435,  -0.054556, -0.044637, 0.297579,  0.059516,  -0.456288,
196     -0.148789, -0.004960, 0.054556,  0.094233,  -0.104153, 0.198386,  -0.302539,
197     0.133911,  0.411651,  0.054556,  0.525723,  -0.089274, 0.079354,  0.238063,
198     0.079354,  -0.039677, 0.039677,  0.029758,  0.332297,  -0.014879, -0.367014,
199     -0.143830, -0.123991, -0.064475, 0.014879,  0.173588,  -0.168628, 0.386853,
200     0.009919,  0.173588,  0.163668,  0.123991,  0.163668,  0.198386,  0.203346,
201     -0.401732, -0.009919, 0.272781,  -0.173588, 0.044637,  0.238063,  0.133911,
202     0.049596,  0.208305,  -0.024798, 0.049596,  -0.049596, 0.034718,  -0.446368,
203     0.466207,  -0.089274, -0.099193, -0.128951, -0.228144, 0.014879,  -0.252942,
204     0.074395,  -0.223184, -0.168628, -0.292619, 0.178547,  0.153749,  -0.014879,
205     0.054556,  0.000000,  0.193426,  0.158709,  0.178547,  -0.327337, -0.138870,
206     -0.114072, 0.168628,  0.297579,  -0.109112, -0.029758, -0.029758, -0.416611,
207     0.059516,  0.000000,  -0.168628, -0.322377, 0.238063,  -0.128951, -0.029758,
208     0.500925,  0.292619,  0.123991,  -0.099193, 0.074395,  0.317418,  -0.148789,
209     0.064475,  -0.104153, -0.044637, -0.094233, 0.188467,  -0.044637, 0.213265,
210     -0.233104, -0.049596, 0.004960,  -0.198386, 0.287660,  -0.148789, -0.257902,
211     0.004960,  -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
212     0.029758,  -0.019839, -0.009919, -0.143830, -0.158709, 0.158709,  -0.243023,
213     -0.039677, -0.297579, 0.069435,  0.049596,  0.302539,  0.059516,  0.074395,
214     -0.019839, 0.352135,  -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
215     0.252942,  -0.228144, -0.049596, 0.173588,  0.173588,  -0.074395, -0.034718,
216     -0.292619, 0.362054,  0.183507,  0.243023,  -0.203346, -0.044637, 0.054556,
217     0.059516,  -0.158709, -0.158709, 0.000000,  0.327337,  0.119032,  0.034718,
218     -0.044637, -0.089274, 0.089274,  -0.233104, 0.000000,  -0.317418, 0.371974,
219     0.213265,  0.307498,  -0.178547, -0.367014, 0.039677,  -0.059516, 0.168628,
220     -0.014879, 0.143830,  0.123991,  -0.084314, -0.332297, -0.416611, 0.183507,
221     0.109112,  -0.039677, 0.014879,  0.292619,  -0.213265, -0.054556, 0.004960,
222     0.123991,  0.119032,  0.000000,  -0.332297, -0.312458, -0.198386, -0.213265,
223     0.119032,  0.322377,  0.168628,  0.104153,  -0.262861, 0.327337,  -0.049596,
224     -0.228144, -0.074395, 0.168628,  0.123991,  0.396772,  0.044637,  0.322377,
225     0.193426,  0.267821,  -0.178547, 0.297579,  0.148789,  -0.218225, -0.138870,
226     0.044637,  0.049596,  0.133911,  0.064475,  0.069435,  0.064475,  -0.158709,
227     -0.044637, -0.173588, 0.267821,  0.327337,  0.079354,  -0.228144, 0.029758,
228     0.014879,  0.198386,  -0.109112, -0.133911, 0.431490,  0.099193,  0.421570,
229     0.233104,  -0.054556, 0.054556,  -0.317418, -0.133911, -0.123991, -0.287660,
230     0.342216,  -0.049596, -0.153749, 0.228144,  -0.213265, 0.262861,  0.406691,
231     -0.084314, -0.004960, 0.193426,  0.188467,  -0.099193, -0.223184, 0.163668,
232     -0.257902, -0.153749, 0.441409,  0.099193,  0.128951,  -0.089274, -0.208305,
233     -0.009919, -0.004960, -0.109112, 0.024798,  -0.119032, 0.019839,  0.391812,
234     -0.024798, 0.198386,  0.327337,  -0.505884, -0.099193, 0.510844,  -0.148789,
235     0.094233,  -0.153749, -0.039677, 0.352135,  0.272781,  -0.228144, -0.287660,
236     -0.272781, 0.148789,  0.277740,  0.074395,  0.109112,  -0.064475, 0.044637,
237     0.074395,  -0.292619, 0.153749,  -0.064475, -0.114072, 0.198386,  -0.039677,
238     -0.128951, -0.004960, 0.257902,  -0.228144, -0.094233, 0.064475,  0.014879,
239     0.188467,  -0.416611, 0.099193,  0.362054,  -0.208305, 0.198386,  -0.079354,
240     0.009919,  0.119032,  0.332297,  0.243023,  -0.168628, 0.158709,  0.039677,
241     0.143830,  0.277740,  -0.168628, 0.009919,  0.099193,  -0.004960, -0.257902,
242     -0.297579, 0.208305,  -0.104153, 0.119032,  0.247982,  0.381893,  -0.223184,
243     -0.367014, -0.327337, -0.168628, -0.094233, 0.208305,  -0.019839, 0.183507,
244     0.084314,  0.133911,  0.109112,  -0.148789, -0.183507, -0.411651, -0.024798,
245     -0.114072, -0.029758, -0.009919, 0.173588,  -0.059516, -0.049596, 0.039677,
246     0.317418,  0.138870,  -0.247982, -0.084314, 0.158709,  0.054556,  -0.084314,
247     -0.049596, 0.074395,  0.019839,  -0.282700, -0.119032, -0.262861, 0.163668,
248     -0.069435, -0.064475, -0.059516, 0.094233,  0.123991,  -0.079354, -0.272781,
249     -0.267821, 0.233104,  0.114072,  -0.218225, 0.540602,  0.089274,  0.262861,
250     0.079354,  0.267821,  -0.119032, -0.109112, -0.128951, 0.128951,  -0.044637,
251     -0.272781, 0.277740,  0.297579,  -0.054556, -0.084314, -0.049596, 0.123991,
252     0.059516,  0.238063,  -0.168628, -0.009919, 0.163668,  -0.307498, 0.109112,
253     -0.064475, 0.218225,  -0.168628, -0.004960, -0.168628, 0.119032,  0.094233,
254     -0.183507, -0.089274, -0.292619, -0.094233, 0.064475,  -0.183507, -0.168628,
255     0.089274,  0.074395,  -0.367014, -0.024798, -0.069435, 0.119032,  -0.302539,
256     -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
257     -0.183507, -0.238063, 0.163668,  -0.332297, -0.148789, -0.391812, -0.024798,
258     -0.133911, -0.059516, -0.123991, 0.123991,  -0.292619, -0.044637, 0.059516,
259     -0.069435, 0.049596,  -0.069435, 0.034718,  0.158709,  -0.347175, -0.044637,
260     0.352135,  -0.347175, -0.282700, -0.054556, 0.307498,  0.029758,  0.357095,
261     -0.148789, 0.208305,  -0.317418, 0.009919,  0.004960,  -0.243023, 0.049596,
262     -0.099193, 0.213265,  -0.342216, 0.158709,  0.123991,  -0.332297, 0.386853,
263     -0.262861, -0.208305, 0.123991,  -0.044637, 0.148789,  0.084314,  -0.297579,
264     -0.307498, -0.163668, 0.337256,  -0.014879, 0.074395,  0.178547,  -0.004960,
265     -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
266     -0.153749, 0.629876,  0.277740,  0.178547,  -0.267821, -0.004960, 0.247982,
267     0.084314,  -0.094233, 0.000000,  -0.039677, 0.332297,  0.178547,  0.009919,
268     -0.213265, -0.208305, -0.044637, 0.019839,  0.218225,  -0.297579, 0.014879,
269     -0.247982, -0.004960, -0.128951, 0.421570,  -0.059516, 0.362054,  -0.203346,
270     -0.143830, -0.099193, -0.024798, 0.094233,  -0.123991, 0.163668,  0.109112,
271     -0.104153, -0.233104, 0.009919,  -0.218225, 0.376933,  0.104153,  -0.059516,
272     0.049596,  -0.054556, 0.019839,  -0.044637, -0.019839, 0.371974,  -0.019839,
273     0.104153,  0.168628,  -0.024798, -0.272781, -0.158709, 0.223184,  0.044637,
274     0.039677,  -0.168628, -0.287660, -0.109112, 0.094233,  -0.089274, -0.148789,
275     0.178547,  -0.039677, -0.089274, -0.049596, -0.024798, 0.064475,  -0.158709,
276     0.089274,  0.029758,  -0.247982, 0.362054,  0.024798,  -0.004960, -0.099193,
277     0.173588,  -0.059516, 0.188467,  -0.629876, 0.094233,  0.371974,  0.069435,
278     0.252942,  -0.357095, -0.272781, -0.367014, 0.014879,  -0.049596, -0.262861,
279     0.009919,  -0.094233, -0.094233, 0.059516,  0.223184,  0.133911,  0.411651,
280     -0.044637, -0.044637, 0.109112,  0.228144,  0.386853,  -0.233104, 0.069435,
281     0.228144,  -0.302539, 0.029758,  0.089274,  0.044637,  -0.238063, -0.138870,
282     -0.158709, -0.019839, 0.049596,  0.039677,  0.000000,  -0.069435, 0.109112,
283     -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911,  0.391812,
284     0.123991,  -0.317418, 0.233104,  -0.029758, -0.099193, -0.193426, 0.074395,
285     -0.009919, 0.252942,  0.322377,  -0.530683, 0.208305,  0.252942,  0.203346,
286     -0.069435, -0.262861};
287 
288 // Time filter of shape {64, 8}.
289 const float time_weights_data_16x1x1[] = {
290     -0.052026, 0.043107,  0.053512,  0.013378,  0.011892,  -0.182834, -0.108511,
291     0.153105,  0.050539,  -0.173915, 0.145672,  0.208103,  -0.221481, 0.108511,
292     -0.496475, 0.181347,  -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
293     -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
294     -0.057972, -0.057972, -0.497962, 0.054999,  0.181347,  0.047566,  -0.099592,
295     -0.111484, -0.130808, -0.071350, 0.380532,  0.010405,  0.041621,  0.052026,
296     0.022297,  0.081755,  0.098106,  0.099592,  -0.584176, -0.023783, 0.062431,
297     -0.090674, -0.279453, -0.486070, -0.273507, 0.004459,  -0.062431, 0.095133,
298     0.056485,  0.022297,  -0.105538, -0.184320, 0.358235,  0.254183,  0.049053,
299     0.084728,  0.218508,  0.078782,  -0.136754, -0.017837, -0.124862, -0.118916,
300     -0.001486, 0.043107,  0.254183,  0.087701,  0.261616,  0.309182,  -0.404315,
301     -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
302     -0.046080, -0.062431, 0.020810,  0.040134,  -0.135267, -0.169456, -0.050539,
303     -0.576743, 0.034188,  0.075809,  0.101079,  0.136754,  0.083241,  0.077296,
304     -0.050539, 0.761064,  -0.335938, -0.080268, 0.025270,  0.257156,  0.227427,
305     0.252697,  0.065404,  0.115943,  0.222968,  -0.026756, -0.054999, 0.107025,
306     -0.093646, 0.041621,  -0.092160, -0.474178, -0.016351, 0.004459,  0.049053,
307     0.019324,  0.019324,  0.074323,  0.038648,  -0.613905, 0.182834,  0.075809,
308     0.028243,  0.019324,  0.010405,  -0.011892, 0.001486,  -0.492016, -0.224454,
309     -0.474178, -0.147159, 0.002973,  0.102565,  0.136754,  -0.267561, -0.001486,
310     -0.095133, -0.040134, 0.066890,  0.074323,  0.104052,  0.532150,  0.090674,
311     0.072836,  -0.053512, -0.004459, 0.020810,  0.046080,  0.062431,  0.477151,
312     0.133781,  -0.029729, -0.026756, 0.031215,  0.156077,  0.096619,  0.251210,
313     0.352289,  0.657012,  0.047566,  -0.014865, -0.072836, -0.016351, 0.008919,
314     -0.053512, 0.016351,  0.300263,  0.047566,  0.020810,  0.169456,  0.001486,
315     0.007432,  0.111484,  0.044594,  -0.188779, -0.096619, 0.074323,  -0.040134,
316     0.160537,  0.138240,  0.184320,  0.377559,  -0.092160, -0.049053, 0.056485,
317     -0.032702, 0.001486,  -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
318     0.026756,  0.163510,  0.172428,  0.127835,  -0.199185, -0.218508, -0.057972,
319     -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
320     0.248238,  0.191752,  0.221481,  0.173915,  0.173915,  -0.208103, -0.077296,
321     0.384991,  -0.313641, -0.313641, -0.147159, -0.090674, 0.035675,  0.059458,
322     -0.010405, 0.019324,  0.087701,  0.016351,  0.037161,  0.469719,  -0.074323,
323     0.092160,  0.026756,  0.090674,  0.098106,  0.004459,  -0.034188, 0.492016,
324     -0.367154, -0.093646, -0.063917, 0.041621,  0.017837,  0.026756,  -0.062431,
325     -0.350803, 0.425125,  0.002973,  0.083241,  0.075809,  0.016351,  0.047566,
326     -0.185807, -0.107025, -0.098106, -0.144186, 0.255670,  0.020810,  0.105538,
327     0.029729,  0.129321,  0.156077,  0.141213,  0.334452,  0.147159,  -0.066890,
328     0.035675,  0.115943,  0.240805,  0.328506,  0.162023,  -0.237832, 0.218508,
329     0.233373,  0.214049,  0.099592,  0.026756,  -0.322560, -0.236346, -0.166483,
330     0.225941,  0.109997,  -0.147159, 0.147159,  -0.266075, 0.111484,  0.078782,
331     -0.120403, 0.022297,  -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
332     -0.023783, 0.016351,  0.026756,  -0.013378, -0.069863, -0.112970, 0.013378,
333     0.086214,  0.014865,  0.352289,  -0.240805, -0.135267, -0.114457, -0.472692,
334     0.334452,  0.095133,  0.047566,  0.130808,  -0.068377, -0.007432, -0.130808,
335     -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000,  -0.028243,
336     0.029729,  -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
337     -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
338     -0.047566, -0.107025, 0.004459,  0.053512,  0.047566,  -0.358235, -0.193239,
339     0.040134,  -0.096619, -0.054999, 0.099592,  0.032702,  0.205130,  -0.170942,
340     -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
341     -0.074323, 0.078782,  0.607959,  -0.437017, -0.164996, -0.166483, 0.043107,
342     -0.016351, 0.258643,  0.065404,  -0.057972, 0.017837,  0.080268,  0.050539,
343     -0.013378, -0.215536, -0.524718, 0.260129,  0.040134,  -0.002973, -0.046080,
344     0.020810,  0.025270,  0.145672,  0.515799,  0.233373,  0.011892,  0.139727,
345     0.126348,  0.065404,  -0.007432, -0.008919, 0.035675,  0.083241,  0.040134,
346     -0.005946, 0.503907,  -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
347     0.133781,  -0.011892, 0.041621,  0.062431,  -0.062431, -0.040134, -0.092160,
348     -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161,  -0.092160,
349     -0.056485, -0.041621, 0.112970,  0.248238,  0.438503,  0.258643,  -0.013378,
350     0.004459,  0.043107,  0.040134,  0.017837,  0.101079,  0.264589,  0.212563,
351     0.014865,  0.285399,  0.153105,  0.170942,  0.358235,  0.334452,  0.086214,
352     0.132294,  0.098106,  -0.001486, 0.107025,  0.200671,  -0.026756, 0.344857,
353     0.227427,  -0.041621, 0.098106,  0.063917,  -0.093646, 0.130808,  0.285399,
354     -0.319587, 0.035675,  -0.017837, -0.319587, 0.016351,  -0.098106, -0.017837,
355     0.083241,  0.074323,  -0.054999, 0.276480,  0.316614,  -0.099592, -0.059458,
356     0.156077,  -0.043107, 0.035675,  0.056485,  -0.022297, 0.017837,  -0.001486,
357     0.340398,  0.492016,  0.004459,  0.057972,  -0.150132, -0.206617, -0.257156,
358     -0.248238, -0.080268, -0.164996, 0.352289,  -0.054999, -0.056485, 0.010405,
359     -0.049053, -0.041621, -0.099592, 0.013378,  -0.089187, 0.057972,  -0.413234,
360     0.217022,  0.013378,  -0.080268, -0.035675, 0.035675,  0.007432,  0.002973,
361     -0.469719, 0.141213,  0.136754,  0.153105,  0.130808,  -0.104052, -0.508367,
362     -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
363     0.484583};
364 
365 // Bias of shape {64}
366 const float bias_data_16x1x1[] = {
367     -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
368     -0.378594, 0.178152,  0.400380,  -0.301349, -0.240913, -0.159454, -0.158757,
369     -0.073665, 0.455906,  -0.061232, 0.318907,  -0.226993, -0.344644, 0.140316,
370     0.559608,  0.109774,  0.437391,  0.113849,  -0.162068, 0.039572,  0.569472,
371     0.460205,  0.113459,  0.370469,  0.176811,  0.203063,  -0.296975, -0.271655,
372     0.059862,  -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
373     0.083172,  0.029040,  -0.236288, -0.267054, -0.166627, 0.188319,  -0.271391,
374     -0.222920, 0.106463,  0.263614,  0.384986,  -0.125957, -0.095890, 0.363686,
375     -0.036990, -0.358884, -0.178254, 0.305596,  0.390088,  -0.189437, 0.613409,
376     0.399639};
377 
378 // Activation state with shape {64, 8}. These initial values must be copied into
379 // a mutable activation state tensor.
380 const float initial_activation_state_data_16x1x1[] = {
381     -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
382     -1.184425, -0.462094, -1.443421, 0.230736,  -0.494701, -0.354955, -2.534061,
383     -4.277471, -4.218467, 0.403711,  -0.248748, -0.330111, -0.467683, 0.549047,
384     0.733511,  -0.230115, 0.793136,  -1.126353, -0.984123, -0.081984, -0.222351,
385     0.692830,  0.517060,  1.367958,  2.118860,  -0.116766, -0.826365, -2.402700,
386     -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490,  0.085400,
387     -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
388     1.109273,  1.693411,  -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
389     0.253406,  -0.357439, -1.593415, -0.879779, -1.111136, 1.821357,  2.471952,
390     1.236908,  -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
391     -0.319242, 0.749349,  1.156476,  0.658670,  1.997437,  2.080663,  2.912618,
392     2.677224,  2.642442,  2.796163,  -0.272349, -0.473273, 3.120063,  2.747097,
393     3.595510,  1.874150,  2.049919,  2.093396,  -1.049959, 0.277939,  -1.255541,
394     -1.052443, -1.810177, -0.883505, -0.538178, 0.524203,  -1.017662, -0.269244,
395     0.039129,  -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
396     0.102480,  0.349986,  0.405885,  1.287216,  0.756181,  0.319242,  -0.641590,
397     -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
398     -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
399     0.039439,  0.672023,  1.408019,  1.362679,  1.467644,  1.006171,  0.310236,
400     -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
401     0.051551,  0.232600,  0.088816,  2.570395,  0.704009,  2.465120,  3.010751,
402     2.139357,  0.630410,  1.006171,  1.545281,  1.486898,  -1.162998, -2.344317,
403     -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
404     -0.368930, -0.162726, 0.396879,  0.505570,  0.534451,  0.554947,  1.270447,
405     0.388805,  0.531967,  -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
406     -1.164550, 0.609603,  3.293348,  2.600208,  1.454290,  -1.034121, -1.760179,
407     -1.192500, -0.613951, 3.449553,  2.912618,  1.917937,  1.435968,  0.879158,
408     1.118279,  0.102791,  -0.502465, -0.239121, -0.092853, 1.786265,  1.943091,
409     2.547104,  2.630641,  2.585302,  2.965411,  -0.945615, -2.538720, -2.474126,
410     -1.088156, 0.056209,  0.864873,  0.170490,  0.457435,  0.545941,  0.752765,
411     1.569503,  1.129459,  0.662086,  -0.527929, -0.810838, -1.662978, 1.285042,
412     1.653040,  4.130893,  2.961995,  4.147041,  3.256393,  3.881524,  2.522571,
413     -0.875431, -1.112378, 2.105817,  2.180970,  3.121926,  1.577577,  1.639376,
414     2.906407,  -0.142230, 0.421101,  2.212335,  2.311399,  3.993321,  3.651719,
415     4.206666,  4.678387,  -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
416     -1.197158, -0.949652, -0.282908, 0.320795,  -1.543728, 1.290322,  1.788128,
417     3.957297,  3.205774,  2.892432,  2.297114,  0.138814,  -0.139435, 0.936920,
418     0.344707,  0.723263,  -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
419     -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
420     3.858543,  2.499901,  1.087535,  0.290051,  -0.026086, -0.880400, -2.602692,
421     -1.404292, 0.253096,  -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
422     1.846200,  -1.604284, -1.244671, -0.464888, 0.326385,  0.168006,  -0.262723,
423     -0.744691, 0.953379,  -0.407127, -0.349986, -1.154302, 0.831023,  1.590931,
424     2.538720,  2.063583,  3.697680,  -0.752455, -1.293117, -1.330693, -1.869802,
425     -0.592523, 0.631652,  1.198089,  -0.481347, 3.738983,  4.153252,  2.782499,
426     2.244321,  0.709289,  1.650245,  1.700865,  0.385078,  2.192460,  2.610456,
427     4.009780,  3.492719,  2.574743,  2.116687,  1.856138,  1.205853,  2.722563,
428     4.075305,  5.415935,  3.009198,  2.715421,  1.571056,  0.897170,  -2.430339,
429     0.749970,  0.425760,  -0.302783, 0.817359,  1.031636,  1.913589,  2.686229,
430     1.631923,  -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
431     -1.805519, -0.486627, 0.909591,  2.082837,  -1.473855, -2.456735, -3.851401,
432     -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
433     -0.245642, -0.588486, 0.033850,  2.084700,  0.076084,  0.690035,  0.747797,
434     0.594697,  -1.016109, -1.348083, -1.201195, -1.088466, 2.045571,  2.460772,
435     0.717984,  0.041613,  -0.721711, 1.134738,  2.322269,  1.112378,  -0.307441,
436     -0.581033, -0.868599, -0.018633, 0.856488,  0.919839,  0.303094,  -0.433213,
437     0.811148,  -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
438     -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205,  0.622646,
439     0.145957,  0.359303,  1.012072,  0.774814,  -0.400295, -1.484103, -2.007374,
440     -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451,  0.087264,
441     -0.227320, -1.211753, -1.532859, -1.688753, 0.065215,  0.134777,  0.608051,
442     -0.393152, -0.214588, -0.635689, -1.499320, 0.069562,  -1.555839, -2.633126,
443     -2.966032, -1.550870, -0.101549, 0.874189,  0.436318,  0.299367,  2.289972,
444     2.339659,  2.602071,  1.564535,  0.019254,  -0.583207, -1.295912, -2.424749,
445     -1.221070, -1.175109, -0.577306, -0.102791, 1.877876,  2.568222,  2.173827,
446     3.131243,  2.637784,  2.088737,  3.679047,  3.218506,  2.483442,  1.650556,
447     1.363611,  -0.027328, 1.486898,  -0.721711, -3.684327, -3.006093, -3.777491,
448     -2.327548, -2.737470, -4.549510, -0.060867, 0.127635,  0.680408,  0.581344,
449     0.320174,  -0.403090, -0.838166, 0.293777,  -0.995613, -0.165521, -0.419859,
450     1.110515,  1.203679,  1.749931,  2.467294,  4.276539,  0.031055,  -0.967664,
451     1.167035,  1.865144,  3.221923,  3.248630,  4.121266,  4.187723,  0.749039,
452     -1.571056, 0.785994,  1.568572,  3.759479,  3.588678,  4.116608,  3.864444,
453     -0.290051, -0.271107, 0.375140,  0.537556,  0.536314,  0.095959,  0.054656,
454     0.088816};
455 
456 // One output with shape {1, 64}
457 const float golden_output_16x1x1[] = {
458     -0.087914, 1.145864,  -0.418088, -1.556392, -0.925298, 0.205252,  0.289119,
459     1.331180,  -0.218010, 0.963057,  -2.225886, 1.248478,  1.448983,  0.355467,
460     1.682174,  0.803739,  0.449738,  0.543566,  1.916269,  -2.975136, 0.222774,
461     0.241589,  -0.104216, 1.561748,  0.936818,  -0.089907, -0.520117, -0.870353,
462     1.606074,  0.895770,  0.521297,  -0.369994, -0.889351, -2.809309, 2.404628,
463     1.069754,  -0.195456, -1.105652, 1.272715,  -1.233177, 1.271416,  -1.691805,
464     -1.058125, -0.716227, 0.052540,  1.262483,  0.540555,  1.735760,  -0.539197,
465     -0.014367, -0.243002, 1.072254,  0.528985,  -0.731151, -1.262649, 2.338702,
466     -0.603093, 0.970736,  -3.567897, 0.035085,  -0.201711, -0.550400, 1.545573,
467     -1.805005};
468 
469 // One output with shape {1, 64}
470 const float golden_output_relu_16x1x1[] = {
471     0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
472     1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
473     1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
474     0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
475     1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
476     1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
477     0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
478     0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
479     0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
480     0.000000};
481 
482 template <typename T>
ValidateSVDFGoldens(const int batch_size,const int num_units,const int input_size,const int rank,TfLiteTensor * tensors,const int tensor_count,TfLiteFusedActivation activaiton,const T * input_sequences_data,const int input_sequences_len,T * output_data,const T * expected_output,float tolerance=1e-5f)483 void ValidateSVDFGoldens(const int batch_size, const int num_units,
484                          const int input_size, const int rank,
485                          TfLiteTensor* tensors, const int tensor_count,
486                          TfLiteFusedActivation activaiton,
487                          const T* input_sequences_data,
488                          const int input_sequences_len, T* output_data,
489                          const T* expected_output, float tolerance = 1e-5f) {
490   TfLiteSVDFParams params;
491   params.rank = rank;
492   params.activation = activaiton;
493 
494   int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
495   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
496 
497   int outputs_array_data[] = {1, 5};
498   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
499 
500   const TfLiteRegistration registration = Register_SVDF();
501   micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
502                              outputs_array, &params);
503 
504   TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
505   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
506 
507   // Abort early to make it clear init and prepare failed.
508   if (init_and_prepare_status != kTfLiteOk) {
509     return;
510   }
511 
512   int num_inputs = input_sequences_len / (input_size * batch_size);
513 
514   for (int i = 0; i < num_inputs; ++i) {
515     const T* input_batch_start =
516         input_sequences_data + i * input_size * batch_size;
517 
518     memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
519     TfLiteStatus status = runner.Invoke();
520     TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
521 
522     // Only validate outputs when invoke has succeeded.
523     if (status == kTfLiteOk) {
524       int output_idx = 0;
525       int golden_idx = i * batch_size * num_units;
526       for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
527         TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
528                                   tolerance);
529         output_idx++;
530       }
531     }
532   }
533 }
534 
535 #if !defined(XTENSA)  // Needed to avoid build errors from unused functions.
TestSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,float * input_data,const float * feature_weights_data,const float * time_weights_data,float * activation_state_data,const float * bias_data,float * scratch_data,float * output_data,const float * input_sequences_data,int input_sequences_len,const float * expected_output,float tolerance=1e-5f)536 void TestSVDF(const int batch_size, const int num_units, const int input_size,
537               const int memory_size, const int rank,
538               TfLiteFusedActivation activation, float* input_data,
539               const float* feature_weights_data, const float* time_weights_data,
540               float* activation_state_data, const float* bias_data,
541               float* scratch_data, float* output_data,
542               const float* input_sequences_data, int input_sequences_len,
543               const float* expected_output, float tolerance = 1e-5f) {
544   const int num_filters = num_units * rank;
545 
546   const int input_dims_arg[] = {2, batch_size, input_size};
547   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
548 
549   const int feature_weights_dims_args[] = {2, num_filters, input_size};
550   TfLiteIntArray* feature_weights_dims =
551       IntArrayFromInts(feature_weights_dims_args);
552 
553   const int time_weights_dims_args[] = {2, num_filters, memory_size};
554   TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
555 
556   const int activation_state_dims_args[] = {2, batch_size,
557                                             memory_size * num_filters};
558   TfLiteIntArray* activation_state_dims =
559       IntArrayFromInts(activation_state_dims_args);
560 
561   const int bias_dims_args[] = {1, num_units};
562   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);
563 
564   const int output_dims_args[] = {2, batch_size, num_units};
565   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
566 
567   const int tensor_count = 6;  // 5 inputs, 1 output
568   TfLiteTensor tensors[] = {
569       CreateTensor(input_data, input_dims),
570       CreateTensor(feature_weights_data, feature_weights_dims),
571       CreateTensor(time_weights_data, time_weights_dims),
572       CreateTensor(bias_data, bias_dims),
573       CreateTensor(activation_state_data, activation_state_dims,
574                    /*is_variable=*/true),
575       CreateTensor(output_data, output_dims),
576   };
577 
578   ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
579                       tensor_count, activation, input_sequences_data,
580                       input_sequences_len, output_data, expected_output,
581                       tolerance);
582 }
583 #endif
584 
585 // The pattern to this method's arguemnts is:
586 // <kernel metadata>
587 // for each tensor in
588 //     {input, feature weights, time weights, bias, activation state, output}:
589 //   <tensor float values> <tensor quantized buffer> <tensor quantization data>
TestIntegerSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,int8_t * input_quantized,float input_scale,int input_zero_point,const float * feature_weights_data,int8_t * feature_weights_quantized,const float feature_weights_scale,const float * time_weights_data,int16_t * time_weights_quantized,float time_weights_scale,const float * bias_data,int32_t * bias_quantized,const float * initial_activation_state_data,int16_t * activation_state_quantized,float activation_state_scale,int8_t * output_data,float output_scale,int output_zero_point,const float * input_sequences_data,int8_t * input_sequences_quantized,const int input_sequences_len,const float * golden_output,int8_t * golden_output_quantized,int golden_output_len)590 inline void TestIntegerSVDF(
591     const int batch_size, const int num_units, const int input_size,
592     const int memory_size, const int rank, TfLiteFusedActivation activation,
593     int8_t* input_quantized, float input_scale, int input_zero_point,
594     const float* feature_weights_data, int8_t* feature_weights_quantized,
595     const float feature_weights_scale, const float* time_weights_data,
596     int16_t* time_weights_quantized, float time_weights_scale,
597     const float* bias_data, int32_t* bias_quantized,
598     const float* initial_activation_state_data,
599     int16_t* activation_state_quantized, float activation_state_scale,
600     int8_t* output_data, float output_scale, int output_zero_point,
601     const float* input_sequences_data, int8_t* input_sequences_quantized,
602     const int input_sequences_len, const float* golden_output,
603     int8_t* golden_output_quantized, int golden_output_len) {
604   const int num_filters = num_units * rank;
605 
606   const int input_dims_arg[] = {2, batch_size, input_size};
607   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
608 
609   const int feature_weights_dims_args[] = {2, num_filters, input_size};
610   TfLiteIntArray* feature_weights_dims =
611       IntArrayFromInts(feature_weights_dims_args);
612 
613   const int time_weights_dims_args[] = {2, num_filters, memory_size};
614   TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
615 
616   const int bias_dims_data[] = {1, num_units};
617   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
618 
619   const int activation_state_dims_args[] = {2, batch_size,
620                                             memory_size * num_filters};
621   TfLiteIntArray* activation_state_dims =
622       IntArrayFromInts(activation_state_dims_args);
623 
624   const int output_dims_args[] = {2, batch_size, num_units};
625   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
626 
627   const int tensor_count = 6;  // 5 inputs, 1 output
628 
629   TfLiteTensor tensors[] = {
630       CreateQuantizedTensor(input_quantized, input_dims, input_scale,
631                             input_zero_point),
632       CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
633                             feature_weights_dims, feature_weights_scale, 0),
634       CreateQuantizedTensor(time_weights_data, time_weights_quantized,
635                             time_weights_dims, time_weights_scale, 0),
636       CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
637                                 time_weights_scale, activation_state_scale),
638       CreateQuantizedTensor(initial_activation_state_data,
639                             activation_state_quantized, activation_state_dims,
640                             activation_state_scale, 0,
641                             /*is_variable=*/true),
642       CreateQuantizedTensor(output_data, output_dims, output_scale,
643                             output_zero_point)};
644 
645   tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
646                    output_scale, output_zero_point);
647   tflite::Quantize(input_sequences_data, input_sequences_quantized,
648                    input_sequences_len, input_scale, input_zero_point);
649 
650   ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
651                       tensor_count, activation, input_sequences_quantized,
652                       input_sequences_len, output_data, golden_output_quantized,
653                       /*tolerance*/ 1);
654 }
655 
656 }  // namespace
657 }  // namespace testing
658 }  // namespace tflite
659 
660 TF_LITE_MICRO_TESTS_BEGIN
661 
662 #if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
663                       // reference kernels and we ifdef out test cases that are
664                       // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden)665 TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
666   constexpr int batch_size = 2;
667   constexpr int num_units = 4;
668   constexpr int input_size = 2;
669   constexpr int memory_size = 10;
670   constexpr int rank = 2;
671   constexpr int num_filters = num_units * rank;
672 
673   const int input_size_dims_count = batch_size * input_size;
674   float input_data[input_size_dims_count];
675 
676   const int activation_state_dims_count =
677       batch_size * memory_size * num_filters;
678   float activation_state_data[activation_state_dims_count];
679 
680   memcpy(activation_state_data,
681          tflite::testing::initial_activation_state_data_2x2x10,
682          sizeof(tflite::testing::initial_activation_state_data_2x2x10));
683 
684   const int scratch_dims_count = batch_size * num_filters;
685   float scratch_data[scratch_dims_count];
686 
687   const int output_dims_count = batch_size * num_units;
688   float output_data[output_dims_count];
689 
690   tflite::testing::TestSVDF(
691       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
692       input_data, tflite::testing::feature_weights_data_2x2x10,
693       tflite::testing::time_weights_data_2x2x10, activation_state_data,
694       tflite::testing::bias_data_2x2x10, scratch_data, output_data,
695       tflite::testing::input_data_2x2x10,
696       sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
697       tflite::testing::golden_output_2x2x10);
698 }
699 #endif
700 
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden)701 TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden) {
702   constexpr int batch_size = 2;
703   constexpr int num_units = 4;
704   constexpr int input_size = 2;
705   constexpr int memory_size = 10;
706   constexpr int rank = 2;
707   constexpr int num_filters = num_units * rank;
708 
709   const int input_size_dims_count = batch_size * input_size;
710 
711   const int activation_state_dims_count =
712       batch_size * memory_size * num_filters;
713 
714   const int output_dims_count = batch_size * num_units;
715   int8_t output_data[output_dims_count];
716 
717   float input_scale = 2.5f / INT8_MAX;              // Range is [-2.5, 2.5]
718   float feature_weights_scale = 1.f / INT8_MAX;     // Range is [-1, 1]
719   float time_weights_scale = 1.f / INT16_MAX;       // Range is [-1, 1]
720   float activation_state_scale = 16.f / INT16_MAX;  // Range is [-16, 16]
721   float output_scale = 1.f / INT8_MAX;              // Range is [-1, 1]
722 
723   int input_zero_point = 0;
724   int output_zero_point = 0;
725 
726   int8_t input_quantized[input_size_dims_count];
727   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
728                                    sizeof(float)];
729   int8_t feature_weights_quantized
730       [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
731   int16_t
732       time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
733                              sizeof(float)];
734   int16_t activation_state_quantized[activation_state_dims_count];
735   int32_t
736       bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
737   int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
738                           sizeof(float)];
739 
740   tflite::testing::TestIntegerSVDF(
741       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
742       input_quantized, input_scale, input_zero_point,
743       tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
744       feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
745       time_weights_quantized, time_weights_scale,
746       tflite::testing::bias_data_2x2x10, bias_quantized,
747       tflite::testing::initial_activation_state_data_2x2x10,
748       activation_state_quantized, activation_state_scale, output_data,
749       output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
750       input_sequences_quantized,
751       sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
752       tflite::testing::golden_output_2x2x10, golden_quantized,
753       sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
754 }
755 
756 #if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
757                       // reference kernels and we ifdef out test cases that are
758                       // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden)759 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
760   constexpr int batch_size = 1;
761   constexpr int num_units = 64;
762   constexpr int input_size = 16;
763   constexpr int memory_size = 8;
764   constexpr int rank = 1;
765   constexpr int num_filters = num_units * rank;
766   constexpr int activation_state_dims_count =
767       batch_size * memory_size * num_filters;
768   constexpr int output_dims_count = batch_size * num_units;
769   constexpr int input_dims_count = batch_size * input_size;
770 
771   float input_data[input_dims_count];
772   float output_data[output_dims_count];
773   float scratch_buffer[batch_size * num_filters];
774   float activation_state_data_mutable[activation_state_dims_count];
775 
776   // Initialize activation state to starting values.
777   memcpy(activation_state_data_mutable,
778          tflite::testing::initial_activation_state_data_16x1x1,
779          sizeof(tflite::testing::initial_activation_state_data_16x1x1));
780 
781   tflite::testing::TestSVDF(
782       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
783       input_data, tflite::testing::feature_weights_data_16x1x1,
784       tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
785       tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
786       tflite::testing::input_data_16x1x1, input_size,
787       tflite::testing::golden_output_16x1x1);
788 }
789 
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden)790 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
791   constexpr int batch_size = 1;
792   constexpr int num_units = 64;
793   constexpr int input_size = 16;
794   constexpr int memory_size = 8;
795   constexpr int rank = 1;
796   constexpr int num_filters = num_units * rank;
797   constexpr int activation_state_dims_count =
798       batch_size * memory_size * num_filters;
799   constexpr int output_dims_count = batch_size * num_units;
800   constexpr int input_dims_count = batch_size * input_size;
801 
802   float input_data[input_dims_count];
803   float output_data[output_dims_count];
804   float scratch_buffer[batch_size * num_filters];
805   float activation_state_data_mutable[activation_state_dims_count];
806 
807   // Initialize activation state to starting values.
808   memcpy(activation_state_data_mutable,
809          tflite::testing::initial_activation_state_data_16x1x1,
810          sizeof(tflite::testing::initial_activation_state_data_16x1x1));
811 
812   tflite::testing::TestSVDF(
813       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
814       input_data, tflite::testing::feature_weights_data_16x1x1,
815       tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
816       tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
817       tflite::testing::input_data_16x1x1, input_size,
818       tflite::testing::golden_output_relu_16x1x1);
819 }
820 #endif
821 
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden)822 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden) {
823   constexpr int batch_size = 1;
824   constexpr int num_units = 64;
825   constexpr int input_size = 16;
826   constexpr int memory_size = 8;
827   constexpr int rank = 1;
828   constexpr int num_filters = num_units * rank;
829   constexpr int activation_state_dims_count =
830       batch_size * memory_size * num_filters;
831   constexpr int output_dims_count = batch_size * num_units;
832   constexpr int input_dims_count = batch_size * input_size;
833 
834   int8_t output_data[output_dims_count];
835 
836   float input_scale = 0.10075444;
837   float feature_weights_scale = 0.00649388;
838   float time_weights_scale = 0.001571355;
839   float activation_state_scale = 0.00045896982;
840   float output_scale = 0.051445257;
841 
842   int input_zero_point = 2;
843   int output_zero_point = 0;
844 
845   int8_t input_quantized[input_dims_count];
846   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
847                                    sizeof(float)];
848   int8_t feature_weights_quantized
849       [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
850   int16_t
851       time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
852                              sizeof(float)];
853   int16_t activation_state_quantized[activation_state_dims_count];
854   int32_t
855       bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
856   int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
857                           sizeof(float)];
858 
859   tflite::testing::TestIntegerSVDF(
860       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
861       input_quantized, input_scale, input_zero_point,
862       tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
863       feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
864       time_weights_quantized, time_weights_scale,
865       tflite::testing::bias_data_16x1x1, bias_quantized,
866       tflite::testing::initial_activation_state_data_16x1x1,
867       activation_state_quantized, activation_state_scale, output_data,
868       output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
869       input_sequences_quantized,
870       sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
871       tflite::testing::golden_output_16x1x1, golden_quantized,
872       sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
873 }
874 
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden)875 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden) {
876   constexpr int batch_size = 1;
877   constexpr int num_units = 64;
878   constexpr int input_size = 16;
879   constexpr int memory_size = 8;
880   constexpr int rank = 1;
881   constexpr int num_filters = num_units * rank;
882   constexpr int activation_state_dims_count =
883       batch_size * memory_size * num_filters;
884   constexpr int output_dims_count = batch_size * num_units;
885   constexpr int input_dims_count = batch_size * input_size;
886 
887   int8_t output_data[output_dims_count];
888 
889   float input_scale = 0.10075444;
890   float feature_weights_scale = 0.00649388;
891   float time_weights_scale = 0.001571355;
892   float activation_state_scale = 0.00045896982;
893   float output_scale = 0.051445257;
894 
895   int input_zero_point = 2;
896   int output_zero_point = -128;
897 
898   int8_t input_quantized[input_dims_count];
899   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
900                                    sizeof(float)];
901   int8_t feature_weights_quantized
902       [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
903   int16_t
904       time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
905                              sizeof(float)];
906   int16_t activation_state_quantized[activation_state_dims_count];
907   int32_t
908       bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
909   int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
910                           sizeof(float)];
911 
912   tflite::testing::TestIntegerSVDF(
913       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
914       input_quantized, input_scale, input_zero_point,
915       tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
916       feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
917       time_weights_quantized, time_weights_scale,
918       tflite::testing::bias_data_16x1x1, bias_quantized,
919       tflite::testing::initial_activation_state_data_16x1x1,
920       activation_state_quantized, activation_state_scale, output_data,
921       output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
922       input_sequences_quantized,
923       sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
924       tflite::testing::golden_output_relu_16x1x1, golden_quantized,
925       sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
926 }
927 
928 TF_LITE_MICRO_TESTS_END
929