1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ 18 19 #include "llvm/IR/Value.h" 20 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 22 23 namespace xla { 24 namespace llvm_ir { 25 26 // About 0-2-1 transpose: 27 // 28 // If a shape can be viewed as three logical components 0-1-2 in the order of 29 // major to minor, a 0-2-1-transpose changes the order of such logical 30 // components to 0-2-1. We call the shape being transposed the input shape and 31 // the transposed shape the output shape. The logical view of the input/output 32 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized 33 // shapes. The original input/output shapes are called unnormalized shapes. 34 // 35 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the 36 // normalized shape of `b` or the 0-2-1 shape. 37 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a, 38 const Shape& b); 39 40 // A tile is a spatial subdivision of a tensor. We group tensor elements into 41 // tiles so that we can launch kernels to process the tensor elements in blocks 42 // of tiles. 43 // 44 // A kernel mapping scheme describes a method to partition the tensors accessed 45 // by an unnested HLO instruction into tiles and blocks of tiles, and the 46 // associated information to use hardware threads to process the tensor elements 47 // in blocks of tiles. 48 // 49 // Currently, there are two main use cases for a tiling scheme. First, we 50 // implement kernels with 0-2-1 memory transpose using shared memory to improve 51 // memory access pattern. Second, we implement reduction to contiguous 52 // dimensions in layout, with or without memory tranpsose, to achieve better 53 // memory access pattern as well as to reduce the need numbers of executed 54 // expensive instructions, such as thread synchronization related instructions 55 // and atomic operations. For both use cases, we can apply a normalization to 56 // the original tensors, to collapse contiguous dimensions for the same purpose 57 // and produce normlized three dimensional tensors. For this reason, the tiling 58 // scheme class only needs to handle normalized three dimensional tensors and 59 // two dimensional tiles. 60 // 61 // The current implementation of the class is somewhat NVIDIA GPU oriented. This 62 // situation can be improved when there is a need though. The idea of 0-2-1 63 // transpose using shared memory can be found in the following CUDA algorithm in 64 // TensorFlow: https://goo.gl/MStRV6. 65 // 66 // We use a thread block to process a tile because we want to use the HW thread 67 // block synchronization primitives to synchronize the processing of all the 68 // elements in the same tile. A thread block can be viewed as a two dimensional 69 // array of threads, described by the number of threads for the Y and X 70 // dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of 71 // (tile_size_y, tile_size_x) as follows: each thread in the thread block 72 // processes one element in the tile so that all the threads in the thread block 73 // together process a subdivision of the tile that has the same dimension as the 74 // thread block array. Then the thread block moves on to process the next 75 // subdivision of the tile until the whole tile is processed. Therefore, each 76 // thread in the thread block processes 77 // tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. 78 // 79 // There are situations where we want a thread block to process multiple 80 // tiles. We can't group those tiles into a bigger tiles because we limit a tile 81 // to a two dimensional spatial subdivision of a tensor. For example, when we 82 // use tiling to implement reduction with tranpose, we want the partial sum 83 // produced by each thread to accumulate values for more elements before using 84 // shlf_down and atomic_add instructions for further reduction, to amortize the 85 // cost of such expensive instructions. The concept of tile block is introduced 86 // for this purpose. A tile block is a three dimensional array of tiles, of 87 // which some dimensions may be degenerated to only one tile. 88 class KernelMappingScheme { 89 public: 90 enum { DimZ = 0, DimY, DimX, DimTot }; 91 92 public: KernelMappingScheme()93 KernelMappingScheme() {} 94 // dims_in_elems: the normalized tensor dimensions. 95 // req_block_sizes: the requested block size in number of tiles for each 96 // dimension. The actual block size is set to min(req_block_size, 97 // dims_in_number_of_blocks). 98 KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y, 99 int64 tile_size_x, 100 absl::Span<const int64> req_block_sizes, 101 int64 num_threads_y, int64 num_threads_x, 102 llvm::IRBuilder<>* b); 103 GetDimensionsInElements()104 absl::Span<const int64> GetDimensionsInElements() const { 105 return dims_in_elems_; 106 } GetDimensionsInTiles()107 absl::Span<const int64> GetDimensionsInTiles() const { 108 return dims_in_tiles_; 109 } GetDimensionsInBlocks()110 absl::Span<const int64> GetDimensionsInBlocks() const { 111 return dims_in_blocks_; 112 } 113 GetNumberOfTilesInTotal()114 int64 GetNumberOfTilesInTotal() const { 115 return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies<int64>()); 116 } GetNumberOfTilesInOneBlock()117 int64 GetNumberOfTilesInOneBlock() const { 118 return absl::c_accumulate(block_sizes_, 1, std::multiplies<int64>()); 119 } GetNumberOfTilesInOneBlockForDimension(int d)120 int64 GetNumberOfTilesInOneBlockForDimension(int d) const { 121 DCHECK(d >= DimZ && d <= DimX); 122 return block_sizes_[d]; 123 } GetNumberOfBlocks()124 int64 GetNumberOfBlocks() const { 125 return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies<int64>()); 126 } 127 GetTileSizeForDimension(int d)128 int64 GetTileSizeForDimension(int d) const { 129 DCHECK(d >= DimZ && d <= DimX); 130 return tile_sizes_[d]; 131 } GetTileSizeForDimensionX()132 int64 GetTileSizeForDimensionX() const { 133 return GetTileSizeForDimension(DimX); 134 } GetTileSizeForDimensionY()135 int64 GetTileSizeForDimensionY() const { 136 return GetTileSizeForDimension(DimY); 137 } 138 GetBlockSizes()139 absl::Span<const int64> GetBlockSizes() const { return block_sizes_; } GetTileBlockSizeForDimension(int d)140 int64 GetTileBlockSizeForDimension(int d) const { 141 DCHECK(d >= DimZ && d <= DimX); 142 return dims_in_blocks_[d]; 143 } 144 GetNumberOfThreadsForDimensionX()145 int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } GetNumberOfThreadsForDimensionY()146 int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } 147 GetThreadsPerBlock()148 int64 GetThreadsPerBlock() const { 149 return GetNumberOfThreadsForDimensionX() * 150 GetNumberOfThreadsForDimensionY(); 151 } 152 DilatedX()153 bool DilatedX() const { return dilated_x_; } SetDilatedX(bool v)154 void SetDilatedX(bool v) { 155 dilated_x_ = v; 156 if (!dilated_x_) { 157 // dilated_x_=false is for the purpose of vectorization, which requires 158 // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. 159 CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); 160 } 161 } 162 163 IrArray::Index EmitBlockIndex(llvm::Type* index_ty); 164 // Returns the index for the first tile in the block with the given block 165 // index. 166 IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index); 167 // Returns the index for the first element in the tile with the given tile 168 // index. 169 IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index); 170 171 std::tuple<llvm::Value*, llvm::Value*> EmitThreadYXCoordinate( 172 llvm::Type* index_ty); 173 174 IrArray::Index GetUnnormalizedIndex( 175 const IrArray::Index& normalized_shape_index, 176 const Shape& unnormalized_shape); 177 178 llvm::GlobalVariable* GetSharedMemoryBufferForElementType( 179 llvm::Type* elem_ty, absl::string_view buffer_name); 180 181 private: 182 llvm::IRBuilder<>* b_; 183 // The number of elements in each dimension. 184 std::vector<int64> dims_in_elems_; 185 186 // The number of elements for each dimension of a tile. 187 std::vector<int64> tile_sizes_; 188 // The number of tiles in each dimension. It is computed from dims_in_elem_ 189 // and tile_sizes_. 190 std::vector<int64> dims_in_tiles_; 191 192 // The number of tiles for each dimension of a tile block. 193 std::vector<int64> block_sizes_; 194 // The number of blocks in each dimension of a tile block. It is computed from 195 // dims_in_tile_ and block_sizes_. 196 std::vector<int64> dims_in_blocks_; 197 198 // Number of threads used to process elements in the X direction of a tile. 199 int64 num_threads_x_; 200 // Number of threads used to process elements in the Y direction of a tile. 201 int64 num_threads_y_; 202 203 // When num_threads_x threads process a total of tile_size_x elements in the 204 // X dimension of a tile, each threads process n=tile_size_x/num_threads_x 205 // elements. When dilated_x=false, the n elements processed by a thread are 206 // contiguous. On the other hand, when dilated_x=true the n elements are 207 // dilated by a factor of num_threads_x. 208 bool dilated_x_; 209 }; 210 211 // A class to represent information for tiled parameters to support IR emission 212 // for 021 transpose. 213 class TiledParameterInfo { 214 public: TiledParameterInfo(absl::Span<llvm::Value * const> param_buffers,llvm::Value * y,llvm::Value * x)215 TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers, 216 llvm::Value* y, llvm::Value* x) 217 : param_buffers_(param_buffers), y_(y), x_(x) {} 218 x()219 llvm::Value* x() const { return x_; } y()220 llvm::Value* y() const { return y_; } 221 set_x(llvm::Value * x)222 void set_x(llvm::Value* x) { x_ = x; } set_y(llvm::Value * y)223 void set_y(llvm::Value* y) { y_ = y; } 224 GetBufferForParameter(int64 index)225 llvm::Value* GetBufferForParameter(int64 index) const { 226 return param_buffers_[index]; 227 } 228 229 private: 230 // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr 231 // if the parameter is not tiled. 232 absl::Span<llvm::Value* const> param_buffers_; 233 // The y coordinate within a tile. 234 llvm::Value* y_; 235 // The x coordinate within a tile. 236 llvm::Value* x_; 237 }; 238 239 } // namespace llvm_ir 240 } // namespace xla 241 242 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ 243