1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "ruy/block_map.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <limits>
21
22 #ifdef RUY_MAKEBLOCKMAP_DEBUG
23 #include <cstdio>
24 #include <cstdlib>
25 #include <string>
26 #endif
27
28 #include "ruy/check_macros.h"
29 #include "ruy/opt_set.h"
30 #include "ruy/profiler/instrumentation.h"
31 #include "ruy/size_util.h"
32 #include "ruy/trace.h"
33
34 namespace ruy {
35
36 namespace {
37
DecodeTraversalLinear(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)38 void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
39 SidePair<int>* local_pos) {
40 (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
41 (*local_pos)[Side::kRhs] = square_index >> size_log2;
42 }
43
DecodeTraversalFractalZ(std::uint32_t square_index,SidePair<int> * local_pos)44 void DecodeTraversalFractalZ(std::uint32_t square_index,
45 SidePair<int>* local_pos) {
46 const std::uint32_t n1 = square_index;
47 const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
48 ((n1 & 0x22222222u) << 1);
49 const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
50 ((n2 & 0x0c0c0c0cu) << 2);
51 const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
52 ((n4 & 0x00f000f0u) << 4);
53 const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
54 ((n8 & 0x0000ff00u) << 8);
55 (*local_pos)[Side::kLhs] = n16 & 0xffff;
56 (*local_pos)[Side::kRhs] = n16 >> 16;
57 }
58
DecodeTraversalFractalU(std::uint32_t square_index,SidePair<int> * local_pos)59 void DecodeTraversalFractalU(std::uint32_t square_index,
60 SidePair<int>* local_pos) {
61 DecodeTraversalFractalZ(square_index, local_pos);
62 // Change fractal z-order to u-order
63 (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
64 }
65
66 // Code inspired by the sample code in
67 // https://en.wikipedia.org/wiki/Hilbert_curve
68 // The main optimization is to avoid hard-to-predict conditional branches
69 // based on the bits of the square_index parameter.
DecodeTraversalFractalHilbert(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)70 void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
71 SidePair<int>* local_pos) {
72 std::uint32_t t = square_index;
73 std::uint32_t x = 0;
74 std::uint32_t y = 0;
75 // Easy-to-predict for loop, the number of iterations is the same for
76 // an entire GEMM.
77 for (int sb = 0; sb < size_log2; sb++) {
78 std::uint32_t s = 1 << sb;
79 bool rx = t & 2;
80 bool ry = (t & 1) ^ rx;
81 std::uint32_t tmp = rx ? (s - 1 - x) : x;
82 x = ry ? x : rx ? (s - 1 - y) : y;
83 y = ry ? (y + s) : tmp;
84 x = rx ? (x + s) : x;
85 t >>= 2;
86 }
87 (*local_pos)[Side::kLhs] = y;
88 (*local_pos)[Side::kRhs] = x;
89 }
90
91 } // end anonymous namespace
92
GetBlockByIndex(const BlockMap & block_map,int index,SidePair<int> * block)93 void GetBlockByIndex(const BlockMap& block_map, int index,
94 SidePair<int>* block) {
95 profiler::ScopeLabel label("GetBlockByIndex");
96 const std::uint32_t index_u32 = index;
97
98 const std::uint32_t num_blocks_per_local_curve =
99 1u << (2 * block_map.num_blocks_base_log2);
100 const std::uint32_t square_index =
101 index_u32 & (num_blocks_per_local_curve - 1);
102
103 const int size_log2 = block_map.num_blocks_base_log2;
104 SidePair<int> local_pos;
105 switch (block_map.traversal_order) {
106 case BlockMapTraversalOrder::kFractalZ:
107 DecodeTraversalFractalZ(square_index, &local_pos);
108 break;
109 case BlockMapTraversalOrder::kFractalU:
110 DecodeTraversalFractalU(square_index, &local_pos);
111 break;
112 case BlockMapTraversalOrder::kFractalHilbert:
113 DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
114 break;
115 default:
116 RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
117 DecodeTraversalLinear(size_log2, square_index, &local_pos);
118 break;
119 }
120
121 const std::uint32_t rectangular_index =
122 index_u32 >> 2 * block_map.num_blocks_base_log2;
123 for (Side side : {Side::kLhs, Side::kRhs}) {
124 const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
125 const int rectangular_offset = (rectangular_index & mask)
126 << block_map.num_blocks_base_log2;
127 (*block)[side] = local_pos[side] + rectangular_offset;
128 }
129 }
130
131 namespace {
132
GetTraversalOrder(int rows_after_rectangularness_division,int cols_after_rectangularness_division,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)133 BlockMapTraversalOrder GetTraversalOrder(
134 int rows_after_rectangularness_division,
135 int cols_after_rectangularness_division, int depth, int lhs_scalar_size,
136 int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) {
137 static constexpr bool kAnyFractal =
138 RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT);
139 const int working_set_size =
140 (lhs_scalar_size * rows_after_rectangularness_division +
141 rhs_scalar_size * cols_after_rectangularness_division) *
142 depth;
143 if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) {
144 if (RUY_OPT(FRACTAL_HILBERT) &&
145 (working_set_size > cpu_cache_params.last_level_cache_size)) {
146 return BlockMapTraversalOrder::kFractalHilbert;
147 } else if (RUY_OPT(FRACTAL_U)) {
148 return BlockMapTraversalOrder::kFractalU;
149 } else {
150 return BlockMapTraversalOrder::kFractalZ;
151 }
152 } else {
153 return BlockMapTraversalOrder::kLinear;
154 }
155 }
156
floor_log2_quotient(int num,int denom)157 int floor_log2_quotient(int num, int denom) {
158 if (num <= denom) {
159 return 0;
160 }
161 int log2_quotient = floor_log2(num) - ceil_log2(denom);
162 if ((denom << (log2_quotient + 1)) <= num) {
163 log2_quotient++;
164 }
165 return log2_quotient;
166 }
167
168 // Computes the rectangularness of the matrix shape (rows, cols). This is
169 // essentially just the log2 of the quotient (rows / cols). The kernel_rows and
170 // kernel_cols only get into the picture for clamping bounds but don't affect
171 // the generic computation.
GetRectangularness(int rows,int cols,int kernel_rows,int kernel_cols,int * rows_rectangularness_log2,int * cols_rectangularness_log2)172 void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
173 int* rows_rectangularness_log2,
174 int* cols_rectangularness_log2) {
175 *rows_rectangularness_log2 = 0;
176 *cols_rectangularness_log2 = 0;
177
178 // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
179 // itself, we risk having too small kernel blocks for good kernel
180 // amortization. We avoid that by limiting recangularness so that kernel
181 // blocks are not too tiny at least in that dimension. Specifically, we try to
182 // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
183 // kernel block along the large dimension.
184 const int min_kernel_inner_loop_runs_log2 = 3;
185 if (rows > cols) {
186 int cols_of_kernel_inner_loop_runs_log2 =
187 ceil_log2(cols) - pot_log2(kernel_cols);
188 int min_rows_of_kernel_inner_loop_runs_log2 =
189 std::max(0, min_kernel_inner_loop_runs_log2 -
190 cols_of_kernel_inner_loop_runs_log2);
191 *rows_rectangularness_log2 =
192 std::min(floor_log2_quotient(rows, cols),
193 std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
194 min_rows_of_kernel_inner_loop_runs_log2));
195 // Sanity check that we did not over-estimate rows_rectangularness_log2.
196 RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
197 } else if (cols > rows) {
198 int rows_of_kernel_inner_loop_runs_log2 =
199 ceil_log2(rows) - pot_log2(kernel_rows);
200 int min_cols_of_kernel_inner_loop_runs_log2 =
201 std::max(0, min_kernel_inner_loop_runs_log2 -
202 rows_of_kernel_inner_loop_runs_log2);
203 *cols_rectangularness_log2 =
204 std::min(floor_log2_quotient(cols, rows),
205 std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
206 min_cols_of_kernel_inner_loop_runs_log2));
207 // Sanity check that we did not over-estimate cols_rectangularness_log2.
208 RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
209 }
210 RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
211 }
212
213 // Computes a 'multithreading score'. When multithreading, we need there to
214 // be at least as many tiles as there are threads, and hopefully
215 // substantially more than that, so we benefit from ruy's ability to
216 // dispatch fine-grained workloads to threads.
GetMultithreadingScore(int block_size_log2,int rows,int cols,int tentative_thread_count)217 int GetMultithreadingScore(int block_size_log2, int rows, int cols,
218 int tentative_thread_count) {
219 const int num_full_blocks_of_rows = rows >> block_size_log2;
220 const int num_full_blocks_of_cols = cols >> block_size_log2;
221 const int candidate_num_full_blocks_log2 = floor_log2(
222 std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
223
224 // The values here have been tuned on ARM Cortex-A55.
225 // We expect this to have to be tuned differently for other CPUs.
226 if (tentative_thread_count == 1) {
227 return 0;
228 } else {
229 const int blocks_per_thread_log2 =
230 candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
231 if (blocks_per_thread_log2 < 0) {
232 return -64;
233 } else if (blocks_per_thread_log2 == 0) {
234 return -16;
235 } else if (blocks_per_thread_log2 == 1) {
236 return -8;
237 } else if (blocks_per_thread_log2 == 2) {
238 return 0;
239 } else if (blocks_per_thread_log2 == 3) {
240 return 8;
241 } else {
242 return 16;
243 }
244 }
245 }
246
247 // Computes a 'cache locality score'.
GetCacheLocalityScore(int block_size_log2,int rows,int cols,int depth,int kernel_rows_log2,int kernel_cols_log2,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)248 int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
249 int kernel_rows_log2, int kernel_cols_log2,
250 int lhs_scalar_size, int rhs_scalar_size,
251 const CpuCacheParams& cpu_cache_params) {
252 // In the narrow case (e.g. matrix*vector), each byte of the big operand
253 // matrix (either LHS or RHS) is traversed only once, so any notion of data
254 // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
255 // be 0 in that case.
256 if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
257 return 0;
258 }
259 const int block_rows = std::min(1 << block_size_log2, rows);
260 const int block_cols = std::min(1 << block_size_log2, cols);
261 const int total_read_bytes =
262 (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
263 const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
264 const int nonlocality_log2 =
265 total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size);
266 // The values here have been tuned on ARM Cortex-A55.
267 // We expect this to have to be tuned differently for other CPUs.
268 if (nonlocality_log2 < -1) {
269 return 64;
270 } else if (nonlocality_log2 == -1) {
271 return 56;
272 } else if (nonlocality_log2 == 0) {
273 return 48;
274 } else if (nonlocality_log2 == 1) {
275 return 32;
276 } else if (nonlocality_log2 == 2) {
277 return 16;
278 } else if (nonlocality_log2 == 3) {
279 return 0;
280 } else {
281 return -64;
282 }
283 }
284
285 // Compute a 'kernel amortization score'. This is the notion that very small
286 // tiles result in more overhead outside of kernels, more complex memory
287 // access patterns and less benefits from ruy's fat kernels, so we reward
288 // larger blocks more than smaller ones.
GetKernelAmortizationScore(int block_size_log2,int rows,int cols,int kernel_rows_log2,int kernel_cols_log2)289 int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
290 int kernel_rows_log2, int kernel_cols_log2) {
291 const int block_rows = std::min(1 << block_size_log2, rows);
292 const int block_cols = std::min(1 << block_size_log2, cols);
293 const int kernels_per_block_log2 =
294 floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
295 RUY_DCHECK_GE(kernels_per_block_log2, 0);
296 // The values here have been tuned on ARM Cortex-A55.
297 // We expect this to have to be tuned differently for other CPUs.
298 if (kernels_per_block_log2 == 0) {
299 return 0;
300 } else if (kernels_per_block_log2 == 1) {
301 return 8;
302 } else if (kernels_per_block_log2 == 2) {
303 return 16;
304 } else if (kernels_per_block_log2 == 3) {
305 return 24;
306 } else if (kernels_per_block_log2 == 4) {
307 return 32;
308 } else if (kernels_per_block_log2 == 5) {
309 return 40;
310 } else if (kernels_per_block_log2 == 6) {
311 return 48;
312 } else if (kernels_per_block_log2 == 7) {
313 return 56;
314 } else {
315 return 64;
316 }
317 }
318
319 } // namespace
320
IsObviouslyLinearTraversal(int rows,int cols,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)321 bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
322 int lhs_scalar_size, int rhs_scalar_size,
323 const CpuCacheParams& cpu_cache_params) {
324 if (rows == 1 || cols == 1) {
325 return true;
326 }
327 // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided
328 // by the rectangularness factors, since any non-linear traversal order will
329 // be local to each subdivision. In the present function, we don't know the
330 // rectangularness factors yet, and we can't just call GetRectangularness
331 // as that requires knowing the kernel block layout. Since we just want
332 // a coarse estimate with only the guarantee that if we return `true` then
333 // linear traversal will be used, it is OK here to over-estimate `rows` and
334 // `cols`, by omitting to divide them by the rectangularness factors.
335 return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
336 cpu_cache_params) == BlockMapTraversalOrder::kLinear;
337 }
338
MakeBlockMap(int rows,int cols,int depth,int kernel_rows,int kernel_cols,int lhs_scalar_size,int rhs_scalar_size,int tentative_thread_count,const CpuCacheParams & cpu_cache_params,BlockMap * block_map)339 void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
340 int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
341 int tentative_thread_count,
342 const CpuCacheParams& cpu_cache_params, BlockMap* block_map) {
343 RUY_TRACE_SCOPE;
344 profiler::ScopeLabel label("MakeBlockMap");
345
346 RUY_DCHECK_GE(rows, kernel_rows);
347 RUY_DCHECK_GE(cols, kernel_cols);
348 RUY_DCHECK_EQ(rows % kernel_rows, 0);
349 RUY_DCHECK_EQ(cols % kernel_cols, 0);
350
351 // Estimate the 'rectangularness', the first level of subdivision bringing
352 // the shape to within 2x of a square shape.
353 int rows_rectangularness_log2 = 0;
354 int cols_rectangularness_log2 = 0;
355 GetRectangularness(rows, cols, kernel_rows, kernel_cols,
356 &rows_rectangularness_log2, &cols_rectangularness_log2);
357
358 const int kernel_rows_log2 = pot_log2(kernel_rows);
359 const int kernel_cols_log2 = pot_log2(kernel_cols);
360 const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
361
362 const int size = std::min(rows, cols);
363 const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
364
365 RUY_DCHECK_GE(size_log2, kernel_size_log2);
366
367 // Heuristic selecting the power-of-two grid subdivision insider of each
368 // square-ish region (past the above subdivision by 'rectangularness').
369 // Note that it is the number of subdivisions, not the resulting block size,
370 // that will be a power of two. But inside of that heuristic, it simplifies
371 // code to talk in terms of 'block_size_log2', as if it were the block size
372 // that were a power of two. This 'block_size_log2' is to be interpreted as
373 // "log2 rounded below", e.g. when block_size_log2=8 we might have a block
374 // size in [256, 511]. When the shape is non-square, rows!=cols, this
375 // refers to the smaller of the two, so the other might be as large as
376 // 1021 (can't be 1022 because following the above 'rectangularness'
377 // subdivision, the aspect ratio is already < 2).
378
379 // We are going to try candidate values for block_size_log2 ranging from
380 // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
381 // For each of them we will compute a 'score' by adding individual scores
382 // for a few different considerations, all of which is entirely empirical.
383 // The values (and possibly the logic) around here are all subject to tuning
384 // based on benchmarks on different hardware. The current values are based
385 // on benchmarking on Qualcomm S855 (big and little cores), arm64,
386 // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
387 // and tune this as needed to achieve good performance elsewhere. Use
388 // the unit test, block_map_test, to encode values that should be preserved
389 // on specific architectures. Use RUY_TRACE to debug the current heuristics
390 // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a
391 // different block_size_log2 choice, to empirically find the optimal value
392 // before getting to updating the heuristic so that it produces that value.
393 static constexpr int kMaxKernelsPerBlockLog2 = 6;
394 const int max_block_size_log2 =
395 std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
396 int best_score = std::numeric_limits<int>::min();
397 int best_score_block_size_log2 = -1;
398 RUY_TRACE_INFO(MAKE_BLOCK_MAP_START);
399 for (int block_size_log2 = kernel_size_log2;
400 block_size_log2 <= max_block_size_log2; block_size_log2++) {
401 const int multithreading_score = GetMultithreadingScore(
402 block_size_log2, rows, cols, tentative_thread_count);
403 const int cache_locality_score = GetCacheLocalityScore(
404 block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
405 lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
406 const int kernel_amortization_score = GetKernelAmortizationScore(
407 block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
408 const int score =
409 multithreading_score + cache_locality_score + kernel_amortization_score;
410 if (score >= best_score) {
411 best_score = score;
412 best_score_block_size_log2 = block_size_log2;
413 }
414 RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE);
415 }
416
417 #ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2
418 // Useful for tuning.
419 best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2;
420 #endif
421
422 // As explained in the above comment, phrasing the above code in terms of
423 // block_size_log2 was only convenience inside of that heuristic. Now we
424 // revert to talking in terms of grid subdivision. That is what will actually
425 // be powers of two.
426 int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
427 RUY_DCHECK_GE(num_blocks_base_log2, 0);
428 const int num_blocks_of_rows_log2 =
429 num_blocks_base_log2 + rows_rectangularness_log2;
430 const int num_blocks_of_cols_log2 =
431 num_blocks_base_log2 + cols_rectangularness_log2;
432
433 // Now that we know the grid subdivision, we can pinpoint the exact block
434 // sizes. They can't be powers of two in general; they can't even be all
435 // equal in general; so the following few parameters will govern how blocks
436 // of slightly different shapes are put together in the block map.
437 const int small_block_rows =
438 round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
439 const int small_block_cols =
440 round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
441 const int rows_of_large_blocks =
442 round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2),
443 kernel_rows) >>
444 pot_log2(kernel_rows);
445 const int cols_of_large_blocks =
446 round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2),
447 kernel_cols) >>
448 pot_log2(kernel_cols);
449
450 // We have everything! Write out to the destination block_map.
451 block_map->dims[Side::kLhs] = rows;
452 block_map->dims[Side::kRhs] = cols;
453 block_map->kernel_dims[Side::kLhs] = kernel_rows;
454 block_map->kernel_dims[Side::kRhs] = kernel_cols;
455 block_map->num_blocks_base_log2 = num_blocks_base_log2;
456 block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
457 block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
458 block_map->small_block_dims[Side::kLhs] = small_block_rows;
459 block_map->small_block_dims[Side::kRhs] = small_block_cols;
460 block_map->large_blocks[Side::kLhs] = rows_of_large_blocks;
461 block_map->large_blocks[Side::kRhs] = cols_of_large_blocks;
462 // See the comment on GetTraversalOrder for why we are dividing `rows` and
463 // `cols` by the rectangularness subdivision parameters here.
464 block_map->traversal_order = GetTraversalOrder(
465 rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2,
466 depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
467 // Done last: NumBlocks needs some of the block_map fields to be already set.
468 block_map->thread_count =
469 std::min(tentative_thread_count, NumBlocks(*block_map));
470 RUY_TRACE_INFO(MAKE_BLOCK_MAP_END);
471 }
472
GetBlockMatrixCoords(Side side,const BlockMap & block_map,int block,int * start,int * end)473 void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
474 int* start, int* end) {
475 profiler::ScopeLabel label("GetBlockMatrixCoords");
476 *start = block * block_map.small_block_dims[side] +
477 std::min(block, block_map.large_blocks[side]) *
478 block_map.kernel_dims[side];
479 *end =
480 *start + block_map.small_block_dims[side] +
481 (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
482
483 RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
484 RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
485 RUY_DCHECK_LE(*end, block_map.dims[side]);
486 RUY_DCHECK_LT(*start, *end);
487 RUY_DCHECK_GE(*start, 0);
488 }
489
GetBlockMatrixCoords(const BlockMap & block_map,const SidePair<int> & block,SidePair<int> * start,SidePair<int> * end)490 void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
491 SidePair<int>* start, SidePair<int>* end) {
492 for (Side side : {Side::kLhs, Side::kRhs}) {
493 GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
494 &(*end)[side]);
495 }
496 }
497
498 } // namespace ruy
499