• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RNN.h"
18 
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21 #include "HalInterfaces.h"
22 
23 #include "Tracing.h"
24 
25 namespace android {
26 namespace nn {
27 
RNN(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)28 RNN::RNN(const Operation& operation,
29          std::vector<RunTimeOperandInfo>& operands) {
30   NNTRACE_TRANS("RNN::RNN");
31   input_ = GetInput(operation, operands, kInputTensor);
32   weights_ = GetInput(operation, operands, kWeightsTensor);
33   recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
34   hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
35   bias_ = GetInput(operation, operands, kBiasTensor);
36 
37   activation_ = static_cast<ActivationFn>(
38       getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
39 
40   hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
41   output_ = GetOutput(operation, operands, kOutputTensor);
42 }
43 
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * hiddenStateShape,Shape * outputShape)44 bool RNN::Prepare(const Operation &operation,
45                   std::vector<RunTimeOperandInfo> &operands,
46                   Shape *hiddenStateShape,
47                   Shape *outputShape) {
48   NNTRACE_TRANS("RNN::Prepare");
49   // Check we have all the inputs and outputs we need.
50   const int num_inputs = NumInputsWithValues(operation, operands);
51   NN_CHECK(num_inputs == 5 || num_inputs == 6);
52   NN_CHECK_EQ(NumOutputs(operation), 2);
53 
54   const RunTimeOperandInfo *input =
55       GetInput(operation, operands, kInputTensor);
56   const RunTimeOperandInfo *input_weights =
57       GetInput(operation, operands, kWeightsTensor);
58   const RunTimeOperandInfo *recurrent_weights =
59       GetInput(operation, operands, kRecurrentWeightsTensor);
60   const RunTimeOperandInfo *bias =
61       GetInput(operation, operands, kBiasTensor);
62 
63   // Check all the parameters of tensor match within themselves and match the
64   // input configuration.
65   const uint32_t batch_size = SizeOfDimension(input, 0);
66   const uint32_t num_units = SizeOfDimension(input_weights, 0);
67   NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
68   NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
69   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
70   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
71 
72   const Shape &inputShape = input->shape();
73 
74   // Resize state.
75   hiddenStateShape->type = inputShape.type;
76   hiddenStateShape->dimensions = { batch_size, num_units };
77 
78   // Resize output.
79   outputShape->type = inputShape.type;
80   outputShape->dimensions = { batch_size, num_units };
81 
82   return true;
83 }
84 
Eval()85 bool RNN::Eval() {
86     switch (input_->type) {
87         case OperandType::TENSOR_FLOAT16: {
88             RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
89                               reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
90                               reinterpret_cast<_Float16*>(bias_->buffer),
91                               reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
92                               reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
93                               recurrent_weights_->shape(), activation_,
94                               reinterpret_cast<_Float16*>(output_->buffer));
95             memcpy(hidden_state_out_->buffer, output_->buffer,
96                    sizeof(_Float16) * getNumberOfElements(output_->shape()));
97             break;
98         }
99         case OperandType::TENSOR_FLOAT32: {
100             RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
101                            reinterpret_cast<float*>(hidden_state_in_->buffer),
102                            reinterpret_cast<float*>(bias_->buffer),
103                            reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
104                            reinterpret_cast<float*>(recurrent_weights_->buffer),
105                            recurrent_weights_->shape(), activation_,
106                            reinterpret_cast<float*>(output_->buffer));
107             memcpy(hidden_state_out_->buffer, output_->buffer,
108                    sizeof(float) * getNumberOfElements(output_->shape()));
109             break;
110         }
111         default: {
112             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
113             return false;
114         }
115     }
116     return true;
117 }
118 
119 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,T * outputData)120 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
121                   const T* biasData, const T* weightsData, const Shape& weightsShape,
122                   const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
123                   const int32_t activation, T* outputData) {
124     NNTRACE_COMP("RNN::Eval");
125 
126     Shape dummyShape;
127     uint32_t numUnits = weightsShape.dimensions[0];
128     return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
129                       hiddenStateInputData, biasData, weightsData, weightsShape,
130                       /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
131                       recurrentWeightsData, recurrentWeightsShape, activation,
132                       /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
133 }
134 
135 // A more general version of the RNNStep function.
136 // Auxiliary input is treated as if it was concatenated to a regular input and
137 // the result was multiplied by the weights matrix which was also concatenated
138 // with auxiliary weights.
139 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * auxInputData,const Shape & auxInputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * auxWeightsData,const Shape & auxWeightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,const uint32_t outputBatchStride,const uint32_t outputBatchOffset,T * outputData,T * hiddenStateOutput)140 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
141                   const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
142                   const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
143                   const Shape& auxWeightsShape, const T* recurrentWeightsData,
144                   const Shape& recurrentWeightsShape, const int32_t activation,
145                   const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
146                   T* hiddenStateOutput) {
147     NNTRACE_COMP("RNN::Eval");
148 
149     const uint32_t batch_size = inputShape.dimensions[0];
150     const uint32_t num_units = weightsShape.dimensions[0];
151     const uint32_t input_size = inputShape.dimensions[1];
152     const uint32_t input_weights_stride = weightsShape.dimensions[1];
153     const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
154 
155     uint32_t aux_input_size = 0;
156     uint32_t aux_input_weights_stride = 0;
157     bool hasAuxInput = (auxInputData != nullptr);
158     if (hasAuxInput) {
159         aux_input_size = auxInputShape.dimensions[1];
160         aux_input_weights_stride = auxWeightsShape.dimensions[1];
161     }
162 
163     // For each batch
164     for (uint32_t b = 0; b < batch_size; b++) {
165         // Initialize the pointer to input, output and bias.
166         const T* input_ptr_batch = inputData + b * input_size;
167         const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
168         const T* aux_input_ptr_batch = nullptr;
169         if (hasAuxInput) {
170             aux_input_ptr_batch = auxInputData + b * aux_input_size;
171         }
172         T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
173 
174         // Initialize input_weights and recurrent_weights.
175         const T* input_weights_ptr = weightsData;
176         const T* recurrent_weights_ptr = recurrentWeightsData;
177         const T* aux_input_weights_ptr = nullptr;
178         if (hasAuxInput) {
179             aux_input_weights_ptr = auxWeightsData;
180         }
181 
182         // Output = bias
183         for (uint32_t o = 0; o < num_units; o++) {
184             output_ptr_batch[o] = biasData[o];
185         }
186 
187         // Output += input * input_weights
188         for (uint32_t o = 0; o < num_units; o++) {
189             for (uint32_t i = 0; i < input_size; i++) {
190                 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
191             }
192             input_weights_ptr += input_weights_stride;
193         }
194 
195         if (hasAuxInput) {
196             // Output += aux_input * aux_input_weights
197             for (uint32_t o = 0; o < num_units; o++) {
198                 for (uint32_t i = 0; i < input_size; i++) {
199                     output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
200                 }
201                 aux_input_weights_ptr += aux_input_weights_stride;
202             }
203         }
204 
205         // Output += recurrent_weights * hidden_state
206         for (uint32_t o = 0; o < num_units; o++) {
207             for (uint32_t h = 0; h < num_units; h++) {
208                 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
209             }
210             recurrent_weights_ptr += recurrent_weights_stride;
211         }
212 
213         // Output = activation(Output)
214         for (uint32_t o = 0; o < num_units; o++) {
215             output_ptr_batch[o] =
216                     (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
217             if (hiddenStateOutput != nullptr) {
218                 *hiddenStateOutput = output_ptr_batch[o];
219                 ++hiddenStateOutput;
220             }
221         }
222     }
223 
224     return true;
225 }
226 
227 }  // namespace nn
228 }  // namespace android
229