1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10 #include <cassert>
11 #include <cstdint>
12 #include <memory>
13 #include <vector>
14
15 #ifndef _WIN32
16 #include <qnnpack/AlignedAllocator.h>
17 #endif
18 #include <pytorch_qnnpack.h>
19 #include <qnnpack/common.h>
20 #include <qnnpack/math.h>
21
22 #ifdef QNNPACK_BCSRMATRIX_DEBUG
23 #include <iostream>
24 #endif // QNNPACK_BCSRMATRIX_DEBUG
25
26 namespace qnnpack {
27
28 template <typename T>
29 struct OwnedOrBorrowedVector {
30 using VECTOR_T =
31 #ifndef _WIN32
32 std::vector<T, AlignedAllocator<T, 16>>;
33 #else
34 std::vector<T>;
35 #endif
36
37 // Only one of owned_vec_data_ or borrowed_tuple_data_ will be meaningfully
38 // populated.
39 // A union could potentially be used here to reduce memory usage.
40 // std::variant is not used here because it causes internal build errors
41 // due to incompatibility.
42 VECTOR_T owned_vec_data_;
43 std::tuple<T*, uint32_t> borrowed_tuple_data_;
44 bool owned;
45
vectorOwnedOrBorrowedVector46 VECTOR_T& vector() {
47 assert(owned);
48 return owned_vec_data_;
49 }
50
sizeOwnedOrBorrowedVector51 uint32_t size() const {
52 if (owned) {
53 return owned_vec_data_.size();
54 } else {
55 return std::get<1>(borrowed_tuple_data_);
56 }
57 }
58
dataOwnedOrBorrowedVector59 const T* data() const {
60 if (owned) {
61 return owned_vec_data_.data();
62 } else {
63 return std::get<0>(borrowed_tuple_data_);
64 }
65 }
66
67 const T& operator[](int i) const {
68 return data()[i];
69 }
70
OwnedOrBorrowedVectorOwnedOrBorrowedVector71 OwnedOrBorrowedVector() : owned(true) {}
72
OwnedOrBorrowedVectorOwnedOrBorrowedVector73 OwnedOrBorrowedVector(T* data_ptr, const uint32_t size)
74 : borrowed_tuple_data_(std::tuple<T*, uint32_t>(data_ptr, size)),
75 owned(false) {}
76 };
77
78 struct BCSRMatrix {
79 OwnedOrBorrowedVector<uint8_t> values;
80 uint32_t col_block_size; // input features block size
81 uint32_t row_block_size; // output features block size
82 enum pytorch_qnnp_sparse_matrix_indices_dtype indices_dtype;
83 virtual ~BCSRMatrix() = default;
84 // Return void for the data ptrs because it doesn't require knowing the
85 // underlying TypedBCSRMatrix indices dtype and that's how it's passed
86 // into the qnnpack fully connected sparse op
87 virtual const void* col_indices_data_ptr() const = 0;
88 virtual const void* row_values_data_ptr() const = 0;
89 #ifdef QNNPACK_BCSRMATRIX_DEBUG
90 virtual void print() const = 0;
91 #endif // QNNPACK_BCSRMATRIX_DEBUG
92 /*
93 * Unpack from BCSR to Dense
94 * - Each value and zero point converted to int8_t by subtracting 128
95 * - num_rows and num_cols are dimensions of dense weight tensor
96 * - dst should be able to hold num_rows * num_cols elements
97 * - zero_points should hold num_rows zero points
98 */
99 virtual void unpack(
100 int8_t* dst,
101 const int64_t num_rows,
102 const int64_t num_cols,
103 const uint8_t* zero_points) const = 0;
104 virtual uint32_t max_index() const = 0;
105 };
106
107 template <typename INDICES_DTYPE>
108 struct TypedBCSRMatrix : BCSRMatrix {
109 OwnedOrBorrowedVector<INDICES_DTYPE> col_indices;
110 OwnedOrBorrowedVector<INDICES_DTYPE> row_values;
111 TypedBCSRMatrix();
112 const void* col_indices_data_ptr() const override;
113 const void* row_values_data_ptr() const override;
114 #ifdef QNNPACK_BCSRMATRIX_DEBUG
115 void print() const override;
116 #endif // QNNPACK_BCSRMATRIX_DEBUG
117 void unpack(
118 int8_t* dst,
119 const int64_t num_rows,
120 const int64_t num_cols,
121 const uint8_t* zero_points) const override;
122 uint32_t max_index() const override;
123
124 ~TypedBCSRMatrix() override = default;
125 };
126
127 template <typename INDICES_DTYPE>
generateBlockCSRMatrix(const uint8_t * a,const size_t N,const size_t K,const uint32_t row_block_size,const uint32_t col_block_size,const uint8_t * zero_points)128 std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
129 const uint8_t* a,
130 const size_t N,
131 const size_t K,
132 const uint32_t row_block_size,
133 const uint32_t col_block_size,
134 const uint8_t* zero_points) {
135 assert(K > 0);
136 std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
137 std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
138 auto& row_values = bcsr_mat->row_values.vector();
139 auto& col_indices = bcsr_mat->col_indices.vector();
140 auto& values = bcsr_mat->values.vector();
141
142 const uint32_t num_row_blocks = (N + row_block_size - 1) / row_block_size;
143 // K must be > 0
144 const uint32_t num_col_blocks = (K + col_block_size - 1) / col_block_size;
145
146 row_values.reserve(num_row_blocks);
147 uint32_t num_nnz_blocks{0};
148 row_values.push_back(num_nnz_blocks);
149 for (uint32_t i = 0; i < num_row_blocks; ++i) {
150 for (uint32_t j = 0; j < num_col_blocks; ++j) {
151 bool block_zero{true};
152 for (uint32_t ib = 0; ib < row_block_size; ++ib) {
153 uint32_t row_index = i * row_block_size + ib;
154 if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
155 break;
156 }
157 for (uint32_t jb = 0; jb < col_block_size; ++jb) {
158 uint32_t col_index = j * col_block_size + jb;
159 if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
160 goto block_scanned;
161 }
162 if (*(a + row_index * K + col_index) != zero_points[row_index]) {
163 block_zero = false;
164 goto block_scanned;
165 }
166 }
167 }
168 block_scanned:
169 if (!block_zero) {
170 col_indices.push_back(j);
171 num_nnz_blocks++;
172 for (uint32_t ib = 0; ib < row_block_size; ++ib) {
173 uint32_t row_index = i * row_block_size + ib;
174 if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
175 for (; row_index < (num_row_blocks * row_block_size); row_index++) {
176 for (uint32_t jb = 0; jb < col_block_size; ++jb) {
177 values.push_back(zero_points[N-1]);
178 }
179 }
180 break;
181 }
182 for (uint32_t jb = 0; jb < col_block_size; ++jb) {
183 uint32_t col_index = j * col_block_size + jb;
184 if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
185 values.push_back(zero_points[row_index]);
186 } else {
187 uint8_t val = *(a + row_index * K + col_index);
188 values.push_back(val);
189 }
190 }
191 }
192 }
193 }
194 row_values.push_back(num_nnz_blocks);
195 }
196 bcsr_mat->row_block_size = row_block_size;
197 bcsr_mat->col_block_size = col_block_size;
198 return bcsr_mat;
199 }
200
201 template <typename INDICES_DTYPE>
generateBlockCSRMatrix(INDICES_DTYPE * col_indices,INDICES_DTYPE * row_values,uint8_t * values,const int64_t col_indices_size,const int64_t row_values_size,const int64_t values_size,const int64_t row_block_size,const int64_t col_block_size)202 std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
203 INDICES_DTYPE* col_indices,
204 INDICES_DTYPE* row_values,
205 uint8_t* values,
206 const int64_t col_indices_size,
207 const int64_t row_values_size,
208 const int64_t values_size,
209 const int64_t row_block_size,
210 const int64_t col_block_size) {
211 std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
212 std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
213 bcsr_mat->col_indices =
214 OwnedOrBorrowedVector<INDICES_DTYPE>(col_indices, col_indices_size);
215 bcsr_mat->row_values =
216 OwnedOrBorrowedVector<INDICES_DTYPE>(row_values, row_values_size);
217 bcsr_mat->values = OwnedOrBorrowedVector<uint8_t>(values, values_size);
218 bcsr_mat->row_block_size = row_block_size;
219 bcsr_mat->col_block_size = col_block_size;
220 return bcsr_mat;
221 }
222
223 template <typename INDICES_DTYPE>
224 struct IndicesDtypeEnumTrait {
225 static_assert(
226 sizeof(INDICES_DTYPE) == 0,
227 "Invalid dtype for IndicesDtypeEnumTrait");
228 };
229
230 template <>
231 struct IndicesDtypeEnumTrait<uint32_t> {
232 const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
233 pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t;
234 };
235
236 template <>
237 struct IndicesDtypeEnumTrait<uint16_t> {
238 const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
239 pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t;
240 };
241
242 template <>
243 struct IndicesDtypeEnumTrait<uint8_t> {
244 const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
245 pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t;
246 };
247
248 template <typename INDICES_DTYPE>
249 TypedBCSRMatrix<INDICES_DTYPE>::TypedBCSRMatrix() {
250 indices_dtype = IndicesDtypeEnumTrait<INDICES_DTYPE>::dtype;
251 }
252
253 template <typename INDICES_DTYPE>
254 const void* TypedBCSRMatrix<INDICES_DTYPE>::col_indices_data_ptr() const {
255 return static_cast<const void*>(col_indices.data());
256 }
257
258 template <typename INDICES_DTYPE>
259 const void* TypedBCSRMatrix<INDICES_DTYPE>::row_values_data_ptr() const {
260 return static_cast<const void*>(row_values.data());
261 }
262
263 #ifdef QNNPACK_BCSRMATRIX_DEBUG
264 template <typename INDICES_DTYPE>
265 void TypedBCSRMatrix<INDICES_DTYPE>::print() const {
266 std::cout << "row block size:" << row_block_size << std::endl;
267 std::cout << "col block size:" << col_block_size << std::endl;
268 std::cout << "row ptr\n";
269 std::cout
270 << "indices dtype: uint"
271 << static_cast<
272 std::underlying_type_t<pytorch_qnnp_sparse_matrix_indices_dtype>>(
273 indices_dtype)
274 << "_t" << std::endl;
275 for (uint32_t i = 0; i < row_values.size(); i++) {
276 std::cout << (uint32_t)row_values[i] << ", ";
277 }
278 std::cout << std::endl;
279 std::cout << "col indices\n";
280 for (uint32_t i = 0; i < col_indices.size(); i++) {
281 std::cout << (uint32_t)col_indices[i] << ", ";
282 }
283 std::cout << std::endl;
284 std::cout << "Actual values\n";
285 for (uint32_t i = 0; i < values.size(); i++) {
286 std::cout << (uint32_t)values[i] << ", ";
287 }
288 std::cout << std::endl;
289 }
290 #endif // QNNPACK_BCSRMATRIX_DEBUG
291
292 template <typename INDICES_DTYPE>
293 void TypedBCSRMatrix<INDICES_DTYPE>::unpack(
294 int8_t* dst,
295 const int64_t num_rows,
296 const int64_t num_cols,
297 const uint8_t* zero_points) const {
298 for (int64_t i = 0; i < num_rows; i++) {
299 memset(
300 dst + i * num_cols,
301 static_cast<int8_t>(static_cast<int16_t>(zero_points[i]) - 128),
302 num_cols * sizeof(int8_t));
303 }
304
305 const int64_t num_block_rows = static_cast<int64_t>(row_values.size()) - 1;
306 const int64_t block_size = (int64_t)row_block_size * col_block_size;
307 int64_t weight_values_num = 0;
308 for (int64_t block_row_num = 0; block_row_num < num_block_rows;
309 block_row_num++) {
310 const int64_t num_blocks_in_current_block_row =
311 row_values[block_row_num + 1] - row_values[block_row_num];
312 for (int64_t k = 0; k < num_blocks_in_current_block_row;
313 k++) { // iterate over each block in the row
314 const int64_t block_start_row_num = block_row_num * row_block_size;
315 const int64_t block_start_col_num =
316 (int64_t)(col_indices[weight_values_num / block_size]) *
317 col_block_size;
318 for (int64_t l = 0; l < block_size;
319 l++) { // iterate over each value in the block
320 const int64_t row_num = block_start_row_num + l / col_block_size;
321 const int64_t col_num = block_start_col_num + l % col_block_size;
322 if (row_num < num_rows && col_num < num_cols) {
323 dst[row_num * num_cols + col_num] = static_cast<int8_t>(
324 static_cast<int16_t>(values[weight_values_num]) - 128);
325 }
326 weight_values_num++;
327 }
328 }
329 }
330 }
331
332 template <typename INDICES_DTYPE>
333 uint32_t TypedBCSRMatrix<INDICES_DTYPE>::max_index() const {
334 return static_cast<uint32_t>(std::max(
335 *std::max_element(
336 row_values.data(), row_values.data() + row_values.size()),
337 *std::max_element(
338 col_indices.data(), col_indices.data() + col_indices.size())));
339 }
340
341 /**
342 * Given a BCSRMatrix (bcsr_) and a block of code enclosed in { }
343 * (dispatch_body), run the block of code with the following in scope
344 * 1) The BCSRMatrix's underlying TypedBCSRMatrix, called typed_bcsr
345 * 2) The TypedBCSRMatrix's indices data type, called INDICES_DTYPE
346 */
347 #define QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(bcsr_, dispatch_body) \
348 [&bcsr = bcsr_]() { \
349 switch (bcsr->indices_dtype) { \
350 case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t: { \
351 using INDICES_DTYPE = uint32_t; \
352 const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr = \
353 static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>( \
354 bcsr.get()); \
355 return [&typed_bcsr]() dispatch_body(); \
356 } \
357 case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t: { \
358 using INDICES_DTYPE = uint16_t; \
359 const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr = \
360 static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>( \
361 bcsr.get()); \
362 return [&typed_bcsr]() dispatch_body(); \
363 } \
364 case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t: { \
365 using INDICES_DTYPE = uint8_t; \
366 const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr = \
367 static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>( \
368 bcsr.get()); \
369 return [&typed_bcsr]() dispatch_body(); \
370 } \
371 case pytorch_qnnp_sparse_matrix_indices_dtype_invalid: { \
372 assert(false); \
373 } \
374 } \
375 /* Throw exception to avoid the following errors: */ \
376 /* - "non-void lambda does not return a value in all control paths" */ \
377 /* - "control reaches end of non-void function" */ \
378 /* Throwing exception from within invalid case alone does not fix these */ \
379 throw std::invalid_argument( \
380 "Invalid indices dtype in QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE"); \
381 }()
382
383 } // namespace qnnpack
384