• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 #include <cassert>
11 #include <cstdint>
12 #include <memory>
13 #include <vector>
14 
15 #ifndef _WIN32
16 #include <qnnpack/AlignedAllocator.h>
17 #endif
18 #include <pytorch_qnnpack.h>
19 #include <qnnpack/common.h>
20 #include <qnnpack/math.h>
21 
22 #ifdef QNNPACK_BCSRMATRIX_DEBUG
23 #include <iostream>
24 #endif // QNNPACK_BCSRMATRIX_DEBUG
25 
26 namespace qnnpack {
27 
28 template <typename T>
29 struct OwnedOrBorrowedVector {
30   using VECTOR_T =
31 #ifndef _WIN32
32       std::vector<T, AlignedAllocator<T, 16>>;
33 #else
34       std::vector<T>;
35 #endif
36 
37   // Only one of owned_vec_data_ or borrowed_tuple_data_ will be meaningfully
38   // populated.
39   // A union could potentially be used here to reduce memory usage.
40   // std::variant is not used here because it causes internal build errors
41   // due to incompatibility.
42   VECTOR_T owned_vec_data_;
43   std::tuple<T*, uint32_t> borrowed_tuple_data_;
44   bool owned;
45 
vectorOwnedOrBorrowedVector46   VECTOR_T& vector() {
47     assert(owned);
48     return owned_vec_data_;
49   }
50 
sizeOwnedOrBorrowedVector51   uint32_t size() const {
52     if (owned) {
53       return owned_vec_data_.size();
54     } else {
55       return std::get<1>(borrowed_tuple_data_);
56     }
57   }
58 
dataOwnedOrBorrowedVector59   const T* data() const {
60     if (owned) {
61       return owned_vec_data_.data();
62     } else {
63       return std::get<0>(borrowed_tuple_data_);
64     }
65   }
66 
67   const T& operator[](int i) const {
68     return data()[i];
69   }
70 
OwnedOrBorrowedVectorOwnedOrBorrowedVector71   OwnedOrBorrowedVector() : owned(true) {}
72 
OwnedOrBorrowedVectorOwnedOrBorrowedVector73   OwnedOrBorrowedVector(T* data_ptr, const uint32_t size)
74       : borrowed_tuple_data_(std::tuple<T*, uint32_t>(data_ptr, size)),
75         owned(false) {}
76 };
77 
78 struct BCSRMatrix {
79   OwnedOrBorrowedVector<uint8_t> values;
80   uint32_t col_block_size;  // input features block size
81   uint32_t row_block_size;  // output features block size
82   enum pytorch_qnnp_sparse_matrix_indices_dtype indices_dtype;
83   virtual ~BCSRMatrix() = default;
84   // Return void for the data ptrs because it doesn't require knowing the
85   // underlying TypedBCSRMatrix indices dtype and that's how it's passed
86   // into the qnnpack fully connected sparse op
87   virtual const void* col_indices_data_ptr() const = 0;
88   virtual const void* row_values_data_ptr() const = 0;
89 #ifdef QNNPACK_BCSRMATRIX_DEBUG
90   virtual void print() const = 0;
91 #endif // QNNPACK_BCSRMATRIX_DEBUG
92   /*
93    * Unpack from BCSR to Dense
94    * - Each value and zero point converted to int8_t by subtracting 128
95    * - num_rows and num_cols are dimensions of dense weight tensor
96    * - dst should be able to hold num_rows * num_cols elements
97    * - zero_points should hold num_rows zero points
98    */
99   virtual void unpack(
100       int8_t* dst,
101       const int64_t num_rows,
102       const int64_t num_cols,
103       const uint8_t* zero_points) const = 0;
104   virtual uint32_t max_index() const = 0;
105 };
106 
107 template <typename INDICES_DTYPE>
108 struct TypedBCSRMatrix : BCSRMatrix {
109   OwnedOrBorrowedVector<INDICES_DTYPE> col_indices;
110   OwnedOrBorrowedVector<INDICES_DTYPE> row_values;
111   TypedBCSRMatrix();
112   const void* col_indices_data_ptr() const override;
113   const void* row_values_data_ptr() const override;
114 #ifdef QNNPACK_BCSRMATRIX_DEBUG
115   void print() const override;
116 #endif // QNNPACK_BCSRMATRIX_DEBUG
117   void unpack(
118       int8_t* dst,
119       const int64_t num_rows,
120       const int64_t num_cols,
121       const uint8_t* zero_points) const override;
122   uint32_t max_index() const override;
123 
124   ~TypedBCSRMatrix() override = default;
125 };
126 
127 template <typename INDICES_DTYPE>
generateBlockCSRMatrix(const uint8_t * a,const size_t N,const size_t K,const uint32_t row_block_size,const uint32_t col_block_size,const uint8_t * zero_points)128 std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
129     const uint8_t* a,
130     const size_t N,
131     const size_t K,
132     const uint32_t row_block_size,
133     const uint32_t col_block_size,
134     const uint8_t* zero_points) {
135   assert(K > 0);
136   std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
137       std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
138   auto& row_values = bcsr_mat->row_values.vector();
139   auto& col_indices = bcsr_mat->col_indices.vector();
140   auto& values = bcsr_mat->values.vector();
141 
142   const uint32_t num_row_blocks = (N + row_block_size - 1) / row_block_size;
143   // K must be > 0
144   const uint32_t num_col_blocks = (K + col_block_size - 1) / col_block_size;
145 
146   row_values.reserve(num_row_blocks);
147   uint32_t num_nnz_blocks{0};
148   row_values.push_back(num_nnz_blocks);
149   for (uint32_t i = 0; i < num_row_blocks; ++i) {
150     for (uint32_t j = 0; j < num_col_blocks; ++j) {
151       bool block_zero{true};
152       for (uint32_t ib = 0; ib < row_block_size; ++ib) {
153         uint32_t row_index = i * row_block_size + ib;
154         if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
155           break;
156         }
157         for (uint32_t jb = 0; jb < col_block_size; ++jb) {
158           uint32_t col_index = j * col_block_size + jb;
159           if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
160             goto block_scanned;
161           }
162           if (*(a + row_index * K + col_index) != zero_points[row_index]) {
163             block_zero = false;
164             goto block_scanned;
165           }
166         }
167       }
168 block_scanned:
169       if (!block_zero) {
170         col_indices.push_back(j);
171         num_nnz_blocks++;
172         for (uint32_t ib = 0; ib < row_block_size; ++ib) {
173           uint32_t row_index = i * row_block_size + ib;
174           if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
175             for (; row_index < (num_row_blocks * row_block_size); row_index++) {
176               for (uint32_t jb = 0; jb < col_block_size; ++jb) {
177                 values.push_back(zero_points[N-1]);
178               }
179             }
180             break;
181           }
182           for (uint32_t jb = 0; jb < col_block_size; ++jb) {
183             uint32_t col_index = j * col_block_size + jb;
184             if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
185               values.push_back(zero_points[row_index]);
186             } else {
187               uint8_t val = *(a + row_index * K + col_index);
188               values.push_back(val);
189             }
190           }
191         }
192       }
193     }
194     row_values.push_back(num_nnz_blocks);
195   }
196   bcsr_mat->row_block_size = row_block_size;
197   bcsr_mat->col_block_size = col_block_size;
198   return bcsr_mat;
199 }
200 
201 template <typename INDICES_DTYPE>
generateBlockCSRMatrix(INDICES_DTYPE * col_indices,INDICES_DTYPE * row_values,uint8_t * values,const int64_t col_indices_size,const int64_t row_values_size,const int64_t values_size,const int64_t row_block_size,const int64_t col_block_size)202 std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
203     INDICES_DTYPE* col_indices,
204     INDICES_DTYPE* row_values,
205     uint8_t* values,
206     const int64_t col_indices_size,
207     const int64_t row_values_size,
208     const int64_t values_size,
209     const int64_t row_block_size,
210     const int64_t col_block_size) {
211   std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
212       std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
213   bcsr_mat->col_indices =
214       OwnedOrBorrowedVector<INDICES_DTYPE>(col_indices, col_indices_size);
215   bcsr_mat->row_values =
216       OwnedOrBorrowedVector<INDICES_DTYPE>(row_values, row_values_size);
217   bcsr_mat->values = OwnedOrBorrowedVector<uint8_t>(values, values_size);
218   bcsr_mat->row_block_size = row_block_size;
219   bcsr_mat->col_block_size = col_block_size;
220   return bcsr_mat;
221 }
222 
223 template <typename INDICES_DTYPE>
224 struct IndicesDtypeEnumTrait {
225   static_assert(
226       sizeof(INDICES_DTYPE) == 0,
227       "Invalid dtype for IndicesDtypeEnumTrait");
228 };
229 
230 template <>
231 struct IndicesDtypeEnumTrait<uint32_t> {
232   const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
233       pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t;
234 };
235 
236 template <>
237 struct IndicesDtypeEnumTrait<uint16_t> {
238   const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
239       pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t;
240 };
241 
242 template <>
243 struct IndicesDtypeEnumTrait<uint8_t> {
244   const static pytorch_qnnp_sparse_matrix_indices_dtype dtype =
245       pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t;
246 };
247 
248 template <typename INDICES_DTYPE>
249 TypedBCSRMatrix<INDICES_DTYPE>::TypedBCSRMatrix() {
250   indices_dtype = IndicesDtypeEnumTrait<INDICES_DTYPE>::dtype;
251 }
252 
253 template <typename INDICES_DTYPE>
254 const void* TypedBCSRMatrix<INDICES_DTYPE>::col_indices_data_ptr() const {
255   return static_cast<const void*>(col_indices.data());
256 }
257 
258 template <typename INDICES_DTYPE>
259 const void* TypedBCSRMatrix<INDICES_DTYPE>::row_values_data_ptr() const {
260   return static_cast<const void*>(row_values.data());
261 }
262 
263 #ifdef QNNPACK_BCSRMATRIX_DEBUG
264 template <typename INDICES_DTYPE>
265 void TypedBCSRMatrix<INDICES_DTYPE>::print() const {
266   std::cout << "row block size:" << row_block_size << std::endl;
267   std::cout << "col block size:" << col_block_size << std::endl;
268   std::cout << "row ptr\n";
269   std::cout
270       << "indices dtype: uint"
271       << static_cast<
272              std::underlying_type_t<pytorch_qnnp_sparse_matrix_indices_dtype>>(
273              indices_dtype)
274       << "_t" << std::endl;
275   for (uint32_t i = 0; i < row_values.size(); i++) {
276     std::cout << (uint32_t)row_values[i] << ", ";
277   }
278   std::cout << std::endl;
279   std::cout << "col indices\n";
280   for (uint32_t i = 0; i < col_indices.size(); i++) {
281     std::cout << (uint32_t)col_indices[i] << ", ";
282   }
283   std::cout << std::endl;
284   std::cout << "Actual values\n";
285   for (uint32_t i = 0; i < values.size(); i++) {
286     std::cout << (uint32_t)values[i] << ", ";
287   }
288   std::cout << std::endl;
289 }
290 #endif // QNNPACK_BCSRMATRIX_DEBUG
291 
292 template <typename INDICES_DTYPE>
293 void TypedBCSRMatrix<INDICES_DTYPE>::unpack(
294     int8_t* dst,
295     const int64_t num_rows,
296     const int64_t num_cols,
297     const uint8_t* zero_points) const {
298   for (int64_t i = 0; i < num_rows; i++) {
299     memset(
300         dst + i * num_cols,
301         static_cast<int8_t>(static_cast<int16_t>(zero_points[i]) - 128),
302         num_cols * sizeof(int8_t));
303   }
304 
305   const int64_t num_block_rows = static_cast<int64_t>(row_values.size()) - 1;
306   const int64_t block_size = (int64_t)row_block_size * col_block_size;
307   int64_t weight_values_num = 0;
308   for (int64_t block_row_num = 0; block_row_num < num_block_rows;
309        block_row_num++) {
310     const int64_t num_blocks_in_current_block_row =
311         row_values[block_row_num + 1] - row_values[block_row_num];
312     for (int64_t k = 0; k < num_blocks_in_current_block_row;
313          k++) { // iterate over each block in the row
314       const int64_t block_start_row_num = block_row_num * row_block_size;
315       const int64_t block_start_col_num =
316           (int64_t)(col_indices[weight_values_num / block_size]) *
317           col_block_size;
318       for (int64_t l = 0; l < block_size;
319            l++) { // iterate over each value in the block
320         const int64_t row_num = block_start_row_num + l / col_block_size;
321         const int64_t col_num = block_start_col_num + l % col_block_size;
322         if (row_num < num_rows && col_num < num_cols) {
323           dst[row_num * num_cols + col_num] = static_cast<int8_t>(
324               static_cast<int16_t>(values[weight_values_num]) - 128);
325         }
326         weight_values_num++;
327       }
328     }
329   }
330 }
331 
332 template <typename INDICES_DTYPE>
333 uint32_t TypedBCSRMatrix<INDICES_DTYPE>::max_index() const {
334   return static_cast<uint32_t>(std::max(
335       *std::max_element(
336           row_values.data(), row_values.data() + row_values.size()),
337       *std::max_element(
338           col_indices.data(), col_indices.data() + col_indices.size())));
339 }
340 
341 /**
342  * Given a BCSRMatrix (bcsr_) and a block of code enclosed in { }
343  * (dispatch_body), run the block of code with the following in scope
344  * 1) The BCSRMatrix's underlying TypedBCSRMatrix, called typed_bcsr
345  * 2) The TypedBCSRMatrix's indices data type, called INDICES_DTYPE
346  */
347 #define QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(bcsr_, dispatch_body)        \
348   [&bcsr = bcsr_]() {                                                          \
349     switch (bcsr->indices_dtype) {                                             \
350       case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t: {                \
351         using INDICES_DTYPE = uint32_t;                                        \
352         const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr =            \
353             static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>(       \
354                 bcsr.get());                                                   \
355         return [&typed_bcsr]() dispatch_body();                                \
356       }                                                                        \
357       case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t: {                \
358         using INDICES_DTYPE = uint16_t;                                        \
359         const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr =            \
360             static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>(       \
361                 bcsr.get());                                                   \
362         return [&typed_bcsr]() dispatch_body();                                \
363       }                                                                        \
364       case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t: {                 \
365         using INDICES_DTYPE = uint8_t;                                         \
366         const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>* typed_bcsr =            \
367             static_cast<const qnnpack::TypedBCSRMatrix<INDICES_DTYPE>*>(       \
368                 bcsr.get());                                                   \
369         return [&typed_bcsr]() dispatch_body();                                \
370       }                                                                        \
371       case pytorch_qnnp_sparse_matrix_indices_dtype_invalid: {                 \
372         assert(false);                                                         \
373       }                                                                        \
374     }                                                                          \
375     /* Throw exception to avoid the following errors: */                       \
376     /* - "non-void lambda does not return a value in all control paths" */     \
377     /* - "control reaches end of non-void function" */                         \
378     /* Throwing exception from within invalid case alone does not fix these */ \
379     throw std::invalid_argument(                                               \
380         "Invalid indices dtype in QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE"); \
381   }()
382 
383 } // namespace qnnpack
384