• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 #include <vector>
21 
22 #include "HashtableLookup.h"
23 #include "NeuralNetworksWrapper.h"
24 
25 using ::testing::FloatNear;
26 using ::testing::Matcher;
27 
28 namespace android {
29 namespace nn {
30 namespace wrapper {
31 
32 namespace {
33 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)34 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
35                                            float max_abs_error = 1.e-6) {
36     std::vector<Matcher<float>> matchers;
37     matchers.reserve(values.size());
38     for (const float& v : values) {
39         matchers.emplace_back(FloatNear(v, max_abs_error));
40     }
41     return matchers;
42 }
43 
44 }  // namespace
45 
46 using ::testing::ElementsAreArray;
47 
48 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
49     ACTION(Lookup, int)                          \
50     ACTION(Key, int)                             \
51     ACTION(Value, float)
52 
53 // For all output and intermediate states
54 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
55     ACTION(Output, float)              \
56     ACTION(Hits, uint8_t)
57 
58 class HashtableLookupOpModel {
59    public:
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,std::initializer_list<uint32_t> key_shape,std::initializer_list<uint32_t> value_shape)60     HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
61                            std::initializer_list<uint32_t> key_shape,
62                            std::initializer_list<uint32_t> value_shape) {
63         auto it_vs = value_shape.begin();
64         rows_ = *it_vs++;
65         features_ = *it_vs;
66 
67         std::vector<uint32_t> inputs;
68 
69         // Input and weights
70         OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
71         inputs.push_back(model_.addOperand(&LookupTy));
72 
73         OperandType KeyTy(Type::TENSOR_INT32, key_shape);
74         inputs.push_back(model_.addOperand(&KeyTy));
75 
76         OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
77         inputs.push_back(model_.addOperand(&ValueTy));
78 
79         // Output and other intermediate state
80         std::vector<uint32_t> outputs;
81 
82         std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
83         out_dim.push_back(features_);
84 
85         OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
86         outputs.push_back(model_.addOperand(&OutputOpndTy));
87 
88         OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
89         outputs.push_back(model_.addOperand(&HitsOpndTy));
90 
91         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
92             uint32_t sz = 1;
93             for (uint32_t d : dims) {
94                 sz *= d;
95             }
96             return sz;
97         };
98 
99         Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
100         Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
101         Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
102 
103         model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
104         model_.identifyInputsAndOutputs(inputs, outputs);
105 
106         model_.finish();
107     }
108 
Invoke()109     void Invoke() {
110         ASSERT_TRUE(model_.isValid());
111 
112         Compilation compilation(&model_);
113         compilation.finish();
114         Execution execution(&compilation);
115 
116 #define SetInputOrWeight(X, T)                                               \
117     ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
118                                  sizeof(T) * X##_.size()),                   \
119               Result::NO_ERROR);
120 
121         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
122 
123 #undef SetInputOrWeight
124 
125 #define SetOutput(X, T)                                                       \
126     ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
127                                   sizeof(T) * X##_.size()),                   \
128               Result::NO_ERROR);
129 
130         FOR_ALL_OUTPUT_TENSORS(SetOutput);
131 
132 #undef SetOutput
133 
134         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
135     }
136 
137 #define DefineSetter(X, T) \
138     void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
139 
140     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
141 
142 #undef DefineSetter
143 
SetHashtableValue(const std::function<float (uint32_t,uint32_t)> & function)144     void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
145         for (uint32_t i = 0; i < rows_; i++) {
146             for (uint32_t j = 0; j < features_; j++) {
147                 Value_[i * features_ + j] = function(i, j);
148             }
149         }
150     }
151 
GetOutput() const152     const std::vector<float>& GetOutput() const { return Output_; }
GetHits() const153     const std::vector<uint8_t>& GetHits() const { return Hits_; }
154 
155    private:
156     Model model_;
157     uint32_t rows_;
158     uint32_t features_;
159 
160 #define DefineTensor(X, T) std::vector<T> X##_;
161 
162     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
163     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
164 
165 #undef DefineTensor
166 };
167 
TEST(HashtableLookupOpTest,BlackBoxTest)168 TEST(HashtableLookupOpTest, BlackBoxTest) {
169     HashtableLookupOpModel m({4}, {3}, {3, 2});
170 
171     m.SetLookup({1234, -292, -11, 0});
172     m.SetKey({-11, 0, 1234});
173     m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
174 
175     m.Invoke();
176 
177     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
178                                        2.0, 2.1,  // 2-rd item
179                                        0, 0,      // Not found
180                                        0.0, 0.1,  // 0-th item
181                                        1.0, 1.1,  // 1-st item
182                                })));
183     EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
184                                    1,
185                                    0,
186                                    1,
187                                    1,
188                            }));
189 }
190 
191 }  // namespace wrapper
192 }  // namespace nn
193 }  // namespace android
194