1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <gmock/gmock-matchers.h>
17 #include <gtest/gtest.h>
18
19 #include <limits>
20 #include <vector>
21
22 #include "OperationsUtils.cpp"
23 #include "QuantUtils.h"
24
25 namespace android {
26 namespace nn {
27 namespace wrapper {
28
29 namespace {
30 using ::testing::ElementsAreArray;
31 } // namespace
32
TEST(CalculateBroadcastedShapeTest,Basic)33 TEST(CalculateBroadcastedShapeTest, Basic) {
34 Shape shape1;
35 Shape shape2;
36 shape1.dimensions = {4, 3, 2, 1};
37 shape2.dimensions = {3, 1, 5};
38
39 Shape expectedOutputShape;
40 expectedOutputShape.dimensions = {4, 3, 2, 5};
41
42 Shape actualOutputShape;
43 EXPECT_TRUE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
44 EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
45
46 EXPECT_TRUE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
47 EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
48 }
49
TEST(CalculateBroadcastedShapeTest,FailsOnIncompatible)50 TEST(CalculateBroadcastedShapeTest, FailsOnIncompatible) {
51 Shape shape1;
52 Shape shape2;
53 shape1.dimensions = {5};
54 shape2.dimensions = {3};
55
56 Shape actualOutputShape;
57 EXPECT_FALSE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
58 EXPECT_FALSE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
59 }
60
getExtensionType(uint16_t extensionPrefix,uint16_t typeWithinExtension)61 static int32_t getExtensionType(uint16_t extensionPrefix, uint16_t typeWithinExtension) {
62 constexpr uint8_t kLowBitsType = static_cast<uint8_t>(ExtensionTypeEncoding::LOW_BITS_TYPE);
63 int32_t type = (extensionPrefix << kLowBitsType) | typeWithinExtension;
64 EXPECT_TRUE(isExtensionOperandType(static_cast<OperandType>(type)));
65 return type;
66 }
67
TEST(TensorHasUnspecifiedDimensionsTest,ExtensionTensorWithUnspecifiedRank)68 TEST(TensorHasUnspecifiedDimensionsTest, ExtensionTensorWithUnspecifiedRank) {
69 // Regression test for b/124285861.
70 EXPECT_TRUE(tensorHasUnspecifiedDimensions(getExtensionType(1, 0), /*dim=*/nullptr,
71 /*dimCount=*/0));
72 }
73
TEST(ValidateOperandTypeTest,ExtensionTensorWithUnspecifiedRank)74 TEST(ValidateOperandTypeTest, ExtensionTensorWithUnspecifiedRank) {
75 // Regression test for b/124104123.
76 constexpr uint16_t kExtensionPrefix = 1;
77 constexpr uint16_t kTypeWithinExtension = 0;
78 int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension);
79 ANeuralNetworksOperandType type = {
80 .type = extensionType,
81 .dimensionCount = 0,
82 .dimensions = nullptr,
83 };
84 Extension::OperandTypeInformation info = {
85 .type = kTypeWithinExtension,
86 .isTensor = true,
87 .byteSize = 4,
88 };
89 EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true),
90 ANEURALNETWORKS_NO_ERROR);
91 EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/false),
92 ANEURALNETWORKS_BAD_DATA);
93 }
94
TEST(ValidateOperandTypeTest,ExtensionTypeDimensionProductOverflow)95 TEST(ValidateOperandTypeTest, ExtensionTypeDimensionProductOverflow) {
96 // Regression test for b/146044137.
97 constexpr uint16_t kExtensionPrefix = 1;
98 constexpr uint16_t kTypeWithinExtension = 0;
99 int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension);
100 uint32_t dimensions[] = {5, 4, 4, 786433, 5, 3, 16777216, 4, 5};
101 ANeuralNetworksOperandType type = {
102 .type = extensionType,
103 .dimensionCount = std::size(dimensions),
104 .dimensions = dimensions,
105 };
106 Extension::OperandTypeInformation info = {
107 .type = kTypeWithinExtension,
108 .isTensor = true,
109 .byteSize = 1,
110 };
111 EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true),
112 ANEURALNETWORKS_BAD_DATA);
113 }
114
TEST(ValidateOperandTypeTest,TensorSizeDimensionProductOverflow)115 TEST(ValidateOperandTypeTest, TensorSizeDimensionProductOverflow) {
116 // Regression test for b/146044137.
117 uint32_t dimensions[] = {256, 256, 256, 256};
118 ANeuralNetworksOperandType type = {
119 .type = ANEURALNETWORKS_TENSOR_FLOAT32,
120 .dimensionCount = std::size(dimensions),
121 .dimensions = dimensions,
122 };
123 EXPECT_EQ(validateOperandType(type, nullptr, /*tag=*/"test", /*allowPartial=*/true),
124 ANEURALNETWORKS_BAD_DATA);
125 }
126
127 class CombineDimensionsTest : public ::testing::Test {
128 protected:
testCompatible(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs,const std::vector<uint32_t> & expected)129 void testCompatible(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs,
130 const std::vector<uint32_t>& expected) {
131 SCOPED_TRACE("lhs = " + toString(lhs) + ", rhs = " + toString(rhs));
132 const auto res = combineDimensions(lhs, rhs);
133 ASSERT_TRUE(res.has_value());
134 EXPECT_EQ(res.value(), expected);
135 }
136
testIncompatible(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs)137 void testIncompatible(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs) {
138 SCOPED_TRACE("lhs = " + toString(lhs) + ", rhs = " + toString(rhs));
139 const auto res = combineDimensions(lhs, rhs);
140 EXPECT_FALSE(res.has_value());
141 }
142 };
143
TEST_F(CombineDimensionsTest,Rank)144 TEST_F(CombineDimensionsTest, Rank) {
145 testCompatible({}, {1, 2, 3, 4}, {1, 2, 3, 4});
146 testCompatible({1, 2, 3, 4}, {}, {1, 2, 3, 4});
147 testCompatible({}, {}, {});
148 testIncompatible({1, 2, 3}, {1, 2, 3, 4});
149 testIncompatible({1, 2, 3, 4}, {1, 2, 3});
150 }
151
TEST_F(CombineDimensionsTest,Dimensions)152 TEST_F(CombineDimensionsTest, Dimensions) {
153 testCompatible({0, 0, 0, 0}, {1, 2, 3, 4}, {1, 2, 3, 4});
154 testCompatible({1, 2, 3, 4}, {0, 0, 0, 0}, {1, 2, 3, 4});
155 testCompatible({0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0});
156 testIncompatible({1, 2, 3, 4}, {2, 2, 3, 4});
157 testIncompatible({1, 2, 3, 4}, {1, 2, 3, 3});
158 }
159
TEST(QuantizationUtilsTest,QuantizeMultiplierSmallerThanOneExp)160 TEST(QuantizationUtilsTest, QuantizeMultiplierSmallerThanOneExp) {
161 auto checkInvalidQuantization = [](double value) {
162 int32_t q;
163 int s;
164 EXPECT_FALSE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
165 };
166
167 checkInvalidQuantization(-0.1);
168 checkInvalidQuantization(0.0);
169 // If we get close enough to 1.0 it crashes and dies in one of two ways:
170 // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
171 checkInvalidQuantization(1 - 1e-15);
172 checkInvalidQuantization(1 - 1e-17);
173 checkInvalidQuantization(1.0);
174
175 auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
176 int32_t q;
177 int s;
178 EXPECT_TRUE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
179 EXPECT_EQ(q, goldenQuantized);
180 EXPECT_EQ(s, goldenShift);
181 };
182
183 checkQuantization(0.25, 1073741824, -1);
184 checkQuantization(0.50 - 5e-9, 2147483627, -1);
185 checkQuantization(0.50 - 1e-10, 1073741824, 0);
186 checkQuantization(0.50, 1073741824, 0);
187 checkQuantization(0.75, 1610612736, 0);
188 checkQuantization(1 - 1e-9, 2147483646, 0);
189 }
190
TEST(QuantizationUtilsTest,QuantizeMultiplierGreaterThanOne)191 TEST(QuantizationUtilsTest, QuantizeMultiplierGreaterThanOne) {
192 auto checkInvalidQuantization = [](double value) {
193 int32_t q;
194 int s;
195 EXPECT_FALSE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
196 };
197
198 checkInvalidQuantization(1 + 1e-16);
199
200 auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
201 int32_t q;
202 int s;
203 EXPECT_TRUE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
204 EXPECT_EQ(q, goldenQuantized);
205 EXPECT_EQ(s, goldenShift);
206 };
207
208 checkQuantization(1 + 1e-11, 1073741824, 1);
209 checkQuantization(1.25, 1342177280, 1);
210 checkQuantization(1.50, 1610612736, 1);
211 checkQuantization(1.50, 1610612736, 1);
212 checkQuantization(1.75, 1879048192, 1);
213 checkQuantization(2 - 1e-9, 2147483647, 1);
214 checkQuantization(2 - 1e-11, 1073741824, 2);
215 checkQuantization(2, 1073741824, 2);
216 }
217
TEST(QuantizationUtilTest,QuantizeMultiplier)218 TEST(QuantizationUtilTest, QuantizeMultiplier) {
219 auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
220 int32_t q;
221 int s;
222 EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
223 EXPECT_EQ(q, goldenQuantized);
224 EXPECT_EQ(s, goldenShift);
225 };
226
227 checkQuantization(-4, -1073741824, 3);
228 checkQuantization(-2, -1073741824, 2);
229 checkQuantization(-1, -1073741824, 1);
230 checkQuantization(-0.5, -1073741824, 0);
231 checkQuantization(-0.25, -1073741824, -1);
232 checkQuantization(-0.125, -1073741824, -2);
233 checkQuantization(0, 0, 0);
234 checkQuantization(0.125, 1073741824, -2);
235 checkQuantization(0.25, 1073741824, -1);
236 checkQuantization(0.5, 1073741824, 0);
237 checkQuantization(1, 1073741824, 1);
238 checkQuantization(2, 1073741824, 2);
239 checkQuantization(4, 1073741824, 3);
240 }
241
TEST(QuantizationUtilTest,QuantizeMultiplierUnderflow)242 TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
243 auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
244 int32_t q;
245 int s;
246 EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
247 EXPECT_EQ(q, goldenQuantized);
248 EXPECT_EQ(s, goldenShift);
249 };
250
251 checkQuantization(std::ldexp(1.0f, -31), 1073741824, -30);
252 checkQuantization(std::ldexp(1.0f, -32), 1073741824, -31);
253 checkQuantization(std::ldexp(0.99f, -32), 0, 0);
254 checkQuantization(std::ldexp(1.0f, -33), 0, 0);
255 }
256
TEST(QuantizationUtilTest,GetInvSqrtQuantizedMultiplierExp)257 TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
258 auto checkInvSqrtQuantization = [](int32_t input, int32_t goldenInvSqrt, int goldenShift) {
259 int32_t q;
260 int s;
261 EXPECT_TRUE(GetInvSqrtQuantizedMultiplierExp(input, 1, &q, &s));
262 EXPECT_EQ(q, goldenInvSqrt);
263 EXPECT_EQ(s, goldenShift);
264 };
265
266 const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
267 checkInvSqrtQuantization(0, kInt32Max, 0);
268 checkInvSqrtQuantization(1, kInt32Max, 0);
269 checkInvSqrtQuantization(2, 1518498372, 0);
270 checkInvSqrtQuantization(3, 1239850284, 0);
271 checkInvSqrtQuantization(4, 1073741828, 0);
272 checkInvSqrtQuantization(100, 214748363, 0);
273 checkInvSqrtQuantization(10000, 343597361, 4);
274 checkInvSqrtQuantization(1000000, 274877901, 7);
275 checkInvSqrtQuantization(100000000, 219902323, 10);
276 checkInvSqrtQuantization((1 << 30), 268435457, 12);
277 checkInvSqrtQuantization(kInt32Max, 189812531, 12);
278 }
279
280 } // namespace wrapper
281 } // namespace nn
282 } // namespace android
283