1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "RNN.h"
18
19 #include "NeuralNetworksWrapper.h"
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22
23 namespace android {
24 namespace nn {
25 namespace wrapper {
26
27 using ::testing::Each;
28 using ::testing::FloatNear;
29 using ::testing::Matcher;
30
31 namespace {
32
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)33 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
34 float max_abs_error = 1.e-5) {
35 std::vector<Matcher<float>> matchers;
36 matchers.reserve(values.size());
37 for (const float& v : values) {
38 matchers.emplace_back(FloatNear(v, max_abs_error));
39 }
40 return matchers;
41 }
42
43 static float rnn_input[] = {
44 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
45 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
46 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
47 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
48 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
49 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
50 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
51 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
52 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
53 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
54 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
55 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
56 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
57 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
58 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
59 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
60 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
61 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
62 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
63 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
64 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
65 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
66 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
67 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
68 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
69 0.93455386, -0.6324693, -0.083922029};
70
71 static float rnn_golden_output[] = {
72 0.496726, 0, 0.965996, 0, 0.0584254, 0,
73 0, 0.12315, 0, 0, 0.612266, 0.456601,
74 0, 0.52286, 1.16099, 0.0291232,
75
76 0, 0, 0.524901, 0, 0, 0,
77 0, 1.02116, 0, 1.35762, 0, 0.356909,
78 0.436415, 0.0355727, 0, 0,
79
80 0, 0, 0, 0.262335, 0, 0,
81 0, 1.33992, 0, 2.9739, 0, 0,
82 1.31914, 2.66147, 0, 0,
83
84 0.942568, 0, 0, 0, 0.025507, 0,
85 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
86 0.8158, 1.21805, 0.586239, 0.25427,
87
88 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
89 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
90 0, 1.22031, 1.30117, 0.495867,
91
92 0.222187, 0, 0.72725, 0, 0.767003, 0,
93 0, 0.147835, 0, 0, 0, 0.608758,
94 0.469394, 0.00720298, 0.927537, 0,
95
96 0.856974, 0.424257, 0, 0, 0.937329, 0,
97 0, 0, 0.476425, 0, 0.566017, 0.418462,
98 0.141911, 0.996214, 1.13063, 0,
99
100 0.967899, 0, 0, 0, 0.0831304, 0,
101 0, 1.00378, 0, 0, 0, 1.44818,
102 1.01768, 0.943891, 0.502745, 0,
103
104 0.940135, 0, 0, 0, 0, 0,
105 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
106 1.30225, 1.59644, 0.70222, 0,
107
108 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
109 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
110 0.0454298, 0.300267, 0.562784, 0.395095,
111
112 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
113 0, 0, 0, 0.735363, 0.0759267, 1.91017,
114 0.941888, 0, 0, 0,
115
116 0, 0, 1.5909, 0, 0, 0,
117 0, 0.5755, 0, 0.184687, 0, 1.56296,
118 0.625285, 0, 0, 0,
119
120 0, 0, 0.0857888, 0, 0, 0,
121 0, 0.488383, 0.252786, 0, 0, 0,
122 1.02817, 1.85665, 0, 0,
123
124 0.00981836, 0, 1.06371, 0, 0, 0,
125 0, 0, 0, 0.290445, 0.316406, 0,
126 0.304161, 1.25079, 0.0707152, 0,
127
128 0.986264, 0.309201, 0, 0, 0, 0,
129 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
130 0.524981, 1.92076, 2.07013, 0.333244,
131
132 0.415153, 0.210318, 0, 0, 0, 0,
133 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
134 0.628881, 3.58099, 1.49974, 0};
135
136 } // anonymous namespace
137
138 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
139 ACTION(Input) \
140 ACTION(Weights) \
141 ACTION(RecurrentWeights) \
142 ACTION(Bias) \
143 ACTION(HiddenStateIn)
144
145 // For all output and intermediate states
146 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
147 ACTION(HiddenStateOut) \
148 ACTION(Output)
149
150 class BasicRNNOpModel {
151 public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)152 BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
153 : batches_(batches),
154 units_(units),
155 input_size_(size),
156 activation_(kActivationRelu) {
157 std::vector<uint32_t> inputs;
158
159 OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
160 inputs.push_back(model_.addOperand(&InputTy));
161 OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
162 inputs.push_back(model_.addOperand(&WeightTy));
163 OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
164 inputs.push_back(model_.addOperand(&RecurrentWeightTy));
165 OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
166 inputs.push_back(model_.addOperand(&BiasTy));
167 OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
168 inputs.push_back(model_.addOperand(&HiddenStateTy));
169 OperandType ActionParamTy(Type::INT32, {1});
170 inputs.push_back(model_.addOperand(&ActionParamTy));
171
172 std::vector<uint32_t> outputs;
173
174 outputs.push_back(model_.addOperand(&HiddenStateTy));
175 OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
176 outputs.push_back(model_.addOperand(&OutputTy));
177
178 Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
179 HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
180 HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
181 Output_.insert(Output_.end(), batches_ * units_, 0.f);
182
183 model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
184 model_.identifyInputsAndOutputs(inputs, outputs);
185
186 model_.finish();
187 }
188
189 #define DefineSetter(X) \
190 void Set##X(const std::vector<float>& f) { \
191 X##_.insert(X##_.end(), f.begin(), f.end()); \
192 }
193
194 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
195
196 #undef DefineSetter
197
SetInput(int offset,float * begin,float * end)198 void SetInput(int offset, float* begin, float* end) {
199 for (; begin != end; begin++, offset++) {
200 Input_[offset] = *begin;
201 }
202 }
203
ResetHiddenState()204 void ResetHiddenState() {
205 std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
206 std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
207 }
208
GetOutput() const209 const std::vector<float>& GetOutput() const { return Output_; }
210
input_size() const211 uint32_t input_size() const { return input_size_; }
num_units() const212 uint32_t num_units() const { return units_; }
num_batches() const213 uint32_t num_batches() const { return batches_; }
214
Invoke()215 void Invoke() {
216 ASSERT_TRUE(model_.isValid());
217
218 HiddenStateIn_.swap(HiddenStateOut_);
219
220 Compilation compilation(&model_);
221 compilation.finish();
222 Execution execution(&compilation);
223 #define SetInputOrWeight(X) \
224 ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), \
225 sizeof(float) * X##_.size()), \
226 Result::NO_ERROR);
227
228 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
229
230 #undef SetInputOrWeight
231
232 #define SetOutput(X) \
233 ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), \
234 sizeof(float) * X##_.size()), \
235 Result::NO_ERROR);
236
237 FOR_ALL_OUTPUT_TENSORS(SetOutput);
238
239 #undef SetOutput
240
241 ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
242 sizeof(activation_)),
243 Result::NO_ERROR);
244
245 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
246 }
247
248 private:
249 Model model_;
250
251 const uint32_t batches_;
252 const uint32_t units_;
253 const uint32_t input_size_;
254
255 const int activation_;
256
257 #define DefineTensor(X) std::vector<float> X##_;
258
259 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
260 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
261
262 #undef DefineTensor
263 };
264
TEST(RNNOpTest,BlackBoxTest)265 TEST(RNNOpTest, BlackBoxTest) {
266 BasicRNNOpModel rnn(2, 16, 8);
267 rnn.SetWeights(
268 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
269 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
270 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
271 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
272 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
273 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
274 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
275 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
276 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
277 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
278 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
279 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
280 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
281 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
282 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
283 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
284 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
285 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
286 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
287 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
288 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
289 0.277308, 0.415818});
290
291 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
292 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
293 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
294 -0.37609905});
295
296 rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
297 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
298 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
299 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
300 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
301 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
302 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
303 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
304 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
305 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
306 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
307 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
308 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
309 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
310 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
311 0.1});
312
313 rnn.ResetHiddenState();
314 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
315 (rnn.input_size() * rnn.num_batches());
316
317 for (int i = 0; i < input_sequence_size; i++) {
318 float* batch_start = rnn_input + i * rnn.input_size();
319 float* batch_end = batch_start + rnn.input_size();
320 rnn.SetInput(0, batch_start, batch_end);
321 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
322
323 rnn.Invoke();
324
325 float* golden_start = rnn_golden_output + i * rnn.num_units();
326 float* golden_end = golden_start + rnn.num_units();
327 std::vector<float> expected;
328 expected.insert(expected.end(), golden_start, golden_end);
329 expected.insert(expected.end(), golden_start, golden_end);
330
331 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
332 }
333 }
334
335 } // namespace wrapper
336 } // namespace nn
337 } // namespace android
338