• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RNN.h"
18 
19 #include "NeuralNetworksWrapper.h"
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22 
23 namespace android {
24 namespace nn {
25 namespace wrapper {
26 
27 using ::testing::Each;
28 using ::testing::FloatNear;
29 using ::testing::Matcher;
30 
31 namespace {
32 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-5)33 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
34                                            float max_abs_error = 1.e-5) {
35   std::vector<Matcher<float>> matchers;
36   matchers.reserve(values.size());
37   for (const float& v : values) {
38     matchers.emplace_back(FloatNear(v, max_abs_error));
39   }
40   return matchers;
41 }
42 
43 static float rnn_input[] = {
44     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
45     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
46     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
47     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
48     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
49     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
50     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
51     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
52     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
53     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
54     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
55     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
56     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
57     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
58     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
59     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
60     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
61     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
62     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
63     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
64     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
65     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
66     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
67     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
68     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
69     0.93455386,   -0.6324693,   -0.083922029};
70 
71 static float rnn_golden_output[] = {
72     0.496726,   0,          0.965996,  0,         0.0584254, 0,
73     0,          0.12315,    0,         0,         0.612266,  0.456601,
74     0,          0.52286,    1.16099,   0.0291232,
75 
76     0,          0,          0.524901,  0,         0,         0,
77     0,          1.02116,    0,         1.35762,   0,         0.356909,
78     0.436415,   0.0355727,  0,         0,
79 
80     0,          0,          0,         0.262335,  0,         0,
81     0,          1.33992,    0,         2.9739,    0,         0,
82     1.31914,    2.66147,    0,         0,
83 
84     0.942568,   0,          0,         0,         0.025507,  0,
85     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
86     0.8158,     1.21805,    0.586239,  0.25427,
87 
88     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
89     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
90     0,          1.22031,    1.30117,   0.495867,
91 
92     0.222187,   0,          0.72725,   0,         0.767003,  0,
93     0,          0.147835,   0,         0,         0,         0.608758,
94     0.469394,   0.00720298, 0.927537,  0,
95 
96     0.856974,   0.424257,   0,         0,         0.937329,  0,
97     0,          0,          0.476425,  0,         0.566017,  0.418462,
98     0.141911,   0.996214,   1.13063,   0,
99 
100     0.967899,   0,          0,         0,         0.0831304, 0,
101     0,          1.00378,    0,         0,         0,         1.44818,
102     1.01768,    0.943891,   0.502745,  0,
103 
104     0.940135,   0,          0,         0,         0,         0,
105     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
106     1.30225,    1.59644,    0.70222,   0,
107 
108     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
109     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
110     0.0454298,  0.300267,   0.562784,  0.395095,
111 
112     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
113     0,          0,          0,         0.735363,  0.0759267, 1.91017,
114     0.941888,   0,          0,         0,
115 
116     0,          0,          1.5909,    0,         0,         0,
117     0,          0.5755,     0,         0.184687,  0,         1.56296,
118     0.625285,   0,          0,         0,
119 
120     0,          0,          0.0857888, 0,         0,         0,
121     0,          0.488383,   0.252786,  0,         0,         0,
122     1.02817,    1.85665,    0,         0,
123 
124     0.00981836, 0,          1.06371,   0,         0,         0,
125     0,          0,          0,         0.290445,  0.316406,  0,
126     0.304161,   1.25079,    0.0707152, 0,
127 
128     0.986264,   0.309201,   0,         0,         0,         0,
129     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
130     0.524981,   1.92076,    2.07013,   0.333244,
131 
132     0.415153,   0.210318,   0,         0,         0,         0,
133     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
134     0.628881,   3.58099,    1.49974,   0};
135 
136 }  // anonymous namespace
137 
138 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
139   ACTION(Input)                                  \
140   ACTION(Weights)                                \
141   ACTION(RecurrentWeights)                       \
142   ACTION(Bias)                                   \
143   ACTION(HiddenStateIn)
144 
145 // For all output and intermediate states
146 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
147   ACTION(HiddenStateOut)               \
148   ACTION(Output)
149 
150 class BasicRNNOpModel {
151  public:
BasicRNNOpModel(uint32_t batches,uint32_t units,uint32_t size)152   BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
153       : batches_(batches),
154         units_(units),
155         input_size_(size),
156         activation_(kActivationRelu) {
157     std::vector<uint32_t> inputs;
158 
159     OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
160     inputs.push_back(model_.addOperand(&InputTy));
161     OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
162     inputs.push_back(model_.addOperand(&WeightTy));
163     OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
164     inputs.push_back(model_.addOperand(&RecurrentWeightTy));
165     OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
166     inputs.push_back(model_.addOperand(&BiasTy));
167     OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
168     inputs.push_back(model_.addOperand(&HiddenStateTy));
169     OperandType ActionParamTy(Type::INT32, {1});
170     inputs.push_back(model_.addOperand(&ActionParamTy));
171 
172     std::vector<uint32_t> outputs;
173 
174     outputs.push_back(model_.addOperand(&HiddenStateTy));
175     OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
176     outputs.push_back(model_.addOperand(&OutputTy));
177 
178     Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
179     HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
180     HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
181     Output_.insert(Output_.end(), batches_ * units_, 0.f);
182 
183     model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
184     model_.identifyInputsAndOutputs(inputs, outputs);
185 
186     model_.finish();
187   }
188 
189 #define DefineSetter(X)                          \
190   void Set##X(const std::vector<float>& f) {     \
191     X##_.insert(X##_.end(), f.begin(), f.end()); \
192   }
193 
194   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
195 
196 #undef DefineSetter
197 
SetInput(int offset,float * begin,float * end)198   void SetInput(int offset, float* begin, float* end) {
199     for (; begin != end; begin++, offset++) {
200       Input_[offset] = *begin;
201     }
202   }
203 
ResetHiddenState()204   void ResetHiddenState() {
205     std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
206     std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
207   }
208 
GetOutput() const209   const std::vector<float>& GetOutput() const { return Output_; }
210 
input_size() const211   uint32_t input_size() const { return input_size_; }
num_units() const212   uint32_t num_units() const { return units_; }
num_batches() const213   uint32_t num_batches() const { return batches_; }
214 
Invoke()215   void Invoke() {
216     ASSERT_TRUE(model_.isValid());
217 
218     HiddenStateIn_.swap(HiddenStateOut_);
219 
220     Compilation compilation(&model_);
221     compilation.finish();
222     Execution execution(&compilation);
223 #define SetInputOrWeight(X)                                                   \
224   ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(),                \
225                                sizeof(float) * X##_.size()),                  \
226             Result::NO_ERROR);
227 
228     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
229 
230 #undef SetInputOrWeight
231 
232 #define SetOutput(X)                                                           \
233   ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(),                \
234                                 sizeof(float) * X##_.size()),                  \
235             Result::NO_ERROR);
236 
237     FOR_ALL_OUTPUT_TENSORS(SetOutput);
238 
239 #undef SetOutput
240 
241     ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
242                                  sizeof(activation_)),
243               Result::NO_ERROR);
244 
245     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
246   }
247 
248  private:
249   Model model_;
250 
251   const uint32_t batches_;
252   const uint32_t units_;
253   const uint32_t input_size_;
254 
255   const int activation_;
256 
257 #define DefineTensor(X) std::vector<float> X##_;
258 
259   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
260   FOR_ALL_OUTPUT_TENSORS(DefineTensor);
261 
262 #undef DefineTensor
263 };
264 
TEST(RNNOpTest,BlackBoxTest)265 TEST(RNNOpTest, BlackBoxTest) {
266   BasicRNNOpModel rnn(2, 16, 8);
267   rnn.SetWeights(
268       {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
269        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
270        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
271        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
272        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
273        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
274        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
275        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
276        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
277        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
278        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
279        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
280        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
281        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
282        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
283        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
284        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
285        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
286        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
287        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
288        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
289        0.277308,    0.415818});
290 
291   rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
292                -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
293                0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
294                -0.37609905});
295 
296   rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
297                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
298                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
299                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
300                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
301                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
302                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
303                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
304                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
305                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
306                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
307                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
308                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
309                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
310                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
311                            0.1});
312 
313   rnn.ResetHiddenState();
314   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
315                                   (rnn.input_size() * rnn.num_batches());
316 
317   for (int i = 0; i < input_sequence_size; i++) {
318     float* batch_start = rnn_input + i * rnn.input_size();
319     float* batch_end = batch_start + rnn.input_size();
320     rnn.SetInput(0, batch_start, batch_end);
321     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
322 
323     rnn.Invoke();
324 
325     float* golden_start = rnn_golden_output + i * rnn.num_units();
326     float* golden_end = golden_start + rnn.num_units();
327     std::vector<float> expected;
328     expected.insert(expected.end(), golden_start, golden_end);
329     expected.insert(expected.end(), golden_start, golden_end);
330 
331     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
332   }
333 }
334 
335 }  // namespace wrapper
336 }  // namespace nn
337 }  // namespace android
338