1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" 21 #include "tensorflow/compiler/xla/service/custom_call_status.h" 22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 23 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 24 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" 25 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" 26 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" 27 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 30 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" 31 32 namespace xla { 33 namespace gpu { 34 35 struct BufferSlice { 36 // The root buffer to look at. 37 BufferAllocation::Slice buffer_slice; 38 39 // The global constant name of the buffer, if it's a constant. 40 std::string constant_name; 41 42 // The buffer is modified by the kernel. 43 bool written = false; 44 45 Shape shape; 46 }; 47 48 // Convenience struct that contains useful data structures in MLIR emitter. 49 // Not all fields may be filled. It's entiredly dependent on the uses. 50 struct MlirEmitterContext { 51 void SetOperation(mlir::Operation* op); 52 53 std::string name; 54 std::vector<Shape> operand_shapes; 55 std::vector<Shape> output_shapes; 56 }; 57 58 // Emits LLVM IR for an "unnested computation". 59 // 60 // An unnested computation is an HloComputation which you run by executing one 61 // or more kernels for each HloInstruction it contains. Examples of unnested 62 // computations: 63 // 64 // - An HloModule's root computation, 65 // - The body of an HLO while loop, 66 // - The true/false computation of an HLO conditional. 67 // 68 // Note the opportunity for confusion -- the while loop's computation is nested 69 // within the root computation, but it's emitted using IrEmitterUnnested! Don't 70 // think about it too hard. 71 // 72 // Examples of things that are not unnested computations: 73 // 74 // - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. 75 // - The body of a fusion node. IrEmitterUnnested emits the relevant code 76 // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not 77 // really an IrEmitter, but is more an "IR generator generator".) 78 // 79 class IrEmitterUnnested : public IrEmitter { 80 public: 81 struct ThreadIdInfo { 82 // Raw thread id. 83 llvm::Value* thread_id; 84 85 // X-coordinate calculated from thread id: `thread_id % num_threads_x` 86 llvm::Value* thread_id_x; 87 88 // Y-coordinate calculated from thread id: `thread_id / num_threads_x` 89 llvm::Value* thread_id_y; 90 91 // Lane id: `thread_id % kWarpSize` 92 llvm::Value* lane_id; 93 }; 94 platform_name()95 absl::string_view platform_name() const { 96 return ir_emitter_context_->platform_name(); 97 } 98 99 // A function object to generate code to process one element in a tile. 100 // 101 // index: the index for the first output element of the current thread. 102 // y_loc: The y coordinate within a tile. 103 // x_loc: The x coordinate within a tile. 104 // x_iter_num: When a thread process N elements in the X dimension, x_iter_num 105 // has a value of 0..N-1 to identify the element being process. 106 using EmitElementFunction = std::function<void( 107 const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, 108 llvm::Value* x_loc, int64_t x_iter_num)>; 109 110 using ConstantGenerator = std::function<llvm::Value*(int64_t)>; 111 112 // A function to generate the code to emit the entire tile. 113 using TileElementGenerator = std::function<void( 114 const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, 115 const string& loop_name, llvm::Value* tile_height, 116 llvm::Value* tile_width, KernelSupportLibrary* ksl)>; 117 118 IrEmitterUnnested(const IrEmitterUnnested&) = delete; 119 IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; 120 121 static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create( 122 const HloModuleConfig& hlo_module_config, 123 IrEmitterContext* ir_emitter_context); 124 125 // Transfers the ownship of thunk_sequence_ out. ConsumeThunkSequence()126 std::unique_ptr<ThunkSequence> ConsumeThunkSequence() { 127 return std::make_unique<ThunkSequence>(std::move(thunk_sequence_)); 128 } 129 130 Status EmitLmhloRegion(mlir::Region* region); 131 132 private: 133 IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 134 IrEmitterContext* ir_emitter_context); 135 136 // IrEmitterUnnested handles the following instructions differently from 137 // IrEmitter. It also mixes in some special handling for custom kernels 138 // via the ThunkEmitter. 139 Status EmitConstant(mlir::Operation* op); 140 141 Status EmitCopy(mlir::Operation* op); 142 143 Status EmitConditional(mlir::Operation* op); 144 Status EmitConvolutionThunk(mlir::Operation* op); 145 Status EmitGemmThunk(mlir::Operation* op); 146 Status EmitBatchNormThunk(mlir::Operation* op); 147 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 148 Status EmitCholeskyThunk(mlir::Operation* op); 149 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 150 Status EmitCustomCallThunk(mlir::Operation* op); 151 Status EmitFftThunk(mlir::Operation* op); 152 Status EmitFusion(mlir::Operation* op); 153 Status EmitLoopFusion(mlir::Operation* op); 154 Status EmitReduce(mlir::Operation* op); 155 Status EmitSelectAndScatter(mlir::Operation* op); 156 Status EmitWhile(mlir::Operation* op); 157 Status EmitInfeed(mlir::Operation* op); 158 Status EmitOutfeed(mlir::Operation* op); 159 Status EmitRngGetAndUpdateState(mlir::Operation* op); 160 Status EmitScatter(mlir::Operation* op); 161 Status EmitSort(mlir::Operation* op); 162 Status EmitTriangularSolve(mlir::Operation* op); 163 164 template <typename NcclThunkType, typename OpTy> 165 Status EmitNcclThunk(mlir::Operation* op); 166 Status EmitAllReduceDone(mlir::Operation* op); 167 168 template <typename ThunkType, typename OpT> 169 Status EmitReplicaOrPartitionId(mlir::Operation* op); 170 171 Status EmitCollectivePermute(mlir::Operation* op); 172 173 Status EmitOp(mlir::Operation* op); 174 175 static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op); 176 177 Status EmitTargetElementLoop( 178 const HloInstruction& hlo, 179 const llvm_ir::ElementGenerator& body_emitter) override; 180 181 // Add a owning Thunk object to the thunk sequence. AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)182 void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) { 183 thunk_sequence_.emplace_back(std::move(thunk)); 184 } 185 186 // Input = {static array, dynamic_dim0, dynamic_dim1} 187 // Output = {dynamic array(with dynamic dimension meta data at the end)} 188 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 189 // (`_` stands for padding) 190 // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 191 // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}} 192 193 // pseudo code for padToStatic on a 2d array 194 // ``` 195 // void padToStatic(int** input, int** output, int threads_per_block, 196 // int meta_data_offset, int max_num_element, 197 // int static_dim0_size, int static_dim1_size) { 198 // int* source_array = input[0]; 199 // int* dest_array = output[0]; 200 201 // // extract the dynamic dimension from the source array's metadata 202 // int* dyn_dim0_size = source_array + meta_data_offset; 203 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); 204 205 // // only one thread need to store the dynamic index 206 // int thread_id = GetThreadId(); 207 // int block_id = GetBlockId(); 208 // if (thread_id == 0 && block_id == 0) { 209 // *output[1] = *dyn_dim0_size; 210 // *output[2] = *dyn_dim1_size; 211 // } 212 213 // int dyn_element_total = 1; 214 // dyn_element_total *= *dyn_dim0_size; 215 // dyn_element_total *= *dyn_dim1_size; 216 // linear_index = block_id * threads_per_block + thread_id; 217 // if (linear_index < max_num_element) { 218 // Index static_index = 219 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 220 // if (linerized_index < dyn_element_total) { 221 // Index dyn_index = 222 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 223 // dest_array[dyn_index.dim0][dyn_index.dim1] = 224 // source_array[static_index.dim0][static_index.dim1]; 225 // } 226 // } 227 // return; 228 // } 229 // ``` 230 Status EmitPadToStatic(mlir::Operation* op); 231 232 // Input = {dynamic array(with dynamic dimension meta data at the end)} 233 // Output = {static array, dynamic_dim0, dynamic_dim1} 234 // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3] 235 // (`_` stands for padding) 236 // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}} 237 // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3} 238 239 // pseudo code for sliceToDynamic on a 2d array 240 // ``` 241 // void sliceToDynamic(int** input, int** output, int threads_per_block, 242 // int meta_data_offset, int max_num_element, 243 // int static_dim0_size, int static_dim1_size) { 244 // int* source_array = input[0]; 245 // int* dest_array = output[0]; 246 247 // // calculate the location where metadata needs to be inserted 248 // int* dyn_dim0_size = dest_array + meta_data_offset; 249 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int); 250 251 // // only one thread need to store the dynamic index 252 // int thread_id = GetThreadId(); 253 // int block_id = GetBlockId(); 254 // if (thread_id == 0 && block_id == 0) { 255 // *dyn_dim0_size = *output[1]; 256 // *dyn_dim1_size = *output[2]; 257 // } 258 259 // int dyn_element_total = 1; 260 // dyn_element_total *= *dyn_dim0_size; 261 // dyn_element_total *= *dyn_dim1_size; 262 // linear_index = block_id * threads_per_block + thread_id; 263 // if (linear_index < max_num_element) { 264 // Index static_index = 265 // delinerized(linerized_index, static_dim0_size, static_dim1_size); 266 // if (linerized_index < dyn_element_total) { 267 // Index dyn_index = 268 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size); 269 // dest_array[static_index.dim0][static_index.dim1] = 270 // source_array[dyn_index.dim0][dyn_index.dim1]; 271 // } 272 // } 273 // return; 274 // } 275 // ``` 276 Status EmitSliceToDynamic(mlir::Operation* op); 277 278 StatusOr<BufferAllocation::Slice> GetAllocationSlice( 279 mlir::Value v, std::string* constant_name = nullptr); 280 ByteSizeOf(const Shape & shape)281 int64 ByteSizeOf(const Shape& shape) const { 282 return llvm_ir::ByteSizeOf( 283 shape, ir_emitter_context_->llvm_module()->getDataLayout()); 284 } 285 286 // Builds the prototype of the IR kernel for `inst` and adds it to the module. 287 // This kernel takes as arguments pointers to the given buffer allocations. 288 llvm::Function* BuildKernelPrototype( 289 absl::string_view name, absl::Span<const BufferAllocation* const> args); 290 291 // Helper for writing extra outputs from inside a reduce kernel. 292 Status EmitExtraOutputsForReduce( 293 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 294 const llvm_ir::IrArray::Index& index, bool use_linear_index, 295 absl::Span<const std::pair<llvm_ir::ElementGenerator, int>> 296 extra_output_gens); 297 298 // Generates code for reduction to contiguous dimensions. 299 // 300 // Row reduction uses the following algorithm described in CUDA-like 301 // pseudocode: 302 // 303 // ``` 304 // __global__ void reduce(int num_rows, float *in, float out) { 305 // __shared__ float[32] cache; 306 // int offset = blockDim.x * blockIdx.x + threadIdx.x; 307 // if (offset >= num_rows) return; 308 // int tile_bound = std::min(offset + kTileSizeX, num_rows); 309 // float accum = 0; 310 // for (int i=offset; i<num_rows; i+= blockDim.x) { 311 // accum += in[i]; 312 // } 313 // accum = warp_reduce(accum); 314 // if (threadIdx.x % kWarpSize == 0) { 315 // cache[threadIdx.x / kWarpSize] = accum; 316 // } 317 // __syncthreads(); 318 // if (threadIdx.x / kWarpSize == 0) { 319 // bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize); 320 // float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0; 321 // block_accum = warp_reduce(accum); 322 // if (threadIdx.x == 0) { 323 // out += block_accum; 324 // } 325 // } 326 // } 327 // ``` 328 // 329 // Column reduction uses the following algorithm: 330 // 331 // ``` 332 // void reduce(float** in, float* out) { 333 // __shared__ float[32][33] cache; 334 // int thread_id = GetThreadId(); 335 // int block_id = GetBlockId(); 336 // int tile_size = 128; 337 // 338 // float accum = 0; 339 // for (int i=0; i<tile_size; i++) { 340 // accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x]; 341 // } 342 // cache[thread_id.x][thread_id.y] = accum; 343 // 344 // __syncthreads(); 345 // accum = cache[thread_id.y][thread_id.x]; 346 // accum = warp_reduce(accum); // Sum all the values of `accum` in the same 347 // // warp. 348 // 349 // if (thread_id.y % 32 == 0) { 350 // out[block_id * 32 + thread_id.x] = accum; 351 // } 352 // } 353 // ``` 354 // 355 // Moreover, a heuristic is implemented to divide the reduce instructions 356 // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` 357 // for details about the heuristic.) Reduce instructions in the same group 358 // will run sequentially while different groups will run in parallel. 359 // 360 // we use raw block_id_y to select the reduce groups for execution without 361 // complicating the index calculation in the code generation of the reduce 362 // instructions. In other words, a block_id_y is assigned to a group and so 363 // different groups can be run in parallel. 364 Status EmitUnnestedReduction(mlir::Operation* unnested_hlo, 365 const FusionLayoutAnalysis& layout_analysis); 366 367 // Computes the KernelMappingScheme for the reduce HLO and indicates whether 368 // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo 369 // and first_reduce are the same instruction. For a kInput fusion, 370 // unnested_hlo is the fusion instruction while first_reduce is the first 371 // reduce op. 372 ReductionCodegenInfo ComputeReductionCodegenInfo( 373 mlir::Operation* unnested_hlo, mlir::Operation* first_reduce, 374 const FusionLayoutAnalysis& layout_analysis); 375 376 // Generates code for input-fusible slices. 377 // 378 // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes 379 // of all ROOT slices need to be the same while their output shapes can be 380 // different. On the other hand, the input ranges of slices can be 381 // overlapping. Further generalization/specialization when the needs are seen 382 // in the future. 383 Status EmitInputFusibleNonStridedSlices(mlir::Operation* op); 384 385 Status EmitElementForInputFusibleSlices( 386 const HloComputation* fused_computation, 387 absl::Span<const llvm_ir::IrArray> ir_arrays, 388 const llvm_ir::IrArray::Index& index); 389 390 // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in 391 // the process. Scatter indices are taken from `scatter_indices_gen`, updates 392 // from `updates_gen`. The output buffer is expected to have the operand 393 // values in it already. If unique_indices is false, we will use an atomic 394 // update. Using true for unique_indices behaves properly only when it is 395 // guaranteed that the indices to be updated do not overlap. The caller is 396 // responsible for ensuring this is the case. 397 Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter, 398 const LaunchDimensions& launch_dimensions, 399 const llvm_ir::IrArray& output, 400 const llvm_ir::ElementGenerator& scatter_indices_gen, 401 const llvm_ir::ElementGenerator& updates_gen, 402 std::function<llvm::Type*(int64_t)> get_index_type); 403 404 // Structure describing a scatter operation for IR emission. 405 // TODO(jurahul): Migrate element generators to use MLIR. 406 // Migrate update_computation to be an MLIR Region. 407 struct ScatterDescriptor { 408 std::string name; 409 Shape operand_shape; 410 Shape scatter_indices_shape; 411 Shape updates_shape; 412 mlir::mhlo::ScatterDimensionNumbers dim_numbers; 413 bool unique_indices; 414 const HloComputation* update_computation; 415 llvm_ir::IrArray output; 416 llvm_ir::ElementGenerator scatter_indices_gen; 417 llvm_ir::ElementGenerator updates_gen; 418 std::function<llvm::Type*(int64_t)> get_index_type; 419 }; 420 421 // Emits code for an in-place scatter using the provided scatter operation 422 // description. 423 Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk, 424 const LaunchDimensions& launch_dimensions); 425 426 // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel 427 // for the hlo instruction. 428 StatusOr<bool> CheckAndEmitHloWithTile021(mlir::Operation* op); 429 430 // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm. 431 // This is a helper to support the implementation of 432 // CheckAndEmitHloWithTile021. 433 void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk, 434 const MlirEmitterContext& context, 435 absl::Span<const llvm_ir::IrArray> operand_arrays, 436 absl::Span<const llvm_ir::IrArray> output_arrays, 437 absl::Span<const int64> reduced_output_dims, 438 absl::Span<const int64> tiled_param_ids, 439 const KernelMappingScheme& mapping_scheme, 440 const LaunchDimensions& launch_dimensions); 441 442 struct TilingKernelInfo { 443 // Tiling bounds. 444 std::array<llvm::Value*, 3> output_tile_bounds; 445 446 // Starting tile, as calculated from block id only. 447 llvm_ir::IrArray::Index tile_origin; 448 }; 449 450 // Emits a kernel for the hlo instruction using the given kernel mapping 451 // scheme. 452 TilingKernelInfo EmitTilingKernel( 453 const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, 454 const TileElementGenerator& tile_element_generator); 455 456 // Emits code to process up to 457 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, 458 // given `emit_elem_function` is the function to emit code to process one 459 // element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for 460 // the first element to process, and `index` is the index for the origin of 461 // the tile. Information about tile_size_x/y and num_threads_x/y are stored in 462 // `mapping_scheme`. Emits bounds check to ensure that each processed element 463 // is within the boundary defined by `tile_width` and `tile_height`. 464 // 465 // Pseudocode: 466 // 467 // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) { 468 // for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled 469 // if (dilated) { 470 // x_loc = x + j * num_threads_x; 471 // } else { 472 // x_loc = x * (tile_size_x / num_threads_x) + j; 473 // } 474 // 475 // if (x_loc < tile_width) { 476 // emit_elem_function(y + y_loc, x_loc); 477 // } 478 // } 479 // } 480 // 481 void EmitTile( 482 const KernelMappingScheme& mapping_scheme, 483 const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name, 484 KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info, 485 llvm::Value* tile_height, llvm::Value* tile_width, 486 const IrEmitterUnnested::EmitElementFunction& emit_elem_function); 487 488 // Emits code to process a tensor element in a tile for the given kLoop 489 // fusion HLO containing parameters that are 0-2-1 transpose of its outputs. 490 // y_loc: The y coordinate within a tile. 491 // x_loc: The x coordinate within a tile. 492 void EmitTileElementForFusion( 493 mlir::lmhlo::FusionOp fusion, 494 absl::Span<const llvm_ir::IrArray> operand_arrays, 495 absl::Span<const llvm_ir::IrArray> output_arrays, 496 const llvm_ir::IrArray::Index& index, 497 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, 498 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers); 499 500 // Emits code to process a tensor element in a tile for the given input hlo 501 // that is either a unnested kReduce or a kInput fusion. 502 // 503 // Calculates and stores the temporary reduction value in the corresponding 504 // alloca. 505 // 506 // `instr_index_group` indicates a set of reductions this call needs to emit, 507 // each i points to the ith output of unnested_hlo. Notice that if 508 // unnested_hlo is not a multi-output fusion, instr_index_group is always {0}. 509 void EmitTileElementForReduction( 510 mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape, 511 absl::Span<const int> instr_index_group, 512 HloComputation* fused_computation, FusedIrEmitter* fused_emitter, 513 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 514 absl::Span<HloComputation* const> reducers, 515 const llvm_ir::IrArray::Index& index, 516 const ReductionCodegenState& reduction_info, int64_t x_iter_num, 517 const FusionLayoutAnalysis& layout_analysis); 518 519 // Prepares for the code generation for a tile block of a reduction kernel. 520 // 521 // Create accumulator alloca's, populate them with initial values, and store 522 // inside reduction_info. 523 void EmitPrologueForReduction( 524 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group, 525 HloComputation* fused_computation, FusedIrEmitter* fused_emitter, 526 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 527 ReductionCodegenState* reduction_info, 528 const FusionLayoutAnalysis& layout_analysis); 529 530 // Wraps up the code generation for a tile block of a reduction kernel: 531 // write the calculated output into the output tensor. 532 void EmitEpilogueForReduction( 533 llvm::Type* index_ty, mlir::Operation* unnested_hlo, 534 absl::Span<const int> instr_index_group, 535 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 536 absl::Span<HloComputation* const> reducers, 537 const ReductionCodegenState& reduction_info, 538 const TilingKernelInfo& tiling_kernel_info, 539 const FusionLayoutAnalysis& layout_analysis); 540 541 // `current_output`: the value the tile has calculated. 542 // `output_address`: address where the output value has to be written. 543 void EmitEpilogueForRowReduction( 544 HloComputation* reducer, 545 const IrEmitterUnnested::ThreadIdInfo& thread_id_info, 546 const ReductionCodegenState& reduction_info, llvm::Type* element_type, 547 llvm::Type* index_ty, llvm::Value* current_output, 548 llvm::Value* output_address, int reduction_idx, int partial_result_idx); 549 550 // Same arguments as EmitEpilogueForRowReduction. 551 void EmitEpilogueForColumnReduction( 552 HloComputation* reducer, 553 const IrEmitterUnnested::ThreadIdInfo& thread_id_info, 554 const ReductionCodegenState& reduction_info, llvm::Type* element_type, 555 llvm::Type* index_ty, llvm::Value* current_output, 556 llvm::Value* output_address, int reduction_idx, int partial_result_idx, 557 const TilingKernelInfo& tiling_kernel_info); 558 559 // Emits code for reductions in the output_instructions. 560 void EmitIRForReduction(mlir::Operation* unnested_hlo, 561 absl::Span<const int> instr_index_group, 562 HloComputation* fused_computation, 563 FusedIrEmitter* fused_emitter, 564 absl::Span<const llvm_ir::IrArray> result_ir_arrays, 565 ReductionCodegenState* reduction_info, 566 const Shape& input_shape, 567 const FusionLayoutAnalysis& layout_analysis); 568 569 // For each reducer, emits the shuffle-down loop to accumulate the partial 570 // result to the global result. 571 void EmitFullWarpShuffleDownLoopForAllReduces( 572 absl::Span<HloComputation* const> reducers, 573 absl::Span<llvm::AllocaInst* const> partial_result_addresses, 574 int threads_per_block); 575 576 // Emits shuffle-down reduction for the `partial_result_address` using the 577 // reduction computation `reducer` over types `element_type`. 578 void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer, 579 llvm::Type* element_type, 580 llvm::Value* partial_result_address, 581 int threads_per_block); 582 583 std::unique_ptr<KernelThunk> BuildKernelThunkImpl( 584 absl::string_view name, Thunk::ThunkInfo thunk_info, 585 absl::Span<const BufferSlice> slices, 586 std::vector<llvm_ir::IrArray>* ir_arrays, 587 const LaunchDimensions& launch_dimensions); 588 589 StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunk( 590 mlir::Operation* op, mlir::ValueRange operands, 591 Thunk::ThunkInfo thunk_info, std::vector<llvm_ir::IrArray>* ir_arrays, 592 const LaunchDimensions& launch_dimensions); 593 594 StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunk( 595 mlir::Operation* op, Thunk::ThunkInfo thunk_info, 596 std::vector<llvm_ir::IrArray>* ir_arrays, 597 const LaunchDimensions& launch_dimensions); 598 599 // Returns a thunk that, given a reduce or select-and-scatter op, 600 // initializes its memory to the appropriate initial value. 601 std::unique_ptr<Thunk> BuildConstantInitializerThunk( 602 absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest, 603 const Shape& output_shape); 604 605 StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk( 606 mlir::Value init_value, mlir::Value dest); 607 608 StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(mlir::Operation* op, 609 mlir::Value init_value, 610 mlir::Value dest); 611 612 StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk( 613 mlir::lmhlo::FusionOp fusion, int output_index); 614 615 // Returns a WhileThunk that invokes thunk sequences for 'condition' and 616 // 'body' sub-computations of while instruction 'hlo'. 617 StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk( 618 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info); 619 620 // Returns a ForThunk which executes 'loop_limit' invocations of a thunk 621 // sequence from the 'body' sub-computation of the while instruction 'hlo'. 622 StatusOr<std::unique_ptr<Thunk>> BuildForThunk( 623 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, 624 const int64_t loop_limit); 625 626 // Returns a ConditionalThunk which executes the thunk sequence for the 627 // 'branch_computation' corresponding to the predicate/branch_index of the 628 // given conditional instruction. 629 StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk( 630 const HloInstruction* conditional); 631 632 // Emits current thread id with the given type. 633 // 634 // Sets the return value range to [0, threads_per_block). 635 llvm::Value* EmitThreadId(int64_t threads_per_block, llvm::Type* index_ty); 636 637 // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane 638 // id. 639 // 640 // Returns a struct containting these values. 641 ThreadIdInfo EmitThreadIdInfo(int64_t threads_per_block, llvm::Type* index_ty, 642 int64_t num_threads_x); 643 644 // Emit __syncthreads(), synchronization barrier for all threads in a block. 645 llvm::CallInst* EmitSyncThreads(); 646 647 // Emits current block id. 648 llvm::Value* EmitBlockId(); 649 650 // Prints a given format string with the given arguments, prefixed with 651 // thread id and block id, and postfixed with a newline. 652 // 653 // `thread_id_filter` and `block_id_filter`: if provided, restrict printing 654 // to only given thread and/or block id. 655 void EmitPrintfWithThreadId( 656 absl::string_view fmt, absl::Span<llvm::Value* const> arguments, 657 absl::optional<int64> thread_id_filter = absl::nullopt, 658 absl::optional<int64> block_id_filter = absl::nullopt); 659 660 // __shared__ memory uses a different address space, so we cast it to 661 // global address space before writing or reading. 662 llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Twine name = ""); 663 664 StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion( 665 mlir::Region* region, bool is_fusion); 666 667 // Returns the last generated thunk. LastThunk()668 Thunk* LastThunk() const { return thunk_sequence_.back().get(); } 669 670 Status AssertNonDeterminismIsOkay(const string& op_name); 671 672 // The thunk sequence this IrEmitter generates for the input computation. 673 ThunkSequence thunk_sequence_; 674 675 // Maps all-reduce-start ops to their thunk so done can access the thunk. 676 absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*> 677 all_reduce_start_thunks_; 678 679 // Begin optional members for XLA HLO -> LMHLO: 680 absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>> 681 scratch_nested_computations_; 682 // End optional members for XLA HLO -> LMHLO. 683 }; 684 685 } // namespace gpu 686 } // namespace xla 687 688 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 689