1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "OperationResolver.h"
20 #include "RNN.h"
21
22 namespace android {
23 namespace nn {
24 namespace bidirectional_sequence_rnn {
25
26 constexpr uint32_t kNumInputs = 15;
27 constexpr uint32_t kInputTensor = 0;
28 // Forward cell tensors
29 constexpr uint32_t kFwWeightsTensor = 1;
30 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
31 constexpr uint32_t kFwBiasTensor = 3;
32 constexpr uint32_t kFwHiddenStateTensor = 4;
33 // Backward cell tensors
34 constexpr uint32_t kBwWeightsTensor = 5;
35 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
36 constexpr uint32_t kBwBiasTensor = 7;
37 constexpr uint32_t kBwHiddenStateTensor = 8;
38 // Auxiliary inputs
39 constexpr uint32_t kAuxInputTensor = 9; // optional
40 constexpr uint32_t kFwAuxWeightsTensor = 10; // optional
41 constexpr uint32_t kBwAuxWeightsTensor = 11; // optional
42 // Cell parameters
43 constexpr uint32_t kActivationParam = 12;
44 constexpr uint32_t kTimeMajorParam = 13;
45 constexpr uint32_t kMergeOutputsParam = 14;
46
47 constexpr uint32_t kFwOutputTensor = 0;
48 constexpr uint32_t kBwOutputTensor = 1; // Only if mergeOutputs parameter is false
49
50 namespace {
51
52 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)53 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
54 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
55 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
56 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
57 for (int f = 0; f < firstDimSize; ++f) {
58 for (int s = 0; s < secondDimSize; ++s) {
59 for (int i = 0; i < inputSize; ++i) {
60 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
61 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
62 output[outputIndex] = input[inputIndex];
63 }
64 }
65 }
66 }
67
removeFirstDim(const Shape & input)68 Shape removeFirstDim(const Shape& input) {
69 Shape output = input;
70 output.dimensions.resize(input.dimensions.size() - 1);
71 for (int i = 0; i < input.dimensions.size() - 1; ++i) {
72 output.dimensions[i] = input.dimensions[i + 1];
73 }
74 return output;
75 }
76
77 template <typename T>
executeTyped(IOperationExecutionContext * context)78 bool executeTyped(IOperationExecutionContext* context) {
79 const T* input = context->getInputBuffer<T>(kInputTensor);
80 Shape inputShape = context->getInputShape(kInputTensor);
81
82 const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
83 Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
84 const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
85 Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
86 const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
87 const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
88
89 const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
90 Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
91 const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
92 Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
93 const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
94 const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
95
96 const T* auxInput = nullptr;
97 const T* fwAuxWeights = nullptr;
98 const T* bwAuxWeights = nullptr;
99 const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
100 if (hasAuxInputs) {
101 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
102 fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
103 bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
104 }
105 Shape auxInputShape = context->getInputShape(kAuxInputTensor);
106 Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
107 Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
108
109 int32_t activation = context->getInputValue<int32_t>(kActivationParam);
110 int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
111 int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
112
113 T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
114 Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
115 T* bwOutput = nullptr;
116 Shape bwOutputShape;
117 if (!mergeOutputs) {
118 bwOutputShape = context->getOutputShape(kBwOutputTensor);
119 bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
120 }
121
122 // If the input tensors are not in time major format, we transpose the first
123 // two dimensions, and set input and output pointers to temporary vectors
124 // which are transposed back after the RNN is applied.
125 std::vector<T> inputTransposed;
126 std::vector<T> auxInputTransposed;
127 std::vector<T> fwOutputTransposed;
128 std::vector<T> bwOutputTransposed;
129 if (!timeMajor) {
130 // First, resize temporary buffers to accommodate for transposed tensors.
131 inputTransposed.resize(getNumberOfElements(inputShape));
132 if (hasAuxInputs) {
133 auxInputTransposed.resize(getNumberOfElements(auxInputShape));
134 }
135 fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
136 if (!mergeOutputs) {
137 bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
138 }
139
140 // Transpose the input tensors.
141 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
142 if (hasAuxInputs) {
143 transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
144 }
145
146 // Change input and output pointers to the temporary buffers.
147 input = inputTransposed.data();
148 if (hasAuxInputs) {
149 auxInput = auxInputTransposed.data();
150 }
151 fwOutput = fwOutputTransposed.data();
152 if (!mergeOutputs) {
153 bwOutput = bwOutputTransposed.data();
154 }
155
156 // Swap the first two dimensions in the Shapes to reflect the
157 // transposition.
158 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
159 if (hasAuxInputs) {
160 std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
161 }
162 std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
163 if (!mergeOutputs) {
164 std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
165 }
166 }
167
168 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
169 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
170 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
171 uint32_t auxInputSize = 0;
172 if (hasAuxInputs) {
173 auxInputSize = getSizeOfDimension(auxInputShape, 2);
174 }
175 const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
176 const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
177
178 Shape fixedTimeInputShape = removeFirstDim(inputShape);
179 Shape fixedTimeAuxInputShape = auxInputShape;
180 if (hasAuxInputs) {
181 fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
182 }
183
184 // Create an additional buffer to store a hidden state between steps.
185 std::vector<T> tempHiddenState(batchSize * fwNumUnits);
186 // Forward pass
187 for (int i = 0; i < maxTime; ++i) {
188 const T* inputBatchPtr = input + i * batchSize * inputSize;
189 const T* auxInputBatchPtr = nullptr;
190 if (hasAuxInputs) {
191 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
192 }
193 const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
194 T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
195
196 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
197 fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
198 fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
199 fwRecurrentWeightsShape, activation, fwOutputBatchStride,
200 /*outputBatchOffset=*/0, fwOutputBatchPtr, tempHiddenState.data());
201
202 fwHiddenState = tempHiddenState.data();
203 }
204
205 tempHiddenState.resize(batchSize * bwNumUnits);
206 // Backward pass
207 for (int i = maxTime - 1; i >= 0; --i) {
208 const T* inputBatchPtr = input + i * batchSize * inputSize;
209 const T* auxInputBatchPtr = nullptr;
210 if (hasAuxInputs) {
211 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
212 }
213 T* bwOutputBatchPtr;
214 uint32_t bwOutputBatchOffset = 0;
215 uint32_t bwOutputBatchStride;
216 if (mergeOutputs) {
217 bwOutputBatchStride = fwNumUnits + bwNumUnits;
218 bwOutputBatchOffset = fwNumUnits;
219 bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
220 } else {
221 bwOutputBatchStride = bwNumUnits;
222 bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
223 }
224
225 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
226 fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
227 bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
228 bwRecurrentWeightsShape, activation, bwOutputBatchStride,
229 bwOutputBatchOffset, bwOutputBatchPtr, tempHiddenState.data());
230
231 bwHiddenState = tempHiddenState.data();
232 }
233
234 // If the inputs were in batch major format, transpose data in temporary
235 // buffers and write to the output(s).
236 if (!timeMajor) {
237 transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
238 context->getOutputBuffer<T>(kFwOutputTensor));
239 if (!mergeOutputs) {
240 transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
241 context->getOutputBuffer<T>(kBwOutputTensor));
242 }
243 }
244 return true;
245 }
246
247 } // namespace
248
validate(const IOperationValidationContext * context)249 bool validate(const IOperationValidationContext* context) {
250 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
251 // Exact number is dependent on the mergeOutputs parameter and checked
252 // during preparation.
253 NN_RET_CHECK(context->getNumOutputs() == 1 || context->getNumOutputs() == 2);
254 OperandType inputType = context->getInputType(kInputTensor);
255 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
256 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
257 << toString(inputType);
258 return false;
259 }
260 NN_RET_CHECK(validateInputTypes(
261 context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
262 inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
263 OperandType::BOOL, OperandType::BOOL}));
264 if (context->getNumOutputs() == 1) {
265 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
266 } else {
267 NN_RET_CHECK(validateOutputTypes(context, {inputType, inputType}));
268 }
269 return validateHalVersion(context, HalVersion::V1_2);
270 }
271
prepare(IOperationExecutionContext * context)272 bool prepare(IOperationExecutionContext* context) {
273 int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
274 if (mergeOutputs) {
275 NN_RET_CHECK_EQ(context->getNumOutputs(), 1);
276 } else {
277 NN_RET_CHECK_EQ(context->getNumOutputs(), 2);
278 }
279
280 // Check that none of the required inputs are omitted.
281 const std::vector<int> requiredInputs = {
282 kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
283 kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
284 kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam,
285 };
286 for (const int requiredInput : requiredInputs) {
287 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
288 << "required input " << requiredInput << " is omitted";
289 }
290
291 Shape input = context->getInputShape(kInputTensor);
292 Shape fwWeights = context->getInputShape(kFwWeightsTensor);
293 Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
294 Shape fwBias = context->getInputShape(kFwBiasTensor);
295 Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
296 Shape bwWeights = context->getInputShape(kBwWeightsTensor);
297 Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
298 Shape bwBias = context->getInputShape(kBwBiasTensor);
299 Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
300
301 Shape auxInput = context->getInputShape(kAuxInputTensor);
302 Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
303 Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
304
305 const bool auxInputsAllOrNone = (context->isOmittedInput(kAuxInputTensor) &&
306 context->isOmittedInput(kFwAuxWeightsTensor) &&
307 context->isOmittedInput(kBwAuxWeightsTensor)) ||
308 (!context->isOmittedInput(kAuxInputTensor) &&
309 !context->isOmittedInput(kFwAuxWeightsTensor) &&
310 !context->isOmittedInput(kBwAuxWeightsTensor));
311 NN_RET_CHECK(auxInputsAllOrNone);
312 const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
313
314 int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
315 const uint32_t batchSize =
316 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
317 const uint32_t maxTime =
318 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
319 const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
320 const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
321 const uint32_t inputSize = getSizeOfDimension(input, 2);
322
323 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
324 NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
325 NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
326 NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
327 NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
328 NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
329 NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
330 NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
331 NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
332
333 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
334 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
335 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
336 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
337 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
338 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
339
340 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
341 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
342 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
343 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
344 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
345 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
346
347 if (hasAuxInputs) {
348 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
349 NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
350 NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
351
352 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
353 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
354 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
355 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
356 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
357 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
358 }
359
360 Shape fwOutput = context->getOutputShape(kFwOutputTensor);
361 fwOutput.dimensions.resize(3);
362 fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
363 fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
364 fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
365 NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
366 if (!mergeOutputs) {
367 Shape bwOutput = context->getOutputShape(kBwOutputTensor);
368 bwOutput.dimensions.resize(3);
369 bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
370 bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
371 bwOutput.dimensions[2] = bwNumUnits;
372 NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
373 }
374
375 return true;
376 }
377
execute(IOperationExecutionContext * context)378 bool execute(IOperationExecutionContext* context) {
379 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
380 executeTyped<_Float16>(context);
381 } else {
382 executeTyped<float>(context);
383 }
384 return true;
385 }
386
387 } // namespace bidirectional_sequence_rnn
388
389 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
390 bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
391 bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
392
393 } // namespace nn
394 } // namespace android
395