1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Sketch of algorithm:
17 // A good description can be found in the HPTT paper, which is linked from the
18 // header file.
19 //
20 // The algorithm is divided into two parts: plan creation, that chooses an
21 // execution plan for the transpose, and execution. Plans may be cached and
22 // reused multiple times.
23 //
24 // We use a two level blocking scheme:
25 //
26 // The inner "microkernel" level is the unit of vectorization. A microkernel
27 // transposes a vector-unit sized tile, e.g., an 8x8 tile of floats for AVX2.
28 // The microkernels require one of the input dimensions and one of the
29 // output dimensions to have stride 1, while the other dimension in each case
30 // may have a non-trivial element stride. The kernels load N vectors of N
31 // stride-1 elements, and performs an NxN transpose, swapping the role of the
32 // stride-1 dimension. The N vectors are then written to the output. To perform
33 // a complete tensor transpose, we simply need apply the microkernel over all
34 // blocks of the matrix.
35 //
36 // In the event that the stride-1 dimensions of the input and output are the
37 // same, we use a simpler kernel which is a memcpy().
38 //
39 // To improve cache locality, we use another level of blocking, namely
40 // "macrokernels". The outer "macrokernel" level is a block of, for example,
41 // 4x4 microkernels. Macrokernels are the basic unit of work of the loop nest
42 // plan. For dimensions that aren't exactly divisible by the macrokernel size,
43 // we repeatedly halve the kernel size for trailing elements. For dimensions
44 // that aren't exactly divisible by the microkernel size, we use a scalar
45 // transpose for the trailing elements.
46 //
47 // A transpose plan iterates over the array's index space, applying macro-kernel
48 // sized blocks. Any iteration order is possible, although some orders give
49 // better locality than others. Currently we always use a default iteration
50 // order.
51 //
52 // A plan contains the data structures that describe how to perform a transpose.
53 // Plan creation chooses such things as a loop iteration order and kernel sizes.
54 // Plan creation also performs a handful of optimizations, such as
55 // coalescing adjacent dimensions that do not change their order and removing
56 // trivial dimensions.
57 //
58 // TODO(phawkins):
59 // * we don't incorporate a number of optimizations from HPTT, notably explicit
60 // prefetching, and manual loop unrolling.
61 // * we could use vector-aligned stores for some arrays, which might
62 // be worth something. We could also use nontemporal stores in the aligned
63 // case.
64 // * we don't yet search for a good loop ordering. This probably matters less
65 // for arrays that fit entirely in cache.
66 // * we could do a better job of vectorizing where the stride-1 dimensions are
67 // small (e.g., inner dimensions of size [..., 3] are not uncommon in some
68 // use cases.)
69
70 #include "tensorflow/compiler/xla/pjrt/transpose.h"
71
72 #include <algorithm>
73 #include <functional>
74 #include <numeric>
75 #include <stack>
76 #include <string>
77 #include <utility>
78
79 #include "absl/algorithm/container.h"
80 #include "absl/strings/str_format.h"
81 #include "absl/strings/str_join.h"
82 #include "absl/synchronization/blocking_counter.h"
83 #include "absl/types/span.h"
84 #include "absl/types/variant.h"
85 #include "tensorflow/compiler/xla/permutation_util.h"
86 #include "tensorflow/compiler/xla/pjrt/transpose_kernels.h"
87 #include "tensorflow/compiler/xla/status.h"
88 #include "tensorflow/compiler/xla/util.h"
89 #include "tensorflow/core/platform/logging.h"
90 #include "tensorflow/core/profiler/lib/traceme.h"
91
92 namespace xla {
93
94 // A plan is a data structure that describes a loop nest.
95 // TODO(phawkins): consider shrinking Node so it fits in a cache line.
96 struct TransposePlan::Node {
97 // The loop should iterate over the index space range(start, end, inc).
98 // These fields are ignored by the macrokernel.
99 int64_t start;
100 int64_t end;
101 int64_t inc; // The transpose sentinel node has inc < 0.
102
103 // Strides of this dimension in A and B.
104 int64_t lda;
105 int64_t ldb;
106
107 // If > 0, this loop is a loop over tile exteriors and has a trailing partial
108 // tile. To handle the trailing partial tile, skip to the plan node this many
109 // steps ahead in the vector of plan nodes.
110 int trailing_tile_next_node_inc = 0;
111
112 // Is this dimension the innermost dimension in either A or B, and hence may
113 // have non-trivial blocking?
114 bool is_inner_dim_in_a = false;
115 bool is_inner_dim_in_b = false;
116 };
117
ConvertF64ToEf57(const double * input,float * output,int n)118 void ConvertF64ToEf57(const double* input, float* output, int n) {
119 // TODO(phawkins): vectorize this transformation.
120 for (int i = 0; i < n; ++i) {
121 std::tie(output[0], output[1]) = SplitF64ToF32(*input);
122 ++input;
123 output += 2;
124 }
125 }
126
127 template <typename T, int inner_bs,
128 TransposePlan::Transformation transformation>
MacroKernel(const char * __restrict a,int64_t lda,int outer_bs_a,char * __restrict b,int64_t ldb,int outer_bs_b,void * __restrict scratch)129 void MacroKernel(const char* __restrict a, int64_t lda, int outer_bs_a,
130 char* __restrict b, int64_t ldb, int outer_bs_b,
131 void* __restrict scratch) {
132 DVLOG(10) << "MacroKernel lda=" << lda << " ldb=" << ldb
133 << " outer_bs_a=" << outer_bs_a << " outer_bs_b=" << outer_bs_b
134 << " inner_bs=" << inner_bs;
135
136 // TODO(phawkins): consider adding prefetching and streaming stores.
137
138 if (transformation == TransposePlan::Transformation::kF64ToEf57) {
139 DCHECK_EQ(outer_bs_a * inner_bs % 2, 0);
140 float* p = reinterpret_cast<float*>(scratch);
141 for (int i = 0; i < outer_bs_b * inner_bs; ++i) {
142 ConvertF64ToEf57(reinterpret_cast<const double*>(a + lda * i),
143 p + outer_bs_a * inner_bs * i,
144 outer_bs_a * inner_bs / 2);
145 }
146 a = reinterpret_cast<const char*>(scratch);
147 lda = outer_bs_a * inner_bs * sizeof(float);
148 }
149
150 for (int i = 0; i < outer_bs_a; ++i) {
151 for (int j = 0; j < outer_bs_b; ++j) {
152 TransposeMicroKernel<T, inner_bs>::Apply(
153 a + inner_bs * j * lda + i * inner_bs * sizeof(T), lda,
154 b + inner_bs * i * ldb + j * inner_bs * sizeof(T), ldb);
155 }
156 }
157 }
158
159 // Transpose() is a driver function that implements a multidimensional loop nest
160 // following by iterating over the linked Node data structure.
161 template <typename T, int inner_bs,
162 TransposePlan::Transformation transformation>
Transpose(const char * __restrict a,int outer_bs_a,char * __restrict b,int outer_bs_b,TransposePlan::Node const * __restrict node,void * __restrict scratch)163 void Transpose(const char* __restrict a, int outer_bs_a, char* __restrict b,
164 int outer_bs_b, TransposePlan::Node const* __restrict node,
165 void* __restrict scratch) {
166 DVLOG(10) << "Transpose " << outer_bs_a << " " << outer_bs_b;
167 DCHECK_GT(outer_bs_a, 0);
168 DCHECK_GT(outer_bs_b, 0);
169 const int64_t start = node->start;
170 const int64_t end = node->end;
171 const int64_t stop = node->end - (node->inc - 1);
172 const int64_t lda = node->lda;
173 const int64_t ldb = node->ldb;
174 const int64_t inc = node->inc;
175 TransposePlan::Node const* next_node = node + 1;
176 if (next_node->inc < 0) {
177 // This is the last loop in the nested loops. The next node is a sentinel
178 // plan node that describes how to invoke the macrokernels.
179
180 const int64_t lda_block = next_node->lda;
181 const int64_t ldb_block = next_node->ldb;
182 int64_t i;
183 for (i = start; i < stop; i += inc) {
184 MacroKernel<T, inner_bs, transformation>(a + i * lda, lda_block,
185 outer_bs_a, b + i * ldb,
186 ldb_block, outer_bs_b, scratch);
187 }
188 // Handle trailing elements that didn't fit in a complete macrokernel.
189 // Only the innermost dimensions have non-trivial outer_bs blocking.
190 if (i < end) {
191 DCHECK_EQ(node->trailing_tile_next_node_inc, 0);
192 DCHECK(node->is_inner_dim_in_a || node->is_inner_dim_in_b);
193 if (node->is_inner_dim_in_a) {
194 outer_bs_a = (end - i) / inner_bs;
195 if (outer_bs_a > 0) {
196 MacroKernel<T, inner_bs, transformation>(
197 a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
198 outer_bs_b, scratch);
199 i += outer_bs_a * inner_bs;
200 }
201 // If there are still trailing elements left over that don't fit in the
202 // inner block size, handle them via an unvectorized transpose.
203 if (i < end) {
204 MacroKernel<T, 1, transformation>(a + i * lda, lda_block, end - i,
205 b + i * ldb, ldb_block,
206 outer_bs_b * inner_bs, scratch);
207 }
208 } else if (node->is_inner_dim_in_b) {
209 outer_bs_b = (end - i) / inner_bs;
210 if (outer_bs_b > 0) {
211 MacroKernel<T, inner_bs, transformation>(
212 a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
213 outer_bs_b, scratch);
214 i += outer_bs_b * inner_bs;
215 }
216 if (i < end) {
217 MacroKernel<T, 1, transformation>(a + i * lda, lda_block,
218 outer_bs_a * inner_bs, b + i * ldb,
219 ldb_block, end - i, scratch);
220 }
221 }
222 } else if (node->trailing_tile_next_node_inc) {
223 // Handle the case where there is a trailing partial tile. We know
224 // inc == 1 for this case, so the loop above has already left `a` and `b`
225 // pointing to the start of the tile. We just need to use the alternate
226 // trailing_next_node to process the interior of the tile.
227 DCHECK_EQ(inc, 1);
228 TransposePlan::Node const* trailing_next_node =
229 node + node->trailing_tile_next_node_inc;
230 if (trailing_next_node->inc < 0) {
231 const int64_t lda_block = trailing_next_node->lda;
232 const int64_t ldb_block = trailing_next_node->ldb;
233 MacroKernel<T, inner_bs, transformation>(
234 a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
235 outer_bs_b, scratch);
236 } else {
237 Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
238 b + i * ldb, outer_bs_b,
239 trailing_next_node, scratch);
240 }
241 }
242 } else {
243 // This is not the last loop in the nested loops. Recursively visit the
244 // inner loops. Structurally this code is identical to the previous case,
245 // but we call Transpose() recursively instead of MacroKernel().
246 int64_t i;
247 for (i = start; i < stop; i += inc) {
248 Transpose<T, inner_bs, transformation>(
249 a + i * lda, outer_bs_a, b + i * ldb, outer_bs_b, next_node, scratch);
250 }
251 if (i < end) {
252 DCHECK_EQ(node->trailing_tile_next_node_inc, 0);
253 DCHECK(node->is_inner_dim_in_a || node->is_inner_dim_in_b);
254 if (node->is_inner_dim_in_a) {
255 outer_bs_a = (end - i) / inner_bs;
256 if (outer_bs_a > 0) {
257 Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
258 b + i * ldb, outer_bs_b,
259 next_node, scratch);
260 i += outer_bs_a * inner_bs;
261 }
262 if (i < end) {
263 Transpose<T, 1, transformation>(a + i * lda, end - i, b + i * ldb,
264 outer_bs_b * inner_bs, next_node,
265 scratch);
266 }
267 } else if (node->is_inner_dim_in_b) {
268 outer_bs_b = (end - i) / inner_bs;
269 if (outer_bs_b > 0) {
270 Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
271 b + i * ldb, outer_bs_b,
272 next_node, scratch);
273 i += outer_bs_b * inner_bs;
274 }
275 if (i < end) {
276 Transpose<T, 1, transformation>(a + i * lda, outer_bs_a * inner_bs,
277 b + i * ldb, end - i, next_node,
278 scratch);
279 }
280 }
281 } else if (node->trailing_tile_next_node_inc) {
282 TransposePlan::Node const* trailing_next_node =
283 node + node->trailing_tile_next_node_inc;
284 if (trailing_next_node->inc < 0) {
285 const int64_t lda_block = trailing_next_node->lda;
286 const int64_t ldb_block = trailing_next_node->ldb;
287 MacroKernel<T, inner_bs, transformation>(
288 a + i * lda, lda_block, outer_bs_a, b + i * ldb, ldb_block,
289 outer_bs_b, scratch);
290 } else {
291 Transpose<T, inner_bs, transformation>(a + i * lda, outer_bs_a,
292 b + i * ldb, outer_bs_b,
293 trailing_next_node, scratch);
294 }
295 }
296 }
297 }
298
299 template <typename T>
TransposeConstStride1(const char * __restrict a,char * __restrict b,TransposePlan::Node const * __restrict node)300 void TransposeConstStride1(const char* __restrict a, char* __restrict b,
301 TransposePlan::Node const* __restrict node) {
302 a += node[0].start * node[0].lda;
303 b += node[0].start * node[0].ldb;
304 if (node[0].is_inner_dim_in_a) {
305 int64_t num_bytes = (node->end - node->start) * sizeof(T);
306 std::memcpy(b, a, num_bytes);
307 } else if (node[1].is_inner_dim_in_a) {
308 int64_t offset_a = node[1].start * node[1].lda;
309 int64_t offset_b = node[1].start * node[1].ldb;
310 int64_t num_bytes = (node[1].end - node[1].start) * sizeof(T);
311 a += offset_a;
312 b += offset_b;
313 for (int64_t i = node[0].start; i < node[0].end; ++i) {
314 std::memcpy(b, a, num_bytes);
315 a += node[0].lda;
316 b += node[0].ldb;
317 }
318 if (node[0].trailing_tile_next_node_inc) {
319 TransposeConstStride1<T>(a - offset_a, b - offset_b,
320 node + node[0].trailing_tile_next_node_inc);
321 }
322 } else if (node[2].is_inner_dim_in_a) {
323 int64_t num_bytes = (node[2].end - node[2].start) * sizeof(T);
324 int64_t offset_a1 = node[1].start * node[1].lda;
325 int64_t offset_b1 = node[1].start * node[1].ldb;
326 int64_t offset_a2 = node[2].start * node[2].lda;
327 int64_t offset_b2 = node[2].start * node[2].ldb;
328 a += offset_a1 + offset_a2;
329 b += offset_b1 + offset_b2;
330 for (int64_t i = node[0].start; i < node[0].end; ++i) {
331 const char* a1 = a;
332 char* b1 = b;
333 for (int64_t j = node[1].start; j < node[1].end; ++j) {
334 std::memcpy(b1, a1, num_bytes);
335 a1 += node[1].lda;
336 b1 += node[1].ldb;
337 }
338 if (node[1].trailing_tile_next_node_inc) {
339 TransposeConstStride1<T>(
340 a1 - offset_a2, b1 - offset_b2,
341 &node[1] + node[1].trailing_tile_next_node_inc);
342 }
343 a += node[0].lda;
344 b += node[0].ldb;
345 }
346 if (node[0].trailing_tile_next_node_inc) {
347 TransposeConstStride1<T>(a - offset_a1 - offset_a2,
348 b - offset_b1 - offset_b2,
349 node + node[0].trailing_tile_next_node_inc);
350 }
351 } else {
352 for (int64_t i = node[0].start; i < node[0].end; ++i) {
353 const char* a1 = a + node[1].start * node[1].lda;
354 char* b1 = b + node[1].start * node[1].ldb;
355 for (int64_t j = node[1].start; j < node[1].end; ++j) {
356 TransposeConstStride1<T>(a1, b1, node + 2);
357 a1 += node[1].lda;
358 b1 += node[1].ldb;
359 }
360 if (node[1].trailing_tile_next_node_inc) {
361 TransposeConstStride1<T>(
362 a1, b1, &node[1] + node[1].trailing_tile_next_node_inc);
363 }
364 a += node[0].lda;
365 b += node[0].ldb;
366 }
367 if (node[0].trailing_tile_next_node_inc) {
368 TransposeConstStride1<T>(a, b,
369 node + node[0].trailing_tile_next_node_inc);
370 }
371 }
372 }
373
374 template <typename T, TransposePlan::Transformation transformation>
ExecuteTyped(const char * a,char * b,absl::Span<Node const> nodes) const375 void TransposePlan::ExecuteTyped(const char* a, char* b,
376 absl::Span<Node const> nodes) const {
377 if (inner_kernel_is_memcpy_) {
378 DCHECK(transformation_ == Transformation::kNone);
379 TransposeConstStride1<T>(a, b, nodes.data());
380 } else {
381 std::unique_ptr<char[]> scratch;
382 if (scratch_size_ > 0) {
383 scratch.reset(new char[scratch_size_]);
384 }
385 switch (inner_block_elems_) {
386 case 1:
387 if (nodes.size() > 1) {
388 Transpose<T, 1, transformation>(a, outer_block_elems_a_, b,
389 outer_block_elems_b_, nodes.data(),
390 scratch.get());
391 } else {
392 MacroKernel<T, 1, transformation>(
393 a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
394 outer_block_elems_b_, scratch.get());
395 }
396 break;
397 case 2:
398 if (nodes.size() > 1) {
399 Transpose<T, 2, transformation>(a, outer_block_elems_a_, b,
400 outer_block_elems_b_, nodes.data(),
401 scratch.get());
402 } else {
403 MacroKernel<T, 2, transformation>(
404 a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
405 outer_block_elems_b_, scratch.get());
406 }
407 break;
408 case 4:
409
410 if (nodes.size() > 1) {
411 Transpose<T, 4, transformation>(a, outer_block_elems_a_, b,
412 outer_block_elems_b_, nodes.data(),
413 scratch.get());
414 } else {
415 MacroKernel<T, 4, transformation>(
416 a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
417 outer_block_elems_b_, scratch.get());
418 }
419 break;
420 case 8:
421 if (nodes.size() > 1) {
422 Transpose<T, 8, transformation>(a, outer_block_elems_a_, b,
423 outer_block_elems_b_, nodes.data(),
424 scratch.get());
425 } else {
426 MacroKernel<T, 8, transformation>(
427 a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
428 outer_block_elems_b_, scratch.get());
429 }
430 break;
431 case 16:
432 if (nodes.size() > 1) {
433 Transpose<T, 16, transformation>(a, outer_block_elems_a_, b,
434 outer_block_elems_b_, nodes.data(),
435 scratch.get());
436 } else {
437 MacroKernel<T, 16, transformation>(
438 a, nodes.back().lda, outer_block_elems_a_, b, nodes.back().ldb,
439 outer_block_elems_b_, scratch.get());
440 }
441 break;
442 default:
443 LOG(FATAL) << "Invalid inner_block_size " << inner_block_elems_;
444 }
445 }
446 }
447
448 struct uint128 {
449 uint64_t lo;
450 uint64_t hi;
451 };
452 static_assert(sizeof(uint128) == 16, "uint128 should be 16 bytes in size");
453
Execute(const void * a,void * b,const std::function<void (std::function<void (void)>)> & schedule_work) const454 void TransposePlan::Execute(
455 const void* a, void* b,
456 const std::function<void(std::function<void(void)>)>& schedule_work) const {
457 if (num_elems_ == 0) {
458 return;
459 }
460
461 const char* ac = static_cast<const char*>(a);
462 char* bc = static_cast<char*>(b);
463
464 auto execute_by_type = [&](absl::Span<Node const> nodes) {
465 switch (elem_size_in_bytes_) {
466 case 1:
467 ExecuteTyped<uint8_t, Transformation::kNone>(ac, bc, nodes);
468 break;
469 case 2:
470 ExecuteTyped<uint16_t, Transformation::kNone>(ac, bc, nodes);
471 break;
472 case 4:
473 if (transformation_ == Transformation::kNone) {
474 ExecuteTyped<uint32_t, Transformation::kNone>(ac, bc, nodes);
475 } else {
476 DCHECK(transformation_ == Transformation::kF64ToEf57);
477 ExecuteTyped<uint32_t, Transformation::kF64ToEf57>(ac, bc, nodes);
478 }
479 break;
480 case 8:
481 ExecuteTyped<uint64_t, Transformation::kNone>(ac, bc, nodes);
482 break;
483 case 16:
484 ExecuteTyped<uint128, Transformation::kNone>(ac, bc, nodes);
485 break;
486 default:
487 LOG(FATAL) << "Unimplemented element size " << elem_size_in_bytes_;
488 }
489 };
490
491 if (!schedule_work || nodes_.size() <= 1) {
492 for (const auto& nodes : nodes_) {
493 execute_by_type(nodes);
494 }
495 } else {
496 absl::BlockingCounter counter(nodes_.size());
497 for (absl::Span<Node const> nodes : nodes_) {
498 schedule_work([&, nodes]() {
499 tensorflow::profiler::TraceMe traceme("Transpose::Execute",
500 /*level=*/2);
501 execute_by_type(nodes);
502 counter.DecrementCount();
503 });
504 }
505 counter.Wait();
506 }
507 }
508
509 // Everything above this point pertains to executing plans.
510 // Everything below this point pertains to building plans.
511
512 TransposePlan::TransposePlan() = default;
513 TransposePlan::~TransposePlan() = default;
514
ComputeStrides(int64_t elem_size_in_bytes,absl::Span<const int64_t> dims,absl::Span<const int64_t> tiling,absl::InlinedVector<int64_t,4> & outer_tile_strides,absl::InlinedVector<int64_t,4> & inner_tile_strides)515 static void ComputeStrides(
516 int64_t elem_size_in_bytes, absl::Span<const int64_t> dims,
517 absl::Span<const int64_t> tiling,
518 absl::InlinedVector<int64_t, 4>& outer_tile_strides,
519 absl::InlinedVector<int64_t, 4>& inner_tile_strides) {
520 inner_tile_strides.resize(dims.size());
521 int64_t acc = elem_size_in_bytes;
522 for (int d = static_cast<int>(dims.size()) - 1; d >= 0; --d) {
523 inner_tile_strides[d] = acc;
524 acc *= tiling[d];
525 }
526 outer_tile_strides.resize(dims.size());
527 for (int d = static_cast<int>(dims.size()) - 1; d >= 0; --d) {
528 outer_tile_strides[d] = acc;
529 acc *= CeilOfRatio(dims[d], tiling[d]);
530 }
531 }
532
RemoveTrivialDimensions(absl::InlinedVector<int64_t,4> & a_dims,absl::InlinedVector<int64_t,4> & permutation,absl::InlinedVector<int64_t,4> & lda,absl::InlinedVector<int64_t,4> & lda_tile,absl::InlinedVector<int64_t,4> & a_tiling,absl::InlinedVector<int64_t,4> & b_tiling)533 void TransposePlan::RemoveTrivialDimensions(
534 absl::InlinedVector<int64_t, 4>& a_dims,
535 absl::InlinedVector<int64_t, 4>& permutation,
536 absl::InlinedVector<int64_t, 4>& lda,
537 absl::InlinedVector<int64_t, 4>& lda_tile,
538 absl::InlinedVector<int64_t, 4>& a_tiling,
539 absl::InlinedVector<int64_t, 4>& b_tiling) {
540 int ndim = a_dims.size();
541 // How many positions has the i-th dimension of 'a' been moved to the left?
542 // -1 if the dimension is to be removed.
543 std::vector<int> shift(ndim);
544 absl::InlinedVector<int64_t, 4> updated_a_dims;
545 absl::InlinedVector<int64_t, 4> updated_lda;
546 absl::InlinedVector<int64_t, 4> updated_lda_tile;
547 absl::InlinedVector<int64_t, 4> updated_a_tiling;
548 updated_a_dims.reserve(ndim);
549 updated_lda.reserve(ndim);
550 updated_lda_tile.reserve(ndim);
551 updated_a_tiling.reserve(ndim);
552 std::vector<int64_t> inv_permutation = InversePermutation(permutation);
553 for (int a_dim = 0; a_dim < ndim; ++a_dim) {
554 int b_dim = inv_permutation[a_dim];
555 // A dimension is trivial if it has size 1 and is not tiled.
556 if (a_dims[a_dim] == 1 && a_tiling[a_dim] == 1 && b_tiling[b_dim] == 1) {
557 shift[a_dim] = -1;
558 } else {
559 updated_a_dims.push_back(a_dims[a_dim]);
560 updated_lda.push_back(lda[a_dim]);
561 updated_lda_tile.push_back(lda_tile[a_dim]);
562 updated_a_tiling.push_back(a_tiling[a_dim]);
563 shift[a_dim] = a_dim + 1 - updated_a_dims.size();
564 }
565 }
566
567 // Updates the permutation and tiling of b.
568 absl::InlinedVector<int64_t, 4> updated_permutation;
569 absl::InlinedVector<int64_t, 4> updated_b_tiling;
570 updated_permutation.reserve(updated_a_dims.size());
571 updated_b_tiling.reserve(updated_a_dims.size());
572 for (int b_dim = 0; b_dim < ndim; ++b_dim) {
573 int a_dim = permutation[b_dim];
574 if (shift[a_dim] >= 0) {
575 updated_permutation.push_back(a_dim - shift[a_dim]);
576 updated_b_tiling.push_back(b_tiling[b_dim]);
577 }
578 }
579
580 DCHECK(IsPermutation(updated_permutation));
581 a_dims = std::move(updated_a_dims);
582 permutation = std::move(updated_permutation);
583 lda = std::move(updated_lda);
584 lda_tile = std::move(updated_lda_tile);
585 a_tiling = std::move(updated_a_tiling);
586 b_tiling = std::move(updated_b_tiling);
587 }
588
CoalesceDimensions(absl::InlinedVector<int64_t,4> & a_dims,absl::InlinedVector<int64_t,4> & permutation,absl::InlinedVector<int64_t,4> & lda,absl::InlinedVector<int64_t,4> & lda_tile,absl::InlinedVector<int64_t,4> & a_tiling,absl::InlinedVector<int64_t,4> & b_tiling)589 void TransposePlan::CoalesceDimensions(
590 absl::InlinedVector<int64_t, 4>& a_dims,
591 absl::InlinedVector<int64_t, 4>& permutation,
592 absl::InlinedVector<int64_t, 4>& lda,
593 absl::InlinedVector<int64_t, 4>& lda_tile,
594 absl::InlinedVector<int64_t, 4>& a_tiling,
595 absl::InlinedVector<int64_t, 4>& b_tiling) {
596 int ndim = a_dims.size();
597 // How many positions has the i-th dimension of 'a' been moved to the left?
598 // -1 if the dimension is to be removed.
599 std::vector<int> shift(ndim, 0);
600 absl::InlinedVector<int64_t, 4> updated_a_dims;
601 absl::InlinedVector<int64_t, 4> updated_lda;
602 absl::InlinedVector<int64_t, 4> updated_lda_tile;
603 absl::InlinedVector<int64_t, 4> updated_a_tiling;
604 updated_a_dims.reserve(ndim);
605 updated_lda.reserve(ndim);
606 updated_lda_tile.reserve(ndim);
607 updated_a_tiling.reserve(ndim);
608 std::vector<int64_t> inv_permutation = InversePermutation(permutation);
609 for (int a_dim = 0; a_dim < ndim; ++a_dim) {
610 // We can coalesce two dimensions if they appear consecutively
611 // in both the input dimensions and the output dimensions, and the stride
612 // of the outer dimension is the usual multiple of the inner dimension.
613 if (a_dim > 0 && inv_permutation[a_dim - 1] + 1 == inv_permutation[a_dim] &&
614 lda[a_dim - 1] == lda[a_dim] * a_dims[a_dim] &&
615 a_tiling[a_dim - 1] == 1 && a_tiling[a_dim] == 1 &&
616 b_tiling[inv_permutation[a_dim]] == 1 &&
617 b_tiling[inv_permutation[a_dim - 1]] == 1) {
618 updated_a_dims.back() *= a_dims[a_dim];
619 updated_lda.back() = lda[a_dim];
620 shift[a_dim] = -1;
621 } else {
622 updated_a_dims.push_back(a_dims[a_dim]);
623 updated_lda.push_back(lda[a_dim]);
624 updated_lda_tile.push_back(lda_tile[a_dim]);
625 updated_a_tiling.push_back(a_tiling[a_dim]);
626 shift[a_dim] = a_dim + 1 - updated_a_dims.size();
627 }
628 }
629
630 // Updates the permutation.
631 absl::InlinedVector<int64_t, 4> updated_permutation;
632 absl::InlinedVector<int64_t, 4> updated_b_tiling;
633 updated_permutation.reserve(updated_a_dims.size());
634 updated_b_tiling.reserve(updated_a_dims.size());
635 for (int b_dim = 0; b_dim < ndim; ++b_dim) {
636 int a_dim = permutation[b_dim];
637 if (shift[a_dim] >= 0) {
638 updated_permutation.push_back(a_dim - shift[a_dim]);
639 updated_b_tiling.push_back(b_tiling[b_dim]);
640 }
641 }
642 DCHECK(IsPermutation(updated_permutation));
643 a_dims = std::move(updated_a_dims);
644 permutation = std::move(updated_permutation);
645 lda = std::move(updated_lda);
646 lda_tile = std::move(updated_lda_tile);
647 a_tiling = std::move(updated_a_tiling);
648 b_tiling = std::move(updated_b_tiling);
649 }
650
InputNumElems() const651 int64_t TransposePlan::InputNumElems() const {
652 int64_t size = 1;
653 for (size_t i = 0; i < a_dims_.size(); ++i) {
654 size *= RoundUpTo(a_dims_[i], a_tiling_[i]);
655 }
656 return size;
657 }
658
OutputNumElems() const659 int64_t TransposePlan::OutputNumElems() const {
660 int64_t size = 1;
661 for (size_t i = 0; i < a_dims_.size(); ++i) {
662 size *= RoundUpTo(a_dims_[permutation_[i]], b_tiling_[i]);
663 }
664 return size;
665 }
666
667 // Parses and validates a tiling specification, and populates `tiling`.
ParseTilingSpecification(int ndim,absl::Span<int64_t const> tiling_spec,absl::InlinedVector<int64_t,4> & tiling)668 static Status ParseTilingSpecification(
669 int ndim, absl::Span<int64_t const> tiling_spec,
670 absl::InlinedVector<int64_t, 4>& tiling) {
671 tiling.resize(ndim, 1);
672 if (tiling_spec.size() > ndim) {
673 return InvalidArgument(
674 "Tiling (%s) must have at as many dimensions as the array (%d)",
675 absl::StrJoin(tiling_spec, ","), ndim);
676 }
677 if (absl::c_find_if(tiling_spec, [](int64_t d) { return d < 1; }) !=
678 tiling_spec.end()) {
679 return InvalidArgument("Tiling sizes (%s) must be >= 1",
680 absl::StrJoin(tiling_spec, ","));
681 }
682 int offset = ndim;
683 offset -= tiling_spec.size();
684 absl::c_copy(tiling_spec, tiling.begin() + offset);
685 return OkStatus();
686 }
687
688 // Helper function that builds a plan.
BuildPlanNodes(absl::Span<int64_t const> inverse_permutation,int thread_id,std::vector<TransposePlan::Node> & nodes)689 void TransposePlan::BuildPlanNodes(
690 absl::Span<int64_t const> inverse_permutation, int thread_id,
691 std::vector<TransposePlan::Node>& nodes) {
692 VLOG(8) << "Before plan build: " << ToString();
693 const int ndim = a_dims_.size();
694 DCHECK_GT(ndim, 0);
695 const int pos_stride1a = ndim - 1;
696 const int pos_stride1b_in_a = permutation_.back();
697 const int pos_stride1a_in_b = inverse_permutation[pos_stride1a];
698
699 // We builld plans in a depth-first order, visiting loops from outermost to
700 // innermost. We use a stack (depth-first) order to handle trailing partial
701 // tiles, which we "come back to" after handling the non-trailing case.
702 struct Agendum {
703 // The ID of the loop to visit in loop_order_.
704 int loop_id;
705 // The parent node ID whose trailing tile should be made to point to this
706 // node.
707 int parent_node_id;
708
709 // The number of parallel tasks available to run this loop and its
710 // successors.
711 int num_tasks_at_loop;
712
713 // The ID number of the current thread in the tasks at this loop.
714 int task_id_at_loop;
715
716 // For which dimensions of `a` are we to visit the partial trailing tile
717 // a loop that visits that tile's interior?
718 absl::InlinedVector<bool, 4> partial_tiles;
719 };
720 std::stack<Agendum> agenda;
721
722 int total_tasks =
723 absl::c_accumulate(loop_parallelism_, int{1}, std::multiplies<int>());
724
725 agenda.push(Agendum{/*loop_id=*/0, /*parent_node_id=*/-1,
726 /*num_tasks_at_loop=*/total_tasks,
727 /*task_id_at_loop=*/thread_id,
728 absl::InlinedVector<bool, 4>(ndim, false)});
729
730 auto loop_has_trivial_iteration_space = [](const Node& node) {
731 return node.start == 0 && node.start + node.inc == node.end;
732 };
733
734 while (!agenda.empty()) {
735 Agendum agendum = std::move(agenda.top());
736 agenda.pop();
737
738 int node_id = static_cast<int>(nodes.size());
739 if (agendum.parent_node_id >= 0) {
740 // This is a trailing partial tile node; update the parent node to
741 // point to it.
742 nodes[agendum.parent_node_id].trailing_tile_next_node_inc =
743 node_id - agendum.parent_node_id;
744 }
745
746 if (agendum.loop_id == loop_order_.size()) {
747 // We've reached the end of the loop nest.
748 DCHECK_EQ(agendum.num_tasks_at_loop, 1);
749 // Transpose loops have a sentinel node, indicated by a negative `inc`
750 // value, that describes the striding of the inner transpose kernel.
751 if (!inner_kernel_is_memcpy_) {
752 Node node;
753 node.start = node.end = node.inc = -1;
754 node.lda = a_tiling_[pos_stride1b_in_a] > 1
755 ? lda_tile_[pos_stride1b_in_a]
756 : lda_[pos_stride1b_in_a];
757 node.ldb = b_tiling_[pos_stride1a_in_b] > 1
758 ? ldb_tile_[pos_stride1a_in_b]
759 : ldb_[pos_stride1a_in_b];
760 nodes.push_back(node);
761 }
762 DCHECK(!(inner_kernel_is_memcpy_ && agendum.parent_node_id >= 0));
763 continue;
764 }
765
766 const Loop& loop = loop_order_[agendum.loop_id];
767 int a_dim = loop.dim_in_a;
768 int b_dim = inverse_permutation[a_dim];
769 DCHECK(a_tiling_[a_dim] == 1 || b_tiling_[b_dim] == 1 ||
770 a_tiling_[a_dim] == b_tiling_[b_dim]);
771 int64_t tile_size = std::max(a_tiling_[a_dim], b_tiling_[b_dim]);
772
773 // Compute the number of tasks for the next loop iteration.
774 int task_id_at_loop = agendum.task_id_at_loop;
775 int num_tasks_at_loop =
776 agendum.num_tasks_at_loop / loop_parallelism_[agendum.loop_id];
777 int task_id_at_next_loop = task_id_at_loop % num_tasks_at_loop;
778
779 if (loop.tile_interior) {
780 // We are visiting the tile interior of a tiled dimension.
781 bool partial = agendum.partial_tiles[a_dim];
782
783 Node node;
784 node.lda = a_tiling_[a_dim] > 1 ? lda_tile_[a_dim] : lda_[a_dim];
785 node.ldb = b_tiling_[b_dim] > 1 ? ldb_tile_[b_dim] : ldb_[b_dim];
786 node.inc = 1;
787 node.is_inner_dim_in_a = (a_dim == pos_stride1a);
788 node.is_inner_dim_in_b = (a_dim == pos_stride1b_in_a);
789 if (node.is_inner_dim_in_a) {
790 node.inc = inner_block_elems_ * outer_block_elems_a_;
791 } else if (node.is_inner_dim_in_b) {
792 node.inc = inner_block_elems_ * outer_block_elems_b_;
793 }
794
795 int task_id = task_id_at_loop / num_tasks_at_loop;
796 int64_t size = partial ? a_dims_[a_dim] % tile_size : tile_size;
797 int64_t num_iterations = CeilOfRatio(size, node.inc);
798 int64_t num_iterations_per_task = CeilOfRatio<int64_t>(
799 num_iterations, loop_parallelism_[agendum.loop_id]);
800 node.start = std::min(size, task_id * num_iterations_per_task * node.inc);
801 node.end =
802 std::min(size, (task_id + 1) * num_iterations_per_task * node.inc);
803 if (!loop_has_trivial_iteration_space(node) ||
804 (inner_kernel_is_memcpy_ && node.is_inner_dim_in_a)) {
805 nodes.push_back(node);
806 }
807 Agendum new_agendum;
808 new_agendum.loop_id = agendum.loop_id + 1;
809 new_agendum.parent_node_id = -1;
810 new_agendum.task_id_at_loop = task_id_at_next_loop;
811 new_agendum.num_tasks_at_loop = num_tasks_at_loop;
812 new_agendum.partial_tiles = agendum.partial_tiles;
813 agenda.push(std::move(new_agendum));
814 } else {
815 // We are either visiting an untiled dimension, or the loop that iterates
816 // over tile exteriors.
817 int task_id = task_id_at_loop / num_tasks_at_loop;
818 int64_t num_complete_tiles = a_dims_[a_dim] / tile_size;
819 bool has_partial_tile = (a_dims_[a_dim] % tile_size != 0);
820
821 // If there is a trailing partial tile as well as complete tiles, handle
822 // it as a trailer on the loop over complete tiles.
823 bool has_trailing_plan_node = false;
824 if (num_complete_tiles > 0 && has_partial_tile &&
825 task_id == loop_parallelism_[agendum.loop_id] - 1) {
826 Agendum new_agendum;
827 new_agendum.loop_id = agendum.loop_id + 1;
828 new_agendum.parent_node_id = node_id;
829 new_agendum.task_id_at_loop = task_id_at_next_loop;
830 new_agendum.num_tasks_at_loop = num_tasks_at_loop;
831 new_agendum.partial_tiles = agendum.partial_tiles;
832 new_agendum.partial_tiles[a_dim] = true;
833 agenda.push(std::move(new_agendum));
834 has_trailing_plan_node = true;
835 }
836 Node node;
837 node.lda = lda_[a_dim] * tile_size / a_tiling_[a_dim];
838 node.ldb = ldb_[b_dim] * tile_size / b_tiling_[b_dim];
839 node.inc = 1;
840 node.is_inner_dim_in_a = (tile_size == 1 && a_dim == ndim - 1);
841 node.is_inner_dim_in_b = (tile_size == 1 && a_dim == pos_stride1b_in_a);
842 if (node.is_inner_dim_in_a) {
843 node.inc = inner_block_elems_ * outer_block_elems_a_;
844 } else if (node.is_inner_dim_in_b) {
845 node.inc = inner_block_elems_ * outer_block_elems_b_;
846 }
847
848 // If this tiled dimension consists only of a single partial tile, handle
849 // it here; there's no point emitting a degenerate loop and a separate
850 // path to handle the trailing tile.
851 bool partial = num_complete_tiles == 0 && has_partial_tile;
852
853 // Evenly divide the loop iterations amongst the threads.
854 int64_t num_tiles = partial ? 1 : num_complete_tiles;
855 int64_t num_iterations = CeilOfRatio(num_tiles, node.inc);
856 int64_t num_iterations_per_task = CeilOfRatio<int64_t>(
857 num_iterations, loop_parallelism_[agendum.loop_id]);
858 node.start =
859 std::min(num_tiles, task_id * num_iterations_per_task * node.inc);
860 node.end = std::min(num_tiles,
861 (task_id + 1) * num_iterations_per_task * node.inc);
862 // If this loop has a trivial iteration space, drop it.
863 if (!loop_has_trivial_iteration_space(node) ||
864 (inner_kernel_is_memcpy_ && node.is_inner_dim_in_a) ||
865 has_trailing_plan_node) {
866 nodes.push_back(node);
867 }
868 Agendum new_agendum;
869 new_agendum.loop_id = agendum.loop_id + 1;
870 new_agendum.parent_node_id = -1;
871 new_agendum.task_id_at_loop = task_id_at_next_loop;
872 new_agendum.num_tasks_at_loop = num_tasks_at_loop;
873 new_agendum.partial_tiles = agendum.partial_tiles;
874 new_agendum.partial_tiles[a_dim] = partial;
875 agenda.push(std::move(new_agendum));
876 }
877 }
878 }
879
Create(size_t elem_size_in_bytes,absl::Span<int64_t const> dims,absl::Span<int64_t const> permutation,std::variant<Tiling,Striding> input_layout,Tiling output_tiling,Transformation transformation,int num_threads)880 StatusOr<std::unique_ptr<TransposePlan>> TransposePlan::Create(
881 size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
882 absl::Span<int64_t const> permutation,
883 std::variant<Tiling, Striding> input_layout, Tiling output_tiling,
884 Transformation transformation, int num_threads) {
885 auto is_negative = [](int d) { return d < 0; };
886 if (absl::c_find_if(dims, is_negative) != dims.end()) {
887 return InvalidArgument("dims must be non-negative, got %s",
888 absl::StrJoin(dims, ","));
889 }
890 if (permutation.size() != dims.size()) {
891 return InvalidArgument(
892 "dims and permutation must have equal sizes, got %d and %d",
893 dims.size(), permutation.size());
894 }
895 if (!IsPermutation(permutation)) {
896 return InvalidArgument("permutation argument is not valid, got: %s",
897 absl::StrJoin(permutation, ","));
898 }
899 if (num_threads < 1) {
900 return InvalidArgument("num_threads argument must be >= 1, got: %d",
901 num_threads);
902 }
903
904 int ndim = dims.size();
905
906 auto plan = std::make_unique<TransposePlan>();
907 plan->num_threads_requested_ = num_threads;
908 plan->elem_size_in_bytes_ = elem_size_in_bytes;
909 switch (elem_size_in_bytes) {
910 case 1:
911 case 2:
912 case 4:
913 case 8:
914 case 16:
915 break;
916 default:
917 return InvalidArgument("Unsupported elem_size_in_bytes=%d",
918 elem_size_in_bytes);
919 }
920 plan->num_elems_ = std::accumulate(dims.begin(), dims.end(), int64_t{1},
921 std::multiplies<int64_t>());
922 plan->original_a_dims_.resize(ndim);
923 absl::c_copy(dims, plan->original_a_dims_.begin());
924 plan->original_b_dims_ = Permute(dims, permutation);
925
926 TF_RETURN_IF_ERROR(
927 ParseTilingSpecification(ndim, output_tiling.tiling, plan->b_tiling_));
928
929 // Handles strides.
930 if (std::holds_alternative<Striding>(input_layout)) {
931 absl::Span<int64_t const> input_strides_in_bytes =
932 std::get<Striding>(input_layout).strides_in_bytes;
933 if (input_strides_in_bytes.size() != dims.size()) {
934 return InvalidArgument(
935 "dims and input_strides_in_bytes must have equal sizes, got %d "
936 "and %d",
937 dims.size(), input_strides_in_bytes.size());
938 }
939 plan->original_a_strides_.resize(ndim);
940 absl::c_copy(input_strides_in_bytes, plan->original_a_strides_.begin());
941 // Sort the dimensions from slowest-varying (largest strides) to
942 // fastest-varying (smallest strides).
943 std::vector<int64_t> dim_order(ndim);
944 absl::c_iota(dim_order, 0);
945
946 auto cost = [&](int k) {
947 int64_t stride = input_strides_in_bytes.at(k);
948 // If there is a dimension with size equal to the element size, sort it
949 // last. This ensures that we place any stride-1 dimension last.
950 bool is_stride1 = stride == elem_size_in_bytes;
951 // If there are multiple stride-1 dimensions, we'd prefer the one that
952 // matches the stride-1 dimension of the output.
953 // Failing that, we'd just prefer the largest stride-1 dimension last.
954 bool is_trailing_dim_in_b = permutation.back() == k;
955
956 // If we are applying ef57 conversion, we want a size-2 stride-1
957 // dimension last.
958 bool ef57_even =
959 (is_stride1 && transformation == Transformation::kF64ToEf57 &&
960 dims[k] == 2);
961
962 return std::make_tuple(is_stride1, -std::abs(stride), ef57_even,
963 is_trailing_dim_in_b, dims[k]);
964 };
965 absl::c_stable_sort(dim_order,
966 [&cost](int i, int j) { return cost(i) < cost(j); });
967 // dim_order maps new input dim -> old input dim, we need its inverse to
968 // compute the new permutation.
969 auto inv_dim_order = InversePermutation(dim_order);
970 plan->lda_.reserve(ndim);
971 plan->a_dims_.reserve(ndim);
972 plan->permutation_.reserve(ndim);
973 for (int i = 0; i < ndim; ++i) {
974 plan->lda_.push_back(input_strides_in_bytes.at(dim_order[i]));
975 plan->a_dims_.push_back(dims[dim_order[i]]);
976 plan->permutation_.push_back(inv_dim_order[permutation[i]]);
977 }
978 plan->lda_tile_.resize(ndim, 1);
979 plan->a_tiling_.resize(ndim, 1);
980 } else {
981 TF_RETURN_IF_ERROR(ParseTilingSpecification(
982 ndim, std::get<Tiling>(input_layout).tiling, plan->a_tiling_));
983
984 plan->a_dims_ = plan->original_a_dims_;
985 plan->permutation_.resize(ndim);
986 absl::c_copy(permutation, plan->permutation_.begin());
987 ComputeStrides(plan->elem_size_in_bytes_, plan->a_dims_, plan->a_tiling_,
988 plan->lda_, plan->lda_tile_);
989 }
990
991 auto is_not_one = [](int64_t x) { return x != 1; };
992 plan->a_is_tiled_ =
993 (absl::c_find_if(plan->a_tiling_, is_not_one) != plan->a_tiling_.end());
994 plan->b_is_tiled_ =
995 (absl::c_find_if(plan->b_tiling_, is_not_one) != plan->b_tiling_.end());
996 if (plan->a_is_tiled_ && plan->b_is_tiled_) {
997 return Unimplemented(
998 "Only one of the input and output may have a non-trivial tiling, "
999 "got tilings: %s and %s",
1000 absl::StrJoin(plan->a_tiling_, ","),
1001 absl::StrJoin(plan->b_tiling_, ","));
1002 }
1003
1004 plan->transformation_ = transformation;
1005 switch (transformation) {
1006 case Transformation::kNone:
1007 break;
1008 case Transformation::kF64ToEf57:
1009 if (elem_size_in_bytes != sizeof(float)) {
1010 return InvalidArgument(
1011 "EF57 conversion requires a element size of %d bytes, got %d",
1012 sizeof(float), elem_size_in_bytes);
1013 }
1014 if (plan->a_dims_.empty() || plan->a_dims_.back() % 2 != 0 ||
1015 plan->lda_.back() != sizeof(float)) {
1016 return InvalidArgument(
1017 "EF57 conversion requires a stride-%d dimension whose size is a "
1018 "multiple of 2",
1019 sizeof(float));
1020 }
1021 }
1022
1023 plan->Initialize();
1024 VLOG(5) << plan->ToString();
1025 return plan;
1026 }
1027
Initialize()1028 void TransposePlan::Initialize() {
1029 if (num_elems_ == 0) {
1030 return;
1031 }
1032 RemoveTrivialDimensions(a_dims_, permutation_, lda_, lda_tile_, a_tiling_,
1033 b_tiling_);
1034 CoalesceDimensions(a_dims_, permutation_, lda_, lda_tile_, a_tiling_,
1035 b_tiling_);
1036
1037 // permutation maps dimensions of b to a
1038 // inverse_permutation maps dimensions of a to b
1039 std::vector<int64_t> inverse_permutation = InversePermutation(permutation_);
1040
1041 int ndim = a_dims_.size();
1042
1043 int64_t stride_pos1a =
1044 lda_.empty()
1045 ? -1
1046 : (a_tiling_[ndim - 1] > 1 ? lda_tile_[ndim - 1] : lda_[ndim - 1]);
1047 // We don't accept arbitrary stridings for B, so we know B always has a
1048 // stride 1 dimension innermost.
1049
1050 // If the plan is 0-dimensional, or the innermost dimension of A is not of
1051 // stride 1, adds a trivial size 1 dimension. The transpose kernels rely on
1052 // the presence of a stride-1 innermost dimension in the input.
1053 if (lda_.empty() || stride_pos1a != elem_size_in_bytes_) {
1054 int dim = static_cast<int>(a_dims_.size());
1055 permutation_.push_back(dim);
1056 inverse_permutation.push_back(dim);
1057 a_dims_.push_back(1);
1058 lda_.push_back(elem_size_in_bytes_);
1059 lda_tile_.push_back(1);
1060 a_tiling_.push_back(1);
1061 b_tiling_.push_back(1);
1062 ++ndim;
1063 }
1064 b_dims_ = Permute(a_dims_, permutation_);
1065 ComputeStrides(elem_size_in_bytes_, b_dims_, b_tiling_, ldb_, ldb_tile_);
1066
1067 const int pos_stride1a = ndim - 1;
1068 const int pos_stride1b_in_a = permutation_.back();
1069 inner_kernel_is_memcpy_ = (pos_stride1b_in_a == pos_stride1a);
1070
1071 loop_order_.reserve(ndim);
1072 for (int i = 0; i < ndim; ++i) {
1073 loop_order_.push_back(Loop{i, /*tile_interior=*/false});
1074 if (a_tiling_[i] != 1 || b_tiling_[inverse_permutation[i]] != 1) {
1075 loop_order_.push_back(Loop{i, /*tile_interior=*/true});
1076 }
1077 }
1078
1079 // Bound the block sizes so they are smaller than the stride-1 dimension
1080 // size.
1081 int64_t a_stride1_size = std::max(
1082 a_tiling_[pos_stride1a], b_tiling_[inverse_permutation[pos_stride1a]]);
1083 if (a_stride1_size == 1) {
1084 a_stride1_size = a_dims_[pos_stride1a];
1085 } else {
1086 // If there's only one tile, we should use the dimension size.
1087 a_stride1_size = std::min(a_dims_[pos_stride1a], a_stride1_size);
1088 }
1089 int64_t b_stride1_size =
1090 std::max(a_tiling_[permutation_.back()], b_tiling_.back());
1091 if (b_stride1_size == 1) {
1092 b_stride1_size = b_dims_.back();
1093 } else {
1094 b_stride1_size = std::min(b_stride1_size, b_dims_.back());
1095 }
1096
1097 if (inner_kernel_is_memcpy_) {
1098 inner_block_elems_ = -1;
1099 outer_block_elems_a_ = -1;
1100 outer_block_elems_b_ = -1;
1101 } else {
1102 // What are the smallest and largest block sizes for which we have a
1103 // vectorized kernel for this element size?
1104 int min_inner_block_elems;
1105 int max_inner_block_elems;
1106 switch (elem_size_in_bytes_) {
1107 case 1:
1108 min_inner_block_elems = 4;
1109 max_inner_block_elems = 16;
1110 break;
1111 case 2:
1112 min_inner_block_elems = 8;
1113 max_inner_block_elems = 8;
1114 break;
1115 case 4:
1116 min_inner_block_elems = 4;
1117 max_inner_block_elems = 8;
1118 break;
1119 case 8:
1120 min_inner_block_elems = 2;
1121 max_inner_block_elems = 4;
1122 break;
1123 case 16:
1124 min_inner_block_elems = 1;
1125 max_inner_block_elems = 1;
1126 break;
1127 default:
1128 LOG(FATAL) << "Unreachable: element size " << elem_size_in_bytes_;
1129 }
1130 inner_block_elems_ = max_inner_block_elems;
1131 while (inner_block_elems_ > std::min(a_stride1_size, b_stride1_size)) {
1132 inner_block_elems_ /= 2;
1133 }
1134 if (inner_block_elems_ < min_inner_block_elems) {
1135 // Size is smaller than our smallest vectorized kernel. Use the scalar
1136 // path.
1137 inner_block_elems_ = 1;
1138 }
1139 outer_block_elems_a_ = FloorOfRatio<int64_t>(
1140 std::min<int64_t>(16, a_stride1_size), inner_block_elems_);
1141 outer_block_elems_b_ = FloorOfRatio<int64_t>(
1142 std::min<int64_t>(16, b_stride1_size), inner_block_elems_);
1143 }
1144
1145 // Loop order heuristic: try to make loops with small strides innermost.
1146 auto cost = [&](const Loop& l) {
1147 int64_t a_stride =
1148 std::abs((l.tile_interior && a_is_tiled_) ? lda_tile_[l.dim_in_a]
1149 : lda_[l.dim_in_a]);
1150 bool is_inner_dim_in_a =
1151 (!a_is_tiled_ || l.tile_interior) && (l.dim_in_a == pos_stride1a);
1152
1153 if (!inner_kernel_is_memcpy_ && is_inner_dim_in_a) {
1154 a_stride *= inner_block_elems_ * outer_block_elems_a_;
1155 }
1156 int b_dim = inverse_permutation[l.dim_in_a];
1157 int64_t b_stride =
1158 (l.tile_interior && b_is_tiled_) ? ldb_tile_[b_dim] : ldb_[b_dim];
1159 bool is_inner_dim_in_b =
1160 (!b_is_tiled_ || l.tile_interior) && (l.dim_in_a == pos_stride1b_in_a);
1161 if (!inner_kernel_is_memcpy_ && is_inner_dim_in_b) {
1162 b_stride *= inner_block_elems_ * outer_block_elems_b_;
1163 }
1164 // Add a small penalty to the input strides: given the choice between
1165 // consecutive writes and consecutive reads, we would prefer consecutive
1166 // writes.
1167 double penalty = 1.01;
1168
1169 // If the inner kernel is a memcpy make sure the innermost loop is the
1170 // stride-1 dimension. This is a requirement of the memcpy kernel.
1171 bool dim_must_go_last =
1172 inner_kernel_is_memcpy_ && l.dim_in_a == pos_stride1a &&
1173 (l.tile_interior ||
1174 (a_tiling_[l.dim_in_a] == 1 && b_tiling_[b_dim] == 1));
1175 return std::make_tuple(dim_must_go_last,
1176 inner_kernel_is_memcpy_ && l.tile_interior,
1177 -std::min<double>(a_stride * penalty, b_stride));
1178 };
1179 absl::c_stable_sort(loop_order_, [&](const Loop& a, const Loop& b) {
1180 return cost(a) < cost(b);
1181 });
1182 // It is a required invariant of the loop order that tile interiors always
1183 // appear after the corresponding tile exterior. This is a consequence of the
1184 // heuristic above, because the tile interior must have smaller strides in
1185 // both input and output.
1186
1187 // The stride-1 loop must be innermost for a memcpy loop.
1188 DCHECK(!inner_kernel_is_memcpy_ || loop_order_.back().dim_in_a == ndim - 1)
1189 << ToString();
1190
1191 loop_parallelism_ = ChooseParallelizationStrategy(inverse_permutation);
1192 int num_threads =
1193 absl::c_accumulate(loop_parallelism_, int{1}, std::multiplies<int>());
1194 nodes_.resize(num_threads);
1195 for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
1196 BuildPlanNodes(inverse_permutation, thread_id, nodes_[thread_id]);
1197 }
1198
1199 switch (transformation_) {
1200 case Transformation::kNone:
1201 scratch_size_ = 0;
1202 break;
1203 case Transformation::kF64ToEf57:
1204 scratch_size_ = sizeof(float) * inner_block_elems_ * inner_block_elems_ *
1205 outer_block_elems_a_ * outer_block_elems_b_;
1206 DCHECK(!inner_kernel_is_memcpy_);
1207 break;
1208 }
1209 }
1210
ChooseParallelizationStrategy(absl::Span<int64_t const> inverse_permutation)1211 std::vector<int> TransposePlan::ChooseParallelizationStrategy(
1212 absl::Span<int64_t const> inverse_permutation) {
1213 std::vector<int> parallelism;
1214 int available_parallelism = num_threads_requested_;
1215 parallelism.reserve(loop_order_.size());
1216
1217 int ndim = permutation_.size();
1218 const int pos_stride1a = ndim - 1;
1219 const int pos_stride1b_in_a = permutation_.back();
1220 // Compute the number of iterations in `loop`.
1221 auto loop_iterations = [&](const Loop& loop) {
1222 int a_dim = loop.dim_in_a;
1223 int b_dim = inverse_permutation[a_dim];
1224 int64_t tile_size = std::max(a_tiling_[a_dim], b_tiling_[b_dim]);
1225 int64_t size = loop.tile_interior
1226 ? tile_size
1227 : (CeilOfRatio(a_dims_[loop.dim_in_a], tile_size));
1228 if (!inner_kernel_is_memcpy_ && (loop.tile_interior || tile_size == 1)) {
1229 if (loop.dim_in_a == pos_stride1a) {
1230 size = CeilOfRatio<int64_t>(size,
1231 inner_block_elems_ * outer_block_elems_a_);
1232 } else if (loop.dim_in_a == pos_stride1b_in_a) {
1233 size = CeilOfRatio<int64_t>(size,
1234 inner_block_elems_ * outer_block_elems_b_);
1235 }
1236 }
1237 return size;
1238 };
1239
1240 // Estimate the number of bytes each iteration of each loop processes.
1241 absl::InlinedVector<int64_t, 4> work_in_bytes(loop_order_.size());
1242 int64_t acc = elem_size_in_bytes_;
1243 if (!inner_kernel_is_memcpy_) {
1244 acc *= inner_block_elems_ * inner_block_elems_ * outer_block_elems_a_ *
1245 outer_block_elems_b_;
1246 }
1247 auto work_it = work_in_bytes.rbegin();
1248 for (auto it = loop_order_.rbegin(); it != loop_order_.rend(); ++it) {
1249 *work_it++ = acc;
1250 acc *= loop_iterations(*it);
1251 }
1252 VLOG(7) << "Per-loop iteration work in bytes: "
1253 << absl::StrJoin(work_in_bytes, ",");
1254
1255 // Heuristic that attempts to parallelize the outermost loops, down to a
1256 // minimum per-thread number of bytes processed.
1257 for (size_t i = 0; i < loop_order_.size(); ++i) {
1258 const Loop& loop = loop_order_[i];
1259 CHECK_GE(available_parallelism, 1);
1260 int64_t iterations = loop_iterations(loop);
1261 int kMinBytesPerThread = inner_kernel_is_memcpy_ ? (1 << 20) : (1 << 26);
1262 int64_t min_iterations_per_thread =
1263 CeilOfRatio<int64_t>(kMinBytesPerThread, work_in_bytes[i]);
1264 int64_t parallel_work = CeilOfRatio(iterations, min_iterations_per_thread);
1265
1266 VLOG(8) << "iterations=" << iterations << " parallel_work=" << parallel_work
1267 << " available_parallelism=" << available_parallelism;
1268 if (parallel_work >= available_parallelism) {
1269 parallelism.push_back(available_parallelism);
1270 available_parallelism = 1;
1271 } else {
1272 parallelism.push_back(parallel_work);
1273 available_parallelism /= parallel_work;
1274 }
1275 }
1276 return parallelism;
1277 }
1278
ToString() const1279 std::string TransposePlan::ToString() const {
1280 std::string nodes_str = absl::StrJoin(
1281 nodes_, "\n", [](std::string* out, absl::Span<Node const> thread_nodes) {
1282 absl::StrAppend(
1283 out, "thread:\n",
1284 absl::StrJoin(
1285 thread_nodes, "\n", [](std::string* out, const Node& node) {
1286 absl::StrAppendFormat(
1287 out,
1288 " "
1289 "Node(start=%d,end=%d,inc=%d,lda=%"
1290 "d,ldb=%d,next_trailing=%d,inner_a=%s,inner_b=%s)",
1291 node.start, node.end, node.inc, node.lda, node.ldb,
1292 node.trailing_tile_next_node_inc,
1293 node.is_inner_dim_in_a ? "y" : "n",
1294 node.is_inner_dim_in_b ? "y" : "n");
1295 }));
1296 });
1297 auto format_loop_order = [](std::string* out, const Loop& loop) {
1298 return absl::StrAppend(out, loop.dim_in_a,
1299 loop.tile_interior ? "[tile]" : "");
1300 };
1301 std::string transformation_str;
1302 switch (transformation_) {
1303 case Transformation::kNone:
1304 transformation_str = "none";
1305 break;
1306 case Transformation::kF64ToEf57:
1307 transformation_str = "ef57";
1308 break;
1309 }
1310 return absl::StrFormat(
1311 "elem_size=%d a_dims=%s b_dims=%s permutation=%s a_tiling=%s b_tiling=%s "
1312 "lda=%s lda_tile=%s ldb=%s ldb_tile=%s loop_order=%s "
1313 "loop_parallelism=%s outer_bs=[%d,%d] inner_bs=%d "
1314 "transformation=%s scratch_size=%d\n"
1315 "nodes:\n%s",
1316 elem_size_in_bytes_, absl::StrJoin(a_dims_, ","),
1317 absl::StrJoin(Permute(a_dims_, permutation_), ","),
1318 absl::StrJoin(permutation_, ","), absl::StrJoin(a_tiling_, ","),
1319 absl::StrJoin(b_tiling_, ","), absl::StrJoin(lda_, ","),
1320 absl::StrJoin(lda_tile_, ","), absl::StrJoin(ldb_, ","),
1321 absl::StrJoin(ldb_tile_, ","),
1322 absl::StrJoin(loop_order_, ",", format_loop_order),
1323 absl::StrJoin(loop_parallelism_, ","), outer_block_elems_a_,
1324 outer_block_elems_b_, inner_block_elems_, transformation_str,
1325 scratch_size_, nodes_str);
1326 }
1327
1328 struct TransposePlanCacheKey {
1329 size_t elem_size_in_bytes;
1330 absl::InlinedVector<int64_t, 4> dims;
1331 absl::InlinedVector<int64_t, 4> permutation;
1332 bool input_layout_is_tiling;
1333 absl::InlinedVector<int64_t, 4> input_layout;
1334 absl::InlinedVector<int64_t, 4> output_tiling;
1335 TransposePlan::Transformation transformation;
1336 int num_threads;
1337
1338 bool operator==(const TransposePlanCacheKey& other) const;
1339 };
1340
operator ==(const TransposePlanCacheKey & other) const1341 bool TransposePlanCacheKey::operator==(
1342 const TransposePlanCacheKey& other) const {
1343 return elem_size_in_bytes == other.elem_size_in_bytes && dims == other.dims &&
1344 permutation == other.permutation &&
1345 input_layout_is_tiling == other.input_layout_is_tiling &&
1346 input_layout == other.input_layout &&
1347 output_tiling == other.output_tiling &&
1348 transformation == other.transformation &&
1349 num_threads == other.num_threads;
1350 }
1351
1352 template <typename H>
AbslHashValue(H h,const TransposePlanCacheKey & key)1353 H AbslHashValue(H h, const TransposePlanCacheKey& key) {
1354 return H::combine(std::move(h), key.elem_size_in_bytes,
1355 key.input_layout_is_tiling, key.num_threads,
1356 key.transformation, key.dims, key.permutation,
1357 key.input_layout, key.output_tiling);
1358 }
1359
TransposePlanCache(int capacity)1360 TransposePlanCache::TransposePlanCache(int capacity)
1361 : lru_list_(capacity), cache_(&lru_list_) {}
1362
1363 TransposePlanCache::~TransposePlanCache() = default;
1364
GetOrCreate(size_t elem_size_in_bytes,absl::Span<int64_t const> dims,absl::Span<int64_t const> permutation,std::variant<TransposePlan::Tiling,TransposePlan::Striding> input_layout,TransposePlan::Tiling output_tiling,TransposePlan::Transformation transformation,int num_threads)1365 StatusOr<std::shared_ptr<TransposePlan>> TransposePlanCache::GetOrCreate(
1366 size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
1367 absl::Span<int64_t const> permutation,
1368 std::variant<TransposePlan::Tiling, TransposePlan::Striding> input_layout,
1369 TransposePlan::Tiling output_tiling,
1370 TransposePlan::Transformation transformation, int num_threads) {
1371 TransposePlanCacheKey key;
1372 key.elem_size_in_bytes = elem_size_in_bytes;
1373 key.dims.resize(dims.size());
1374 absl::c_copy(dims, key.dims.begin());
1375 key.permutation.resize(permutation.size());
1376 absl::c_copy(permutation, key.permutation.begin());
1377 if (std::holds_alternative<TransposePlan::Striding>(input_layout)) {
1378 absl::Span<int64_t const> input_strides_in_bytes =
1379 std::get<TransposePlan::Striding>(input_layout).strides_in_bytes;
1380 key.input_layout = absl::InlinedVector<int64_t, 4>(
1381 input_strides_in_bytes.begin(), input_strides_in_bytes.end());
1382 key.input_layout_is_tiling = false;
1383 } else {
1384 absl::Span<int64_t const> input_tiling =
1385 std::get<TransposePlan::Tiling>(input_layout).tiling;
1386 key.input_layout = absl::InlinedVector<int64_t, 4>(input_tiling.begin(),
1387 input_tiling.end());
1388 key.input_layout_is_tiling = true;
1389 }
1390 key.output_tiling.resize(output_tiling.tiling.size());
1391 absl::c_copy(output_tiling.tiling, key.output_tiling.begin());
1392 key.transformation = transformation;
1393 key.num_threads = num_threads;
1394 return cache_.GetOrCreateIfAbsent(
1395 key,
1396 [&](const TransposePlanCacheKey& key)
1397 -> StatusOr<std::shared_ptr<TransposePlan>> {
1398 TF_ASSIGN_OR_RETURN(
1399 std::unique_ptr<TransposePlan> plan,
1400 TransposePlan::Create(elem_size_in_bytes, dims, permutation,
1401 input_layout, output_tiling, transformation,
1402 num_threads));
1403 return std::shared_ptr<TransposePlan>(std::move(plan));
1404 });
1405 }
1406
1407 } // namespace xla
1408