• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "OperationResolver.h"
20 #include "RNN.h"
21 
22 namespace android {
23 namespace nn {
24 namespace bidirectional_sequence_rnn {
25 
26 constexpr uint32_t kNumInputs = 15;
27 constexpr uint32_t kInputTensor = 0;
28 // Forward cell tensors
29 constexpr uint32_t kFwWeightsTensor = 1;
30 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
31 constexpr uint32_t kFwBiasTensor = 3;
32 constexpr uint32_t kFwHiddenStateTensor = 4;
33 // Backward cell tensors
34 constexpr uint32_t kBwWeightsTensor = 5;
35 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
36 constexpr uint32_t kBwBiasTensor = 7;
37 constexpr uint32_t kBwHiddenStateTensor = 8;
38 // Auxiliary inputs
39 constexpr uint32_t kAuxInputTensor = 9;       // optional
40 constexpr uint32_t kFwAuxWeightsTensor = 10;  // optional
41 constexpr uint32_t kBwAuxWeightsTensor = 11;  // optional
42 // Cell parameters
43 constexpr uint32_t kActivationParam = 12;
44 constexpr uint32_t kTimeMajorParam = 13;
45 constexpr uint32_t kMergeOutputsParam = 14;
46 
47 constexpr uint32_t kFwOutputTensor = 0;
48 constexpr uint32_t kBwOutputTensor = 1;  // Only if mergeOutputs parameter is false
49 
50 namespace {
51 
52 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)53 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
54     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
55     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
56     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
57     for (int f = 0; f < firstDimSize; ++f) {
58         for (int s = 0; s < secondDimSize; ++s) {
59             for (int i = 0; i < inputSize; ++i) {
60                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
61                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
62                 output[outputIndex] = input[inputIndex];
63             }
64         }
65     }
66 }
67 
removeFirstDim(const Shape & input)68 Shape removeFirstDim(const Shape& input) {
69     Shape output = input;
70     output.dimensions.resize(input.dimensions.size() - 1);
71     for (int i = 0; i < input.dimensions.size() - 1; ++i) {
72         output.dimensions[i] = input.dimensions[i + 1];
73     }
74     return output;
75 }
76 
77 template <typename T>
executeTyped(IOperationExecutionContext * context)78 bool executeTyped(IOperationExecutionContext* context) {
79     const T* input = context->getInputBuffer<T>(kInputTensor);
80     Shape inputShape = context->getInputShape(kInputTensor);
81 
82     const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
83     Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
84     const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
85     Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
86     const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
87     const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
88 
89     const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
90     Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
91     const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
92     Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
93     const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
94     const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
95 
96     const T* auxInput = nullptr;
97     const T* fwAuxWeights = nullptr;
98     const T* bwAuxWeights = nullptr;
99     const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
100     if (hasAuxInputs) {
101         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
102         fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
103         bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
104     }
105     Shape auxInputShape = context->getInputShape(kAuxInputTensor);
106     Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
107     Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
108 
109     int32_t activation = context->getInputValue<int32_t>(kActivationParam);
110     int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
111     int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
112 
113     T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
114     Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
115     T* bwOutput = nullptr;
116     Shape bwOutputShape;
117     if (!mergeOutputs) {
118         bwOutputShape = context->getOutputShape(kBwOutputTensor);
119         bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
120     }
121 
122     // If the input tensors are not in time major format, we transpose the first
123     // two dimensions, and set input and output pointers to temporary vectors
124     // which are transposed back after the RNN is applied.
125     std::vector<T> inputTransposed;
126     std::vector<T> auxInputTransposed;
127     std::vector<T> fwOutputTransposed;
128     std::vector<T> bwOutputTransposed;
129     if (!timeMajor) {
130         // First, resize temporary buffers to accommodate for transposed tensors.
131         inputTransposed.resize(getNumberOfElements(inputShape));
132         if (hasAuxInputs) {
133             auxInputTransposed.resize(getNumberOfElements(auxInputShape));
134         }
135         fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
136         if (!mergeOutputs) {
137             bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
138         }
139 
140         // Transpose the input tensors.
141         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
142         if (hasAuxInputs) {
143             transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
144         }
145 
146         // Change input and output pointers to the temporary buffers.
147         input = inputTransposed.data();
148         if (hasAuxInputs) {
149             auxInput = auxInputTransposed.data();
150         }
151         fwOutput = fwOutputTransposed.data();
152         if (!mergeOutputs) {
153             bwOutput = bwOutputTransposed.data();
154         }
155 
156         // Swap the first two dimensions in the Shapes to reflect the
157         // transposition.
158         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
159         if (hasAuxInputs) {
160             std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
161         }
162         std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
163         if (!mergeOutputs) {
164             std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
165         }
166     }
167 
168     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
169     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
170     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
171     uint32_t auxInputSize = 0;
172     if (hasAuxInputs) {
173         auxInputSize = getSizeOfDimension(auxInputShape, 2);
174     }
175     const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
176     const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
177 
178     Shape fixedTimeInputShape = removeFirstDim(inputShape);
179     Shape fixedTimeAuxInputShape = auxInputShape;
180     if (hasAuxInputs) {
181         fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
182     }
183 
184     // Create an additional buffer to store a hidden state between steps.
185     std::vector<T> tempHiddenState(batchSize * fwNumUnits);
186     // Forward pass
187     for (int i = 0; i < maxTime; ++i) {
188         const T* inputBatchPtr = input + i * batchSize * inputSize;
189         const T* auxInputBatchPtr = nullptr;
190         if (hasAuxInputs) {
191             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
192         }
193         const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
194         T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
195 
196         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
197                         fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
198                         fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
199                         fwRecurrentWeightsShape, activation, fwOutputBatchStride,
200                         /*outputBatchOffset=*/0, fwOutputBatchPtr, tempHiddenState.data());
201 
202         fwHiddenState = tempHiddenState.data();
203     }
204 
205     tempHiddenState.resize(batchSize * bwNumUnits);
206     // Backward pass
207     for (int i = maxTime - 1; i >= 0; --i) {
208         const T* inputBatchPtr = input + i * batchSize * inputSize;
209         const T* auxInputBatchPtr = nullptr;
210         if (hasAuxInputs) {
211             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
212         }
213         T* bwOutputBatchPtr;
214         uint32_t bwOutputBatchOffset = 0;
215         uint32_t bwOutputBatchStride;
216         if (mergeOutputs) {
217             bwOutputBatchStride = fwNumUnits + bwNumUnits;
218             bwOutputBatchOffset = fwNumUnits;
219             bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
220         } else {
221             bwOutputBatchStride = bwNumUnits;
222             bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
223         }
224 
225         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
226                         fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
227                         bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
228                         bwRecurrentWeightsShape, activation, bwOutputBatchStride,
229                         bwOutputBatchOffset, bwOutputBatchPtr, tempHiddenState.data());
230 
231         bwHiddenState = tempHiddenState.data();
232     }
233 
234     // If the inputs were in batch major format, transpose data in temporary
235     // buffers and write to the output(s).
236     if (!timeMajor) {
237         transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
238                               context->getOutputBuffer<T>(kFwOutputTensor));
239         if (!mergeOutputs) {
240             transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
241                                   context->getOutputBuffer<T>(kBwOutputTensor));
242         }
243     }
244     return true;
245 }
246 
247 }  // namespace
248 
validate(const IOperationValidationContext * context)249 bool validate(const IOperationValidationContext* context) {
250     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
251     // Exact number is dependent on the mergeOutputs parameter and checked
252     // during preparation.
253     NN_RET_CHECK(context->getNumOutputs() == 1 || context->getNumOutputs() == 2);
254     OperandType inputType = context->getInputType(kInputTensor);
255     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
256         LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
257                    << toString(inputType);
258         return false;
259     }
260     NN_RET_CHECK(validateInputTypes(
261             context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
262                       inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
263                       OperandType::BOOL, OperandType::BOOL}));
264     if (context->getNumOutputs() == 1) {
265         NN_RET_CHECK(validateOutputTypes(context, {inputType}));
266     } else {
267         NN_RET_CHECK(validateOutputTypes(context, {inputType, inputType}));
268     }
269     return validateHalVersion(context, HalVersion::V1_2);
270 }
271 
prepare(IOperationExecutionContext * context)272 bool prepare(IOperationExecutionContext* context) {
273     int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
274     if (mergeOutputs) {
275         NN_RET_CHECK_EQ(context->getNumOutputs(), 1);
276     } else {
277         NN_RET_CHECK_EQ(context->getNumOutputs(), 2);
278     }
279 
280     // Check that none of the required inputs are omitted.
281     const std::vector<int> requiredInputs = {
282             kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
283             kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
284             kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
285     };
286     for (const int requiredInput : requiredInputs) {
287         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
288                 << "required input " << requiredInput << " is omitted";
289     }
290 
291     Shape input = context->getInputShape(kInputTensor);
292     Shape fwWeights = context->getInputShape(kFwWeightsTensor);
293     Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
294     Shape fwBias = context->getInputShape(kFwBiasTensor);
295     Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
296     Shape bwWeights = context->getInputShape(kBwWeightsTensor);
297     Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
298     Shape bwBias = context->getInputShape(kBwBiasTensor);
299     Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
300 
301     Shape auxInput = context->getInputShape(kAuxInputTensor);
302     Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
303     Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
304 
305     const bool auxInputsAllOrNone = (context->isOmittedInput(kAuxInputTensor) &&
306                                      context->isOmittedInput(kFwAuxWeightsTensor) &&
307                                      context->isOmittedInput(kBwAuxWeightsTensor)) ||
308                                     (!context->isOmittedInput(kAuxInputTensor) &&
309                                      !context->isOmittedInput(kFwAuxWeightsTensor) &&
310                                      !context->isOmittedInput(kBwAuxWeightsTensor));
311     NN_RET_CHECK(auxInputsAllOrNone);
312     const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
313 
314     int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
315     const uint32_t batchSize =
316             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
317     const uint32_t maxTime =
318             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
319     const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
320     const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
321     const uint32_t inputSize = getSizeOfDimension(input, 2);
322 
323     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
324     NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
325     NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
326     NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
327     NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
328     NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
329     NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
330     NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
331     NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
332 
333     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
334     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
335     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
336     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
337     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
338     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
339 
340     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
341     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
342     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
343     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
344     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
345     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
346 
347     if (hasAuxInputs) {
348         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
349         NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
350         NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
351 
352         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
353         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
354         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
355         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
356         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
357         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
358     }
359 
360     Shape fwOutput = context->getOutputShape(kFwOutputTensor);
361     fwOutput.dimensions.resize(3);
362     fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
363     fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
364     fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
365     NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
366     if (!mergeOutputs) {
367         Shape bwOutput = context->getOutputShape(kBwOutputTensor);
368         bwOutput.dimensions.resize(3);
369         bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
370         bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
371         bwOutput.dimensions[2] = bwNumUnits;
372         NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
373     }
374 
375     return true;
376 }
377 
execute(IOperationExecutionContext * context)378 bool execute(IOperationExecutionContext* context) {
379     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
380         executeTyped<_Float16>(context);
381     } else {
382         executeTyped<float>(context);
383     }
384     return true;
385 }
386 
387 }  // namespace bidirectional_sequence_rnn
388 
389 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
390                       bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
391                       bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
392 
393 }  // namespace nn
394 }  // namespace android
395