• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // unpack.h: unpacking the result blocks computed by compute.h,
16 // storing them into the destination matrix.
17 
18 #ifndef GEMMLOWP_INTERNAL_UNPACK_H_
19 #define GEMMLOWP_INTERNAL_UNPACK_H_
20 
21 #include "allocator.h"
22 #include "block_params.h"
23 #include "output.h"
24 #include "pack.h"
25 
26 #include <cmath>
27 
28 namespace gemmlowp {
29 
30 class PackedResult {
31  public:
PackedResult(Allocator * _allocator,const BlockParams & _block_params)32   PackedResult(Allocator* _allocator, const BlockParams& _block_params)
33       : allocator_(_allocator), block_params_(_block_params) {
34     matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows *
35                                                        block_params_.l2_cols);
36   }
37 
~PackedResult()38   ~PackedResult() {}
39 
Map()40   MatrixMap<std::int32_t, MapOrder::ColMajor> Map() {
41     return MatrixMap<std::int32_t, MapOrder::ColMajor>(
42         allocator_->GetPointer<std::int32_t>(matrix_handle_),
43         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
44   }
45 
Map()46   MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const {
47     return MatrixMap<const std::int32_t, MapOrder::ColMajor>(
48         allocator_->GetPointer<const std::int32_t>(matrix_handle_),
49         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
50   }
51 
52  private:
53   Allocator* allocator_;
54   Allocator::Handle matrix_handle_;
55   const BlockParams& block_params_;
56 };
57 
58 struct MatrixBlockBounds {
59   int start_row;
60   int start_col;
61   int rows;
62   int cols;
63 
MatrixBlockBoundsMatrixBlockBounds64   MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_)
65       : start_row(start_row_),
66         start_col(start_col_),
67         rows(rows_),
68         cols(cols_) {}
69 };
70 
71 template <int Rows, int Cols, typename SrcMapType>
PrefetchResultBlock(const SrcMapType & src,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,int src_row,int src_col)72 void PrefetchResultBlock(const SrcMapType& src,
73                          const VectorMap<const std::int32_t, VectorShape::Col>&
74                              lhs_sums_of_each_slice,
75                          int src_row, int src_col) {
76   const std::int32_t* src_data = src.data(src_row, src_col);
77   const int src_stride = src.stride();
78   const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row);
79   for (int r = 0; r < Rows; r += 4) {
80     Prefetch(lhs_sums_data + r);
81   }
82   for (int c = 0; c < Cols; c++) {
83     for (int r = 0; r < Rows; r += 4) {
84       Prefetch(src_data + r + c * src_stride);
85     }
86   }
87 }
88 
89 template <typename KernelFormat, typename RegisterBlockType,
90           typename SrcMapType, typename LhsOffset, typename RhsOffset,
91           typename OutputPipelineExecutorType, typename DstType>
UnpackResultBlock(const SrcMapType & src,const OutputPipelineExecutorType & executor,DstType * dst,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,const VectorMap<const std::int32_t,VectorShape::Row> & rhs_sums_of_each_slice,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,int depth,int src_row,int src_col,int src_global_row,int src_global_col,int dst_row,int dst_col)92 void UnpackResultBlock(const SrcMapType& src,
93                        const OutputPipelineExecutorType& executor, DstType* dst,
94                        const VectorMap<const std::int32_t, VectorShape::Col>&
95                            lhs_sums_of_each_slice,
96                        const VectorMap<const std::int32_t, VectorShape::Row>&
97                            rhs_sums_of_each_slice,
98                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
99                        int depth, int src_row, int src_col, int src_global_row,
100                        int src_global_col, int dst_row, int dst_col) {
101   using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar;
102   using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
103   using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar;
104   using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
105   static constexpr int KernelLhsZeroPointInput =
106       ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue;
107   static constexpr int KernelRhsZeroPointInput =
108       ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue;
109   auto acc = Load<RegisterBlockType>(src, src_row, src_col);
110   const auto& lhs_sums_of_each_slice_block =
111       LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
112   const auto& rhs_sums_of_each_slice_block =
113       LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col);
114   auto lhs_offset_block =
115       LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row);
116   auto rhs_offset_block =
117       LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col);
118   AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block);
119   AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block);
120   BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc);
121   for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) {
122     rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth);
123   }
124   BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block),
125                   lhs_offset_block, &acc);
126   executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col);
127 }
128 
129 template <typename KernelFormat, typename ResultBlockType,
130           typename PackedResultType, typename LhsOffset, typename RhsOffset,
131           typename OutputPipelineType>
UnpackResult(ResultBlockType * dst,const MatrixBlockBounds & dst_block,const PackedResultType & src,int depth,const std::int32_t * lhs_sums_of_each_slice_ptr,const std::int32_t * rhs_sums_of_each_slice_ptr,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)132 void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
133                   const PackedResultType& src, int depth,
134                   const std::int32_t* lhs_sums_of_each_slice_ptr,
135                   const std::int32_t* rhs_sums_of_each_slice_ptr,
136                   const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
137                   const OutputPipelineType& output_pipeline) {
138   ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor
139                                  ? "unpack to column-major"
140                                  : "unpack to row-major");
141   assert(dst_block.start_row >= 0);
142   assert(dst_block.start_row + dst_block.rows <= dst->rows());
143   assert(dst_block.start_col >= 0);
144   assert(dst_block.start_col + dst_block.cols <= dst->cols());
145   const auto src_map = src.Map();
146   const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice(
147       lhs_sums_of_each_slice_ptr, dst_block.rows);
148   const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice(
149       rhs_sums_of_each_slice_ptr, dst_block.cols);
150   using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>;
151   using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>;
152   using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>;
153   using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>;
154   using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>;
155   using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>;
156 
157   using DstScalarType = typename ResultBlockType::Scalar;
158   using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>;
159 
160   OutputPipelineExecutor<OutputPipelineType, Int32x1x1>
161       output_pipeline_executor_1x1(output_pipeline);
162   OutputPipelineExecutor<OutputPipelineType, Int32x4x1>
163       output_pipeline_executor_4x1(output_pipeline);
164   OutputPipelineExecutor<OutputPipelineType, Int32x8x1>
165       output_pipeline_executor_8x1(output_pipeline);
166   OutputPipelineExecutor<OutputPipelineType, Int32x1x4>
167       output_pipeline_executor_1x4(output_pipeline);
168   OutputPipelineExecutor<OutputPipelineType, Int32x4x4>
169       output_pipeline_executor_4x4(output_pipeline);
170   OutputPipelineExecutor<OutputPipelineType, Int32x8x4>
171       output_pipeline_executor_8x4(output_pipeline);
172 
173   int c8 = 0;
174   if (ResultBlockType::kOrder == MapOrder::RowMajor) {
175     for (; c8 <= dst_block.cols - 8; c8 += 8) {
176       PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8);
177       int r = 0;
178       for (; r <= dst_block.rows - 8; r += 8) {
179         const int global_row = r + dst_block.start_row;
180         PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8);
181         DstScalarType dst_colmajor_buf[64];
182         MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map(
183             dst_colmajor_buf, 8, 8);
184         for (int cx = 0; cx < 8; cx += 4) {
185           const int c = c8 + cx;
186           const int global_col = c + dst_block.start_col;
187           UnpackResultBlock<KernelFormat, Int32x8x4>(
188               src_map, output_pipeline_executor_8x4, &dst_colmajor_map,
189               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
190               rhs_offset, depth, r, c, global_row, global_col, 0, cx);
191         }
192         StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst,
193                          r + dst_block.start_row, c8 + dst_block.start_col);
194       }
195       for (; r <= dst_block.rows - 4; r += 4) {
196         const int global_row = r + dst_block.start_row;
197         for (int cx = 0; cx < 8; cx += 4) {
198           const int c = c8 + cx;
199           const int global_col = c + dst_block.start_col;
200           UnpackResultBlock<KernelFormat, Int32x4x4>(
201               src_map, output_pipeline_executor_4x4, dst,
202               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
203               rhs_offset, depth, r, c, global_row, global_col, global_row,
204               global_col);
205         }
206       }
207       for (; r < dst_block.rows; r++) {
208         const int global_row = r + dst_block.start_row;
209         for (int cx = 0; cx < 8; cx += 4) {
210           const int c = c8 + cx;
211           const int global_col = c + dst_block.start_col;
212           UnpackResultBlock<KernelFormat, Int32x1x4>(
213               src_map, output_pipeline_executor_1x4, dst,
214               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
215               rhs_offset, depth, r, c, global_row, global_col, global_row,
216               global_col);
217         }
218       }
219     }
220   }
221   int c = c8;
222   for (; c <= dst_block.cols - 4; c += 4) {
223     const int global_col = c + dst_block.start_col;
224     PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c);
225     int r = 0;
226     for (; r <= dst_block.rows - 8; r += 8) {
227       const int global_row = r + dst_block.start_row;
228       PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c);
229       UnpackResultBlock<KernelFormat, Int32x8x4>(
230           src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice,
231           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
232           global_row, global_col, global_row, global_col);
233     }
234     for (; r <= dst_block.rows - 4; r += 4) {
235       const int global_row = r + dst_block.start_row;
236       UnpackResultBlock<KernelFormat, Int32x4x4>(
237           src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice,
238           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
239           global_row, global_col, global_row, global_col);
240     }
241     for (; r < dst_block.rows; r++) {
242       const int global_row = r + dst_block.start_row;
243       UnpackResultBlock<KernelFormat, Int32x1x4>(
244           src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice,
245           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
246           global_row, global_col, global_row, global_col);
247     }
248   }
249   for (; c < dst_block.cols; c++) {
250     const int global_col = c + dst_block.start_col;
251     PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c);
252     int r = 0;
253     for (; r <= dst_block.rows - 8; r += 8) {
254       const int global_row = r + dst_block.start_row;
255       PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c);
256       UnpackResultBlock<KernelFormat, Int32x8x1>(
257           src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice,
258           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
259           global_row, global_col, global_row, global_col);
260     }
261     for (; r <= dst_block.rows - 4; r += 4) {
262       const int global_row = r + dst_block.start_row;
263       UnpackResultBlock<KernelFormat, Int32x4x1>(
264           src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice,
265           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
266           global_row, global_col, global_row, global_col);
267     }
268     for (; r < dst_block.rows; r++) {
269       const int global_row = r + dst_block.start_row;
270       UnpackResultBlock<KernelFormat, Int32x1x1>(
271           src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice,
272           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
273           global_row, global_col, global_row, global_col);
274     }
275   }
276 }
277 
278 }  // end namespace gemmlowp
279 
280 #endif  // GEMMLOWP_INTERNAL_UNPACK_H_
281