• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "LSTM.h"
18 
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21 #include "HalInterfaces.h"
22 #include "OperationsUtils.h"
23 
24 #include "Tracing.h"
25 #include "Utils.h"
26 
27 namespace android {
28 namespace nn {
29 
30 namespace {
31 
32 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)33 inline T* GetBuffer(RunTimeOperandInfo* operand) {
34     return reinterpret_cast<T*>(operand->buffer);
35 }
36 
37 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
39     return reinterpret_cast<const T*>(operand->buffer);
40 }
41 
42 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
44     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
45 }
46 
47 }  // anonymous namespace
48 
LSTMCell(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)49 LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
50     input_ = GetInput(operation, operands, kInputTensor);
51 
52     input_to_input_weights_ =
53             GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
54     input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
55     input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
56     input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
57 
58     recurrent_to_input_weights_ =
59             GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
60     recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
61     recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
62     recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
63 
64     cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);  // optional
65     cell_to_forget_weights_ =
66             GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
67     cell_to_output_weights_ =
68             GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
69 
70     input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
71     forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
72     cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
73     output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
74 
75     projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
76     projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
77 
78     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
79     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
80 
81     params_.activation = static_cast<TfLiteFusedActivation>(
82             getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
83     if (input_->type == OperandType::TENSOR_FLOAT32) {
84         params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
85         params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
86     } else {
87         params_.cell_clip = static_cast<float>(
88                 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
89         params_.proj_clip = static_cast<float>(
90                 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
91     }
92 
93     // We check the version of LSTM by checking the number of the inputs to the
94     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
95     if (operation.inputs.size() == 27) {
96         input_layer_norm_weights_ =
97                 GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
98         forget_layer_norm_weights_ =
99                 GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
100         cell_layer_norm_weights_ =
101                 GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
102         output_layer_norm_weights_ =
103                 GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
104     } else {
105         // For LSTM from HAL v1.0 assign operands with no values
106         static RunTimeOperandInfo no_value;
107         no_value.lifetime = OperandLifeTime::NO_VALUE;
108 
109         input_layer_norm_weights_ = &no_value;
110         forget_layer_norm_weights_ = &no_value;
111         cell_layer_norm_weights_ = &no_value;
112         output_layer_norm_weights_ = &no_value;
113     }
114 
115     output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
116     cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
117     output_ = GetOutput(operation, operands, kOutputTensor);
118 
119     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
120 }
121 
122 // static
CheckInputTensorDimensions(const RunTimeOperandInfo * input_,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo * input_to_output_weights,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo * recurrent_to_output_weights,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)123 bool LSTMCell::CheckInputTensorDimensions(
124         const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
125         const RunTimeOperandInfo* input_to_forget_weights,
126         const RunTimeOperandInfo* input_to_cell_weights,
127         const RunTimeOperandInfo* input_to_output_weights,
128         const RunTimeOperandInfo* recurrent_to_input_weights,
129         const RunTimeOperandInfo* recurrent_to_forget_weights,
130         const RunTimeOperandInfo* recurrent_to_cell_weights,
131         const RunTimeOperandInfo* recurrent_to_output_weights,
132         const RunTimeOperandInfo* cell_to_input_weights,
133         const RunTimeOperandInfo* cell_to_forget_weights,
134         const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
135         const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
136         const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
137         const RunTimeOperandInfo* projection_bias,
138         const RunTimeOperandInfo* input_layer_norm_weights,
139         const RunTimeOperandInfo* forget_layer_norm_weights,
140         const RunTimeOperandInfo* cell_layer_norm_weights,
141         const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
142         uint32_t n_cell, LSTMParams* params) {
143     // Making sure clipping parameters have valid values.
144     // == 0 means no clipping
145     //  > 0 means clipping
146     NN_CHECK(params->cell_clip >= 0);
147     NN_CHECK(params->proj_clip >= 0);
148 
149     if (!IsNullInput(input_to_input_weights)) {
150         NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
151         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
152         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
153     }
154 
155     NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
156     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
157     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
158 
159     NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
160     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
161     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
162 
163     if (!IsNullInput(recurrent_to_input_weights)) {
164         NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
165         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
166         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
167     }
168 
169     NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
170     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
171     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
172 
173     NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
174     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
175     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
176 
177     // We make sure the input-gate's parameters are either both present (regular
178     // LSTM) or not at all (CIFG-LSTM).
179     const bool cifg_weights_all_or_none =
180             (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
181             (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
182     NN_CHECK(cifg_weights_all_or_none);
183 
184     if (!IsNullInput(cell_to_input_weights)) {
185         NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
186         NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
187     }
188 
189     if (!IsNullInput(cell_to_forget_weights)) {
190         NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
191         NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
192     }
193 
194     if (!IsNullInput(cell_to_output_weights)) {
195         NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
196         NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
197     }
198 
199     // Making sure the peephole weights are there all or none.
200     params->use_cifg = IsNullInput(input_to_input_weights);
201     const bool peephole_weights_all_or_none =
202             ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
203              !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
204             (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
205              IsNullInput(cell_to_output_weights));
206     NN_CHECK(peephole_weights_all_or_none);
207 
208     // Since we have already checked that weights are all there or none, we can
209     // check the existence of only one to the get the condition.
210     params->use_peephole = !IsNullInput(cell_to_output_weights);
211     // Checking output instead of input layer norm weights because input can be
212     // omitted ones can be omited in case CIFG LSTM is used.
213     params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
214 
215     params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
216     params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
217 
218     // Make sure the input gate bias is present only when not a CIFG-LSTM.
219     if (params->use_cifg) {
220         NN_CHECK(IsNullInput(input_gate_bias));
221     } else {
222         NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
223         NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
224     }
225 
226     NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
227     NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
228 
229     NN_CHECK_EQ(NumDimensions(cell_bias), 1);
230     NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
231 
232     NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
233     NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
234 
235     if (!IsNullInput(projection_weights)) {
236         NN_CHECK_EQ(NumDimensions(projection_weights), 2);
237         NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
238         NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
239     }
240 
241     if (!IsNullInput(projection_bias)) {
242         NN_CHECK_EQ(NumDimensions(projection_bias), 1);
243         NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
244     }
245 
246     // Making sure the projection tensors are consistent:
247     // 1) If projection weight is not present, then projection bias should not be
248     // present.
249     // 2) If projection weight is present, then projection bias is optional.
250     // TODO: make sure this is correct.
251     const bool projecton_tensors_consistent =
252             (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
253     NN_CHECK(projecton_tensors_consistent == true);
254 
255     if (!IsNullInput(input_layer_norm_weights)) {
256         NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
257         NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
258     }
259     if (!IsNullInput(forget_layer_norm_weights)) {
260         NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
261         NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
262     }
263     if (!IsNullInput(cell_layer_norm_weights)) {
264         NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
265         NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
266     }
267     if (!IsNullInput(output_layer_norm_weights)) {
268         NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
269         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
270     }
271 
272     if (params->use_cifg) {
273         NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
274                 << "input_layer_norm_weights are provided while CIFG is used";
275         const bool layer_norm_weights_all_or_none_cifg =
276                 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
277                  IsNullInput(output_layer_norm_weights)) ||
278                 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
279                  !IsNullInput(output_layer_norm_weights));
280         NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
281     } else {
282         const bool layer_norm_weights_all_or_none =
283                 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
284                  IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
285                 (!IsNullInput(input_layer_norm_weights) &&
286                  !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
287                  !IsNullInput(output_layer_norm_weights));
288         NN_RET_CHECK(layer_norm_weights_all_or_none);
289     }
290 
291     return true;
292 }
293 
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)294 bool LSTMCell::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
295                        Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
296                        Shape* outputShape) {
297     // Check we have all the inputs and outputs we need.
298     NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
299              NumInputsWithValues(operation, operands) <= 27);
300     NN_CHECK_EQ(NumOutputs(operation), 4);
301 
302     // Inferring batch size, number of outputs and number of cells from the
303     // input tensors.
304     NN_CHECK(NumDimensions(input_) > 1);
305     const uint32_t n_batch = SizeOfDimension(input_, 0);
306     const uint32_t n_input = SizeOfDimension(input_, 1);
307 
308     const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
309     NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
310     NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
311 
312     NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
313     NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
314     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
315 
316     // Check that input tensor dimensions matches with each other.
317     if (!CheckInputTensorDimensions(
318                 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
319                 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
320                 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
321                 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
322                 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
323                 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
324                 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
325                 &params_)) {
326         return false;
327     }
328 
329     // Resize the output and output_state tensors.
330     const Shape& inputShape = input_->shape();
331 
332     outputShape->type = inputShape.type;
333     outputShape->dimensions = {n_batch, n_output};
334     outputShape->offset = inputShape.offset;
335     outputShape->scale = inputShape.scale;
336 
337     outputStateShape->type = inputShape.type;
338     outputStateShape->dimensions = {n_batch, n_output};
339     outputStateShape->offset = inputShape.offset;
340     outputStateShape->scale = inputShape.scale;
341 
342     cellStateShape->type = inputShape.type;
343     cellStateShape->dimensions = {n_batch, n_cell};
344     cellStateShape->offset = inputShape.offset;
345     cellStateShape->scale = inputShape.scale;
346 
347     if (params_.use_cifg) {
348         // Reserving space for Cell, Forget, Output gates
349         scratchShape->dimensions = {n_batch, n_cell * 3};
350     } else {
351         // Reserving space for Input, Cell, Forget, Output gates
352         scratchShape->dimensions = {n_batch, n_cell * 4};
353     }
354     scratchShape->type = inputShape.type;
355     scratchShape->offset = inputShape.offset;
356     scratchShape->scale = inputShape.scale;
357 
358     return true;
359 }
360 
361 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)362 bool LSTMCell::LSTMEvalFloat32(
363         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
364         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
365         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
366         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
367         const float* recurrent_to_forget_weights_buffer,
368         const float* recurrent_to_cell_weights_buffer,
369         const float* recurrent_to_output_weights_buffer,
370         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
371         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
372         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
373         const float* aux_input_to_forget_weights_buffer,
374         const float* aux_input_to_cell_weights_buffer,
375         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
376         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
377         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
378         const float* projection_bias_buffer, const float* output_state_in_buffer,
379         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
380         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
381         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
382         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
383         bool timeMajor, bool forwardSequence) {
384     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
385 
386     const uint32_t inputRank = getNumberOfDimensions(input_shape);
387     NN_CHECK(inputRank == 2 || inputRank == 3);
388 
389     const uint32_t maxTime =
390             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
391     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
392                                                 : getSizeOfDimension(input_shape, 0);
393     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
394     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
395     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
396 
397     Shape batchInputShape = input_shape;
398     batchInputShape.dimensions = {batchSize, inputSize};
399     const uint32_t batchInputSize = batchSize * inputSize;
400     const uint32_t batchOutputSize = batchSize * outputSize;
401 
402     std::vector<float> transposedInput;
403     const bool hasAuxInput = (aux_input_buffer != nullptr);
404     std::vector<float> transposedAuxInput;
405     std::vector<float> transposedOutput;
406     Shape transposedInputShape;
407     Shape transposedOutputShape;
408     if (!timeMajor) {
409         transposedInput.resize(maxTime * batchInputSize);
410         transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
411         if (hasAuxInput) {
412             transposedAuxInput.resize(maxTime * batchInputSize);
413             transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
414                                                transposedAuxInput.data());
415         }
416         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
417         transposedOutput.resize(maxTime * batchOutputSize);
418         transposedOutputShape = transposedInputShape;
419         transposedOutputShape.dimensions[2] = outputSize;
420     }
421     const float* inputData = timeMajor ? input_buffer : transposedInput.data();
422     const float* auxInputData =
423             hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
424     float* outputData = timeMajor ? output_buffer : transposedOutput.data();
425 
426     std::vector<float> outputStateInCurrentTimeStep(
427             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
428     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
429                                                   cell_state_in_buffer + batchSize * numCells);
430     const float* inputCurrentTimeStep =
431             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
432     const float* auxInputCurrentTimeStep =
433             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
434                         : nullptr;
435     float* outputCurrentTimeStep =
436             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
437     const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
438     const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
439 
440     for (int t = 0; t < maxTime; ++t) {
441         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
442                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
443                  input_to_output_weights_buffer, input_to_output_weights_shape,
444                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
445                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
446                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
447                  cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
448                  auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
449                  aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
450                  aux_input_to_output_weights_buffer, input_gate_bias_buffer,
451                  forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
452                  projection_weights_buffer, projection_bias_buffer,
453                  outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
454                  input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
455                  cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
456                  output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
457                  scratch_buffer_buffer);
458         inputCurrentTimeStep += batchInputDelta;
459         if (hasAuxInput) {
460             auxInputCurrentTimeStep += batchInputDelta;
461         }
462         outputCurrentTimeStep += batchOutputDelta;
463         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
464                                             output_state_out_buffer + batchSize * outputSize);
465         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
466                                           cell_state_out_buffer + batchSize * numCells);
467     }
468 
469     if (!timeMajor) {
470         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
471                                            output_buffer);
472     }
473 
474     return true;
475 }
476 
477 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)478 bool LSTMCell::LSTMEvalFloat16(
479         const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
480         const _Float16* input_to_input_weights_buffer,
481         const _Float16* input_to_forget_weights_buffer,
482         const _Float16* input_to_cell_weights_buffer,
483         const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
484         const _Float16* recurrent_to_input_weights_buffer,
485         const _Float16* recurrent_to_forget_weights_buffer,
486         const _Float16* recurrent_to_cell_weights_buffer,
487         const _Float16* recurrent_to_output_weights_buffer,
488         const Shape& recurrent_to_output_weights_shape,
489         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
490         const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
491         const _Float16* aux_input_to_input_weights_buffer,
492         const _Float16* aux_input_to_forget_weights_buffer,
493         const _Float16* aux_input_to_cell_weights_buffer,
494         const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
495         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
496         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
497         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
498         const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
499         const _Float16* forget_layer_norm_weights_buffer,
500         const _Float16* cell_layer_norm_weights_buffer,
501         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
502         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
503         bool timeMajor, bool forwardSequence) {
504     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
505 
506     const uint32_t inputRank = getNumberOfDimensions(input_shape);
507     NN_CHECK(inputRank == 2 || inputRank == 3);
508 
509     const uint32_t maxTime =
510             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
511     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
512                                                 : getSizeOfDimension(input_shape, 0);
513     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
514     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
515     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
516 
517     Shape batchInputShape = input_shape;
518     batchInputShape.dimensions = {batchSize, inputSize};
519     const uint32_t batchInputSize = batchSize * inputSize;
520     const uint32_t batchOutputSize = batchSize * outputSize;
521 
522     std::vector<float> input_float32(maxTime * batchInputSize);
523     convertFloat16ToFloat32(input_buffer, &input_float32);
524     std::vector<float> input_to_input_weights_float32(numCells * inputSize);
525     if (input_to_input_weights_buffer != nullptr) {
526         convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
527     }
528     std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
529     convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
530     std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
531     convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
532     std::vector<float> input_to_output_weights_float32(numCells * inputSize);
533     convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
534 
535     std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
536     if (recurrent_to_input_weights_buffer != nullptr) {
537         convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
538                                 &recurrent_to_input_weights_float32);
539     }
540     std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
541     convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
542                             &recurrent_to_forget_weights_float32);
543     std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
544     convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
545     std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
546     convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
547                             &recurrent_to_output_weights_float32);
548 
549     std::vector<float> cell_to_input_weights_float32(numCells);
550     if (cell_to_input_weights_buffer != nullptr) {
551         convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
552     }
553     std::vector<float> cell_to_forget_weights_float32(numCells);
554     if (cell_to_forget_weights_buffer != nullptr) {
555         convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
556     }
557     std::vector<float> cell_to_output_weights_float32(numCells);
558     if (cell_to_output_weights_buffer != nullptr) {
559         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
560     }
561 
562     std::vector<float> aux_input_float32(maxTime * batchInputSize);
563     if (aux_input_buffer != nullptr) {
564         convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
565     }
566     std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
567     if (aux_input_to_input_weights_buffer != nullptr) {
568         convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
569                                 &aux_input_to_input_weights_float32);
570     }
571     std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
572     if (aux_input_to_forget_weights_buffer != nullptr) {
573         convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
574                                 &aux_input_to_forget_weights_float32);
575     }
576     std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
577     if (aux_input_to_cell_weights_buffer != nullptr) {
578         convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
579                                 &aux_input_to_cell_weights_float32);
580     }
581     std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
582     if (aux_input_to_output_weights_buffer != nullptr) {
583         convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
584                                 &aux_input_to_output_weights_float32);
585     }
586 
587     std::vector<float> input_gate_bias_float32(numCells);
588     if (input_gate_bias_buffer != nullptr) {
589         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
590     }
591     std::vector<float> forget_gate_bias_float32(numCells);
592     convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
593     std::vector<float> cell_bias_float32(numCells);
594     convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
595     std::vector<float> output_gate_bias_float32(numCells);
596     convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
597 
598     std::vector<float> projection_weights_float32(numCells * outputSize);
599     if (projection_weights_buffer != nullptr) {
600         convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
601     }
602     std::vector<float> projection_bias_float32(outputSize);
603     if (projection_bias_buffer != nullptr) {
604         convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
605     }
606 
607     std::vector<float> input_layer_norm_weights_float32(numCells);
608     if (input_layer_norm_weights_buffer != nullptr) {
609         convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
610     }
611     std::vector<float> forget_layer_norm_weights_float32(numCells);
612     if (forget_layer_norm_weights_buffer != nullptr) {
613         convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
614                                 &forget_layer_norm_weights_float32);
615     }
616     std::vector<float> cell_layer_norm_weights_float32(numCells);
617     if (cell_layer_norm_weights_buffer != nullptr) {
618         convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
619     }
620     std::vector<float> output_layer_norm_weights_float32(numCells);
621     if (output_layer_norm_weights_buffer != nullptr) {
622         convertFloat16ToFloat32(output_layer_norm_weights_buffer,
623                                 &output_layer_norm_weights_float32);
624     }
625 
626     std::vector<float> output_state_out_float32(batchOutputSize);
627     convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
628     std::vector<float> cell_state_out_float32(batchSize * numCells);
629     convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
630 
631     std::vector<float> output_float32(maxTime * batchOutputSize);
632     convertFloat16ToFloat32(output_buffer, &output_float32);
633     std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
634                                                               : 4 * batchSize * numCells);
635     convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
636 
637     std::vector<float> transposedInput;
638     const bool hasAuxInput = (aux_input_buffer != nullptr);
639     std::vector<float> transposedAuxInput;
640     std::vector<float> transposedOutput;
641     Shape transposedInputShape;
642     Shape transposedOutputShape;
643     if (!timeMajor) {
644         transposedInput.resize(maxTime * batchInputSize);
645         transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
646                                            transposedInput.data());
647         if (hasAuxInput) {
648             transposedAuxInput.resize(maxTime * batchInputSize);
649             transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
650                                                transposedAuxInput.data());
651         }
652         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
653         transposedOutput.resize(maxTime * batchOutputSize);
654         transposedOutputShape = transposedInputShape;
655         transposedOutputShape.dimensions[2] = outputSize;
656     }
657     const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
658     const float* auxInputData =
659             hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
660                         : nullptr;
661     float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
662 
663     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
664     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
665     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
666     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
667 
668     const float* inputCurrentTimeStep =
669             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
670     const float* auxInputCurrentTimeStep =
671             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
672                         : nullptr;
673     float* outputCurrentTimeStep =
674             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
675     const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
676     const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
677 
678     for (int t = 0; t < maxTime; ++t) {
679         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
680                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
681                  input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
682                  input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
683                  recurrent_to_forget_weights_float32.data(),
684                  recurrent_to_cell_weights_float32.data(),
685                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
686                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
687                  cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
688                  aux_input_to_input_weights_float32.data(),
689                  aux_input_to_forget_weights_float32.data(),
690                  aux_input_to_cell_weights_float32.data(),
691                  aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
692                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
693                  output_gate_bias_float32.data(), projection_weights_float32.data(),
694                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
695                  cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
696                  forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
697                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
698                  cell_state_out_float32.data(), outputCurrentTimeStep,
699                  scratch_buffer_float32.data());
700         inputCurrentTimeStep += batchInputDelta;
701         if (hasAuxInput) {
702             auxInputCurrentTimeStep += batchInputDelta;
703         }
704         outputCurrentTimeStep += batchOutputDelta;
705         outputStateInCurrentTimeStep = output_state_out_float32;
706         cellStateInCurrentTimeStep = cell_state_out_float32;
707     }
708 
709     if (!timeMajor) {
710         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
711                                            output_float32.data());
712     }
713 
714     convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
715     convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
716     convertFloat32ToFloat16(output_float32, output_buffer);
717     convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
718     return true;
719 }
720 
721 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)722 bool LSTMCell::LSTMStep(
723         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
724         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
725         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
726         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
727         const float* recurrent_to_forget_weights_buffer,
728         const float* recurrent_to_cell_weights_buffer,
729         const float* recurrent_to_output_weights_buffer,
730         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
731         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
732         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
733         const float* aux_input_to_forget_weights_buffer,
734         const float* aux_input_to_cell_weights_buffer,
735         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
736         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
737         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
738         const float* projection_bias_buffer, const float* output_state_in_buffer,
739         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
740         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
741         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
742         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
743     NNTRACE_COMP("LSTMCell::LSTMStep");
744 
745     const uint32_t n_batch = input_shape.dimensions[0];
746     const uint32_t n_input = input_shape.dimensions[1];
747     // n_cell and n_output will be the same size when there is no projection.
748     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
749     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
750     const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
751 
752     // Index the scratch buffers pointers to the global scratch buffer.
753     float* input_gate_scratch = nullptr;
754     float* cell_scratch = nullptr;
755     float* forget_gate_scratch = nullptr;
756     float* output_gate_scratch = nullptr;
757     if (params.use_cifg) {
758         cell_scratch = scratch_buffer_buffer;
759         forget_gate_scratch = cell_scratch + n_cell * n_batch;
760         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
761     } else {
762         input_gate_scratch = scratch_buffer_buffer;
763         cell_scratch = input_gate_scratch + n_cell * n_batch;
764         forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
765         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
766     }
767 
768     if (!params.use_layer_norm) {
769         // Initialize scratch buffers with bias.
770         if (!params.use_cifg) {
771             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
772                                                           input_gate_scratch);
773         }
774         tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
775                                                       forget_gate_scratch);
776         tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
777                                                       cell_scratch);
778         tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
779                                                       output_gate_scratch);
780     } else {
781         // Initialize scratch buffers with zeroes.
782         if (!params.use_cifg) {
783             tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
784         }
785         tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
786         tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
787         tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
788     }
789 
790     // For each batch and cell: compute input_weight * input.
791     if (!params.use_cifg) {
792         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
793                 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
794                 input_gate_scratch, /*result_stride*/ 1);
795     }
796     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
797             input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch,
798             forget_gate_scratch, /*result_stride*/ 1);
799     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell,
800                                                               n_input, input_buffer, n_batch,
801                                                               cell_scratch, /*result_stride*/ 1);
802     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
803             input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
804             output_gate_scratch, /*result_stride*/ 1);
805 
806     // If auxiliary input is available then compute aux_input_weight * aux_input
807     if (aux_input_buffer != nullptr) {
808         if (!params.use_cifg) {
809             tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
810                     aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
811                     n_batch, input_gate_scratch,
812                     /*result_stride=*/1);
813         }
814 
815         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
816                 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
817                 forget_gate_scratch, /*result_stride=*/1);
818         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
819                 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
820                 cell_scratch, /*result_stride=*/1);
821         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
822                 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
823                 output_gate_scratch, /*result_stride=*/1);
824     }
825 
826     // For each batch and cell: compute recurrent_weight * output_state.
827     if (!params.use_cifg) {
828         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
829                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
830                 n_batch, input_gate_scratch,
831                 /*result_stride*/ 1);
832     }
833     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
834             recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
835             forget_gate_scratch, /*result_stride*/ 1);
836     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
837             recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
838             cell_scratch, /*result_stride*/ 1);
839     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
840             recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
841             output_gate_scratch, /*result_stride*/ 1);
842 
843     // For each batch and cell: update input gate.
844     if (!params.use_cifg) {
845         if (params.use_peephole) {
846             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
847                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
848                     input_gate_scratch);
849         }
850         if (params.use_layer_norm) {
851             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
852                                                           n_cell, n_batch, kLayerNormEpsilon);
853             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
854                                                                 n_cell, input_gate_scratch, n_batch,
855                                                                 input_gate_scratch);
856             tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
857                                                        input_gate_scratch);
858         }
859         tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
860                                                    input_gate_scratch);
861     }
862 
863     // For each batch and cell: update forget gate.
864     if (params.use_peephole) {
865         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
866                                                                       n_cell, cell_state_in_buffer,
867                                                                       n_batch, forget_gate_scratch);
868     }
869     if (params.use_layer_norm) {
870         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
871                                                       n_cell, n_batch, kLayerNormEpsilon);
872         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
873                                                             n_cell, forget_gate_scratch, n_batch,
874                                                             forget_gate_scratch);
875         tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
876                                                    forget_gate_scratch);
877     }
878     tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
879                                                forget_gate_scratch);
880 
881     // For each batch and cell: update the cell.
882     if (params.use_layer_norm) {
883         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
884                                                       kLayerNormEpsilon);
885         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
886                                                             cell_scratch, n_batch, cell_scratch);
887         tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
888     }
889     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
890                                                    n_batch * n_cell, cell_state_out_buffer);
891     tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
892                                                   cell_scratch);
893     if (params.use_cifg) {
894         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
895                                          forget_gate_scratch);
896         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
897                 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
898     } else {
899         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
900                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
901     }
902     if (params.cell_clip > 0.0) {
903         tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
904                                          cell_state_out_buffer);
905     }
906 
907     // For each batch and cell: update the output gate.
908     if (params.use_peephole) {
909         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
910                                                                       n_cell, cell_state_out_buffer,
911                                                                       n_batch, output_gate_scratch);
912     }
913     if (params.use_layer_norm) {
914         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
915                                                       n_cell, n_batch, kLayerNormEpsilon);
916         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
917                                                             n_cell, output_gate_scratch, n_batch,
918                                                             output_gate_scratch);
919         tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
920                                                    output_gate_scratch);
921     }
922     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
923                                                output_gate_scratch);
924     tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
925                                                   params.activation, cell_scratch);
926     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
927                                                    n_batch * n_cell, output_gate_scratch);
928 
929     // For each batch: update the projection and output_state.
930     if (params.use_projection_weight) {
931         if (params.use_projection_bias) {
932             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
933                                                           output_buffer);
934         } else {
935             tflite::tensor_utils::ZeroVector(output_buffer, n_batch * n_output);
936         }
937         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
938                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
939                 output_buffer,
940                 /*result_stride*/ 1);
941         if (params.proj_clip > 0.0) {
942             tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
943                                              output_buffer);
944         }
945     } else {
946         tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, output_buffer);
947     }
948     tflite::tensor_utils::CopyVector(output_buffer, n_batch * n_output, output_state_out_buffer);
949     return true;
950 }
951 
Eval()952 bool LSTMCell::Eval() {
953     switch (input_->type) {
954         case OperandType::TENSOR_FLOAT32: {
955             LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
956                             GetBuffer<const float>(input_to_input_weights_),
957                             GetBuffer<const float>(input_to_forget_weights_),
958                             GetBuffer<const float>(input_to_cell_weights_),
959                             GetBuffer<const float>(input_to_output_weights_),
960                             input_to_output_weights_->shape(),
961                             GetBuffer<const float>(recurrent_to_input_weights_),
962                             GetBuffer<const float>(recurrent_to_forget_weights_),
963                             GetBuffer<const float>(recurrent_to_cell_weights_),
964                             GetBuffer<const float>(recurrent_to_output_weights_),
965                             recurrent_to_output_weights_->shape(),
966                             GetBuffer<const float>(cell_to_input_weights_),
967                             GetBuffer<const float>(cell_to_forget_weights_),
968                             GetBuffer<const float>(cell_to_output_weights_),
969                             /*aux_input_buffer=*/nullptr,
970                             /*aux_input_to_input_weights_buffer=*/nullptr,
971                             /*aux_input_to_forget_weights_buffer=*/nullptr,
972                             /*aux_input_to_cell_weights_buffer=*/nullptr,
973                             /*aux_input_to_output_weights_buffer=*/nullptr,
974                             GetBuffer<const float>(input_gate_bias_),
975                             GetBuffer<const float>(forget_gate_bias_),
976                             GetBuffer<const float>(cell_bias_),
977                             GetBuffer<const float>(output_gate_bias_),
978                             GetBuffer<const float>(projection_weights_),
979                             GetBuffer<const float>(projection_bias_),
980                             GetBuffer<const float>(output_state_in_),
981                             GetBuffer<const float>(cell_state_in_),
982                             GetBuffer<const float>(input_layer_norm_weights_),
983                             GetBuffer<const float>(forget_layer_norm_weights_),
984                             GetBuffer<const float>(cell_layer_norm_weights_),
985                             GetBuffer<const float>(output_layer_norm_weights_),
986                             GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
987                             GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
988         } break;
989         case OperandType::TENSOR_FLOAT16: {
990             LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
991                             GetOptionalBuffer<const _Float16>(input_to_input_weights_),
992                             GetBuffer<const _Float16>(input_to_forget_weights_),
993                             GetBuffer<const _Float16>(input_to_cell_weights_),
994                             GetBuffer<const _Float16>(input_to_output_weights_),
995                             input_to_output_weights_->shape(),
996                             GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
997                             GetBuffer<const _Float16>(recurrent_to_forget_weights_),
998                             GetBuffer<const _Float16>(recurrent_to_cell_weights_),
999                             GetBuffer<const _Float16>(recurrent_to_output_weights_),
1000                             recurrent_to_output_weights_->shape(),
1001                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1002                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1003                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1004                             /*aux_input_buffer=*/nullptr,
1005                             /*aux_input_to_input_weights_buffer=*/nullptr,
1006                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1007                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1008                             /*aux_input_to_output_weights_buffer=*/nullptr,
1009                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
1010                             GetBuffer<const _Float16>(forget_gate_bias_),
1011                             GetBuffer<const _Float16>(cell_bias_),
1012                             GetBuffer<const _Float16>(output_gate_bias_),
1013                             GetOptionalBuffer<const _Float16>(projection_weights_),
1014                             GetOptionalBuffer<const _Float16>(projection_bias_),
1015                             GetBuffer<const _Float16>(output_state_in_),
1016                             GetBuffer<const _Float16>(cell_state_in_),
1017                             GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1018                             GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1019                             GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1020                             GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1021                             GetBuffer<_Float16>(output_state_out_),
1022                             GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1023                             GetBuffer<_Float16>(scratch_buffer_));
1024         } break;
1025         default: {
1026             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1027             return false;
1028         }
1029     }
1030     return true;
1031 }
1032 
1033 }  // namespace nn
1034 }  // namespace android
1035