• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE
25 #define ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE
26 
27 #include "tests/Globals.h"
28 #include "tests/framework/Asserts.h"
29 #include "tests/framework/Fixture.h"
30 #include "tests/validation/reference/ActivationLayer.h"
31 #include "tests/validation/reference/ArithmeticOperations.h"
32 #include "tests/validation/reference/ConcatenateLayer.h"
33 #include "tests/validation/reference/FullyConnectedLayer.h"
34 #include "tests/validation/reference/GEMM.h"
35 #include "tests/validation/reference/MeanStdDevNormalizationLayer.h"
36 #include "tests/validation/reference/PixelWiseMultiplication.h"
37 #include "tests/validation/reference/Transpose.h"
38 
39 namespace arm_compute
40 {
41 namespace test
42 {
43 namespace validation
44 {
45 template <typename TensorType, typename AccessorType, typename FunctionType, typename FunctionParams, typename T>
46 class LSTMLayerValidationFixture : public framework::Fixture
47 {
48 public:
49     template <typename...>
setup(TensorShape input_shape,TensorShape input_weights_shape,TensorShape recurrent_weights_shape,TensorShape cell_bias_shape,TensorShape output_cell_shape,TensorShape output_shape,TensorShape scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)50     void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape,
51                TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt,
52                bool use_layer_norm)
53     {
54         _target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
55                                  data_type, projection_opt, peephole_opt, use_layer_norm);
56         _reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
57                                        data_type, projection_opt, peephole_opt, use_layer_norm);
58     }
59 
60 protected:
61     template <typename U>
fill(U && tensor,int i)62     void fill(U &&tensor, int i)
63     {
64         static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
65         using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
66 
67         DistributionType distribution{ T(-1.0f), T(1.0f) };
68         library->fill(tensor, distribution, i);
69     }
70     template <typename U>
fill_custom_val(U && tensor,float num,int i)71     void fill_custom_val(U &&tensor, float num, int i)
72     {
73         static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
74         using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
75 
76         DistributionType distribution{ T(num), T(num) };
77         library->fill(tensor, distribution, i);
78     }
compute_target(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)79     TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
80                               const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
81                               float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
82     {
83         const unsigned int num_cells   = input_weights_shape.y();
84         const unsigned int num_outputs = recurrent_weights_shape.x();
85 
86         // Create tensors
87         TensorType input                 = create_tensor<TensorType>(input_shape, data_type);
88         TensorType input_to_forget_w     = create_tensor<TensorType>(input_weights_shape, data_type);
89         TensorType input_to_cell_w       = create_tensor<TensorType>(input_weights_shape, data_type);
90         TensorType input_to_output_w     = create_tensor<TensorType>(input_weights_shape, data_type);
91         TensorType recurrent_to_forget_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
92         TensorType recurrent_to_cell_w   = create_tensor<TensorType>(recurrent_weights_shape, data_type);
93         TensorType recurrent_to_output_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
94         TensorType forget_gate_bias      = create_tensor<TensorType>(cell_bias_shape, data_type);
95         TensorType cell_bias             = create_tensor<TensorType>(cell_bias_shape, data_type);
96         TensorType output_gate_bias      = create_tensor<TensorType>(cell_bias_shape, data_type);
97         TensorType output_state_in       = create_tensor<TensorType>(output_shape, data_type);
98         TensorType cell_state_in         = create_tensor<TensorType>(output_cell_shape, data_type);
99         TensorType scratch               = create_tensor<TensorType>(scratch_shape, data_type);
100         TensorType output_state_out      = create_tensor<TensorType>(output_shape, data_type);
101         TensorType cell_state_out        = create_tensor<TensorType>(output_cell_shape, data_type);
102         TensorType output                = create_tensor<TensorType>(output_shape, data_type);
103         TensorType input_to_input_w;
104         TensorType recurrent_to_input_w;
105         TensorType cell_to_input_w;
106         TensorType cell_to_forget_w;
107         TensorType input_gate_bias;
108         TensorType cell_to_output_w;
109         TensorType projection_w;
110         TensorType projection_bias;
111         TensorType input_layer_norm_w;
112         TensorType forget_layer_norm_w;
113         TensorType cell_layer_norm_w;
114         TensorType output_layer_norm_w;
115 
116         bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
117 
118         FunctionParams lstm_params;
119 
120         if(!cifg_opt)
121         {
122             input_to_input_w     = create_tensor<TensorType>(input_weights_shape, data_type);
123             recurrent_to_input_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
124             if(peephole_opt)
125             {
126                 cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type);
127             }
128             input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
129             lstm_params.set_cifg_params(&input_to_input_w, &recurrent_to_input_w, &cell_to_input_w, &input_gate_bias);
130         }
131 
132         if(peephole_opt)
133         {
134             cell_to_forget_w = create_tensor<TensorType>(cell_bias_shape, data_type);
135             cell_to_output_w = create_tensor<TensorType>(cell_bias_shape, data_type);
136             lstm_params.set_peephole_params(&cell_to_forget_w, &cell_to_output_w);
137         }
138 
139         if(projection_opt)
140         {
141             projection_w    = create_tensor<TensorType>(TensorShape(num_cells, num_outputs), data_type);
142             projection_bias = create_tensor<TensorType>(TensorShape(num_outputs), data_type);
143             lstm_params.set_projection_params(&projection_w, &projection_bias);
144         }
145 
146         if(use_layer_norm)
147         {
148             forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
149             cell_layer_norm_w   = create_tensor<TensorType>(TensorShape(num_cells), data_type);
150             output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
151             if(!cifg_opt)
152             {
153                 input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
154                 lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
155             }
156             else
157             {
158                 lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
159             }
160         }
161 
162         // Create and configure function
163         FunctionType lstm;
164         lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w,
165                        &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias,
166                        &output_state_in, &cell_state_in,
167                        &scratch, &output_state_out, &cell_state_out, &output,
168                        lstm_params, info, cell_threshold, projection_threshold);
169 
170         ARM_COMPUTE_ASSERT(input.info()->is_resizable());
171         ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable());
172         ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable());
173         ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable());
174         ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable());
175         ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable());
176         ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable());
177         ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable());
178         ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable());
179         ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable());
180         ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable());
181         ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable());
182         ARM_COMPUTE_ASSERT(scratch.info()->is_resizable());
183         ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable());
184         ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable());
185         ARM_COMPUTE_ASSERT(output.info()->is_resizable());
186 
187         // Allocate tensors
188         input.allocator()->allocate();
189         input_to_forget_w.allocator()->allocate();
190         input_to_cell_w.allocator()->allocate();
191         input_to_output_w.allocator()->allocate();
192         recurrent_to_forget_w.allocator()->allocate();
193         recurrent_to_cell_w.allocator()->allocate();
194         recurrent_to_output_w.allocator()->allocate();
195         forget_gate_bias.allocator()->allocate();
196         cell_bias.allocator()->allocate();
197         output_gate_bias.allocator()->allocate();
198         output_state_in.allocator()->allocate();
199         cell_state_in.allocator()->allocate();
200         scratch.allocator()->allocate();
201         output_state_out.allocator()->allocate();
202         cell_state_out.allocator()->allocate();
203         output.allocator()->allocate();
204 
205         ARM_COMPUTE_ASSERT(!input.info()->is_resizable());
206         ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable());
207         ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable());
208         ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable());
209         ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable());
210         ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable());
211         ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable());
212         ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable());
213         ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable());
214         ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable());
215         ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable());
216         ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable());
217         ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable());
218         ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable());
219         ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable());
220         ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
221 
222         // Fill tensors
223         fill(AccessorType(input), 0);
224         fill(AccessorType(input_to_forget_w), 1);
225         fill(AccessorType(input_to_cell_w), 2);
226         fill(AccessorType(input_to_output_w), 3);
227         fill(AccessorType(recurrent_to_forget_w), 4);
228         fill(AccessorType(recurrent_to_cell_w), 5);
229         fill(AccessorType(recurrent_to_output_w), 6);
230         fill(AccessorType(forget_gate_bias), 7);
231         fill(AccessorType(cell_bias), 8);
232         fill(AccessorType(output_gate_bias), 9);
233         fill(AccessorType(output_state_in), 10);
234         fill(AccessorType(cell_state_in), 11);
235         fill(AccessorType(scratch), 12);
236 
237         if(!cifg_opt)
238         {
239             ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable());
240             ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable());
241             ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable());
242             ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable());
243             input_to_input_w.allocator()->allocate();
244             recurrent_to_input_w.allocator()->allocate();
245             cell_to_input_w.allocator()->allocate();
246             input_gate_bias.allocator()->allocate();
247             ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable());
248             ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable());
249             ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable());
250             ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable());
251             fill(AccessorType(input_to_input_w), 13);
252             fill(AccessorType(recurrent_to_input_w), 14);
253             if(peephole_opt)
254             {
255                 fill(AccessorType(cell_to_input_w), 15);
256             }
257             fill(AccessorType(recurrent_to_input_w), 16);
258             fill(AccessorType(input_gate_bias), 17);
259         }
260 
261         if(peephole_opt)
262         {
263             ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable());
264             ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable());
265             cell_to_forget_w.allocator()->allocate();
266             cell_to_output_w.allocator()->allocate();
267             ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable());
268             ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable());
269             fill(AccessorType(cell_to_forget_w), 18);
270             fill(AccessorType(cell_to_output_w), 19);
271         }
272 
273         if(projection_opt)
274         {
275             ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable());
276             ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable());
277 
278             projection_w.allocator()->allocate();
279             projection_bias.allocator()->allocate();
280 
281             ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable());
282             ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable());
283 
284             fill(AccessorType(projection_w), 20);
285             fill(AccessorType(projection_bias), 21);
286         }
287 
288         if(use_layer_norm)
289         {
290             if(!cifg_opt)
291             {
292                 ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable());
293 
294                 input_layer_norm_w.allocator()->allocate();
295 
296                 ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable());
297 
298                 fill(AccessorType(input_layer_norm_w), 22);
299             }
300             ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable());
301             ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable());
302             ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable());
303 
304             forget_layer_norm_w.allocator()->allocate();
305             cell_layer_norm_w.allocator()->allocate();
306             output_layer_norm_w.allocator()->allocate();
307 
308             ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable());
309             ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable());
310             ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable());
311 
312             fill(AccessorType(forget_layer_norm_w), 23);
313             fill(AccessorType(cell_layer_norm_w), 24);
314             fill(AccessorType(output_layer_norm_w), 25);
315         }
316 
317         // Compute function
318         lstm.run();
319 
320         _target_scratch = std::move(scratch);
321         return output;
322     }
323 
compute_reference(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)324     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
325                                       const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
326                                       float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
327     {
328         const unsigned int num_cells   = input_weights_shape.y();
329         const unsigned int num_outputs = recurrent_weights_shape.x();
330 
331         // Create projection weights shape
332         TensorShape projection_weights_shape(num_cells, num_outputs);
333 
334         // Create projection bias shape
335         TensorShape projection_bias_shape(num_outputs);
336 
337         TensorShape     gemm_shape{ 1, output_shape.y() };
338         SimpleTensor<T> gemm_out{ gemm_shape, data_type };
339 
340         // Create reference
341         SimpleTensor<T> input{ input_shape, data_type };
342         SimpleTensor<T> input_to_input_w{ input_weights_shape, data_type };
343         SimpleTensor<T> input_to_forget_w{ input_weights_shape, data_type };
344         SimpleTensor<T> input_to_cell_w{ input_weights_shape, data_type };
345         SimpleTensor<T> input_to_output_w{ input_weights_shape, data_type };
346         SimpleTensor<T> recurrent_to_input_w{ recurrent_weights_shape, data_type };
347         SimpleTensor<T> recurrent_to_forget_w{ recurrent_weights_shape, data_type };
348         SimpleTensor<T> recurrent_to_cell_w{ recurrent_weights_shape, data_type };
349         SimpleTensor<T> recurrent_to_output_w{ recurrent_weights_shape, data_type };
350         SimpleTensor<T> cell_to_input_w{ cell_bias_shape, data_type };
351         SimpleTensor<T> cell_to_forget_w{ cell_bias_shape, data_type };
352         SimpleTensor<T> cell_to_output_w{ cell_bias_shape, data_type };
353         SimpleTensor<T> input_gate_bias{ cell_bias_shape, data_type };
354         SimpleTensor<T> forget_gate_bias{ cell_bias_shape, data_type };
355         SimpleTensor<T> cell_bias{ cell_bias_shape, data_type };
356         SimpleTensor<T> output_gate_bias{ cell_bias_shape, data_type };
357         SimpleTensor<T> projection_w{ projection_weights_shape, data_type };
358         SimpleTensor<T> projection_bias{ projection_bias_shape, data_type };
359         SimpleTensor<T> output_state_in{ output_shape, data_type };
360         SimpleTensor<T> cell_state_in{ output_cell_shape, data_type };
361         SimpleTensor<T> scratch{ scratch_shape, data_type };
362         SimpleTensor<T> output_state_out{ output_shape, data_type };
363         SimpleTensor<T> cell_state_out{ output_cell_shape, data_type };
364         SimpleTensor<T> output{ output_shape, data_type };
365 
366         bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
367 
368         // Fill reference
369         fill(input, 0);
370         fill(input_to_forget_w, 1);
371         fill(input_to_cell_w, 2);
372         fill(input_to_output_w, 3);
373         fill(recurrent_to_forget_w, 4);
374         fill(recurrent_to_cell_w, 5);
375         fill(recurrent_to_output_w, 6);
376         if(use_layer_norm)
377         {
378             fill_custom_val(forget_gate_bias, 0.f, 7);
379             fill_custom_val(cell_bias, 0.f, 8);
380             fill_custom_val(output_gate_bias, 0.f, 9);
381         }
382         else
383         {
384             fill(forget_gate_bias, 7);
385             fill(cell_bias, 8);
386             fill(output_gate_bias, 9);
387         }
388         fill(output_state_in, 10);
389         fill(cell_state_in, 11);
390         fill(scratch, 12);
391         fill(input_to_input_w, 13);
392         fill(recurrent_to_input_w, 14);
393         fill(cell_to_input_w, 15);
394         fill(recurrent_to_input_w, 16);
395         if(!cifg_opt && use_layer_norm)
396         {
397             fill_custom_val(input_gate_bias, 0.f, 17);
398         }
399         else
400         {
401             fill(input_gate_bias, 17);
402         }
403         fill(cell_to_forget_w, 18);
404         fill(cell_to_output_w, 19);
405         fill(projection_w, 20);
406         fill(projection_bias, 21);
407 
408         // Compute forget_gate
409         SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
410         SimpleTensor<T> transposed_weights     = reference::transpose(recurrent_to_forget_w);
411         SimpleTensor<T> gemm                   = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
412         SimpleTensor<T> forget_gate            = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE);
413 
414         if(peephole_opt)
415         {
416             SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type);
417             forget_gate                               = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
418         }
419 
420         if(use_layer_norm)
421         {
422             SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
423             fill(forget_layer_norm_w, 23);
424             forget_gate = reference::mean_std_normalization_layer(forget_gate);
425             forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
426             fill(forget_gate_bias, 7);
427             forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
428         }
429         forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
430 
431         // Compute input_gate
432         SimpleTensor<T> input_gate;
433         if(cifg_opt)
434         {
435             SimpleTensor<T> ones{ cell_bias_shape, data_type };
436             fill_custom_val(ones, 1.f, 0);
437             input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE);
438         }
439         else
440         {
441             SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
442             transposed_weights                    = reference::transpose(recurrent_to_input_w);
443             gemm                                  = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
444             input_gate                            = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
445             if(peephole_opt)
446             {
447                 SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
448                 input_gate                               = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
449             }
450             if(use_layer_norm)
451             {
452                 SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
453                 fill(input_layer_norm_w, 22);
454                 input_gate = reference::mean_std_normalization_layer(input_gate);
455                 input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
456                 fill(input_gate_bias, 17);
457                 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
458             }
459             input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
460         }
461         // Compute cell_state
462         SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
463         transposed_weights                         = reference::transpose(recurrent_to_cell_w);
464         gemm                                       = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
465         SimpleTensor<T> pixelwise_mul              = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
466         cell_state_out                             = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
467         if(use_layer_norm)
468         {
469             SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
470             fill(cell_layer_norm_w, 24);
471             cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
472             cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
473             fill(cell_bias, 8);
474             cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
475         }
476         cell_state_out = reference::activation_layer(cell_state_out, info);
477         cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
478         cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
479 
480         if(cell_threshold != 0.f)
481         {
482             cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
483         }
484 
485         // Compute output
486         SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape);
487         transposed_weights                     = reference::transpose(recurrent_to_output_w);
488         gemm                                   = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
489         output                                 = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
490         if(peephole_opt)
491         {
492             pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
493             output        = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
494         }
495         if(use_layer_norm)
496         {
497             SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
498             fill(output_layer_norm_w, 25);
499             output = reference::mean_std_normalization_layer(output);
500             output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
501             fill(output_gate_bias, 9);
502             output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
503         }
504         output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
505 
506         // Compute output state
507         SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info);
508         output_state_out                      = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
509 
510         if(projection_opt)
511         {
512             SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state_out, projection_w, projection_bias, output_cell_shape);
513             if(projection_threshold != 0.f)
514             {
515                 output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
516             }
517         }
518         std::vector<SimpleTensor<T>> scratch_inputs;
519         if(!cifg_opt)
520         {
521             scratch_inputs.emplace_back(std::move(input_gate));
522         }
523         scratch_inputs.emplace_back(std::move(cell_state_out));
524         scratch_inputs.emplace_back(std::move(forget_gate));
525         scratch_inputs.emplace_back(std::move(output));
526         scratch            = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX);
527         _reference_scratch = std::move(scratch);
528         return output_state_out;
529     }
530 
531     TensorType      _target{};
532     TensorType      _target_scratch{};
533     SimpleTensor<T> _reference{};
534     SimpleTensor<T> _reference_scratch{};
535 };
536 } // namespace validation
537 } // namespace test
538 } // namespace arm_compute
539 #endif /* ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE */
540