1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file implements a out-of-place multidimensional array transpose 17 // inspired by the paper: 18 // 19 // Springer, P., Su, T. and Bientinesi, P., 2017, June. HPTT: A high-performance 20 // tensor transposition C++ library. In Proceedings of the 4th ACM SIGPLAN 21 // International Workshop on Libraries, Languages, and Compilers for Array 22 // Programming (pp. 56-62). 23 // https://arxiv.org/abs/1704.04374 24 // 25 26 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_ 27 #define TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_ 28 29 #include <cstdint> 30 #include <functional> 31 #include <memory> 32 #include <string> 33 #include <vector> 34 35 #include "absl/container/inlined_vector.h" 36 #include "absl/types/variant.h" 37 #include "tensorflow/compiler/xla/pjrt/lru_cache.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 40 namespace xla { 41 42 class TransposePlan { 43 public: 44 // elem_size_in_bytes: size of each element in bytes. 45 // dims: the input shape, in elements. 46 // permutation: for each output dimension, gives the number of the 47 // corresponding input dimension. Must be a permutation of [0..dims.size()) 48 // input_layout: either byte strides or an input tiling. 49 // 50 // A Striding represents the strides of the input array in bytes. (N.B. not 51 // elements). 52 // 53 // A Tiling is a tiling specification for the input or output array. May 54 // have fewer dimensions that `dims`, in which case the tiling applies to the 55 // minormost dimensions and any remaining dimensions are left untiled (i.e., 56 // tile size 1). An empty tiling corresponds to an untiled dense 57 // major-to-minor layout. 58 // 59 // For more information about tiling, see 60 // https://www.tensorflow.org/xla/tiled_layout 61 // This class supports a single level of tiling. In addition, the same 62 // dimension currently cannot have different non-trivial tiling values in 63 // both the input and output. 64 // 65 // The size of the plan may be exponential in the number of non-trivial 66 // tiled dimensions. This is acceptable because in the intended use case for 67 // this code we expect at most 2 tiled dimensions on input and output. 68 // 69 // The input may have either a striding or a tiling but not both. 70 // 71 // num_threads: is the number of threads requested. The actual number of 72 // threads used may be smaller if there isn't enough work per thread. 73 struct Tiling { 74 absl::Span<int64_t const> tiling; 75 }; 76 struct Striding { 77 absl::Span<int64_t const> strides_in_bytes; 78 }; 79 enum class Transformation { 80 // Apply no transformations to the data. 81 kNone = 0, 82 83 // Convert doubles into the ef57 extended precision pair-of-floats 84 // representation used on TPU. 85 kF64ToEf57 = 1, 86 }; 87 88 static StatusOr<std::unique_ptr<TransposePlan>> Create( 89 size_t elem_size_in_bytes, absl::Span<int64_t const> dims, 90 absl::Span<int64_t const> permutation, 91 std::variant<Tiling, Striding> input_layout = Tiling{}, 92 Tiling output_tiling = Tiling{}, 93 Transformation transformation = Transformation::kNone, 94 int num_threads = 1); 95 96 TransposePlan(); 97 ~TransposePlan(); 98 99 // Executes the transposition. 100 // `a` is the input array and `b` is the output array. The input and output 101 // arrays must not overlap. 102 // Currently there are no alignment requirements on either `a` or `b`. However 103 // performance may be better if either or both are aligned. 104 void Execute(const void* a, void* b, 105 const std::function<void(std::function<void(void)>)>& 106 schedule_work = {}) const; 107 108 // Returns a human-readable description of the plan. 109 std::string ToString() const; 110 ElemSizeInBytes()111 size_t ElemSizeInBytes() const { return elem_size_in_bytes_; } 112 113 // Input and output size, in number of elements. Ignores any input striding, 114 // but accounts for tiling. 115 int64_t InputNumElems() const; 116 int64_t OutputNumElems() const; 117 InputDims()118 absl::Span<int64_t const> InputDims() const { return original_a_dims_; } OutputDims()119 absl::Span<int64_t const> OutputDims() const { return original_b_dims_; } 120 InputStrides()121 absl::Span<int64_t const> InputStrides() const { return original_a_strides_; } 122 123 // Returns the number of items of parallel work in the plan. Parallelism()124 int Parallelism() const { return nodes_.size(); } 125 126 struct Node; 127 128 protected: 129 // Methods protected so they can be accessed by tests. 130 131 // Removes any size-1 dimensions. 132 static void RemoveTrivialDimensions( 133 absl::InlinedVector<int64_t, 4>& a_dims, 134 absl::InlinedVector<int64_t, 4>& permutation, 135 absl::InlinedVector<int64_t, 4>& lda, 136 absl::InlinedVector<int64_t, 4>& lda_tile, 137 absl::InlinedVector<int64_t, 4>& a_tiling, 138 absl::InlinedVector<int64_t, 4>& b_tiling); 139 140 // Collapses together dimensions that are adjacent both in `dims` and 141 // `permutation`. 142 static void CoalesceDimensions(absl::InlinedVector<int64_t, 4>& a_dims, 143 absl::InlinedVector<int64_t, 4>& permutation, 144 absl::InlinedVector<int64_t, 4>& lda, 145 absl::InlinedVector<int64_t, 4>& lda_tile, 146 absl::InlinedVector<int64_t, 4>& a_tiling, 147 absl::InlinedVector<int64_t, 4>& b_tiling); 148 149 private: 150 // Performs plan initialization that cannot fail. 151 void Initialize(); 152 153 void BuildPlanNodes(absl::Span<int64_t const> inverse_permutation, 154 int thread_id, std::vector<Node>& output_nodes); 155 156 std::vector<int> ChooseParallelizationStrategy( 157 absl::Span<int64_t const> inverse_permutation); 158 159 // The signature of ExecuteTyped uses char* pointers because we perform 160 // address calculations with strides in bytes; the strides need not be 161 // multiples of the element size. 162 template <typename T, Transformation transformation> 163 void ExecuteTyped(const char* a, char* b, absl::Span<Node const> nodes) const; 164 165 // Number of threads requested. 166 int num_threads_requested_ = 1; 167 168 // Size of each element in bytes. 169 int64_t elem_size_in_bytes_; 170 171 // Number of elements in the input array. 172 int64_t num_elems_; 173 174 // Description of the transpose, before any optimizations such as coalescing 175 // dimensions have been applied. 176 absl::InlinedVector<int64_t, 4> original_a_dims_; 177 absl::InlinedVector<int64_t, 4> original_a_strides_; 178 std::vector<int64_t> original_b_dims_; 179 180 // Dimensions of the input array A. 181 absl::InlinedVector<int64_t, 4> a_dims_; 182 absl::InlinedVector<int64_t, 4> a_strides_; 183 184 // Dimensions of the output array B. 185 std::vector<int64_t> b_dims_; 186 187 // Dimension permutation to apply to form B. For each dimension of B, what is 188 // the corresponding dimension of A? 189 absl::InlinedVector<int64_t, 4> permutation_; 190 191 // Leading-dimension sizes (byte strides) of each dimension. 192 absl::InlinedVector<int64_t, 4> lda_; 193 absl::InlinedVector<int64_t, 4> lda_tile_; 194 absl::InlinedVector<int64_t, 4> ldb_; 195 absl::InlinedVector<int64_t, 4> ldb_tile_; 196 197 // Tile sizes in each dimension. Has size equal to the number of dimensions. 198 // A 1 entry means that dimension is not tiled. 199 absl::InlinedVector<int64_t, 4> a_tiling_; 200 absl::InlinedVector<int64_t, 4> b_tiling_; 201 bool a_is_tiled_; 202 bool b_is_tiled_; 203 204 // Order to traverse dimensions, from slowest-varying to fastest-varying. 205 struct Loop { 206 // The integers are dimension numbers in A. 207 int dim_in_a; 208 // If true, the loop iterates over the interior of a tile. 209 bool tile_interior; 210 }; 211 std::vector<Loop> loop_order_; 212 std::vector<int> loop_parallelism_; 213 214 // Root nodes of the plan, i.e., pointing to the outermost loops in the loop 215 // nest. The outer vector is indexed on the thread ID. 216 absl::InlinedVector<std::vector<Node>, 1> nodes_; 217 218 // Are the innermost (stride-1) dimensions the same dimension? This determines 219 // whether the inner kernel is a transpose or a memcpy. 220 bool inner_kernel_is_memcpy_; 221 222 // Size of the inner (microkernel) block size. This is the unit of work for 223 // our vectorized kernels. 224 int inner_block_elems_ = 1; 225 // Size of the outer (macrokernel) block size. This is the unit of work for 226 // cache blocking and need not be equal between input and output. 227 int outer_block_elems_a_ = 4; 228 int outer_block_elems_b_ = 4; 229 230 // Transformations to apply to the input before transposition. 231 // Currently the only supported transformation is EF57 conversion, which is 232 // a pair-of-floats extended precision representation used on TPU. We 233 // support fusing transformations with the transpose for two reasons: 234 // (a) it makes sense to fuse cheap computations with a memory-bandwidth 235 // bound transformation, and 236 // (b) it allows us to support non-trivial striding. 237 Transformation transformation_; 238 239 // Size of the per-thread scratch buffer. 0 means "no scratch buffer required" 240 int64_t scratch_size_ = 0; 241 }; 242 243 struct TransposePlanCacheKey; 244 245 template <typename H> 246 H AbslHashValue(H h, const TransposePlanCacheKey& key); 247 248 // An LRU cache for transpose plans. Not thread-safe. 249 // Transpose plans aren't cheap to build, but once computed for a particular set 250 // of inputs can be cached and reused for arrays. TransposePlanCache implements 251 // such a cache. 252 class TransposePlanCache { 253 public: 254 explicit TransposePlanCache(int capacity); 255 ~TransposePlanCache(); 256 257 TransposePlanCache(const TransposePlanCache&) = delete; 258 TransposePlanCache(TransposePlanCache&&) = delete; 259 TransposePlanCache& operator=(const TransposePlanCache&) = delete; 260 TransposePlanCache& operator=(TransposePlanCache&&) = delete; 261 262 // Creates or returns a cached copy of a transpose plan. 263 StatusOr<std::shared_ptr<TransposePlan>> GetOrCreate( 264 size_t elem_size_in_bytes, absl::Span<int64_t const> dims, 265 absl::Span<int64_t const> permutation, 266 std::variant<TransposePlan::Tiling, TransposePlan::Striding> 267 input_layout = TransposePlan::Tiling{}, 268 TransposePlan::Tiling output_tiling = TransposePlan::Tiling{}, 269 TransposePlan::Transformation transformation = 270 TransposePlan::Transformation::kNone, 271 int num_threads = 1); 272 273 private: 274 LRUCache<TransposePlanCacheKey, 275 StatusOr<std::shared_ptr<TransposePlan>>>::LRUList lru_list_; 276 LRUCache<TransposePlanCacheKey, StatusOr<std::shared_ptr<TransposePlan>>> 277 cache_; 278 }; 279 280 } // namespace xla 281 282 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_H_ 283