1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "LSTM.h"
18
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21 #include "HalInterfaces.h"
22 #include "OperationsUtils.h"
23
24 #include "Tracing.h"
25 #include "Utils.h"
26
27 namespace android {
28 namespace nn {
29
30 namespace {
31
32 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)33 inline T* GetBuffer(RunTimeOperandInfo* operand) {
34 return reinterpret_cast<T*>(operand->buffer);
35 }
36
37 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
39 return reinterpret_cast<const T*>(operand->buffer);
40 }
41
42 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
44 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
45 }
46
47 } // anonymous namespace
48
LSTMCell(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)49 LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
50 input_ = GetInput(operation, operands, kInputTensor);
51
52 input_to_input_weights_ =
53 GetInput(operation, operands, kInputToInputWeightsTensor); // optional
54 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
55 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
56 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
57
58 recurrent_to_input_weights_ =
59 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
60 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
61 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
62 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
63
64 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
65 cell_to_forget_weights_ =
66 GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
67 cell_to_output_weights_ =
68 GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
69
70 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
71 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
72 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
73 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
74
75 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
76 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
77
78 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
79 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
80
81 params_.activation = static_cast<TfLiteFusedActivation>(
82 getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
83 if (input_->type == OperandType::TENSOR_FLOAT32) {
84 params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
85 params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
86 } else {
87 params_.cell_clip = static_cast<float>(
88 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
89 params_.proj_clip = static_cast<float>(
90 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
91 }
92
93 // We check the version of LSTM by checking the number of the inputs to the
94 // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
95 if (operation.inputs.size() == 27) {
96 input_layer_norm_weights_ =
97 GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional
98 forget_layer_norm_weights_ =
99 GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional
100 cell_layer_norm_weights_ =
101 GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional
102 output_layer_norm_weights_ =
103 GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional
104 } else {
105 // For LSTM from HAL v1.0 assign operands with no values
106 static RunTimeOperandInfo no_value;
107 no_value.lifetime = OperandLifeTime::NO_VALUE;
108
109 input_layer_norm_weights_ = &no_value;
110 forget_layer_norm_weights_ = &no_value;
111 cell_layer_norm_weights_ = &no_value;
112 output_layer_norm_weights_ = &no_value;
113 }
114
115 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
116 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
117 output_ = GetOutput(operation, operands, kOutputTensor);
118
119 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
120 }
121
122 // static
CheckInputTensorDimensions(const RunTimeOperandInfo * input_,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo * input_to_output_weights,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo * recurrent_to_output_weights,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)123 bool LSTMCell::CheckInputTensorDimensions(
124 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
125 const RunTimeOperandInfo* input_to_forget_weights,
126 const RunTimeOperandInfo* input_to_cell_weights,
127 const RunTimeOperandInfo* input_to_output_weights,
128 const RunTimeOperandInfo* recurrent_to_input_weights,
129 const RunTimeOperandInfo* recurrent_to_forget_weights,
130 const RunTimeOperandInfo* recurrent_to_cell_weights,
131 const RunTimeOperandInfo* recurrent_to_output_weights,
132 const RunTimeOperandInfo* cell_to_input_weights,
133 const RunTimeOperandInfo* cell_to_forget_weights,
134 const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
135 const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
136 const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
137 const RunTimeOperandInfo* projection_bias,
138 const RunTimeOperandInfo* input_layer_norm_weights,
139 const RunTimeOperandInfo* forget_layer_norm_weights,
140 const RunTimeOperandInfo* cell_layer_norm_weights,
141 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
142 uint32_t n_cell, LSTMParams* params) {
143 // Making sure clipping parameters have valid values.
144 // == 0 means no clipping
145 // > 0 means clipping
146 NN_CHECK(params->cell_clip >= 0);
147 NN_CHECK(params->proj_clip >= 0);
148
149 if (!IsNullInput(input_to_input_weights)) {
150 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
151 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
152 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
153 }
154
155 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
156 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
157 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
158
159 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
160 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
161 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
162
163 if (!IsNullInput(recurrent_to_input_weights)) {
164 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
165 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
166 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
167 }
168
169 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
170 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
171 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
172
173 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
174 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
175 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
176
177 // We make sure the input-gate's parameters are either both present (regular
178 // LSTM) or not at all (CIFG-LSTM).
179 const bool cifg_weights_all_or_none =
180 (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
181 (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
182 NN_CHECK(cifg_weights_all_or_none);
183
184 if (!IsNullInput(cell_to_input_weights)) {
185 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
186 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
187 }
188
189 if (!IsNullInput(cell_to_forget_weights)) {
190 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
191 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
192 }
193
194 if (!IsNullInput(cell_to_output_weights)) {
195 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
196 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
197 }
198
199 // Making sure the peephole weights are there all or none.
200 params->use_cifg = IsNullInput(input_to_input_weights);
201 const bool peephole_weights_all_or_none =
202 ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
203 !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
204 (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
205 IsNullInput(cell_to_output_weights));
206 NN_CHECK(peephole_weights_all_or_none);
207
208 // Since we have already checked that weights are all there or none, we can
209 // check the existence of only one to the get the condition.
210 params->use_peephole = !IsNullInput(cell_to_output_weights);
211 // Checking output instead of input layer norm weights because input can be
212 // omitted ones can be omited in case CIFG LSTM is used.
213 params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
214
215 params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
216 params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
217
218 // Make sure the input gate bias is present only when not a CIFG-LSTM.
219 if (params->use_cifg) {
220 NN_CHECK(IsNullInput(input_gate_bias));
221 } else {
222 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
223 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
224 }
225
226 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
227 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
228
229 NN_CHECK_EQ(NumDimensions(cell_bias), 1);
230 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
231
232 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
233 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
234
235 if (!IsNullInput(projection_weights)) {
236 NN_CHECK_EQ(NumDimensions(projection_weights), 2);
237 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
238 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
239 }
240
241 if (!IsNullInput(projection_bias)) {
242 NN_CHECK_EQ(NumDimensions(projection_bias), 1);
243 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
244 }
245
246 // Making sure the projection tensors are consistent:
247 // 1) If projection weight is not present, then projection bias should not be
248 // present.
249 // 2) If projection weight is present, then projection bias is optional.
250 // TODO: make sure this is correct.
251 const bool projecton_tensors_consistent =
252 (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
253 NN_CHECK(projecton_tensors_consistent == true);
254
255 if (!IsNullInput(input_layer_norm_weights)) {
256 NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
257 NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
258 }
259 if (!IsNullInput(forget_layer_norm_weights)) {
260 NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
261 NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
262 }
263 if (!IsNullInput(cell_layer_norm_weights)) {
264 NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
265 NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
266 }
267 if (!IsNullInput(output_layer_norm_weights)) {
268 NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
269 NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
270 }
271
272 if (params->use_cifg) {
273 NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
274 << "input_layer_norm_weights are provided while CIFG is used";
275 const bool layer_norm_weights_all_or_none_cifg =
276 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
277 IsNullInput(output_layer_norm_weights)) ||
278 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
279 !IsNullInput(output_layer_norm_weights));
280 NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
281 } else {
282 const bool layer_norm_weights_all_or_none =
283 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
284 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
285 (!IsNullInput(input_layer_norm_weights) &&
286 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
287 !IsNullInput(output_layer_norm_weights));
288 NN_RET_CHECK(layer_norm_weights_all_or_none);
289 }
290
291 return true;
292 }
293
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)294 bool LSTMCell::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
295 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
296 Shape* outputShape) {
297 // Check we have all the inputs and outputs we need.
298 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
299 NumInputsWithValues(operation, operands) <= 27);
300 NN_CHECK_EQ(NumOutputs(operation), 4);
301
302 // Inferring batch size, number of outputs and number of cells from the
303 // input tensors.
304 NN_CHECK(NumDimensions(input_) > 1);
305 const uint32_t n_batch = SizeOfDimension(input_, 0);
306 const uint32_t n_input = SizeOfDimension(input_, 1);
307
308 const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
309 NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
310 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
311
312 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
313 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
314 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
315
316 // Check that input tensor dimensions matches with each other.
317 if (!CheckInputTensorDimensions(
318 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
319 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
320 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
321 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
322 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
323 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
324 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
325 ¶ms_)) {
326 return false;
327 }
328
329 // Resize the output and output_state tensors.
330 const Shape& inputShape = input_->shape();
331
332 outputShape->type = inputShape.type;
333 outputShape->dimensions = {n_batch, n_output};
334 outputShape->offset = inputShape.offset;
335 outputShape->scale = inputShape.scale;
336
337 outputStateShape->type = inputShape.type;
338 outputStateShape->dimensions = {n_batch, n_output};
339 outputStateShape->offset = inputShape.offset;
340 outputStateShape->scale = inputShape.scale;
341
342 cellStateShape->type = inputShape.type;
343 cellStateShape->dimensions = {n_batch, n_cell};
344 cellStateShape->offset = inputShape.offset;
345 cellStateShape->scale = inputShape.scale;
346
347 if (params_.use_cifg) {
348 // Reserving space for Cell, Forget, Output gates
349 scratchShape->dimensions = {n_batch, n_cell * 3};
350 } else {
351 // Reserving space for Input, Cell, Forget, Output gates
352 scratchShape->dimensions = {n_batch, n_cell * 4};
353 }
354 scratchShape->type = inputShape.type;
355 scratchShape->offset = inputShape.offset;
356 scratchShape->scale = inputShape.scale;
357
358 return true;
359 }
360
361 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)362 bool LSTMCell::LSTMEvalFloat32(
363 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
364 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
365 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
366 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
367 const float* recurrent_to_forget_weights_buffer,
368 const float* recurrent_to_cell_weights_buffer,
369 const float* recurrent_to_output_weights_buffer,
370 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
371 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
372 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
373 const float* aux_input_to_forget_weights_buffer,
374 const float* aux_input_to_cell_weights_buffer,
375 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
376 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
377 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
378 const float* projection_bias_buffer, const float* output_state_in_buffer,
379 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
380 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
381 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
382 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
383 bool timeMajor, bool forwardSequence) {
384 NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
385
386 const uint32_t inputRank = getNumberOfDimensions(input_shape);
387 NN_CHECK(inputRank == 2 || inputRank == 3);
388
389 const uint32_t maxTime =
390 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
391 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
392 : getSizeOfDimension(input_shape, 0);
393 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
394 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
395 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
396
397 Shape batchInputShape = input_shape;
398 batchInputShape.dimensions = {batchSize, inputSize};
399 const uint32_t batchInputSize = batchSize * inputSize;
400 const uint32_t batchOutputSize = batchSize * outputSize;
401
402 std::vector<float> transposedInput;
403 const bool hasAuxInput = (aux_input_buffer != nullptr);
404 std::vector<float> transposedAuxInput;
405 std::vector<float> transposedOutput;
406 Shape transposedInputShape;
407 Shape transposedOutputShape;
408 if (!timeMajor) {
409 transposedInput.resize(maxTime * batchInputSize);
410 transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
411 if (hasAuxInput) {
412 transposedAuxInput.resize(maxTime * batchInputSize);
413 transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
414 transposedAuxInput.data());
415 }
416 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
417 transposedOutput.resize(maxTime * batchOutputSize);
418 transposedOutputShape = transposedInputShape;
419 transposedOutputShape.dimensions[2] = outputSize;
420 }
421 const float* inputData = timeMajor ? input_buffer : transposedInput.data();
422 const float* auxInputData =
423 hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
424 float* outputData = timeMajor ? output_buffer : transposedOutput.data();
425
426 std::vector<float> outputStateInCurrentTimeStep(
427 output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
428 std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
429 cell_state_in_buffer + batchSize * numCells);
430 const float* inputCurrentTimeStep =
431 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
432 const float* auxInputCurrentTimeStep =
433 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
434 : nullptr;
435 float* outputCurrentTimeStep =
436 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
437 const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
438 const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
439
440 for (int t = 0; t < maxTime; ++t) {
441 LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
442 input_to_forget_weights_buffer, input_to_cell_weights_buffer,
443 input_to_output_weights_buffer, input_to_output_weights_shape,
444 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
445 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
446 recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
447 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
448 auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
449 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
450 aux_input_to_output_weights_buffer, input_gate_bias_buffer,
451 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
452 projection_weights_buffer, projection_bias_buffer,
453 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
454 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
455 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
456 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
457 scratch_buffer_buffer);
458 inputCurrentTimeStep += batchInputDelta;
459 if (hasAuxInput) {
460 auxInputCurrentTimeStep += batchInputDelta;
461 }
462 outputCurrentTimeStep += batchOutputDelta;
463 outputStateInCurrentTimeStep.assign(output_state_out_buffer,
464 output_state_out_buffer + batchSize * outputSize);
465 cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
466 cell_state_out_buffer + batchSize * numCells);
467 }
468
469 if (!timeMajor) {
470 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
471 output_buffer);
472 }
473
474 return true;
475 }
476
477 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)478 bool LSTMCell::LSTMEvalFloat16(
479 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
480 const _Float16* input_to_input_weights_buffer,
481 const _Float16* input_to_forget_weights_buffer,
482 const _Float16* input_to_cell_weights_buffer,
483 const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
484 const _Float16* recurrent_to_input_weights_buffer,
485 const _Float16* recurrent_to_forget_weights_buffer,
486 const _Float16* recurrent_to_cell_weights_buffer,
487 const _Float16* recurrent_to_output_weights_buffer,
488 const Shape& recurrent_to_output_weights_shape,
489 const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
490 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
491 const _Float16* aux_input_to_input_weights_buffer,
492 const _Float16* aux_input_to_forget_weights_buffer,
493 const _Float16* aux_input_to_cell_weights_buffer,
494 const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
495 const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
496 const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
497 const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
498 const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
499 const _Float16* forget_layer_norm_weights_buffer,
500 const _Float16* cell_layer_norm_weights_buffer,
501 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
502 _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
503 bool timeMajor, bool forwardSequence) {
504 NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
505
506 const uint32_t inputRank = getNumberOfDimensions(input_shape);
507 NN_CHECK(inputRank == 2 || inputRank == 3);
508
509 const uint32_t maxTime =
510 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
511 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
512 : getSizeOfDimension(input_shape, 0);
513 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
514 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
515 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
516
517 Shape batchInputShape = input_shape;
518 batchInputShape.dimensions = {batchSize, inputSize};
519 const uint32_t batchInputSize = batchSize * inputSize;
520 const uint32_t batchOutputSize = batchSize * outputSize;
521
522 std::vector<float> input_float32(maxTime * batchInputSize);
523 convertFloat16ToFloat32(input_buffer, &input_float32);
524 std::vector<float> input_to_input_weights_float32(numCells * inputSize);
525 if (input_to_input_weights_buffer != nullptr) {
526 convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
527 }
528 std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
529 convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
530 std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
531 convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
532 std::vector<float> input_to_output_weights_float32(numCells * inputSize);
533 convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
534
535 std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
536 if (recurrent_to_input_weights_buffer != nullptr) {
537 convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
538 &recurrent_to_input_weights_float32);
539 }
540 std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
541 convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
542 &recurrent_to_forget_weights_float32);
543 std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
544 convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
545 std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
546 convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
547 &recurrent_to_output_weights_float32);
548
549 std::vector<float> cell_to_input_weights_float32(numCells);
550 if (cell_to_input_weights_buffer != nullptr) {
551 convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
552 }
553 std::vector<float> cell_to_forget_weights_float32(numCells);
554 if (cell_to_forget_weights_buffer != nullptr) {
555 convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
556 }
557 std::vector<float> cell_to_output_weights_float32(numCells);
558 if (cell_to_output_weights_buffer != nullptr) {
559 convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
560 }
561
562 std::vector<float> aux_input_float32(maxTime * batchInputSize);
563 if (aux_input_buffer != nullptr) {
564 convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
565 }
566 std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
567 if (aux_input_to_input_weights_buffer != nullptr) {
568 convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
569 &aux_input_to_input_weights_float32);
570 }
571 std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
572 if (aux_input_to_forget_weights_buffer != nullptr) {
573 convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
574 &aux_input_to_forget_weights_float32);
575 }
576 std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
577 if (aux_input_to_cell_weights_buffer != nullptr) {
578 convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
579 &aux_input_to_cell_weights_float32);
580 }
581 std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
582 if (aux_input_to_output_weights_buffer != nullptr) {
583 convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
584 &aux_input_to_output_weights_float32);
585 }
586
587 std::vector<float> input_gate_bias_float32(numCells);
588 if (input_gate_bias_buffer != nullptr) {
589 convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
590 }
591 std::vector<float> forget_gate_bias_float32(numCells);
592 convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
593 std::vector<float> cell_bias_float32(numCells);
594 convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
595 std::vector<float> output_gate_bias_float32(numCells);
596 convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
597
598 std::vector<float> projection_weights_float32(numCells * outputSize);
599 if (projection_weights_buffer != nullptr) {
600 convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
601 }
602 std::vector<float> projection_bias_float32(outputSize);
603 if (projection_bias_buffer != nullptr) {
604 convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
605 }
606
607 std::vector<float> input_layer_norm_weights_float32(numCells);
608 if (input_layer_norm_weights_buffer != nullptr) {
609 convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
610 }
611 std::vector<float> forget_layer_norm_weights_float32(numCells);
612 if (forget_layer_norm_weights_buffer != nullptr) {
613 convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
614 &forget_layer_norm_weights_float32);
615 }
616 std::vector<float> cell_layer_norm_weights_float32(numCells);
617 if (cell_layer_norm_weights_buffer != nullptr) {
618 convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
619 }
620 std::vector<float> output_layer_norm_weights_float32(numCells);
621 if (output_layer_norm_weights_buffer != nullptr) {
622 convertFloat16ToFloat32(output_layer_norm_weights_buffer,
623 &output_layer_norm_weights_float32);
624 }
625
626 std::vector<float> output_state_out_float32(batchOutputSize);
627 convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
628 std::vector<float> cell_state_out_float32(batchSize * numCells);
629 convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
630
631 std::vector<float> output_float32(maxTime * batchOutputSize);
632 convertFloat16ToFloat32(output_buffer, &output_float32);
633 std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
634 : 4 * batchSize * numCells);
635 convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
636
637 std::vector<float> transposedInput;
638 const bool hasAuxInput = (aux_input_buffer != nullptr);
639 std::vector<float> transposedAuxInput;
640 std::vector<float> transposedOutput;
641 Shape transposedInputShape;
642 Shape transposedOutputShape;
643 if (!timeMajor) {
644 transposedInput.resize(maxTime * batchInputSize);
645 transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
646 transposedInput.data());
647 if (hasAuxInput) {
648 transposedAuxInput.resize(maxTime * batchInputSize);
649 transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
650 transposedAuxInput.data());
651 }
652 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
653 transposedOutput.resize(maxTime * batchOutputSize);
654 transposedOutputShape = transposedInputShape;
655 transposedOutputShape.dimensions[2] = outputSize;
656 }
657 const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
658 const float* auxInputData =
659 hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
660 : nullptr;
661 float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
662
663 std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
664 convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
665 std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
666 convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
667
668 const float* inputCurrentTimeStep =
669 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
670 const float* auxInputCurrentTimeStep =
671 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
672 : nullptr;
673 float* outputCurrentTimeStep =
674 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
675 const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
676 const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
677
678 for (int t = 0; t < maxTime; ++t) {
679 LSTMStep(params, inputCurrentTimeStep, batchInputShape,
680 input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
681 input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
682 input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
683 recurrent_to_forget_weights_float32.data(),
684 recurrent_to_cell_weights_float32.data(),
685 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
686 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
687 cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
688 aux_input_to_input_weights_float32.data(),
689 aux_input_to_forget_weights_float32.data(),
690 aux_input_to_cell_weights_float32.data(),
691 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
692 forget_gate_bias_float32.data(), cell_bias_float32.data(),
693 output_gate_bias_float32.data(), projection_weights_float32.data(),
694 projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
695 cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
696 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
697 output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
698 cell_state_out_float32.data(), outputCurrentTimeStep,
699 scratch_buffer_float32.data());
700 inputCurrentTimeStep += batchInputDelta;
701 if (hasAuxInput) {
702 auxInputCurrentTimeStep += batchInputDelta;
703 }
704 outputCurrentTimeStep += batchOutputDelta;
705 outputStateInCurrentTimeStep = output_state_out_float32;
706 cellStateInCurrentTimeStep = cell_state_out_float32;
707 }
708
709 if (!timeMajor) {
710 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
711 output_float32.data());
712 }
713
714 convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
715 convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
716 convertFloat32ToFloat16(output_float32, output_buffer);
717 convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
718 return true;
719 }
720
721 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)722 bool LSTMCell::LSTMStep(
723 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
724 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
725 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
726 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
727 const float* recurrent_to_forget_weights_buffer,
728 const float* recurrent_to_cell_weights_buffer,
729 const float* recurrent_to_output_weights_buffer,
730 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
731 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
732 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
733 const float* aux_input_to_forget_weights_buffer,
734 const float* aux_input_to_cell_weights_buffer,
735 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
736 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
737 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
738 const float* projection_bias_buffer, const float* output_state_in_buffer,
739 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
740 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
741 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
742 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
743 NNTRACE_COMP("LSTMCell::LSTMStep");
744
745 const uint32_t n_batch = input_shape.dimensions[0];
746 const uint32_t n_input = input_shape.dimensions[1];
747 // n_cell and n_output will be the same size when there is no projection.
748 const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
749 const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
750 const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
751
752 // Index the scratch buffers pointers to the global scratch buffer.
753 float* input_gate_scratch = nullptr;
754 float* cell_scratch = nullptr;
755 float* forget_gate_scratch = nullptr;
756 float* output_gate_scratch = nullptr;
757 if (params.use_cifg) {
758 cell_scratch = scratch_buffer_buffer;
759 forget_gate_scratch = cell_scratch + n_cell * n_batch;
760 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
761 } else {
762 input_gate_scratch = scratch_buffer_buffer;
763 cell_scratch = input_gate_scratch + n_cell * n_batch;
764 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
765 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
766 }
767
768 if (!params.use_layer_norm) {
769 // Initialize scratch buffers with bias.
770 if (!params.use_cifg) {
771 tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
772 input_gate_scratch);
773 }
774 tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
775 forget_gate_scratch);
776 tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
777 cell_scratch);
778 tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
779 output_gate_scratch);
780 } else {
781 // Initialize scratch buffers with zeroes.
782 if (!params.use_cifg) {
783 tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
784 }
785 tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
786 tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
787 tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
788 }
789
790 // For each batch and cell: compute input_weight * input.
791 if (!params.use_cifg) {
792 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
793 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
794 input_gate_scratch, /*result_stride*/ 1);
795 }
796 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
797 input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch,
798 forget_gate_scratch, /*result_stride*/ 1);
799 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell,
800 n_input, input_buffer, n_batch,
801 cell_scratch, /*result_stride*/ 1);
802 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
803 input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
804 output_gate_scratch, /*result_stride*/ 1);
805
806 // If auxiliary input is available then compute aux_input_weight * aux_input
807 if (aux_input_buffer != nullptr) {
808 if (!params.use_cifg) {
809 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
810 aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
811 n_batch, input_gate_scratch,
812 /*result_stride=*/1);
813 }
814
815 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
816 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
817 forget_gate_scratch, /*result_stride=*/1);
818 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
819 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
820 cell_scratch, /*result_stride=*/1);
821 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
822 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
823 output_gate_scratch, /*result_stride=*/1);
824 }
825
826 // For each batch and cell: compute recurrent_weight * output_state.
827 if (!params.use_cifg) {
828 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
829 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
830 n_batch, input_gate_scratch,
831 /*result_stride*/ 1);
832 }
833 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
834 recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
835 forget_gate_scratch, /*result_stride*/ 1);
836 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
837 recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
838 cell_scratch, /*result_stride*/ 1);
839 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
840 recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
841 output_gate_scratch, /*result_stride*/ 1);
842
843 // For each batch and cell: update input gate.
844 if (!params.use_cifg) {
845 if (params.use_peephole) {
846 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
847 cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
848 input_gate_scratch);
849 }
850 if (params.use_layer_norm) {
851 tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
852 n_cell, n_batch, kLayerNormEpsilon);
853 tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
854 n_cell, input_gate_scratch, n_batch,
855 input_gate_scratch);
856 tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
857 input_gate_scratch);
858 }
859 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
860 input_gate_scratch);
861 }
862
863 // For each batch and cell: update forget gate.
864 if (params.use_peephole) {
865 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
866 n_cell, cell_state_in_buffer,
867 n_batch, forget_gate_scratch);
868 }
869 if (params.use_layer_norm) {
870 tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
871 n_cell, n_batch, kLayerNormEpsilon);
872 tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
873 n_cell, forget_gate_scratch, n_batch,
874 forget_gate_scratch);
875 tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
876 forget_gate_scratch);
877 }
878 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
879 forget_gate_scratch);
880
881 // For each batch and cell: update the cell.
882 if (params.use_layer_norm) {
883 tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
884 kLayerNormEpsilon);
885 tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
886 cell_scratch, n_batch, cell_scratch);
887 tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
888 }
889 tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
890 n_batch * n_cell, cell_state_out_buffer);
891 tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
892 cell_scratch);
893 if (params.use_cifg) {
894 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
895 forget_gate_scratch);
896 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
897 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
898 } else {
899 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
900 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
901 }
902 if (params.cell_clip > 0.0) {
903 tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
904 cell_state_out_buffer);
905 }
906
907 // For each batch and cell: update the output gate.
908 if (params.use_peephole) {
909 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
910 n_cell, cell_state_out_buffer,
911 n_batch, output_gate_scratch);
912 }
913 if (params.use_layer_norm) {
914 tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
915 n_cell, n_batch, kLayerNormEpsilon);
916 tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
917 n_cell, output_gate_scratch, n_batch,
918 output_gate_scratch);
919 tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
920 output_gate_scratch);
921 }
922 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
923 output_gate_scratch);
924 tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
925 params.activation, cell_scratch);
926 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
927 n_batch * n_cell, output_gate_scratch);
928
929 // For each batch: update the projection and output_state.
930 if (params.use_projection_weight) {
931 if (params.use_projection_bias) {
932 tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
933 output_buffer);
934 } else {
935 tflite::tensor_utils::ZeroVector(output_buffer, n_batch * n_output);
936 }
937 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
938 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
939 output_buffer,
940 /*result_stride*/ 1);
941 if (params.proj_clip > 0.0) {
942 tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
943 output_buffer);
944 }
945 } else {
946 tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, output_buffer);
947 }
948 tflite::tensor_utils::CopyVector(output_buffer, n_batch * n_output, output_state_out_buffer);
949 return true;
950 }
951
Eval()952 bool LSTMCell::Eval() {
953 switch (input_->type) {
954 case OperandType::TENSOR_FLOAT32: {
955 LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
956 GetBuffer<const float>(input_to_input_weights_),
957 GetBuffer<const float>(input_to_forget_weights_),
958 GetBuffer<const float>(input_to_cell_weights_),
959 GetBuffer<const float>(input_to_output_weights_),
960 input_to_output_weights_->shape(),
961 GetBuffer<const float>(recurrent_to_input_weights_),
962 GetBuffer<const float>(recurrent_to_forget_weights_),
963 GetBuffer<const float>(recurrent_to_cell_weights_),
964 GetBuffer<const float>(recurrent_to_output_weights_),
965 recurrent_to_output_weights_->shape(),
966 GetBuffer<const float>(cell_to_input_weights_),
967 GetBuffer<const float>(cell_to_forget_weights_),
968 GetBuffer<const float>(cell_to_output_weights_),
969 /*aux_input_buffer=*/nullptr,
970 /*aux_input_to_input_weights_buffer=*/nullptr,
971 /*aux_input_to_forget_weights_buffer=*/nullptr,
972 /*aux_input_to_cell_weights_buffer=*/nullptr,
973 /*aux_input_to_output_weights_buffer=*/nullptr,
974 GetBuffer<const float>(input_gate_bias_),
975 GetBuffer<const float>(forget_gate_bias_),
976 GetBuffer<const float>(cell_bias_),
977 GetBuffer<const float>(output_gate_bias_),
978 GetBuffer<const float>(projection_weights_),
979 GetBuffer<const float>(projection_bias_),
980 GetBuffer<const float>(output_state_in_),
981 GetBuffer<const float>(cell_state_in_),
982 GetBuffer<const float>(input_layer_norm_weights_),
983 GetBuffer<const float>(forget_layer_norm_weights_),
984 GetBuffer<const float>(cell_layer_norm_weights_),
985 GetBuffer<const float>(output_layer_norm_weights_),
986 GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
987 GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
988 } break;
989 case OperandType::TENSOR_FLOAT16: {
990 LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
991 GetOptionalBuffer<const _Float16>(input_to_input_weights_),
992 GetBuffer<const _Float16>(input_to_forget_weights_),
993 GetBuffer<const _Float16>(input_to_cell_weights_),
994 GetBuffer<const _Float16>(input_to_output_weights_),
995 input_to_output_weights_->shape(),
996 GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
997 GetBuffer<const _Float16>(recurrent_to_forget_weights_),
998 GetBuffer<const _Float16>(recurrent_to_cell_weights_),
999 GetBuffer<const _Float16>(recurrent_to_output_weights_),
1000 recurrent_to_output_weights_->shape(),
1001 GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1002 GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1003 GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1004 /*aux_input_buffer=*/nullptr,
1005 /*aux_input_to_input_weights_buffer=*/nullptr,
1006 /*aux_input_to_forget_weights_buffer=*/nullptr,
1007 /*aux_input_to_cell_weights_buffer=*/nullptr,
1008 /*aux_input_to_output_weights_buffer=*/nullptr,
1009 GetOptionalBuffer<const _Float16>(input_gate_bias_),
1010 GetBuffer<const _Float16>(forget_gate_bias_),
1011 GetBuffer<const _Float16>(cell_bias_),
1012 GetBuffer<const _Float16>(output_gate_bias_),
1013 GetOptionalBuffer<const _Float16>(projection_weights_),
1014 GetOptionalBuffer<const _Float16>(projection_bias_),
1015 GetBuffer<const _Float16>(output_state_in_),
1016 GetBuffer<const _Float16>(cell_state_in_),
1017 GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1018 GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1019 GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1020 GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1021 GetBuffer<_Float16>(output_state_out_),
1022 GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1023 GetBuffer<_Float16>(scratch_buffer_));
1024 } break;
1025 default: {
1026 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1027 return false;
1028 }
1029 }
1030 return true;
1031 }
1032
1033 } // namespace nn
1034 } // namespace android
1035