• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "LSTM.h"
20 
21 #pragma clang diagnostic push
22 #pragma clang diagnostic ignored "-Wunused-parameter"
23 #include <tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h>
24 #pragma clang diagnostic pop
25 
26 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
27 
28 #include <vector>
29 
30 #include "CpuExecutor.h"
31 #include "CpuOperationUtils.h"
32 #include "LegacyUtils.h"
33 #include "OperationsExecutionUtils.h"
34 #include "Tracing.h"
35 #include "nnapi/Types.h"
36 
37 namespace android {
38 namespace nn {
39 
40 namespace {
41 
42 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)43 inline T* GetBuffer(RunTimeOperandInfo* operand) {
44     return reinterpret_cast<T*>(operand->buffer);
45 }
46 
47 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
49     return reinterpret_cast<const T*>(operand->buffer);
50 }
51 
52 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)53 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
54     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
55 }
56 
57 }  // anonymous namespace
58 
LSTMCell(const Operation & operation,RunTimeOperandInfo * operands)59 LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
60     input_ = GetInput(operation, operands, kInputTensor);
61 
62     input_to_input_weights_ =
63             GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
64     input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
65     input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
66     input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
67 
68     recurrent_to_input_weights_ =
69             GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
70     recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
71     recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
72     recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
73 
74     cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);  // optional
75     cell_to_forget_weights_ =
76             GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
77     cell_to_output_weights_ =
78             GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
79 
80     input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
81     forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
82     cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
83     output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
84 
85     projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
86     projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
87 
88     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
89     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
90 
91     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
92     params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
93             activationOperand, TfLiteFusedActivation::kTfLiteActNone));
94 
95     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
96     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
97     if (input_->type == OperandType::TENSOR_FLOAT32) {
98         params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
99         params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
100     } else {
101         params_.cell_clip =
102                 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
103         params_.proj_clip =
104                 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
105     }
106 
107     // We check the version of LSTM by checking the number of the inputs to the
108     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
109     if (operation.inputs.size() == 27) {
110         input_layer_norm_weights_ =
111                 GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
112         forget_layer_norm_weights_ =
113                 GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
114         cell_layer_norm_weights_ =
115                 GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
116         output_layer_norm_weights_ =
117                 GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
118     } else {
119         // For LSTM from HAL v1.0 assign operands with no values
120         static RunTimeOperandInfo no_value;
121         no_value.lifetime = Operand::LifeTime::NO_VALUE;
122 
123         input_layer_norm_weights_ = &no_value;
124         forget_layer_norm_weights_ = &no_value;
125         cell_layer_norm_weights_ = &no_value;
126         output_layer_norm_weights_ = &no_value;
127     }
128 
129     output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
130     cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
131     output_ = GetOutput(operation, operands, kOutputTensor);
132 
133     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
134 }
135 
136 // static
CheckInputTensorDimensions(const RunTimeOperandInfo *,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)137 bool LSTMCell::CheckInputTensorDimensions(
138         const RunTimeOperandInfo* /*input_*/, const RunTimeOperandInfo* input_to_input_weights,
139         const RunTimeOperandInfo* input_to_forget_weights,
140         const RunTimeOperandInfo* input_to_cell_weights,
141         const RunTimeOperandInfo* /*input_to_output_weights*/,
142         const RunTimeOperandInfo* recurrent_to_input_weights,
143         const RunTimeOperandInfo* recurrent_to_forget_weights,
144         const RunTimeOperandInfo* recurrent_to_cell_weights,
145         const RunTimeOperandInfo* /*recurrent_to_output_weights*/,
146         const RunTimeOperandInfo* cell_to_input_weights,
147         const RunTimeOperandInfo* cell_to_forget_weights,
148         const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
149         const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
150         const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
151         const RunTimeOperandInfo* projection_bias,
152         const RunTimeOperandInfo* input_layer_norm_weights,
153         const RunTimeOperandInfo* forget_layer_norm_weights,
154         const RunTimeOperandInfo* cell_layer_norm_weights,
155         const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
156         uint32_t n_cell, LSTMParams* params) {
157     // Making sure clipping parameters have valid values.
158     // == 0 means no clipping
159     //  > 0 means clipping
160     NN_CHECK(params->cell_clip >= 0);
161     NN_CHECK(params->proj_clip >= 0);
162 
163     if (!IsNullInput(input_to_input_weights)) {
164         NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2u);
165         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
166         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
167     }
168 
169     NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2u);
170     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
171     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
172 
173     NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2u);
174     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
175     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
176 
177     if (!IsNullInput(recurrent_to_input_weights)) {
178         NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2u);
179         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
180         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
181     }
182 
183     NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2u);
184     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
185     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
186 
187     NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2u);
188     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
189     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
190 
191     // We make sure the input-gate's parameters are either both present (regular
192     // LSTM) or not at all (CIFG-LSTM).
193     const bool cifg_weights_all_or_none =
194             (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
195             (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
196     NN_CHECK(cifg_weights_all_or_none);
197 
198     if (!IsNullInput(cell_to_input_weights)) {
199         NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1u);
200         NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
201     }
202 
203     if (!IsNullInput(cell_to_forget_weights)) {
204         NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1u);
205         NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
206     }
207 
208     if (!IsNullInput(cell_to_output_weights)) {
209         NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1u);
210         NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
211     }
212 
213     // Making sure the peephole weights are there all or none.
214     params->use_cifg = IsNullInput(input_to_input_weights);
215     const bool peephole_weights_all_or_none =
216             ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
217              !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
218             (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
219              IsNullInput(cell_to_output_weights));
220     NN_CHECK(peephole_weights_all_or_none);
221 
222     // Since we have already checked that weights are all there or none, we can
223     // check the existence of only one to the get the condition.
224     params->use_peephole = !IsNullInput(cell_to_output_weights);
225     // Checking output instead of input layer norm weights because input can be
226     // omitted ones can be omited in case CIFG LSTM is used.
227     params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
228 
229     params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE);
230     params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE);
231 
232     // Make sure the input gate bias is present only when not a CIFG-LSTM.
233     if (params->use_cifg) {
234         NN_CHECK(IsNullInput(input_gate_bias));
235     } else {
236         NN_CHECK_EQ(NumDimensions(input_gate_bias), 1u);
237         NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
238     }
239 
240     NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1u);
241     NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
242 
243     NN_CHECK_EQ(NumDimensions(cell_bias), 1u);
244     NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
245 
246     NN_CHECK_EQ(NumDimensions(output_gate_bias), 1u);
247     NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
248 
249     if (!IsNullInput(projection_weights)) {
250         NN_CHECK_EQ(NumDimensions(projection_weights), 2u);
251         NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
252         NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
253     }
254 
255     if (!IsNullInput(projection_bias)) {
256         NN_CHECK_EQ(NumDimensions(projection_bias), 1u);
257         NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
258     }
259 
260     // Making sure the projection tensors are consistent:
261     // 1) If projection weight is not present, then projection bias should not be
262     // present.
263     // 2) If projection weight is present, then projection bias is optional.
264     // TODO: make sure this is correct.
265     const bool projecton_tensors_consistent =
266             (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
267     NN_CHECK(projecton_tensors_consistent == true);
268 
269     if (!IsNullInput(input_layer_norm_weights)) {
270         NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1u);
271         NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
272     }
273     if (!IsNullInput(forget_layer_norm_weights)) {
274         NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1u);
275         NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
276     }
277     if (!IsNullInput(cell_layer_norm_weights)) {
278         NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1u);
279         NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
280     }
281     if (!IsNullInput(output_layer_norm_weights)) {
282         NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1u);
283         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
284     }
285 
286     if (params->use_cifg) {
287         NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
288                 << "input_layer_norm_weights are provided while CIFG is used";
289         const bool layer_norm_weights_all_or_none_cifg =
290                 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
291                  IsNullInput(output_layer_norm_weights)) ||
292                 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
293                  !IsNullInput(output_layer_norm_weights));
294         NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
295     } else {
296         const bool layer_norm_weights_all_or_none =
297                 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
298                  IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
299                 (!IsNullInput(input_layer_norm_weights) &&
300                  !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
301                  !IsNullInput(output_layer_norm_weights));
302         NN_RET_CHECK(layer_norm_weights_all_or_none);
303     }
304 
305     return true;
306 }
307 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)308 bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
309                        Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
310                        Shape* outputShape) {
311     // Check we have all the inputs and outputs we need.
312     NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
313              NumInputsWithValues(operation, operands) <= 27);
314     constexpr int requiredInputs[] = {
315             kInputTensor,
316             kInputToForgetWeightsTensor,
317             kInputToCellWeightsTensor,
318             kInputToOutputWeightsTensor,
319             kRecurrentToForgetWeightsTensor,
320             kRecurrentToCellWeightsTensor,
321             kRecurrentToOutputWeightsTensor,
322             kForgetGateBiasTensor,
323             kCellGateBiasTensor,
324             kOutputGateBiasTensor,
325             kOutputStateInTensor,
326             kCellStateInTensor,
327             kActivationParam,
328             kCellClipParam,
329             kProjClipParam,
330     };
331     for (const int requiredInput : requiredInputs) {
332         NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
333                 << "required input " << requiredInput << " is omitted";
334     }
335     NN_CHECK_EQ(NumOutputs(operation), 4);
336 
337     // Check that the scalar operands' buffers are large enough.
338     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
339     NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
340     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
341     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
342     if (input_->type == OperandType::TENSOR_FLOAT32) {
343         NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
344         NN_RET_CHECK(projClipOperand.length >= sizeof(float));
345     } else {
346         NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
347         NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
348     }
349 
350     // Inferring batch size, number of outputs and number of cells from the
351     // input tensors.
352     NN_CHECK(NumDimensions(input_) > 1);
353     const uint32_t n_batch = SizeOfDimension(input_, 0);
354     const uint32_t n_input = SizeOfDimension(input_, 1);
355 
356     const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
357     NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2u);
358     NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
359 
360     NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2u);
361     NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
362     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
363 
364     // Check that input tensor dimensions matches with each other.
365     if (!CheckInputTensorDimensions(
366                 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
367                 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
368                 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
369                 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
370                 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
371                 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
372                 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
373                 &params_)) {
374         return false;
375     }
376 
377     // Resize the output and output_state tensors.
378     const Shape& inputShape = input_->shape();
379 
380     outputShape->type = inputShape.type;
381     outputShape->dimensions = {n_batch, n_output};
382     outputShape->offset = inputShape.offset;
383     outputShape->scale = inputShape.scale;
384 
385     outputStateShape->type = inputShape.type;
386     outputStateShape->dimensions = {n_batch, n_output};
387     outputStateShape->offset = inputShape.offset;
388     outputStateShape->scale = inputShape.scale;
389 
390     cellStateShape->type = inputShape.type;
391     cellStateShape->dimensions = {n_batch, n_cell};
392     cellStateShape->offset = inputShape.offset;
393     cellStateShape->scale = inputShape.scale;
394 
395     if (params_.use_cifg) {
396         // Reserving space for Cell, Forget, Output gates
397         scratchShape->dimensions = {n_batch, n_cell * 3};
398     } else {
399         // Reserving space for Input, Cell, Forget, Output gates
400         scratchShape->dimensions = {n_batch, n_cell * 4};
401     }
402     scratchShape->type = inputShape.type;
403     scratchShape->offset = inputShape.offset;
404     scratchShape->scale = inputShape.scale;
405 
406     return true;
407 }
408 
409 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)410 bool LSTMCell::LSTMEvalFloat32(
411         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
412         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
413         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
414         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
415         const float* recurrent_to_forget_weights_buffer,
416         const float* recurrent_to_cell_weights_buffer,
417         const float* recurrent_to_output_weights_buffer,
418         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
419         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
420         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
421         const float* aux_input_to_forget_weights_buffer,
422         const float* aux_input_to_cell_weights_buffer,
423         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
424         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
425         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
426         const float* projection_bias_buffer, const float* output_state_in_buffer,
427         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
428         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
429         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
430         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
431         bool timeMajor, bool forwardSequence) {
432     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
433 
434     const uint32_t inputRank = getNumberOfDimensions(input_shape);
435     NN_CHECK(inputRank == 2 || inputRank == 3);
436 
437     const uint32_t maxTime =
438             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
439     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
440                                                 : getSizeOfDimension(input_shape, 0);
441     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
442     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
443     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
444 
445     Shape batchInputShape = input_shape;
446     batchInputShape.dimensions = {batchSize, inputSize};
447     const uint32_t batchInputSize = batchSize * inputSize;
448     const uint32_t batchOutputSize = batchSize * outputSize;
449 
450     std::vector<float> transposedInput;
451     const bool hasAuxInput = (aux_input_buffer != nullptr);
452     std::vector<float> transposedAuxInput;
453     std::vector<float> transposedOutput;
454     Shape transposedInputShape;
455     Shape transposedOutputShape;
456     if (!timeMajor) {
457         transposedInput.resize(maxTime * batchInputSize);
458         transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
459         if (hasAuxInput) {
460             transposedAuxInput.resize(maxTime * batchInputSize);
461             transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
462                                                transposedAuxInput.data());
463         }
464         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
465         transposedOutput.resize(maxTime * batchOutputSize);
466         transposedOutputShape = transposedInputShape;
467         transposedOutputShape.dimensions[2] = outputSize;
468     }
469     const float* inputData = timeMajor ? input_buffer : transposedInput.data();
470     const float* auxInputData =
471             hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
472     float* outputData = timeMajor ? output_buffer : transposedOutput.data();
473 
474     std::vector<float> outputStateInCurrentTimeStep(
475             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
476     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
477                                                   cell_state_in_buffer + batchSize * numCells);
478     const float* inputCurrentTimeStep =
479             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
480     const float* auxInputCurrentTimeStep =
481             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
482                         : nullptr;
483     float* outputCurrentTimeStep =
484             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
485     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
486     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
487 
488     for (uint32_t t = 0; t < maxTime; ++t) {
489         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
490                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
491                  input_to_output_weights_buffer, input_to_output_weights_shape,
492                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
493                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
494                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
495                  cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
496                  auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
497                  aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
498                  aux_input_to_output_weights_buffer, input_gate_bias_buffer,
499                  forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
500                  projection_weights_buffer, projection_bias_buffer,
501                  outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
502                  input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
503                  cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
504                  output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
505                  scratch_buffer_buffer);
506         inputCurrentTimeStep += batchInputDelta;
507         if (hasAuxInput) {
508             auxInputCurrentTimeStep += batchInputDelta;
509         }
510         outputCurrentTimeStep += batchOutputDelta;
511         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
512                                             output_state_out_buffer + batchSize * outputSize);
513         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
514                                           cell_state_out_buffer + batchSize * numCells);
515     }
516 
517     if (!timeMajor) {
518         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
519                                            output_buffer);
520     }
521 
522     return true;
523 }
524 
525 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)526 bool LSTMCell::LSTMEvalFloat16(
527         const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
528         const _Float16* input_to_input_weights_buffer,
529         const _Float16* input_to_forget_weights_buffer,
530         const _Float16* input_to_cell_weights_buffer,
531         const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
532         const _Float16* recurrent_to_input_weights_buffer,
533         const _Float16* recurrent_to_forget_weights_buffer,
534         const _Float16* recurrent_to_cell_weights_buffer,
535         const _Float16* recurrent_to_output_weights_buffer,
536         const Shape& recurrent_to_output_weights_shape,
537         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
538         const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
539         const _Float16* aux_input_to_input_weights_buffer,
540         const _Float16* aux_input_to_forget_weights_buffer,
541         const _Float16* aux_input_to_cell_weights_buffer,
542         const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
543         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
544         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
545         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
546         const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
547         const _Float16* forget_layer_norm_weights_buffer,
548         const _Float16* cell_layer_norm_weights_buffer,
549         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
550         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
551         bool timeMajor, bool forwardSequence) {
552     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
553 
554     const uint32_t inputRank = getNumberOfDimensions(input_shape);
555     NN_CHECK(inputRank == 2 || inputRank == 3);
556 
557     const uint32_t maxTime =
558             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
559     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
560                                                 : getSizeOfDimension(input_shape, 0);
561     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
562     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
563     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
564 
565     Shape batchInputShape = input_shape;
566     batchInputShape.dimensions = {batchSize, inputSize};
567     const uint32_t batchInputSize = batchSize * inputSize;
568     const uint32_t batchOutputSize = batchSize * outputSize;
569 
570     std::vector<float> input_float32(maxTime * batchInputSize);
571     convertFloat16ToFloat32(input_buffer, &input_float32);
572     std::vector<float> input_to_input_weights_float32(numCells * inputSize);
573     if (input_to_input_weights_buffer != nullptr) {
574         convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
575     }
576     std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
577     convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
578     std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
579     convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
580     std::vector<float> input_to_output_weights_float32(numCells * inputSize);
581     convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
582 
583     std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
584     if (recurrent_to_input_weights_buffer != nullptr) {
585         convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
586                                 &recurrent_to_input_weights_float32);
587     }
588     std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
589     convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
590                             &recurrent_to_forget_weights_float32);
591     std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
592     convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
593     std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
594     convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
595                             &recurrent_to_output_weights_float32);
596 
597     std::vector<float> cell_to_input_weights_float32(numCells);
598     if (cell_to_input_weights_buffer != nullptr) {
599         convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
600     }
601     std::vector<float> cell_to_forget_weights_float32(numCells);
602     if (cell_to_forget_weights_buffer != nullptr) {
603         convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
604     }
605     std::vector<float> cell_to_output_weights_float32(numCells);
606     if (cell_to_output_weights_buffer != nullptr) {
607         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
608     }
609 
610     std::vector<float> aux_input_float32(maxTime * batchInputSize);
611     if (aux_input_buffer != nullptr) {
612         convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
613     }
614     std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
615     if (aux_input_to_input_weights_buffer != nullptr) {
616         convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
617                                 &aux_input_to_input_weights_float32);
618     }
619     std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
620     if (aux_input_to_forget_weights_buffer != nullptr) {
621         convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
622                                 &aux_input_to_forget_weights_float32);
623     }
624     std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
625     if (aux_input_to_cell_weights_buffer != nullptr) {
626         convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
627                                 &aux_input_to_cell_weights_float32);
628     }
629     std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
630     if (aux_input_to_output_weights_buffer != nullptr) {
631         convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
632                                 &aux_input_to_output_weights_float32);
633     }
634 
635     std::vector<float> input_gate_bias_float32(numCells);
636     if (input_gate_bias_buffer != nullptr) {
637         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
638     }
639     std::vector<float> forget_gate_bias_float32(numCells);
640     convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
641     std::vector<float> cell_bias_float32(numCells);
642     convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
643     std::vector<float> output_gate_bias_float32(numCells);
644     convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
645 
646     std::vector<float> projection_weights_float32(numCells * outputSize);
647     if (projection_weights_buffer != nullptr) {
648         convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
649     }
650     std::vector<float> projection_bias_float32(outputSize);
651     if (projection_bias_buffer != nullptr) {
652         convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
653     }
654 
655     std::vector<float> input_layer_norm_weights_float32(numCells);
656     if (input_layer_norm_weights_buffer != nullptr) {
657         convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
658     }
659     std::vector<float> forget_layer_norm_weights_float32(numCells);
660     if (forget_layer_norm_weights_buffer != nullptr) {
661         convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
662                                 &forget_layer_norm_weights_float32);
663     }
664     std::vector<float> cell_layer_norm_weights_float32(numCells);
665     if (cell_layer_norm_weights_buffer != nullptr) {
666         convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
667     }
668     std::vector<float> output_layer_norm_weights_float32(numCells);
669     if (output_layer_norm_weights_buffer != nullptr) {
670         convertFloat16ToFloat32(output_layer_norm_weights_buffer,
671                                 &output_layer_norm_weights_float32);
672     }
673 
674     std::vector<float> output_state_out_float32(batchOutputSize);
675     convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
676     std::vector<float> cell_state_out_float32(batchSize * numCells);
677     convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
678 
679     std::vector<float> output_float32(maxTime * batchOutputSize);
680     convertFloat16ToFloat32(output_buffer, &output_float32);
681     std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
682                                                               : 4 * batchSize * numCells);
683     convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
684 
685     std::vector<float> transposedInput;
686     const bool hasAuxInput = (aux_input_buffer != nullptr);
687     std::vector<float> transposedAuxInput;
688     std::vector<float> transposedOutput;
689     Shape transposedInputShape;
690     Shape transposedOutputShape;
691     if (!timeMajor) {
692         transposedInput.resize(maxTime * batchInputSize);
693         transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
694                                            transposedInput.data());
695         if (hasAuxInput) {
696             transposedAuxInput.resize(maxTime * batchInputSize);
697             transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
698                                                transposedAuxInput.data());
699         }
700         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
701         transposedOutput.resize(maxTime * batchOutputSize);
702         transposedOutputShape = transposedInputShape;
703         transposedOutputShape.dimensions[2] = outputSize;
704     }
705     const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
706     const float* auxInputData =
707             hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
708                         : nullptr;
709     float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
710 
711     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
712     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
713     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
714     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
715 
716     const float* inputCurrentTimeStep =
717             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
718     const float* auxInputCurrentTimeStep =
719             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
720                         : nullptr;
721     float* outputCurrentTimeStep =
722             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
723     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
724     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
725 
726     for (uint32_t t = 0; t < maxTime; ++t) {
727         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
728                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
729                  input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
730                  input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
731                  recurrent_to_forget_weights_float32.data(),
732                  recurrent_to_cell_weights_float32.data(),
733                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
734                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
735                  cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
736                  aux_input_to_input_weights_float32.data(),
737                  aux_input_to_forget_weights_float32.data(),
738                  aux_input_to_cell_weights_float32.data(),
739                  aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
740                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
741                  output_gate_bias_float32.data(), projection_weights_float32.data(),
742                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
743                  cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
744                  forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
745                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
746                  cell_state_out_float32.data(), outputCurrentTimeStep,
747                  scratch_buffer_float32.data());
748         inputCurrentTimeStep += batchInputDelta;
749         if (hasAuxInput) {
750             auxInputCurrentTimeStep += batchInputDelta;
751         }
752         outputCurrentTimeStep += batchOutputDelta;
753         outputStateInCurrentTimeStep = output_state_out_float32;
754         cellStateInCurrentTimeStep = cell_state_out_float32;
755     }
756 
757     if (!timeMajor) {
758         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
759                                            output_float32.data());
760     }
761 
762     convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
763     convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
764     convertFloat32ToFloat16(output_float32, output_buffer);
765     convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
766     return true;
767 }
768 
769 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)770 bool LSTMCell::LSTMStep(
771         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
772         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
773         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
774         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
775         const float* recurrent_to_forget_weights_buffer,
776         const float* recurrent_to_cell_weights_buffer,
777         const float* recurrent_to_output_weights_buffer,
778         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
779         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
780         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
781         const float* aux_input_to_forget_weights_buffer,
782         const float* aux_input_to_cell_weights_buffer,
783         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
784         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
785         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
786         const float* projection_bias_buffer, const float* output_state_in_buffer,
787         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
788         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
789         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
790         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
791     NNTRACE_COMP("LSTMCell::LSTMStep");
792 
793     const uint32_t n_batch = input_shape.dimensions[0];
794     const uint32_t n_input = input_shape.dimensions[1];
795     // n_cell and n_output will be the same size when there is no projection.
796     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
797     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
798     const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
799 
800     // Index the scratch buffers pointers to the global scratch buffer.
801     float* input_gate_scratch = nullptr;
802     float* cell_scratch = nullptr;
803     float* forget_gate_scratch = nullptr;
804     float* output_gate_scratch = nullptr;
805     if (params.use_cifg) {
806         cell_scratch = scratch_buffer_buffer;
807         forget_gate_scratch = cell_scratch + n_cell * n_batch;
808         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
809     } else {
810         input_gate_scratch = scratch_buffer_buffer;
811         cell_scratch = input_gate_scratch + n_cell * n_batch;
812         forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
813         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
814     }
815 
816     if (!params.use_layer_norm) {
817         // Initialize scratch buffers with bias.
818         if (!params.use_cifg) {
819             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
820                                                           input_gate_scratch);
821         }
822         tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
823                                                       forget_gate_scratch);
824         tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
825                                                       cell_scratch);
826         tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
827                                                       output_gate_scratch);
828     } else {
829         // Initialize scratch buffers with zeroes.
830         if (!params.use_cifg) {
831             std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
832         }
833         std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
834         std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
835         std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
836     }
837 
838     // For each batch and cell: compute input_weight * input.
839     if (!params.use_cifg) {
840         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer,
841                                                                   n_cell, n_input, input_buffer,
842                                                                   n_batch, input_gate_scratch);
843     }
844     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer,
845                                                               n_cell, n_input, input_buffer,
846                                                               n_batch, forget_gate_scratch);
847     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
848             input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch);
849     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer,
850                                                               n_cell, n_input, input_buffer,
851                                                               n_batch, output_gate_scratch);
852 
853     // If auxiliary input is available then compute aux_input_weight * aux_input
854     if (aux_input_buffer != nullptr) {
855         if (!params.use_cifg) {
856             tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
857                     aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
858                     n_batch, input_gate_scratch);
859         }
860 
861         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
862                 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
863                 forget_gate_scratch);
864         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
865                 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
866                 cell_scratch);
867         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
868                 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
869                 output_gate_scratch);
870     }
871 
872     // For each batch and cell: compute recurrent_weight * output_state.
873     if (!params.use_cifg) {
874         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
875                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
876                 n_batch, input_gate_scratch);
877     }
878     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
879             recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
880             forget_gate_scratch);
881     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
882             recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
883             cell_scratch);
884     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
885             recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
886             output_gate_scratch);
887 
888     // For each batch and cell: update input gate.
889     if (!params.use_cifg) {
890         if (params.use_peephole) {
891             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
892                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
893                     input_gate_scratch);
894         }
895         if (params.use_layer_norm) {
896             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
897                                                           n_cell, n_batch);
898             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
899                                                                 n_cell, input_gate_scratch, n_batch,
900                                                                 input_gate_scratch);
901             tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
902                                                        input_gate_scratch);
903         }
904         tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
905                                                    input_gate_scratch);
906     }
907 
908     // For each batch and cell: update forget gate.
909     if (params.use_peephole) {
910         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
911                                                                       n_cell, cell_state_in_buffer,
912                                                                       n_batch, forget_gate_scratch);
913     }
914     if (params.use_layer_norm) {
915         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
916                                                       n_cell, n_batch);
917         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
918                                                             n_cell, forget_gate_scratch, n_batch,
919                                                             forget_gate_scratch);
920         tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
921                                                    forget_gate_scratch);
922     }
923     tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
924                                                forget_gate_scratch);
925 
926     // For each batch and cell: update the cell.
927     if (params.use_layer_norm) {
928         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
929         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
930                                                             cell_scratch, n_batch, cell_scratch);
931         tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
932     }
933     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
934                                                    n_batch * n_cell, cell_state_out_buffer);
935     tflite::tensor_utils::ApplyActivationToVector(
936             cell_scratch, n_batch * n_cell, static_cast<TfLiteFusedActivation>(params.activation),
937             cell_scratch);
938     if (params.use_cifg) {
939         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
940                                          forget_gate_scratch);
941         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
942                 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
943     } else {
944         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
945                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
946     }
947     if (params.cell_clip > 0.0) {
948         tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell,
949                                             params.cell_clip);
950     }
951 
952     // For each batch and cell: update the output gate.
953     if (params.use_peephole) {
954         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
955                                                                       n_cell, cell_state_out_buffer,
956                                                                       n_batch, output_gate_scratch);
957     }
958     if (params.use_layer_norm) {
959         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
960                                                       n_cell, n_batch);
961         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
962                                                             n_cell, output_gate_scratch, n_batch,
963                                                             output_gate_scratch);
964         tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
965                                                    output_gate_scratch);
966     }
967     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
968                                                output_gate_scratch);
969     tflite::tensor_utils::ApplyActivationToVector(
970             cell_state_out_buffer, n_batch * n_cell,
971             static_cast<TfLiteFusedActivation>(params.activation), cell_scratch);
972     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
973                                                    n_batch * n_cell, output_gate_scratch);
974 
975     // For each batch: update the projection and output_state.
976     if (params.use_projection_weight) {
977         if (params.use_projection_bias) {
978             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
979                                                           output_buffer);
980         } else {
981             std::fill_n(output_buffer, n_batch * n_output, 0.0f);
982         }
983         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
984                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
985                 output_buffer);
986         if (params.proj_clip > 0.0) {
987             tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output,
988                                                 params.proj_clip);
989         }
990     } else {
991         std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
992     }
993     std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
994     return true;
995 }
996 
Eval()997 bool LSTMCell::Eval() {
998     switch (input_->type) {
999         case OperandType::TENSOR_FLOAT32: {
1000             LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
1001                             GetBuffer<const float>(input_to_input_weights_),
1002                             GetBuffer<const float>(input_to_forget_weights_),
1003                             GetBuffer<const float>(input_to_cell_weights_),
1004                             GetBuffer<const float>(input_to_output_weights_),
1005                             input_to_output_weights_->shape(),
1006                             GetBuffer<const float>(recurrent_to_input_weights_),
1007                             GetBuffer<const float>(recurrent_to_forget_weights_),
1008                             GetBuffer<const float>(recurrent_to_cell_weights_),
1009                             GetBuffer<const float>(recurrent_to_output_weights_),
1010                             recurrent_to_output_weights_->shape(),
1011                             GetBuffer<const float>(cell_to_input_weights_),
1012                             GetBuffer<const float>(cell_to_forget_weights_),
1013                             GetBuffer<const float>(cell_to_output_weights_),
1014                             /*aux_input_buffer=*/nullptr,
1015                             /*aux_input_to_input_weights_buffer=*/nullptr,
1016                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1017                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1018                             /*aux_input_to_output_weights_buffer=*/nullptr,
1019                             GetBuffer<const float>(input_gate_bias_),
1020                             GetBuffer<const float>(forget_gate_bias_),
1021                             GetBuffer<const float>(cell_bias_),
1022                             GetBuffer<const float>(output_gate_bias_),
1023                             GetBuffer<const float>(projection_weights_),
1024                             GetBuffer<const float>(projection_bias_),
1025                             GetBuffer<const float>(output_state_in_),
1026                             GetBuffer<const float>(cell_state_in_),
1027                             GetBuffer<const float>(input_layer_norm_weights_),
1028                             GetBuffer<const float>(forget_layer_norm_weights_),
1029                             GetBuffer<const float>(cell_layer_norm_weights_),
1030                             GetBuffer<const float>(output_layer_norm_weights_),
1031                             GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1032                             GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
1033         } break;
1034         case OperandType::TENSOR_FLOAT16: {
1035             LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1036                             GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1037                             GetBuffer<const _Float16>(input_to_forget_weights_),
1038                             GetBuffer<const _Float16>(input_to_cell_weights_),
1039                             GetBuffer<const _Float16>(input_to_output_weights_),
1040                             input_to_output_weights_->shape(),
1041                             GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1042                             GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1043                             GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1044                             GetBuffer<const _Float16>(recurrent_to_output_weights_),
1045                             recurrent_to_output_weights_->shape(),
1046                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1047                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1048                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1049                             /*aux_input_buffer=*/nullptr,
1050                             /*aux_input_to_input_weights_buffer=*/nullptr,
1051                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1052                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1053                             /*aux_input_to_output_weights_buffer=*/nullptr,
1054                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
1055                             GetBuffer<const _Float16>(forget_gate_bias_),
1056                             GetBuffer<const _Float16>(cell_bias_),
1057                             GetBuffer<const _Float16>(output_gate_bias_),
1058                             GetOptionalBuffer<const _Float16>(projection_weights_),
1059                             GetOptionalBuffer<const _Float16>(projection_bias_),
1060                             GetBuffer<const _Float16>(output_state_in_),
1061                             GetBuffer<const _Float16>(cell_state_in_),
1062                             GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1063                             GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1064                             GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1065                             GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1066                             GetBuffer<_Float16>(output_state_out_),
1067                             GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1068                             GetBuffer<_Float16>(scratch_buffer_));
1069         } break;
1070         default: {
1071             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1072             return false;
1073         }
1074     }
1075     return true;
1076 }
1077 
1078 }  // namespace nn
1079 }  // namespace android
1080