• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/block_map.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #ifdef RUY_MAKEBLOCKMAP_DEBUG
22 #include <cstdio>
23 #include <cstdlib>
24 #include <string>
25 #endif
26 
27 #include "ruy/check_macros.h"
28 #include "ruy/opt_set.h"
29 #include "ruy/profiler/instrumentation.h"
30 #include "ruy/size_util.h"
31 #include "ruy/trace.h"
32 
33 namespace ruy {
34 
35 namespace {
36 
DecodeTraversalLinear(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)37 void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
38                            SidePair<int>* local_pos) {
39   (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
40   (*local_pos)[Side::kRhs] = square_index >> size_log2;
41 }
42 
DecodeTraversalFractalZ(std::uint32_t square_index,SidePair<int> * local_pos)43 void DecodeTraversalFractalZ(std::uint32_t square_index,
44                              SidePair<int>* local_pos) {
45   const std::uint32_t n1 = square_index;
46   const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
47                            ((n1 & 0x22222222u) << 1);
48   const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
49                            ((n2 & 0x0c0c0c0cu) << 2);
50   const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
51                            ((n4 & 0x00f000f0u) << 4);
52   const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
53                             ((n8 & 0x0000ff00u) << 8);
54   (*local_pos)[Side::kLhs] = n16 & 0xffff;
55   (*local_pos)[Side::kRhs] = n16 >> 16;
56 }
57 
DecodeTraversalFractalU(std::uint32_t square_index,SidePair<int> * local_pos)58 void DecodeTraversalFractalU(std::uint32_t square_index,
59                              SidePair<int>* local_pos) {
60   DecodeTraversalFractalZ(square_index, local_pos);
61   // Change fractal z-order to u-order
62   (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
63 }
64 
65 // Code inspired by the sample code in
66 //   https://en.wikipedia.org/wiki/Hilbert_curve
67 // The main optimization is to avoid hard-to-predict conditional branches
68 // based on the bits of the square_index parameter.
DecodeTraversalFractalHilbert(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)69 void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
70                                    SidePair<int>* local_pos) {
71   std::uint32_t t = square_index;
72   std::uint32_t x = 0;
73   std::uint32_t y = 0;
74   // Easy-to-predict for loop, the number of iterations is the same for
75   // an entire GEMM.
76   for (int sb = 0; sb < size_log2; sb++) {
77     std::uint32_t s = 1 << sb;
78     bool rx = t & 2;
79     bool ry = (t & 1) ^ rx;
80     std::uint32_t tmp = rx ? (s - 1 - x) : x;
81     x = ry ? x : rx ? (s - 1 - y) : y;
82     y = ry ? (y + s) : tmp;
83     x = rx ? (x + s) : x;
84     t >>= 2;
85   }
86   (*local_pos)[Side::kLhs] = y;
87   (*local_pos)[Side::kRhs] = x;
88 }
89 
90 }  // end anonymous namespace
91 
GetBlockByIndex(const BlockMap & block_map,int index,SidePair<int> * block)92 void GetBlockByIndex(const BlockMap& block_map, int index,
93                      SidePair<int>* block) {
94   profiler::ScopeLabel label("GetBlockByIndex");
95   const std::uint32_t index_u32 = index;
96 
97   const std::uint32_t num_blocks_per_local_curve =
98       1u << (2 * block_map.num_blocks_base_log2);
99   const std::uint32_t square_index =
100       index_u32 & (num_blocks_per_local_curve - 1);
101 
102   const int size_log2 = block_map.num_blocks_base_log2;
103   SidePair<int> local_pos;
104   switch (block_map.traversal_order) {
105     case BlockMapTraversalOrder::kFractalZ:
106       DecodeTraversalFractalZ(square_index, &local_pos);
107       break;
108     case BlockMapTraversalOrder::kFractalU:
109       DecodeTraversalFractalU(square_index, &local_pos);
110       break;
111     case BlockMapTraversalOrder::kFractalHilbert:
112       DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
113       break;
114     default:
115       RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
116       DecodeTraversalLinear(size_log2, square_index, &local_pos);
117       break;
118   }
119 
120   const std::uint32_t rectangular_index =
121       index_u32 >> 2 * block_map.num_blocks_base_log2;
122   for (Side side : {Side::kLhs, Side::kRhs}) {
123     const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
124     const int rectangular_offset = (rectangular_index & mask)
125                                    << block_map.num_blocks_base_log2;
126     (*block)[side] = local_pos[side] + rectangular_offset;
127   }
128 }
129 
130 namespace {
131 
GetTraversalOrder(int rows_after_rectangularness_division,int cols_after_rectangularness_division,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)132 BlockMapTraversalOrder GetTraversalOrder(
133     int rows_after_rectangularness_division,
134     int cols_after_rectangularness_division, int depth, int lhs_scalar_size,
135     int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) {
136   static constexpr bool kAnyFractal =
137       RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT);
138   const int working_set_size =
139       (lhs_scalar_size * rows_after_rectangularness_division +
140        rhs_scalar_size * cols_after_rectangularness_division) *
141       depth;
142   if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) {
143     if (RUY_OPT(FRACTAL_HILBERT) &&
144         (working_set_size > cpu_cache_params.last_level_cache_size)) {
145       return BlockMapTraversalOrder::kFractalHilbert;
146     } else if (RUY_OPT(FRACTAL_U)) {
147       return BlockMapTraversalOrder::kFractalU;
148     } else {
149       return BlockMapTraversalOrder::kFractalZ;
150     }
151   } else {
152     return BlockMapTraversalOrder::kLinear;
153   }
154 }
155 
floor_log2_quotient(int num,int denom)156 int floor_log2_quotient(int num, int denom) {
157   if (num <= denom) {
158     return 0;
159   }
160   int log2_quotient = floor_log2(num) - ceil_log2(denom);
161   if ((denom << (log2_quotient + 1)) <= num) {
162     log2_quotient++;
163   }
164   return log2_quotient;
165 }
166 
167 // Computes the rectangularness of the matrix shape (rows, cols). This is
168 // essentially just the log2 of the quotient (rows / cols). The kernel_rows and
169 // kernel_cols only get into the picture for clamping bounds but don't affect
170 // the generic computation.
GetRectangularness(int rows,int cols,int kernel_rows,int kernel_cols,int * rows_rectangularness_log2,int * cols_rectangularness_log2)171 void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
172                         int* rows_rectangularness_log2,
173                         int* cols_rectangularness_log2) {
174   *rows_rectangularness_log2 = 0;
175   *cols_rectangularness_log2 = 0;
176 
177   // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
178   // itself, we risk having too small kernel blocks for good kernel
179   // amortization. We avoid that by limiting recangularness so that kernel
180   // blocks are not too tiny at least in that dimension. Specifically, we try to
181   // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
182   // kernel block along the large dimension.
183   const int min_kernel_inner_loop_runs_log2 = 3;
184   if (rows > cols) {
185     int cols_of_kernel_inner_loop_runs_log2 =
186         ceil_log2(cols) - pot_log2(kernel_cols);
187     int min_rows_of_kernel_inner_loop_runs_log2 =
188         std::max(0, min_kernel_inner_loop_runs_log2 -
189                         cols_of_kernel_inner_loop_runs_log2);
190     *rows_rectangularness_log2 =
191         std::min(floor_log2_quotient(rows, cols),
192                  std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
193                                  min_rows_of_kernel_inner_loop_runs_log2));
194     // Sanity check that we did not over-estimate rows_rectangularness_log2.
195     RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
196   } else if (cols > rows) {
197     int rows_of_kernel_inner_loop_runs_log2 =
198         ceil_log2(rows) - pot_log2(kernel_rows);
199     int min_cols_of_kernel_inner_loop_runs_log2 =
200         std::max(0, min_kernel_inner_loop_runs_log2 -
201                         rows_of_kernel_inner_loop_runs_log2);
202     *cols_rectangularness_log2 =
203         std::min(floor_log2_quotient(cols, rows),
204                  std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
205                                  min_cols_of_kernel_inner_loop_runs_log2));
206     // Sanity check that we did not over-estimate cols_rectangularness_log2.
207     RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
208   }
209   RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
210 }
211 
212 // Computes a 'multithreading score'. When multithreading, we need there to
213 // be at least as many tiles as there are threads, and hopefully
214 // substantially more than that, so we benefit from ruy's ability to
215 // dispatch fine-grained workloads to threads.
GetMultithreadingScore(int block_size_log2,int rows,int cols,int tentative_thread_count)216 int GetMultithreadingScore(int block_size_log2, int rows, int cols,
217                            int tentative_thread_count) {
218   const int num_full_blocks_of_rows = rows >> block_size_log2;
219   const int num_full_blocks_of_cols = cols >> block_size_log2;
220   const int candidate_num_full_blocks_log2 = floor_log2(
221       std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
222 
223   // The values here have been tuned on ARM Cortex-A55.
224   // We expect this to have to be tuned differently for other CPUs.
225   if (tentative_thread_count == 1) {
226     return 0;
227   } else {
228     const int blocks_per_thread_log2 =
229         candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
230     if (blocks_per_thread_log2 < 0) {
231       return -64;
232     } else if (blocks_per_thread_log2 == 0) {
233       return -16;
234     } else if (blocks_per_thread_log2 == 1) {
235       return -8;
236     } else if (blocks_per_thread_log2 == 2) {
237       return 0;
238     } else if (blocks_per_thread_log2 == 3) {
239       return 8;
240     } else {
241       return 16;
242     }
243   }
244 }
245 
246 // Computes a 'cache locality score'.
GetCacheLocalityScore(int block_size_log2,int rows,int cols,int depth,int kernel_rows_log2,int kernel_cols_log2,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)247 int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
248                           int kernel_rows_log2, int kernel_cols_log2,
249                           int lhs_scalar_size, int rhs_scalar_size,
250                           const CpuCacheParams& cpu_cache_params) {
251   // In the narrow case (e.g. matrix*vector), each byte of the big operand
252   // matrix (either LHS or RHS) is traversed only once, so any notion of data
253   // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
254   // be 0 in that case.
255   if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
256     return 0;
257   }
258   const int block_rows = std::min(1 << block_size_log2, rows);
259   const int block_cols = std::min(1 << block_size_log2, cols);
260   const int total_read_bytes =
261       (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
262   const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
263   const int nonlocality_log2 =
264       total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size);
265   // The values here have been tuned on ARM Cortex-A55.
266   // We expect this to have to be tuned differently for other CPUs.
267   if (nonlocality_log2 < -1) {
268     return 64;
269   } else if (nonlocality_log2 == -1) {
270     return 56;
271   } else if (nonlocality_log2 == 0) {
272     return 48;
273   } else if (nonlocality_log2 == 1) {
274     return 32;
275   } else if (nonlocality_log2 == 2) {
276     return 16;
277   } else if (nonlocality_log2 == 3) {
278     return 0;
279   } else {
280     return -64;
281   }
282 }
283 
284 // Compute a 'kernel amortization score'. This is the notion that very small
285 // tiles result in more overhead outside of kernels, more complex memory
286 // access patterns and less benefits from ruy's fat kernels, so we reward
287 // larger blocks more than smaller ones.
GetKernelAmortizationScore(int block_size_log2,int rows,int cols,int kernel_rows_log2,int kernel_cols_log2)288 int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
289                                int kernel_rows_log2, int kernel_cols_log2) {
290   const int block_rows = std::min(1 << block_size_log2, rows);
291   const int block_cols = std::min(1 << block_size_log2, cols);
292   const int kernels_per_block_log2 =
293       floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
294   RUY_DCHECK_GE(kernels_per_block_log2, 0);
295   // The values here have been tuned on ARM Cortex-A55.
296   // We expect this to have to be tuned differently for other CPUs.
297   if (kernels_per_block_log2 == 0) {
298     return 0;
299   } else if (kernels_per_block_log2 == 1) {
300     return 8;
301   } else if (kernels_per_block_log2 == 2) {
302     return 16;
303   } else if (kernels_per_block_log2 == 3) {
304     return 24;
305   } else if (kernels_per_block_log2 == 4) {
306     return 32;
307   } else if (kernels_per_block_log2 == 5) {
308     return 40;
309   } else if (kernels_per_block_log2 == 6) {
310     return 48;
311   } else if (kernels_per_block_log2 == 7) {
312     return 56;
313   } else {
314     return 64;
315   }
316 }
317 
318 }  // namespace
319 
IsObviouslyLinearTraversal(int rows,int cols,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)320 bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
321                                 int lhs_scalar_size, int rhs_scalar_size,
322                                 const CpuCacheParams& cpu_cache_params) {
323   if (rows == 1 || cols == 1) {
324     return true;
325   }
326   // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided
327   // by the rectangularness factors, since any non-linear traversal order will
328   // be local to each subdivision. In the present function, we don't know the
329   // rectangularness factors yet, and we can't just call GetRectangularness
330   // as that requires knowing the kernel block layout. Since we just want
331   // a coarse estimate with only the guarantee that if we return `true` then
332   // linear traversal will be used, it is OK here to over-estimate `rows` and
333   // `cols`, by omitting to divide them by the rectangularness factors.ß
334   return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
335                            cpu_cache_params) == BlockMapTraversalOrder::kLinear;
336 }
337 
MakeBlockMap(int rows,int cols,int depth,int kernel_rows,int kernel_cols,int lhs_scalar_size,int rhs_scalar_size,int tentative_thread_count,const CpuCacheParams & cpu_cache_params,BlockMap * block_map)338 void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
339                   int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
340                   int tentative_thread_count,
341                   const CpuCacheParams& cpu_cache_params, BlockMap* block_map) {
342   RUY_TRACE_SCOPE;
343   profiler::ScopeLabel label("MakeBlockMap");
344 
345   RUY_DCHECK_GE(rows, kernel_rows);
346   RUY_DCHECK_GE(cols, kernel_cols);
347   RUY_DCHECK_EQ(rows % kernel_rows, 0);
348   RUY_DCHECK_EQ(cols % kernel_cols, 0);
349 
350   // Estimate the 'rectangularness', the first level of subdivision bringing
351   // the shape to within 2x of a square shape.
352   int rows_rectangularness_log2 = 0;
353   int cols_rectangularness_log2 = 0;
354   GetRectangularness(rows, cols, kernel_rows, kernel_cols,
355                      &rows_rectangularness_log2, &cols_rectangularness_log2);
356 
357   const int kernel_rows_log2 = pot_log2(kernel_rows);
358   const int kernel_cols_log2 = pot_log2(kernel_cols);
359   const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
360 
361   const int size = std::min(rows, cols);
362   const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
363 
364   RUY_DCHECK_GE(size_log2, kernel_size_log2);
365 
366   // Heuristic selecting the power-of-two grid subdivision insider of each
367   // square-ish region (past the above subdivision by 'rectangularness').
368   // Note that it is the number of subdivisions, not the resulting block size,
369   // that will be a power of two. But inside of that heuristic, it simplifies
370   // code to talk in terms of 'block_size_log2', as if it were the block size
371   // that were a power of two. This 'block_size_log2' is to be interpreted as
372   // "log2 rounded below", e.g. when block_size_log2=8 we might have a block
373   // size in [256, 511]. When the shape is non-square, rows!=cols, this
374   // refers to the smaller of the two, so the other might be as large as
375   // 1021 (can't be 1022 because following the above 'rectangularness'
376   // subdivision, the aspect ratio is already < 2).
377 
378   // We are going to try candidate values for block_size_log2 ranging from
379   // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
380   // For each of them we will compute a 'score' by adding individual scores
381   // for a few different considerations, all of which is entirely empirical.
382   // The values (and possibly the logic) around here are all subject to tuning
383   // based on benchmarks on different hardware. The current values are based
384   // on benchmarking on Qualcomm S855 (big and little cores), arm64,
385   // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
386   // and tune this as needed to achieve good performance elsewhere. Use
387   // the unit test, block_map_test, to encode values that should be preserved
388   // on specific architectures. Use RUY_TRACE to debug the current heuristics
389   // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a
390   // different block_size_log2 choice, to empirically find the optimal value
391   // before getting to updating the heuristic so that it produces that value.
392   static constexpr int kMaxKernelsPerBlockLog2 = 6;
393   const int max_block_size_log2 =
394       std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
395   int best_score = std::numeric_limits<int>::min();
396   int best_score_block_size_log2 = -1;
397   RUY_TRACE_INFO(MAKE_BLOCK_MAP_START);
398   for (int block_size_log2 = kernel_size_log2;
399        block_size_log2 <= max_block_size_log2; block_size_log2++) {
400     const int multithreading_score = GetMultithreadingScore(
401         block_size_log2, rows, cols, tentative_thread_count);
402     const int cache_locality_score = GetCacheLocalityScore(
403         block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
404         lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
405     const int kernel_amortization_score = GetKernelAmortizationScore(
406         block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
407     const int score =
408         multithreading_score + cache_locality_score + kernel_amortization_score;
409     if (score >= best_score) {
410       best_score = score;
411       best_score_block_size_log2 = block_size_log2;
412     }
413     RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE);
414   }
415 
416 #ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2
417   // Useful for tuning.
418   best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2;
419 #endif
420 
421   // As explained in the above comment, phrasing the above code in terms of
422   // block_size_log2 was only convenience inside of that heuristic. Now we
423   // revert to talking in terms of grid subdivision. That is what will actually
424   // be powers of two.
425   int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
426   RUY_DCHECK_GE(num_blocks_base_log2, 0);
427   const int num_blocks_of_rows_log2 =
428       num_blocks_base_log2 + rows_rectangularness_log2;
429   const int num_blocks_of_cols_log2 =
430       num_blocks_base_log2 + cols_rectangularness_log2;
431 
432   // Now that we know the grid subdivision, we can pinpoint the exact block
433   // sizes. They can't be powers of two in general; they can't even be all
434   // equal in general; so the following few parameters will govern how blocks
435   // of slightly different shapes are put together in the block map.
436   const int small_block_rows =
437       round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
438   const int small_block_cols =
439       round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
440   const int rows_of_large_blocks =
441       round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2),
442                    kernel_rows) >>
443       pot_log2(kernel_rows);
444   const int cols_of_large_blocks =
445       round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2),
446                    kernel_cols) >>
447       pot_log2(kernel_cols);
448 
449   // We have everything! Write out to the destination block_map.
450   block_map->dims[Side::kLhs] = rows;
451   block_map->dims[Side::kRhs] = cols;
452   block_map->kernel_dims[Side::kLhs] = kernel_rows;
453   block_map->kernel_dims[Side::kRhs] = kernel_cols;
454   block_map->num_blocks_base_log2 = num_blocks_base_log2;
455   block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
456   block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
457   block_map->small_block_dims[Side::kLhs] = small_block_rows;
458   block_map->small_block_dims[Side::kRhs] = small_block_cols;
459   block_map->large_blocks[Side::kLhs] = rows_of_large_blocks;
460   block_map->large_blocks[Side::kRhs] = cols_of_large_blocks;
461   // See the comment on GetTraversalOrder for why we are dividing `rows` and
462   // `cols` by the rectangularness subdivision parameters here.
463   block_map->traversal_order = GetTraversalOrder(
464       rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2,
465       depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
466   // Done last: NumBlocks needs some of the block_map fields to be already set.
467   block_map->thread_count =
468       std::min(tentative_thread_count, NumBlocks(*block_map));
469   RUY_TRACE_INFO(MAKE_BLOCK_MAP_END);
470 }
471 
GetBlockMatrixCoords(Side side,const BlockMap & block_map,int block,int * start,int * end)472 void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
473                           int* start, int* end) {
474   profiler::ScopeLabel label("GetBlockMatrixCoords");
475   *start = block * block_map.small_block_dims[side] +
476            std::min(block, block_map.large_blocks[side]) *
477                block_map.kernel_dims[side];
478   *end =
479       *start + block_map.small_block_dims[side] +
480       (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
481 
482   RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
483   RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
484   RUY_DCHECK_LE(*end, block_map.dims[side]);
485   RUY_DCHECK_LT(*start, *end);
486   RUY_DCHECK_GE(*start, 0);
487 }
488 
GetBlockMatrixCoords(const BlockMap & block_map,const SidePair<int> & block,SidePair<int> * start,SidePair<int> * end)489 void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
490                           SidePair<int>* start, SidePair<int>* end) {
491   for (Side side : {Side::kLhs, Side::kRhs}) {
492     GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
493                          &(*end)[side]);
494   }
495 }
496 
497 }  // namespace ruy
498