• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "LSTM.h"
18 
19 #include "CpuExecutor.h"
20 #include "HalInterfaces.h"
21 
22 namespace android {
23 namespace nn {
24 
25 namespace {
26 
27 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)28 inline T *GetBuffer(RunTimeOperandInfo* operand) {
29   return reinterpret_cast<T*>(operand->buffer);
30 }
31 
32 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)33 inline const T *GetBuffer(const RunTimeOperandInfo* operand) {
34   return reinterpret_cast<const T*>(operand->buffer);
35 }
36 
37 }  // anonymous namespace
38 
LSTMCell(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)39 LSTMCell::LSTMCell(const Operation& operation,
40                    std::vector<RunTimeOperandInfo>& operands) {
41   input_ = GetInput(operation, operands, kInputTensor);
42 
43   input_to_input_weights_ = GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
44   input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
45   input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
46   input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
47 
48   recurrent_to_input_weights_ =
49       GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
50   recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
51   recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
52   recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
53 
54   cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);    // optional
55   cell_to_forget_weights_ = GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
56   cell_to_output_weights_ = GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
57 
58   input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
59   forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
60   cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
61   output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
62 
63   projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
64   projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
65 
66   output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
67   cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
68 
69   params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
70       *GetInput(operation, operands, kActivationParam)));
71   params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
72   params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
73 
74   output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
75   cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
76   output_ = GetOutput(operation, operands, kOutputTensor);
77 
78   scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
79 }
80 
CheckInputTensorDimensions(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,uint32_t n_input,uint32_t n_output,uint32_t n_cell)81 bool LSTMCell::CheckInputTensorDimensions(
82     const Operation &operation, std::vector<RunTimeOperandInfo> &operands,
83     uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
84   LSTMParams params = {
85     .activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
86         *GetInput(operation, operands, LSTMCell::kActivationParam))),
87     .cell_clip_  = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kCellClipParam)),
88     .proj_clip_  = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kProjClipParam))
89   };
90 
91   // Making sure clipping parameters have valid values.
92   // == 0 means no clipping
93   //  > 0 means clipping
94   NN_CHECK(params.cell_clip_ >= 0);
95   NN_CHECK(params.proj_clip_ >= 0);
96 
97   const RunTimeOperandInfo *input_to_input_weights =
98       GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
99   if (!IsNullInput(input_to_input_weights)) {
100     NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
101     NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
102     NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
103   }
104 
105   const RunTimeOperandInfo *input_to_forget_weights =
106       GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor);
107   NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
108   NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
109   NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
110 
111   const RunTimeOperandInfo *input_to_cell_weights =
112       GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor);
113   NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
114   NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
115   NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
116 
117   const RunTimeOperandInfo *recurrent_to_input_weights =
118       GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor);
119   if (!IsNullInput(recurrent_to_input_weights)) {
120     NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
121     NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
122     NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
123   }
124 
125   const RunTimeOperandInfo *recurrent_to_forget_weights =
126       GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor);
127   NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
128   NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
129   NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
130 
131   const RunTimeOperandInfo *recurrent_to_cell_weights =
132       GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor);
133   NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
134   NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
135   NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
136 
137   // We make sure the input-gate's parameters are either both present (regular
138   // LSTM) or not at all (CIFG-LSTM).
139   const bool cifg_weights_all_or_none =
140       (!IsNullInput(input_to_input_weights) &&
141        !IsNullInput(recurrent_to_input_weights)) ||
142       (IsNullInput(input_to_input_weights) &&
143        IsNullInput(recurrent_to_input_weights));
144   NN_CHECK(cifg_weights_all_or_none);
145 
146   const RunTimeOperandInfo *cell_to_input_weights =
147       GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor);
148   if (!IsNullInput(cell_to_input_weights)) {
149     NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
150     NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
151   }
152 
153   const RunTimeOperandInfo *cell_to_forget_weights =
154       GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor);
155   if (!IsNullInput(cell_to_forget_weights)) {
156     NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
157     NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
158   }
159 
160   const RunTimeOperandInfo *cell_to_output_weights =
161       GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor);
162   if (!IsNullInput(cell_to_output_weights)) {
163     NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
164     NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
165   }
166 
167   // Making sure the peephole weights are there all or none.
168   const bool use_cifg = IsNullInput(input_to_input_weights);
169   const bool peephole_weights_all_or_none =
170       ((!IsNullInput(cell_to_input_weights) || use_cifg) &&
171        !IsNullInput(cell_to_forget_weights) &&
172        !IsNullInput(cell_to_output_weights)) ||
173       (IsNullInput(cell_to_input_weights) &&
174        IsNullInput(cell_to_forget_weights) &&
175        IsNullInput(cell_to_output_weights));
176   NN_CHECK(peephole_weights_all_or_none);
177 
178   // Make sure the input gate bias is present only when not a CIFG-LSTM.
179   const RunTimeOperandInfo* input_gate_bias =
180       GetInput(operation, operands, LSTMCell::kInputGateBiasTensor);
181   if (use_cifg) {
182     NN_CHECK(IsNullInput(input_gate_bias));
183   } else {
184     NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
185     NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
186   }
187 
188   const RunTimeOperandInfo *forget_gate_bias =
189       GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor);
190   NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
191   NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
192 
193   const RunTimeOperandInfo *cell_bias =
194       GetInput(operation, operands, LSTMCell::kCellGateBiasTensor);
195   NN_CHECK_EQ(NumDimensions(cell_bias), 1);
196   NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
197 
198   const RunTimeOperandInfo *output_gate_bias =
199       GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor);
200   NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
201   NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
202 
203   const RunTimeOperandInfo *projection_weights =
204       GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor);
205   if (!IsNullInput(projection_weights)) {
206     NN_CHECK_EQ(NumDimensions(projection_weights), 2);
207     NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
208     NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
209   }
210 
211   const RunTimeOperandInfo *projection_bias =
212       GetInput(operation, operands, LSTMCell::kProjectionBiasTensor);
213   if (!IsNullInput(projection_bias)) {
214     NN_CHECK_EQ(NumDimensions(projection_bias), 1);
215     NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
216   }
217 
218   // Making sure the projection tensors are consistent:
219   // 1) If projection weight is not present, then projection bias should not be
220   // present.
221   // 2) If projection weight is present, then projection bias is optional.
222   // TODO: make sure this is correct.
223   const bool projecton_tensors_consistent =
224       (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
225   NN_CHECK(projecton_tensors_consistent == true);
226 
227   return true;
228 }
229 
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)230 bool LSTMCell::Prepare(const Operation &operation,
231                        std::vector<RunTimeOperandInfo> &operands,
232                        Shape *scratchShape,
233                        Shape *outputStateShape,
234                        Shape *cellStateShape,
235                        Shape *outputShape) {
236   // Check we have all the inputs and outputs we need.
237   NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
238            NumInputsWithValues(operation, operands) <= 23);
239   NN_CHECK_EQ(NumOutputs(operation), 4);
240 
241   // Inferring batch size, number of outputs and number of cells from the
242   // input tensors.
243   const RunTimeOperandInfo *input =
244       GetInput(operation, operands, LSTMCell::kInputTensor);
245   NN_CHECK(NumDimensions(input) > 1);
246   const uint32_t n_batch = SizeOfDimension(input, 0);
247   const uint32_t n_input = SizeOfDimension(input, 1);
248 
249   const RunTimeOperandInfo *input_to_output_weights =
250       GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor);
251   const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0);
252   NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2);
253   NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input);
254 
255   const RunTimeOperandInfo *recurrent_to_output_weights =
256       GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor);
257   NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2);
258   NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0),
259                     n_cell);
260   const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1);
261 
262   // Check that input tensor dimensions matches with each other.
263   if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
264     return false;
265   }
266 
267   // Resize the output and output_state tensors.
268   const Shape &inputShape = input->shape();
269 
270   outputShape->type = inputShape.type;
271   outputShape->dimensions = { n_batch, n_output };
272   outputShape->offset = inputShape.offset;
273   outputShape->scale = inputShape.scale;
274 
275   outputStateShape->type = inputShape.type;
276   outputStateShape->dimensions = { n_batch, n_output };
277   outputStateShape->offset = inputShape.offset;
278   outputStateShape->scale = inputShape.scale;
279 
280   cellStateShape->type = inputShape.type;
281   cellStateShape->dimensions = { n_batch, n_cell };
282   cellStateShape->offset = inputShape.offset;
283   cellStateShape->scale = inputShape.scale;
284 
285   const RunTimeOperandInfo *input_to_input_weights =
286       GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
287   const bool use_cifg = IsNullInput(input_to_input_weights);
288   if (use_cifg) {
289     // Reserving space for Cell, Forget, Output gates
290     scratchShape->dimensions = { n_batch, n_cell * 3 };
291   } else {
292     // Reserving space for Input, Cell, Forget, Output gates
293     scratchShape->dimensions = { n_batch, n_cell * 4 };
294   }
295   scratchShape->type = inputShape.type;
296   scratchShape->offset = inputShape.offset;
297   scratchShape->scale = inputShape.scale;
298 
299   return true;
300 }
301 
Eval()302 bool LSTMCell::Eval() {
303   const uint32_t n_batch = input_->shape().dimensions[0];
304   const uint32_t n_input = input_->shape().dimensions[1];
305   // n_cell and n_output will be the same size when there is no projection.
306   const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
307   const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
308 
309   // Since we have already checked that weights are all there or none, we can
310   // check the existence of only one to the get the condition.
311   const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE);
312   const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE);
313 
314   // Index the scratch buffers pointers to the global scratch buffer.
315   float* input_gate_scratch = nullptr;
316   float* cell_scratch = nullptr;
317   float* forget_gate_scratch = nullptr;
318   float* output_gate_scratch = nullptr;
319   if (use_cifg) {
320     cell_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
321     forget_gate_scratch = cell_scratch + n_cell * n_batch;
322     output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
323   } else {
324     input_gate_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
325     cell_scratch = input_gate_scratch + n_cell * n_batch;
326     forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
327     output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
328   }
329 
330   // Initialize scratch buffers with bias.
331   if (!use_cifg) {
332     tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
333                                                   n_cell, n_batch, input_gate_scratch);
334   }
335   tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
336                                                 n_cell, n_batch, forget_gate_scratch);
337   tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
338                                                 n_cell, n_batch, cell_scratch);
339   tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
340                                                 n_cell, n_batch, output_gate_scratch);
341 
342   // For each batch and cell: compute input_weight * input.
343   if (!use_cifg) {
344     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
345         GetBuffer<float>(input_to_input_weights_), n_cell, n_input,
346         GetBuffer<float>(input_), n_batch, input_gate_scratch, /*result_stride*/1);
347   }
348   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
349       GetBuffer<float>(input_to_forget_weights_), n_cell, n_input,
350       GetBuffer<float>(input_), n_batch, forget_gate_scratch, /*result_stride*/1);
351   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
352       GetBuffer<float>(input_to_cell_weights_), n_cell, n_input,
353       GetBuffer<float>(input_), n_batch, cell_scratch, /*result_stride*/1);
354   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
355       GetBuffer<float>(input_to_output_weights_), n_cell, n_input,
356       GetBuffer<float>(input_), n_batch, output_gate_scratch, /*result_stride*/1);
357 
358   // For each batch and cell: compute recurrent_weight * output_state.
359   if (!use_cifg) {
360     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
361         GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
362         GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch, /*result_stride*/1);
363   }
364   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
365       GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
366       GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch, /*result_stride*/1);
367   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
368       GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
369       GetBuffer<float>(output_state_in_), n_batch, cell_scratch, /*result_stride*/1);
370   tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
371       GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
372       GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch, /*result_stride*/1);
373 
374   // For each batch and cell: update input gate.
375   if (!use_cifg) {
376     if (use_peephole) {
377       tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
378           GetBuffer<float>(cell_to_input_weights_), n_cell,
379           GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
380     }
381     tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch,
382                                                n_cell * n_batch,
383                                                input_gate_scratch);
384   }
385 
386   // For each batch and cell: update forget gate.
387   if (use_peephole) {
388     tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
389         GetBuffer<float>(cell_to_forget_weights_), n_cell,
390         GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
391   }
392   tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch,
393                                              n_cell * n_batch,
394                                              forget_gate_scratch);
395 
396   // For each batch and cell: update the cell.
397   tflite::tensor_utils::VectorVectorCwiseProduct(
398       forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell,
399       GetBuffer<float>(cell_state_out_));
400   tflite::tensor_utils::ApplyActivationToVector(
401       cell_scratch, n_batch * n_cell,
402       params_.activation_, cell_scratch);
403   if (use_cifg) {
404     tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
405                                      forget_gate_scratch);
406     tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
407         cell_scratch, forget_gate_scratch, n_batch * n_cell,
408         GetBuffer<float>(cell_state_out_));
409   } else {
410     tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
411         cell_scratch, input_gate_scratch, n_batch * n_cell,
412         GetBuffer<float>(cell_state_out_));
413   }
414   if (params_.cell_clip_ > 0.0) {
415     tflite::tensor_utils::ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
416                                      params_.cell_clip_, GetBuffer<float>(cell_state_out_));
417   }
418 
419   // For each batch and cell: update the output gate.
420   if (use_peephole) {
421     tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
422         GetBuffer<float>(cell_to_output_weights_), n_cell,
423         GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
424   }
425   tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
426                                              output_gate_scratch);
427   tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_),
428                                                 n_batch * n_cell,
429                                                 params_.activation_,
430                                                 cell_scratch);
431   tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch,
432                                                  cell_scratch, n_batch * n_cell,
433                                                  output_gate_scratch);
434 
435   // For each batch: update the projection and output_state.
436   const bool use_projection_weight =
437           (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
438   const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
439   if (use_projection_weight) {
440     if (use_projection_bias) {
441       tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(projection_bias_), n_output,
442                                                     n_batch, GetBuffer<float>(output_));
443     } else {
444       tflite::tensor_utils::ZeroVector(GetBuffer<float>(output_), n_batch * n_output);
445     }
446     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
447         GetBuffer<float>(projection_weights_), n_output, n_cell,
448         output_gate_scratch, n_batch, GetBuffer<float>(output_),
449         /*result_stride*/1);
450     if (params_.proj_clip_ > 0.0) {
451       tflite::tensor_utils::ClipVector(GetBuffer<float>(output_), n_batch * n_output,
452                                params_.proj_clip_, GetBuffer<float>(output_));
453     }
454   } else {
455     tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
456                              GetBuffer<float>(output_));
457   }
458   tflite::tensor_utils::CopyVector(GetBuffer<float>(output_), n_batch * n_output,
459                            GetBuffer<float>(output_state_out_));
460 
461   return true;
462 }
463 
464 }  // namespace nn
465 }  // namespace android
466