1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "absl/types/span.h" 21 #include "llvm/IR/Value.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 24 25 namespace xla { 26 namespace gpu { 27 28 // A tile is a spatial subdivision of a tensor. We group tensor elements into 29 // tiles so that we can launch kernels to process the tensor elements in blocks 30 // of tiles. 31 // 32 // A kernel mapping scheme describes a method to partition the tensors accessed 33 // by an unnested HLO instruction into tiles and blocks of tiles, and the 34 // associated information to use hardware threads to process the tensor elements 35 // in blocks of tiles. 36 // 37 // Currently, there are two main use cases for a tiling scheme. First, we 38 // implement kernels with 0-2-1 memory transpose using shared memory to improve 39 // memory access pattern. Second, we implement reduction to contiguous 40 // dimensions in layout, with or without memory transpose, to achieve better 41 // memory access pattern as well as to reduce the need numbers of executed 42 // expensive instructions, such as thread synchronization related instructions 43 // and atomic operations. For both use cases, we can apply a normalization to 44 // the original tensors, to collapse contiguous dimensions for the same purpose 45 // and produce normlized three dimensional tensors. For this reason, the tiling 46 // scheme class only needs to handle normalized three dimensional tensors and 47 // two dimensional tiles. 48 // 49 // The current implementation of the class is somewhat NVIDIA GPU oriented. This 50 // situation can be improved when there is a need though. The idea of 0-2-1 51 // transpose using shared memory can be found in the following CUDA algorithm in 52 // TensorFlow: https://goo.gl/MStRV6. 53 // 54 // We use a thread block to process a tile because we want to use the HW thread 55 // block synchronization primitives to synchronize the processing of all the 56 // elements in the same tile. A thread block can be viewed as a two dimensional 57 // array of threads, described by the number of threads for the Y and X 58 // dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of 59 // (tile_size_y, tile_size_x) as follows: each thread in the thread block 60 // processes one element in the tile so that all the threads in the thread block 61 // together process a subdivision of the tile that has the same dimension as the 62 // thread block array. Then the thread block moves on to process the next 63 // subdivision of the tile until the whole tile is processed. Therefore, each 64 // thread in the thread block processes 65 // tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. 66 // 67 // There are situations where we want a thread block to process multiple 68 // tiles. We can't group those tiles into a bigger tiles because we limit a tile 69 // to a two dimensional spatial subdivision of a tensor. For example, when we 70 // use tiling to implement reduction with tranpose, we want the partial sum 71 // produced by each thread to accumulate values for more elements before using 72 // shlf_down and atomic_add instructions for further reduction, to amortize the 73 // cost of such expensive instructions. The concept of tile block is introduced 74 // for this purpose. A tile block is a three dimensional array of tiles, of 75 // which some dimensions may be degenerated to only one tile. 76 class KernelMappingScheme { 77 public: 78 enum { DimZ = 0, DimY, DimX, DimTot }; 79 enum IndexingOrder { 80 // Thread reads consecutive elements. 81 LinearIndexingX, 82 // Thread reads strided elements while keeping memory coalescing. 83 StridedIndexingX, 84 // Thread reads a few consecutive elements then take a strided 85 // step. This can trigger vectorized reads and keep memory 86 // coalescing. 87 StridedLinearIndexingX 88 }; 89 90 KernelMappingScheme(absl::Span<const int64> dims_in_elems, 91 absl::Span<const int64> tile_sizes, int64 num_threads_y, 92 int64 num_threads_x, IndexingOrder indexing_order, 93 int vector_size, bool is_row_contiguous = false) 94 : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, 95 tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]}, 96 num_threads_x_(num_threads_x), 97 num_threads_y_(num_threads_y), 98 indexing_order_(indexing_order), 99 vector_size_(vector_size), 100 is_row_contiguous_(is_row_contiguous) { 101 CHECK_EQ(tile_sizes[1] % num_threads_y_, 0); 102 CHECK_EQ(tile_sizes[2] % num_threads_x_, 0); 103 VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); 104 if (indexing_order != LinearIndexingX) { 105 // StridedIndexingX, and StridedLinearIndexingX 106 // is for the purpose of vectorization, which requires 107 // GetTileSizeFor(DimX) to be a multiplier of num_threads_x_. 108 CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0); 109 } 110 } 111 112 // Number of elements in each dimension (Z/Y/X respectively). GetDimsInElems()113 absl::Span<const int64> GetDimsInElems() const { return dims_in_elems_; } 114 GetNumberOfBlocks()115 int64 GetNumberOfBlocks() const { 116 return CeilOfRatio(dims_in_elems_[0], GetTileSizeZ()) * 117 CeilOfRatio(dims_in_elems_[1], GetTileSizeY()) * 118 CeilOfRatio(dims_in_elems_[2], GetTileSizeX()); 119 } 120 121 // Tile size for a given dimensions. Tiles are assigned per thread block, 122 // and are processed by all threads in the block. GetTileSizeFor(int d)123 int64 GetTileSizeFor(int d) const { return tile_sizes_.at(d); } 124 GetTileSizeZ()125 int64 GetTileSizeZ() const { return GetTileSizeFor(DimZ); } GetTileSizeX()126 int64 GetTileSizeX() const { return GetTileSizeFor(DimX); } GetTileSizeY()127 int64 GetTileSizeY() const { return GetTileSizeFor(DimY); } 128 GetNumThreadsX()129 int64 GetNumThreadsX() const { return num_threads_x_; } GetNumThreadsY()130 int64 GetNumThreadsY() const { return num_threads_y_; } 131 GetThreadsPerBlock()132 int64 GetThreadsPerBlock() const { 133 return GetNumThreadsX() * GetNumThreadsY(); 134 } 135 GetIndexingOrder()136 IndexingOrder GetIndexingOrder() const { return indexing_order_; } GetVectorSize()137 int GetVectorSize() const { return vector_size_; } GetRowContiguous()138 bool GetRowContiguous() const { return is_row_contiguous_; } 139 140 private: 141 // The number of elements in each dimension. 142 const std::array<int64, 3> dims_in_elems_; 143 144 // The number of elements for each dimension of a tile. 145 const std::array<int64, 3> tile_sizes_; 146 147 // Number of threads used to process elements in the X direction of a tile. 148 const int64 num_threads_x_; 149 150 // Number of threads used to process elements in the Y direction of a tile. 151 const int64 num_threads_y_; 152 153 // When num_threads_x threads process a total of tile_size_x 154 // elements in the X dimension of a tile, each threads process 155 // n=tile_size_x/num_threads_x elements. 156 // indexing_order defines which tile's elements each thread reads. 157 const IndexingOrder indexing_order_; 158 159 // vector_size_ only supported for row reduction and must be a divisor 160 // of tile_sizes_[2]/num_threads_x. Interesting values are 2 and 4 161 // to trigger vectorized loads on GPUs while keeping memory 162 // coalescing. 163 const int vector_size_; 164 const bool is_row_contiguous_; 165 }; 166 167 // Information to support the code generation for a tiled reduction kernel. 168 using AddressVector = absl::InlinedVector<llvm::AllocaInst*, 1>; 169 class ReductionCodegenInfo { 170 public: ReductionCodegenInfo(KernelMappingScheme mapping_scheme,int num_partial_results,bool is_row_reduction)171 explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme, 172 int num_partial_results, bool is_row_reduction) 173 : mapping_scheme_(mapping_scheme), 174 num_partial_results_(num_partial_results), 175 is_row_reduction_(is_row_reduction) { 176 if (num_partial_results > 1) { 177 CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() / 178 mapping_scheme.GetNumThreadsX())); 179 } 180 } 181 GetKernelMappingScheme()182 const KernelMappingScheme& GetKernelMappingScheme() const { 183 return mapping_scheme_; 184 } 185 186 // Gets writeable pointer to the address (or addresses) used to store 187 // reduction accumulators. GetMutablePartialResultAddresses()188 AddressVector* GetMutablePartialResultAddresses() { 189 return &partial_result_addresses_; 190 } 191 192 // Returns the address (addresses) of the reduction accumulators. GetPartialResultAddresses()193 absl::Span<llvm::AllocaInst* const> GetPartialResultAddresses() const { 194 return partial_result_addresses_; 195 } 196 197 // Mutable pointer to the address of the input element to perform the 198 // reduction with. GetMutableReductionInputAddresses()199 AddressVector* GetMutableReductionInputAddresses() { 200 return &reduction_input_addresses_; 201 } 202 GetMutableInitialValues()203 std::vector<llvm::Value*>* GetMutableInitialValues() { 204 return &initial_values_; 205 } 206 GetInitialValues()207 absl::Span<llvm::Value* const> GetInitialValues() const { 208 return initial_values_; 209 } 210 211 // Returns the address of the input element to perform the reduction with. GetReductionInputAddresses()212 absl::Span<llvm::AllocaInst* const> GetReductionInputAddresses() const { 213 return reduction_input_addresses_; 214 } 215 GetNumPartialResults()216 int GetNumPartialResults() const { return num_partial_results_; } IsRowReduction()217 bool IsRowReduction() const { return is_row_reduction_; } 218 219 // Gets a pointer to a mutable shared cache used by reduction. GetMutableSharedCache()220 std::vector<llvm::GlobalVariable*>* GetMutableSharedCache() { 221 return &shared_cache_; 222 } 223 224 // Shared cache used for reduction. GetSharedCache()225 absl::Span<llvm::GlobalVariable* const> GetSharedCache() const { 226 return shared_cache_; 227 } 228 229 private: 230 std::vector<llvm::GlobalVariable*> shared_cache_; 231 std::vector<llvm::Value*> initial_values_; 232 const KernelMappingScheme mapping_scheme_; 233 AddressVector partial_result_addresses_; 234 AddressVector reduction_input_addresses_; 235 int num_partial_results_; 236 bool is_row_reduction_; 237 }; 238 239 } // end namespace gpu 240 } // end namespace xla 241 242 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ 243