• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
21 #include "tensorflow/compiler/xla/service/custom_call_status.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
24 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
25 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
26 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
27 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
31 
32 namespace xla {
33 namespace gpu {
34 
35 struct BufferSlice {
36   // The root buffer to look at.
37   BufferAllocation::Slice buffer_slice;
38 
39   // The global constant name of the buffer, if it's a constant.
40   std::string constant_name;
41 
42   // The buffer is modified by the kernel.
43   bool written = false;
44 
45   Shape shape;
46 };
47 
48 // Convenience struct that contains useful data structures in MLIR emitter.
49 // Not all fields may be filled. It's entiredly dependent on the uses.
50 struct MlirEmitterContext {
51   void SetOperation(mlir::Operation* op);
52 
53   std::string name;
54   std::vector<Shape> operand_shapes;
55   std::vector<Shape> output_shapes;
56 };
57 
58 // Emits LLVM IR for an "unnested computation".
59 //
60 // An unnested computation is an HloComputation which you run by executing one
61 // or more kernels for each HloInstruction it contains.  Examples of unnested
62 // computations:
63 //
64 //  - An HloModule's root computation,
65 //  - The body of an HLO while loop,
66 //  - The true/false computation of an HLO conditional.
67 //
68 // Note the opportunity for confusion -- the while loop's computation is nested
69 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
70 // think about it too hard.
71 //
72 // Examples of things that are not unnested computations:
73 //
74 //  - The reducer of a kReduce HLO.  This is emitted using IrEmitterNested.
75 //  - The body of a fusion node.  IrEmitterUnnested emits the relevant code
76 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
77 //    really an IrEmitter, but is more an "IR generator generator".)
78 //
79 class IrEmitterUnnested : public IrEmitter {
80  public:
81   struct ThreadIdInfo {
82     // Raw thread id.
83     llvm::Value* thread_id;
84 
85     // X-coordinate calculated from thread id: `thread_id % num_threads_x`
86     llvm::Value* thread_id_x;
87 
88     // Y-coordinate calculated from thread id: `thread_id / num_threads_x`
89     llvm::Value* thread_id_y;
90 
91     // Lane id: `thread_id % kWarpSize`
92     llvm::Value* lane_id;
93   };
94 
platform_name()95   absl::string_view platform_name() const {
96     return ir_emitter_context_->platform_name();
97   }
98 
99   // A function object to generate code to process one element in a tile.
100   //
101   // index: the index for the first output element of the current thread.
102   // y_loc: The y coordinate within a tile.
103   // x_loc: The x coordinate within a tile.
104   // x_iter_num: When a thread process N elements in the X dimension, x_iter_num
105   //             has a value of 0..N-1 to identify the element being process.
106   using EmitElementFunction = std::function<void(
107       const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
108       llvm::Value* x_loc, int64_t x_iter_num)>;
109 
110   using ConstantGenerator = std::function<llvm::Value*(int64_t)>;
111 
112   // A function to generate the code to emit the entire tile.
113   using TileElementGenerator = std::function<void(
114       const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
115       const string& loop_name, llvm::Value* tile_height,
116       llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
117 
118   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
119   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
120 
121   static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create(
122       const HloModuleConfig& hlo_module_config,
123       IrEmitterContext* ir_emitter_context);
124 
125   // Transfers the ownship of thunk_sequence_ out.
ConsumeThunkSequence()126   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
127     return std::make_unique<ThunkSequence>(std::move(thunk_sequence_));
128   }
129 
130   Status EmitLmhloRegion(mlir::Region* region);
131 
132  private:
133   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
134                     IrEmitterContext* ir_emitter_context);
135 
136   // IrEmitterUnnested handles the following instructions differently from
137   // IrEmitter. It also mixes in some special handling for custom kernels
138   // via the ThunkEmitter.
139   Status EmitConstant(mlir::Operation* op);
140 
141   Status EmitCopy(mlir::Operation* op);
142 
143   Status EmitConditional(mlir::Operation* op);
144   Status EmitConvolutionThunk(mlir::Operation* op);
145   Status EmitGemmThunk(mlir::Operation* op);
146   Status EmitBatchNormThunk(mlir::Operation* op);
147 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
148   Status EmitCholeskyThunk(mlir::Operation* op);
149 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
150   Status EmitCustomCallThunk(mlir::Operation* op);
151   Status EmitFftThunk(mlir::Operation* op);
152   Status EmitFusion(mlir::Operation* op);
153   Status EmitLoopFusion(mlir::Operation* op);
154   Status EmitReduce(mlir::Operation* op);
155   Status EmitSelectAndScatter(mlir::Operation* op);
156   Status EmitWhile(mlir::Operation* op);
157   Status EmitInfeed(mlir::Operation* op);
158   Status EmitOutfeed(mlir::Operation* op);
159   Status EmitRngGetAndUpdateState(mlir::Operation* op);
160   Status EmitScatter(mlir::Operation* op);
161   Status EmitSort(mlir::Operation* op);
162   Status EmitTriangularSolve(mlir::Operation* op);
163 
164   template <typename NcclThunkType, typename OpTy>
165   Status EmitNcclThunk(mlir::Operation* op);
166   Status EmitAllReduceDone(mlir::Operation* op);
167 
168   template <typename ThunkType, typename OpT>
169   Status EmitReplicaOrPartitionId(mlir::Operation* op);
170 
171   Status EmitCollectivePermute(mlir::Operation* op);
172 
173   Status EmitOp(mlir::Operation* op);
174 
175   static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op);
176 
177   Status EmitTargetElementLoop(
178       const HloInstruction& hlo,
179       const llvm_ir::ElementGenerator& body_emitter) override;
180 
181   // Add a owning Thunk object to the thunk sequence.
AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)182   void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
183     thunk_sequence_.emplace_back(std::move(thunk));
184   }
185 
186   // Input = {static array, dynamic_dim0, dynamic_dim1}
187   // Output = {dynamic array(with dynamic dimension meta data at the end)}
188   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
189   // (`_` stands for padding)
190   // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
191   // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
192 
193   // pseudo code for padToStatic on a 2d array
194   //   ```
195   // void padToStatic(int** input, int** output, int threads_per_block,
196   //                  int meta_data_offset, int max_num_element,
197   //                  int static_dim0_size, int static_dim1_size) {
198   //   int* source_array = input[0];
199   //   int* dest_array = output[0];
200 
201   //   // extract the dynamic dimension from the source array's metadata
202   //   int* dyn_dim0_size = source_array + meta_data_offset;
203   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
204 
205   //   // only one thread need to store the dynamic index
206   //   int thread_id = GetThreadId();
207   //   int block_id = GetBlockId();
208   //   if (thread_id == 0 && block_id == 0) {
209   //     *output[1] = *dyn_dim0_size;
210   //     *output[2] = *dyn_dim1_size;
211   //   }
212 
213   //   int dyn_element_total = 1;
214   //   dyn_element_total *= *dyn_dim0_size;
215   //   dyn_element_total *= *dyn_dim1_size;
216   //   linear_index = block_id * threads_per_block + thread_id;
217   //   if (linear_index < max_num_element) {
218   //     Index static_index =
219   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
220   //     if (linerized_index < dyn_element_total) {
221   //       Index dyn_index =
222   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
223   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
224   //           source_array[static_index.dim0][static_index.dim1];
225   //     }
226   //   }
227   //   return;
228   // }
229   //   ```
230   Status EmitPadToStatic(mlir::Operation* op);
231 
232   // Input = {dynamic array(with dynamic dimension meta data at the end)}
233   // Output = {static array, dynamic_dim0, dynamic_dim1}
234   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
235   // (`_` stands for padding)
236   // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
237   // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
238 
239   // pseudo code for sliceToDynamic on a 2d array
240   //   ```
241   // void sliceToDynamic(int** input, int** output, int threads_per_block,
242   //                  int meta_data_offset, int max_num_element,
243   //                  int static_dim0_size, int static_dim1_size) {
244   //   int* source_array = input[0];
245   //   int* dest_array = output[0];
246 
247   //   // calculate the location where metadata needs to be inserted
248   //   int* dyn_dim0_size = dest_array + meta_data_offset;
249   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
250 
251   //   // only one thread need to store the dynamic index
252   //   int thread_id = GetThreadId();
253   //   int block_id = GetBlockId();
254   //   if (thread_id == 0 && block_id == 0) {
255   //     *dyn_dim0_size = *output[1];
256   //     *dyn_dim1_size = *output[2];
257   //   }
258 
259   //   int dyn_element_total = 1;
260   //   dyn_element_total *= *dyn_dim0_size;
261   //   dyn_element_total *= *dyn_dim1_size;
262   //   linear_index = block_id * threads_per_block + thread_id;
263   //   if (linear_index < max_num_element) {
264   //     Index static_index =
265   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
266   //     if (linerized_index < dyn_element_total) {
267   //       Index dyn_index =
268   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
269   //       dest_array[static_index.dim0][static_index.dim1] =
270   //           source_array[dyn_index.dim0][dyn_index.dim1];
271   //     }
272   //   }
273   //   return;
274   // }
275   //   ```
276   Status EmitSliceToDynamic(mlir::Operation* op);
277 
278   StatusOr<BufferAllocation::Slice> GetAllocationSlice(
279       mlir::Value v, std::string* constant_name = nullptr);
280 
ByteSizeOf(const Shape & shape)281   int64 ByteSizeOf(const Shape& shape) const {
282     return llvm_ir::ByteSizeOf(
283         shape, ir_emitter_context_->llvm_module()->getDataLayout());
284   }
285 
286   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
287   // This kernel takes as arguments pointers to the given buffer allocations.
288   llvm::Function* BuildKernelPrototype(
289       absl::string_view name, absl::Span<const BufferAllocation* const> args);
290 
291   // Helper for writing extra outputs from inside a reduce kernel.
292   Status EmitExtraOutputsForReduce(
293       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
294       const llvm_ir::IrArray::Index& index, bool use_linear_index,
295       absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
296           extra_output_gens);
297 
298   // Generates code for reduction to contiguous dimensions.
299   //
300   // Row reduction uses the following algorithm described in CUDA-like
301   // pseudocode:
302   //
303   // ```
304   //  __global__ void reduce(int num_rows, float *in, float out) {
305   //    __shared__ float[32] cache;
306   //    int offset = blockDim.x * blockIdx.x + threadIdx.x;
307   //    if (offset >= num_rows) return;
308   //    int tile_bound = std::min(offset + kTileSizeX, num_rows);
309   //    float accum = 0;
310   //    for (int i=offset; i<num_rows; i+= blockDim.x) {
311   //      accum += in[i];
312   //    }
313   //    accum = warp_reduce(accum);
314   //    if (threadIdx.x % kWarpSize == 0) {
315   //      cache[threadIdx.x / kWarpSize] = accum;
316   //    }
317   //    __syncthreads();
318   //    if (threadIdx.x / kWarpSize == 0) {
319   //      bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize);
320   //      float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0;
321   //      block_accum = warp_reduce(accum);
322   //      if (threadIdx.x == 0) {
323   //        out += block_accum;
324   //      }
325   //    }
326   //  }
327   // ```
328   //
329   // Column reduction uses the following algorithm:
330   //
331   // ```
332   // void reduce(float** in, float* out) {
333   //   __shared__ float[32][33] cache;
334   //   int thread_id = GetThreadId();
335   //   int block_id = GetBlockId();
336   //   int tile_size = 128;
337   //
338   //   float accum = 0;
339   //   for (int i=0; i<tile_size; i++) {
340   //     accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
341   //   }
342   //   cache[thread_id.x][thread_id.y] = accum;
343   //
344   //   __syncthreads();
345   //   accum = cache[thread_id.y][thread_id.x];
346   //   accum = warp_reduce(accum); // Sum all the values of `accum` in the same
347   //                               // warp.
348   //
349   //   if (thread_id.y % 32 == 0) {
350   //     out[block_id * 32 + thread_id.x] = accum;
351   //   }
352   // }
353   // ```
354   //
355   // Moreover, a heuristic is implemented to divide the reduce instructions
356   // into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
357   // for details about the heuristic.) Reduce instructions in the same group
358   // will run sequentially while different groups will run in parallel.
359   //
360   // we use raw block_id_y to select the reduce groups for execution without
361   // complicating the index calculation in the code generation of the reduce
362   // instructions. In other words, a block_id_y is assigned to a group and so
363   // different groups can be run in parallel.
364   Status EmitUnnestedReduction(mlir::Operation* unnested_hlo,
365                                const FusionLayoutAnalysis& layout_analysis);
366 
367   // Computes the KernelMappingScheme for the reduce HLO and indicates whether
368   // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
369   // and first_reduce are the same instruction. For a kInput fusion,
370   // unnested_hlo is the fusion instruction while first_reduce is the first
371   // reduce op.
372   ReductionCodegenInfo ComputeReductionCodegenInfo(
373       mlir::Operation* unnested_hlo, mlir::Operation* first_reduce,
374       const FusionLayoutAnalysis& layout_analysis);
375 
376   // Generates code for input-fusible slices.
377   //
378   // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
379   // of all ROOT slices need to be the same while their output shapes can be
380   // different. On the other hand, the input ranges of slices can be
381   // overlapping. Further generalization/specialization when the needs are seen
382   // in the future.
383   Status EmitInputFusibleNonStridedSlices(mlir::Operation* op);
384 
385   Status EmitElementForInputFusibleSlices(
386       const HloComputation* fused_computation,
387       absl::Span<const llvm_ir::IrArray> ir_arrays,
388       const llvm_ir::IrArray::Index& index);
389 
390   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
391   // the process. Scatter indices are taken from `scatter_indices_gen`, updates
392   // from `updates_gen`. The output buffer is expected to have the operand
393   // values in it already. If unique_indices is false, we will use an atomic
394   // update. Using true for unique_indices behaves properly only when it is
395   // guaranteed that the indices to be updated do not overlap. The caller is
396   // responsible for ensuring this is the case.
397   Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
398                      const LaunchDimensions& launch_dimensions,
399                      const llvm_ir::IrArray& output,
400                      const llvm_ir::ElementGenerator& scatter_indices_gen,
401                      const llvm_ir::ElementGenerator& updates_gen,
402                      std::function<llvm::Type*(int64_t)> get_index_type);
403 
404   // Structure describing a scatter operation for IR emission.
405   // TODO(jurahul): Migrate element generators to use MLIR.
406   //                Migrate update_computation to be an MLIR Region.
407   struct ScatterDescriptor {
408     std::string name;
409     Shape operand_shape;
410     Shape scatter_indices_shape;
411     Shape updates_shape;
412     mlir::mhlo::ScatterDimensionNumbers dim_numbers;
413     bool unique_indices;
414     const HloComputation* update_computation;
415     llvm_ir::IrArray output;
416     llvm_ir::ElementGenerator scatter_indices_gen;
417     llvm_ir::ElementGenerator updates_gen;
418     std::function<llvm::Type*(int64_t)> get_index_type;
419   };
420 
421   // Emits code for an in-place scatter using the provided scatter operation
422   // description.
423   Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk,
424                      const LaunchDimensions& launch_dimensions);
425 
426   // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
427   // for the hlo instruction.
428   StatusOr<bool> CheckAndEmitHloWithTile021(mlir::Operation* op);
429 
430   // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm.
431   // This is a helper to support the implementation of
432   // CheckAndEmitHloWithTile021.
433   void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk,
434                       const MlirEmitterContext& context,
435                       absl::Span<const llvm_ir::IrArray> operand_arrays,
436                       absl::Span<const llvm_ir::IrArray> output_arrays,
437                       absl::Span<const int64> reduced_output_dims,
438                       absl::Span<const int64> tiled_param_ids,
439                       const KernelMappingScheme& mapping_scheme,
440                       const LaunchDimensions& launch_dimensions);
441 
442   struct TilingKernelInfo {
443     // Tiling bounds.
444     std::array<llvm::Value*, 3> output_tile_bounds;
445 
446     // Starting tile, as calculated from block id only.
447     llvm_ir::IrArray::Index tile_origin;
448   };
449 
450   // Emits a kernel for the hlo instruction using the given kernel mapping
451   // scheme.
452   TilingKernelInfo EmitTilingKernel(
453       const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
454       const TileElementGenerator& tile_element_generator);
455 
456   // Emits code to process up to
457   // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
458   // given `emit_elem_function` is the function to emit code to process one
459   // element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
460   // the first element to process, and `index` is the index for the origin of
461   // the tile. Information about tile_size_x/y and num_threads_x/y are stored in
462   // `mapping_scheme`. Emits bounds check to ensure that each processed element
463   // is within the boundary defined by `tile_width` and `tile_height`.
464   //
465   // Pseudocode:
466   //
467   // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
468   //   for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
469   //     if (dilated) {
470   //       x_loc = x + j * num_threads_x;
471   //     } else {
472   //       x_loc = x * (tile_size_x / num_threads_x) + j;
473   //     }
474   //
475   //     if (x_loc < tile_width) {
476   //       emit_elem_function(y + y_loc, x_loc);
477   //     }
478   //   }
479   // }
480   //
481   void EmitTile(
482       const KernelMappingScheme& mapping_scheme,
483       const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
484       KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
485       llvm::Value* tile_height, llvm::Value* tile_width,
486       const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
487 
488   // Emits code to process a tensor element in a tile for the given kLoop
489   // fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
490   // y_loc: The y coordinate within a tile.
491   // x_loc: The x coordinate within a tile.
492   void EmitTileElementForFusion(
493       mlir::lmhlo::FusionOp fusion,
494       absl::Span<const llvm_ir::IrArray> operand_arrays,
495       absl::Span<const llvm_ir::IrArray> output_arrays,
496       const llvm_ir::IrArray::Index& index,
497       const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
498       llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
499 
500   // Emits code to process a tensor element in a tile for the given input hlo
501   // that is either a unnested kReduce or a kInput fusion.
502   //
503   // Calculates and stores the temporary reduction value in the corresponding
504   // alloca.
505   //
506   // `instr_index_group` indicates a set of reductions this call needs to emit,
507   // each i points to the ith output of unnested_hlo. Notice that if
508   // unnested_hlo is not a multi-output fusion, instr_index_group is always {0}.
509   void EmitTileElementForReduction(
510       mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
511       absl::Span<const int> instr_index_group,
512       HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
513       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
514       absl::Span<HloComputation* const> reducers,
515       const llvm_ir::IrArray::Index& index,
516       const ReductionCodegenState& reduction_info, int64_t x_iter_num,
517       const FusionLayoutAnalysis& layout_analysis);
518 
519   // Prepares for the code generation for a tile block of a reduction kernel.
520   //
521   // Create accumulator alloca's, populate them with initial values, and store
522   // inside reduction_info.
523   void EmitPrologueForReduction(
524       mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
525       HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
526       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
527       ReductionCodegenState* reduction_info,
528       const FusionLayoutAnalysis& layout_analysis);
529 
530   // Wraps up the code generation for a tile block of a reduction kernel:
531   // write the calculated output into the output tensor.
532   void EmitEpilogueForReduction(
533       llvm::Type* index_ty, mlir::Operation* unnested_hlo,
534       absl::Span<const int> instr_index_group,
535       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
536       absl::Span<HloComputation* const> reducers,
537       const ReductionCodegenState& reduction_info,
538       const TilingKernelInfo& tiling_kernel_info,
539       const FusionLayoutAnalysis& layout_analysis);
540 
541   // `current_output`: the value the tile has calculated.
542   // `output_address`: address where the output value has to be written.
543   void EmitEpilogueForRowReduction(
544       HloComputation* reducer,
545       const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
546       const ReductionCodegenState& reduction_info, llvm::Type* element_type,
547       llvm::Type* index_ty, llvm::Value* current_output,
548       llvm::Value* output_address, int reduction_idx, int partial_result_idx);
549 
550   // Same arguments as EmitEpilogueForRowReduction.
551   void EmitEpilogueForColumnReduction(
552       HloComputation* reducer,
553       const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
554       const ReductionCodegenState& reduction_info, llvm::Type* element_type,
555       llvm::Type* index_ty, llvm::Value* current_output,
556       llvm::Value* output_address, int reduction_idx, int partial_result_idx,
557       const TilingKernelInfo& tiling_kernel_info);
558 
559   // Emits code for reductions in the output_instructions.
560   void EmitIRForReduction(mlir::Operation* unnested_hlo,
561                           absl::Span<const int> instr_index_group,
562                           HloComputation* fused_computation,
563                           FusedIrEmitter* fused_emitter,
564                           absl::Span<const llvm_ir::IrArray> result_ir_arrays,
565                           ReductionCodegenState* reduction_info,
566                           const Shape& input_shape,
567                           const FusionLayoutAnalysis& layout_analysis);
568 
569   // For each reducer, emits the shuffle-down loop to accumulate the partial
570   // result to the global result.
571   void EmitFullWarpShuffleDownLoopForAllReduces(
572       absl::Span<HloComputation* const> reducers,
573       absl::Span<llvm::AllocaInst* const> partial_result_addresses,
574       int threads_per_block);
575 
576   // Emits shuffle-down reduction for the `partial_result_address` using the
577   // reduction computation `reducer` over types `element_type`.
578   void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer,
579                                             llvm::Type* element_type,
580                                             llvm::Value* partial_result_address,
581                                             int threads_per_block);
582 
583   std::unique_ptr<KernelThunk> BuildKernelThunkImpl(
584       absl::string_view name, Thunk::ThunkInfo thunk_info,
585       absl::Span<const BufferSlice> slices,
586       std::vector<llvm_ir::IrArray>* ir_arrays,
587       const LaunchDimensions& launch_dimensions);
588 
589   StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunk(
590       mlir::Operation* op, mlir::ValueRange operands,
591       Thunk::ThunkInfo thunk_info, std::vector<llvm_ir::IrArray>* ir_arrays,
592       const LaunchDimensions& launch_dimensions);
593 
594   StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunk(
595       mlir::Operation* op, Thunk::ThunkInfo thunk_info,
596       std::vector<llvm_ir::IrArray>* ir_arrays,
597       const LaunchDimensions& launch_dimensions);
598 
599   // Returns a thunk that, given a reduce or select-and-scatter op,
600   // initializes its memory to the appropriate initial value.
601   std::unique_ptr<Thunk> BuildConstantInitializerThunk(
602       absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
603       const Shape& output_shape);
604 
605   StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk(
606       mlir::Value init_value, mlir::Value dest);
607 
608   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(mlir::Operation* op,
609                                                          mlir::Value init_value,
610                                                          mlir::Value dest);
611 
612   StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
613       mlir::lmhlo::FusionOp fusion, int output_index);
614 
615   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
616   // 'body' sub-computations of while instruction 'hlo'.
617   StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(
618       mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info);
619 
620   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
621   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
622   StatusOr<std::unique_ptr<Thunk>> BuildForThunk(
623       mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
624       const int64_t loop_limit);
625 
626   // Returns a ConditionalThunk which executes the thunk sequence for the
627   // 'branch_computation' corresponding to the predicate/branch_index of the
628   // given conditional instruction.
629   StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk(
630       const HloInstruction* conditional);
631 
632   // Emits current thread id with the given type.
633   //
634   // Sets the return value range to [0, threads_per_block).
635   llvm::Value* EmitThreadId(int64_t threads_per_block, llvm::Type* index_ty);
636 
637   // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
638   // id.
639   //
640   // Returns a struct containting these values.
641   ThreadIdInfo EmitThreadIdInfo(int64_t threads_per_block, llvm::Type* index_ty,
642                                 int64_t num_threads_x);
643 
644   // Emit __syncthreads(), synchronization barrier for all threads in a block.
645   llvm::CallInst* EmitSyncThreads();
646 
647   // Emits current block id.
648   llvm::Value* EmitBlockId();
649 
650   // Prints a given format string with the given arguments, prefixed with
651   // thread id and block id, and postfixed with a newline.
652   //
653   // `thread_id_filter` and `block_id_filter`: if provided, restrict printing
654   // to only given thread and/or block id.
655   void EmitPrintfWithThreadId(
656       absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
657       absl::optional<int64> thread_id_filter = absl::nullopt,
658       absl::optional<int64> block_id_filter = absl::nullopt);
659 
660   // __shared__ memory uses a different address space, so we cast it to
661   // global address space before writing or reading.
662   llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Twine name = "");
663 
664   StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion(
665       mlir::Region* region, bool is_fusion);
666 
667   // Returns the last generated thunk.
LastThunk()668   Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
669 
670   Status AssertNonDeterminismIsOkay(const string& op_name);
671 
672   // The thunk sequence this IrEmitter generates for the input computation.
673   ThunkSequence thunk_sequence_;
674 
675   // Maps all-reduce-start ops to their thunk so done can access the thunk.
676   absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>
677       all_reduce_start_thunks_;
678 
679   // Begin optional members for XLA HLO -> LMHLO:
680   absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
681       scratch_nested_computations_;
682   // End optional members for XLA HLO -> LMHLO.
683 };
684 
685 }  // namespace gpu
686 }  // namespace xla
687 
688 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
689