1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "LSTM.h"
18
19 #include "CpuExecutor.h"
20 #include "HalInterfaces.h"
21
22 namespace android {
23 namespace nn {
24
25 namespace {
26
27 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)28 inline T *GetBuffer(RunTimeOperandInfo* operand) {
29 return reinterpret_cast<T*>(operand->buffer);
30 }
31
32 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)33 inline const T *GetBuffer(const RunTimeOperandInfo* operand) {
34 return reinterpret_cast<const T*>(operand->buffer);
35 }
36
37 } // anonymous namespace
38
LSTMCell(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)39 LSTMCell::LSTMCell(const Operation& operation,
40 std::vector<RunTimeOperandInfo>& operands) {
41 input_ = GetInput(operation, operands, kInputTensor);
42
43 input_to_input_weights_ = GetInput(operation, operands, kInputToInputWeightsTensor); // optional
44 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
45 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
46 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
47
48 recurrent_to_input_weights_ =
49 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
50 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
51 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
52 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
53
54 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
55 cell_to_forget_weights_ = GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
56 cell_to_output_weights_ = GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
57
58 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
59 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
60 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
61 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
62
63 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
64 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
65
66 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
67 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
68
69 params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
70 *GetInput(operation, operands, kActivationParam)));
71 params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
72 params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
73
74 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
75 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
76 output_ = GetOutput(operation, operands, kOutputTensor);
77
78 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
79 }
80
CheckInputTensorDimensions(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,uint32_t n_input,uint32_t n_output,uint32_t n_cell)81 bool LSTMCell::CheckInputTensorDimensions(
82 const Operation &operation, std::vector<RunTimeOperandInfo> &operands,
83 uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
84 LSTMParams params = {
85 .activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
86 *GetInput(operation, operands, LSTMCell::kActivationParam))),
87 .cell_clip_ = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kCellClipParam)),
88 .proj_clip_ = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kProjClipParam))
89 };
90
91 // Making sure clipping parameters have valid values.
92 // == 0 means no clipping
93 // > 0 means clipping
94 NN_CHECK(params.cell_clip_ >= 0);
95 NN_CHECK(params.proj_clip_ >= 0);
96
97 const RunTimeOperandInfo *input_to_input_weights =
98 GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
99 if (!IsNullInput(input_to_input_weights)) {
100 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
101 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
102 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
103 }
104
105 const RunTimeOperandInfo *input_to_forget_weights =
106 GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor);
107 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
108 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
109 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
110
111 const RunTimeOperandInfo *input_to_cell_weights =
112 GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor);
113 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
114 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
115 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
116
117 const RunTimeOperandInfo *recurrent_to_input_weights =
118 GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor);
119 if (!IsNullInput(recurrent_to_input_weights)) {
120 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
121 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
122 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
123 }
124
125 const RunTimeOperandInfo *recurrent_to_forget_weights =
126 GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor);
127 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
128 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
129 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
130
131 const RunTimeOperandInfo *recurrent_to_cell_weights =
132 GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor);
133 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
134 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
135 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
136
137 // We make sure the input-gate's parameters are either both present (regular
138 // LSTM) or not at all (CIFG-LSTM).
139 const bool cifg_weights_all_or_none =
140 (!IsNullInput(input_to_input_weights) &&
141 !IsNullInput(recurrent_to_input_weights)) ||
142 (IsNullInput(input_to_input_weights) &&
143 IsNullInput(recurrent_to_input_weights));
144 NN_CHECK(cifg_weights_all_or_none);
145
146 const RunTimeOperandInfo *cell_to_input_weights =
147 GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor);
148 if (!IsNullInput(cell_to_input_weights)) {
149 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
150 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
151 }
152
153 const RunTimeOperandInfo *cell_to_forget_weights =
154 GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor);
155 if (!IsNullInput(cell_to_forget_weights)) {
156 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
157 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
158 }
159
160 const RunTimeOperandInfo *cell_to_output_weights =
161 GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor);
162 if (!IsNullInput(cell_to_output_weights)) {
163 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
164 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
165 }
166
167 // Making sure the peephole weights are there all or none.
168 const bool use_cifg = IsNullInput(input_to_input_weights);
169 const bool peephole_weights_all_or_none =
170 ((!IsNullInput(cell_to_input_weights) || use_cifg) &&
171 !IsNullInput(cell_to_forget_weights) &&
172 !IsNullInput(cell_to_output_weights)) ||
173 (IsNullInput(cell_to_input_weights) &&
174 IsNullInput(cell_to_forget_weights) &&
175 IsNullInput(cell_to_output_weights));
176 NN_CHECK(peephole_weights_all_or_none);
177
178 // Make sure the input gate bias is present only when not a CIFG-LSTM.
179 const RunTimeOperandInfo* input_gate_bias =
180 GetInput(operation, operands, LSTMCell::kInputGateBiasTensor);
181 if (use_cifg) {
182 NN_CHECK(IsNullInput(input_gate_bias));
183 } else {
184 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
185 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
186 }
187
188 const RunTimeOperandInfo *forget_gate_bias =
189 GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor);
190 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
191 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
192
193 const RunTimeOperandInfo *cell_bias =
194 GetInput(operation, operands, LSTMCell::kCellGateBiasTensor);
195 NN_CHECK_EQ(NumDimensions(cell_bias), 1);
196 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
197
198 const RunTimeOperandInfo *output_gate_bias =
199 GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor);
200 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
201 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
202
203 const RunTimeOperandInfo *projection_weights =
204 GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor);
205 if (!IsNullInput(projection_weights)) {
206 NN_CHECK_EQ(NumDimensions(projection_weights), 2);
207 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
208 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
209 }
210
211 const RunTimeOperandInfo *projection_bias =
212 GetInput(operation, operands, LSTMCell::kProjectionBiasTensor);
213 if (!IsNullInput(projection_bias)) {
214 NN_CHECK_EQ(NumDimensions(projection_bias), 1);
215 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
216 }
217
218 // Making sure the projection tensors are consistent:
219 // 1) If projection weight is not present, then projection bias should not be
220 // present.
221 // 2) If projection weight is present, then projection bias is optional.
222 // TODO: make sure this is correct.
223 const bool projecton_tensors_consistent =
224 (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
225 NN_CHECK(projecton_tensors_consistent == true);
226
227 return true;
228 }
229
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)230 bool LSTMCell::Prepare(const Operation &operation,
231 std::vector<RunTimeOperandInfo> &operands,
232 Shape *scratchShape,
233 Shape *outputStateShape,
234 Shape *cellStateShape,
235 Shape *outputShape) {
236 // Check we have all the inputs and outputs we need.
237 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
238 NumInputsWithValues(operation, operands) <= 23);
239 NN_CHECK_EQ(NumOutputs(operation), 4);
240
241 // Inferring batch size, number of outputs and number of cells from the
242 // input tensors.
243 const RunTimeOperandInfo *input =
244 GetInput(operation, operands, LSTMCell::kInputTensor);
245 NN_CHECK(NumDimensions(input) > 1);
246 const uint32_t n_batch = SizeOfDimension(input, 0);
247 const uint32_t n_input = SizeOfDimension(input, 1);
248
249 const RunTimeOperandInfo *input_to_output_weights =
250 GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor);
251 const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0);
252 NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2);
253 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input);
254
255 const RunTimeOperandInfo *recurrent_to_output_weights =
256 GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor);
257 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2);
258 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0),
259 n_cell);
260 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1);
261
262 // Check that input tensor dimensions matches with each other.
263 if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
264 return false;
265 }
266
267 // Resize the output and output_state tensors.
268 const Shape &inputShape = input->shape();
269
270 outputShape->type = inputShape.type;
271 outputShape->dimensions = { n_batch, n_output };
272 outputShape->offset = inputShape.offset;
273 outputShape->scale = inputShape.scale;
274
275 outputStateShape->type = inputShape.type;
276 outputStateShape->dimensions = { n_batch, n_output };
277 outputStateShape->offset = inputShape.offset;
278 outputStateShape->scale = inputShape.scale;
279
280 cellStateShape->type = inputShape.type;
281 cellStateShape->dimensions = { n_batch, n_cell };
282 cellStateShape->offset = inputShape.offset;
283 cellStateShape->scale = inputShape.scale;
284
285 const RunTimeOperandInfo *input_to_input_weights =
286 GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
287 const bool use_cifg = IsNullInput(input_to_input_weights);
288 if (use_cifg) {
289 // Reserving space for Cell, Forget, Output gates
290 scratchShape->dimensions = { n_batch, n_cell * 3 };
291 } else {
292 // Reserving space for Input, Cell, Forget, Output gates
293 scratchShape->dimensions = { n_batch, n_cell * 4 };
294 }
295 scratchShape->type = inputShape.type;
296 scratchShape->offset = inputShape.offset;
297 scratchShape->scale = inputShape.scale;
298
299 return true;
300 }
301
Eval()302 bool LSTMCell::Eval() {
303 const uint32_t n_batch = input_->shape().dimensions[0];
304 const uint32_t n_input = input_->shape().dimensions[1];
305 // n_cell and n_output will be the same size when there is no projection.
306 const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
307 const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
308
309 // Since we have already checked that weights are all there or none, we can
310 // check the existence of only one to the get the condition.
311 const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE);
312 const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE);
313
314 // Index the scratch buffers pointers to the global scratch buffer.
315 float* input_gate_scratch = nullptr;
316 float* cell_scratch = nullptr;
317 float* forget_gate_scratch = nullptr;
318 float* output_gate_scratch = nullptr;
319 if (use_cifg) {
320 cell_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
321 forget_gate_scratch = cell_scratch + n_cell * n_batch;
322 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
323 } else {
324 input_gate_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
325 cell_scratch = input_gate_scratch + n_cell * n_batch;
326 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
327 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
328 }
329
330 // Initialize scratch buffers with bias.
331 if (!use_cifg) {
332 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
333 n_cell, n_batch, input_gate_scratch);
334 }
335 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
336 n_cell, n_batch, forget_gate_scratch);
337 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
338 n_cell, n_batch, cell_scratch);
339 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
340 n_cell, n_batch, output_gate_scratch);
341
342 // For each batch and cell: compute input_weight * input.
343 if (!use_cifg) {
344 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
345 GetBuffer<float>(input_to_input_weights_), n_cell, n_input,
346 GetBuffer<float>(input_), n_batch, input_gate_scratch, /*result_stride*/1);
347 }
348 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
349 GetBuffer<float>(input_to_forget_weights_), n_cell, n_input,
350 GetBuffer<float>(input_), n_batch, forget_gate_scratch, /*result_stride*/1);
351 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
352 GetBuffer<float>(input_to_cell_weights_), n_cell, n_input,
353 GetBuffer<float>(input_), n_batch, cell_scratch, /*result_stride*/1);
354 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
355 GetBuffer<float>(input_to_output_weights_), n_cell, n_input,
356 GetBuffer<float>(input_), n_batch, output_gate_scratch, /*result_stride*/1);
357
358 // For each batch and cell: compute recurrent_weight * output_state.
359 if (!use_cifg) {
360 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
361 GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
362 GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch, /*result_stride*/1);
363 }
364 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
365 GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
366 GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch, /*result_stride*/1);
367 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
368 GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
369 GetBuffer<float>(output_state_in_), n_batch, cell_scratch, /*result_stride*/1);
370 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
371 GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
372 GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch, /*result_stride*/1);
373
374 // For each batch and cell: update input gate.
375 if (!use_cifg) {
376 if (use_peephole) {
377 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
378 GetBuffer<float>(cell_to_input_weights_), n_cell,
379 GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
380 }
381 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch,
382 n_cell * n_batch,
383 input_gate_scratch);
384 }
385
386 // For each batch and cell: update forget gate.
387 if (use_peephole) {
388 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
389 GetBuffer<float>(cell_to_forget_weights_), n_cell,
390 GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
391 }
392 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch,
393 n_cell * n_batch,
394 forget_gate_scratch);
395
396 // For each batch and cell: update the cell.
397 tflite::tensor_utils::VectorVectorCwiseProduct(
398 forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell,
399 GetBuffer<float>(cell_state_out_));
400 tflite::tensor_utils::ApplyActivationToVector(
401 cell_scratch, n_batch * n_cell,
402 params_.activation_, cell_scratch);
403 if (use_cifg) {
404 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
405 forget_gate_scratch);
406 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
407 cell_scratch, forget_gate_scratch, n_batch * n_cell,
408 GetBuffer<float>(cell_state_out_));
409 } else {
410 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
411 cell_scratch, input_gate_scratch, n_batch * n_cell,
412 GetBuffer<float>(cell_state_out_));
413 }
414 if (params_.cell_clip_ > 0.0) {
415 tflite::tensor_utils::ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
416 params_.cell_clip_, GetBuffer<float>(cell_state_out_));
417 }
418
419 // For each batch and cell: update the output gate.
420 if (use_peephole) {
421 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
422 GetBuffer<float>(cell_to_output_weights_), n_cell,
423 GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
424 }
425 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
426 output_gate_scratch);
427 tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_),
428 n_batch * n_cell,
429 params_.activation_,
430 cell_scratch);
431 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch,
432 cell_scratch, n_batch * n_cell,
433 output_gate_scratch);
434
435 // For each batch: update the projection and output_state.
436 const bool use_projection_weight =
437 (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
438 const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
439 if (use_projection_weight) {
440 if (use_projection_bias) {
441 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(projection_bias_), n_output,
442 n_batch, GetBuffer<float>(output_));
443 } else {
444 tflite::tensor_utils::ZeroVector(GetBuffer<float>(output_), n_batch * n_output);
445 }
446 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
447 GetBuffer<float>(projection_weights_), n_output, n_cell,
448 output_gate_scratch, n_batch, GetBuffer<float>(output_),
449 /*result_stride*/1);
450 if (params_.proj_clip_ > 0.0) {
451 tflite::tensor_utils::ClipVector(GetBuffer<float>(output_), n_batch * n_output,
452 params_.proj_clip_, GetBuffer<float>(output_));
453 }
454 } else {
455 tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
456 GetBuffer<float>(output_));
457 }
458 tflite::tensor_utils::CopyVector(GetBuffer<float>(output_), n_batch * n_output,
459 GetBuffer<float>(output_state_out_));
460
461 return true;
462 }
463
464 } // namespace nn
465 } // namespace android
466