1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // unpack.h: unpacking the result blocks computed by compute.h,
16 // storing them into the destination matrix.
17
18 #ifndef GEMMLOWP_INTERNAL_UNPACK_H_
19 #define GEMMLOWP_INTERNAL_UNPACK_H_
20
21 #include "allocator.h"
22 #include "block_params.h"
23 #include "output.h"
24 #include "pack.h"
25
26 #include <cmath>
27
28 namespace gemmlowp {
29
30 class PackedResult {
31 public:
PackedResult(Allocator * _allocator,const BlockParams & _block_params)32 PackedResult(Allocator* _allocator, const BlockParams& _block_params)
33 : allocator_(_allocator), block_params_(_block_params) {
34 matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows *
35 block_params_.l2_cols);
36 }
37
~PackedResult()38 ~PackedResult() {}
39
Map()40 MatrixMap<std::int32_t, MapOrder::ColMajor> Map() {
41 return MatrixMap<std::int32_t, MapOrder::ColMajor>(
42 allocator_->GetPointer<std::int32_t>(matrix_handle_),
43 block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
44 }
45
Map()46 MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const {
47 return MatrixMap<const std::int32_t, MapOrder::ColMajor>(
48 allocator_->GetPointer<const std::int32_t>(matrix_handle_),
49 block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
50 }
51
52 private:
53 Allocator* allocator_;
54 Allocator::Handle matrix_handle_;
55 const BlockParams& block_params_;
56 };
57
58 struct MatrixBlockBounds {
59 int start_row;
60 int start_col;
61 int rows;
62 int cols;
63
MatrixBlockBoundsMatrixBlockBounds64 MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_)
65 : start_row(start_row_),
66 start_col(start_col_),
67 rows(rows_),
68 cols(cols_) {}
69 };
70
71 template <int Rows, int Cols, typename SrcMapType>
PrefetchResultBlock(const SrcMapType & src,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,int src_row,int src_col)72 void PrefetchResultBlock(const SrcMapType& src,
73 const VectorMap<const std::int32_t, VectorShape::Col>&
74 lhs_sums_of_each_slice,
75 int src_row, int src_col) {
76 const std::int32_t* src_data = src.data(src_row, src_col);
77 const int src_stride = src.stride();
78 const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row);
79 for (int r = 0; r < Rows; r += 4) {
80 Prefetch(lhs_sums_data + r);
81 }
82 for (int c = 0; c < Cols; c++) {
83 for (int r = 0; r < Rows; r += 4) {
84 Prefetch(src_data + r + c * src_stride);
85 }
86 }
87 }
88
89 template <typename KernelFormat, typename RegisterBlockType,
90 typename SrcMapType, typename LhsOffset, typename RhsOffset,
91 typename OutputPipelineExecutorType, typename DstType>
UnpackResultBlock(const SrcMapType & src,const OutputPipelineExecutorType & executor,DstType * dst,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,const VectorMap<const std::int32_t,VectorShape::Row> & rhs_sums_of_each_slice,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,int depth,int src_row,int src_col,int src_global_row,int src_global_col,int dst_row,int dst_col)92 void UnpackResultBlock(const SrcMapType& src,
93 const OutputPipelineExecutorType& executor, DstType* dst,
94 const VectorMap<const std::int32_t, VectorShape::Col>&
95 lhs_sums_of_each_slice,
96 const VectorMap<const std::int32_t, VectorShape::Row>&
97 rhs_sums_of_each_slice,
98 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
99 int depth, int src_row, int src_col, int src_global_row,
100 int src_global_col, int dst_row, int dst_col) {
101 using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar;
102 using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
103 using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar;
104 using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
105 static constexpr int KernelLhsZeroPointInput =
106 ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue;
107 static constexpr int KernelRhsZeroPointInput =
108 ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue;
109 auto acc = Load<RegisterBlockType>(src, src_row, src_col);
110 const auto& lhs_sums_of_each_slice_block =
111 LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
112 const auto& rhs_sums_of_each_slice_block =
113 LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col);
114 auto lhs_offset_block =
115 LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row);
116 auto rhs_offset_block =
117 LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col);
118 AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block);
119 AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block);
120 BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc);
121 for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) {
122 rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth);
123 }
124 BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block),
125 lhs_offset_block, &acc);
126 executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col);
127 }
128
129 template <typename KernelFormat, typename ResultBlockType,
130 typename PackedResultType, typename LhsOffset, typename RhsOffset,
131 typename OutputPipelineType>
UnpackResult(ResultBlockType * dst,const MatrixBlockBounds & dst_block,const PackedResultType & src,int depth,const std::int32_t * lhs_sums_of_each_slice_ptr,const std::int32_t * rhs_sums_of_each_slice_ptr,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)132 void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
133 const PackedResultType& src, int depth,
134 const std::int32_t* lhs_sums_of_each_slice_ptr,
135 const std::int32_t* rhs_sums_of_each_slice_ptr,
136 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
137 const OutputPipelineType& output_pipeline) {
138 ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor
139 ? "unpack to column-major"
140 : "unpack to row-major");
141 assert(dst_block.start_row >= 0);
142 assert(dst_block.start_row + dst_block.rows <= dst->rows());
143 assert(dst_block.start_col >= 0);
144 assert(dst_block.start_col + dst_block.cols <= dst->cols());
145 const auto src_map = src.Map();
146 const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice(
147 lhs_sums_of_each_slice_ptr, dst_block.rows);
148 const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice(
149 rhs_sums_of_each_slice_ptr, dst_block.cols);
150 using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>;
151 using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>;
152 using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>;
153 using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>;
154 using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>;
155 using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>;
156
157 using DstScalarType = typename ResultBlockType::Scalar;
158 using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>;
159
160 OutputPipelineExecutor<OutputPipelineType, Int32x1x1>
161 output_pipeline_executor_1x1(output_pipeline);
162 OutputPipelineExecutor<OutputPipelineType, Int32x4x1>
163 output_pipeline_executor_4x1(output_pipeline);
164 OutputPipelineExecutor<OutputPipelineType, Int32x8x1>
165 output_pipeline_executor_8x1(output_pipeline);
166 OutputPipelineExecutor<OutputPipelineType, Int32x1x4>
167 output_pipeline_executor_1x4(output_pipeline);
168 OutputPipelineExecutor<OutputPipelineType, Int32x4x4>
169 output_pipeline_executor_4x4(output_pipeline);
170 OutputPipelineExecutor<OutputPipelineType, Int32x8x4>
171 output_pipeline_executor_8x4(output_pipeline);
172
173 int c8 = 0;
174 if (ResultBlockType::kOrder == MapOrder::RowMajor) {
175 for (; c8 <= dst_block.cols - 8; c8 += 8) {
176 PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8);
177 int r = 0;
178 for (; r <= dst_block.rows - 8; r += 8) {
179 const int global_row = r + dst_block.start_row;
180 PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8);
181 DstScalarType dst_colmajor_buf[64];
182 MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map(
183 dst_colmajor_buf, 8, 8);
184 for (int cx = 0; cx < 8; cx += 4) {
185 const int c = c8 + cx;
186 const int global_col = c + dst_block.start_col;
187 UnpackResultBlock<KernelFormat, Int32x8x4>(
188 src_map, output_pipeline_executor_8x4, &dst_colmajor_map,
189 lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
190 rhs_offset, depth, r, c, global_row, global_col, 0, cx);
191 }
192 StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst,
193 r + dst_block.start_row, c8 + dst_block.start_col);
194 }
195 for (; r <= dst_block.rows - 4; r += 4) {
196 const int global_row = r + dst_block.start_row;
197 for (int cx = 0; cx < 8; cx += 4) {
198 const int c = c8 + cx;
199 const int global_col = c + dst_block.start_col;
200 UnpackResultBlock<KernelFormat, Int32x4x4>(
201 src_map, output_pipeline_executor_4x4, dst,
202 lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
203 rhs_offset, depth, r, c, global_row, global_col, global_row,
204 global_col);
205 }
206 }
207 for (; r < dst_block.rows; r++) {
208 const int global_row = r + dst_block.start_row;
209 for (int cx = 0; cx < 8; cx += 4) {
210 const int c = c8 + cx;
211 const int global_col = c + dst_block.start_col;
212 UnpackResultBlock<KernelFormat, Int32x1x4>(
213 src_map, output_pipeline_executor_1x4, dst,
214 lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
215 rhs_offset, depth, r, c, global_row, global_col, global_row,
216 global_col);
217 }
218 }
219 }
220 }
221 int c = c8;
222 for (; c <= dst_block.cols - 4; c += 4) {
223 const int global_col = c + dst_block.start_col;
224 PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c);
225 int r = 0;
226 for (; r <= dst_block.rows - 8; r += 8) {
227 const int global_row = r + dst_block.start_row;
228 PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c);
229 UnpackResultBlock<KernelFormat, Int32x8x4>(
230 src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice,
231 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
232 global_row, global_col, global_row, global_col);
233 }
234 for (; r <= dst_block.rows - 4; r += 4) {
235 const int global_row = r + dst_block.start_row;
236 UnpackResultBlock<KernelFormat, Int32x4x4>(
237 src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice,
238 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
239 global_row, global_col, global_row, global_col);
240 }
241 for (; r < dst_block.rows; r++) {
242 const int global_row = r + dst_block.start_row;
243 UnpackResultBlock<KernelFormat, Int32x1x4>(
244 src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice,
245 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
246 global_row, global_col, global_row, global_col);
247 }
248 }
249 for (; c < dst_block.cols; c++) {
250 const int global_col = c + dst_block.start_col;
251 PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c);
252 int r = 0;
253 for (; r <= dst_block.rows - 8; r += 8) {
254 const int global_row = r + dst_block.start_row;
255 PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c);
256 UnpackResultBlock<KernelFormat, Int32x8x1>(
257 src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice,
258 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
259 global_row, global_col, global_row, global_col);
260 }
261 for (; r <= dst_block.rows - 4; r += 4) {
262 const int global_row = r + dst_block.start_row;
263 UnpackResultBlock<KernelFormat, Int32x4x1>(
264 src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice,
265 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
266 global_row, global_col, global_row, global_col);
267 }
268 for (; r < dst_block.rows; r++) {
269 const int global_row = r + dst_block.start_row;
270 UnpackResultBlock<KernelFormat, Int32x1x1>(
271 src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice,
272 rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
273 global_row, global_col, global_row, global_col);
274 }
275 }
276 }
277
278 } // end namespace gemmlowp
279
280 #endif // GEMMLOWP_INTERNAL_UNPACK_H_
281