1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "BidirectionalSequenceLSTM.h"
18
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21 #include "HalInterfaces.h"
22 #include "OperationsUtils.h"
23
24 #include "Tracing.h"
25
26 namespace android {
27 namespace nn {
28
29 namespace {
30
31 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)32 inline T* GetBuffer(RunTimeOperandInfo* operand) {
33 return reinterpret_cast<T*>(operand->buffer);
34 }
35
36 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)37 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
38 return reinterpret_cast<const T*>(operand->buffer);
39 }
40
41 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)42 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
43 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
44 }
45
46 } // anonymous namespace
47
BidirectionalSequenceLSTM(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)48 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
49 std::vector<RunTimeOperandInfo>& operands) {
50 input_ = GetInput(operation, operands, kInputTensor);
51
52 fw_input_to_input_weights_ =
53 GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional
54 fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
55 fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
56 fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
57
58 fw_recurrent_to_input_weights_ =
59 GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional
60 fw_recurrent_to_forget_weights_ =
61 GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
62 fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
63 fw_recurrent_to_output_weights_ =
64 GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
65
66 fw_cell_to_input_weights_ =
67 GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional
68 fw_cell_to_forget_weights_ =
69 GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional
70 fw_cell_to_output_weights_ =
71 GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional
72
73 fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
74 fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
75 fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
76 fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
77
78 fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional
79 fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional
80
81 fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
82 fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
83
84 bw_input_to_input_weights_ =
85 GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional
86 bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
87 bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
88 bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
89
90 bw_recurrent_to_input_weights_ =
91 GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional
92 bw_recurrent_to_forget_weights_ =
93 GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
94 bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
95 bw_recurrent_to_output_weights_ =
96 GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
97
98 bw_cell_to_input_weights_ =
99 GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional
100 bw_cell_to_forget_weights_ =
101 GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional
102 bw_cell_to_output_weights_ =
103 GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional
104
105 bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
106 bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
107 bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
108 bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
109
110 bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional
111 bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional
112
113 bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
114 bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
115
116 aux_input_ = GetInput(operation, operands, kAuxInputTensor);
117 fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
118 fw_aux_input_to_forget_weights_ =
119 GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
120 fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
121 fw_aux_input_to_output_weights_ =
122 GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
123 bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
124 bw_aux_input_to_forget_weights_ =
125 GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
126 bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
127 bw_aux_input_to_output_weights_ =
128 GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
129
130 fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
131 fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
132 fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
133 fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
134 bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
135 bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
136 bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
137 bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
138
139 params_.activation = static_cast<TfLiteFusedActivation>(
140 getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
141 if (input_->type == OperandType::TENSOR_FLOAT32) {
142 params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
143 params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
144 } else {
145 params_.cell_clip = static_cast<float>(
146 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
147 params_.proj_clip = static_cast<float>(
148 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
149 }
150 params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
151 params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
152 params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
153
154 fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
155 if (!params_.merge_outputs) {
156 bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
157 }
158 }
159
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * fwOutputShape,Shape * bwOutputShape)160 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation,
161 std::vector<RunTimeOperandInfo>& operands,
162 Shape* fwOutputShape, Shape* bwOutputShape) {
163 // Inferring batch size, number of outputs and number of cells from the
164 // input tensors.
165 NN_CHECK(NumDimensions(input_) == 3);
166 const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
167 const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
168 const uint32_t n_input = SizeOfDimension(input_, 2);
169
170 const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
171 NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
172 NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input);
173
174 NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
175 NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
176 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
177
178 // Check that input tensor dimensions matches with each other.
179 if (!LSTMCell::CheckInputTensorDimensions(
180 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
181 fw_input_to_cell_weights_, fw_input_to_output_weights_,
182 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
183 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
184 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
185 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
186 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
187 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
188 fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, ¶ms_)) {
189 return false;
190 }
191
192 const bool aux_inputs_all_or_none =
193 (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) &&
194 !IsNullInput(fw_aux_input_to_forget_weights_) &&
195 !IsNullInput(fw_aux_input_to_output_weights_) &&
196 !IsNullInput(bw_aux_input_to_cell_weights_) &&
197 !IsNullInput(bw_aux_input_to_forget_weights_) &&
198 !IsNullInput(bw_aux_input_to_output_weights_)) ||
199 (IsNullInput(fw_aux_input_to_cell_weights_) &&
200 IsNullInput(fw_aux_input_to_forget_weights_) &&
201 IsNullInput(fw_aux_input_to_output_weights_) &&
202 IsNullInput(bw_aux_input_to_cell_weights_) &&
203 IsNullInput(bw_aux_input_to_forget_weights_) &&
204 IsNullInput(bw_aux_input_to_output_weights_));
205 NN_CHECK(aux_inputs_all_or_none);
206 if (!IsNullInput(aux_input_)) {
207 // Check that aux_input has the same dimensions (except last) as the input.
208 NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
209 NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
210 }
211
212 const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
213 NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2);
214 NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input);
215
216 NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
217 NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
218 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
219
220 const Shape& inputShape = input_->shape();
221 fwOutputShape->type = inputShape.type;
222 fwOutputShape->offset = inputShape.offset;
223 fwOutputShape->scale = inputShape.scale;
224 fwOutputShape->dimensions.resize(3);
225 fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
226 fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
227 fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
228
229 // Check that input tensor dimensions matches with each other.
230 if (!LSTMCell::CheckInputTensorDimensions(
231 input_, bw_input_to_input_weights_, bw_input_to_forget_weights_,
232 bw_input_to_cell_weights_, bw_input_to_output_weights_,
233 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
234 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
235 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
236 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
237 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
238 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
239 bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, ¶ms_)) {
240 return false;
241 }
242
243 if (!params_.merge_outputs) {
244 bwOutputShape->type = inputShape.type;
245 bwOutputShape->offset = inputShape.offset;
246 bwOutputShape->scale = inputShape.scale;
247 bwOutputShape->dimensions.resize(3);
248 bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
249 bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
250 bwOutputShape->dimensions[2] = n_bw_output;
251 }
252
253 if (params_.use_cifg) {
254 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
255 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
256 } else {
257 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
258 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
259 }
260 fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
261 fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
262 fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
263
264 return true;
265 }
266
Eval()267 bool BidirectionalSequenceLSTM::Eval() {
268 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
269 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
270 std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
271 fw_output_dims[2] = n_fw_output;
272 std::vector<uint32_t> bw_output_dims = fw_output_dims;
273 bw_output_dims[2] = n_bw_output;
274 const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
275 const uint32_t n_output_elements =
276 fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
277
278 switch (input_->type) {
279 case OperandType::TENSOR_FLOAT32: {
280 std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
281 const bool kForwardSequence = true;
282 LSTMCell::LSTMEvalFloat32(
283 params_, GetBuffer<const float>(input_), input_->shape(),
284 GetBuffer<const float>(fw_input_to_input_weights_),
285 GetBuffer<const float>(fw_input_to_forget_weights_),
286 GetBuffer<const float>(fw_input_to_cell_weights_),
287 GetBuffer<const float>(fw_input_to_output_weights_),
288 fw_input_to_output_weights_->shape(),
289 GetBuffer<const float>(fw_recurrent_to_input_weights_),
290 GetBuffer<const float>(fw_recurrent_to_forget_weights_),
291 GetBuffer<const float>(fw_recurrent_to_cell_weights_),
292 GetBuffer<const float>(fw_recurrent_to_output_weights_),
293 fw_recurrent_to_output_weights_->shape(),
294 GetBuffer<const float>(fw_cell_to_input_weights_),
295 GetBuffer<const float>(fw_cell_to_forget_weights_),
296 GetBuffer<const float>(fw_cell_to_output_weights_),
297 GetOptionalBuffer<const float>(aux_input_),
298 GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
299 GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
300 GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
301 GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
302 GetBuffer<const float>(fw_input_gate_bias_),
303 GetBuffer<const float>(fw_forget_gate_bias_),
304 GetBuffer<const float>(fw_cell_bias_),
305 GetBuffer<const float>(fw_output_gate_bias_),
306 GetBuffer<const float>(fw_projection_weights_),
307 GetBuffer<const float>(fw_projection_bias_),
308 GetBuffer<const float>(fw_activation_state_),
309 GetBuffer<const float>(fw_cell_state_),
310 GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
311 GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
312 GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
313 GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
314 GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_),
315 GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
316 kForwardSequence);
317
318 std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
319 const bool kBackwardSequence = false;
320 LSTMCell::LSTMEvalFloat32(
321 params_, GetBuffer<const float>(input_), input_->shape(),
322 GetBuffer<const float>(bw_input_to_input_weights_),
323 GetBuffer<const float>(bw_input_to_forget_weights_),
324 GetBuffer<const float>(bw_input_to_cell_weights_),
325 GetBuffer<const float>(bw_input_to_output_weights_),
326 bw_input_to_output_weights_->shape(),
327 GetBuffer<const float>(bw_recurrent_to_input_weights_),
328 GetBuffer<const float>(bw_recurrent_to_forget_weights_),
329 GetBuffer<const float>(bw_recurrent_to_cell_weights_),
330 GetBuffer<const float>(bw_recurrent_to_output_weights_),
331 bw_recurrent_to_output_weights_->shape(),
332 GetBuffer<const float>(bw_cell_to_input_weights_),
333 GetBuffer<const float>(bw_cell_to_forget_weights_),
334 GetBuffer<const float>(bw_cell_to_output_weights_),
335 GetOptionalBuffer<const float>(aux_input_),
336 GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
337 GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
338 GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
339 GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
340 GetBuffer<const float>(bw_input_gate_bias_),
341 GetBuffer<const float>(bw_forget_gate_bias_),
342 GetBuffer<const float>(bw_cell_bias_),
343 GetBuffer<const float>(bw_output_gate_bias_),
344 GetBuffer<const float>(bw_projection_weights_),
345 GetBuffer<const float>(bw_projection_bias_),
346 GetBuffer<const float>(bw_activation_state_),
347 GetBuffer<const float>(bw_cell_state_),
348 GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
349 GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
350 GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
351 GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
352 GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_),
353 params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
354 : GetBuffer<float>(bw_output_),
355 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
356 if (params_.merge_outputs) {
357 std::vector<float> temp(n_output_elements);
358 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
359 GetBuffer<float>(fw_output_) + n_fw_output_elements,
360 bw_output_dims, temp.data());
361 std::copy(temp.data(), temp.data() + n_output_elements,
362 GetBuffer<float>(fw_output_));
363 }
364 } break;
365 case OperandType::TENSOR_FLOAT16: {
366 std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
367 const bool kForwardSequence = true;
368 LSTMCell::LSTMEvalFloat16(
369 params_, GetBuffer<const _Float16>(input_), input_->shape(),
370 GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
371 GetBuffer<const _Float16>(fw_input_to_forget_weights_),
372 GetBuffer<const _Float16>(fw_input_to_cell_weights_),
373 GetBuffer<const _Float16>(fw_input_to_output_weights_),
374 fw_input_to_output_weights_->shape(),
375 GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
376 GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
377 GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
378 GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
379 fw_recurrent_to_output_weights_->shape(),
380 GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
381 GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
382 GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_),
383 GetOptionalBuffer<const _Float16>(aux_input_),
384 GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
385 GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
386 GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
387 GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
388 GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
389 GetBuffer<const _Float16>(fw_forget_gate_bias_),
390 GetBuffer<const _Float16>(fw_cell_bias_),
391 GetBuffer<const _Float16>(fw_output_gate_bias_),
392 GetOptionalBuffer<const _Float16>(fw_projection_weights_),
393 GetOptionalBuffer<const _Float16>(fw_projection_bias_),
394 GetBuffer<const _Float16>(fw_activation_state_),
395 GetBuffer<const _Float16>(fw_cell_state_),
396 GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
397 GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
398 GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
399 GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
400 GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_),
401 GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
402 kForwardSequence);
403
404 std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
405 const bool kBackwardSequence = false;
406 LSTMCell::LSTMEvalFloat16(
407 params_, GetBuffer<const _Float16>(input_), input_->shape(),
408 GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
409 GetBuffer<const _Float16>(bw_input_to_forget_weights_),
410 GetBuffer<const _Float16>(bw_input_to_cell_weights_),
411 GetBuffer<const _Float16>(bw_input_to_output_weights_),
412 bw_input_to_output_weights_->shape(),
413 GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
414 GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
415 GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
416 GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
417 bw_recurrent_to_output_weights_->shape(),
418 GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
419 GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
420 GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_),
421 GetOptionalBuffer<const _Float16>(aux_input_),
422 GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
423 GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
424 GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
425 GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
426 GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
427 GetBuffer<const _Float16>(bw_forget_gate_bias_),
428 GetBuffer<const _Float16>(bw_cell_bias_),
429 GetBuffer<const _Float16>(bw_output_gate_bias_),
430 GetOptionalBuffer<const _Float16>(bw_projection_weights_),
431 GetOptionalBuffer<const _Float16>(bw_projection_bias_),
432 GetBuffer<const _Float16>(bw_activation_state_),
433 GetBuffer<const _Float16>(bw_cell_state_),
434 GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
435 GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
436 GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
437 GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
438 GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_),
439 params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
440 : GetBuffer<_Float16>(bw_output_),
441 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
442 if (params_.merge_outputs) {
443 std::vector<_Float16> temp(n_output_elements);
444 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
445 GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
446 bw_output_dims, temp.data());
447 std::copy(temp.data(), temp.data() + n_output_elements,
448 GetBuffer<_Float16>(fw_output_));
449 }
450 } break;
451 default: {
452 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
453 return false;
454 }
455 }
456 return true;
457 }
458
459 } // namespace nn
460 } // namespace android
461