• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/block_map.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <limits>
21 
22 #ifdef RUY_MAKEBLOCKMAP_DEBUG
23 #include <cstdio>
24 #include <cstdlib>
25 #include <string>
26 #endif
27 
28 #include "ruy/check_macros.h"
29 #include "ruy/opt_set.h"
30 #include "ruy/profiler/instrumentation.h"
31 #include "ruy/size_util.h"
32 #include "ruy/trace.h"
33 
34 namespace ruy {
35 
36 namespace {
37 
DecodeTraversalLinear(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)38 void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
39                            SidePair<int>* local_pos) {
40   (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
41   (*local_pos)[Side::kRhs] = square_index >> size_log2;
42 }
43 
DecodeTraversalFractalZ(std::uint32_t square_index,SidePair<int> * local_pos)44 void DecodeTraversalFractalZ(std::uint32_t square_index,
45                              SidePair<int>* local_pos) {
46   const std::uint32_t n1 = square_index;
47   const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
48                            ((n1 & 0x22222222u) << 1);
49   const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
50                            ((n2 & 0x0c0c0c0cu) << 2);
51   const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
52                            ((n4 & 0x00f000f0u) << 4);
53   const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
54                             ((n8 & 0x0000ff00u) << 8);
55   (*local_pos)[Side::kLhs] = n16 & 0xffff;
56   (*local_pos)[Side::kRhs] = n16 >> 16;
57 }
58 
DecodeTraversalFractalU(std::uint32_t square_index,SidePair<int> * local_pos)59 void DecodeTraversalFractalU(std::uint32_t square_index,
60                              SidePair<int>* local_pos) {
61   DecodeTraversalFractalZ(square_index, local_pos);
62   // Change fractal z-order to u-order
63   (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
64 }
65 
66 // Code inspired by the sample code in
67 //   https://en.wikipedia.org/wiki/Hilbert_curve
68 // The main optimization is to avoid hard-to-predict conditional branches
69 // based on the bits of the square_index parameter.
DecodeTraversalFractalHilbert(int size_log2,std::uint32_t square_index,SidePair<int> * local_pos)70 void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
71                                    SidePair<int>* local_pos) {
72   std::uint32_t t = square_index;
73   std::uint32_t x = 0;
74   std::uint32_t y = 0;
75   // Easy-to-predict for loop, the number of iterations is the same for
76   // an entire GEMM.
77   for (int sb = 0; sb < size_log2; sb++) {
78     std::uint32_t s = 1 << sb;
79     bool rx = t & 2;
80     bool ry = (t & 1) ^ rx;
81     std::uint32_t tmp = rx ? (s - 1 - x) : x;
82     x = ry ? x : rx ? (s - 1 - y) : y;
83     y = ry ? (y + s) : tmp;
84     x = rx ? (x + s) : x;
85     t >>= 2;
86   }
87   (*local_pos)[Side::kLhs] = y;
88   (*local_pos)[Side::kRhs] = x;
89 }
90 
91 }  // end anonymous namespace
92 
GetBlockByIndex(const BlockMap & block_map,int index,SidePair<int> * block)93 void GetBlockByIndex(const BlockMap& block_map, int index,
94                      SidePair<int>* block) {
95   profiler::ScopeLabel label("GetBlockByIndex");
96   const std::uint32_t index_u32 = index;
97 
98   const std::uint32_t num_blocks_per_local_curve =
99       1u << (2 * block_map.num_blocks_base_log2);
100   const std::uint32_t square_index =
101       index_u32 & (num_blocks_per_local_curve - 1);
102 
103   const int size_log2 = block_map.num_blocks_base_log2;
104   SidePair<int> local_pos;
105   switch (block_map.traversal_order) {
106     case BlockMapTraversalOrder::kFractalZ:
107       DecodeTraversalFractalZ(square_index, &local_pos);
108       break;
109     case BlockMapTraversalOrder::kFractalU:
110       DecodeTraversalFractalU(square_index, &local_pos);
111       break;
112     case BlockMapTraversalOrder::kFractalHilbert:
113       DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
114       break;
115     default:
116       RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
117       DecodeTraversalLinear(size_log2, square_index, &local_pos);
118       break;
119   }
120 
121   const std::uint32_t rectangular_index =
122       index_u32 >> 2 * block_map.num_blocks_base_log2;
123   for (Side side : {Side::kLhs, Side::kRhs}) {
124     const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
125     const int rectangular_offset = (rectangular_index & mask)
126                                    << block_map.num_blocks_base_log2;
127     (*block)[side] = local_pos[side] + rectangular_offset;
128   }
129 }
130 
131 namespace {
132 
GetTraversalOrder(int rows_after_rectangularness_division,int cols_after_rectangularness_division,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)133 BlockMapTraversalOrder GetTraversalOrder(
134     int rows_after_rectangularness_division,
135     int cols_after_rectangularness_division, int depth, int lhs_scalar_size,
136     int rhs_scalar_size, const CpuCacheParams& cpu_cache_params) {
137   static constexpr bool kAnyFractal =
138       RUY_OPT(FRACTAL_Z) | RUY_OPT(FRACTAL_U) | RUY_OPT(FRACTAL_HILBERT);
139   const int working_set_size =
140       (lhs_scalar_size * rows_after_rectangularness_division +
141        rhs_scalar_size * cols_after_rectangularness_division) *
142       depth;
143   if (kAnyFractal && (working_set_size > cpu_cache_params.local_cache_size)) {
144     if (RUY_OPT(FRACTAL_HILBERT) &&
145         (working_set_size > cpu_cache_params.last_level_cache_size)) {
146       return BlockMapTraversalOrder::kFractalHilbert;
147     } else if (RUY_OPT(FRACTAL_U)) {
148       return BlockMapTraversalOrder::kFractalU;
149     } else {
150       return BlockMapTraversalOrder::kFractalZ;
151     }
152   } else {
153     return BlockMapTraversalOrder::kLinear;
154   }
155 }
156 
floor_log2_quotient(int num,int denom)157 int floor_log2_quotient(int num, int denom) {
158   if (num <= denom) {
159     return 0;
160   }
161   int log2_quotient = floor_log2(num) - ceil_log2(denom);
162   if ((denom << (log2_quotient + 1)) <= num) {
163     log2_quotient++;
164   }
165   return log2_quotient;
166 }
167 
168 // Computes the rectangularness of the matrix shape (rows, cols). This is
169 // essentially just the log2 of the quotient (rows / cols). The kernel_rows and
170 // kernel_cols only get into the picture for clamping bounds but don't affect
171 // the generic computation.
GetRectangularness(int rows,int cols,int kernel_rows,int kernel_cols,int * rows_rectangularness_log2,int * cols_rectangularness_log2)172 void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
173                         int* rows_rectangularness_log2,
174                         int* cols_rectangularness_log2) {
175   *rows_rectangularness_log2 = 0;
176   *cols_rectangularness_log2 = 0;
177 
178   // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
179   // itself, we risk having too small kernel blocks for good kernel
180   // amortization. We avoid that by limiting recangularness so that kernel
181   // blocks are not too tiny at least in that dimension. Specifically, we try to
182   // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
183   // kernel block along the large dimension.
184   const int min_kernel_inner_loop_runs_log2 = 3;
185   if (rows > cols) {
186     int cols_of_kernel_inner_loop_runs_log2 =
187         ceil_log2(cols) - pot_log2(kernel_cols);
188     int min_rows_of_kernel_inner_loop_runs_log2 =
189         std::max(0, min_kernel_inner_loop_runs_log2 -
190                         cols_of_kernel_inner_loop_runs_log2);
191     *rows_rectangularness_log2 =
192         std::min(floor_log2_quotient(rows, cols),
193                  std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
194                                  min_rows_of_kernel_inner_loop_runs_log2));
195     // Sanity check that we did not over-estimate rows_rectangularness_log2.
196     RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
197   } else if (cols > rows) {
198     int rows_of_kernel_inner_loop_runs_log2 =
199         ceil_log2(rows) - pot_log2(kernel_rows);
200     int min_cols_of_kernel_inner_loop_runs_log2 =
201         std::max(0, min_kernel_inner_loop_runs_log2 -
202                         rows_of_kernel_inner_loop_runs_log2);
203     *cols_rectangularness_log2 =
204         std::min(floor_log2_quotient(cols, rows),
205                  std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
206                                  min_cols_of_kernel_inner_loop_runs_log2));
207     // Sanity check that we did not over-estimate cols_rectangularness_log2.
208     RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
209   }
210   RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
211 }
212 
213 // Computes a 'multithreading score'. When multithreading, we need there to
214 // be at least as many tiles as there are threads, and hopefully
215 // substantially more than that, so we benefit from ruy's ability to
216 // dispatch fine-grained workloads to threads.
GetMultithreadingScore(int block_size_log2,int rows,int cols,int tentative_thread_count)217 int GetMultithreadingScore(int block_size_log2, int rows, int cols,
218                            int tentative_thread_count) {
219   const int num_full_blocks_of_rows = rows >> block_size_log2;
220   const int num_full_blocks_of_cols = cols >> block_size_log2;
221   const int candidate_num_full_blocks_log2 = floor_log2(
222       std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
223 
224   // The values here have been tuned on ARM Cortex-A55.
225   // We expect this to have to be tuned differently for other CPUs.
226   if (tentative_thread_count == 1) {
227     return 0;
228   } else {
229     const int blocks_per_thread_log2 =
230         candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
231     if (blocks_per_thread_log2 < 0) {
232       return -64;
233     } else if (blocks_per_thread_log2 == 0) {
234       return -16;
235     } else if (blocks_per_thread_log2 == 1) {
236       return -8;
237     } else if (blocks_per_thread_log2 == 2) {
238       return 0;
239     } else if (blocks_per_thread_log2 == 3) {
240       return 8;
241     } else {
242       return 16;
243     }
244   }
245 }
246 
247 // Computes a 'cache locality score'.
GetCacheLocalityScore(int block_size_log2,int rows,int cols,int depth,int kernel_rows_log2,int kernel_cols_log2,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)248 int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
249                           int kernel_rows_log2, int kernel_cols_log2,
250                           int lhs_scalar_size, int rhs_scalar_size,
251                           const CpuCacheParams& cpu_cache_params) {
252   // In the narrow case (e.g. matrix*vector), each byte of the big operand
253   // matrix (either LHS or RHS) is traversed only once, so any notion of data
254   // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
255   // be 0 in that case.
256   if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
257     return 0;
258   }
259   const int block_rows = std::min(1 << block_size_log2, rows);
260   const int block_cols = std::min(1 << block_size_log2, cols);
261   const int total_read_bytes =
262       (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
263   const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
264   const int nonlocality_log2 =
265       total_read_bytes_log2 - floor_log2(cpu_cache_params.local_cache_size);
266   // The values here have been tuned on ARM Cortex-A55.
267   // We expect this to have to be tuned differently for other CPUs.
268   if (nonlocality_log2 < -1) {
269     return 64;
270   } else if (nonlocality_log2 == -1) {
271     return 56;
272   } else if (nonlocality_log2 == 0) {
273     return 48;
274   } else if (nonlocality_log2 == 1) {
275     return 32;
276   } else if (nonlocality_log2 == 2) {
277     return 16;
278   } else if (nonlocality_log2 == 3) {
279     return 0;
280   } else {
281     return -64;
282   }
283 }
284 
285 // Compute a 'kernel amortization score'. This is the notion that very small
286 // tiles result in more overhead outside of kernels, more complex memory
287 // access patterns and less benefits from ruy's fat kernels, so we reward
288 // larger blocks more than smaller ones.
GetKernelAmortizationScore(int block_size_log2,int rows,int cols,int kernel_rows_log2,int kernel_cols_log2)289 int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
290                                int kernel_rows_log2, int kernel_cols_log2) {
291   const int block_rows = std::min(1 << block_size_log2, rows);
292   const int block_cols = std::min(1 << block_size_log2, cols);
293   const int kernels_per_block_log2 =
294       floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
295   RUY_DCHECK_GE(kernels_per_block_log2, 0);
296   // The values here have been tuned on ARM Cortex-A55.
297   // We expect this to have to be tuned differently for other CPUs.
298   if (kernels_per_block_log2 == 0) {
299     return 0;
300   } else if (kernels_per_block_log2 == 1) {
301     return 8;
302   } else if (kernels_per_block_log2 == 2) {
303     return 16;
304   } else if (kernels_per_block_log2 == 3) {
305     return 24;
306   } else if (kernels_per_block_log2 == 4) {
307     return 32;
308   } else if (kernels_per_block_log2 == 5) {
309     return 40;
310   } else if (kernels_per_block_log2 == 6) {
311     return 48;
312   } else if (kernels_per_block_log2 == 7) {
313     return 56;
314   } else {
315     return 64;
316   }
317 }
318 
319 }  // namespace
320 
IsObviouslyLinearTraversal(int rows,int cols,int depth,int lhs_scalar_size,int rhs_scalar_size,const CpuCacheParams & cpu_cache_params)321 bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
322                                 int lhs_scalar_size, int rhs_scalar_size,
323                                 const CpuCacheParams& cpu_cache_params) {
324   if (rows == 1 || cols == 1) {
325     return true;
326   }
327   // Normally, GetTraversalOrder wants the dimensions (rows x cols) divided
328   // by the rectangularness factors, since any non-linear traversal order will
329   // be local to each subdivision. In the present function, we don't know the
330   // rectangularness factors yet, and we can't just call GetRectangularness
331   // as that requires knowing the kernel block layout. Since we just want
332   // a coarse estimate with only the guarantee that if we return `true` then
333   // linear traversal will be used, it is OK here to over-estimate `rows` and
334   // `cols`, by omitting to divide them by the rectangularness factors.
335   return GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
336                            cpu_cache_params) == BlockMapTraversalOrder::kLinear;
337 }
338 
MakeBlockMap(int rows,int cols,int depth,int kernel_rows,int kernel_cols,int lhs_scalar_size,int rhs_scalar_size,int tentative_thread_count,const CpuCacheParams & cpu_cache_params,BlockMap * block_map)339 void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
340                   int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
341                   int tentative_thread_count,
342                   const CpuCacheParams& cpu_cache_params, BlockMap* block_map) {
343   RUY_TRACE_SCOPE;
344   profiler::ScopeLabel label("MakeBlockMap");
345 
346   RUY_DCHECK_GE(rows, kernel_rows);
347   RUY_DCHECK_GE(cols, kernel_cols);
348   RUY_DCHECK_EQ(rows % kernel_rows, 0);
349   RUY_DCHECK_EQ(cols % kernel_cols, 0);
350 
351   // Estimate the 'rectangularness', the first level of subdivision bringing
352   // the shape to within 2x of a square shape.
353   int rows_rectangularness_log2 = 0;
354   int cols_rectangularness_log2 = 0;
355   GetRectangularness(rows, cols, kernel_rows, kernel_cols,
356                      &rows_rectangularness_log2, &cols_rectangularness_log2);
357 
358   const int kernel_rows_log2 = pot_log2(kernel_rows);
359   const int kernel_cols_log2 = pot_log2(kernel_cols);
360   const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
361 
362   const int size = std::min(rows, cols);
363   const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
364 
365   RUY_DCHECK_GE(size_log2, kernel_size_log2);
366 
367   // Heuristic selecting the power-of-two grid subdivision insider of each
368   // square-ish region (past the above subdivision by 'rectangularness').
369   // Note that it is the number of subdivisions, not the resulting block size,
370   // that will be a power of two. But inside of that heuristic, it simplifies
371   // code to talk in terms of 'block_size_log2', as if it were the block size
372   // that were a power of two. This 'block_size_log2' is to be interpreted as
373   // "log2 rounded below", e.g. when block_size_log2=8 we might have a block
374   // size in [256, 511]. When the shape is non-square, rows!=cols, this
375   // refers to the smaller of the two, so the other might be as large as
376   // 1021 (can't be 1022 because following the above 'rectangularness'
377   // subdivision, the aspect ratio is already < 2).
378 
379   // We are going to try candidate values for block_size_log2 ranging from
380   // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
381   // For each of them we will compute a 'score' by adding individual scores
382   // for a few different considerations, all of which is entirely empirical.
383   // The values (and possibly the logic) around here are all subject to tuning
384   // based on benchmarks on different hardware. The current values are based
385   // on benchmarking on Qualcomm S855 (big and little cores), arm64,
386   // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
387   // and tune this as needed to achieve good performance elsewhere. Use
388   // the unit test, block_map_test, to encode values that should be preserved
389   // on specific architectures. Use RUY_TRACE to debug the current heuristics
390   // and RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2 to test the impact of a
391   // different block_size_log2 choice, to empirically find the optimal value
392   // before getting to updating the heuristic so that it produces that value.
393   static constexpr int kMaxKernelsPerBlockLog2 = 6;
394   const int max_block_size_log2 =
395       std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
396   int best_score = std::numeric_limits<int>::min();
397   int best_score_block_size_log2 = -1;
398   RUY_TRACE_INFO(MAKE_BLOCK_MAP_START);
399   for (int block_size_log2 = kernel_size_log2;
400        block_size_log2 <= max_block_size_log2; block_size_log2++) {
401     const int multithreading_score = GetMultithreadingScore(
402         block_size_log2, rows, cols, tentative_thread_count);
403     const int cache_locality_score = GetCacheLocalityScore(
404         block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
405         lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
406     const int kernel_amortization_score = GetKernelAmortizationScore(
407         block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
408     const int score =
409         multithreading_score + cache_locality_score + kernel_amortization_score;
410     if (score >= best_score) {
411       best_score = score;
412       best_score_block_size_log2 = block_size_log2;
413     }
414     RUY_TRACE_INFO(MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE);
415   }
416 
417 #ifdef RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2
418   // Useful for tuning.
419   best_score_block_size_log2 = RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2;
420 #endif
421 
422   // As explained in the above comment, phrasing the above code in terms of
423   // block_size_log2 was only convenience inside of that heuristic. Now we
424   // revert to talking in terms of grid subdivision. That is what will actually
425   // be powers of two.
426   int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
427   RUY_DCHECK_GE(num_blocks_base_log2, 0);
428   const int num_blocks_of_rows_log2 =
429       num_blocks_base_log2 + rows_rectangularness_log2;
430   const int num_blocks_of_cols_log2 =
431       num_blocks_base_log2 + cols_rectangularness_log2;
432 
433   // Now that we know the grid subdivision, we can pinpoint the exact block
434   // sizes. They can't be powers of two in general; they can't even be all
435   // equal in general; so the following few parameters will govern how blocks
436   // of slightly different shapes are put together in the block map.
437   const int small_block_rows =
438       round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
439   const int small_block_cols =
440       round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
441   const int rows_of_large_blocks =
442       round_up_pot(rows - (small_block_rows << num_blocks_of_rows_log2),
443                    kernel_rows) >>
444       pot_log2(kernel_rows);
445   const int cols_of_large_blocks =
446       round_up_pot(cols - (small_block_cols << num_blocks_of_cols_log2),
447                    kernel_cols) >>
448       pot_log2(kernel_cols);
449 
450   // We have everything! Write out to the destination block_map.
451   block_map->dims[Side::kLhs] = rows;
452   block_map->dims[Side::kRhs] = cols;
453   block_map->kernel_dims[Side::kLhs] = kernel_rows;
454   block_map->kernel_dims[Side::kRhs] = kernel_cols;
455   block_map->num_blocks_base_log2 = num_blocks_base_log2;
456   block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
457   block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
458   block_map->small_block_dims[Side::kLhs] = small_block_rows;
459   block_map->small_block_dims[Side::kRhs] = small_block_cols;
460   block_map->large_blocks[Side::kLhs] = rows_of_large_blocks;
461   block_map->large_blocks[Side::kRhs] = cols_of_large_blocks;
462   // See the comment on GetTraversalOrder for why we are dividing `rows` and
463   // `cols` by the rectangularness subdivision parameters here.
464   block_map->traversal_order = GetTraversalOrder(
465       rows >> rows_rectangularness_log2, cols >> cols_rectangularness_log2,
466       depth, lhs_scalar_size, rhs_scalar_size, cpu_cache_params);
467   // Done last: NumBlocks needs some of the block_map fields to be already set.
468   block_map->thread_count =
469       std::min(tentative_thread_count, NumBlocks(*block_map));
470   RUY_TRACE_INFO(MAKE_BLOCK_MAP_END);
471 }
472 
GetBlockMatrixCoords(Side side,const BlockMap & block_map,int block,int * start,int * end)473 void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
474                           int* start, int* end) {
475   profiler::ScopeLabel label("GetBlockMatrixCoords");
476   *start = block * block_map.small_block_dims[side] +
477            std::min(block, block_map.large_blocks[side]) *
478                block_map.kernel_dims[side];
479   *end =
480       *start + block_map.small_block_dims[side] +
481       (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
482 
483   RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
484   RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
485   RUY_DCHECK_LE(*end, block_map.dims[side]);
486   RUY_DCHECK_LT(*start, *end);
487   RUY_DCHECK_GE(*start, 0);
488 }
489 
GetBlockMatrixCoords(const BlockMap & block_map,const SidePair<int> & block,SidePair<int> * start,SidePair<int> * end)490 void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
491                           SidePair<int>* start, SidePair<int>* end) {
492   for (Side side : {Side::kLhs, Side::kRhs}) {
493     GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
494                          &(*end)[side]);
495   }
496 }
497 
498 }  // namespace ruy
499