• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
22 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
23 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
24 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
25 #include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 struct BufferSlice {
34   // The root buffer to look at.
35   BufferAllocation::Slice buffer_slice;
36 
37   // Describes how to dereference starting at that buffer to get to the buffer
38   // in question.
39   ShapeIndex gte_index;
40 };
41 
42 // Describes how to access a particular subshape for an HLO.  For instance if
43 // `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
44 // ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is
45 // found at `.buffer_slice`[3][4].  That is, `.slice` is a void***, which we
46 // dereference twice -- first at index 3, and then at index 4 -- to get the
47 // address of our buffer.
48 struct HloBufferSlice : public BufferSlice {
49   const HloInstruction* instr;
50   ShapeIndex hlo_index;
51 };
52 
53 struct MlirBufferSlice : public BufferSlice {
54   // The buffer is modified by the kernel.
55   bool written = false;
56 
57   Shape shape;
58 };
59 
60 struct MlirEmitterInput {
61   mlir::Operation* op;
62   Thunk::ThunkInfo thunk_info;
63 
64   // A field to allow plumbing extra information that BufferAssignment has, but
65   // LMHLO/MLIR representation does not have. Specifically, this is for passing
66   // the allocated buffer for tuple outputs (an array of pointers to tuple
67   // elements).
68   //
69   // TODO(timshen): We need a corresponding construct in LMHLO to represent
70   // this, aka an array of pointers to different memrefs. Once we have that, we
71   // can merge this information back to LMHLO graph and remove this field.
72   absl::optional<MlirBufferSlice> extra_slice;
73 };
74 
75 // Convenience struct that contains useful data structures in MLIR emitter.
76 // Not all fields may be filled. It's entiredly dependent on the uses.
77 struct MlirEmitterContext {
78   void SetOperation(mlir::Operation* op);
79 
80   std::string name;
81   std::vector<Shape> operand_shapes;
82   std::vector<Shape> output_shapes;
83 };
84 
85 // Emits LLVM IR for an "unnested computation".
86 //
87 // An unnested computation is an HloComputation which you run by executing one
88 // or more kernels for each HloInstruction it contains.  Examples of unnested
89 // computations:
90 //
91 //  - An HloModule's root computation,
92 //  - The body of an HLO while loop,
93 //  - The true/false computation of an HLO conditional.
94 //
95 // Note the opportunity for confusion -- the while loop's computation is nested
96 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
97 // think about it too hard.
98 //
99 // Examples of things that are not unnested computations:
100 //
101 //  - The reducer of a kReduce HLO.  This is emitted using IrEmitterNested.
102 //  - The body of a fusion node.  IrEmitterUnnested emits the relevant code
103 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
104 //    really an IrEmitter, but is more an "IR generator generator".)
105 //
106 class IrEmitterUnnested : public IrEmitter,
107                           private ThunkEmitter::EmissionContext {
108  public:
109   struct ThreadIdInfo {
110     // Raw thread id.
111     llvm::Value* thread_id;
112 
113     // X-coordinate calculated from thread id: `thread_id % num_threads_x`
114     llvm::Value* thread_id_x;
115 
116     // Y-coordinate calculated from thread id: `thread_id / num_threads_x`
117     llvm::Value* thread_id_y;
118 
119     // Lane id: `thread_id % kWarpSize`
120     llvm::Value* lane_id;
121   };
122 
platform_name()123   absl::string_view platform_name() const override {
124     return ir_emitter_context_->platform_name();
125   }
126 
127   // A function object to generate code to process one element in a tile.
128   //
129   // index: the index for the first output element of the current thread.
130   // y_loc: The y coordinate within a tile.
131   // x_loc: The x coordinate within a tile.
132   // x_iter_num: When a thread process N elements in the X dimension, x_iter_num
133   //             has a value of 0..N-1 to identify the element being process.
134   using EmitElementFunction = std::function<void(
135       const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
136       llvm::Value* x_loc, int64 x_iter_num)>;
137 
138   using ConstantGenerator = std::function<llvm::Value*(int64)>;
139 
140   // A function to generate the code to emit the entire tile.
141   using TileElementGenerator = std::function<void(
142       const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
143       const string& loop_name, llvm::Value* tile_height,
144       llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
145 
146   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
147   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
148 
149   static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create(
150       const HloModuleConfig& hlo_module_config,
151       const HloComputation* hlo_computation,
152       IrEmitterContext* ir_emitter_context);
153 
154   // Transfers the ownship of thunk_sequence_ out.
ConsumeThunkSequence()155   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
156     return std::make_unique<ThunkSequence>(std::move(thunk_sequence_));
157   }
158 
159   Status DefaultAction(HloInstruction* hlo) override;
160   Status HandleBitcast(HloInstruction* bitcast) override;
161   Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
162 
163   // IrEmitterUnnested handles the following instructions differently from
164   // IrEmitter. It also mixes in some special handling for custom kernels
165   // via the ThunkEmitter.
166   Status HandleConstant(HloInstruction* constant) override;
167   Status EmitConstant(MlirEmitterInput mlir_input);
168 
169   Status HandleCopy(HloInstruction* copy) override;
170   Status EmitCopyForMlir(MlirEmitterInput input);
171 
172   Status HandleConditional(HloInstruction* conditional) override;
173   Status HandleConvolution(HloInstruction* convolution) override;
174   Status HandleCustomCall(HloInstruction* custom_call) override;
175   Status EmitCustomCallFromMlir(MlirEmitterInput input);
176   Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
177   Status EmitGemmThunkFromMlir(MlirEmitterInput input);
178   Status EmitBatchNormThunkFromMlir(MlirEmitterInput input);
179 #if GOOGLE_CUDA
180   Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
181 #endif  // GOOGLE_CUDA
182   Status EmitCustomCallThunkFromMlir(MlirEmitterInput input);
183   Status HandleFft(HloInstruction* fft) override;
184   Status EmitFftThunkFromMlir(MlirEmitterInput input);
185   Status HandleFusion(HloInstruction* fusion) override;
186   Status EmitLoopFusionFromMlir(
187       MlirEmitterInput input, const Shape& output_shape,
188       absl::optional<int> unroll_factor_override = {});
189   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
190   Status HandleReduce(HloInstruction* reduce) override;
191   Status EmitReduceFromMlir(MlirEmitterInput mlir_input);
192   Status HandleSelectAndScatter(HloInstruction* instruction) override;
193   Status EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input);
194   Status HandleTuple(HloInstruction* tuple) override;
195   Status HandleWhile(HloInstruction* xla_while) override;
196   Status HandleInfeed(HloInstruction* xla_infeed) override;
197   Status HandleOutfeed(HloInstruction* outfeed) override;
198   Status HandleRng(HloInstruction* random) override;
199   Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
200   Status EmitRngGetAndUpdateState(MlirEmitterInput mlir_input);
201   Status HandleScatter(HloInstruction* scatter) override;
202   Status EmitScatterFromMlir(MlirEmitterInput mlir_input);
203   Status HandleSort(HloInstruction* sort) override;
204   Status EmitSortFromMlir(MlirEmitterInput mlir_input);
205   Status HandleTriangularSolve(HloInstruction* hlo) override;
206   Status EmitTriangularSolveFromMlir(MlirEmitterInput mlir_input);
207 
208   template <typename NcclThunkType, typename OpTy>
209   Status EmitNcclThunkFromMlir(MlirEmitterInput mlir_input);
210   Status HandleAllGather(HloInstruction* hlo) override;
211   Status HandleAllReduce(HloInstruction* hlo) override;
212   Status HandleAllToAll(HloInstruction* hlo) override;
213 
214   Status HandleAfterAll(HloInstruction* after_all) override;
215 
216   template <typename ThunkType, typename OpT>
217   Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
218   Status HandleReplicaId(HloInstruction* hlo) override;
219   Status HandlePartitionId(HloInstruction* hlo) override;
220 
221   Status HandleCollectivePermute(HloInstruction* hlo) override;
222   Status EmitCollectivePermuteFromMlir(MlirEmitterInput input);
223 
224   Status EmitOp(MlirEmitterInput mlir_input);
225 
226   Status EmitTargetElementLoop(
227       const HloInstruction& hlo,
228       const llvm_ir::ElementGenerator& body_emitter) override;
229 
230   // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
231   // `LastThunk()`. The kernel implementation will be unrolled if
232   // `unroll_factor` is greater than one.
233   Status EmitTargetElementLoopInThunk(
234       const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
235       KernelThunk* thunk, int unroll_factor, bool few_waves = false);
236 
237   Status Postprocess(HloInstruction* hlo) override;
238 
239  private:
240   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
241                     const HloComputation* hlo_computation,
242                     IrEmitterContext* ir_emitter_context);
243 
244   // Add a owning Thunk object to the thunk sequence.
AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)245   void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) override {
246     thunk_sequence_.emplace_back(std::move(thunk));
247   }
248 
249   // Input = {static array, dynamic_dim0, dynamic_dim1}
250   // Output = {dynamic array(with dynamic dimension meta data at the end)}
251   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
252   // (`_` stands for padding)
253   // Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
254   // Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
255 
256   // pseudo code for padToStatic on a 2d array
257   //   ```
258   // void padToStatic(int** input, int** output, int threads_per_block,
259   //                  int meta_data_offset, int max_num_element,
260   //                  int static_dim0_size, int static_dim1_size) {
261   //   int* source_array = input[0];
262   //   int* dest_array = output[0];
263 
264   //   // extract the dynamic dimension from the source array's metadata
265   //   int* dyn_dim0_size = source_array + meta_data_offset;
266   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
267 
268   //   // only one thread need to store the dynamic index
269   //   int thread_id = GetThreadId();
270   //   int block_id = GetBlockId();
271   //   if (thread_id == 0 && block_id == 0) {
272   //     *output[1] = *dyn_dim0_size;
273   //     *output[2] = *dyn_dim1_size;
274   //   }
275 
276   //   int dyn_element_total = 1;
277   //   dyn_element_total *= *dyn_dim0_size;
278   //   dyn_element_total *= *dyn_dim1_size;
279   //   linear_index = block_id * threads_per_block + thread_id;
280   //   if (linear_index < max_num_element) {
281   //     Index static_index =
282   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
283   //     if (linerized_index < dyn_element_total) {
284   //       Index dyn_index =
285   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
286   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
287   //           source_array[static_index.dim0][static_index.dim1];
288   //     }
289   //   }
290   //   return;
291   // }
292   //   ```
293   Status EmitPadToStaticFromMlir(MlirEmitterInput mlir_input);
294 
295   // Input = {dynamic array(with dynamic dimension meta data at the end)}
296   // Output = {static array, dynamic_dim0, dynamic_dim1}
297   // For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
298   // (`_` stands for padding)
299   // Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
300   // Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
301 
302   // pseudo code for sliceToDynamic on a 2d array
303   //   ```
304   // void sliceToDynamic(int** input, int** output, int threads_per_block,
305   //                  int meta_data_offset, int max_num_element,
306   //                  int static_dim0_size, int static_dim1_size) {
307   //   int* source_array = input[0];
308   //   int* dest_array = output[0];
309 
310   //   // calculate the location where metadata needs to be inserted
311   //   int* dyn_dim0_size = dest_array + meta_data_offset;
312   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
313 
314   //   // only one thread need to store the dynamic index
315   //   int thread_id = GetThreadId();
316   //   int block_id = GetBlockId();
317   //   if (thread_id == 0 && block_id == 0) {
318   //     *dyn_dim0_size = *output[1];
319   //     *dyn_dim1_size = *output[2];
320   //   }
321 
322   //   int dyn_element_total = 1;
323   //   dyn_element_total *= *dyn_dim0_size;
324   //   dyn_element_total *= *dyn_dim1_size;
325   //   linear_index = block_id * threads_per_block + thread_id;
326   //   if (linear_index < max_num_element) {
327   //     Index static_index =
328   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
329   //     if (linerized_index < dyn_element_total) {
330   //       Index dyn_index =
331   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
332   //       dest_array[static_index.dim0][static_index.dim1] =
333   //           source_array[dyn_index.dim0][dyn_index.dim1];
334   //     }
335   //   }
336   //   return;
337   // }
338   //   ```
339   Status EmitSliceToDynamicFromMlir(MlirEmitterInput mlir_input);
340 
341   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
MaybeGetAllocationSlice(const HloInstruction & hlo,const ShapeIndex & index)342   StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
343       const HloInstruction& hlo, const ShapeIndex& index) const override {
344     return ir_emitter_context_->buffer_assignment().GetUniqueSlice(&hlo, index);
345   }
346 
347   BufferAllocation::Slice GetAllocationSlice(
348       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
349     return MaybeGetAllocationSlice(hlo, index).ConsumeValueOrDie();
350   }
351 
352   StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(mlir::Value v);
353 
ByteSizeOf(const Shape & shape)354   int64 ByteSizeOf(const Shape& shape) const override {
355     return llvm_ir::ByteSizeOf(
356         shape, ir_emitter_context_->llvm_module()->getDataLayout());
357   }
358 
359   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
360   // This kernel takes as arguments pointers to the given buffer allocations.
361   llvm::Function* BuildKernelPrototype(
362       absl::string_view name, absl::Span<const BufferAllocation* const> args);
363 
364   // Helper for writing extra outputs from inside a reduce kernel.
365   Status EmitExtraOutputsForReduce(
366       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
367       const llvm_ir::IrArray::Index& index, bool use_linear_index,
368       absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
369           extra_output_gens);
370 
371   // Generates code for reduction to contiguous dimensions.
372   //
373   // Row reduction uses the following algorithm described in CUDA-like
374   // pseudocode:
375   //
376   // ```
377   //  __global__ void reduce(int num_rows, float *in, float out) {
378   //    __shared__ float[32] cache;
379   //    int offset = blockDim.x * blockIdx.x + threadIdx.x;
380   //    if (offset >= num_rows) return;
381   //    int tile_bound = std::min(offset + kTileSizeX, num_rows);
382   //    float accum = 0;
383   //    for (int i=offset; i<num_rows; i+= blockDim.x) {
384   //      accum += in[i];
385   //    }
386   //    accum = warp_reduce(accum);
387   //    if (threadIdx.x % kWarpSize == 0) {
388   //      cache[threadIdx.x / kWarpSize] = accum;
389   //    }
390   //    __syncthreads();
391   //    if (threadIdx.x / kWarpSize == 0) {
392   //      bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize);
393   //      float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0;
394   //      block_accum = warp_reduce(accum);
395   //      if (threadIdx.x == 0) {
396   //        out += block_accum;
397   //      }
398   //    }
399   //  }
400   // ```
401   //
402   // Column reduction uses the following algorithm:
403   //
404   // ```
405   // void reduce(float** in, float* out) {
406   //   __shared__ float[32][33] cache;
407   //   int thread_id = GetThreadId();
408   //   int block_id = GetBlockId();
409   //   int tile_size = 128;
410   //
411   //   float accum = 0;
412   //   for (int i=0; i<tile_size; i++) {
413   //     accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
414   //   }
415   //   cache[thread_id.x][thread_id.y] = accum;
416   //
417   //   __syncthreads();
418   //   accum = cache[thread_id.y][thread_id.x];
419   //   accum = warp_reduce(accum); // Sum all the values of `accum` in the same
420   //                               // warp.
421   //
422   //   if (thread_id.y % 32 == 0) {
423   //     out[block_id * 32 + thread_id.x] = accum;
424   //   }
425   // }
426   // ```
427   //
428   // Moreover, a heuristic is implemented to divide the reduce instructions
429   // into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
430   // for details about the heuristic.) Reduce instructions in the same group
431   // will run sequentially while different groups will run in parallel.
432   //
433   // we use raw block_id_y to select the reduce groups for execution without
434   // complicating the index calculation in the code generation of the reduce
435   // instructions. In other words, a block_id_y is assigned to a group and so
436   // different groups can be run in parallel.
437   Status EmitReductionFromOrToContiguousDimensions(MlirEmitterInput mlir_input);
438 
439   // Computes the KernelMappingScheme for the reduce HLO and indicates whether
440   // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
441   // and first_reduce are the same instruction. For a kInput fusion,
442   // unnested_hlo is the fusion instruction while first_reduce is the first
443   // reduce op.
444   ReductionCodegenInfo ComputeReductionCodegenInfo(
445       mlir::Operation* unnested_hlo, mlir::Operation* first_reduce);
446 
447   // Generates code for input-fusible slices.
448   //
449   // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
450   // of all ROOT slices need to be the same while their output shapes can be
451   // different. On the other hand, the input ranges of slices can be
452   // overlapping. Further generalization/specialization when the needs are seen
453   // in the future.
454   Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input);
455 
456   Status EmitElementForInputFusibleSlices(
457       mlir::lmhlo::FusionOp fusion,
458       absl::Span<const llvm_ir::IrArray> ir_arrays,
459       const llvm_ir::IrArray::Index& index);
460 
461   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
462   // the process. Scatter indices are taken from `scatter_indices_gen`, updates
463   // from `updates_gen`. The output buffer is expected to have the operand
464   // values in it already. If unique_indices is false, we will use an atomic
465   // update. Using true for unique_indices behaves properly only when it is
466   // guaranteed that the indices to be updated do not overlap. The caller is
467   // responsible for ensuring this is the case.
468   Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
469                      const llvm_ir::IrArray& output,
470                      const llvm_ir::ElementGenerator& scatter_indices_gen,
471                      const llvm_ir::ElementGenerator& updates_gen,
472                      std::function<llvm::Type*(int64)> get_index_type);
473 
474   // Structure describing a scatter operation for IR emission.
475   // TODO(jurahul): Migrate element generators to use MLIR.
476   //                Migrate update_computation to be an MLIR Region.
477   struct ScatterDescriptor {
478     std::string name;
479     Shape operand_shape;
480     Shape scatter_indices_shape;
481     Shape updates_shape;
482     mlir::mhlo::ScatterDimensionNumbers dim_numbers;
483     bool unique_indices;
484     const HloComputation* update_computation;
485     llvm_ir::IrArray output;
486     llvm_ir::ElementGenerator scatter_indices_gen;
487     llvm_ir::ElementGenerator updates_gen;
488     std::function<llvm::Type*(int64)> get_index_type;
489   };
490 
491   // Emits code for an in-place scatter using the provided scatter operation
492   // description.
493   Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk);
494 
495   // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
496   // for the hlo instruction.
497   StatusOr<bool> CheckAndEmitHloWithTile021(MlirEmitterInput input);
498 
499   // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
500   // sets the corresponding launch dimensions. This is a helper to support
501   // the implementation of CheckAndEmitHloWithTile021.
502   void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk,
503                       const MlirEmitterContext& context,
504                       absl::Span<const llvm_ir::IrArray> operand_arrays,
505                       absl::Span<const llvm_ir::IrArray> output_arrays,
506                       absl::Span<const int64> reduced_output_dims,
507                       absl::Span<const int64> tiled_param_ids);
508 
509   struct TilingKernelInfo {
510     // Tiling bounds.
511     std::array<llvm::Value*, 3> output_tile_bounds;
512 
513     // Starting tile, as calculated from block id only.
514     llvm_ir::IrArray::Index tile_origin;
515   };
516 
517   // Emits a kernel for the hlo instruction using the given kernel mapping
518   // scheme.
519   TilingKernelInfo EmitTilingKernel(
520       const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
521       const TileElementGenerator& tile_element_generator);
522 
523   // Emits code to process up to
524   // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
525   // given `emit_elem_function` is the function to emit code to process one
526   // element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
527   // the first element to process, and `index` is the index for the origin of
528   // the tile. Information about tile_size_x/y and num_threads_x/y are stored in
529   // `mapping_scheme`. Emits bounds check to ensure that each processed element
530   // is within the boundary defined by `tile_width` and `tile_height`.
531   //
532   // Pseudocode:
533   //
534   // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
535   //   for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
536   //     if (dilated) {
537   //       x_loc = x + j * num_threads_x;
538   //     } else {
539   //       x_loc = x * (tile_size_x / num_threads_x) + j;
540   //     }
541   //
542   //     if (x_loc < tile_width) {
543   //       emit_elem_function(y + y_loc, x_loc);
544   //     }
545   //   }
546   // }
547   //
548   void EmitTile(
549       const KernelMappingScheme& mapping_scheme,
550       const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
551       KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
552       llvm::Value* tile_height, llvm::Value* tile_width,
553       const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
554 
555   // Emits code to process a tensor element in a tile for the given kCopy HLO
556   // that performs a 0-2-1 transpose.
557   // y_loc: The y coordinate within a tile.
558   // x_loc: The x coordinate within a tile.
559   void EmitTileElementForCopy(
560       const Shape& output_shape, const llvm_ir::IrArray& ir_array,
561       const llvm_ir::IrArray::Index& index,
562       const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
563       llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
564 
565   // Emits code to process a tensor element in a tile for the given kLoop
566   // fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
567   // y_loc: The y coordinate within a tile.
568   // x_loc: The x coordinate within a tile.
569   void EmitTileElementForFusion(
570       mlir::lmhlo::FusionOp fusion,
571       absl::Span<const llvm_ir::IrArray> operand_arrays,
572       absl::Span<const llvm_ir::IrArray> output_arrays,
573       const llvm_ir::IrArray::Index& index,
574       const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
575       llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
576 
577   // Emits code to process a tensor element in a tile for the given input hlo
578   // that is either a unnested kReduce or a kInput fusion.
579   //
580   // Calculates and stores the temporary reduction value in the corresponding
581   // alloca.
582   //
583   // `instr_index_group` indicates a set of reductions this call needs to emit,
584   // each i points to the ith output of unnested_hlo. Notice that if
585   // unnested_hlo is not a multi-output fusion, instr_index_group is always {0}.
586   void EmitTileElementForReduction(
587       mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
588       absl::Span<const int> instr_index_group,
589       HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
590       absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
591       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
592       absl::Span<HloComputation* const> reducers,
593       const llvm_ir::IrArray::Index& index,
594       const ReductionCodegenInfo& reduction_info, int64 x_iter_num);
595 
596   // Prepares for the code generation for a tile block of a reduction kernel.
597   //
598   // Create accumulator alloca's, populate them with initial values, and store
599   // inside reduction_info.
600   void EmitPrologueForReduction(
601       mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
602       HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
603       absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
604       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
605       ReductionCodegenInfo* reduction_info);
606 
607   // Wraps up the code generation for a tile block of a reduction kernel:
608   // write the calculated output into the output tensor.
609   void EmitEpilogueForReduction(
610       llvm::Type* index_ty, mlir::Operation* unnested_hlo,
611       absl::Span<const int> instr_index_group,
612       absl::Span<const llvm_ir::IrArray> result_ir_arrays,
613       absl::Span<HloComputation* const> reducers,
614       const ReductionCodegenInfo& reduction_info,
615       const TilingKernelInfo& tiling_kernel_info);
616 
617   // Emits code for reductions in the output_instructions.
618   void EmitIRForReduction(mlir::Operation* unnested_hlo,
619                           absl::Span<const int> instr_index_group,
620                           HloComputation* fused_computation,
621                           FusedIrEmitter* fused_emitter,
622                           absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
623                           absl::Span<const llvm_ir::IrArray> result_ir_arrays,
624                           ReductionCodegenInfo* reduction_info,
625                           const Shape& input_shape);
626 
627   // For each reducer, emits the shuffle-down loop to accumulate the partial
628   // result to the global result.
629   void EmitFullWarpShuffleDownLoopForAllReduces(
630       absl::Span<HloComputation* const> reducers,
631       absl::Span<llvm::AllocaInst* const> partial_result_addresses,
632       int threads_per_block);
633 
634   // Emits shuffle-down reduction for the `partial_result_address` using the
635   // reduction computation `reducer` over types `element_type`.
636   void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer,
637                                             llvm::Type* element_type,
638                                             llvm::Value* partial_result_address,
639                                             int threads_per_block);
640 
641   std::unique_ptr<KernelThunk> BuildKernelThunkFromBufferSlices(
642       absl::string_view name, Thunk::ThunkInfo thunk_info,
643       absl::Span<const BufferSlice* const> slices,
644       std::function<void(const BufferSlice*, llvm::Value*)>
645           bind_slice_to_ir_value);
646 
647   // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
648   // caller needs to make sure `inst` outlives the lifetime of the returned
649   // Thunk object. 'implements_whole_instruction' specifies whether this
650   // KernelThunk implements the whole 'inst' HloInstruction. In some cases
651   // 'inst' will be implemented by a sequence of Thunks.
652   std::unique_ptr<KernelThunk> BuildKernelThunk(
653       const HloInstruction* inst, bool implements_whole_instruction);
654 
655   std::unique_ptr<KernelThunk> BuildKernelThunkForMlirImpl(
656       absl::string_view name, Thunk::ThunkInfo thunk_info,
657       absl::Span<const MlirBufferSlice> slices,
658       std::vector<llvm_ir::IrArray>* ir_arrays);
659 
660   StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunkForMlir(
661       mlir::Operation* op, mlir::ValueRange operands,
662       Thunk::ThunkInfo thunk_info, absl::optional<MlirBufferSlice> extra_slice,
663       std::vector<llvm_ir::IrArray>* ir_arrays);
664 
665   StatusOr<std::unique_ptr<KernelThunk>> BuildKernelThunkForMlir(
666       mlir::Operation* op, Thunk::ThunkInfo thunk_info,
667       absl::optional<MlirBufferSlice> extra_slice,
668       std::vector<llvm_ir::IrArray>* ir_arrays);
669 
670   // Returns a thunk that, given a reduce or select-and-scatter op,
671   // initializes its memory to the appropriate initial value.
672   std::unique_ptr<Thunk> BuildConstantInitializerThunk(
673       absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
674       const Shape& output_shape);
675 
676   StatusOr<std::unique_ptr<Thunk>> TryBuildConstantInitializerThunk(
677       mlir::Value init_value, mlir::Value dest);
678 
679   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunkForMlir(
680       mlir::Operation* op, mlir::Value init_value, mlir::Value dest);
681 
682   StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunkForMlir(
683       mlir::lmhlo::FusionOp fusion, int output_index);
684 
685   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
686   // 'body' sub-computations of while instruction 'hlo'.
687   StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo);
688 
689   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
690   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
691   StatusOr<std::unique_ptr<Thunk>> BuildForThunk(const HloInstruction* hlo,
692                                                  const int64 loop_limit);
693 
694   // Returns a ConditionalThunk which executes the thunk sequence for the
695   // 'branch_computation' corresponding to the predicate/branch_index of the
696   // given conditional instruction.
697   StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk(
698       const HloInstruction* hlo);
699 
700   // Emits current thread id with the given type.
701   //
702   // Sets the return value range to [0, threads_per_block).
703   llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty);
704 
705   // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
706   // id.
707   //
708   // Returns a struct containting these values.
709   ThreadIdInfo EmitThreadIdInfo(int64 threads_per_block, llvm::Type* index_ty,
710                                 int64 num_threads_x);
711 
712   // Emit __syncthreads(), synchronization barrier for all threads in a block.
713   llvm::CallInst* EmitSyncThreads();
714 
715   // Emits current block id.
716   llvm::Value* EmitBlockId();
717 
718   // Prints a given format string with the given arguments, prefixed with
719   // thread id and block id, and postfixed with a newline.
720   //
721   // `thread_id_filter` and `block_id_filter`: if provided, restrict printing
722   // to only given thread and/or block id.
723   void EmitPrintfWithThreadId(
724       absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
725       absl::optional<int64> thread_id_filter = absl::nullopt,
726       absl::optional<int64> block_id_filter = absl::nullopt);
727 
728   StatusOr<HloComputation*> GetOrCreateSubComputationFromRegion(
729       mlir::Region* region, bool is_fusion);
730 
731   StatusOr<MlirEmitterInput> GetMlirEmitterInput(HloInstruction* hlo);
732 
733   // Returns the last generated thunk.
LastThunk()734   Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
735 
736   Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
737 
738   Status AssertNonDeterminismIsOkay(const string& op_name);
739 
740   // The thunk sequence this IrEmitter generates for the input computation.
741   ThunkSequence thunk_sequence_;
742 
743   // Begin optional members for XLA HLO -> LMHLO:
744   // TODO(timshen): Once XLA HLO -> LMHLO converter is complete,
745   // IrEmitterUnnested should take LMHLO only, and won't require a scratch
746   // module.
747   absl::optional<mlir::OwningModuleRef> mlir_scratch_module_;
748 
749   // This is for cache-purpose only. It has no significant semantics.
750   absl::optional<mlir::LhloDialectEmitter> lhlo_scratch_emitter_;
751 
752   absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
753       scratch_nested_computations_;
754   // End optional members for XLA HLO -> LMHLO.
755 };
756 
757 }  // namespace gpu
758 }  // namespace xla
759 
760 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
761