• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25 
26 #include "arm_compute/core/Utils.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/misc/InfoHelpers.h"
29 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31 #include "arm_compute/runtime/CL/CLScheduler.h"
32 #include "src/core/CL/kernels/CLCopyKernel.h"
33 #include "src/core/CL/kernels/CLDepthConvertLayerKernel.h"
34 #include "src/core/CL/kernels/CLFillBorderKernel.h"
35 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
36 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
37 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h"
38 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h"
39 #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h"
40 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
41 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
42 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
43 #include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
44 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
45 #include "src/core/CL/kernels/CLMemsetKernel.h"
46 #include "src/core/CL/kernels/CLTransposeKernel.h"
47 #include "support/MemorySupport.h"
48 
49 namespace arm_compute
50 {
51 using namespace arm_compute::misc::shape_calculator;
52 using namespace arm_compute::utils::info_helpers;
53 
CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)54 CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
55     : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
56       _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(),
57       _transpose_cell_state(support::cpp14::make_unique<CLTransposeKernel>()), _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(),
58       _pixelwise_mul_cell_state2(), _fully_connected_output(), _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(),
59       _fully_connected_output_state(), _projection_clip(), _copy_cell_state(support::cpp14::make_unique<CLCopyKernel>()), _copy_output(support::cpp14::make_unique<CLCopyKernel>()), _concat_scratch_buffer(),
60       _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(), _ones_memset_kernel(support::cpp14::make_unique<CLMemsetKernel>()),
61       _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(), _accum_forget_gate_bias(),
62       _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(), _accum_output_gate_bias(), _input_gate_out1(),
63       _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(), _forget_gate_out6(),
64       _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(), _cell_state_activation(), _output_state1(), _ones(),
65       _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(), _cell_layer_norm_out2(), _output_layer_norm_out1(),
66       _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false), _perform_projection_clipping(false), _is_prepared(false),
67       _is_layer_norm_lstm(false)
68 {
69 }
70 
71 CLLSTMLayer::~CLLSTMLayer() = default;
72 
configure(const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,const ICLTensor * output_state_in,ICLTensor * cell_state_in,ICLTensor * scratch_buffer,ICLTensor * output_state_out,ICLTensor * cell_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)73 void CLLSTMLayer::configure(const ICLTensor *input,
74                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
75                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
76                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
77                             const ICLTensor *output_state_in, ICLTensor *cell_state_in,
78                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
79                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
80 {
81     configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
82               recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
83               cell_threshold, projection_threshold);
84 }
85 
configure(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,const ICLTensor * output_state_in,ICLTensor * cell_state_in,ICLTensor * scratch_buffer,ICLTensor * output_state_out,ICLTensor * cell_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)86 void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
87                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
88                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
89                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
90                             const ICLTensor *output_state_in, ICLTensor *cell_state_in,
91                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
92                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
93 {
94     ARM_COMPUTE_ERROR_ON_NULLPTR(input,
95                                  input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
96                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
97                                  forget_gate_bias, cell_bias, output_gate_bias,
98                                  output_state_in, cell_state_in,
99                                  scratch_buffer, output_state_out, cell_state_out, output);
100 
101     _is_layer_norm_lstm = lstm_params.use_layer_norm();
102 
103     // Set lstm parameters
104     LSTMParams<ITensorInfo> lstm_params_info{};
105     build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
106 
107     // Validate
108     ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
109                                                      input_to_cell_weights->info(), input_to_output_weights->info(),
110                                                      recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
111                                                      forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
112                                                      output_state_in->info(), cell_state_in->info(),
113                                                      scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
114                                                      lstm_params_info, activation_info, cell_threshold, projection_threshold));
115 
116     const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
117     // Configure block that calculates the forget gate
118     // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
119     // We optimize this as follows:
120     // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
121     _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
122     _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
123     _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
124 
125     std::vector<const ICLTensor *> inputs_vector;
126     inputs_vector.emplace_back(input);
127     inputs_vector.emplace_back(output_state_in);
128     const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
129     _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
130 
131     _memory_group.manage(&_forget_gate_out2);
132     _concat_inputs_forget_gate.configure(compile_context, inputs_vector, &_forget_gate_out2, Window::DimX);
133 
134     std::vector<const ICLTensor *> weights_vector;
135 
136     weights_vector.emplace_back(input_to_forget_weights);
137     weights_vector.emplace_back(recurrent_to_forget_weights);
138     const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
139     _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
140 
141     _concat_weights_forget_gate.configure(compile_context, weights_vector, &_forget_gate_out6, Window::DimX);
142 
143     _memory_group.manage(&_forget_gate_out5);
144     _fully_connected_forget_gate.configure(compile_context, &_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
145     _memory_group.manage(&_forget_gate_out1);
146     _memory_group.manage(&_forget_gate_out3);
147     _forget_gate_out6.allocator()->allocate();
148 
149     CLTensor *forget_gate_out = &_forget_gate_out5;
150     if(lstm_params.has_peephole_opt())
151     {
152         _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
153 
154         _run_peephole_opt = true;
155         _memory_group.manage(&_forget_gate_out4);
156         _pixelwise_mul_forget_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
157         _accum_forget_gate1.configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
158         _forget_gate_out4.allocator()->allocate();
159         _forget_gate_out5.allocator()->allocate();
160         forget_gate_out = &_forget_gate_out3;
161     }
162     else
163     {
164         _forget_gate_out3.allocator()->allocate();
165     }
166     if(_is_layer_norm_lstm)
167     {
168         _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
169         _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
170         _memory_group.manage(&_forget_layer_norm_out1);
171         _memory_group.manage(&_forget_layer_norm_out2);
172         _mean_std_norm_forget_gate.configure(compile_context, forget_gate_out);
173         _pixelwise_mul_forget_gate_coeff.configure(compile_context, forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE,
174                                                    RoundingPolicy::TO_NEAREST_EVEN);
175         // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
176         forget_gate_out->allocator()->allocate();
177         _accum_forget_gate_bias.configure(compile_context, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
178         _forget_layer_norm_out1.allocator()->allocate();
179         forget_gate_out = &_forget_layer_norm_out2;
180     }
181     _activation_forget_gate.configure(compile_context, forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
182 
183     // Configure block that calculates the input gate
184     // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
185     // input_gate = 1 - forget_gate, with CIFG
186     // We optimize this as follows:
187     // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
188     _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
189     CLTensor *input_gate_out = &_input_gate_out1;
190     if(lstm_params.has_cifg_opt())
191     {
192         _memory_group.manage(&_input_gate_out1);
193         _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
194         _ones_memset_kernel->configure(compile_context, &_ones, PixelValue(1, _ones.info()->data_type()));
195         _subtract_input_gate.configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
196         _ones.allocator()->allocate();
197         _run_cifg_opt = true;
198     }
199     else
200     {
201         _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
202         _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
203 
204         std::vector<const ICLTensor *> lstm_weights;
205         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
206         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
207         TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
208         _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
209 
210         _concat_weights_input_gate.configure(compile_context, lstm_weights, &_input_gate_out2, Window::DimX);
211 
212         _memory_group.manage(&_input_gate_out1);
213 
214         _memory_group.manage(&_input_gate_out3);
215         _fully_connected_input_gate.configure(compile_context, &_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
216         _input_gate_out2.allocator()->allocate();
217 
218         input_gate_out = &_input_gate_out3;
219         if(_run_peephole_opt)
220         {
221             _memory_group.manage(&_input_gate_out4);
222             _pixelwise_mul_input_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
223             _accum_input_gate1.configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
224             _input_gate_out3.allocator()->allocate();
225             _input_gate_out4.allocator()->allocate();
226             input_gate_out = &_input_gate_out1;
227         }
228         else
229         {
230             _input_gate_out1.allocator()->allocate();
231         }
232 
233         if(_is_layer_norm_lstm)
234         {
235             _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
236             _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
237             _memory_group.manage(&_input_layer_norm_out1);
238             _memory_group.manage(&_input_layer_norm_out2);
239             _mean_std_norm_input_gate.configure(compile_context, input_gate_out);
240             _pixelwise_mul_input_gate_coeff.configure(compile_context, input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE,
241                                                       RoundingPolicy::TO_NEAREST_EVEN);
242             // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
243             input_gate_out->allocator()->allocate();
244             _accum_input_gate_bias.configure(compile_context, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
245             _input_layer_norm_out1.allocator()->allocate();
246             input_gate_out = &_input_layer_norm_out2;
247         }
248         _activation_input_gate.configure(compile_context, input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
249     }
250 
251     // Configure block that calculates the cell state
252     // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
253     TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
254     _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
255     _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
256     _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
257     _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
258     _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
259 
260     _memory_group.manage(&_cell_state_out1);
261     _fully_connected_cell_state.configure(compile_context, input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
262     _memory_group.manage(&_cell_state_out2);
263     _transpose_cell_state->configure(compile_context, recurrent_to_cell_weights, &_cell_state_out2);
264     _memory_group.manage(&_cell_state_out3);
265     _gemm_cell_state1.configure(compile_context, output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
266     _cell_state_out2.allocator()->allocate();
267     _memory_group.manage(&_cell_state_out4);
268     _accum_cell_state1.configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
269     CLTensor *cell_state_out_ptr = &_cell_state_out4;
270     if(_is_layer_norm_lstm)
271     {
272         _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
273         _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
274         _memory_group.manage(&_cell_layer_norm_out1);
275         _memory_group.manage(&_cell_layer_norm_out2);
276         _mean_std_norm_cell_gate.configure(compile_context, cell_state_out_ptr);
277         _pixelwise_mul_cell_gate_coeff.configure(compile_context, cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE,
278                                                  RoundingPolicy::TO_NEAREST_EVEN);
279         // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
280         cell_state_out_ptr->allocator()->allocate();
281         _accum_cell_gate_bias.configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
282         _cell_layer_norm_out1.allocator()->allocate();
283         cell_state_out_ptr = &_cell_layer_norm_out2;
284     }
285     _activation_cell_state.configure(compile_context, cell_state_out_ptr, nullptr, activation_info);
286     _memory_group.manage(&_cell_state_out5);
287     _pixelwise_mul_cell_state1.configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
288     cell_state_out_ptr->allocator()->allocate();
289     _pixelwise_mul_cell_state2.configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
290     _accum_cell_state2.configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
291     _cell_state_out3.allocator()->allocate();
292     _cell_state_out5.allocator()->allocate();
293     // Perform clipping
294     if(cell_threshold != 0.f)
295     {
296         _perform_cell_clipping = true;
297         _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
298     }
299 
300     // Configure block that calculates the output
301     // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
302     // We optimize this as follows:
303     // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
304     _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
305     _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
306     std::vector<const ICLTensor *> in_out_weights;
307     in_out_weights.emplace_back(input_to_output_weights);
308     in_out_weights.emplace_back(recurrent_to_output_weights);
309     TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
310     _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
311 
312     _concat_weights_output.configure(compile_context, in_out_weights, &_output2, Window::DimX);
313 
314     _memory_group.manage(&_output1);
315     _memory_group.manage(&_output4);
316 
317     _fully_connected_output.configure(compile_context, &_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
318 
319     _output2.allocator()->allocate();
320     _forget_gate_out2.allocator()->allocate();
321 
322     CLTensor *output_gate_out = &_output4;
323     if(lstm_params.has_peephole_opt())
324     {
325         _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
326 
327         _memory_group.manage(&_output3);
328         _pixelwise_mul_output_state1.configure(compile_context, &_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
329         _accum_output1.configure(compile_context, &_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
330         _output4.allocator()->allocate();
331         output_gate_out = &_output1;
332 
333         // Allocate intermediate buffers
334         _output3.allocator()->allocate();
335     }
336     else
337     {
338         _output1.allocator()->allocate();
339     }
340     if(_is_layer_norm_lstm)
341     {
342         _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
343         _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
344         _memory_group.manage(&_output_layer_norm_out1);
345         _memory_group.manage(&_output_layer_norm_out2);
346         _mean_std_norm_output_gate.configure(compile_context, output_gate_out);
347         _pixelwise_mul_output_gate_coeff.configure(compile_context, output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE,
348                                                    RoundingPolicy::TO_NEAREST_EVEN);
349         // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
350         output_gate_out->allocator()->allocate();
351         _accum_output_gate_bias.configure(compile_context, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
352         _output_layer_norm_out1.allocator()->allocate();
353         output_gate_out = &_output_layer_norm_out2;
354     }
355     _activation_output.configure(compile_context, output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
356 
357     // Configure block that calculates the output state
358     /** lstm_res = PixelwiseMul(output, Activation(cell_state))
359      *
360      *                      -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
361      *                     /
362      *  output_state =  --
363      *                     \
364      *                      -- lstm_res , otherwise
365      */
366     ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
367     _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
368     _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
369 
370     _memory_group.manage(&_cell_state_activation);
371     _activation_output_state.configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
372     _pixelwise_mul_output_state2.configure(compile_context, &_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
373     _cell_state_activation.allocator()->allocate();
374 
375     if(lstm_params.has_projection())
376     {
377         _has_projection_weights = true;
378         _fully_connected_output_state.configure(compile_context, output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
379         _output_state1.allocator()->allocate();
380         // Perform clipping
381         if(projection_threshold != 0.f)
382         {
383             _perform_projection_clipping = true;
384             _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
385         }
386     }
387 
388     // Copy cell state and output
389     _copy_cell_state->configure(compile_context, &_cell_state_out1, cell_state_out);
390     _copy_output->configure(compile_context, output_state_out, output);
391 
392     // Vector for holding the tensors to store in scratch buffer
393     std::vector<const ICLTensor *> scratch_inputs;
394     if(!lstm_params.has_cifg_opt())
395     {
396         scratch_inputs.emplace_back(input_gate_out);
397     }
398     scratch_inputs.emplace_back(&_cell_state_out1);
399     scratch_inputs.emplace_back(forget_gate_out);
400     scratch_inputs.emplace_back(output_gate_out);
401     _concat_scratch_buffer.configure(compile_context, scratch_inputs, scratch_buffer, Window::DimX);
402     input_gate_out->allocator()->allocate();
403     _cell_state_out1.allocator()->allocate();
404     forget_gate_out->allocator()->allocate();
405     output_gate_out->allocator()->allocate();
406 }
407 
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_in,const ITensorInfo * scratch_buffer,const ITensorInfo * output_state_out,const ITensorInfo * cell_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)408 Status CLLSTMLayer::validate(const ITensorInfo *input,
409                              const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
410                              const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
411                              const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
412                              const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
413                              const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
414                              const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
415 {
416     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
417                                         input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
418                                         recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
419                                         forget_gate_bias, cell_bias, output_gate_bias,
420                                         output_state_in, cell_state_in,
421                                         scratch_buffer, output_state_out, cell_state_out, output);
422 
423     // Check data types
424     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
425     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
426                                                        input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
427                                                        recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
428                                                        forget_gate_bias, cell_bias, output_gate_bias,
429                                                        output_state_in, cell_state_in,
430                                                        scratch_buffer, output_state_out, cell_state_out, output);
431 
432     // Check dimensions
433     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
434     ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
435     ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
436     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
437     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
438     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
439     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
440     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
441     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
442     ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
443     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
444     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
445     ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
446     ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
447     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
448     ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
449     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
450                                 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
451 
452     const unsigned int num_batches = input->dimension(1);
453     const unsigned int num_cells   = input_to_output_weights->dimension(1);
454 
455     if(lstm_params.use_layer_norm())
456     {
457         // If CIFG is used, input layer normalization weights tensor is omitted
458         if(lstm_params.has_cifg_opt())
459         {
460             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
461         }
462         else
463         {
464             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
465             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
466             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
467             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
468         }
469 
470         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
471         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
472         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
473         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
474         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
475         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
476         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
477         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
478     }
479 
480     // Check peephole optimization
481     if(lstm_params.has_peephole_opt())
482     {
483         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
484         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
485         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
486     }
487 
488     TensorShape      units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
489     TensorShape      num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
490     const TensorInfo units_out_transposed_info  = TensorInfo(units_out_transposed_shape, 1, input->data_type());
491     const TensorInfo num_units_transposed_info  = TensorInfo(num_units_transposed_shape, 1, input->data_type());
492 
493     TensorInfo input_gate      = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
494     TensorInfo forget_gate     = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
495     TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
496     TensorInfo cell_state_tmp  = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
497 
498     // Validate forget gate
499     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
500 
501     std::vector<const ITensorInfo *> inputs_vector;
502     inputs_vector.emplace_back(input);
503     inputs_vector.emplace_back(output_state_in);
504     const TensorShape concat_shape       = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
505     TensorInfo        forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
506 
507     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
508 
509     if(lstm_params.has_peephole_opt())
510     {
511         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
512         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
513     }
514     if(lstm_params.use_layer_norm())
515     {
516         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
517         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
518                                                                         RoundingPolicy::TO_NEAREST_EVEN));
519         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
520     }
521     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
522 
523     // Validate input gate
524     if(!lstm_params.has_cifg_opt())
525     {
526         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
527                                             lstm_params.recurrent_to_input_weights(),
528                                             lstm_params.input_gate_bias());
529         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
530         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
531         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
532 
533         std::vector<const ITensorInfo *> lstm_weights;
534         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
535         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
536         TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
537         TensorInfo  lstm_gate_concat          = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
538         ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
539 
540         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
541 
542         if(lstm_params.has_peephole_opt())
543         {
544             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
545             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
546             ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
547             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
548         }
549 
550         if(lstm_params.use_layer_norm())
551         {
552             ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
553             ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
554             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
555         }
556         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
557     }
558     else
559     {
560         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
561     }
562 
563     // Validate cell state
564     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
565     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
566     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
567     if(lstm_params.use_layer_norm())
568     {
569         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
570         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
571                                                                         RoundingPolicy::TO_NEAREST_EVEN));
572         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
573     }
574     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
575     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
576     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
577     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
578     if(cell_threshold != 0.f)
579     {
580         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
581                                                                                                               cell_threshold)));
582     }
583 
584     std::vector<const ITensorInfo *> in_out_weights;
585     in_out_weights.emplace_back(input_to_output_weights);
586     in_out_weights.emplace_back(recurrent_to_output_weights);
587     TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
588     TensorInfo  in_out_gate_concat          = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
589     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
590     // Validate output gate tmp
591     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
592 
593     if(lstm_params.has_peephole_opt())
594     {
595         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
596                                                                         RoundingPolicy::TO_NEAREST_EVEN));
597         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
598     }
599     if(lstm_params.use_layer_norm())
600     {
601         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
602         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
603                                                                         RoundingPolicy::TO_NEAREST_EVEN));
604         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
605     }
606     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
607 
608     // Validate output state
609     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
610     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
611     if(lstm_params.has_projection())
612     {
613         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
614         if(projection_threshold != 0.f)
615         {
616             ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, output_state_out,
617                                                                     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
618         }
619     }
620 
621     // Validate copy kernel
622     ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
623     ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
624 
625     // Validate scratch concatenation
626     std::vector<const ITensorInfo *> inputs_vector_info_raw;
627     if(!lstm_params.has_cifg_opt())
628     {
629         inputs_vector_info_raw.push_back(&input_gate);
630     }
631     inputs_vector_info_raw.push_back(&cell_state_tmp);
632     inputs_vector_info_raw.push_back(&forget_gate);
633     inputs_vector_info_raw.push_back(&output_gate_tmp);
634 
635     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
636     return Status{};
637 }
638 
run()639 void CLLSTMLayer::run()
640 {
641     prepare();
642 
643     MemoryGroupResourceScope scope_mg(_memory_group);
644 
645     _concat_inputs_forget_gate.run();
646 
647     _fully_connected_forget_gate.run();
648 
649     if(_run_peephole_opt)
650     {
651         _pixelwise_mul_forget_gate.run();
652         _accum_forget_gate1.run();
653     }
654     if(_is_layer_norm_lstm)
655     {
656         _mean_std_norm_forget_gate.run();
657         _pixelwise_mul_forget_gate_coeff.run();
658         _accum_forget_gate_bias.run();
659     }
660     _activation_forget_gate.run();
661 
662     if(_run_cifg_opt)
663     {
664         CLScheduler::get().enqueue(*_ones_memset_kernel);
665         _subtract_input_gate.run();
666     }
667     else
668     {
669         _fully_connected_input_gate.run();
670 
671         if(_run_peephole_opt)
672         {
673             _pixelwise_mul_input_gate.run();
674             _accum_input_gate1.run();
675         }
676 
677         if(_is_layer_norm_lstm)
678         {
679             _mean_std_norm_input_gate.run();
680             _pixelwise_mul_input_gate_coeff.run();
681             _accum_input_gate_bias.run();
682         }
683         _activation_input_gate.run();
684     }
685 
686     _fully_connected_cell_state.run();
687     CLScheduler::get().enqueue(*_transpose_cell_state);
688     _gemm_cell_state1.run();
689     _accum_cell_state1.run();
690     if(_is_layer_norm_lstm)
691     {
692         _mean_std_norm_cell_gate.run();
693         _pixelwise_mul_cell_gate_coeff.run();
694         _accum_cell_gate_bias.run();
695     }
696     _activation_cell_state.run();
697     _pixelwise_mul_cell_state1.run();
698     _pixelwise_mul_cell_state2.run();
699     _accum_cell_state2.run();
700 
701     if(_perform_cell_clipping)
702     {
703         _cell_clip.run();
704     }
705 
706     _fully_connected_output.run();
707 
708     if(_run_peephole_opt)
709     {
710         _pixelwise_mul_output_state1.run();
711         _accum_output1.run();
712     }
713     if(_is_layer_norm_lstm)
714     {
715         _mean_std_norm_output_gate.run();
716         _pixelwise_mul_output_gate_coeff.run();
717         _accum_output_gate_bias.run();
718     }
719     _activation_output.run();
720 
721     _activation_output_state.run();
722     _pixelwise_mul_output_state2.run();
723 
724     if(_has_projection_weights)
725     {
726         _fully_connected_output_state.run();
727         if(_perform_projection_clipping)
728         {
729             _projection_clip.run();
730         }
731     }
732 
733     CLScheduler::get().enqueue(*_copy_cell_state);
734     CLScheduler::get().enqueue(*_copy_output);
735 
736     _concat_scratch_buffer.run();
737 }
738 
prepare()739 void CLLSTMLayer::prepare()
740 {
741     if(!_is_prepared)
742     {
743         _concat_weights_forget_gate.run();
744         if(!_run_cifg_opt)
745         {
746             _concat_weights_input_gate.run();
747         }
748         _concat_weights_output.run();
749         _is_prepared = true;
750     }
751 }
752 } // namespace arm_compute
753