1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "SVDF.h"
18
19 #include "CpuExecutor.h"
20 #include "HalInterfaces.h"
21
22 namespace android {
23 namespace nn {
24
25 namespace {
26
27 // TODO: Implement this using circular buffer instead.
28 // This is here temporarily only to show the logic.
svdf_right_shift_state(const float * state_in,int state_len,float shift_value,float * state_out)29 void svdf_right_shift_state(const float* state_in, int state_len, float shift_value,
30 float* state_out) {
31 for (int i = 0; i < state_len - 1; i++) {
32 state_out[i] = state_in[i + 1];
33 }
34 state_out[state_len - 1] = shift_value;
35 }
36
getInt32ScalarData(RunTimeOperandInfo & info)37 int32_t getInt32ScalarData(RunTimeOperandInfo& info) {
38 int32_t * data = reinterpret_cast<int32_t*>(info.buffer);
39 return data[0];
40 }
41
42 }
43
SVDF(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)44 SVDF::SVDF(const Operation& operation,
45 std::vector<RunTimeOperandInfo>& operands) {
46 input_ = GetInput(operation, operands, kInputTensor);
47 weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor);
48 weights_time_ = GetInput(operation, operands, kWeightsTimeTensor);
49 bias_ = GetInput(operation, operands, kBiasTensor);
50 state_in_ = GetInput(operation, operands, kStateInTensor);
51
52 params_.rank_ = getInt32ScalarData(*GetInput(operation, operands, kRankParam));
53 params_.activation_ = static_cast<ActivationFn>(getInt32ScalarData(
54 *GetInput(operation, operands, kActivationParam)));
55
56 state_out_ = GetOutput(operation, operands, kStateOutTensor);
57 output_ = GetOutput(operation, operands, kOutputTensor);
58 }
59
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * stateShape,Shape * outputShape)60 bool SVDF::Prepare(const Operation &operation,
61 std::vector<RunTimeOperandInfo> &operands,
62 Shape *stateShape,
63 Shape *outputShape) {
64 // Check we have all the inputs and outputs we need.
65 const int num_inputs = NumInputsWithValues(operation, operands);
66 NN_CHECK(num_inputs == 6 || num_inputs == 7);
67 NN_CHECK_EQ(NumOutputs(operation), 2);
68
69 const RunTimeOperandInfo *input =
70 GetInput(operation, operands, SVDF::kInputTensor);
71 const RunTimeOperandInfo *weights_feature =
72 GetInput(operation, operands, SVDF::kWeightsFeatureTensor);
73 const RunTimeOperandInfo *weights_time =
74 GetInput(operation, operands, SVDF::kWeightsTimeTensor);
75
76 // Check all the parameters of tensor match within themselves and match the
77 // input configuration.
78 const uint32_t batch_size = SizeOfDimension(input, 0);
79 const uint32_t num_units = SizeOfDimension(weights_feature, 0);
80 const uint32_t memory_size = SizeOfDimension(weights_time, 1);
81 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1));
82 NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_units);
83
84 const RunTimeOperandInfo *bias =
85 GetInput(operation, operands, kBiasTensor);
86 if (!IsNullInput(bias)) {
87 NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
88 }
89
90 // Resize state.
91 const Shape &inputShape = input->shape();
92 stateShape->type = inputShape.type;
93 stateShape->dimensions = { batch_size, memory_size * num_units };
94 stateShape->offset = inputShape.offset;
95 stateShape->scale = inputShape.scale;
96
97 // Resize output.
98 outputShape->type = inputShape.type;
99 outputShape->dimensions = { batch_size, num_units };
100 outputShape->offset = inputShape.offset;
101 outputShape->scale = inputShape.scale;
102
103 return true;
104 }
105
Eval()106 bool SVDF::Eval() {
107 const int batch_size = input_->shape().dimensions[0];
108 const int input_size = input_->shape().dimensions[1];
109 const int num_units = weights_feature_->shape().dimensions[0];
110 const int memory_size = weights_time_->shape().dimensions[1];
111 const int weights_feature_stride = weights_feature_->shape().dimensions[1];
112 const int weights_time_stride = weights_time_->shape().dimensions[1];
113
114 // Initialize weights_feature and weights_time pointers.
115 const float* weights_feature_ptr = reinterpret_cast<float *>(weights_feature_->buffer);
116 const float* weights_time_ptr = reinterpret_cast<float *>(weights_time_->buffer);
117
118 // For each batch
119 for (int b = 0; b < batch_size; b++) {
120 // Initialize the pointer to input, output and bias.
121 const float* input_ptr_batch = reinterpret_cast<float *>(input_->buffer) + b * input_size;
122 float* output_ptr_batch = reinterpret_cast<float*>(output_->buffer) + b * num_units;
123 const float* state_in_ptr_batch = reinterpret_cast<const float*>(state_in_->buffer) + b * (memory_size - 1) * num_units;
124 float* state_out_ptr_batch = reinterpret_cast<float*>(state_out_->buffer) + b * (memory_size - 1) * num_units;
125
126 // For each unit
127 for (int c = 0; c < num_units; c++) {
128 float activation = 0.0;
129
130 // tf.nn.conv1d(inputs, weights_feature, feature_dim, "VALID")
131 for (int j = 0; j < input_size; j++) {
132 activation += input_ptr_batch[j] * weights_feature_ptr[j];
133 }
134
135 // Initialize state pointer for unit 'c'.
136 const float* state_in_ptr = state_in_ptr_batch + c * (memory_size - 1);
137 float* state_out_ptr = state_out_ptr_batch + c * (memory_size - 1);
138
139 // Apply bias if bias tensor exists.
140 output_ptr_batch[c] = bias_->buffer ? reinterpret_cast<float *>(bias_->buffer)[c] : 0.f;
141
142 // output = tf.matmul(state, weights_time)
143 output_ptr_batch[c] += weights_time_ptr[memory_size - 1] * activation;
144 for (int j = 0; j < memory_size - 1; j++) {
145 output_ptr_batch[c] += weights_time_ptr[j] * state_in_ptr[j];
146 }
147
148 // Apply activation.
149 output_ptr_batch[c] =
150 (ActivationFunctor(params_.activation_))(output_ptr_batch[c]);
151
152 // Right shift the state and concatenate with activation.
153 svdf_right_shift_state(state_in_ptr, memory_size - 1, activation,
154 state_out_ptr);
155
156 // Update weight pointers.
157 weights_feature_ptr += weights_feature_stride;
158 weights_time_ptr += weights_time_stride;
159 }
160 // Reset weight pointers for next batch.
161 weights_feature_ptr = reinterpret_cast<float*>(weights_feature_->buffer);
162 weights_time_ptr = reinterpret_cast<float*>(weights_time_->buffer);
163 }
164 return true;
165 }
166
167 } // namespace nn
168 } // namespace android
169