• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <gmock/gmock-matchers.h>
17 #include <gtest/gtest.h>
18 
19 #include <limits>
20 #include <vector>
21 
22 #include "OperationsUtils.cpp"
23 #include "QuantUtils.h"
24 
25 namespace android {
26 namespace nn {
27 namespace wrapper {
28 
29 namespace {
30 using ::testing::ElementsAreArray;
31 }  // namespace
32 
TEST(CalculateBroadcastedShapeTest,Basic)33 TEST(CalculateBroadcastedShapeTest, Basic) {
34     Shape shape1;
35     Shape shape2;
36     shape1.dimensions = {4, 3, 2, 1};
37     shape2.dimensions = {3, 1, 5};
38 
39     Shape expectedOutputShape;
40     expectedOutputShape.dimensions = {4, 3, 2, 5};
41 
42     Shape actualOutputShape;
43     EXPECT_TRUE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
44     EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
45 
46     EXPECT_TRUE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
47     EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
48 }
49 
TEST(CalculateBroadcastedShapeTest,FailsOnIncompatible)50 TEST(CalculateBroadcastedShapeTest, FailsOnIncompatible) {
51     Shape shape1;
52     Shape shape2;
53     shape1.dimensions = {5};
54     shape2.dimensions = {3};
55 
56     Shape actualOutputShape;
57     EXPECT_FALSE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
58     EXPECT_FALSE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
59 }
60 
getExtensionType(uint16_t extensionPrefix,uint16_t typeWithinExtension)61 static int32_t getExtensionType(uint16_t extensionPrefix, uint16_t typeWithinExtension) {
62     constexpr uint8_t kLowBitsType = static_cast<uint8_t>(ExtensionTypeEncoding::LOW_BITS_TYPE);
63     int32_t type = (extensionPrefix << kLowBitsType) | typeWithinExtension;
64     EXPECT_TRUE(isExtensionOperandType(static_cast<OperandType>(type)));
65     return type;
66 }
67 
TEST(TensorHasUnspecifiedDimensionsTest,ExtensionTensorWithUnspecifiedRank)68 TEST(TensorHasUnspecifiedDimensionsTest, ExtensionTensorWithUnspecifiedRank) {
69     // Regression test for b/124285861.
70     EXPECT_TRUE(tensorHasUnspecifiedDimensions(getExtensionType(1, 0), /*dim=*/nullptr,
71                                                /*dimCount=*/0));
72 }
73 
TEST(ValidateOperandTypeTest,ExtensionTensorWithUnspecifiedRank)74 TEST(ValidateOperandTypeTest, ExtensionTensorWithUnspecifiedRank) {
75     // Regression test for b/124104123.
76     constexpr uint16_t kExtensionPrefix = 1;
77     constexpr uint16_t kTypeWithinExtension = 0;
78     int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension);
79     ANeuralNetworksOperandType type = {
80             .type = extensionType,
81             .dimensionCount = 0,
82             .dimensions = nullptr,
83     };
84     Extension::OperandTypeInformation info = {
85             .type = kTypeWithinExtension,
86             .isTensor = true,
87             .byteSize = 4,
88     };
89     EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true),
90               ANEURALNETWORKS_NO_ERROR);
91     EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/false),
92               ANEURALNETWORKS_BAD_DATA);
93 }
94 
TEST(ValidateOperandTypeTest,ExtensionTypeDimensionProductOverflow)95 TEST(ValidateOperandTypeTest, ExtensionTypeDimensionProductOverflow) {
96     // Regression test for b/146044137.
97     constexpr uint16_t kExtensionPrefix = 1;
98     constexpr uint16_t kTypeWithinExtension = 0;
99     int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension);
100     uint32_t dimensions[] = {5, 4, 4, 786433, 5, 3, 16777216, 4, 5};
101     ANeuralNetworksOperandType type = {
102             .type = extensionType,
103             .dimensionCount = std::size(dimensions),
104             .dimensions = dimensions,
105     };
106     Extension::OperandTypeInformation info = {
107             .type = kTypeWithinExtension,
108             .isTensor = true,
109             .byteSize = 1,
110     };
111     EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true),
112               ANEURALNETWORKS_BAD_DATA);
113 }
114 
TEST(ValidateOperandTypeTest,TensorSizeDimensionProductOverflow)115 TEST(ValidateOperandTypeTest, TensorSizeDimensionProductOverflow) {
116     // Regression test for b/146044137.
117     uint32_t dimensions[] = {256, 256, 256, 256};
118     ANeuralNetworksOperandType type = {
119             .type = ANEURALNETWORKS_TENSOR_FLOAT32,
120             .dimensionCount = std::size(dimensions),
121             .dimensions = dimensions,
122     };
123     EXPECT_EQ(validateOperandType(type, nullptr, /*tag=*/"test", /*allowPartial=*/true),
124               ANEURALNETWORKS_BAD_DATA);
125 }
126 
127 class CombineDimensionsTest : public ::testing::Test {
128    protected:
testCompatible(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs,const std::vector<uint32_t> & expected)129     void testCompatible(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs,
130                         const std::vector<uint32_t>& expected) {
131         SCOPED_TRACE("lhs = " + toString(lhs) + ", rhs = " + toString(rhs));
132         const auto res = combineDimensions(lhs, rhs);
133         ASSERT_TRUE(res.has_value());
134         EXPECT_EQ(res.value(), expected);
135     }
136 
testIncompatible(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs)137     void testIncompatible(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs) {
138         SCOPED_TRACE("lhs = " + toString(lhs) + ", rhs = " + toString(rhs));
139         const auto res = combineDimensions(lhs, rhs);
140         EXPECT_FALSE(res.has_value());
141     }
142 };
143 
TEST_F(CombineDimensionsTest,Rank)144 TEST_F(CombineDimensionsTest, Rank) {
145     testCompatible({}, {1, 2, 3, 4}, {1, 2, 3, 4});
146     testCompatible({1, 2, 3, 4}, {}, {1, 2, 3, 4});
147     testCompatible({}, {}, {});
148     testIncompatible({1, 2, 3}, {1, 2, 3, 4});
149     testIncompatible({1, 2, 3, 4}, {1, 2, 3});
150 }
151 
TEST_F(CombineDimensionsTest,Dimensions)152 TEST_F(CombineDimensionsTest, Dimensions) {
153     testCompatible({0, 0, 0, 0}, {1, 2, 3, 4}, {1, 2, 3, 4});
154     testCompatible({1, 2, 3, 4}, {0, 0, 0, 0}, {1, 2, 3, 4});
155     testCompatible({0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0});
156     testIncompatible({1, 2, 3, 4}, {2, 2, 3, 4});
157     testIncompatible({1, 2, 3, 4}, {1, 2, 3, 3});
158 }
159 
TEST(QuantizationUtilsTest,QuantizeMultiplierSmallerThanOneExp)160 TEST(QuantizationUtilsTest, QuantizeMultiplierSmallerThanOneExp) {
161     auto checkInvalidQuantization = [](double value) {
162         int32_t q;
163         int s;
164         EXPECT_FALSE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
165     };
166 
167     checkInvalidQuantization(-0.1);
168     checkInvalidQuantization(0.0);
169     // If we get close enough to 1.0 it crashes and dies in one of two ways:
170     // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
171     checkInvalidQuantization(1 - 1e-15);
172     checkInvalidQuantization(1 - 1e-17);
173     checkInvalidQuantization(1.0);
174 
175     auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
176         int32_t q;
177         int s;
178         EXPECT_TRUE(QuantizeMultiplierSmallerThanOneExp(value, &q, &s));
179         EXPECT_EQ(q, goldenQuantized);
180         EXPECT_EQ(s, goldenShift);
181     };
182 
183     checkQuantization(0.25, 1073741824, -1);
184     checkQuantization(0.50 - 5e-9, 2147483627, -1);
185     checkQuantization(0.50 - 1e-10, 1073741824, 0);
186     checkQuantization(0.50, 1073741824, 0);
187     checkQuantization(0.75, 1610612736, 0);
188     checkQuantization(1 - 1e-9, 2147483646, 0);
189 }
190 
TEST(QuantizationUtilsTest,QuantizeMultiplierGreaterThanOne)191 TEST(QuantizationUtilsTest, QuantizeMultiplierGreaterThanOne) {
192     auto checkInvalidQuantization = [](double value) {
193         int32_t q;
194         int s;
195         EXPECT_FALSE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
196     };
197 
198     checkInvalidQuantization(1 + 1e-16);
199 
200     auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
201         int32_t q;
202         int s;
203         EXPECT_TRUE(QuantizeMultiplierGreaterThanOne(value, &q, &s));
204         EXPECT_EQ(q, goldenQuantized);
205         EXPECT_EQ(s, goldenShift);
206     };
207 
208     checkQuantization(1 + 1e-11, 1073741824, 1);
209     checkQuantization(1.25, 1342177280, 1);
210     checkQuantization(1.50, 1610612736, 1);
211     checkQuantization(1.50, 1610612736, 1);
212     checkQuantization(1.75, 1879048192, 1);
213     checkQuantization(2 - 1e-9, 2147483647, 1);
214     checkQuantization(2 - 1e-11, 1073741824, 2);
215     checkQuantization(2, 1073741824, 2);
216 }
217 
TEST(QuantizationUtilTest,QuantizeMultiplier)218 TEST(QuantizationUtilTest, QuantizeMultiplier) {
219     auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
220         int32_t q;
221         int s;
222         EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
223         EXPECT_EQ(q, goldenQuantized);
224         EXPECT_EQ(s, goldenShift);
225     };
226 
227     checkQuantization(-4, -1073741824, 3);
228     checkQuantization(-2, -1073741824, 2);
229     checkQuantization(-1, -1073741824, 1);
230     checkQuantization(-0.5, -1073741824, 0);
231     checkQuantization(-0.25, -1073741824, -1);
232     checkQuantization(-0.125, -1073741824, -2);
233     checkQuantization(0, 0, 0);
234     checkQuantization(0.125, 1073741824, -2);
235     checkQuantization(0.25, 1073741824, -1);
236     checkQuantization(0.5, 1073741824, 0);
237     checkQuantization(1, 1073741824, 1);
238     checkQuantization(2, 1073741824, 2);
239     checkQuantization(4, 1073741824, 3);
240 }
241 
TEST(QuantizationUtilTest,QuantizeMultiplierUnderflow)242 TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
243     auto checkQuantization = [](double value, int32_t goldenQuantized, int goldenShift) {
244         int32_t q;
245         int s;
246         EXPECT_TRUE(QuantizeMultiplier(value, &q, &s));
247         EXPECT_EQ(q, goldenQuantized);
248         EXPECT_EQ(s, goldenShift);
249     };
250 
251     checkQuantization(std::ldexp(1.0f, -31), 1073741824, -30);
252     checkQuantization(std::ldexp(1.0f, -32), 1073741824, -31);
253     checkQuantization(std::ldexp(0.99f, -32), 0, 0);
254     checkQuantization(std::ldexp(1.0f, -33), 0, 0);
255 }
256 
TEST(QuantizationUtilTest,GetInvSqrtQuantizedMultiplierExp)257 TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
258     auto checkInvSqrtQuantization = [](int32_t input, int32_t goldenInvSqrt, int goldenShift) {
259         int32_t q;
260         int s;
261         EXPECT_TRUE(GetInvSqrtQuantizedMultiplierExp(input, 1, &q, &s));
262         EXPECT_EQ(q, goldenInvSqrt);
263         EXPECT_EQ(s, goldenShift);
264     };
265 
266     const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
267     checkInvSqrtQuantization(0, kInt32Max, 0);
268     checkInvSqrtQuantization(1, kInt32Max, 0);
269     checkInvSqrtQuantization(2, 1518498372, 0);
270     checkInvSqrtQuantization(3, 1239850284, 0);
271     checkInvSqrtQuantization(4, 1073741828, 0);
272     checkInvSqrtQuantization(100, 214748363, 0);
273     checkInvSqrtQuantization(10000, 343597361, 4);
274     checkInvSqrtQuantization(1000000, 274877901, 7);
275     checkInvSqrtQuantization(100000000, 219902323, 10);
276     checkInvSqrtQuantization((1 << 30), 268435457, 12);
277     checkInvSqrtQuantization(kInt32Max, 189812531, 12);
278 }
279 
280 }  // namespace wrapper
281 }  // namespace nn
282 }  // namespace android
283