• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Sketch of algorithm:
17 // A good description can be found in the HPTT paper, which is linked from the
18 // header file.
19 //
20 // The algorithm is divided into two parts: plan creation, that chooses an
21 // execution plan for the transpose, and execution. Plans may be cached and
22 // reused multiple times.
23 //
24 // We use a two level blocking scheme:
25 //
26 // The inner "microkernel" level is the unit of vectorization. A microkernel
27 // transposes a vector-unit sized tile, e.g., an 8x8 tile of floats for AVX2.
28 // The microkernels require one of the input dimensions and one of the
29 // output dimensions to have stride 1, while the other dimension in each case
30 // may have a non-trivial element stride. The kernels load N vectors of N
31 // stride-1 elements, and performs an NxN transpose, swapping the role of the
32 // stride-1 dimension. The N vectors are then written to the output. To perform
33 // a complete tensor transpose, we simply need apply the microkernel over all
34 // blocks of the matrix.
35 //
36 // In the event that the stride-1 dimensions of the input and output are the
37 // same, we use a simpler kernel which is a memcpy().
38 //
39 // To improve cache locality, we use another level of blocking, namely
40 // "macrokernels". The outer "macrokernel" level is a block of, for example,
41 // 4x4 microkernels. Macrokernels are the basic unit of work of the loop nest
42 // plan. For dimensions that aren't exactly divisible by the macrokernel size,
43 // we repeatedly halve the kernel size for trailing elements. For dimensions
44 // that aren't exactly divisible by the microkernel size, we use a scalar
45 // transpose for the trailing elements.
46 //
47 // A transpose plan iterates over the array's index space, applying macro-kernel
48 // sized blocks. Any iteration order is possible, although some orders give
49 // better locality than others. Currently we always use a default iteration
50 // order.
51 //
52 // A plan contains the data structures that describe how to perform a transpose.
53 // Plan creation chooses such things as a loop iteration order and kernel sizes.
54 // Plan creation also performs a handful of optimizations, such as
55 // coalescing adjacent dimensions that do not change their order and removing
56 // trivial dimensions.
57 //
58 // TODO(phawkins):
59 // * we don't incorporate a number of optimizations from HPTT, notably explicit
60 //   prefetching, and manual loop unrolling.
61 // * we could use vector-aligned stores for some arrays, which might
62 //   be worth something. We could also use nontemporal stores in the aligned
63 //   case.
64 // * we don't yet search for a good loop ordering. This probably matters less
65 //   for arrays that fit entirely in cache.
66 // * we could do a better job of vectorizing where the stride-1 dimensions are
67 //   small (e.g., inner dimensions of size [..., 3] are not uncommon in some
68 //   use cases.)
69 
70 #include "tensorflow/compiler/xla/pjrt/transpose.h"
71 
72 #include <algorithm>
73 #include <functional>
74 #include <numeric>
75 #include <stack>
76 #include <string>
77 #include <utility>
78 
79 #include "absl/algorithm/container.h"
80 #include "absl/strings/str_format.h"
81 #include "absl/strings/str_join.h"
82 #include "absl/synchronization/blocking_counter.h"
83 #include "absl/types/span.h"
84 #include "absl/types/variant.h"
85 #include "tensorflow/compiler/xla/permutation_util.h"
86 #include "tensorflow/compiler/xla/pjrt/transpose_kernels.h"
87 #include "tensorflow/compiler/xla/status.h"
88 #include "tensorflow/compiler/xla/util.h"
89 #include "tensorflow/core/platform/logging.h"
90 #include "tensorflow/core/profiler/lib/traceme.h"
91 
92 namespace xla {
93 
94 // A plan is a data structure that describes a loop nest.
95 // TODO(phawkins): consider shrinking Node so it fits in a cache line.
96 struct TransposePlan::Node {
97   // The loop should iterate over the index space range(start, end, inc).
98   // These fields are ignored by the macrokernel.
99   int64_t start;
100   int64_t end;
101   int64_t inc;  // The transpose sentinel node has inc < 0.
102 
103   // Strides of this dimension in A and B.
104   int64_t lda;
105   int64_t ldb;
106 
107   // If > 0, this loop is a loop over tile exteriors and has a trailing partial
108   // tile. To handle the trailing partial tile, skip to the plan node this many
109   // steps ahead in the vector of plan nodes.
110   int trailing_tile_next_node_inc = 0;
111 
112   // Is this dimension the innermost dimension in either A or B, and hence may
113   // have non-trivial blocking?
114   bool is_inner_dim_in_a = false;
115   bool is_inner_dim_in_b = false;
116 };
117 
ConvertF64ToEf57(const double * input,float * output,int n)118 void ConvertF64ToEf57(const double* input, float* output, int n) {
119   // TODO(phawkins): vectorize this transformation.
120   for (int i = 0; i < n; ++i) {
121     std::tie(output[0], output[1]) = SplitF64ToF32(*input);
122     ++input;
123     output += 2;
124   }
125 }
126 
127 template <typename T, int inner_bs,
128           TransposePlan::Transformation transformation>
MacroKernel(const char * __restrict a,int64_t lda,int outer_bs_a,char * __restrict b,int64_t ldb,int outer_bs_b,void * __restrict scratch)129 void MacroKernel(const char* __restrict a, int64_t lda, int outer_bs_a,
130                  char* __restrict b, int64_t ldb, int outer_bs_b,
131                  void* __restrict scratch) {
132   DVLOG(10) << "MacroKernel lda=" << lda << " ldb=" << ldb
133             << " outer_bs_a=" << outer_bs_a << " outer_bs_b=" << outer_bs_b
134             << " inner_bs=" << inner_bs;
135 
136   // TODO(phawkins): consider adding prefetching and streaming stores.
137 
138   if (transformation == TransposePlan::Transformation::kF64ToEf57) {
139     DCHECK_EQ(outer_bs_a * inner_bs % 2, 0);
140     float* p = reinterpret_cast<float*>(scratch);
141     for (int i = 0; i < outer_bs_b * inner_bs; ++i) {
142       ConvertF64ToEf57(reinterpret_cast<const double*>(a + lda * i),
143                        p + outer_bs_a * inner_bs * i,
144                        outer_bs_a * inner_bs / 2);
145     }
146     a = reinterpret_cast<const char*>(scratch);
147     lda = outer_bs_a * inner_bs * sizeof(float);
148   }
149 
150   for (int i = 0; i < outer_bs_a; ++i) {
151     for (int j = 0; j < outer_bs_b; ++j) {
152       TransposeMicroKernel<T, inner_bs>::Apply(
153           a + inner_bs * j * lda + i * inner_bs * sizeof(T), lda,
154           b + inner_bs * i * ldb + j * inner_bs * sizeof(T), ldb);
155     }
156   }
157 }
158 
159 // Transpose() is a driver function that implements a multidimensional loop nest
160 // following by iterating over the linked Node data structure.
161 template <typename T, int inner_bs,
162           TransposePlan::Transformation transformation>
Transpose(const char * __restrict a,int outer_bs_a,char * __restrict b,int outer_bs_b,TransposePlan::Node const * __restrict node,void * __restrict scratch)163 void Transpose(const char* __restrict a, int outer_bs_a, char* __restrict b,
164                int outer_bs_b, TransposePlan::Node const* __restrict node,
165                void* __restrict scratch) {
166   DVLOG(10) << "Transpose " << outer_bs_a << " " << outer_bs_b;
167   DCHECK_GT(outer_bs_a, 0);
168   DCHECK_GT(outer_bs_b, 0);
169   const int64_t start = node->start;
170   const int64_t end = node->end;
171   const int64_t stop = node->end - (node->inc - 1);
172   const int64_t lda = node->lda;
173   const int64_t ldb = node->ldb;
174   const int64_t inc = node->inc;
175   TransposePlan::Node const* next_node = node + 1;
176   if (next_node->inc < 0) {
177     // This is the last loop in the nested loops. The next node is a sentinel
178     // plan node that describes how to invoke the macrokernels.
179 
180     const int64_t lda_block = next_node->lda;
181     const int64_t ldb_block = next_node->ldb;
182     int64_t i;
183     for (i = start; i < stop; i += inc) {
184       MacroKernel<T, inner_bs, transformation>(a + i * lda, lda_block,
185                                                outer_bs_a, b + i * ldb,
186                                                ldb_block, outer_bs_b, scratch);
187     }
188     // Handle trailing elements that didn't fit in a complete macrokernel.
189     // Only the innermost dimensions have non-trivial outer_bs blocking.
190     if (i < end) {
191       DCHECK_EQ(node->trailing_tile_next_node_inc, 0);
192       DCHECK(node->is_inner_dim_in_a || node->is_inner_dim_in_b);
193       if (node->is_inner_dim_in_a) {
194         outer_bs_a = (end - i) / inner_bs;
195         if (outer_bs_a > 0) {
196           MacroKernel<T, inner_bs, transformation>(
197               a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
198               outer_bs_b, scratch);
199           i += outer_bs_a * inner_bs;
200         }
201         // If there are still trailing elements left over that don't fit in the
202         // inner block size, handle them via an unvectorized transpose.
203         if (i < end) {
204           MacroKernel<T, 1, transformation>(a + i * lda, lda_block, end - i,
205                                             b + i * ldb, ldb_block,
206                                             outer_bs_b * inner_bs, scratch);
207         }
208       } else if (node->is_inner_dim_in_b) {
209         outer_bs_b = (end - i) / inner_bs;
210         if (outer_bs_b > 0) {
211           MacroKernel<T, inner_bs, transformation>(
212               a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
213               outer_bs_b, scratch);
214           i += outer_bs_b * inner_bs;
215         }
216         if (i < end) {
217           MacroKernel<T, 1, transformation>(a + i * lda, lda_block,
218                                             outer_bs_a * inner_bs, b + i * ldb,
219                                             ldb_block, end - i, scratch);
220         }
221       }
222     } else if (node->trailing_tile_next_node_inc) {
223       // Handle the case where there is a trailing partial tile. We know
224       // inc == 1 for this case, so the loop above has already left `a` and `b`
225       // pointing to the start of the tile. We just need to use the alternate
226       // trailing_next_node to process the interior of the tile.
227       DCHECK_EQ(inc, 1);
228       TransposePlan::Node const* trailing_next_node =
229           node + node->trailing_tile_next_node_inc;
230       if (trailing_next_node->inc < 0) {
231         const int64_t lda_block = trailing_next_node->lda;
232         const int64_t ldb_block = trailing_next_node->ldb;
233         MacroKernel<T, inner_bs, transformation>(
234             a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
235             outer_bs_b, scratch);
236       } else {
237         Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
238                                                b + i * ldb, outer_bs_b,
239                                                trailing_next_node, scratch);
240       }
241     }
242   } else {
243     // This is not the last loop in the nested loops. Recursively visit the
244     // inner loops. Structurally this code is identical to the previous case,
245     // but we call Transpose() recursively instead of MacroKernel().
246     int64_t i;
247     for (i = start; i < stop; i += inc) {
248       Transpose<T, inner_bs, transformation>(
249           a + i * lda, outer_bs_a, b + i * ldb, outer_bs_b, next_node, scratch);
250     }
251     if (i < end) {
252       DCHECK_EQ(node->trailing_tile_next_node_inc, 0);
253       DCHECK(node->is_inner_dim_in_a || node->is_inner_dim_in_b);
254       if (node->is_inner_dim_in_a) {
255         outer_bs_a = (end - i) / inner_bs;
256         if (outer_bs_a > 0) {
257           Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
258                                                  b + i * ldb, outer_bs_b,
259                                                  next_node, scratch);
260           i += outer_bs_a * inner_bs;
261         }
262         if (i < end) {
263           Transpose<T, 1, transformation>(a + i * lda, end - i, b + i * ldb,
264                                           outer_bs_b * inner_bs, next_node,
265                                           scratch);
266         }
267       } else if (node->is_inner_dim_in_b) {
268         outer_bs_b = (end - i) / inner_bs;
269         if (outer_bs_b > 0) {
270           Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
271                                                  b + i * ldb, outer_bs_b,
272                                                  next_node, scratch);
273           i += outer_bs_b * inner_bs;
274         }
275         if (i < end) {
276           Transpose<T, 1, transformation>(a + i * lda, outer_bs_a * inner_bs,
277                                           b + i * ldb, end - i, next_node,
278                                           scratch);
279         }
280       }
281     } else if (node->trailing_tile_next_node_inc) {
282       TransposePlan::Node const* trailing_next_node =
283           node + node->trailing_tile_next_node_inc;
284       if (trailing_next_node->inc < 0) {
285         const int64_t lda_block = trailing_next_node->lda;
286         const int64_t ldb_block = trailing_next_node->ldb;
287         MacroKernel<T, inner_bs, transformation>(
288             a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
289             outer_bs_b, scratch);
290       } else {
291         Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
292                                                b + i * ldb, outer_bs_b,
293                                                trailing_next_node, scratch);
294       }
295     }
296   }
297 }
298 
299 template <typename T>
TransposeConstStride1(const char * __restrict a,char * __restrict b,TransposePlan::Node const * __restrict node)300 void TransposeConstStride1(const char* __restrict a, char* __restrict b,
301                            TransposePlan::Node const* __restrict node) {
302   a += node[0].start * node[0].lda;
303   b += node[0].start * node[0].ldb;
304   if (node[0].is_inner_dim_in_a) {
305     int64_t num_bytes = (node->end - node->start) * sizeof(T);
306     std::memcpy(b, a, num_bytes);
307   } else if (node[1].is_inner_dim_in_a) {
308     int64_t offset_a = node[1].start * node[1].lda;
309     int64_t offset_b = node[1].start * node[1].ldb;
310     int64_t num_bytes = (node[1].end - node[1].start) * sizeof(T);
311     a += offset_a;
312     b += offset_b;
313     for (int64_t i = node[0].start; i < node[0].end; ++i) {
314       std::memcpy(b, a, num_bytes);
315       a += node[0].lda;
316       b += node[0].ldb;
317     }
318     if (node[0].trailing_tile_next_node_inc) {
319       TransposeConstStride1<T>(a - offset_a, b - offset_b,
320                                node + node[0].trailing_tile_next_node_inc);
321     }
322   } else if (node[2].is_inner_dim_in_a) {
323     int64_t num_bytes = (node[2].end - node[2].start) * sizeof(T);
324     int64_t offset_a1 = node[1].start * node[1].lda;
325     int64_t offset_b1 = node[1].start * node[1].ldb;
326     int64_t offset_a2 = node[2].start * node[2].lda;
327     int64_t offset_b2 = node[2].start * node[2].ldb;
328     a += offset_a1 + offset_a2;
329     b += offset_b1 + offset_b2;
330     for (int64_t i = node[0].start; i < node[0].end; ++i) {
331       const char* a1 = a;
332       char* b1 = b;
333       for (int64_t j = node[1].start; j < node[1].end; ++j) {
334         std::memcpy(b1, a1, num_bytes);
335         a1 += node[1].lda;
336         b1 += node[1].ldb;
337       }
338       if (node[1].trailing_tile_next_node_inc) {
339         TransposeConstStride1<T>(
340             a1 - offset_a2, b1 - offset_b2,
341             &node[1] + node[1].trailing_tile_next_node_inc);
342       }
343       a += node[0].lda;
344       b += node[0].ldb;
345     }
346     if (node[0].trailing_tile_next_node_inc) {
347       TransposeConstStride1<T>(a - offset_a1 - offset_a2,
348                                b - offset_b1 - offset_b2,
349                                node + node[0].trailing_tile_next_node_inc);
350     }
351   } else {
352     for (int64_t i = node[0].start; i < node[0].end; ++i) {
353       const char* a1 = a + node[1].start * node[1].lda;
354       char* b1 = b + node[1].start * node[1].ldb;
355       for (int64_t j = node[1].start; j < node[1].end; ++j) {
356         TransposeConstStride1<T>(a1, b1, node + 2);
357         a1 += node[1].lda;
358         b1 += node[1].ldb;
359       }
360       if (node[1].trailing_tile_next_node_inc) {
361         TransposeConstStride1<T>(
362             a1, b1, &node[1] + node[1].trailing_tile_next_node_inc);
363       }
364       a += node[0].lda;
365       b += node[0].ldb;
366     }
367     if (node[0].trailing_tile_next_node_inc) {
368       TransposeConstStride1<T>(a, b,
369                                node + node[0].trailing_tile_next_node_inc);
370     }
371   }
372 }
373 
374 template <typename T, TransposePlan::Transformation transformation>
ExecuteTyped(const char * a,char * b,absl::Span<Node const> nodes) const375 void TransposePlan::ExecuteTyped(const char* a, char* b,
376                                  absl::Span<Node const> nodes) const {
377   if (inner_kernel_is_memcpy_) {
378     DCHECK(transformation_ == Transformation::kNone);
379     TransposeConstStride1<T>(a, b, nodes.data());
380   } else {
381     std::unique_ptr<char[]> scratch;
382     if (scratch_size_ > 0) {
383       scratch.reset(new char[scratch_size_]);
384     }
385     switch (inner_block_elems_) {
386       case 1:
387         if (nodes.size() > 1) {
388           Transpose<T, 1, transformation>(a, outer_block_elems_a_, b,
389                                           outer_block_elems_b_, nodes.data(),
390                                           scratch.get());
391         } else {
392           MacroKernel<T, 1, transformation>(
393               a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
394               outer_block_elems_b_, scratch.get());
395         }
396         break;
397       case 2:
398         if (nodes.size() > 1) {
399           Transpose<T, 2, transformation>(a, outer_block_elems_a_, b,
400                                           outer_block_elems_b_, nodes.data(),
401                                           scratch.get());
402         } else {
403           MacroKernel<T, 2, transformation>(
404               a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
405               outer_block_elems_b_, scratch.get());
406         }
407         break;
408       case 4:
409 
410         if (nodes.size() > 1) {
411           Transpose<T, 4, transformation>(a, outer_block_elems_a_, b,
412                                           outer_block_elems_b_, nodes.data(),
413                                           scratch.get());
414         } else {
415           MacroKernel<T, 4, transformation>(
416               a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
417               outer_block_elems_b_, scratch.get());
418         }
419         break;
420       case 8:
421         if (nodes.size() > 1) {
422           Transpose<T, 8, transformation>(a, outer_block_elems_a_, b,
423                                           outer_block_elems_b_, nodes.data(),
424                                           scratch.get());
425         } else {
426           MacroKernel<T, 8, transformation>(
427               a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
428               outer_block_elems_b_, scratch.get());
429         }
430         break;
431       case 16:
432         if (nodes.size() > 1) {
433           Transpose<T, 16, transformation>(a, outer_block_elems_a_, b,
434                                            outer_block_elems_b_, nodes.data(),
435                                            scratch.get());
436         } else {
437           MacroKernel<T, 16, transformation>(
438               a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
439               outer_block_elems_b_, scratch.get());
440         }
441         break;
442       default:
443         LOG(FATAL) << "Invalid inner_block_size " << inner_block_elems_;
444     }
445   }
446 }
447 
448 struct uint128 {
449   uint64_t lo;
450   uint64_t hi;
451 };
452 static_assert(sizeof(uint128) == 16, "uint128 should be 16 bytes in size");
453 
Execute(const void * a,void * b,const std::function<void (std::function<void (void)>)> & schedule_work) const454 void TransposePlan::Execute(
455     const void* a, void* b,
456     const std::function<void(std::function<void(void)>)>& schedule_work) const {
457   if (num_elems_ == 0) {
458     return;
459   }
460 
461   const char* ac = static_cast<const char*>(a);
462   char* bc = static_cast<char*>(b);
463 
464   auto execute_by_type = [&](absl::Span<Node const> nodes) {
465     switch (elem_size_in_bytes_) {
466       case 1:
467         ExecuteTyped<uint8_t, Transformation::kNone>(ac, bc, nodes);
468         break;
469       case 2:
470         ExecuteTyped<uint16_t, Transformation::kNone>(ac, bc, nodes);
471         break;
472       case 4:
473         if (transformation_ == Transformation::kNone) {
474           ExecuteTyped<uint32_t, Transformation::kNone>(ac, bc, nodes);
475         } else {
476           DCHECK(transformation_ == Transformation::kF64ToEf57);
477           ExecuteTyped<uint32_t, Transformation::kF64ToEf57>(ac, bc, nodes);
478         }
479         break;
480       case 8:
481         ExecuteTyped<uint64_t, Transformation::kNone>(ac, bc, nodes);
482         break;
483       case 16:
484         ExecuteTyped<uint128, Transformation::kNone>(ac, bc, nodes);
485         break;
486       default:
487         LOG(FATAL) << "Unimplemented element size " << elem_size_in_bytes_;
488     }
489   };
490 
491   if (!schedule_work || nodes_.size() <= 1) {
492     for (const auto& nodes : nodes_) {
493       execute_by_type(nodes);
494     }
495   } else {
496     absl::BlockingCounter counter(nodes_.size());
497     for (absl::Span<Node const> nodes : nodes_) {
498       schedule_work([&, nodes]() {
499         tensorflow::profiler::TraceMe traceme("Transpose::Execute",
500                                               /*level=*/2);
501         execute_by_type(nodes);
502         counter.DecrementCount();
503       });
504     }
505     counter.Wait();
506   }
507 }
508 
509 // Everything above this point pertains to executing plans.
510 // Everything below this point pertains to building plans.
511 
512 TransposePlan::TransposePlan() = default;
513 TransposePlan::~TransposePlan() = default;
514 
ComputeStrides(int64_t elem_size_in_bytes,absl::Span<const int64_t> dims,absl::Span<const int64_t> tiling,absl::InlinedVector<int64_t,4> & outer_tile_strides,absl::InlinedVector<int64_t,4> & inner_tile_strides)515 static void ComputeStrides(
516     int64_t elem_size_in_bytes, absl::Span<const int64_t> dims,
517     absl::Span<const int64_t> tiling,
518     absl::InlinedVector<int64_t, 4>& outer_tile_strides,
519     absl::InlinedVector<int64_t, 4>& inner_tile_strides) {
520   inner_tile_strides.resize(dims.size());
521   int64_t acc = elem_size_in_bytes;
522   for (int d = static_cast<int>(dims.size()) - 1; d >= 0; --d) {
523     inner_tile_strides[d] = acc;
524     acc *= tiling[d];
525   }
526   outer_tile_strides.resize(dims.size());
527   for (int d = static_cast<int>(dims.size()) - 1; d >= 0; --d) {
528     outer_tile_strides[d] = acc;
529     acc *= CeilOfRatio(dims[d], tiling[d]);
530   }
531 }
532 
RemoveTrivialDimensions(absl::InlinedVector<int64_t,4> & a_dims,absl::InlinedVector<int64_t,4> & permutation,absl::InlinedVector<int64_t,4> & lda,absl::InlinedVector<int64_t,4> & lda_tile,absl::InlinedVector<int64_t,4> & a_tiling,absl::InlinedVector<int64_t,4> & b_tiling)533 void TransposePlan::RemoveTrivialDimensions(
534     absl::InlinedVector<int64_t, 4>& a_dims,
535     absl::InlinedVector<int64_t, 4>& permutation,
536     absl::InlinedVector<int64_t, 4>& lda,
537     absl::InlinedVector<int64_t, 4>& lda_tile,
538     absl::InlinedVector<int64_t, 4>& a_tiling,
539     absl::InlinedVector<int64_t, 4>& b_tiling) {
540   int ndim = a_dims.size();
541   // How many positions has the i-th dimension of 'a' been moved to the left?
542   // -1 if the dimension is to be removed.
543   std::vector<int> shift(ndim);
544   absl::InlinedVector<int64_t, 4> updated_a_dims;
545   absl::InlinedVector<int64_t, 4> updated_lda;
546   absl::InlinedVector<int64_t, 4> updated_lda_tile;
547   absl::InlinedVector<int64_t, 4> updated_a_tiling;
548   updated_a_dims.reserve(ndim);
549   updated_lda.reserve(ndim);
550   updated_lda_tile.reserve(ndim);
551   updated_a_tiling.reserve(ndim);
552   std::vector<int64_t> inv_permutation = InversePermutation(permutation);
553   for (int a_dim = 0; a_dim < ndim; ++a_dim) {
554     int b_dim = inv_permutation[a_dim];
555     // A dimension is trivial if it has size 1 and is not tiled.
556     if (a_dims[a_dim] == 1 && a_tiling[a_dim] == 1 && b_tiling[b_dim] == 1) {
557       shift[a_dim] = -1;
558     } else {
559       updated_a_dims.push_back(a_dims[a_dim]);
560       updated_lda.push_back(lda[a_dim]);
561       updated_lda_tile.push_back(lda_tile[a_dim]);
562       updated_a_tiling.push_back(a_tiling[a_dim]);
563       shift[a_dim] = a_dim + 1 - updated_a_dims.size();
564     }
565   }
566 
567   // Updates the permutation and tiling of b.
568   absl::InlinedVector<int64_t, 4> updated_permutation;
569   absl::InlinedVector<int64_t, 4> updated_b_tiling;
570   updated_permutation.reserve(updated_a_dims.size());
571   updated_b_tiling.reserve(updated_a_dims.size());
572   for (int b_dim = 0; b_dim < ndim; ++b_dim) {
573     int a_dim = permutation[b_dim];
574     if (shift[a_dim] >= 0) {
575       updated_permutation.push_back(a_dim - shift[a_dim]);
576       updated_b_tiling.push_back(b_tiling[b_dim]);
577     }
578   }
579 
580   DCHECK(IsPermutation(updated_permutation));
581   a_dims = std::move(updated_a_dims);
582   permutation = std::move(updated_permutation);
583   lda = std::move(updated_lda);
584   lda_tile = std::move(updated_lda_tile);
585   a_tiling = std::move(updated_a_tiling);
586   b_tiling = std::move(updated_b_tiling);
587 }
588 
CoalesceDimensions(absl::InlinedVector<int64_t,4> & a_dims,absl::InlinedVector<int64_t,4> & permutation,absl::InlinedVector<int64_t,4> & lda,absl::InlinedVector<int64_t,4> & lda_tile,absl::InlinedVector<int64_t,4> & a_tiling,absl::InlinedVector<int64_t,4> & b_tiling)589 void TransposePlan::CoalesceDimensions(
590     absl::InlinedVector<int64_t, 4>& a_dims,
591     absl::InlinedVector<int64_t, 4>& permutation,
592     absl::InlinedVector<int64_t, 4>& lda,
593     absl::InlinedVector<int64_t, 4>& lda_tile,
594     absl::InlinedVector<int64_t, 4>& a_tiling,
595     absl::InlinedVector<int64_t, 4>& b_tiling) {
596   int ndim = a_dims.size();
597   // How many positions has the i-th dimension of 'a' been moved to the left?
598   // -1 if the dimension is to be removed.
599   std::vector<int> shift(ndim, 0);
600   absl::InlinedVector<int64_t, 4> updated_a_dims;
601   absl::InlinedVector<int64_t, 4> updated_lda;
602   absl::InlinedVector<int64_t, 4> updated_lda_tile;
603   absl::InlinedVector<int64_t, 4> updated_a_tiling;
604   updated_a_dims.reserve(ndim);
605   updated_lda.reserve(ndim);
606   updated_lda_tile.reserve(ndim);
607   updated_a_tiling.reserve(ndim);
608   std::vector<int64_t> inv_permutation = InversePermutation(permutation);
609   for (int a_dim = 0; a_dim < ndim; ++a_dim) {
610     // We can coalesce two dimensions if they appear consecutively
611     // in both the input dimensions and the output dimensions, and the stride
612     // of the outer dimension is the usual multiple of the inner dimension.
613     if (a_dim > 0 && inv_permutation[a_dim - 1] + 1 == inv_permutation[a_dim] &&
614         lda[a_dim - 1] == lda[a_dim] * a_dims[a_dim] &&
615         a_tiling[a_dim - 1] == 1 && a_tiling[a_dim] == 1 &&
616         b_tiling[inv_permutation[a_dim]] == 1 &&
617         b_tiling[inv_permutation[a_dim - 1]] == 1) {
618       updated_a_dims.back() *= a_dims[a_dim];
619       updated_lda.back() = lda[a_dim];
620       shift[a_dim] = -1;
621     } else {
622       updated_a_dims.push_back(a_dims[a_dim]);
623       updated_lda.push_back(lda[a_dim]);
624       updated_lda_tile.push_back(lda_tile[a_dim]);
625       updated_a_tiling.push_back(a_tiling[a_dim]);
626       shift[a_dim] = a_dim + 1 - updated_a_dims.size();
627     }
628   }
629 
630   // Updates the permutation.
631   absl::InlinedVector<int64_t, 4> updated_permutation;
632   absl::InlinedVector<int64_t, 4> updated_b_tiling;
633   updated_permutation.reserve(updated_a_dims.size());
634   updated_b_tiling.reserve(updated_a_dims.size());
635   for (int b_dim = 0; b_dim < ndim; ++b_dim) {
636     int a_dim = permutation[b_dim];
637     if (shift[a_dim] >= 0) {
638       updated_permutation.push_back(a_dim - shift[a_dim]);
639       updated_b_tiling.push_back(b_tiling[b_dim]);
640     }
641   }
642   DCHECK(IsPermutation(updated_permutation));
643   a_dims = std::move(updated_a_dims);
644   permutation = std::move(updated_permutation);
645   lda = std::move(updated_lda);
646   lda_tile = std::move(updated_lda_tile);
647   a_tiling = std::move(updated_a_tiling);
648   b_tiling = std::move(updated_b_tiling);
649 }
650 
InputNumElems() const651 int64_t TransposePlan::InputNumElems() const {
652   int64_t size = 1;
653   for (size_t i = 0; i < a_dims_.size(); ++i) {
654     size *= RoundUpTo(a_dims_[i], a_tiling_[i]);
655   }
656   return size;
657 }
658 
OutputNumElems() const659 int64_t TransposePlan::OutputNumElems() const {
660   int64_t size = 1;
661   for (size_t i = 0; i < a_dims_.size(); ++i) {
662     size *= RoundUpTo(a_dims_[permutation_[i]], b_tiling_[i]);
663   }
664   return size;
665 }
666 
667 // Parses and validates a tiling specification, and populates `tiling`.
ParseTilingSpecification(int ndim,absl::Span<int64_t const> tiling_spec,absl::InlinedVector<int64_t,4> & tiling)668 static Status ParseTilingSpecification(
669     int ndim, absl::Span<int64_t const> tiling_spec,
670     absl::InlinedVector<int64_t, 4>& tiling) {
671   tiling.resize(ndim, 1);
672   if (tiling_spec.size() > ndim) {
673     return InvalidArgument(
674         "Tiling (%s) must have at as many dimensions as the array (%d)",
675         absl::StrJoin(tiling_spec, ","), ndim);
676   }
677   if (absl::c_find_if(tiling_spec, [](int64_t d) { return d < 1; }) !=
678       tiling_spec.end()) {
679     return InvalidArgument("Tiling sizes (%s) must be >= 1",
680                            absl::StrJoin(tiling_spec, ","));
681   }
682   int offset = ndim;
683   offset -= tiling_spec.size();
684   absl::c_copy(tiling_spec, tiling.begin() + offset);
685   return OkStatus();
686 }
687 
688 // Helper function that builds a plan.
BuildPlanNodes(absl::Span<int64_t const> inverse_permutation,int thread_id,std::vector<TransposePlan::Node> & nodes)689 void TransposePlan::BuildPlanNodes(
690     absl::Span<int64_t const> inverse_permutation, int thread_id,
691     std::vector<TransposePlan::Node>& nodes) {
692   VLOG(8) << "Before plan build: " << ToString();
693   const int ndim = a_dims_.size();
694   DCHECK_GT(ndim, 0);
695   const int pos_stride1a = ndim - 1;
696   const int pos_stride1b_in_a = permutation_.back();
697   const int pos_stride1a_in_b = inverse_permutation[pos_stride1a];
698 
699   // We builld plans in a depth-first order, visiting loops from outermost to
700   // innermost. We use a stack (depth-first) order to handle trailing partial
701   // tiles, which we "come back to" after handling the non-trailing case.
702   struct Agendum {
703     // The ID of the loop to visit in loop_order_.
704     int loop_id;
705     // The parent node ID whose trailing tile should be made to point to this
706     // node.
707     int parent_node_id;
708 
709     // The number of parallel tasks available to run this loop and its
710     // successors.
711     int num_tasks_at_loop;
712 
713     // The ID number of the current thread in the tasks at this loop.
714     int task_id_at_loop;
715 
716     // For which dimensions of `a` are we to visit the partial trailing tile
717     // a loop that visits that tile's interior?
718     absl::InlinedVector<bool, 4> partial_tiles;
719   };
720   std::stack<Agendum> agenda;
721 
722   int total_tasks =
723       absl::c_accumulate(loop_parallelism_, int{1}, std::multiplies<int>());
724 
725   agenda.push(Agendum{/*loop_id=*/0, /*parent_node_id=*/-1,
726                       /*num_tasks_at_loop=*/total_tasks,
727                       /*task_id_at_loop=*/thread_id,
728                       absl::InlinedVector<bool, 4>(ndim, false)});
729 
730   auto loop_has_trivial_iteration_space = [](const Node& node) {
731     return node.start == 0 && node.start + node.inc == node.end;
732   };
733 
734   while (!agenda.empty()) {
735     Agendum agendum = std::move(agenda.top());
736     agenda.pop();
737 
738     int node_id = static_cast<int>(nodes.size());
739     if (agendum.parent_node_id >= 0) {
740       // This is a trailing partial tile node; update the parent node to
741       // point to it.
742       nodes[agendum.parent_node_id].trailing_tile_next_node_inc =
743           node_id - agendum.parent_node_id;
744     }
745 
746     if (agendum.loop_id == loop_order_.size()) {
747       // We've reached the end of the loop nest.
748       DCHECK_EQ(agendum.num_tasks_at_loop, 1);
749       // Transpose loops have a sentinel node, indicated by a negative `inc`
750       // value, that describes the striding of the inner transpose kernel.
751       if (!inner_kernel_is_memcpy_) {
752         Node node;
753         node.start = node.end = node.inc = -1;
754         node.lda = a_tiling_[pos_stride1b_in_a] > 1
755                        ? lda_tile_[pos_stride1b_in_a]
756                        : lda_[pos_stride1b_in_a];
757         node.ldb = b_tiling_[pos_stride1a_in_b] > 1
758                        ? ldb_tile_[pos_stride1a_in_b]
759                        : ldb_[pos_stride1a_in_b];
760         nodes.push_back(node);
761       }
762       DCHECK(!(inner_kernel_is_memcpy_ && agendum.parent_node_id >= 0));
763       continue;
764     }
765 
766     const Loop& loop = loop_order_[agendum.loop_id];
767     int a_dim = loop.dim_in_a;
768     int b_dim = inverse_permutation[a_dim];
769     DCHECK(a_tiling_[a_dim] == 1 || b_tiling_[b_dim] == 1 ||
770            a_tiling_[a_dim] == b_tiling_[b_dim]);
771     int64_t tile_size = std::max(a_tiling_[a_dim], b_tiling_[b_dim]);
772 
773     // Compute the number of tasks for the next loop iteration.
774     int task_id_at_loop = agendum.task_id_at_loop;
775     int num_tasks_at_loop =
776         agendum.num_tasks_at_loop / loop_parallelism_[agendum.loop_id];
777     int task_id_at_next_loop = task_id_at_loop % num_tasks_at_loop;
778 
779     if (loop.tile_interior) {
780       // We are visiting the tile interior of a tiled dimension.
781       bool partial = agendum.partial_tiles[a_dim];
782 
783       Node node;
784       node.lda = a_tiling_[a_dim] > 1 ? lda_tile_[a_dim] : lda_[a_dim];
785       node.ldb = b_tiling_[b_dim] > 1 ? ldb_tile_[b_dim] : ldb_[b_dim];
786       node.inc = 1;
787       node.is_inner_dim_in_a = (a_dim == pos_stride1a);
788       node.is_inner_dim_in_b = (a_dim == pos_stride1b_in_a);
789       if (node.is_inner_dim_in_a) {
790         node.inc = inner_block_elems_ * outer_block_elems_a_;
791       } else if (node.is_inner_dim_in_b) {
792         node.inc = inner_block_elems_ * outer_block_elems_b_;
793       }
794 
795       int task_id = task_id_at_loop / num_tasks_at_loop;
796       int64_t size = partial ? a_dims_[a_dim] % tile_size : tile_size;
797       int64_t num_iterations = CeilOfRatio(size, node.inc);
798       int64_t num_iterations_per_task = CeilOfRatio<int64_t>(
799           num_iterations, loop_parallelism_[agendum.loop_id]);
800       node.start = std::min(size, task_id * num_iterations_per_task * node.inc);
801       node.end =
802           std::min(size, (task_id + 1) * num_iterations_per_task * node.inc);
803       if (!loop_has_trivial_iteration_space(node) ||
804           (inner_kernel_is_memcpy_ && node.is_inner_dim_in_a)) {
805         nodes.push_back(node);
806       }
807       Agendum new_agendum;
808       new_agendum.loop_id = agendum.loop_id + 1;
809       new_agendum.parent_node_id = -1;
810       new_agendum.task_id_at_loop = task_id_at_next_loop;
811       new_agendum.num_tasks_at_loop = num_tasks_at_loop;
812       new_agendum.partial_tiles = agendum.partial_tiles;
813       agenda.push(std::move(new_agendum));
814     } else {
815       // We are either visiting an untiled dimension, or the loop that iterates
816       // over tile exteriors.
817       int task_id = task_id_at_loop / num_tasks_at_loop;
818       int64_t num_complete_tiles = a_dims_[a_dim] / tile_size;
819       bool has_partial_tile = (a_dims_[a_dim] % tile_size != 0);
820 
821       // If there is a trailing partial tile as well as complete tiles, handle
822       // it as a trailer on the loop over complete tiles.
823       bool has_trailing_plan_node = false;
824       if (num_complete_tiles > 0 && has_partial_tile &&
825           task_id == loop_parallelism_[agendum.loop_id] - 1) {
826         Agendum new_agendum;
827         new_agendum.loop_id = agendum.loop_id + 1;
828         new_agendum.parent_node_id = node_id;
829         new_agendum.task_id_at_loop = task_id_at_next_loop;
830         new_agendum.num_tasks_at_loop = num_tasks_at_loop;
831         new_agendum.partial_tiles = agendum.partial_tiles;
832         new_agendum.partial_tiles[a_dim] = true;
833         agenda.push(std::move(new_agendum));
834         has_trailing_plan_node = true;
835       }
836       Node node;
837       node.lda = lda_[a_dim] * tile_size / a_tiling_[a_dim];
838       node.ldb = ldb_[b_dim] * tile_size / b_tiling_[b_dim];
839       node.inc = 1;
840       node.is_inner_dim_in_a = (tile_size == 1 && a_dim == ndim - 1);
841       node.is_inner_dim_in_b = (tile_size == 1 && a_dim == pos_stride1b_in_a);
842       if (node.is_inner_dim_in_a) {
843         node.inc = inner_block_elems_ * outer_block_elems_a_;
844       } else if (node.is_inner_dim_in_b) {
845         node.inc = inner_block_elems_ * outer_block_elems_b_;
846       }
847 
848       // If this tiled dimension consists only of a single partial tile, handle
849       // it here; there's no point emitting a degenerate loop and a separate
850       // path to handle the trailing tile.
851       bool partial = num_complete_tiles == 0 && has_partial_tile;
852 
853       // Evenly divide the loop iterations amongst the threads.
854       int64_t num_tiles = partial ? 1 : num_complete_tiles;
855       int64_t num_iterations = CeilOfRatio(num_tiles, node.inc);
856       int64_t num_iterations_per_task = CeilOfRatio<int64_t>(
857           num_iterations, loop_parallelism_[agendum.loop_id]);
858       node.start =
859           std::min(num_tiles, task_id * num_iterations_per_task * node.inc);
860       node.end = std::min(num_tiles,
861                           (task_id + 1) * num_iterations_per_task * node.inc);
862       // If this loop has a trivial iteration space, drop it.
863       if (!loop_has_trivial_iteration_space(node) ||
864           (inner_kernel_is_memcpy_ && node.is_inner_dim_in_a) ||
865           has_trailing_plan_node) {
866         nodes.push_back(node);
867       }
868       Agendum new_agendum;
869       new_agendum.loop_id = agendum.loop_id + 1;
870       new_agendum.parent_node_id = -1;
871       new_agendum.task_id_at_loop = task_id_at_next_loop;
872       new_agendum.num_tasks_at_loop = num_tasks_at_loop;
873       new_agendum.partial_tiles = agendum.partial_tiles;
874       new_agendum.partial_tiles[a_dim] = partial;
875       agenda.push(std::move(new_agendum));
876     }
877   }
878 }
879 
Create(size_t elem_size_in_bytes,absl::Span<int64_t const> dims,absl::Span<int64_t const> permutation,std::variant<Tiling,Striding> input_layout,Tiling output_tiling,Transformation transformation,int num_threads)880 StatusOr<std::unique_ptr<TransposePlan>> TransposePlan::Create(
881     size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
882     absl::Span<int64_t const> permutation,
883     std::variant<Tiling, Striding> input_layout, Tiling output_tiling,
884     Transformation transformation, int num_threads) {
885   auto is_negative = [](int d) { return d < 0; };
886   if (absl::c_find_if(dims, is_negative) != dims.end()) {
887     return InvalidArgument("dims must be non-negative, got %s",
888                            absl::StrJoin(dims, ","));
889   }
890   if (permutation.size() != dims.size()) {
891     return InvalidArgument(
892         "dims and permutation must have equal sizes, got %d and %d",
893         dims.size(), permutation.size());
894   }
895   if (!IsPermutation(permutation)) {
896     return InvalidArgument("permutation argument is not valid, got: %s",
897                            absl::StrJoin(permutation, ","));
898   }
899   if (num_threads < 1) {
900     return InvalidArgument("num_threads argument must be >= 1, got: %d",
901                            num_threads);
902   }
903 
904   int ndim = dims.size();
905 
906   auto plan = std::make_unique<TransposePlan>();
907   plan->num_threads_requested_ = num_threads;
908   plan->elem_size_in_bytes_ = elem_size_in_bytes;
909   switch (elem_size_in_bytes) {
910     case 1:
911     case 2:
912     case 4:
913     case 8:
914     case 16:
915       break;
916     default:
917       return InvalidArgument("Unsupported elem_size_in_bytes=%d",
918                              elem_size_in_bytes);
919   }
920   plan->num_elems_ = std::accumulate(dims.begin(), dims.end(), int64_t{1},
921                                      std::multiplies<int64_t>());
922   plan->original_a_dims_.resize(ndim);
923   absl::c_copy(dims, plan->original_a_dims_.begin());
924   plan->original_b_dims_ = Permute(dims, permutation);
925 
926   TF_RETURN_IF_ERROR(
927       ParseTilingSpecification(ndim, output_tiling.tiling, plan->b_tiling_));
928 
929   // Handles strides.
930   if (std::holds_alternative<Striding>(input_layout)) {
931     absl::Span<int64_t const> input_strides_in_bytes =
932         std::get<Striding>(input_layout).strides_in_bytes;
933     if (input_strides_in_bytes.size() != dims.size()) {
934       return InvalidArgument(
935           "dims and input_strides_in_bytes must have equal sizes, got %d "
936           "and %d",
937           dims.size(), input_strides_in_bytes.size());
938     }
939     plan->original_a_strides_.resize(ndim);
940     absl::c_copy(input_strides_in_bytes, plan->original_a_strides_.begin());
941     // Sort the dimensions from slowest-varying (largest strides) to
942     // fastest-varying (smallest strides).
943     std::vector<int64_t> dim_order(ndim);
944     absl::c_iota(dim_order, 0);
945 
946     auto cost = [&](int k) {
947       int64_t stride = input_strides_in_bytes.at(k);
948       // If there is a dimension with size equal to the element size, sort it
949       // last. This ensures that we place any stride-1 dimension last.
950       bool is_stride1 = stride == elem_size_in_bytes;
951       // If there are multiple stride-1 dimensions, we'd prefer the one that
952       // matches the stride-1 dimension of the output.
953       // Failing that, we'd just prefer the largest stride-1 dimension last.
954       bool is_trailing_dim_in_b = permutation.back() == k;
955 
956       // If we are applying ef57 conversion, we want a size-2 stride-1
957       // dimension last.
958       bool ef57_even =
959           (is_stride1 && transformation == Transformation::kF64ToEf57 &&
960            dims[k] == 2);
961 
962       return std::make_tuple(is_stride1, -std::abs(stride), ef57_even,
963                              is_trailing_dim_in_b, dims[k]);
964     };
965     absl::c_stable_sort(dim_order,
966                         [&cost](int i, int j) { return cost(i) < cost(j); });
967     // dim_order maps new input dim -> old input dim, we need its inverse to
968     // compute the new permutation.
969     auto inv_dim_order = InversePermutation(dim_order);
970     plan->lda_.reserve(ndim);
971     plan->a_dims_.reserve(ndim);
972     plan->permutation_.reserve(ndim);
973     for (int i = 0; i < ndim; ++i) {
974       plan->lda_.push_back(input_strides_in_bytes.at(dim_order[i]));
975       plan->a_dims_.push_back(dims[dim_order[i]]);
976       plan->permutation_.push_back(inv_dim_order[permutation[i]]);
977     }
978     plan->lda_tile_.resize(ndim, 1);
979     plan->a_tiling_.resize(ndim, 1);
980   } else {
981     TF_RETURN_IF_ERROR(ParseTilingSpecification(
982         ndim, std::get<Tiling>(input_layout).tiling, plan->a_tiling_));
983 
984     plan->a_dims_ = plan->original_a_dims_;
985     plan->permutation_.resize(ndim);
986     absl::c_copy(permutation, plan->permutation_.begin());
987     ComputeStrides(plan->elem_size_in_bytes_, plan->a_dims_, plan->a_tiling_,
988                    plan->lda_, plan->lda_tile_);
989   }
990 
991   auto is_not_one = [](int64_t x) { return x != 1; };
992   plan->a_is_tiled_ =
993       (absl::c_find_if(plan->a_tiling_, is_not_one) != plan->a_tiling_.end());
994   plan->b_is_tiled_ =
995       (absl::c_find_if(plan->b_tiling_, is_not_one) != plan->b_tiling_.end());
996   if (plan->a_is_tiled_ && plan->b_is_tiled_) {
997     return Unimplemented(
998         "Only one of the input and output may have a non-trivial tiling, "
999         "got tilings: %s and %s",
1000         absl::StrJoin(plan->a_tiling_, ","),
1001         absl::StrJoin(plan->b_tiling_, ","));
1002   }
1003 
1004   plan->transformation_ = transformation;
1005   switch (transformation) {
1006     case Transformation::kNone:
1007       break;
1008     case Transformation::kF64ToEf57:
1009       if (elem_size_in_bytes != sizeof(float)) {
1010         return InvalidArgument(
1011             "EF57 conversion requires a element size of %d bytes, got %d",
1012             sizeof(float), elem_size_in_bytes);
1013       }
1014       if (plan->a_dims_.empty() || plan->a_dims_.back() % 2 != 0 ||
1015           plan->lda_.back() != sizeof(float)) {
1016         return InvalidArgument(
1017             "EF57 conversion requires a stride-%d dimension whose size is a "
1018             "multiple of 2",
1019             sizeof(float));
1020       }
1021   }
1022 
1023   plan->Initialize();
1024   VLOG(5) << plan->ToString();
1025   return plan;
1026 }
1027 
Initialize()1028 void TransposePlan::Initialize() {
1029   if (num_elems_ == 0) {
1030     return;
1031   }
1032   RemoveTrivialDimensions(a_dims_, permutation_, lda_, lda_tile_, a_tiling_,
1033                           b_tiling_);
1034   CoalesceDimensions(a_dims_, permutation_, lda_, lda_tile_, a_tiling_,
1035                      b_tiling_);
1036 
1037   // permutation maps dimensions of b to a
1038   // inverse_permutation maps dimensions of a to b
1039   std::vector<int64_t> inverse_permutation = InversePermutation(permutation_);
1040 
1041   int ndim = a_dims_.size();
1042 
1043   int64_t stride_pos1a =
1044       lda_.empty()
1045           ? -1
1046           : (a_tiling_[ndim - 1] > 1 ? lda_tile_[ndim - 1] : lda_[ndim - 1]);
1047   // We don't accept arbitrary stridings for B, so we know B always has a
1048   // stride 1 dimension innermost.
1049 
1050   // If the plan is 0-dimensional, or the innermost dimension of A is not of
1051   // stride 1, adds a trivial size 1 dimension. The transpose kernels rely on
1052   // the presence of a stride-1 innermost dimension in the input.
1053   if (lda_.empty() || stride_pos1a != elem_size_in_bytes_) {
1054     int dim = static_cast<int>(a_dims_.size());
1055     permutation_.push_back(dim);
1056     inverse_permutation.push_back(dim);
1057     a_dims_.push_back(1);
1058     lda_.push_back(elem_size_in_bytes_);
1059     lda_tile_.push_back(1);
1060     a_tiling_.push_back(1);
1061     b_tiling_.push_back(1);
1062     ++ndim;
1063   }
1064   b_dims_ = Permute(a_dims_, permutation_);
1065   ComputeStrides(elem_size_in_bytes_, b_dims_, b_tiling_, ldb_, ldb_tile_);
1066 
1067   const int pos_stride1a = ndim - 1;
1068   const int pos_stride1b_in_a = permutation_.back();
1069   inner_kernel_is_memcpy_ = (pos_stride1b_in_a == pos_stride1a);
1070 
1071   loop_order_.reserve(ndim);
1072   for (int i = 0; i < ndim; ++i) {
1073     loop_order_.push_back(Loop{i, /*tile_interior=*/false});
1074     if (a_tiling_[i] != 1 || b_tiling_[inverse_permutation[i]] != 1) {
1075       loop_order_.push_back(Loop{i, /*tile_interior=*/true});
1076     }
1077   }
1078 
1079   // Bound the block sizes so they are smaller than the stride-1 dimension
1080   // size.
1081   int64_t a_stride1_size = std::max(
1082       a_tiling_[pos_stride1a], b_tiling_[inverse_permutation[pos_stride1a]]);
1083   if (a_stride1_size == 1) {
1084     a_stride1_size = a_dims_[pos_stride1a];
1085   } else {
1086     // If there's only one tile, we should use the dimension size.
1087     a_stride1_size = std::min(a_dims_[pos_stride1a], a_stride1_size);
1088   }
1089   int64_t b_stride1_size =
1090       std::max(a_tiling_[permutation_.back()], b_tiling_.back());
1091   if (b_stride1_size == 1) {
1092     b_stride1_size = b_dims_.back();
1093   } else {
1094     b_stride1_size = std::min(b_stride1_size, b_dims_.back());
1095   }
1096 
1097   if (inner_kernel_is_memcpy_) {
1098     inner_block_elems_ = -1;
1099     outer_block_elems_a_ = -1;
1100     outer_block_elems_b_ = -1;
1101   } else {
1102     // What are the smallest and largest block sizes for which we have a
1103     // vectorized kernel for this element size?
1104     int min_inner_block_elems;
1105     int max_inner_block_elems;
1106     switch (elem_size_in_bytes_) {
1107       case 1:
1108         min_inner_block_elems = 4;
1109         max_inner_block_elems = 16;
1110         break;
1111       case 2:
1112         min_inner_block_elems = 8;
1113         max_inner_block_elems = 8;
1114         break;
1115       case 4:
1116         min_inner_block_elems = 4;
1117         max_inner_block_elems = 8;
1118         break;
1119       case 8:
1120         min_inner_block_elems = 2;
1121         max_inner_block_elems = 4;
1122         break;
1123       case 16:
1124         min_inner_block_elems = 1;
1125         max_inner_block_elems = 1;
1126         break;
1127       default:
1128         LOG(FATAL) << "Unreachable: element size " << elem_size_in_bytes_;
1129     }
1130     inner_block_elems_ = max_inner_block_elems;
1131     while (inner_block_elems_ > std::min(a_stride1_size, b_stride1_size)) {
1132       inner_block_elems_ /= 2;
1133     }
1134     if (inner_block_elems_ < min_inner_block_elems) {
1135       // Size is smaller than our smallest vectorized kernel. Use the scalar
1136       // path.
1137       inner_block_elems_ = 1;
1138     }
1139     outer_block_elems_a_ = FloorOfRatio<int64_t>(
1140         std::min<int64_t>(16, a_stride1_size), inner_block_elems_);
1141     outer_block_elems_b_ = FloorOfRatio<int64_t>(
1142         std::min<int64_t>(16, b_stride1_size), inner_block_elems_);
1143   }
1144 
1145   // Loop order heuristic: try to make loops with small strides innermost.
1146   auto cost = [&](const Loop& l) {
1147     int64_t a_stride =
1148         std::abs((l.tile_interior && a_is_tiled_) ? lda_tile_[l.dim_in_a]
1149                                                   : lda_[l.dim_in_a]);
1150     bool is_inner_dim_in_a =
1151         (!a_is_tiled_ || l.tile_interior) && (l.dim_in_a == pos_stride1a);
1152 
1153     if (!inner_kernel_is_memcpy_ && is_inner_dim_in_a) {
1154       a_stride *= inner_block_elems_ * outer_block_elems_a_;
1155     }
1156     int b_dim = inverse_permutation[l.dim_in_a];
1157     int64_t b_stride =
1158         (l.tile_interior && b_is_tiled_) ? ldb_tile_[b_dim] : ldb_[b_dim];
1159     bool is_inner_dim_in_b =
1160         (!b_is_tiled_ || l.tile_interior) && (l.dim_in_a == pos_stride1b_in_a);
1161     if (!inner_kernel_is_memcpy_ && is_inner_dim_in_b) {
1162       b_stride *= inner_block_elems_ * outer_block_elems_b_;
1163     }
1164     // Add a small penalty to the input strides: given the choice between
1165     // consecutive writes and consecutive reads, we would prefer consecutive
1166     // writes.
1167     double penalty = 1.01;
1168 
1169     // If the inner kernel is a memcpy make sure the innermost loop is the
1170     // stride-1 dimension. This is a requirement of the memcpy kernel.
1171     bool dim_must_go_last =
1172         inner_kernel_is_memcpy_ && l.dim_in_a == pos_stride1a &&
1173         (l.tile_interior ||
1174          (a_tiling_[l.dim_in_a] == 1 && b_tiling_[b_dim] == 1));
1175     return std::make_tuple(dim_must_go_last,
1176                            inner_kernel_is_memcpy_ && l.tile_interior,
1177                            -std::min<double>(a_stride * penalty, b_stride));
1178   };
1179   absl::c_stable_sort(loop_order_, [&](const Loop& a, const Loop& b) {
1180     return cost(a) < cost(b);
1181   });
1182   // It is a required invariant of the loop order that tile interiors always
1183   // appear after the corresponding tile exterior. This is a consequence of the
1184   // heuristic above, because the tile interior must have smaller strides in
1185   // both input and output.
1186 
1187   // The stride-1 loop must be innermost for a memcpy loop.
1188   DCHECK(!inner_kernel_is_memcpy_ || loop_order_.back().dim_in_a == ndim - 1)
1189       << ToString();
1190 
1191   loop_parallelism_ = ChooseParallelizationStrategy(inverse_permutation);
1192   int num_threads =
1193       absl::c_accumulate(loop_parallelism_, int{1}, std::multiplies<int>());
1194   nodes_.resize(num_threads);
1195   for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
1196     BuildPlanNodes(inverse_permutation, thread_id, nodes_[thread_id]);
1197   }
1198 
1199   switch (transformation_) {
1200     case Transformation::kNone:
1201       scratch_size_ = 0;
1202       break;
1203     case Transformation::kF64ToEf57:
1204       scratch_size_ = sizeof(float) * inner_block_elems_ * inner_block_elems_ *
1205                       outer_block_elems_a_ * outer_block_elems_b_;
1206       DCHECK(!inner_kernel_is_memcpy_);
1207       break;
1208   }
1209 }
1210 
ChooseParallelizationStrategy(absl::Span<int64_t const> inverse_permutation)1211 std::vector<int> TransposePlan::ChooseParallelizationStrategy(
1212     absl::Span<int64_t const> inverse_permutation) {
1213   std::vector<int> parallelism;
1214   int available_parallelism = num_threads_requested_;
1215   parallelism.reserve(loop_order_.size());
1216 
1217   int ndim = permutation_.size();
1218   const int pos_stride1a = ndim - 1;
1219   const int pos_stride1b_in_a = permutation_.back();
1220   // Compute the number of iterations in `loop`.
1221   auto loop_iterations = [&](const Loop& loop) {
1222     int a_dim = loop.dim_in_a;
1223     int b_dim = inverse_permutation[a_dim];
1224     int64_t tile_size = std::max(a_tiling_[a_dim], b_tiling_[b_dim]);
1225     int64_t size = loop.tile_interior
1226                        ? tile_size
1227                        : (CeilOfRatio(a_dims_[loop.dim_in_a], tile_size));
1228     if (!inner_kernel_is_memcpy_ && (loop.tile_interior || tile_size == 1)) {
1229       if (loop.dim_in_a == pos_stride1a) {
1230         size = CeilOfRatio<int64_t>(size,
1231                                     inner_block_elems_ * outer_block_elems_a_);
1232       } else if (loop.dim_in_a == pos_stride1b_in_a) {
1233         size = CeilOfRatio<int64_t>(size,
1234                                     inner_block_elems_ * outer_block_elems_b_);
1235       }
1236     }
1237     return size;
1238   };
1239 
1240   // Estimate the number of bytes each iteration of each loop processes.
1241   absl::InlinedVector<int64_t, 4> work_in_bytes(loop_order_.size());
1242   int64_t acc = elem_size_in_bytes_;
1243   if (!inner_kernel_is_memcpy_) {
1244     acc *= inner_block_elems_ * inner_block_elems_ * outer_block_elems_a_ *
1245            outer_block_elems_b_;
1246   }
1247   auto work_it = work_in_bytes.rbegin();
1248   for (auto it = loop_order_.rbegin(); it != loop_order_.rend(); ++it) {
1249     *work_it++ = acc;
1250     acc *= loop_iterations(*it);
1251   }
1252   VLOG(7) << "Per-loop iteration work in bytes: "
1253           << absl::StrJoin(work_in_bytes, ",");
1254 
1255   // Heuristic that attempts to parallelize the outermost loops, down to a
1256   // minimum per-thread number of bytes processed.
1257   for (size_t i = 0; i < loop_order_.size(); ++i) {
1258     const Loop& loop = loop_order_[i];
1259     CHECK_GE(available_parallelism, 1);
1260     int64_t iterations = loop_iterations(loop);
1261     int kMinBytesPerThread = inner_kernel_is_memcpy_ ? (1 << 20) : (1 << 26);
1262     int64_t min_iterations_per_thread =
1263         CeilOfRatio<int64_t>(kMinBytesPerThread, work_in_bytes[i]);
1264     int64_t parallel_work = CeilOfRatio(iterations, min_iterations_per_thread);
1265 
1266     VLOG(8) << "iterations=" << iterations << " parallel_work=" << parallel_work
1267             << " available_parallelism=" << available_parallelism;
1268     if (parallel_work >= available_parallelism) {
1269       parallelism.push_back(available_parallelism);
1270       available_parallelism = 1;
1271     } else {
1272       parallelism.push_back(parallel_work);
1273       available_parallelism /= parallel_work;
1274     }
1275   }
1276   return parallelism;
1277 }
1278 
ToString() const1279 std::string TransposePlan::ToString() const {
1280   std::string nodes_str = absl::StrJoin(
1281       nodes_, "\n", [](std::string* out, absl::Span<Node const> thread_nodes) {
1282         absl::StrAppend(
1283             out, "thread:\n",
1284             absl::StrJoin(
1285                 thread_nodes, "\n", [](std::string* out, const Node& node) {
1286                   absl::StrAppendFormat(
1287                       out,
1288                       "    "
1289                       "Node(start=%d,end=%d,inc=%d,lda=%"
1290                       "d,ldb=%d,next_trailing=%d,inner_a=%s,inner_b=%s)",
1291                       node.start, node.end, node.inc, node.lda, node.ldb,
1292                       node.trailing_tile_next_node_inc,
1293                       node.is_inner_dim_in_a ? "y" : "n",
1294                       node.is_inner_dim_in_b ? "y" : "n");
1295                 }));
1296       });
1297   auto format_loop_order = [](std::string* out, const Loop& loop) {
1298     return absl::StrAppend(out, loop.dim_in_a,
1299                            loop.tile_interior ? "[tile]" : "");
1300   };
1301   std::string transformation_str;
1302   switch (transformation_) {
1303     case Transformation::kNone:
1304       transformation_str = "none";
1305       break;
1306     case Transformation::kF64ToEf57:
1307       transformation_str = "ef57";
1308       break;
1309   }
1310   return absl::StrFormat(
1311       "elem_size=%d a_dims=%s b_dims=%s permutation=%s a_tiling=%s b_tiling=%s "
1312       "lda=%s lda_tile=%s ldb=%s ldb_tile=%s loop_order=%s "
1313       "loop_parallelism=%s outer_bs=[%d,%d] inner_bs=%d "
1314       "transformation=%s scratch_size=%d\n"
1315       "nodes:\n%s",
1316       elem_size_in_bytes_, absl::StrJoin(a_dims_, ","),
1317       absl::StrJoin(Permute(a_dims_, permutation_), ","),
1318       absl::StrJoin(permutation_, ","), absl::StrJoin(a_tiling_, ","),
1319       absl::StrJoin(b_tiling_, ","), absl::StrJoin(lda_, ","),
1320       absl::StrJoin(lda_tile_, ","), absl::StrJoin(ldb_, ","),
1321       absl::StrJoin(ldb_tile_, ","),
1322       absl::StrJoin(loop_order_, ",", format_loop_order),
1323       absl::StrJoin(loop_parallelism_, ","), outer_block_elems_a_,
1324       outer_block_elems_b_, inner_block_elems_, transformation_str,
1325       scratch_size_, nodes_str);
1326 }
1327 
1328 struct TransposePlanCacheKey {
1329   size_t elem_size_in_bytes;
1330   absl::InlinedVector<int64_t, 4> dims;
1331   absl::InlinedVector<int64_t, 4> permutation;
1332   bool input_layout_is_tiling;
1333   absl::InlinedVector<int64_t, 4> input_layout;
1334   absl::InlinedVector<int64_t, 4> output_tiling;
1335   TransposePlan::Transformation transformation;
1336   int num_threads;
1337 
1338   bool operator==(const TransposePlanCacheKey& other) const;
1339 };
1340 
operator ==(const TransposePlanCacheKey & other) const1341 bool TransposePlanCacheKey::operator==(
1342     const TransposePlanCacheKey& other) const {
1343   return elem_size_in_bytes == other.elem_size_in_bytes && dims == other.dims &&
1344          permutation == other.permutation &&
1345          input_layout_is_tiling == other.input_layout_is_tiling &&
1346          input_layout == other.input_layout &&
1347          output_tiling == other.output_tiling &&
1348          transformation == other.transformation &&
1349          num_threads == other.num_threads;
1350 }
1351 
1352 template <typename H>
AbslHashValue(H h,const TransposePlanCacheKey & key)1353 H AbslHashValue(H h, const TransposePlanCacheKey& key) {
1354   return H::combine(std::move(h), key.elem_size_in_bytes,
1355                     key.input_layout_is_tiling, key.num_threads,
1356                     key.transformation, key.dims, key.permutation,
1357                     key.input_layout, key.output_tiling);
1358 }
1359 
TransposePlanCache(int capacity)1360 TransposePlanCache::TransposePlanCache(int capacity)
1361     : lru_list_(capacity), cache_(&lru_list_) {}
1362 
1363 TransposePlanCache::~TransposePlanCache() = default;
1364 
GetOrCreate(size_t elem_size_in_bytes,absl::Span<int64_t const> dims,absl::Span<int64_t const> permutation,std::variant<TransposePlan::Tiling,TransposePlan::Striding> input_layout,TransposePlan::Tiling output_tiling,TransposePlan::Transformation transformation,int num_threads)1365 StatusOr<std::shared_ptr<TransposePlan>> TransposePlanCache::GetOrCreate(
1366     size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
1367     absl::Span<int64_t const> permutation,
1368     std::variant<TransposePlan::Tiling, TransposePlan::Striding> input_layout,
1369     TransposePlan::Tiling output_tiling,
1370     TransposePlan::Transformation transformation, int num_threads) {
1371   TransposePlanCacheKey key;
1372   key.elem_size_in_bytes = elem_size_in_bytes;
1373   key.dims.resize(dims.size());
1374   absl::c_copy(dims, key.dims.begin());
1375   key.permutation.resize(permutation.size());
1376   absl::c_copy(permutation, key.permutation.begin());
1377   if (std::holds_alternative<TransposePlan::Striding>(input_layout)) {
1378     absl::Span<int64_t const> input_strides_in_bytes =
1379         std::get<TransposePlan::Striding>(input_layout).strides_in_bytes;
1380     key.input_layout = absl::InlinedVector<int64_t, 4>(
1381         input_strides_in_bytes.begin(), input_strides_in_bytes.end());
1382     key.input_layout_is_tiling = false;
1383   } else {
1384     absl::Span<int64_t const> input_tiling =
1385         std::get<TransposePlan::Tiling>(input_layout).tiling;
1386     key.input_layout = absl::InlinedVector<int64_t, 4>(input_tiling.begin(),
1387                                                        input_tiling.end());
1388     key.input_layout_is_tiling = true;
1389   }
1390   key.output_tiling.resize(output_tiling.tiling.size());
1391   absl::c_copy(output_tiling.tiling, key.output_tiling.begin());
1392   key.transformation = transformation;
1393   key.num_threads = num_threads;
1394   return cache_.GetOrCreateIfAbsent(
1395       key,
1396       [&](const TransposePlanCacheKey& key)
1397           -> StatusOr<std::shared_ptr<TransposePlan>> {
1398         TF_ASSIGN_OR_RETURN(
1399             std::unique_ptr<TransposePlan> plan,
1400             TransposePlan::Create(elem_size_in_bytes, dims, permutation,
1401                                   input_layout, output_tiling, transformation,
1402                                   num_threads));
1403         return std::shared_ptr<TransposePlan>(std::move(plan));
1404       });
1405 }
1406 
1407 }  // namespace xla
1408