1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "LSTM.h"
20
21 #pragma clang diagnostic push
22 #pragma clang diagnostic ignored "-Wunused-parameter"
23 #include <tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h>
24 #pragma clang diagnostic pop
25
26 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
27
28 #include <vector>
29
30 #include "CpuExecutor.h"
31 #include "CpuOperationUtils.h"
32 #include "LegacyUtils.h"
33 #include "OperationsExecutionUtils.h"
34 #include "Tracing.h"
35 #include "nnapi/Types.h"
36
37 namespace android {
38 namespace nn {
39
40 namespace {
41
42 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)43 inline T* GetBuffer(RunTimeOperandInfo* operand) {
44 return reinterpret_cast<T*>(operand->buffer);
45 }
46
47 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
49 return reinterpret_cast<const T*>(operand->buffer);
50 }
51
52 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)53 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
54 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
55 }
56
57 } // anonymous namespace
58
LSTMCell(const Operation & operation,RunTimeOperandInfo * operands)59 LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
60 input_ = GetInput(operation, operands, kInputTensor);
61
62 input_to_input_weights_ =
63 GetInput(operation, operands, kInputToInputWeightsTensor); // optional
64 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
65 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
66 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
67
68 recurrent_to_input_weights_ =
69 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
70 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
71 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
72 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
73
74 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
75 cell_to_forget_weights_ =
76 GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
77 cell_to_output_weights_ =
78 GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
79
80 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
81 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
82 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
83 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
84
85 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
86 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
87
88 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
89 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
90
91 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
92 params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>(
93 activationOperand, TfLiteFusedActivation::kTfLiteActNone));
94
95 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
96 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
97 if (input_->type == OperandType::TENSOR_FLOAT32) {
98 params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
99 params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
100 } else {
101 params_.cell_clip =
102 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
103 params_.proj_clip =
104 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
105 }
106
107 // We check the version of LSTM by checking the number of the inputs to the
108 // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
109 if (operation.inputs.size() == 27) {
110 input_layer_norm_weights_ =
111 GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional
112 forget_layer_norm_weights_ =
113 GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional
114 cell_layer_norm_weights_ =
115 GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional
116 output_layer_norm_weights_ =
117 GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional
118 } else {
119 // For LSTM from HAL v1.0 assign operands with no values
120 static RunTimeOperandInfo no_value;
121 no_value.lifetime = Operand::LifeTime::NO_VALUE;
122
123 input_layer_norm_weights_ = &no_value;
124 forget_layer_norm_weights_ = &no_value;
125 cell_layer_norm_weights_ = &no_value;
126 output_layer_norm_weights_ = &no_value;
127 }
128
129 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
130 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
131 output_ = GetOutput(operation, operands, kOutputTensor);
132
133 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
134 }
135
136 // static
CheckInputTensorDimensions(const RunTimeOperandInfo *,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo *,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)137 bool LSTMCell::CheckInputTensorDimensions(
138 const RunTimeOperandInfo* /*input_*/, const RunTimeOperandInfo* input_to_input_weights,
139 const RunTimeOperandInfo* input_to_forget_weights,
140 const RunTimeOperandInfo* input_to_cell_weights,
141 const RunTimeOperandInfo* /*input_to_output_weights*/,
142 const RunTimeOperandInfo* recurrent_to_input_weights,
143 const RunTimeOperandInfo* recurrent_to_forget_weights,
144 const RunTimeOperandInfo* recurrent_to_cell_weights,
145 const RunTimeOperandInfo* /*recurrent_to_output_weights*/,
146 const RunTimeOperandInfo* cell_to_input_weights,
147 const RunTimeOperandInfo* cell_to_forget_weights,
148 const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
149 const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
150 const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
151 const RunTimeOperandInfo* projection_bias,
152 const RunTimeOperandInfo* input_layer_norm_weights,
153 const RunTimeOperandInfo* forget_layer_norm_weights,
154 const RunTimeOperandInfo* cell_layer_norm_weights,
155 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
156 uint32_t n_cell, LSTMParams* params) {
157 // Making sure clipping parameters have valid values.
158 // == 0 means no clipping
159 // > 0 means clipping
160 NN_CHECK(params->cell_clip >= 0);
161 NN_CHECK(params->proj_clip >= 0);
162
163 if (!IsNullInput(input_to_input_weights)) {
164 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2u);
165 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
166 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
167 }
168
169 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2u);
170 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
171 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
172
173 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2u);
174 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
175 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
176
177 if (!IsNullInput(recurrent_to_input_weights)) {
178 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2u);
179 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
180 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
181 }
182
183 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2u);
184 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
185 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
186
187 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2u);
188 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
189 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
190
191 // We make sure the input-gate's parameters are either both present (regular
192 // LSTM) or not at all (CIFG-LSTM).
193 const bool cifg_weights_all_or_none =
194 (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
195 (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
196 NN_CHECK(cifg_weights_all_or_none);
197
198 if (!IsNullInput(cell_to_input_weights)) {
199 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1u);
200 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
201 }
202
203 if (!IsNullInput(cell_to_forget_weights)) {
204 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1u);
205 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
206 }
207
208 if (!IsNullInput(cell_to_output_weights)) {
209 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1u);
210 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
211 }
212
213 // Making sure the peephole weights are there all or none.
214 params->use_cifg = IsNullInput(input_to_input_weights);
215 const bool peephole_weights_all_or_none =
216 ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
217 !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
218 (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
219 IsNullInput(cell_to_output_weights));
220 NN_CHECK(peephole_weights_all_or_none);
221
222 // Since we have already checked that weights are all there or none, we can
223 // check the existence of only one to the get the condition.
224 params->use_peephole = !IsNullInput(cell_to_output_weights);
225 // Checking output instead of input layer norm weights because input can be
226 // omitted ones can be omited in case CIFG LSTM is used.
227 params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
228
229 params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE);
230 params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE);
231
232 // Make sure the input gate bias is present only when not a CIFG-LSTM.
233 if (params->use_cifg) {
234 NN_CHECK(IsNullInput(input_gate_bias));
235 } else {
236 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1u);
237 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
238 }
239
240 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1u);
241 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
242
243 NN_CHECK_EQ(NumDimensions(cell_bias), 1u);
244 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
245
246 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1u);
247 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
248
249 if (!IsNullInput(projection_weights)) {
250 NN_CHECK_EQ(NumDimensions(projection_weights), 2u);
251 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
252 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
253 }
254
255 if (!IsNullInput(projection_bias)) {
256 NN_CHECK_EQ(NumDimensions(projection_bias), 1u);
257 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
258 }
259
260 // Making sure the projection tensors are consistent:
261 // 1) If projection weight is not present, then projection bias should not be
262 // present.
263 // 2) If projection weight is present, then projection bias is optional.
264 // TODO: make sure this is correct.
265 const bool projecton_tensors_consistent =
266 (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
267 NN_CHECK(projecton_tensors_consistent == true);
268
269 if (!IsNullInput(input_layer_norm_weights)) {
270 NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1u);
271 NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
272 }
273 if (!IsNullInput(forget_layer_norm_weights)) {
274 NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1u);
275 NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
276 }
277 if (!IsNullInput(cell_layer_norm_weights)) {
278 NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1u);
279 NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
280 }
281 if (!IsNullInput(output_layer_norm_weights)) {
282 NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1u);
283 NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
284 }
285
286 if (params->use_cifg) {
287 NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
288 << "input_layer_norm_weights are provided while CIFG is used";
289 const bool layer_norm_weights_all_or_none_cifg =
290 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
291 IsNullInput(output_layer_norm_weights)) ||
292 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
293 !IsNullInput(output_layer_norm_weights));
294 NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
295 } else {
296 const bool layer_norm_weights_all_or_none =
297 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
298 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
299 (!IsNullInput(input_layer_norm_weights) &&
300 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
301 !IsNullInput(output_layer_norm_weights));
302 NN_RET_CHECK(layer_norm_weights_all_or_none);
303 }
304
305 return true;
306 }
307
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)308 bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
309 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
310 Shape* outputShape) {
311 // Check we have all the inputs and outputs we need.
312 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
313 NumInputsWithValues(operation, operands) <= 27);
314 constexpr int requiredInputs[] = {
315 kInputTensor,
316 kInputToForgetWeightsTensor,
317 kInputToCellWeightsTensor,
318 kInputToOutputWeightsTensor,
319 kRecurrentToForgetWeightsTensor,
320 kRecurrentToCellWeightsTensor,
321 kRecurrentToOutputWeightsTensor,
322 kForgetGateBiasTensor,
323 kCellGateBiasTensor,
324 kOutputGateBiasTensor,
325 kOutputStateInTensor,
326 kCellStateInTensor,
327 kActivationParam,
328 kCellClipParam,
329 kProjClipParam,
330 };
331 for (const int requiredInput : requiredInputs) {
332 NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
333 << "required input " << requiredInput << " is omitted";
334 }
335 NN_CHECK_EQ(NumOutputs(operation), 4);
336
337 // Check that the scalar operands' buffers are large enough.
338 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
339 NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
340 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
341 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
342 if (input_->type == OperandType::TENSOR_FLOAT32) {
343 NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
344 NN_RET_CHECK(projClipOperand.length >= sizeof(float));
345 } else {
346 NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
347 NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
348 }
349
350 // Inferring batch size, number of outputs and number of cells from the
351 // input tensors.
352 NN_CHECK(NumDimensions(input_) > 1);
353 const uint32_t n_batch = SizeOfDimension(input_, 0);
354 const uint32_t n_input = SizeOfDimension(input_, 1);
355
356 const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
357 NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2u);
358 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
359
360 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2u);
361 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
362 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
363
364 // Check that input tensor dimensions matches with each other.
365 if (!CheckInputTensorDimensions(
366 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
367 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
368 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
369 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
370 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
371 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
372 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
373 ¶ms_)) {
374 return false;
375 }
376
377 // Resize the output and output_state tensors.
378 const Shape& inputShape = input_->shape();
379
380 outputShape->type = inputShape.type;
381 outputShape->dimensions = {n_batch, n_output};
382 outputShape->offset = inputShape.offset;
383 outputShape->scale = inputShape.scale;
384
385 outputStateShape->type = inputShape.type;
386 outputStateShape->dimensions = {n_batch, n_output};
387 outputStateShape->offset = inputShape.offset;
388 outputStateShape->scale = inputShape.scale;
389
390 cellStateShape->type = inputShape.type;
391 cellStateShape->dimensions = {n_batch, n_cell};
392 cellStateShape->offset = inputShape.offset;
393 cellStateShape->scale = inputShape.scale;
394
395 if (params_.use_cifg) {
396 // Reserving space for Cell, Forget, Output gates
397 scratchShape->dimensions = {n_batch, n_cell * 3};
398 } else {
399 // Reserving space for Input, Cell, Forget, Output gates
400 scratchShape->dimensions = {n_batch, n_cell * 4};
401 }
402 scratchShape->type = inputShape.type;
403 scratchShape->offset = inputShape.offset;
404 scratchShape->scale = inputShape.scale;
405
406 return true;
407 }
408
409 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)410 bool LSTMCell::LSTMEvalFloat32(
411 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
412 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
413 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
414 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
415 const float* recurrent_to_forget_weights_buffer,
416 const float* recurrent_to_cell_weights_buffer,
417 const float* recurrent_to_output_weights_buffer,
418 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
419 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
420 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
421 const float* aux_input_to_forget_weights_buffer,
422 const float* aux_input_to_cell_weights_buffer,
423 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
424 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
425 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
426 const float* projection_bias_buffer, const float* output_state_in_buffer,
427 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
428 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
429 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
430 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
431 bool timeMajor, bool forwardSequence) {
432 NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
433
434 const uint32_t inputRank = getNumberOfDimensions(input_shape);
435 NN_CHECK(inputRank == 2 || inputRank == 3);
436
437 const uint32_t maxTime =
438 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
439 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
440 : getSizeOfDimension(input_shape, 0);
441 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
442 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
443 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
444
445 Shape batchInputShape = input_shape;
446 batchInputShape.dimensions = {batchSize, inputSize};
447 const uint32_t batchInputSize = batchSize * inputSize;
448 const uint32_t batchOutputSize = batchSize * outputSize;
449
450 std::vector<float> transposedInput;
451 const bool hasAuxInput = (aux_input_buffer != nullptr);
452 std::vector<float> transposedAuxInput;
453 std::vector<float> transposedOutput;
454 Shape transposedInputShape;
455 Shape transposedOutputShape;
456 if (!timeMajor) {
457 transposedInput.resize(maxTime * batchInputSize);
458 transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
459 if (hasAuxInput) {
460 transposedAuxInput.resize(maxTime * batchInputSize);
461 transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
462 transposedAuxInput.data());
463 }
464 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
465 transposedOutput.resize(maxTime * batchOutputSize);
466 transposedOutputShape = transposedInputShape;
467 transposedOutputShape.dimensions[2] = outputSize;
468 }
469 const float* inputData = timeMajor ? input_buffer : transposedInput.data();
470 const float* auxInputData =
471 hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
472 float* outputData = timeMajor ? output_buffer : transposedOutput.data();
473
474 std::vector<float> outputStateInCurrentTimeStep(
475 output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
476 std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
477 cell_state_in_buffer + batchSize * numCells);
478 const float* inputCurrentTimeStep =
479 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
480 const float* auxInputCurrentTimeStep =
481 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
482 : nullptr;
483 float* outputCurrentTimeStep =
484 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
485 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
486 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
487
488 for (uint32_t t = 0; t < maxTime; ++t) {
489 LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
490 input_to_forget_weights_buffer, input_to_cell_weights_buffer,
491 input_to_output_weights_buffer, input_to_output_weights_shape,
492 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
493 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
494 recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
495 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
496 auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
497 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
498 aux_input_to_output_weights_buffer, input_gate_bias_buffer,
499 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
500 projection_weights_buffer, projection_bias_buffer,
501 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
502 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
503 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
504 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
505 scratch_buffer_buffer);
506 inputCurrentTimeStep += batchInputDelta;
507 if (hasAuxInput) {
508 auxInputCurrentTimeStep += batchInputDelta;
509 }
510 outputCurrentTimeStep += batchOutputDelta;
511 outputStateInCurrentTimeStep.assign(output_state_out_buffer,
512 output_state_out_buffer + batchSize * outputSize);
513 cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
514 cell_state_out_buffer + batchSize * numCells);
515 }
516
517 if (!timeMajor) {
518 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
519 output_buffer);
520 }
521
522 return true;
523 }
524
525 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)526 bool LSTMCell::LSTMEvalFloat16(
527 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
528 const _Float16* input_to_input_weights_buffer,
529 const _Float16* input_to_forget_weights_buffer,
530 const _Float16* input_to_cell_weights_buffer,
531 const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
532 const _Float16* recurrent_to_input_weights_buffer,
533 const _Float16* recurrent_to_forget_weights_buffer,
534 const _Float16* recurrent_to_cell_weights_buffer,
535 const _Float16* recurrent_to_output_weights_buffer,
536 const Shape& recurrent_to_output_weights_shape,
537 const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
538 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
539 const _Float16* aux_input_to_input_weights_buffer,
540 const _Float16* aux_input_to_forget_weights_buffer,
541 const _Float16* aux_input_to_cell_weights_buffer,
542 const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
543 const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
544 const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
545 const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
546 const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
547 const _Float16* forget_layer_norm_weights_buffer,
548 const _Float16* cell_layer_norm_weights_buffer,
549 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
550 _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
551 bool timeMajor, bool forwardSequence) {
552 NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
553
554 const uint32_t inputRank = getNumberOfDimensions(input_shape);
555 NN_CHECK(inputRank == 2 || inputRank == 3);
556
557 const uint32_t maxTime =
558 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
559 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
560 : getSizeOfDimension(input_shape, 0);
561 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
562 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
563 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
564
565 Shape batchInputShape = input_shape;
566 batchInputShape.dimensions = {batchSize, inputSize};
567 const uint32_t batchInputSize = batchSize * inputSize;
568 const uint32_t batchOutputSize = batchSize * outputSize;
569
570 std::vector<float> input_float32(maxTime * batchInputSize);
571 convertFloat16ToFloat32(input_buffer, &input_float32);
572 std::vector<float> input_to_input_weights_float32(numCells * inputSize);
573 if (input_to_input_weights_buffer != nullptr) {
574 convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
575 }
576 std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
577 convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
578 std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
579 convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
580 std::vector<float> input_to_output_weights_float32(numCells * inputSize);
581 convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
582
583 std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
584 if (recurrent_to_input_weights_buffer != nullptr) {
585 convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
586 &recurrent_to_input_weights_float32);
587 }
588 std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
589 convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
590 &recurrent_to_forget_weights_float32);
591 std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
592 convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
593 std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
594 convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
595 &recurrent_to_output_weights_float32);
596
597 std::vector<float> cell_to_input_weights_float32(numCells);
598 if (cell_to_input_weights_buffer != nullptr) {
599 convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
600 }
601 std::vector<float> cell_to_forget_weights_float32(numCells);
602 if (cell_to_forget_weights_buffer != nullptr) {
603 convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
604 }
605 std::vector<float> cell_to_output_weights_float32(numCells);
606 if (cell_to_output_weights_buffer != nullptr) {
607 convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
608 }
609
610 std::vector<float> aux_input_float32(maxTime * batchInputSize);
611 if (aux_input_buffer != nullptr) {
612 convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
613 }
614 std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
615 if (aux_input_to_input_weights_buffer != nullptr) {
616 convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
617 &aux_input_to_input_weights_float32);
618 }
619 std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
620 if (aux_input_to_forget_weights_buffer != nullptr) {
621 convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
622 &aux_input_to_forget_weights_float32);
623 }
624 std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
625 if (aux_input_to_cell_weights_buffer != nullptr) {
626 convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
627 &aux_input_to_cell_weights_float32);
628 }
629 std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
630 if (aux_input_to_output_weights_buffer != nullptr) {
631 convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
632 &aux_input_to_output_weights_float32);
633 }
634
635 std::vector<float> input_gate_bias_float32(numCells);
636 if (input_gate_bias_buffer != nullptr) {
637 convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
638 }
639 std::vector<float> forget_gate_bias_float32(numCells);
640 convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
641 std::vector<float> cell_bias_float32(numCells);
642 convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
643 std::vector<float> output_gate_bias_float32(numCells);
644 convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
645
646 std::vector<float> projection_weights_float32(numCells * outputSize);
647 if (projection_weights_buffer != nullptr) {
648 convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
649 }
650 std::vector<float> projection_bias_float32(outputSize);
651 if (projection_bias_buffer != nullptr) {
652 convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
653 }
654
655 std::vector<float> input_layer_norm_weights_float32(numCells);
656 if (input_layer_norm_weights_buffer != nullptr) {
657 convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
658 }
659 std::vector<float> forget_layer_norm_weights_float32(numCells);
660 if (forget_layer_norm_weights_buffer != nullptr) {
661 convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
662 &forget_layer_norm_weights_float32);
663 }
664 std::vector<float> cell_layer_norm_weights_float32(numCells);
665 if (cell_layer_norm_weights_buffer != nullptr) {
666 convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
667 }
668 std::vector<float> output_layer_norm_weights_float32(numCells);
669 if (output_layer_norm_weights_buffer != nullptr) {
670 convertFloat16ToFloat32(output_layer_norm_weights_buffer,
671 &output_layer_norm_weights_float32);
672 }
673
674 std::vector<float> output_state_out_float32(batchOutputSize);
675 convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
676 std::vector<float> cell_state_out_float32(batchSize * numCells);
677 convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
678
679 std::vector<float> output_float32(maxTime * batchOutputSize);
680 convertFloat16ToFloat32(output_buffer, &output_float32);
681 std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
682 : 4 * batchSize * numCells);
683 convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
684
685 std::vector<float> transposedInput;
686 const bool hasAuxInput = (aux_input_buffer != nullptr);
687 std::vector<float> transposedAuxInput;
688 std::vector<float> transposedOutput;
689 Shape transposedInputShape;
690 Shape transposedOutputShape;
691 if (!timeMajor) {
692 transposedInput.resize(maxTime * batchInputSize);
693 transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
694 transposedInput.data());
695 if (hasAuxInput) {
696 transposedAuxInput.resize(maxTime * batchInputSize);
697 transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
698 transposedAuxInput.data());
699 }
700 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
701 transposedOutput.resize(maxTime * batchOutputSize);
702 transposedOutputShape = transposedInputShape;
703 transposedOutputShape.dimensions[2] = outputSize;
704 }
705 const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
706 const float* auxInputData =
707 hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
708 : nullptr;
709 float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
710
711 std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
712 convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
713 std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
714 convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
715
716 const float* inputCurrentTimeStep =
717 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
718 const float* auxInputCurrentTimeStep =
719 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
720 : nullptr;
721 float* outputCurrentTimeStep =
722 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
723 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
724 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
725
726 for (uint32_t t = 0; t < maxTime; ++t) {
727 LSTMStep(params, inputCurrentTimeStep, batchInputShape,
728 input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
729 input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
730 input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
731 recurrent_to_forget_weights_float32.data(),
732 recurrent_to_cell_weights_float32.data(),
733 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
734 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
735 cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
736 aux_input_to_input_weights_float32.data(),
737 aux_input_to_forget_weights_float32.data(),
738 aux_input_to_cell_weights_float32.data(),
739 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
740 forget_gate_bias_float32.data(), cell_bias_float32.data(),
741 output_gate_bias_float32.data(), projection_weights_float32.data(),
742 projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
743 cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
744 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
745 output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
746 cell_state_out_float32.data(), outputCurrentTimeStep,
747 scratch_buffer_float32.data());
748 inputCurrentTimeStep += batchInputDelta;
749 if (hasAuxInput) {
750 auxInputCurrentTimeStep += batchInputDelta;
751 }
752 outputCurrentTimeStep += batchOutputDelta;
753 outputStateInCurrentTimeStep = output_state_out_float32;
754 cellStateInCurrentTimeStep = cell_state_out_float32;
755 }
756
757 if (!timeMajor) {
758 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
759 output_float32.data());
760 }
761
762 convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
763 convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
764 convertFloat32ToFloat16(output_float32, output_buffer);
765 convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
766 return true;
767 }
768
769 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)770 bool LSTMCell::LSTMStep(
771 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
772 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
773 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
774 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
775 const float* recurrent_to_forget_weights_buffer,
776 const float* recurrent_to_cell_weights_buffer,
777 const float* recurrent_to_output_weights_buffer,
778 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
779 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
780 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
781 const float* aux_input_to_forget_weights_buffer,
782 const float* aux_input_to_cell_weights_buffer,
783 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
784 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
785 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
786 const float* projection_bias_buffer, const float* output_state_in_buffer,
787 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
788 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
789 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
790 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
791 NNTRACE_COMP("LSTMCell::LSTMStep");
792
793 const uint32_t n_batch = input_shape.dimensions[0];
794 const uint32_t n_input = input_shape.dimensions[1];
795 // n_cell and n_output will be the same size when there is no projection.
796 const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
797 const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
798 const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
799
800 // Index the scratch buffers pointers to the global scratch buffer.
801 float* input_gate_scratch = nullptr;
802 float* cell_scratch = nullptr;
803 float* forget_gate_scratch = nullptr;
804 float* output_gate_scratch = nullptr;
805 if (params.use_cifg) {
806 cell_scratch = scratch_buffer_buffer;
807 forget_gate_scratch = cell_scratch + n_cell * n_batch;
808 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
809 } else {
810 input_gate_scratch = scratch_buffer_buffer;
811 cell_scratch = input_gate_scratch + n_cell * n_batch;
812 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
813 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
814 }
815
816 if (!params.use_layer_norm) {
817 // Initialize scratch buffers with bias.
818 if (!params.use_cifg) {
819 tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
820 input_gate_scratch);
821 }
822 tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
823 forget_gate_scratch);
824 tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
825 cell_scratch);
826 tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
827 output_gate_scratch);
828 } else {
829 // Initialize scratch buffers with zeroes.
830 if (!params.use_cifg) {
831 std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
832 }
833 std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
834 std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
835 std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
836 }
837
838 // For each batch and cell: compute input_weight * input.
839 if (!params.use_cifg) {
840 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer,
841 n_cell, n_input, input_buffer,
842 n_batch, input_gate_scratch);
843 }
844 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer,
845 n_cell, n_input, input_buffer,
846 n_batch, forget_gate_scratch);
847 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
848 input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch);
849 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer,
850 n_cell, n_input, input_buffer,
851 n_batch, output_gate_scratch);
852
853 // If auxiliary input is available then compute aux_input_weight * aux_input
854 if (aux_input_buffer != nullptr) {
855 if (!params.use_cifg) {
856 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
857 aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
858 n_batch, input_gate_scratch);
859 }
860
861 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
862 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
863 forget_gate_scratch);
864 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
865 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
866 cell_scratch);
867 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
868 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
869 output_gate_scratch);
870 }
871
872 // For each batch and cell: compute recurrent_weight * output_state.
873 if (!params.use_cifg) {
874 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
875 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
876 n_batch, input_gate_scratch);
877 }
878 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
879 recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
880 forget_gate_scratch);
881 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
882 recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
883 cell_scratch);
884 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
885 recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
886 output_gate_scratch);
887
888 // For each batch and cell: update input gate.
889 if (!params.use_cifg) {
890 if (params.use_peephole) {
891 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
892 cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
893 input_gate_scratch);
894 }
895 if (params.use_layer_norm) {
896 tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
897 n_cell, n_batch);
898 tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
899 n_cell, input_gate_scratch, n_batch,
900 input_gate_scratch);
901 tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
902 input_gate_scratch);
903 }
904 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
905 input_gate_scratch);
906 }
907
908 // For each batch and cell: update forget gate.
909 if (params.use_peephole) {
910 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
911 n_cell, cell_state_in_buffer,
912 n_batch, forget_gate_scratch);
913 }
914 if (params.use_layer_norm) {
915 tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
916 n_cell, n_batch);
917 tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
918 n_cell, forget_gate_scratch, n_batch,
919 forget_gate_scratch);
920 tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
921 forget_gate_scratch);
922 }
923 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
924 forget_gate_scratch);
925
926 // For each batch and cell: update the cell.
927 if (params.use_layer_norm) {
928 tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
929 tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
930 cell_scratch, n_batch, cell_scratch);
931 tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
932 }
933 tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
934 n_batch * n_cell, cell_state_out_buffer);
935 tflite::tensor_utils::ApplyActivationToVector(
936 cell_scratch, n_batch * n_cell, static_cast<TfLiteFusedActivation>(params.activation),
937 cell_scratch);
938 if (params.use_cifg) {
939 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
940 forget_gate_scratch);
941 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
942 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
943 } else {
944 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
945 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
946 }
947 if (params.cell_clip > 0.0) {
948 tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell,
949 params.cell_clip);
950 }
951
952 // For each batch and cell: update the output gate.
953 if (params.use_peephole) {
954 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
955 n_cell, cell_state_out_buffer,
956 n_batch, output_gate_scratch);
957 }
958 if (params.use_layer_norm) {
959 tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
960 n_cell, n_batch);
961 tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
962 n_cell, output_gate_scratch, n_batch,
963 output_gate_scratch);
964 tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
965 output_gate_scratch);
966 }
967 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
968 output_gate_scratch);
969 tflite::tensor_utils::ApplyActivationToVector(
970 cell_state_out_buffer, n_batch * n_cell,
971 static_cast<TfLiteFusedActivation>(params.activation), cell_scratch);
972 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
973 n_batch * n_cell, output_gate_scratch);
974
975 // For each batch: update the projection and output_state.
976 if (params.use_projection_weight) {
977 if (params.use_projection_bias) {
978 tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
979 output_buffer);
980 } else {
981 std::fill_n(output_buffer, n_batch * n_output, 0.0f);
982 }
983 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
984 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
985 output_buffer);
986 if (params.proj_clip > 0.0) {
987 tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output,
988 params.proj_clip);
989 }
990 } else {
991 std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
992 }
993 std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
994 return true;
995 }
996
Eval()997 bool LSTMCell::Eval() {
998 switch (input_->type) {
999 case OperandType::TENSOR_FLOAT32: {
1000 LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
1001 GetBuffer<const float>(input_to_input_weights_),
1002 GetBuffer<const float>(input_to_forget_weights_),
1003 GetBuffer<const float>(input_to_cell_weights_),
1004 GetBuffer<const float>(input_to_output_weights_),
1005 input_to_output_weights_->shape(),
1006 GetBuffer<const float>(recurrent_to_input_weights_),
1007 GetBuffer<const float>(recurrent_to_forget_weights_),
1008 GetBuffer<const float>(recurrent_to_cell_weights_),
1009 GetBuffer<const float>(recurrent_to_output_weights_),
1010 recurrent_to_output_weights_->shape(),
1011 GetBuffer<const float>(cell_to_input_weights_),
1012 GetBuffer<const float>(cell_to_forget_weights_),
1013 GetBuffer<const float>(cell_to_output_weights_),
1014 /*aux_input_buffer=*/nullptr,
1015 /*aux_input_to_input_weights_buffer=*/nullptr,
1016 /*aux_input_to_forget_weights_buffer=*/nullptr,
1017 /*aux_input_to_cell_weights_buffer=*/nullptr,
1018 /*aux_input_to_output_weights_buffer=*/nullptr,
1019 GetBuffer<const float>(input_gate_bias_),
1020 GetBuffer<const float>(forget_gate_bias_),
1021 GetBuffer<const float>(cell_bias_),
1022 GetBuffer<const float>(output_gate_bias_),
1023 GetBuffer<const float>(projection_weights_),
1024 GetBuffer<const float>(projection_bias_),
1025 GetBuffer<const float>(output_state_in_),
1026 GetBuffer<const float>(cell_state_in_),
1027 GetBuffer<const float>(input_layer_norm_weights_),
1028 GetBuffer<const float>(forget_layer_norm_weights_),
1029 GetBuffer<const float>(cell_layer_norm_weights_),
1030 GetBuffer<const float>(output_layer_norm_weights_),
1031 GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1032 GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
1033 } break;
1034 case OperandType::TENSOR_FLOAT16: {
1035 LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1036 GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1037 GetBuffer<const _Float16>(input_to_forget_weights_),
1038 GetBuffer<const _Float16>(input_to_cell_weights_),
1039 GetBuffer<const _Float16>(input_to_output_weights_),
1040 input_to_output_weights_->shape(),
1041 GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1042 GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1043 GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1044 GetBuffer<const _Float16>(recurrent_to_output_weights_),
1045 recurrent_to_output_weights_->shape(),
1046 GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1047 GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1048 GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1049 /*aux_input_buffer=*/nullptr,
1050 /*aux_input_to_input_weights_buffer=*/nullptr,
1051 /*aux_input_to_forget_weights_buffer=*/nullptr,
1052 /*aux_input_to_cell_weights_buffer=*/nullptr,
1053 /*aux_input_to_output_weights_buffer=*/nullptr,
1054 GetOptionalBuffer<const _Float16>(input_gate_bias_),
1055 GetBuffer<const _Float16>(forget_gate_bias_),
1056 GetBuffer<const _Float16>(cell_bias_),
1057 GetBuffer<const _Float16>(output_gate_bias_),
1058 GetOptionalBuffer<const _Float16>(projection_weights_),
1059 GetOptionalBuffer<const _Float16>(projection_bias_),
1060 GetBuffer<const _Float16>(output_state_in_),
1061 GetBuffer<const _Float16>(cell_state_in_),
1062 GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1063 GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1064 GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1065 GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1066 GetBuffer<_Float16>(output_state_out_),
1067 GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1068 GetBuffer<_Float16>(scratch_buffer_));
1069 } break;
1070 default: {
1071 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1072 return false;
1073 }
1074 }
1075 return true;
1076 }
1077
1078 } // namespace nn
1079 } // namespace android
1080