• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements a out-of-place multidimensional array transpose
17 // inspired by the paper:
18 //
19 // Springer, P., Su, T. and Bientinesi, P., 2017, June. HPTT: A high-performance
20 // tensor transposition C++ library. In Proceedings of the 4th ACM SIGPLAN
21 // International Workshop on Libraries, Languages, and Compilers for Array
22 // Programming (pp. 56-62).
23 // https://arxiv.org/abs/1704.04374
24 //
25 
26 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_
27 #define TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_
28 
29 #include <cstdint>
30 #include <functional>
31 #include <memory>
32 #include <string>
33 #include <vector>
34 
35 #include "absl/container/inlined_vector.h"
36 #include "absl/types/variant.h"
37 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 
40 namespace xla {
41 
42 class TransposePlan {
43  public:
44   // elem_size_in_bytes: size of each element in bytes.
45   // dims: the input shape, in elements.
46   // permutation: for each output dimension, gives the number of the
47   //   corresponding input dimension. Must be a permutation of [0..dims.size())
48   // input_layout: either byte strides or an input tiling.
49   //
50   // A Striding represents the strides of the input array in bytes. (N.B. not
51   // elements).
52   //
53   // A Tiling is a tiling specification for the input or output array. May
54   // have fewer dimensions that `dims`, in which case the tiling applies to the
55   // minormost dimensions and any remaining dimensions are left untiled (i.e.,
56   // tile size 1). An empty tiling corresponds to an untiled dense
57   // major-to-minor layout.
58   //
59   // For more information about tiling, see
60   // https://www.tensorflow.org/xla/tiled_layout
61   // This class supports a single level of tiling. In addition, the same
62   // dimension currently cannot have different non-trivial tiling values in
63   // both the input and output.
64   //
65   // The size of the plan may be exponential in the number of non-trivial
66   // tiled dimensions. This is acceptable because in the intended use case for
67   // this code we expect at most 2 tiled dimensions on input and output.
68   //
69   // The input may have either a striding or a tiling but not both.
70   //
71   // num_threads: is the number of threads requested. The actual number of
72   //   threads used may be smaller if there isn't enough work per thread.
73   struct Tiling {
74     absl::Span<int64_t const> tiling;
75   };
76   struct Striding {
77     absl::Span<int64_t const> strides_in_bytes;
78   };
79   enum class Transformation {
80     // Apply no transformations to the data.
81     kNone = 0,
82 
83     // Convert doubles into the ef57 extended precision pair-of-floats
84     // representation used on TPU.
85     kF64ToEf57 = 1,
86   };
87 
88   static StatusOr<std::unique_ptr<TransposePlan>> Create(
89       size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
90       absl::Span<int64_t const> permutation,
91       std::variant<Tiling, Striding> input_layout = Tiling{},
92       Tiling output_tiling = Tiling{},
93       Transformation transformation = Transformation::kNone,
94       int num_threads = 1);
95 
96   TransposePlan();
97   ~TransposePlan();
98 
99   // Executes the transposition.
100   // `a` is the input array and `b` is the output array. The input and output
101   // arrays must not overlap.
102   // Currently there are no alignment requirements on either `a` or `b`. However
103   // performance may be better if either or both are aligned.
104   void Execute(const void* a, void* b,
105                const std::function<void(std::function<void(void)>)>&
106                    schedule_work = {}) const;
107 
108   // Returns a human-readable description of the plan.
109   std::string ToString() const;
110 
ElemSizeInBytes()111   size_t ElemSizeInBytes() const { return elem_size_in_bytes_; }
112 
113   // Input and output size, in number of elements. Ignores any input striding,
114   // but accounts for tiling.
115   int64_t InputNumElems() const;
116   int64_t OutputNumElems() const;
117 
InputDims()118   absl::Span<int64_t const> InputDims() const { return original_a_dims_; }
OutputDims()119   absl::Span<int64_t const> OutputDims() const { return original_b_dims_; }
120 
InputStrides()121   absl::Span<int64_t const> InputStrides() const { return original_a_strides_; }
122 
123   // Returns the number of items of parallel work in the plan.
Parallelism()124   int Parallelism() const { return nodes_.size(); }
125 
126   struct Node;
127 
128  protected:
129   // Methods protected so they can be accessed by tests.
130 
131   // Removes any size-1 dimensions.
132   static void RemoveTrivialDimensions(
133       absl::InlinedVector<int64_t, 4>& a_dims,
134       absl::InlinedVector<int64_t, 4>& permutation,
135       absl::InlinedVector<int64_t, 4>& lda,
136       absl::InlinedVector<int64_t, 4>& lda_tile,
137       absl::InlinedVector<int64_t, 4>& a_tiling,
138       absl::InlinedVector<int64_t, 4>& b_tiling);
139 
140   // Collapses together dimensions that are adjacent both in `dims` and
141   // `permutation`.
142   static void CoalesceDimensions(absl::InlinedVector<int64_t, 4>& a_dims,
143                                  absl::InlinedVector<int64_t, 4>& permutation,
144                                  absl::InlinedVector<int64_t, 4>& lda,
145                                  absl::InlinedVector<int64_t, 4>& lda_tile,
146                                  absl::InlinedVector<int64_t, 4>& a_tiling,
147                                  absl::InlinedVector<int64_t, 4>& b_tiling);
148 
149  private:
150   // Performs plan initialization that cannot fail.
151   void Initialize();
152 
153   void BuildPlanNodes(absl::Span<int64_t const> inverse_permutation,
154                       int thread_id, std::vector<Node>& output_nodes);
155 
156   std::vector<int> ChooseParallelizationStrategy(
157       absl::Span<int64_t const> inverse_permutation);
158 
159   // The signature of ExecuteTyped uses char* pointers because we perform
160   // address calculations with strides in bytes; the strides need not be
161   // multiples of the element size.
162   template <typename T, Transformation transformation>
163   void ExecuteTyped(const char* a, char* b, absl::Span<Node const> nodes) const;
164 
165   // Number of threads requested.
166   int num_threads_requested_ = 1;
167 
168   // Size of each element in bytes.
169   int64_t elem_size_in_bytes_;
170 
171   // Number of elements in the input array.
172   int64_t num_elems_;
173 
174   // Description of the transpose, before any optimizations such as coalescing
175   // dimensions have been applied.
176   absl::InlinedVector<int64_t, 4> original_a_dims_;
177   absl::InlinedVector<int64_t, 4> original_a_strides_;
178   std::vector<int64_t> original_b_dims_;
179 
180   // Dimensions of the input array A.
181   absl::InlinedVector<int64_t, 4> a_dims_;
182   absl::InlinedVector<int64_t, 4> a_strides_;
183 
184   // Dimensions of the output array B.
185   std::vector<int64_t> b_dims_;
186 
187   // Dimension permutation to apply to form B. For each dimension of B, what is
188   // the corresponding dimension of A?
189   absl::InlinedVector<int64_t, 4> permutation_;
190 
191   // Leading-dimension sizes (byte strides) of each dimension.
192   absl::InlinedVector<int64_t, 4> lda_;
193   absl::InlinedVector<int64_t, 4> lda_tile_;
194   absl::InlinedVector<int64_t, 4> ldb_;
195   absl::InlinedVector<int64_t, 4> ldb_tile_;
196 
197   // Tile sizes in each dimension. Has size equal to the number of dimensions.
198   // A 1 entry means that dimension is not tiled.
199   absl::InlinedVector<int64_t, 4> a_tiling_;
200   absl::InlinedVector<int64_t, 4> b_tiling_;
201   bool a_is_tiled_;
202   bool b_is_tiled_;
203 
204   // Order to traverse dimensions, from slowest-varying to fastest-varying.
205   struct Loop {
206     // The integers are dimension numbers in A.
207     int dim_in_a;
208     // If true, the loop iterates over the interior of a tile.
209     bool tile_interior;
210   };
211   std::vector<Loop> loop_order_;
212   std::vector<int> loop_parallelism_;
213 
214   // Root nodes of the plan, i.e., pointing to the outermost loops in the loop
215   // nest. The outer vector is indexed on the thread ID.
216   absl::InlinedVector<std::vector<Node>, 1> nodes_;
217 
218   // Are the innermost (stride-1) dimensions the same dimension? This determines
219   // whether the inner kernel is a transpose or a memcpy.
220   bool inner_kernel_is_memcpy_;
221 
222   // Size of the inner (microkernel) block size. This is the unit of work for
223   // our vectorized kernels.
224   int inner_block_elems_ = 1;
225   // Size of the outer (macrokernel) block size. This is the unit of work for
226   // cache blocking and need not be equal between input and output.
227   int outer_block_elems_a_ = 4;
228   int outer_block_elems_b_ = 4;
229 
230   // Transformations to apply to the input before transposition.
231   // Currently the only supported transformation is EF57 conversion, which is
232   // a pair-of-floats extended precision representation used on TPU. We
233   // support fusing transformations with the transpose for two reasons:
234   // (a) it makes sense to fuse cheap computations with a memory-bandwidth
235   //     bound transformation, and
236   // (b) it allows us to support non-trivial striding.
237   Transformation transformation_;
238 
239   // Size of the per-thread scratch buffer. 0 means "no scratch buffer required"
240   int64_t scratch_size_ = 0;
241 };
242 
243 struct TransposePlanCacheKey;
244 
245 template <typename H>
246 H AbslHashValue(H h, const TransposePlanCacheKey& key);
247 
248 // An LRU cache for transpose plans. Not thread-safe.
249 // Transpose plans aren't cheap to build, but once computed for a particular set
250 // of inputs can be cached and reused for arrays. TransposePlanCache implements
251 // such a cache.
252 class TransposePlanCache {
253  public:
254   explicit TransposePlanCache(int capacity);
255   ~TransposePlanCache();
256 
257   TransposePlanCache(const TransposePlanCache&) = delete;
258   TransposePlanCache(TransposePlanCache&&) = delete;
259   TransposePlanCache& operator=(const TransposePlanCache&) = delete;
260   TransposePlanCache& operator=(TransposePlanCache&&) = delete;
261 
262   // Creates or returns a cached copy of a transpose plan.
263   StatusOr<std::shared_ptr<TransposePlan>> GetOrCreate(
264       size_t elem_size_in_bytes, absl::Span<int64_t const> dims,
265       absl::Span<int64_t const> permutation,
266       std::variant<TransposePlan::Tiling, TransposePlan::Striding>
267           input_layout = TransposePlan::Tiling{},
268       TransposePlan::Tiling output_tiling = TransposePlan::Tiling{},
269       TransposePlan::Transformation transformation =
270           TransposePlan::Transformation::kNone,
271       int num_threads = 1);
272 
273  private:
274   LRUCache<TransposePlanCacheKey,
275            StatusOr<std::shared_ptr<TransposePlan>>>::LRUList lru_list_;
276   LRUCache<TransposePlanCacheKey, StatusOr<std::shared_ptr<TransposePlan>>>
277       cache_;
278 };
279 
280 }  // namespace xla
281 
282 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_
283