1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19
20 #include <vector>
21
22 #include "HashtableLookup.h"
23 #include "NeuralNetworksWrapper.h"
24
25 using ::testing::FloatNear;
26 using ::testing::Matcher;
27
28 namespace android {
29 namespace nn {
30 namespace wrapper {
31
32 namespace {
33
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)34 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
35 float max_abs_error = 1.e-6) {
36 std::vector<Matcher<float>> matchers;
37 matchers.reserve(values.size());
38 for (const float& v : values) {
39 matchers.emplace_back(FloatNear(v, max_abs_error));
40 }
41 return matchers;
42 }
43
44 } // namespace
45
46 using ::testing::ElementsAreArray;
47
48 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
49 ACTION(Lookup, int) \
50 ACTION(Key, int) \
51 ACTION(Value, float)
52
53 // For all output and intermediate states
54 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
55 ACTION(Output, float) \
56 ACTION(Hits, uint8_t)
57
58 class HashtableLookupOpModel {
59 public:
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,std::initializer_list<uint32_t> key_shape,std::initializer_list<uint32_t> value_shape)60 HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
61 std::initializer_list<uint32_t> key_shape,
62 std::initializer_list<uint32_t> value_shape) {
63 auto it_vs = value_shape.begin();
64 rows_ = *it_vs++;
65 features_ = *it_vs;
66
67 std::vector<uint32_t> inputs;
68
69 // Input and weights
70 OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
71 inputs.push_back(model_.addOperand(&LookupTy));
72
73 OperandType KeyTy(Type::TENSOR_INT32, key_shape);
74 inputs.push_back(model_.addOperand(&KeyTy));
75
76 OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
77 inputs.push_back(model_.addOperand(&ValueTy));
78
79 // Output and other intermediate state
80 std::vector<uint32_t> outputs;
81
82 std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
83 out_dim.push_back(features_);
84
85 OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
86 outputs.push_back(model_.addOperand(&OutputOpndTy));
87
88 OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
89 outputs.push_back(model_.addOperand(&HitsOpndTy));
90
91 auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
92 uint32_t sz = 1;
93 for (uint32_t d : dims) {
94 sz *= d;
95 }
96 return sz;
97 };
98
99 Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
100 Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
101 Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
102
103 model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
104 model_.identifyInputsAndOutputs(inputs, outputs);
105
106 model_.finish();
107 }
108
Invoke()109 void Invoke() {
110 ASSERT_TRUE(model_.isValid());
111
112 Compilation compilation(&model_);
113 compilation.finish();
114 Execution execution(&compilation);
115
116 #define SetInputOrWeight(X, T) \
117 ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
118 sizeof(T) * X##_.size()), \
119 Result::NO_ERROR);
120
121 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
122
123 #undef SetInputOrWeight
124
125 #define SetOutput(X, T) \
126 ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
127 sizeof(T) * X##_.size()), \
128 Result::NO_ERROR);
129
130 FOR_ALL_OUTPUT_TENSORS(SetOutput);
131
132 #undef SetOutput
133
134 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
135 }
136
137 #define DefineSetter(X, T) \
138 void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
139
140 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
141
142 #undef DefineSetter
143
SetHashtableValue(const std::function<float (uint32_t,uint32_t)> & function)144 void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
145 for (uint32_t i = 0; i < rows_; i++) {
146 for (uint32_t j = 0; j < features_; j++) {
147 Value_[i * features_ + j] = function(i, j);
148 }
149 }
150 }
151
GetOutput() const152 const std::vector<float>& GetOutput() const { return Output_; }
GetHits() const153 const std::vector<uint8_t>& GetHits() const { return Hits_; }
154
155 private:
156 Model model_;
157 uint32_t rows_;
158 uint32_t features_;
159
160 #define DefineTensor(X, T) std::vector<T> X##_;
161
162 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
163 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
164
165 #undef DefineTensor
166 };
167
TEST(HashtableLookupOpTest,BlackBoxTest)168 TEST(HashtableLookupOpTest, BlackBoxTest) {
169 HashtableLookupOpModel m({4}, {3}, {3, 2});
170
171 m.SetLookup({1234, -292, -11, 0});
172 m.SetKey({-11, 0, 1234});
173 m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
174
175 m.Invoke();
176
177 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
178 2.0, 2.1, // 2-rd item
179 0, 0, // Not found
180 0.0, 0.1, // 0-th item
181 1.0, 1.1, // 1-st item
182 })));
183 EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
184 1,
185 0,
186 1,
187 1,
188 }));
189 }
190
191 } // namespace wrapper
192 } // namespace nn
193 } // namespace android
194