1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" 21 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 22 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" 23 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" 24 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 25 #include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 28 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" 29 30 namespace xla { 31 namespace gpu { 32 33 struct BufferSlice { 34 // The root buffer to look at. 35 BufferAllocation::Slice buffer_slice; 36 37 // Describes how to dereference starting at that buffer to get to the buffer 38 // in question. 39 ShapeIndex gte_index; 40 }; 41 42 // Describes how to access a particular subshape for an HLO. For instance if 43 // `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at 44 // ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is 45 // found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we 46 // dereference twice -- first at index 3, and then at index 4 -- to get the 47 // address of our buffer. 48 struct HloBufferSlice : public BufferSlice { 49 const HloInstruction* instr; 50 ShapeIndex hlo_index; 51 }; 52 53 struct MlirBufferSlice : public BufferSlice { 54 // The buffer is modified by the kernel. 55 bool written = false; 56 57 Shape shape; 58 }; 59 60 struct MlirEmitterInput { 61 mlir::Operation* op; 62 Thunk::ThunkInfo thunk_info; 63 64 // A field to allow plumbing extra information that BufferAssignment has, but 65 // LMHLO/MLIR representation does not have. Specifically, this is for passing 66 // the allocated buffer for tuple outputs (an array of pointers to tuple 67 // elements). 68 // 69 // TODO(timshen): We need a corresponding construct in LMHLO to represent 70 // this, aka an array of pointers to different memrefs. Once we have that, we 71 // can merge this information back to LMHLO graph and remove this field. 72 absl::optional<MlirBufferSlice> extra_slice; 73 }; 74 75 // Convenience struct that contains useful data structures in MLIR emitter. 76 // Not all fields may be filled. It's entiredly dependent on the uses. 77 struct MlirEmitterContext { 78 void SetOperation(mlir::Operation* op); 79 80 std::string name; 81 std::vector<Shape> operand_shapes; 82 std::vector<Shape> output_shapes; 83 }; 84 85 // Emits LLVM IR for an "unnested computation". 86 // 87 // An unnested computation is an HloComputation which you run by executing one 88 // or more kernels for each HloInstruction it contains. Examples of unnested 89 // computations: 90 // 91 // - An HloModule's root computation, 92 // - The body of an HLO while loop, 93 // - The true/false computation of an HLO conditional. 94 // 95 // Note the opportunity for confusion -- the while loop's computation is nested 96 // within the root computation, but it's emitted using IrEmitterUnnested! Don't 97 // think about it too hard. 98 // 99 // Examples of things that are not unnested computations: 100 // 101 // - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. 102 // - The body of a fusion node. IrEmitterUnnested emits the relevant code 103 // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not 104 // really an IrEmitter, but is more an "IR generator generator".) 105 // 106 class IrEmitterUnnested : public IrEmitter, 107 private ThunkEmitter::EmissionContext { 108 public: 109 struct ThreadIdInfo { 110 // Raw thread id. 111 llvm::Value* thread_id; 112 113 // X-coordinate calculated from thread id: `thread_id % num_threads_x` 114 llvm::Value* thread_id_x; 115 116 // Y-coordinate calculated from thread id: `thread_id / num_threads_x` 117 llvm::Value* thread_id_y; 118 119 // Lane id: `thread_id % kWarpSize` 120 llvm::Value* lane_id; 121 }; 122 platform_name()123 absl::string_view platform_name() const override { 124 return ir_emitter_context_->platform_name(); 125 } 126 127 // A function object to generate code to process one element in a tile. 128 // 129 // index: the index for the first output element of the current thread. 130 // y_loc: The y coordinate within a tile. 131 // x_loc: The x coordinate within a tile. 132 // x_iter_num: When a thread process N elements in the X dimension, x_iter_num 133 // has a value of 0..N-1 to identify the element being process. 134 using EmitElementFunction = std::function<void( 135 const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, 136 llvm::Value* x_loc, int64 x_iter_num)>; 137 138 using ConstantGenerator = std::function<llvm::Value*(int64)>; 139 140 // A function to generate the code to emit the entire tile. 141 using TileElementGenerator = std::function<void( 142 const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, 143 const string& loop_name, llvm::Value* tile_height, 144 llvm::Value* tile_width, KernelSupportLibrary* ksl)>; 145 146 IrEmitterUnnested(const IrEmitterUnnested&) = delete; 147 IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; 148 149 static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create( 150 const HloModuleConfig& hlo_module_config, 151 const HloComputation* hlo_computation, 152 IrEmitterContext* ir_emitter_context); 153 154 // Transfers the ownship of thunk_sequence_ out. ConsumeThunkSequence()155 std::unique_ptr<ThunkSequence> ConsumeThunkSequence() { 156 return std::make_unique<ThunkSequence>(std::move(thunk_sequence_)); 157 } 158 159 Status DefaultAction(HloInstruction* hlo) override; 160 Status HandleBitcast(HloInstruction* bitcast) override; 161 Status EmitUsingElementalIrEmitter(MlirEmitterInput input); 162 163 // IrEmitterUnnested handles the following instructions differently from 164 // IrEmitter. It also mixes in some special handling for custom kernels 165 // via the ThunkEmitter. 166 Status HandleConstant(HloInstruction* constant) override; 167 Status EmitConstant(MlirEmitterInput mlir_input); 168 169 Status HandleCopy(HloInstruction* copy) override; 170 Status EmitCopyForMlir(MlirEmitterInput input); 171 172 Status HandleConditional(HloInstruction* conditional) override; 173 Status HandleConvolution(HloInstruction* convolution) override; 174 Status HandleCustomCall(HloInstruction* custom_call) override; 175 Status EmitCustomCallFromMlir(MlirEmitterInput input); 176 Status EmitConvolutionThunkFromMlir(MlirEmitterInput input); 177 Status EmitGemmThunkFromMlir(MlirEmitterInput input); 178 Status EmitBatchNormThunkFromMlir(MlirEmitterInput input); 179 #if GOOGLE_CUDA 180 Status EmitCholeskyThunkFromMlir(MlirEmitterInput input); 181 #endif // GOOGLE_CUDA 182 Status EmitCustomCallThunkFromMlir(MlirEmitterInput input); 183 Status HandleFft(HloInstruction* fft) override; 184 Status EmitFftThunkFromMlir(MlirEmitterInput input); 185 Status HandleFusion(HloInstruction* fusion) override; 186 Status EmitLoopFusionFromMlir( 187 MlirEmitterInput input, const Shape& output_shape, 188 absl::optional<int> unroll_factor_override = {}); 189 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 190 Status HandleReduce(HloInstruction* reduce) override; 191 Status EmitReduceFromMlir(MlirEmitterInput mlir_input); 192 Status HandleSelectAndScatter(HloInstruction* instruction) override; 193 Status EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input); 194 Status HandleTuple(HloInstruction* tuple) override; 195 Status HandleWhile(HloInstruction* xla_while) override; 196 Status HandleInfeed(HloInstruction* xla_infeed) override; 197 Status HandleOutfeed(HloInstruction* outfeed) override; 198 Status HandleRng(HloInstruction* random) override; 199 Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; 200 Status EmitRngGetAndUpdateState(MlirEmitterInput mlir_input); 201 Status HandleScatter(HloInstruction* scatter) override; 202 Status EmitScatterFromMlir(MlirEmitterInput mlir_input); 203 Status HandleSort(HloInstruction* sort) override; 204 Status EmitSortFromMlir(MlirEmitterInput mlir_input); 205 Status HandleTriangularSolve(HloInstruction* hlo) override; 206 Status EmitTriangularSolveFromMlir(MlirEmitterInput mlir_input); 207 208 template <typename NcclThunkType, typename OpTy> 209 Status EmitNcclThunkFromMlir(MlirEmitterInput mlir_input); 210 Status HandleAllGather(HloInstruction* hlo) override; 211 Status HandleAllReduce(HloInstruction* hlo) override; 212 Status HandleAllToAll(HloInstruction* hlo) override; 213 214 Status HandleAfterAll(HloInstruction* after_all) override; 215 216 template <typename ThunkType, typename OpT> 217 Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input); 218 Status HandleReplicaId(HloInstruction* hlo) override; 219 Status HandlePartitionId(HloInstruction* hlo) override; 220 221 Status HandleCollectivePermute(HloInstruction* hlo) override; 222 Status EmitCollectivePermuteFromMlir(MlirEmitterInput input); 223 224 Status EmitOp(MlirEmitterInput mlir_input); 225 226 Status EmitTargetElementLoop( 227 const HloInstruction& hlo, 228 const llvm_ir::ElementGenerator& body_emitter) override; 229 230 // Same as `EmitTargetElementLoop`, but in given `thunk` rather than 231 // `LastThunk()`. The kernel implementation will be unrolled if 232 // `unroll_factor` is greater than one. 233 Status EmitTargetElementLoopInThunk( 234 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, 235 KernelThunk* thunk, int unroll_factor, bool few_waves = false); 236 237 Status Postprocess(HloInstruction* hlo) override; 238 239 private: 240 IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 241 const HloComputation* hlo_computation, 242 IrEmitterContext* ir_emitter_context); 243 244 // Add a owning Thunk object to the thunk sequence. AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)245 void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) override { 246 thunk_sequence_.emplace_back(std::move(thunk)); 247 } 248 249 // Input = {static array, dynamic_dim0, dynamic_dim1} 250 // Output = {dynamic array(with dynamic dimension meta data at the end)} 251 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 252 // (`_` stands for padding) 253 // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 254 // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}} 255 256 // pseudo code for padToStatic on a 2d array 257 // ``` 258 // void padToStatic(int** input, int** output, int threads_per_block, 259 // int meta_data_offset, int max_num_element, 260 // int static_dim0_size, int static_dim1_size) { 261 // int* source_array = input[0]; 262 // int* dest_array = output[0]; 263 264 // // extract the dynamic dimension from the source array's metadata 265 // int* dyn_dim0_size = source_array + meta_data_offset; 266 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); 267 268 // // only one thread need to store the dynamic index 269 // int thread_id = GetThreadId(); 270 // int block_id = GetBlockId(); 271 // if (thread_id == 0 && block_id == 0) { 272 // *output[1] = *dyn_dim0_size; 273 // *output[2] = *dyn_dim1_size; 274 // } 275 276 // int dyn_element_total = 1; 277 // dyn_element_total *= *dyn_dim0_size; 278 // dyn_element_total *= *dyn_dim1_size; 279 // linear_index = block_id * threads_per_block + thread_id; 280 // if (linear_index < max_num_element) { 281 // Index static_index = 282 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 283 // if (linerized_index < dyn_element_total) { 284 // Index dyn_index = 285 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 286 // dest_array[dyn_index.dim0][dyn_index.dim1] = 287 // source_array[static_index.dim0][static_index.dim1]; 288 // } 289 // } 290 // return; 291 // } 292 // ``` 293 Status EmitPadToStaticFromMlir(MlirEmitterInput mlir_input); 294 295 // Input = {dynamic array(with dynamic dimension meta data at the end)} 296 // Output = {static array, dynamic_dim0, dynamic_dim1} 297 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 298 // (`_` stands for padding) 299 // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}} 300 // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 301 302 // pseudo code for sliceToDynamic on a 2d array 303 // ``` 304 // void sliceToDynamic(int** input, int** output, int threads_per_block, 305 // int meta_data_offset, int max_num_element, 306 // int static_dim0_size, int static_dim1_size) { 307 // int* source_array = input[0]; 308 // int* dest_array = output[0]; 309 310 // // calculate the location where metadata needs to be inserted 311 // int* dyn_dim0_size = dest_array + meta_data_offset; 312 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int); 313 314 // // only one thread need to store the dynamic index 315 // int thread_id = GetThreadId(); 316 // int block_id = GetBlockId(); 317 // if (thread_id == 0 && block_id == 0) { 318 // *dyn_dim0_size = *output[1]; 319 // *dyn_dim1_size = *output[2]; 320 // } 321 322 // int dyn_element_total = 1; 323 // dyn_element_total *= *dyn_dim0_size; 324 // dyn_element_total *= *dyn_dim1_size; 325 // linear_index = block_id * threads_per_block + thread_id; 326 // if (linear_index < max_num_element) { 327 // Index static_index = 328 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 329 // if (linerized_index < dyn_element_total) { 330 // Index dyn_index = 331 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 332 // dest_array[static_index.dim0][static_index.dim1] = 333 // source_array[dyn_index.dim0][dyn_index.dim1]; 334 // } 335 // } 336 // return; 337 // } 338 // ``` 339 Status EmitSliceToDynamicFromMlir(MlirEmitterInput mlir_input); 340 341 // A convenient helper for calling BufferAssignment::GetUniqueSlice. MaybeGetAllocationSlice(const HloInstruction & hlo,const ShapeIndex & index)342 StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice( 343 const HloInstruction& hlo, const ShapeIndex& index) const override { 344 return ir_emitter_context_->buffer_assignment().GetUniqueSlice(&hlo, index); 345 } 346 347 BufferAllocation::Slice GetAllocationSlice( 348 const HloInstruction& hlo, const ShapeIndex& index = {}) const { 349 return MaybeGetAllocationSlice(hlo, index).ConsumeValueOrDie(); 350 } 351 352 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(mlir::Value v); 353 ByteSizeOf(const Shape & shape)354 int64 ByteSizeOf(const Shape& shape) const override { 355 return llvm_ir::ByteSizeOf( 356 shape, ir_emitter_context_->llvm_module()->getDataLayout()); 357 } 358 359 // Builds the prototype of the IR kernel for `inst` and adds it to the module. 360 // This kernel takes as arguments pointers to the given buffer allocations. 361 llvm::Function* BuildKernelPrototype( 362 absl::string_view name, absl::Span<const BufferAllocation* const> args); 363 364 // Helper for writing extra outputs from inside a reduce kernel. 365 Status EmitExtraOutputsForReduce( 366 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 367 const llvm_ir::IrArray::Index& index, bool use_linear_index, 368 absl::Span<const std::pair<llvm_ir::ElementGenerator, int>> 369 extra_output_gens); 370 371 // Generates code for reduction to contiguous dimensions. 372 // 373 // Row reduction uses the following algorithm described in CUDA-like 374 // pseudocode: 375 // 376 // ``` 377 // __global__ void reduce(int num_rows, float *in, float out) { 378 // __shared__ float[32] cache; 379 // int offset = blockDim.x * blockIdx.x + threadIdx.x; 380 // if (offset >= num_rows) return; 381 // int tile_bound = std::min(offset + kTileSizeX, num_rows); 382 // float accum = 0; 383 // for (int i=offset; i<num_rows; i+= blockDim.x) { 384 // accum += in[i]; 385 // } 386 // accum = warp_reduce(accum); 387 // if (threadIdx.x % kWarpSize == 0) { 388 // cache[threadIdx.x / kWarpSize] = accum; 389 // } 390 // __syncthreads(); 391 // if (threadIdx.x / kWarpSize == 0) { 392 // bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize); 393 // float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0; 394 // block_accum = warp_reduce(accum); 395 // if (threadIdx.x == 0) { 396 // out += block_accum; 397 // } 398 // } 399 // } 400 // ``` 401 // 402 // Column reduction uses the following algorithm: 403 // 404 // ``` 405 // void reduce(float** in, float* out) { 406 // __shared__ float[32][33] cache; 407 // int thread_id = GetThreadId(); 408 // int block_id = GetBlockId(); 409 // int tile_size = 128; 410 // 411 // float accum = 0; 412 // for (int i=0; i<tile_size; i++) { 413 // accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x]; 414 // } 415 // cache[thread_id.x][thread_id.y] = accum; 416 // 417 // __syncthreads(); 418 // accum = cache[thread_id.y][thread_id.x]; 419 // accum = warp_reduce(accum); // Sum all the values of `accum` in the same 420 // // warp. 421 // 422 // if (thread_id.y % 32 == 0) { 423 // out[block_id * 32 + thread_id.x] = accum; 424 // } 425 // } 426 // ``` 427 // 428 // Moreover, a heuristic is implemented to divide the reduce instructions 429 // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` 430 // for details about the heuristic.) Reduce instructions in the same group 431 // will run sequentially while different groups will run in parallel. 432 // 433 // we use raw block_id_y to select the reduce groups for execution without 434 // complicating the index calculation in the code generation of the reduce 435 // instructions. In other words, a block_id_y is assigned to a group and so 436 // different groups can be run in parallel. 437 Status EmitReductionFromOrToContiguousDimensions(MlirEmitterInput mlir_input); 438 439 // Computes the KernelMappingScheme for the reduce HLO and indicates whether 440 // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo 441 // and first_reduce are the same instruction. For a kInput fusion, 442 // unnested_hlo is the fusion instruction while first_reduce is the first 443 // reduce op. 444 ReductionCodegenInfo ComputeReductionCodegenInfo( 445 mlir::Operation* unnested_hlo, mlir::Operation* first_reduce); 446 447 // Generates code for input-fusible slices. 448 // 449 // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes 450 // of all ROOT slices need to be the same while their output shapes can be 451 // different. On the other hand, the input ranges of slices can be 452 // overlapping. Further generalization/specialization when the needs are seen 453 // in the future. 454 Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input); 455 456 Status EmitElementForInputFusibleSlices( 457 mlir::lmhlo::FusionOp fusion, 458 absl::Span<const llvm_ir::IrArray> ir_arrays, 459 const llvm_ir::IrArray::Index& index); 460 461 // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in 462 // the process. Scatter indices are taken from `scatter_indices_gen`, updates 463 // from `updates_gen`. The output buffer is expected to have the operand 464 // values in it already. If unique_indices is false, we will use an atomic 465 // update. Using true for unique_indices behaves properly only when it is 466 // guaranteed that the indices to be updated do not overlap. The caller is 467 // responsible for ensuring this is the case. 468 Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter, 469 const llvm_ir::IrArray& output, 470 const llvm_ir::ElementGenerator& scatter_indices_gen, 471 const llvm_ir::ElementGenerator& updates_gen, 472 std::function<llvm::Type*(int64)> get_index_type); 473 474 // Structure describing a scatter operation for IR emission. 475 // TODO(jurahul): Migrate element generators to use MLIR. 476 // Migrate update_computation to be an MLIR Region. 477 struct ScatterDescriptor { 478 std::string name; 479 Shape operand_shape; 480 Shape scatter_indices_shape; 481 Shape updates_shape; 482 mlir::mhlo::ScatterDimensionNumbers dim_numbers; 483 bool unique_indices; 484 const HloComputation* update_computation; 485 llvm_ir::IrArray output; 486 llvm_ir::ElementGenerator scatter_indices_gen; 487 llvm_ir::ElementGenerator updates_gen; 488 std::function<llvm::Type*(int64)> get_index_type; 489 }; 490 491 // Emits code for an in-place scatter using the provided scatter operation 492 // description. 493 Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk); 494 495 // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel 496 // for the hlo instruction. 497 StatusOr<bool> CheckAndEmitHloWithTile021(MlirEmitterInput input); 498 499 // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and 500 // sets the corresponding launch dimensions. This is a helper to support 501 // the implementation of CheckAndEmitHloWithTile021. 502 void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk, 503 const MlirEmitterContext& context, 504 absl::Span<const llvm_ir::IrArray> operand_arrays, 505 absl::Span<const llvm_ir::IrArray> output_arrays, 506 absl::Span<const int64> reduced_output_dims, 507 absl::Span<const int64> tiled_param_ids); 508 509 struct TilingKernelInfo { 510 // Tiling bounds. 511 std::array<llvm::Value*, 3> output_tile_bounds; 512 513 // Starting tile, as calculated from block id only. 514 llvm_ir::IrArray::Index tile_origin; 515 }; 516 517 // Emits a kernel for the hlo instruction using the given kernel mapping 518 // scheme. 519 TilingKernelInfo EmitTilingKernel( 520 const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, 521 const TileElementGenerator& tile_element_generator); 522 523 // Emits code to process up to 524 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, 525 // given `emit_elem_function` is the function to emit code to process one 526 // element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for 527 // the first element to process, and `index` is the index for the origin of 528 // the tile. Information about tile_size_x/y and num_threads_x/y are stored in 529 // `mapping_scheme`. Emits bounds check to ensure that each processed element 530 // is within the boundary defined by `tile_width` and `tile_height`. 531 // 532 // Pseudocode: 533 // 534 // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) { 535 // for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled 536 // if (dilated) { 537 // x_loc = x + j * num_threads_x; 538 // } else { 539 // x_loc = x * (tile_size_x / num_threads_x) + j; 540 // } 541 // 542 // if (x_loc < tile_width) { 543 // emit_elem_function(y + y_loc, x_loc); 544 // } 545 // } 546 // } 547 // 548 void EmitTile( 549 const KernelMappingScheme& mapping_scheme, 550 const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name, 551 KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info, 552 llvm::Value* tile_height, llvm::Value* tile_width, 553 const IrEmitterUnnested::EmitElementFunction& emit_elem_function); 554 555 // Emits code to process a tensor element in a tile for the given kCopy HLO 556 // that performs a 0-2-1 transpose. 557 // y_loc: The y coordinate within a tile. 558 // x_loc: The x coordinate within a tile. 559 void EmitTileElementForCopy( 560 const Shape& output_shape, const llvm_ir::IrArray& ir_array, 561 const llvm_ir::IrArray::Index& index, 562 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, 563 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers); 564 565 // Emits code to process a tensor element in a tile for the given kLoop 566 // fusion HLO containing parameters that are 0-2-1 transpose of its outputs. 567 // y_loc: The y coordinate within a tile. 568 // x_loc: The x coordinate within a tile. 569 void EmitTileElementForFusion( 570 mlir::lmhlo::FusionOp fusion, 571 absl::Span<const llvm_ir::IrArray> operand_arrays, 572 absl::Span<const llvm_ir::IrArray> output_arrays, 573 const llvm_ir::IrArray::Index& index, 574 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, 575 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers); 576 577 // Emits code to process a tensor element in a tile for the given input hlo 578 // that is either a unnested kReduce or a kInput fusion. 579 // 580 // Calculates and stores the temporary reduction value in the corresponding 581 // alloca. 582 // 583 // `instr_index_group` indicates a set of reductions this call needs to emit, 584 // each i points to the ith output of unnested_hlo. Notice that if 585 // unnested_hlo is not a multi-output fusion, instr_index_group is always {0}. 586 void EmitTileElementForReduction( 587 mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape, 588 absl::Span<const int> instr_index_group, 589 HloComputation* fused_computation, FusedIrEmitter* fused_emitter, 590 absl::Span<const llvm_ir::IrArray> operand_ir_arrays, 591 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 592 absl::Span<HloComputation* const> reducers, 593 const llvm_ir::IrArray::Index& index, 594 const ReductionCodegenInfo& reduction_info, int64 x_iter_num); 595 596 // Prepares for the code generation for a tile block of a reduction kernel. 597 // 598 // Create accumulator alloca's, populate them with initial values, and store 599 // inside reduction_info. 600 void EmitPrologueForReduction( 601 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group, 602 HloComputation* fused_computation, FusedIrEmitter* fused_emitter, 603 absl::Span<const llvm_ir::IrArray> operand_ir_arrays, 604 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 605 ReductionCodegenInfo* reduction_info); 606 607 // Wraps up the code generation for a tile block of a reduction kernel: 608 // write the calculated output into the output tensor. 609 void EmitEpilogueForReduction( 610 llvm::Type* index_ty, mlir::Operation* unnested_hlo, 611 absl::Span<const int> instr_index_group, 612 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 613 absl::Span<HloComputation* const> reducers, 614 const ReductionCodegenInfo& reduction_info, 615 const TilingKernelInfo& tiling_kernel_info); 616 617 // Emits code for reductions in the output_instructions. 618 void EmitIRForReduction(mlir::Operation* unnested_hlo, 619 absl::Span<const int> instr_index_group, 620 HloComputation* fused_computation, 621 FusedIrEmitter* fused_emitter, 622 absl::Span<const llvm_ir::IrArray> operand_ir_arrays, 623 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 624 ReductionCodegenInfo* reduction_info, 625 const Shape& input_shape); 626 627 // For each reducer, emits the shuffle-down loop to accumulate the partial 628 // result to the global result. 629 void EmitFullWarpShuffleDownLoopForAllReduces( 630 absl::Span<HloComputation* const> reducers, 631 absl::Span<llvm::AllocaInst* const> partial_result_addresses, 632 int threads_per_block); 633 634 // Emits shuffle-down reduction for the `partial_result_address` using the 635 // reduction computation `reducer` over types `element_type`. 636 void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer, 637 llvm::Type* element_type, 638 llvm::Value* partial_result_address, 639 int threads_per_block); 640 641 std::unique_ptr<KernelThunk> BuildKernelThunkFromBufferSlices( 642 absl::string_view name, Thunk::ThunkInfo thunk_info, 643 absl::Span<const BufferSlice* const> slices, 644 std::function<void(const BufferSlice*, llvm::Value*)> 645 bind_slice_to_ir_value); 646 647 // Returns a KernelThunk that invokes the kernel emitted for `inst`. The 648 // caller needs to make sure `inst` outlives the lifetime of the returned 649 // Thunk object. 'implements_whole_instruction' specifies whether this 650 // KernelThunk implements the whole 'inst' HloInstruction. In some cases 651 // 'inst' will be implemented by a sequence of Thunks. 652 std::unique_ptr<KernelThunk> BuildKernelThunk( 653 const HloInstruction* inst, bool implements_whole_instruction); 654 655 std::unique_ptr<KernelThunk> BuildKernelThunkForMlirImpl( 656 absl::string_view name, Thunk::ThunkInfo thunk_info, 657 absl::Span<const MlirBufferSlice> slices, 658 std::vector<llvm_ir::IrArray>* ir_arrays); 659 660 StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunkForMlir( 661 mlir::Operation* op, mlir::ValueRange operands, 662 Thunk::ThunkInfo thunk_info, absl::optional<MlirBufferSlice> extra_slice, 663 std::vector<llvm_ir::IrArray>* ir_arrays); 664 665 StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunkForMlir( 666 mlir::Operation* op, Thunk::ThunkInfo thunk_info, 667 absl::optional<MlirBufferSlice> extra_slice, 668 std::vector<llvm_ir::IrArray>* ir_arrays); 669 670 // Returns a thunk that, given a reduce or select-and-scatter op, 671 // initializes its memory to the appropriate initial value. 672 std::unique_ptr<Thunk> BuildConstantInitializerThunk( 673 absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest, 674 const Shape& output_shape); 675 676 StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk( 677 mlir::Value init_value, mlir::Value dest); 678 679 StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunkForMlir( 680 mlir::Operation* op, mlir::Value init_value, mlir::Value dest); 681 682 StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunkForMlir( 683 mlir::lmhlo::FusionOp fusion, int output_index); 684 685 // Returns a WhileThunk that invokes thunk sequences for 'condition' and 686 // 'body' sub-computations of while instruction 'hlo'. 687 StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo); 688 689 // Returns a ForThunk which executes 'loop_limit' invocations of a thunk 690 // sequence from the 'body' sub-computation of the while instruction 'hlo'. 691 StatusOr<std::unique_ptr<Thunk>> BuildForThunk(const HloInstruction* hlo, 692 const int64 loop_limit); 693 694 // Returns a ConditionalThunk which executes the thunk sequence for the 695 // 'branch_computation' corresponding to the predicate/branch_index of the 696 // given conditional instruction. 697 StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk( 698 const HloInstruction* hlo); 699 700 // Emits current thread id with the given type. 701 // 702 // Sets the return value range to [0, threads_per_block). 703 llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty); 704 705 // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane 706 // id. 707 // 708 // Returns a struct containting these values. 709 ThreadIdInfo EmitThreadIdInfo(int64 threads_per_block, llvm::Type* index_ty, 710 int64 num_threads_x); 711 712 // Emit __syncthreads(), synchronization barrier for all threads in a block. 713 llvm::CallInst* EmitSyncThreads(); 714 715 // Emits current block id. 716 llvm::Value* EmitBlockId(); 717 718 // Prints a given format string with the given arguments, prefixed with 719 // thread id and block id, and postfixed with a newline. 720 // 721 // `thread_id_filter` and `block_id_filter`: if provided, restrict printing 722 // to only given thread and/or block id. 723 void EmitPrintfWithThreadId( 724 absl::string_view fmt, absl::Span<llvm::Value* const> arguments, 725 absl::optional<int64> thread_id_filter = absl::nullopt, 726 absl::optional<int64> block_id_filter = absl::nullopt); 727 728 StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion( 729 mlir::Region* region, bool is_fusion); 730 731 StatusOr<MlirEmitterInput> GetMlirEmitterInput(HloInstruction* hlo); 732 733 // Returns the last generated thunk. LastThunk()734 Thunk* LastThunk() const { return thunk_sequence_.back().get(); } 735 736 Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override; 737 738 Status AssertNonDeterminismIsOkay(const string& op_name); 739 740 // The thunk sequence this IrEmitter generates for the input computation. 741 ThunkSequence thunk_sequence_; 742 743 // Begin optional members for XLA HLO -> LMHLO: 744 // TODO(timshen): Once XLA HLO -> LMHLO converter is complete, 745 // IrEmitterUnnested should take LMHLO only, and won't require a scratch 746 // module. 747 absl::optional<mlir::OwningModuleRef> mlir_scratch_module_; 748 749 // This is for cache-purpose only. It has no significant semantics. 750 absl::optional<mlir::LhloDialectEmitter> lhlo_scratch_emitter_; 751 752 absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>> 753 scratch_nested_computations_; 754 // End optional members for XLA HLO -> LMHLO. 755 }; 756 757 } // namespace gpu 758 } // namespace xla 759 760 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 761