• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "quantization.h"
18 
19 #include <vector>
20 
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 
24 using testing::ElementsAreArray;
25 using testing::FloatEq;
26 using testing::Matcher;
27 
28 namespace libtextclassifier2 {
29 namespace {
30 
ElementsAreFloat(const std::vector<float> & values)31 Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
32   std::vector<Matcher<float>> matchers;
33   for (const float value : values) {
34     matchers.push_back(FloatEq(value));
35   }
36   return ElementsAreArray(matchers);
37 }
38 
TEST(QuantizationTest,DequantizeAdd8bit)39 TEST(QuantizationTest, DequantizeAdd8bit) {
40   std::vector<float> scales{{0.1, 9.0, -7.0}};
41   std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
42                                  /*1: */ 0xFF, 0x09, 0x00, 0xFF,
43                                  /*2: */ 0x09, 0x00, 0xFF, 0x09}};
44 
45   const int quantization_bits = 8;
46   const int bytes_per_embedding = 4;
47   const int num_sparse_features = 7;
48   {
49     const int bucket_id = 0;
50     std::vector<float> dest(4, 0.0);
51     DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
52                   num_sparse_features, quantization_bits, bucket_id,
53                   dest.data(), dest.size());
54 
55     EXPECT_THAT(dest,
56                 ElementsAreFloat(std::vector<float>{
57                     // clang-format off
58                     {1.0 / 7 * 0.1 * (0x00 - 128),
59                      1.0 / 7 * 0.1 * (0xFF - 128),
60                      1.0 / 7 * 0.1 * (0x09 - 128),
61                      1.0 / 7 * 0.1 * (0x00 - 128)}
62                     // clang-format on
63                 }));
64   }
65 
66   {
67     const int bucket_id = 1;
68     std::vector<float> dest(4, 0.0);
69     DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
70                   num_sparse_features, quantization_bits, bucket_id,
71                   dest.data(), dest.size());
72 
73     EXPECT_THAT(dest,
74                 ElementsAreFloat(std::vector<float>{
75                     // clang-format off
76                     {1.0 / 7 * 9.0 * (0xFF - 128),
77                      1.0 / 7 * 9.0 * (0x09 - 128),
78                      1.0 / 7 * 9.0 * (0x00 - 128),
79                      1.0 / 7 * 9.0 * (0xFF - 128)}
80                     // clang-format on
81                 }));
82   }
83 }
84 
TEST(QuantizationTest,DequantizeAdd1bitZeros)85 TEST(QuantizationTest, DequantizeAdd1bitZeros) {
86   const int bytes_per_embedding = 4;
87   const int num_buckets = 3;
88   const int num_sparse_features = 7;
89   const int quantization_bits = 1;
90   const int bucket_id = 1;
91 
92   std::vector<float> scales(num_buckets);
93   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
94   std::fill(scales.begin(), scales.end(), 1);
95   std::fill(embeddings.begin(), embeddings.end(), 0);
96 
97   std::vector<float> dest(32);
98   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
99                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
100                 dest.size());
101 
102   std::vector<float> expected(32);
103   std::fill(expected.begin(), expected.end(),
104             1.0 / num_sparse_features * (0 - 1));
105   EXPECT_THAT(dest, ElementsAreFloat(expected));
106 }
107 
TEST(QuantizationTest,DequantizeAdd1bitOnes)108 TEST(QuantizationTest, DequantizeAdd1bitOnes) {
109   const int bytes_per_embedding = 4;
110   const int num_buckets = 3;
111   const int num_sparse_features = 7;
112   const int quantization_bits = 1;
113   const int bucket_id = 1;
114 
115   std::vector<float> scales(num_buckets, 1.0);
116   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
117 
118   std::vector<float> dest(32);
119   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
120                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
121                 dest.size());
122   std::vector<float> expected(32);
123   std::fill(expected.begin(), expected.end(),
124             1.0 / num_sparse_features * (1 - 1));
125   EXPECT_THAT(dest, ElementsAreFloat(expected));
126 }
127 
TEST(QuantizationTest,DequantizeAdd3bit)128 TEST(QuantizationTest, DequantizeAdd3bit) {
129   const int bytes_per_embedding = 4;
130   const int num_buckets = 3;
131   const int num_sparse_features = 7;
132   const int quantization_bits = 3;
133   const int bucket_id = 1;
134 
135   std::vector<float> scales(num_buckets, 1.0);
136   scales[1] = 9.0;
137   std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
138   // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
139   embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
140   embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
141   embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
142 
143   std::vector<float> dest(10);
144   DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
145                 num_sparse_features, quantization_bits, bucket_id, dest.data(),
146                 dest.size());
147 
148   std::vector<float> expected;
149   expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
150   expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
151   expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
152   expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
153   expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
154   expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
155   expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
156   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
157   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
158   expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
159   EXPECT_THAT(dest, ElementsAreFloat(expected));
160 }
161 
162 }  // namespace
163 }  // namespace libtextclassifier2
164