1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "ruy/block_map.h"
17
18 #include <algorithm>
19 #include <cstdint>
20
21 #ifdef RUY_MAKEBLOCKMAP_DEBUG
22 #include <cstdio>
23 #include <cstdlib>
24 #include <string>
25 #endif
26
27 #include "ruy/check_macros.h"
28 #include "ruy/opt_set.h"
29 #include "ruy/profiler/instrumentation.h"
30 #include "ruy/size_util.h"
31 #include "ruy/trace.h"
32
33 namespace ruy {
34
35 namespace {
36
DecodeTraversalLinear(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)37 void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
38 SidePair<int>* local_pos) {
39 (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
40 (*local_pos)[Side::kRhs] = square_index >> size_log2;
41 }
42
DecodeTraversalFractalZ(std::uint32_t square_index,SidePair<int> * local_pos)43 void DecodeTraversalFractalZ(std::uint32_t square_index,
44 SidePair<int>* local_pos) {
45 const std::uint32_t n1 = square_index;
46 const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
47 ((n1 & 0x22222222u) << 1);
48 const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
49 ((n2 & 0x0c0c0c0cu) << 2);
50 const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
51 ((n4 & 0x00f000f0u) << 4);
52 const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
53 ((n8 & 0x0000ff00u) << 8);
54 (*local_pos)[Side::kLhs] = n16 & 0xffff;
55 (*local_pos)[Side::kRhs] = n16 >> 16;
56 }
57
DecodeTraversalFractalU(std::uint32_t square_index,SidePair<int> * local_pos)58 void DecodeTraversalFractalU(std::uint32_t square_index,
59 SidePair<int>* local_pos) {
60 DecodeTraversalFractalZ(square_index, local_pos);
61 // Change fractal z-order to u-order
62 (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
63 }
64
65 // Code inspired by the sample code in
66 // https://en.wikipedia.org/wiki/Hilbert_curve
67 // The main optimization is to avoid hard-to-predict conditional branches
68 // based on the bits of the square_index parameter.
DecodeTraversalFractalHilbert(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)69 void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
70 SidePair<int>* local_pos) {
71 std::uint32_t t = square_index;
72 std::uint32_t x = 0;
73 std::uint32_t y = 0;
74 // Easy-to-predict for loop, the number of iterations is the same for
75 // an entire GEMM.
76 for (int sb = 0; sb < size_log2; sb++) {
77 std::uint32_t s = 1 << sb;
78 bool rx = t & 2;
79 bool ry = (t & 1) ^ rx;
80 std::uint32_t tmp = rx ? (s - 1 - x) : x;
81 x = ry ? x : rx ? (s - 1 - y) : y;
82 y = ry ? (y + s) : tmp;
83 x = rx ? (x + s) : x;
84 t >>= 2;
85 }
86 (*local_pos)[Side::kLhs] = y;
87 (*local_pos)[Side::kRhs] = x;
88 }
89
90 } // end anonymous namespace
91
GetBlockByIndex(const BlockMap & block_map,int index,SidePair<int> * block)92 void GetBlockByIndex(const BlockMap& block_map, int index,
93 SidePair<int>* block) {
94 profiler::ScopeLabel label("GetBlockByIndex");
95 const std::uint32_t index_u32 = index;
96
97 const std::uint32_t num_blocks_per_local_curve =
98 1u << (2 * block_map.num_blocks_base_log2);
99 const std::uint32_t square_index =
100 index_u32 & (num_blocks_per_local_curve - 1);
101
102 const int size_log2 = block_map.num_blocks_base_log2;
103 SidePair<int> local_pos;
104 switch (block_map.traversal_order) {
105 case BlockMapTraversalOrder::kFractalZ:
106 DecodeTraversalFractalZ(square_index, &local_pos);
107 break;
108 case BlockMapTraversalOrder::kFractalU:
109 DecodeTraversalFractalU(square_index, &local_pos);
110 break;
111 case BlockMapTraversalOrder::kFractalHilbert:
112 DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
113 break;
114 default:
115 RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
116 DecodeTraversalLinear(size_log2, square_index, &local_pos);
117 break;
118 }
119
120 const std::uint32_t rectangular_index =
121 index_u32 >> 2 * block_map.num_blocks_base_log2;
122 for (Side side : {Side::kLhs, Side::kRhs}) {
123 const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
124 const int rectangular_offset = (rectangular_index & mask)
125 << block_map.num_blocks_base_log2;
126 (*block)[side] = local_pos[side] + rectangular_offset;
127 }
128 }
129
130 namespace {
131
GetTraversalOrder(int rows_after_rectangularness_division,int cols_after_rectangularness_division,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)132 BlockMapTraversalOrder GetTraversalOrder(
133 int rows_after_rectangularness_division,
134 int cols_after_rectangularness_division, int depth, int lhs_scalar_size,
135 int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) {
136 static constexpr bool kAnyFractal =
137 RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT);
138 const int working_set_size =
139 (lhs_scalar_size * rows_after_rectangularness_division +
140 rhs_scalar_size * cols_after_rectangularness_division) *
141 depth;
142 if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) {
143 if (RUY_OPT(FRACTAL_HILBERT) &&
144 (working_set_size > cpu_cache_params.last_level_cache_size)) {
145 return BlockMapTraversalOrder::kFractalHilbert;
146 } else if (RUY_OPT(FRACTAL_U)) {
147 return BlockMapTraversalOrder::kFractalU;
148 } else {
149 return BlockMapTraversalOrder::kFractalZ;
150 }
151 } else {
152 return BlockMapTraversalOrder::kLinear;
153 }
154 }
155
floor_log2_quotient(int num,int denom)156 int floor_log2_quotient(int num, int denom) {
157 if (num <= denom) {
158 return 0;
159 }
160 int log2_quotient = floor_log2(num) - ceil_log2(denom);
161 if ((denom << (log2_quotient + 1)) <= num) {
162 log2_quotient++;
163 }
164 return log2_quotient;
165 }
166
167 // Computes the rectangularness of the matrix shape (rows, cols). This is
168 // essentially just the log2 of the quotient (rows / cols). The kernel_rows and
169 // kernel_cols only get into the picture for clamping bounds but don't affect
170 // the generic computation.
GetRectangularness(int rows,int cols,int kernel_rows,int kernel_cols,int * rows_rectangularness_log2,int * cols_rectangularness_log2)171 void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
172 int* rows_rectangularness_log2,
173 int* cols_rectangularness_log2) {
174 *rows_rectangularness_log2 = 0;
175 *cols_rectangularness_log2 = 0;
176
177 // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
178 // itself, we risk having too small kernel blocks for good kernel
179 // amortization. We avoid that by limiting recangularness so that kernel
180 // blocks are not too tiny at least in that dimension. Specifically, we try to
181 // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
182 // kernel block along the large dimension.
183 const int min_kernel_inner_loop_runs_log2 = 3;
184 if (rows > cols) {
185 int cols_of_kernel_inner_loop_runs_log2 =
186 ceil_log2(cols) - pot_log2(kernel_cols);
187 int min_rows_of_kernel_inner_loop_runs_log2 =
188 std::max(0, min_kernel_inner_loop_runs_log2 -
189 cols_of_kernel_inner_loop_runs_log2);
190 *rows_rectangularness_log2 =
191 std::min(floor_log2_quotient(rows, cols),
192 std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
193 min_rows_of_kernel_inner_loop_runs_log2));
194 // Sanity check that we did not over-estimate rows_rectangularness_log2.
195 RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
196 } else if (cols > rows) {
197 int rows_of_kernel_inner_loop_runs_log2 =
198 ceil_log2(rows) - pot_log2(kernel_rows);
199 int min_cols_of_kernel_inner_loop_runs_log2 =
200 std::max(0, min_kernel_inner_loop_runs_log2 -
201 rows_of_kernel_inner_loop_runs_log2);
202 *cols_rectangularness_log2 =
203 std::min(floor_log2_quotient(cols, rows),
204 std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
205 min_cols_of_kernel_inner_loop_runs_log2));
206 // Sanity check that we did not over-estimate cols_rectangularness_log2.
207 RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
208 }
209 RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
210 }
211
212 // Computes a 'multithreading score'. When multithreading, we need there to
213 // be at least as many tiles as there are threads, and hopefully
214 // substantially more than that, so we benefit from ruy's ability to
215 // dispatch fine-grained workloads to threads.
GetMultithreadingScore(int block_size_log2,int rows,int cols,int tentative_thread_count)216 int GetMultithreadingScore(int block_size_log2, int rows, int cols,
217 int tentative_thread_count) {
218 const int num_full_blocks_of_rows = rows >> block_size_log2;
219 const int num_full_blocks_of_cols = cols >> block_size_log2;
220 const int candidate_num_full_blocks_log2 = floor_log2(
221 std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
222
223 // The values here have been tuned on ARM Cortex-A55.
224 // We expect this to have to be tuned differently for other CPUs.
225 if (tentative_thread_count == 1) {
226 return 0;
227 } else {
228 const int blocks_per_thread_log2 =
229 candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
230 if (blocks_per_thread_log2 < 0) {
231 return -64;
232 } else if (blocks_per_thread_log2 == 0) {
233 return -16;
234 } else if (blocks_per_thread_log2 == 1) {
235 return -8;
236 } else if (blocks_per_thread_log2 == 2) {
237 return 0;
238 } else if (blocks_per_thread_log2 == 3) {
239 return 8;
240 } else {
241 return 16;
242 }
243 }
244 }
245
246 // Computes a 'cache locality score'.
GetCacheLocalityScore(int block_size_log2,int rows,int cols,int depth,int kernel_rows_log2,int kernel_cols_log2,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)247 int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
248 int kernel_rows_log2, int kernel_cols_log2,
249 int lhs_scalar_size, int rhs_scalar_size,
250 const CpuCacheParams& cpu_cache_params) {
251 // In the narrow case (e.g. matrix*vector), each byte of the big operand
252 // matrix (either LHS or RHS) is traversed only once, so any notion of data
253 // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
254 // be 0 in that case.
255 if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
256 return 0;
257 }
258 const int block_rows = std::min(1 << block_size_log2, rows);
259 const int block_cols = std::min(1 << block_size_log2, cols);
260 const int total_read_bytes =
261 (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
262 const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
263 const int nonlocality_log2 =
264 total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size);
265 // The values here have been tuned on ARM Cortex-A55.
266 // We expect this to have to be tuned differently for other CPUs.
267 if (nonlocality_log2 < -1) {
268 return 64;
269 } else if (nonlocality_log2 == -1) {
270 return 56;
271 } else if (nonlocality_log2 == 0) {
272 return 48;
273 } else if (nonlocality_log2 == 1) {
274 return 32;
275 } else if (nonlocality_log2 == 2) {
276 return 16;
277 } else if (nonlocality_log2 == 3) {
278 return 0;
279 } else {
280 return -64;
281 }
282 }
283
284 // Compute a 'kernel amortization score'. This is the notion that very small
285 // tiles result in more overhead outside of kernels, more complex memory
286 // access patterns and less benefits from ruy's fat kernels, so we reward
287 // larger blocks more than smaller ones.
GetKernelAmortizationScore(int block_size_log2,int rows,int cols,int kernel_rows_log2,int kernel_cols_log2)288 int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
289 int kernel_rows_log2, int kernel_cols_log2) {
290 const int block_rows = std::min(1 << block_size_log2, rows);
291 const int block_cols = std::min(1 << block_size_log2, cols);
292 const int kernels_per_block_log2 =
293 floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
294 RUY_DCHECK_GE(kernels_per_block_log2, 0);
295 // The values here have been tuned on ARM Cortex-A55.
296 // We expect this to have to be tuned differently for other CPUs.
297 if (kernels_per_block_log2 == 0) {
298 return 0;
299 } else if (kernels_per_block_log2 == 1) {
300 return 8;
301 } else if (kernels_per_block_log2 == 2) {
302 return 16;
303 } else if (kernels_per_block_log2 == 3) {
304 return 24;
305 } else if (kernels_per_block_log2 == 4) {
306 return 32;
307 } else if (kernels_per_block_log2 == 5) {
308 return 40;
309 } else if (kernels_per_block_log2 == 6) {
310 return 48;
311 } else if (kernels_per_block_log2 == 7) {
312 return 56;
313 } else {
314 return 64;
315 }
316 }
317
318 } // namespace
319
IsObviouslyLinearTraversal(int rows,int cols,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)320 bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
321 int lhs_scalar_size, int rhs_scalar_size,
322 const CpuCacheParams& cpu_cache_params) {
323 if (rows == 1 || cols == 1) {
324 return true;
325 }
326 // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided
327 // by the rectangularness factors, since any non-linear traversal order will
328 // be local to each subdivision. In the present function, we don't know the
329 // rectangularness factors yet, and we can't just call GetRectangularness
330 // as that requires knowing the kernel block layout. Since we just want
331 // a coarse estimate with only the guarantee that if we return `true` then
332 // linear traversal will be used, it is OK here to over-estimate `rows` and
333 // `cols`, by omitting to divide them by the rectangularness factors.ß
334 return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
335 cpu_cache_params) == BlockMapTraversalOrder::kLinear;
336 }
337
MakeBlockMap(int rows,int cols,int depth,int kernel_rows,int kernel_cols,int lhs_scalar_size,int rhs_scalar_size,int tentative_thread_count,const CpuCacheParams & cpu_cache_params,BlockMap * block_map)338 void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
339 int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
340 int tentative_thread_count,
341 const CpuCacheParams& cpu_cache_params, BlockMap* block_map) {
342 RUY_TRACE_SCOPE;
343 profiler::ScopeLabel label("MakeBlockMap");
344
345 RUY_DCHECK_GE(rows, kernel_rows);
346 RUY_DCHECK_GE(cols, kernel_cols);
347 RUY_DCHECK_EQ(rows % kernel_rows, 0);
348 RUY_DCHECK_EQ(cols % kernel_cols, 0);
349
350 // Estimate the 'rectangularness', the first level of subdivision bringing
351 // the shape to within 2x of a square shape.
352 int rows_rectangularness_log2 = 0;
353 int cols_rectangularness_log2 = 0;
354 GetRectangularness(rows, cols, kernel_rows, kernel_cols,
355 &rows_rectangularness_log2, &cols_rectangularness_log2);
356
357 const int kernel_rows_log2 = pot_log2(kernel_rows);
358 const int kernel_cols_log2 = pot_log2(kernel_cols);
359 const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
360
361 const int size = std::min(rows, cols);
362 const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
363
364 RUY_DCHECK_GE(size_log2, kernel_size_log2);
365
366 // Heuristic selecting the power-of-two grid subdivision insider of each
367 // square-ish region (past the above subdivision by 'rectangularness').
368 // Note that it is the number of subdivisions, not the resulting block size,
369 // that will be a power of two. But inside of that heuristic, it simplifies
370 // code to talk in terms of 'block_size_log2', as if it were the block size
371 // that were a power of two. This 'block_size_log2' is to be interpreted as
372 // "log2 rounded below", e.g. when block_size_log2=8 we might have a block
373 // size in [256, 511]. When the shape is non-square, rows!=cols, this
374 // refers to the smaller of the two, so the other might be as large as
375 // 1021 (can't be 1022 because following the above 'rectangularness'
376 // subdivision, the aspect ratio is already < 2).
377
378 // We are going to try candidate values for block_size_log2 ranging from
379 // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
380 // For each of them we will compute a 'score' by adding individual scores
381 // for a few different considerations, all of which is entirely empirical.
382 // The values (and possibly the logic) around here are all subject to tuning
383 // based on benchmarks on different hardware. The current values are based
384 // on benchmarking on Qualcomm S855 (big and little cores), arm64,
385 // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
386 // and tune this as needed to achieve good performance elsewhere. Use
387 // the unit test, block_map_test, to encode values that should be preserved
388 // on specific architectures. Use RUY_TRACE to debug the current heuristics
389 // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a
390 // different block_size_log2 choice, to empirically find the optimal value
391 // before getting to updating the heuristic so that it produces that value.
392 static constexpr int kMaxKernelsPerBlockLog2 = 6;
393 const int max_block_size_log2 =
394 std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
395 int best_score = std::numeric_limits<int>::min();
396 int best_score_block_size_log2 = -1;
397 RUY_TRACE_INFO(MAKE_BLOCK_MAP_START);
398 for (int block_size_log2 = kernel_size_log2;
399 block_size_log2 <= max_block_size_log2; block_size_log2++) {
400 const int multithreading_score = GetMultithreadingScore(
401 block_size_log2, rows, cols, tentative_thread_count);
402 const int cache_locality_score = GetCacheLocalityScore(
403 block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
404 lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
405 const int kernel_amortization_score = GetKernelAmortizationScore(
406 block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
407 const int score =
408 multithreading_score + cache_locality_score + kernel_amortization_score;
409 if (score >= best_score) {
410 best_score = score;
411 best_score_block_size_log2 = block_size_log2;
412 }
413 RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE);
414 }
415
416 #ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2
417 // Useful for tuning.
418 best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2;
419 #endif
420
421 // As explained in the above comment, phrasing the above code in terms of
422 // block_size_log2 was only convenience inside of that heuristic. Now we
423 // revert to talking in terms of grid subdivision. That is what will actually
424 // be powers of two.
425 int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
426 RUY_DCHECK_GE(num_blocks_base_log2, 0);
427 const int num_blocks_of_rows_log2 =
428 num_blocks_base_log2 + rows_rectangularness_log2;
429 const int num_blocks_of_cols_log2 =
430 num_blocks_base_log2 + cols_rectangularness_log2;
431
432 // Now that we know the grid subdivision, we can pinpoint the exact block
433 // sizes. They can't be powers of two in general; they can't even be all
434 // equal in general; so the following few parameters will govern how blocks
435 // of slightly different shapes are put together in the block map.
436 const int small_block_rows =
437 round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
438 const int small_block_cols =
439 round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
440 const int rows_of_large_blocks =
441 round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2),
442 kernel_rows) >>
443 pot_log2(kernel_rows);
444 const int cols_of_large_blocks =
445 round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2),
446 kernel_cols) >>
447 pot_log2(kernel_cols);
448
449 // We have everything! Write out to the destination block_map.
450 block_map->dims[Side::kLhs] = rows;
451 block_map->dims[Side::kRhs] = cols;
452 block_map->kernel_dims[Side::kLhs] = kernel_rows;
453 block_map->kernel_dims[Side::kRhs] = kernel_cols;
454 block_map->num_blocks_base_log2 = num_blocks_base_log2;
455 block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
456 block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
457 block_map->small_block_dims[Side::kLhs] = small_block_rows;
458 block_map->small_block_dims[Side::kRhs] = small_block_cols;
459 block_map->large_blocks[Side::kLhs] = rows_of_large_blocks;
460 block_map->large_blocks[Side::kRhs] = cols_of_large_blocks;
461 // See the comment on GetTraversalOrder for why we are dividing `rows` and
462 // `cols` by the rectangularness subdivision parameters here.
463 block_map->traversal_order = GetTraversalOrder(
464 rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2,
465 depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
466 // Done last: NumBlocks needs some of the block_map fields to be already set.
467 block_map->thread_count =
468 std::min(tentative_thread_count, NumBlocks(*block_map));
469 RUY_TRACE_INFO(MAKE_BLOCK_MAP_END);
470 }
471
GetBlockMatrixCoords(Side side,const BlockMap & block_map,int block,int * start,int * end)472 void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
473 int* start, int* end) {
474 profiler::ScopeLabel label("GetBlockMatrixCoords");
475 *start = block * block_map.small_block_dims[side] +
476 std::min(block, block_map.large_blocks[side]) *
477 block_map.kernel_dims[side];
478 *end =
479 *start + block_map.small_block_dims[side] +
480 (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
481
482 RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
483 RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
484 RUY_DCHECK_LE(*end, block_map.dims[side]);
485 RUY_DCHECK_LT(*start, *end);
486 RUY_DCHECK_GE(*start, 0);
487 }
488
GetBlockMatrixCoords(const BlockMap & block_map,const SidePair<int> & block,SidePair<int> * start,SidePair<int> * end)489 void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
490 SidePair<int>* start, SidePair<int>* end) {
491 for (Side side : {Side::kLhs, Side::kRhs}) {
492 GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
493 &(*end)[side]);
494 }
495 }
496
497 } // namespace ruy
498