• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
17 
18 #include <stdlib.h>
19 
20 #include <atomic>
21 #include <functional>
22 #include <utility>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/AsmParser/Parser.h"
28 #include "llvm/Bitcode/BitcodeReader.h"
29 #include "llvm/Bitcode/BitcodeWriter.h"
30 #include "llvm/IR/DiagnosticInfo.h"
31 #include "llvm/IR/DiagnosticPrinter.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/Verifier.h"
35 #include "llvm/Transforms/Utils/SplitModule.h"
36 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
37 #include "mlir/InitAllDialects.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
41 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
42 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
43 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
44 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
45 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
46 #include "tensorflow/compiler/xla/service/call_inliner.h"
47 #include "tensorflow/compiler/xla/service/comparison_expander.h"
48 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
49 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
50 #include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
51 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
52 #include "tensorflow/compiler/xla/service/dump.h"
53 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
54 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
55 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
56 #include "tensorflow/compiler/xla/service/gather_expander.h"
57 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
58 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
59 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
60 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
61 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
62 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
63 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
64 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
65 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
66 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
67 #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
68 #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
69 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
70 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
71 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
72 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
73 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
74 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
75 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
76 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
77 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
78 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
79 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
80 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
81 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
82 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
83 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
84 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
85 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
86 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
87 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
88 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
89 #include "tensorflow/compiler/xla/service/hlo_computation.h"
90 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
91 #include "tensorflow/compiler/xla/service/hlo_cse.h"
92 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
93 #include "tensorflow/compiler/xla/service/hlo_dce.h"
94 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
95 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
96 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
97 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
98 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
99 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
100 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
101 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
102 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
103 #include "tensorflow/compiler/xla/service/logistic_expander.h"
104 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
105 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
106 #include "tensorflow/compiler/xla/service/qr_expander.h"
107 #include "tensorflow/compiler/xla/service/reshape_mover.h"
108 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
109 #include "tensorflow/compiler/xla/service/rng_expander.h"
110 #include "tensorflow/compiler/xla/service/slice_sinker.h"
111 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
112 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
113 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
114 #include "tensorflow/compiler/xla/service/transpose_folding.h"
115 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
116 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
117 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
118 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
119 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
120 #include "tensorflow/compiler/xla/status_macros.h"
121 #include "tensorflow/compiler/xla/types.h"
122 #include "tensorflow/compiler/xla/util.h"
123 #include "tensorflow/core/lib/core/status.h"
124 #include "tensorflow/core/lib/gtl/cleanup.h"
125 #include "tensorflow/core/lib/io/path.h"
126 #include "tensorflow/core/platform/blocking_counter.h"
127 #include "tensorflow/core/platform/env.h"
128 #include "tensorflow/core/platform/logging.h"
129 #include "tensorflow/core/platform/regexp.h"
130 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
131 #include "tensorflow/core/platform/subprocess.h"
132 #include "tensorflow/core/platform/threadpool.h"
133 #include "tensorflow/core/platform/tracing.h"
134 #include "tensorflow/core/profiler/lib/traceme.h"
135 #include "tensorflow/core/util/env_var.h"
136 
137 namespace xla {
138 namespace gpu {
139 
GpuCompiler(se::Platform::Id platform_id,const char * target_triple,const char * data_layout)140 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
141                          const char* target_triple, const char* data_layout)
142     : platform_id_(platform_id),
143       target_triple_(target_triple),
144       data_layout_(data_layout),
145       pointer_size_(llvm::DataLayout(data_layout)
146                         .getPointerSize(0 /* default address space */)) {}
147 
148 // Runs optimization passes on the given HLO module.
OptimizeHloModule(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)149 Status GpuCompiler::OptimizeHloModule(
150     HloModule* hlo_module, se::StreamExecutor* stream_exec,
151     se::DeviceMemoryAllocator* device_allocator) {
152   {
153     HloPassPipeline pipeline("optimization");
154     pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
155                                               /*allow_mixed_precision=*/false);
156 
157     pipeline.AddPass<AllGatherDecomposer>(
158         [](const HloAllGatherInstruction& ag) {
159           return !NcclAllGatherThunk::CanImplement(&ag);
160         });
161     pipeline.AddPass<AllToAllDecomposer>();
162 
163     pipeline.AddPass<OperandUpcaster>();
164 
165     // Expand random number generation.
166     pipeline.AddPass<RngExpander>();
167     pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
168 
169     // Comparison total order expander
170     pipeline.AddPass<ComparisonExpander>();
171 
172     // Remove zero-sized HLO from the input so that other passes don't have to
173     // handle it.
174     pipeline.AddPass<ZeroSizedHloElimination>();
175 
176     pipeline.AddPass<GpuScatterExpander>();
177     // TODO(phawkins): replace QR decompositions with calls to cuSOLVER.
178     pipeline.AddPass<QrExpander>();
179 
180     pipeline.AddPass<DynamicIndexSplitter>();
181 
182     // TODO(b/64094172): make Call work on GPU instead of inlining.
183     pipeline.AddPass<CallInliner>();
184 
185     pipeline.AddPass<DotDecomposer>();
186 
187     pipeline.AddPass<Convolution4DExpander>();
188 
189     // Expand the sort op to support stable sorting if required.
190     pipeline.AddPass<StableSortExpander>();
191     // Convert BF16 operations to F32 operations so that the GPU backend can
192     // support BF16 operations without directly implementing a BF16 lowering for
193     // most ops.
194     pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
195 
196     // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
197     // where possible.  Not every batchnorm op can be implemented as a call to
198     // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
199     if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
200       // Since BatchNorm inference is essentially pointwise operations, it is
201       // always advantageous to use kernel fusion rather than cudnn.
202       pipeline.AddPass<BatchNormExpander>(
203           /*rewrite_training_op=*/false,
204           /*rewrite_inference_op=*/true,
205           /*rewrite_grad_op=*/false);
206       pipeline.AddPass<CudnnBatchNormRewriter>();
207     }
208     pipeline.AddPass<BatchNormExpander>(
209         /*rewrite_training_op=*/true,
210         /*rewrite_inference_op=*/true,
211         /*rewrite_grad_op=*/true);
212 
213     pipeline.AddPass<LogisticExpander>(
214         /*expansion_type=*/LogisticExpansionType::kExp);
215     pipeline.AddPass<ConditionalCanonicalizer>();
216     pipeline.AddPass<DynamicPadder>();
217 
218     {
219       auto& pass =
220           pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
221       pass.AddInvariantCheckerDebug<HloVerifier>(
222           /*layout_sensitive=*/false,
223           /*allow_mixed_precision=*/false);
224 
225       // BatchNormExpander can create zero-sized ops, so zero-sized HLO
226       // elimination has to come after that pass.
227       pass.AddPass<ZeroSizedHloElimination>();
228 
229       pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
230       pass.AddPass<ScatterExpander>(ScatterExpander::kEliminateSimpleScatters);
231 
232       AlgebraicSimplifierOptions options;
233       // When transposes appear in a fusion node, we can easily adjust the
234       // multi-dimensional index to create the one needed for the operand. This
235       // is not as easy with bitcasts, because we don't have the information
236       // readily available which dimensions are permuted. In addition to that,
237       // if we have a transpose and a reshape next to each other, they will both
238       // be replaced by a bitcast, and we replace bitcast(bitcast) with one
239       // bitcast. This leads to having to linearize and then delinearize the
240       // index.
241       options.set_replace_transpose_with_bitcast(false);
242       options.set_enable_conv_operand_swap(false);
243       pass.AddPass<AlgebraicSimplifier>(options);
244       // AlgebraicSimplifier may add contracting dimensions to a dot.
245       pass.AddPass<DotDecomposer>();
246       pass.AddPass<SortSimplifier>();
247       pass.AddPass<TupleSimplifier>();
248       pass.AddPass<WhileLoopConstantSinking>();
249       pass.AddPass<WhileLoopSimplifier>();
250 
251       // TODO(b/134075051): Re-enable after b/134075051 is fixed.
252       // pass.AddPass<SliceSinker>();
253 
254       pass.AddPass<HloDCE>();
255       pass.AddPass<ReshapeMover>();
256       pass.AddPass<HloConstantFolding>();
257       pass.AddPass<ConditionalSimplifier>();
258     }
259 
260     pipeline.AddPass<TransposeFolding>(
261         [](const HloInstruction& dot,
262            const TransposeFolding::OperandIndices& candidate_operands) {
263           return IsMatrixMultiplication(dot)
264                      ? candidate_operands
265                      : TransposeFolding::OperandIndices{};
266         });
267     pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
268     pipeline.AddPass<HloDCE>();
269 
270     // Run WhileLoopTripCountAnnotator at the end of the simplification
271     // pipeline, before layout assignment and fusion.  This pass does some
272     // pattern-matching on while bodies/conditions, and this is where the HLO is
273     // "nicest".
274     //
275     // It's important that we don't make semantic changes (e.g. unrolling) to
276     // any `while` loops after this point, because otherwise the trip-count
277     // annotations added by this pass may not be correct after the
278     // modifications.
279     pipeline.AddPass<WhileLoopTripCountAnnotator>();
280     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
281   }
282 
283   // Run target-specific HLO optimization passes for convolution
284   // canonicalization.
285   TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
286       hlo_module, stream_exec, device_allocator));
287 
288   {
289     // Run layout assignment in a separate pipeline from
290     // "post-layout-assignment" because we want everything after layout
291     // assignment to have a layout-sensitive invariant-checker, but
292     // HloPassPipeline also runs its invariant checker before any passes are
293     // run, meaning, the pipeline that contains layout assignment cannot contain
294     // a layout-sensitive verifier!
295     HloPassPipeline pipeline("layout assignment");
296     // Layout assignment uses alias analysis, which requires the call graph to
297     // be flattened.
298     pipeline.AddPass<FlattenCallGraph>();
299     pipeline.AddPass<GpuLayoutAssignment>(
300         hlo_module->mutable_entry_computation_layout(),
301         LayoutAssignment::InstructionCanChangeLayout, stream_exec);
302     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
303   }
304 
305   // Run target-specific HLO optimization passes after layout assignment.
306   TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
307                                                      device_allocator));
308 
309   {
310     HloPassFix<HloPassPipeline> fusion("fusion");
311     // We try to split variadic ops with many parameters into several such ops
312     // to avoid exceeding the parameter space.
313     fusion.AddPass<VariadicOpSplitter>();
314     /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
315      * fixing the ticket. */
316     fusion.AddInvariantCheckerDebug<HloVerifier>(
317         /*layout_sensitive=*/true,
318         /*allow_mixed_precision=*/false,
319         LayoutAssignment::InstructionCanChangeLayout);
320     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
321     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
322     fusion.AddPass<FusionMerger>();
323     fusion.AddPass<GpuMultiOutputFusion>();
324     fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
325                            /*only_fusion_computations=*/true);
326     fusion.AddPass<HloDCE>();
327     TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
328 
329     HloPassPipeline horizontal_fusion("horizontal_fusion");
330     horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
331     horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
332     horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
333                                       /*only_fusion_computations=*/true);
334     horizontal_fusion.AddPass<HloDCE>();
335     TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
336   }
337 
338   {
339     HloPassPipeline pipeline("all_reduce_combiner");
340     pipeline.AddPass<AllReduceCombiner>(
341         /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
342         /*combine_threshold_count=*/256);
343     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
344   }
345   {
346     // Now we allow to replace any transposes outside of fusions with bitcasts.
347     HloPassPipeline pipeline("final_algebraic_simplifier");
348     AlgebraicSimplifierOptions options;
349     options.set_is_layout_sensitive(true);
350     options.set_enable_conv_operand_swap(false);
351     pipeline.AddPass<AlgebraicSimplifier>(options);
352     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
353   }
354   return Status::OK();
355 }
356 
357 // Modifies the given HLO module so that it will be accepted by IrEmitter.
358 // Unlike optimization passes, the passes are necessary for correctness.
PrepareHloModuleForIrEmitting(HloModule * hlo_module)359 Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
360   // In some cases, we have to place the result of an instruction in a temporary
361   // buffer. For instance, the buffer that holds an external parameter is
362   // assumed immutable at this point, and should not be reused for output
363   // (b/27180329). Therefore, in that case, we set the output to be a copy of
364   // the parameter.
365   HloPassPipeline pipeline("GPU-ir-emit-prepare");
366   /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
367    * fixing the ticket. */
368   pipeline.AddInvariantCheckerDebug<HloVerifier>(
369       /*layout_sensitive=*/true,
370       /*allow_mixed_precision=*/false,
371       LayoutAssignment::InstructionCanChangeLayout);
372 
373   // Copy insertion should be performed immediately before IR emission to avoid
374   // inserting unnecessary copies (later pass adds an instruction which
375   // materializes the value) or missing a necessary copy (later pass removes an
376   // instruction which materializes a value). DCE must be run immediately before
377   // (and sometime after) copy insertion, to avoid dead code from interfering
378   // with the rewrites.
379   pipeline.AddPass<HloDCE>();
380   if (hlo_module->config().alias_passthrough_params()) {
381     pipeline.AddPass<AliasPassthroughParams>();
382   }
383   pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
384   pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
385   pipeline.AddPass<GpuSanitizeConstantNames>();
386   return pipeline.Run(hlo_module).status();
387 }
388 
389 // TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a
390 // right way to share this.
RequireDeterminism()391 static bool RequireDeterminism() {
392   static bool require_determinism = [] {
393     bool deterministic_ops = false;
394     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
395                                                /*default_val=*/false,
396                                                &deterministic_ops));
397     return deterministic_ops;
398   }();
399   return require_determinism;
400 }
401 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)402 Status GpuCompiler::OptimizeHloPostLayoutAssignment(
403     HloModule* hlo_module, se::StreamExecutor* stream_exec,
404     se::DeviceMemoryAllocator* device_allocator) {
405   HloPassPipeline pipeline("post-layout_assignment");
406   /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
407    * fixing the ticket. */
408   pipeline.AddInvariantCheckerDebug<HloVerifier>(
409       /*layout_sensitive=*/true,
410       /*allow_mixed_precision=*/false,
411       LayoutAssignment::InstructionCanChangeLayout);
412 
413   pipeline.AddPass<ReductionDegenerateDimRemover>();
414   pipeline.AddPass<ReductionLayoutNormalizer>();
415   pipeline.AddPass<ReductionDimensionGrouper>();
416   pipeline.AddPass<HloPassFix<ReductionSplitter>>();
417 
418   // The LayoutAssignment pass may leave behind kCopy instructions which are
419   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
420   AlgebraicSimplifierOptions options;
421   options.set_is_layout_sensitive(true);
422   // When transposes appear in a fusion node, we can easily adjust the
423   // multi-dimensional index to create the one needed for the operand. This
424   // is not as easy with bitcasts, because we don't have the information
425   // readily available which dimensions are permuted. In addition to that,
426   // if we have a transpose and a reshape next to each other, they will both
427   // be replaced by a bitcast, and we replace bitcast(bitcast) with one
428   // bitcast. This leads to having to linearize and then delinearize the
429   // index.
430   options.set_replace_transpose_with_bitcast(false);
431   options.set_enable_conv_operand_swap(false);
432   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
433 
434   if (RequireDeterminism() ||
435       hlo_module->config().debug_options().xla_gpu_deterministic_reductions() ||
436       hlo_module->config().debug_options().xla_gpu_deterministic_ops()) {
437     pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();
438   }
439 
440   // GemmRewriter assumes that all transposes are folded into gemms, but,
441   // since commit 7d529df, this is not always true at this point.
442   // Therefore, rerun transpose folding.
443   pipeline.AddPass<TransposeFolding>(
444       [](const HloInstruction& dot,
445          const TransposeFolding::OperandIndices& candidate_operands) {
446         return IsMatrixMultiplication(dot) ? candidate_operands
447                                            : TransposeFolding::OperandIndices{};
448       },
449       TransposeFolding::NeverFoldTranspose);
450   // Rewrite GEMMs into custom calls.
451   pipeline.AddPass<GemmRewriter>();
452 
453   // Choose the fastest algorithm for each conv.
454   //
455   // We pick the algorithm before fusion so we can generate better HLO. After
456   // GpuConvRewriter, our convolutions are CustomCalls which return a
457   // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
458   // scratch:
459   //
460   //   customcall = (f32[...], f32[0])
461   //   return gte(customcall, 0)
462   //
463   // The algorithm picker then chooses the best algorithm, and potentially
464   // increases the scratch space.  It replaces customcall with new_tuple,
465   // giving us the following:
466   //
467   //   new_customcall = (f32[...], f32[N])
468   //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
469   //   return gte(new_tuple, 0)
470   //
471   // The new tuple and gte instructions then be simplified away, because
472   // nobody is expected to use the scratch value.
473   //
474   // However, if we were to run GpuConvAlgorithmPicker after fusion
475   // the gte(customcall, 0) would probably already be into a fusion node.  We
476   // can't simplify across HloComputation boundaries, so in this case we
477   // wouldn't be able to simplify away the new_tuple bits.
478   pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
479 
480   // Clean up new_tuple described above.
481   pipeline.AddPass<TupleSimplifier>();
482 
483   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
484   TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
485 
486   return Status::OK();
487 }
488 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)489 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
490     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
491     const CompileOptions& options) {
492   // We dump the post-optimization HLO in RunBackend so no need to dump it here.
493   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
494   tensorflow::profiler::TraceMe activity(
495       [&] { return absl::StrCat("HLO Transforms:", module->name()); },
496       tensorflow::profiler::TraceMeLevel::kInfo);
497   TF_RETURN_IF_ERROR(
498       OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
499 
500   TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
501 
502   return std::move(module);
503 }
504 
DummyCanShareBufferFunction(const HloInstruction *,const HloInstruction *,const ShapeIndex &)505 static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
506                                                         const HloInstruction*,
507                                                         const ShapeIndex&) {
508   return absl::nullopt;
509 }
510 
511 StatusOr<
512     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)513 GpuCompiler::RunHloPassesAndBufferAssignement(
514     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
515     bool optimize, const CompileOptions& options) {
516   if (optimize) {
517     TF_ASSIGN_OR_RETURN(hlo_module,
518                         RunHloPasses(std::move(hlo_module), executor, options));
519   }
520 
521   std::unique_ptr<StreamAssignment> stream_assignment =
522       AssignStreams(*hlo_module);
523   TF_ASSIGN_OR_RETURN(
524       std::unique_ptr<GpuHloSchedule> hlo_schedule,
525       GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size_));
526 
527   auto buffer_size_bytes_function =
528       [this](const BufferValue& buffer_value) -> int64 {
529     return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size_);
530   };
531 
532   TF_ASSIGN_OR_RETURN(
533       std::unique_ptr<BufferAssignment> assignment,
534       BufferAssigner::Run(
535           hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
536           buffer_size_bytes_function,
537           /*color_alignment=*/
538           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
539           /*allocate_buffers_for_constants=*/true,
540           /*colorer=*/BufferAssigner::DefaultColorer(),
541           /*must_not_live_out=*/{}, DummyCanShareBufferFunction));
542 
543   return std::make_tuple(std::move(hlo_module), std::move(assignment));
544 }
545 
546 // The order of `thunk_sequence` corresponds to
547 // `hlo_schedule->ThunkLaunchOrder()`.
CompileModuleToLlvmIrImpl(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,absl::optional<CudaComputeCapability> cuda_compute_capability,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer_function,int pointer_size,const HloProfileIndexMap * profile_index_map,std::unique_ptr<llvm::Module> * llvm_module,std::unique_ptr<BufferAssignment> * buffer_assignment,std::unique_ptr<ThunkSchedule> * thunk_schedule,std::vector<GpuExecutable::ConstantInfo> * constants)548 static Status CompileModuleToLlvmIrImpl(
549     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
550     const std::string& target_triple, const std::string& data_layout,
551     const std::string& platform_name, GpuDeviceInfo gpu_device_info,
552     absl::optional<CudaComputeCapability> cuda_compute_capability,
553     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
554     int pointer_size, const HloProfileIndexMap* profile_index_map,
555     std::unique_ptr<llvm::Module>* llvm_module,
556     std::unique_ptr<BufferAssignment>* buffer_assignment,
557     std::unique_ptr<ThunkSchedule>* thunk_schedule,
558     std::vector<GpuExecutable::ConstantInfo>* constants) {
559   *llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
560 
561   (*llvm_module)->setTargetTriple(target_triple);
562   (*llvm_module)->setDataLayout(data_layout);
563 
564   std::unique_ptr<StreamAssignment> stream_assignment =
565       AssignStreams(*hlo_module);
566   TF_ASSIGN_OR_RETURN(
567       std::unique_ptr<GpuHloSchedule> hlo_schedule,
568       GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size));
569 
570   auto buffer_size_bytes_function =
571       [pointer_size](const BufferValue& buffer_value) -> int64 {
572     return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size);
573   };
574 
575   TF_ASSIGN_OR_RETURN(
576       *buffer_assignment,
577       BufferAssigner::Run(
578           hlo_module, hlo_schedule->ConsumeHloOrdering(),
579           buffer_size_bytes_function,
580           /*color_alignment=*/
581           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
582           /*allocate_buffers_for_constants=*/true,
583           /*colorer=*/BufferAssigner::DefaultColorer(),
584           /*must_not_live_out=*/{}, can_share_buffer_function));
585 
586   VLOG(1) << "Buffer Assignment Stats "
587           << (*buffer_assignment)->GetStats().ToString();
588   DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
589                          "after_optimizations");
590 
591   mlir::MLIRContext mlir_context;
592   mlir_context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
593                            mlir::StandardOpsDialect,
594                            mlir::lmhlo_gpu::LmhloGpuDialect>();
595 
596   IrEmitterContext ir_emitter_context(
597       hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
598       cuda_compute_capability, profile_index_map, &mlir_context,
599       llvm_module->get());
600 
601   HloComputation* entry_computation = hlo_module->entry_computation();
602 
603   TF_ASSIGN_OR_RETURN(
604       auto ir_emitter,
605       IrEmitterUnnested::Create(hlo_module->config(), entry_computation,
606                                 &ir_emitter_context));
607 
608   {
609     XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
610 
611     absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo;
612     ThunkSequence thunk_sequence;
613     absl::Span<HloInstruction* const> order = hlo_schedule->ThunkLaunchOrder();
614     for (HloInstruction* instruction : order) {
615       TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get()));
616       TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction));
617       std::unique_ptr<ThunkSequence> thunks =
618           ir_emitter->ConsumeThunkSequence();
619 
620       // The invariants between each input HloInstruction* and output Thunk* are
621       // not all explicitly checked, but at least we can document them here:
622       // * The entry HloComputation shall not have dead code (all reachable from
623       // ROOT).
624       // * The visited instructions are all instructions in the entry
625       // computation.
626       // * For each visit of these HloInstructions, either none or one Thunk
627       // will be returned.
628       // * If there is a thunk returned, thunk->hlo_instruction_ equals the
629       // input HloInstruction*.
630       // * A returned thunk may contain other sub-thunks. A sub-thunk may or may
631       // not have an associated hlo_instruction_.
632       TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
633       if (!thunks->empty()) {
634         auto thunk = std::move(thunks->front());
635         InsertOrDie(&thunk_to_hlo, thunk.get(), instruction);
636         thunk_sequence.push_back(std::move(thunk));
637       }
638     }
639     // TODO(timshen): ThunkSchedule taking thunk_to_hlo is a bit awkward. To fix
640     // that, we can turn it into a proper pass, from:
641     //   map<Thunk, HloInstruction> -> (ThunkSchedule, [Thunk...])
642     // to:
643     //   map<Thunk, HloInstruction> -> GenerateMultiStreamDepInfo() -> [(Thunk,
644     //   DepInfo)...]
645     //
646     //   where "DepInfo" is
647     //   struct {
648     //     int stream_number;
649     //     std::vector<Thunk*> dependencies;
650     //     std::vector<Thunk*> users;
651     //   };
652     // We might want to do this after MLIR migration.
653     *thunk_schedule = absl::make_unique<ThunkSchedule>(
654         std::make_unique<ThunkSequence>(std::move(thunk_sequence)),
655         std::move(stream_assignment), std::move(thunk_to_hlo));
656 
657     if (constants) {
658       *constants = std::move(ir_emitter_context.constants());
659     }
660   }
661 
662   return Status::OK();
663 }
664 
NullDiagnosticHandler(const llvm::DiagnosticInfo & diag_info,void * context)665 static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
666                                   void* context) {
667   std::string error_string;
668   llvm::raw_string_ostream string_printer(error_string);
669   llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
670   diag_info.print(diagnostic_printer);
671 
672   VLOG(1) << error_string;
673 }
674 
675 StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileToTargetBinary(const HloModuleConfig & module_config,std::unique_ptr<llvm::Module> llvm_module,se::StreamExecutor * stream_exec,const CompileOptions & options,const HloModule * debug_module)676 GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
677                                    std::unique_ptr<llvm::Module> llvm_module,
678                                    se::StreamExecutor* stream_exec,
679                                    const CompileOptions& options,
680                                    const HloModule* debug_module) {
681   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
682 
683   const auto compile_single_module =
684       [this, stream_exec, &module_config, debug_module](
685           llvm::Module* llvm_module, bool relocatable,
686           absl::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
687     {
688       XLA_SCOPED_LOGGING_TIMER(
689           "GpuCompiler::RunBackend - Running LLVM verifier");
690 
691       llvm_module->getContext().setDiagnosticHandlerCallBack(
692           NullDiagnosticHandler, nullptr);
693 
694       std::string err;
695       llvm::raw_string_ostream err_stream(err);
696 
697       // verifyModule() returns true if the module is broken.
698       TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
699           << "Invalid LLVM IR before optimizations:\n"
700           << err_stream.str()
701           << "\nThis probably indicates a bug in the HLO -> LLVM IR "
702              "lowering. Rerun with --xla_dump_to to get the IR"
703           << (debug_module
704                   ? absl::StrCat(" and looks for files with name containing: *",
705                                  FilenameFor(*debug_module, "", ""), "*")
706                   : ".");
707     }
708     GpuVersion gpu_version = GetGpuVersion(stream_exec);
709     StatusOr<std::pair<std::string, std::vector<uint8>>> result =
710         CompileTargetBinary(module_config, llvm_module, gpu_version,
711                             stream_exec, relocatable, debug_module);
712 
713     if (!result.ok()) {
714       return result;
715     }
716 
717     const bool should_dump =
718         DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
719                                    module_config.debug_options());
720 
721     if (should_dump) {
722       if (debug_module) {
723         if (shard_number.has_value()) {
724           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
725                                    /*optimized=*/true,
726                                    std::to_string(*shard_number));
727         } else {
728           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
729                                    /*optimized=*/true);
730         }
731       } else {
732         LOG(ERROR)
733             << "Dumping is not implemented since the file name cannot be "
734                "inferred. Please implement (potentially MLIR) module -> "
735                "filename heuristic.";
736       }
737     }
738 
739     if (user_post_optimization_hook_) {
740       user_post_optimization_hook_(*llvm_module);
741     }
742 
743     // Write PTX to IR dump directory, if IR dumping was requested.
744     if (should_dump) {
745       absl::string_view ptx = result->first;
746       if (debug_module) {
747         if (shard_number.has_value()) {
748           DumpToFileInDirOrStdout(*debug_module, "",
749                                   std::to_string(*shard_number) + ".ptx", ptx);
750         } else {
751           DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
752         }
753       } else {
754         LOG(ERROR)
755             << "Dumping is not implemented since the file name cannot be "
756                "inferred. Please implement (potentially MLIR) module -> "
757                "filename heuristic.";
758       }
759     }
760 
761     return result;
762   };
763 
764   tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
765 
766   absl::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
767   if (module_config.debug_options().xla_gpu_force_compilation_parallelism() !=
768       0) {
769     overriding_thread_pool.emplace(
770         tensorflow::Env::Default(), "",
771         module_config.debug_options().xla_gpu_force_compilation_parallelism());
772     thread_pool = &*overriding_thread_pool;
773   }
774 
775   if (!thread_pool) {
776     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
777                                  /*shard_number=*/absl::nullopt);
778   }
779 
780   // Test whether LinkModules is supported.
781   if (this->LinkModules(stream_exec, {}).status().code() ==
782       tensorflow::error::Code::UNIMPLEMENTED) {
783     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
784                                  /*shard_number=*/absl::nullopt);
785   }
786 
787   std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
788   int num_functions = 0;
789   for (llvm::Function& func : llvm_module->functions()) {
790     if (!func.isDeclaration() &&
791         func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
792       num_functions++;
793     }
794   }
795 
796   llvm::SplitModule(
797       *llvm_module.get(),
798       std::max<unsigned>(
799           1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
800       [&](std::unique_ptr<llvm::Module> module) {
801         llvm_modules.push_back(std::move(module));
802       },
803       /*PreserveLocals=*/true);
804 
805   std::vector<StatusOr<BackendCompileResult>> compile_results(
806       llvm_modules.size());
807   tensorflow::BlockingCounter counter(llvm_modules.size());
808   for (int i = 0; i < llvm_modules.size(); i++) {
809     thread_pool->Schedule(
810         [&compile_results, compile_single_module, i, &llvm_modules, &counter] {
811           llvm::Module* original_module = llvm_modules[i].get();
812           llvm::LLVMContext context;
813           std::string buffer;
814           llvm::raw_string_ostream error(buffer);
815 
816           std::unique_ptr<llvm::Module> new_llvm_module;
817           // Switch to a new context by dumping and re-parsing LLVM IR. Each
818           // thread has its own context to avoid race conditions.
819           {
820             std::string ir;
821             {
822               llvm::raw_string_ostream os(ir);
823               original_module->print(os, nullptr);
824             }
825             llvm::SMDiagnostic err;
826             new_llvm_module = llvm::parseAssemblyString(ir, err, context);
827           }
828 
829           compile_results[i] = compile_single_module(
830               new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
831           counter.DecrementCount();
832         });
833   }
834   counter.Wait();
835 
836   std::string ptx_snippets;
837   std::vector<std::vector<uint8>> submodule_compile_results;
838   for (auto& maybe_result : compile_results) {
839     TF_ASSIGN_OR_RETURN(auto result, maybe_result);
840     if (result.second.empty()) {
841       continue;
842     }
843     ptx_snippets += result.first;
844     ptx_snippets += "\n";
845     submodule_compile_results.push_back(result.second);
846   }
847 
848   TF_ASSIGN_OR_RETURN(
849       std::vector<uint8> backend_result,
850       this->LinkModules(stream_exec, std::move(submodule_compile_results)));
851 
852   return std::make_pair(ptx_snippets, backend_result);
853 }
854 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)855 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
856     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
857     const CompileOptions& options) {
858   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
859   std::string slow_compilation_msg =
860       absl::StrCat("Compiling module ", module->name());
861   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
862 
863   TF_RET_CHECK(stream_exec != nullptr);
864 
865   llvm::LLVMContext llvm_context;
866 
867   GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
868 
869   absl::optional<CudaComputeCapability> cuda_compute_capability =
870       [&]() -> absl::optional<CudaComputeCapability> {
871     CudaComputeCapability cuda_compute_capability;
872     stream_exec->GetDeviceDescription().cuda_compute_capability(
873         &cuda_compute_capability.cc_major, &cuda_compute_capability.cc_minor);
874     if (cuda_compute_capability.cc_major == -1) {
875       return absl::nullopt;
876     }
877     return cuda_compute_capability;
878   }();
879 
880   std::unique_ptr<HloProfileIndexMap> profile_index_map;
881   std::unique_ptr<HloProfilePrinterData> profile_printer;
882 
883   if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
884     HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
885     cost_analysis.set_bytes_per_second(
886         stream_exec->GetDeviceDescription().memory_bandwidth());
887     TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
888     VLOG(1) << "HLO memory read+written: "
889             << tensorflow::strings::HumanReadableNumBytes(
890                    cost_analysis.bytes_accessed());
891     if (module->config().hlo_profiling_enabled()) {
892       profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
893       profile_printer =
894           CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
895                                       module->entry_computation()->name());
896     }
897   }
898 
899   std::unique_ptr<llvm::Module> llvm_module;
900   std::unique_ptr<BufferAssignment> buffer_assignment;
901   std::unique_ptr<ThunkSchedule> thunk_schedule;
902   std::vector<GpuExecutable::ConstantInfo> constants;
903 
904   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
905       module.get(), &llvm_context, target_triple_, data_layout_,
906       stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
907       GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
908       &buffer_assignment, &thunk_schedule, &constants));
909 
910   if (user_pre_optimization_hook_) {
911     user_pre_optimization_hook_(*llvm_module);
912   }
913   string ir_module_string_before_opt;
914   const bool embed_ir_in_executable =
915       module->config().debug_options().xla_embed_ir_in_executable();
916   if (embed_ir_in_executable) {
917     ir_module_string_before_opt = llvm_ir::DumpModuleToString(*llvm_module);
918   }
919 
920   llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
921 
922   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
923   TF_ASSIGN_OR_RETURN(
924       BackendCompileResult backend_result,
925       CompileToTargetBinary(module->config(), std::move(llvm_module),
926                             stream_exec, options, module.get()));
927   if (DumpingEnabledForHloModule(*module)) {
928     DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
929                             thunk_schedule->ToString());
930   }
931 
932   using OutputInfoMap =
933       absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
934   TF_ASSIGN_OR_RETURN(OutputInfoMap output_info,
935                       GetOutputInfo(*module, *buffer_assignment));
936   auto buffer_assignment_proto =
937       std::make_unique<BufferAssignmentProto>(buffer_assignment->ToProto());
938   std::vector<BufferAllocation> allocations =
939       buffer_assignment->ReleaseAllocations();
940   std::string module_name = module->name();
941   Shape output_shape = module->entry_computation()->root_instruction()->shape();
942   size_t profile_index = 0;
943   if (profile_index_map) {
944     profile_index =
945         profile_index_map->GetProfileIndexFor(*module->entry_computation());
946   }
947 
948   GpuVersion gpu_version = GetGpuVersion(stream_exec);
949   auto* gpu_executable = new GpuExecutable(
950       {std::move(backend_result.first), std::move(backend_result.second),
951        gpu_version, std::move(thunk_schedule), std::move(constants),
952        std::move(output_info), module_name, output_shape,
953        std::move(allocations), std::move(buffer_assignment_proto),
954        std::move(module), profile_index, std::move(profile_printer),
955        std::move(profile_index_map)});
956   if (embed_ir_in_executable) {
957     DCHECK_NE("", ir_module_string_before_opt);
958     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
959   }
960   return std::unique_ptr<Executable>(gpu_executable);
961 }
962 
GetGpuDeviceInfo(se::StreamExecutor * stream_exec)963 GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) {
964   GpuDeviceInfo gpu_device_info;
965   gpu_device_info.threads_per_block_limit =
966       stream_exec->GetDeviceDescription().threads_per_block_limit();
967   gpu_device_info.threads_per_warp =
968       stream_exec->GetDeviceDescription().threads_per_warp();
969   gpu_device_info.shared_memory_per_block =
970       stream_exec->GetDeviceDescription().shared_memory_per_block();
971   gpu_device_info.threads_per_core_limit =
972       stream_exec->GetDeviceDescription().threads_per_core_limit();
973   gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
974   return gpu_device_info;
975 }
976 
977 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)978 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
979                                 const AotCompilationOptions& options) {
980   return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
981 }
982 
CompileModuleToLlvmIr(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,absl::optional<CudaComputeCapability> cuda_compute_capability,int pointer_size)983 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
984     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
985     const std::string& target_triple, const std::string& data_layout,
986     const std::string& platform_name, GpuDeviceInfo gpu_device_info,
987     absl::optional<CudaComputeCapability> cuda_compute_capability,
988     int pointer_size) {
989   std::unique_ptr<llvm::Module> llvm_module;
990   std::unique_ptr<BufferAssignment> buffer_assignment;
991   std::unique_ptr<ThunkSchedule> thunk_schedule;
992 
993   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
994       hlo_module, llvm_context, target_triple, data_layout, platform_name,
995       gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
996       pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
997       &buffer_assignment, &thunk_schedule, nullptr));
998   return llvm_module;
999 }
1000 
1001 // Analyze the function signature to reconstruct a vector of BufferAllocation
1002 // objects, as well as other output information.
1003 //
1004 // This function also serves as a half-baked verifier for function arg
1005 // attributes, since a full verifier doens't exist yet.
GetMlirAllocationInfo(mlir::FuncOp func,std::vector<BufferAllocation> * allocations,absl::flat_hash_map<ShapeIndex,GpuExecutable::OutputInfo> * output_info,Shape * output_shape)1006 static Status GetMlirAllocationInfo(
1007     mlir::FuncOp func, std::vector<BufferAllocation>* allocations,
1008     absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
1009     Shape* output_shape) {
1010   std::vector<absl::optional<BufferAllocation>> maybe_allocations;
1011 
1012   for (int i = 0; i < func.getNumArguments(); i++) {
1013     auto allocation_index_attr =
1014         func.getArgAttr(i, "lmhlo.alloc").dyn_cast_or_null<mlir::IntegerAttr>();
1015     TF_RET_CHECK(allocation_index_attr);
1016     int index = allocation_index_attr.getInt();
1017     if (index >= maybe_allocations.size()) {
1018       maybe_allocations.resize(index + 1);
1019     }
1020     mlir::BlockArgument arg = func.getArgument(i);
1021     TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
1022     size_t size = arg.getType().cast<mlir::ShapedType>().getSizeInBits() / 8;
1023     maybe_allocations[index].emplace(index, size, 0);
1024   }
1025 
1026   allocations->reserve(maybe_allocations.size());
1027   for (auto& maybe_alloc : maybe_allocations) {
1028     if (maybe_alloc.has_value()) {
1029       allocations->push_back(*maybe_alloc);
1030     } else {
1031       return InvalidArgument("Allocation indices should range in [0, n)");
1032     }
1033   }
1034 
1035   for (int i = 0; i < func.getNumArguments(); i++) {
1036     for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
1037       TF_RET_CHECK(attr.first == "lmhlo.alloc" ||
1038                    attr.first == "lmhlo.params" ||
1039                    attr.first == "lmhlo.output_index");
1040     }
1041   }
1042 
1043   std::vector<Shape> output_shapes;
1044   absl::optional<int> rank;
1045   for (int i = 0; i < func.getNumArguments(); i++) {
1046     auto index =
1047         func.getArgAttr(i, "lmhlo.alloc").cast<mlir::IntegerAttr>().getInt();
1048     if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1049       allocations->at(index).set_entry_computation_parameter(
1050           param_attr.cast<mlir::IntegerAttr>().getInt(), {},
1051           static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
1052     }
1053     if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
1054       allocations->at(index).set_maybe_live_out(true);
1055 
1056       // Reconstruct a shape index from output_index.
1057       ShapeIndex shape_index;
1058       for (const llvm::APInt& i :
1059            output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1060         shape_index.push_back(i.getSExtValue());
1061       }
1062       if (rank.has_value()) {
1063         if (*rank != shape_index.size()) {
1064           return InvalidArgument("Expect output_index to have the same ranks");
1065         }
1066       } else {
1067         rank.emplace(shape_index.size());
1068       }
1069       auto& o = (*output_info)[shape_index];
1070       o.allocation_index = index;
1071       if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1072         o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
1073                                ShapeIndex{});
1074       }
1075 
1076       if (shape_index.size() > 1) {
1077         return Unimplemented("Expect array type or 1-level tuple type");
1078       }
1079 
1080       mlir::BlockArgument arg = func.getArgument(i);
1081       if (shape_index.empty()) {
1082         output_shapes.push_back(TypeToShape(arg.getType()));
1083       } else {
1084         if (shape_index[0] >= output_shapes.size()) {
1085           output_shapes.resize(shape_index[0] + 1);
1086         }
1087         output_shapes[shape_index[0]] = TypeToShape(arg.getType());
1088       }
1089     }
1090   }
1091   *output_shape = ShapeUtil::MakeTupleShape(output_shapes);
1092 
1093   return Status::OK();
1094 }
1095 
CompileLmhloToExecutable(GpuCompiler * compiler,mlir::ModuleOp module,std::string module_name,const HloModuleConfig & module_config,const Compiler::CompileOptions & options,absl::string_view entry_function_name,se::StreamExecutor * stream_exec,std::unique_ptr<llvm::Module> llvm_module,IrEmitterContext * ir_emitter_context)1096 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
1097     GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
1098     const HloModuleConfig& module_config,
1099     const Compiler::CompileOptions& options,
1100     absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
1101     std::unique_ptr<llvm::Module> llvm_module,
1102     IrEmitterContext* ir_emitter_context) {
1103   mlir::FuncOp entry_function = mlir::cast<mlir::FuncOp>(module.lookupSymbol(
1104       llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
1105 
1106   std::vector<BufferAllocation> allocations;
1107   absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
1108   Shape output_shape;
1109   absl::flat_hash_map<ShapeIndex, int> output_to_argnum_map;
1110   TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
1111                                            &output_info, &output_shape));
1112 
1113   CHECK(!allocations.empty());
1114 
1115   ir_emitter_context->set_allocations(allocations);
1116 
1117   TF_ASSIGN_OR_RETURN(
1118       auto ir_emitter,
1119       IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
1120                                 ir_emitter_context));
1121   ThunkSequence thunk_sequence;
1122   for (mlir::Operation& op :
1123        entry_function.getBody().front().without_terminator()) {
1124     MlirEmitterInput input;
1125     input.op = &op;
1126     TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));
1127     std::unique_ptr<ThunkSequence> thunks = ir_emitter->ConsumeThunkSequence();
1128     TF_RET_CHECK(thunks->size() <= 1);
1129     if (!thunks->empty()) {
1130       auto thunk = std::move(thunks->front());
1131       thunk_sequence.push_back(std::move(thunk));
1132     }
1133   }
1134   auto thunk_schedule = absl::make_unique<ThunkSchedule>(
1135       std::make_unique<ThunkSequence>(std::move(thunk_sequence)));
1136 
1137   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
1138   TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
1139                       compiler->CompileToTargetBinary(
1140                           module_config, std::move(llvm_module), stream_exec,
1141                           options, /*debug_module=*/nullptr));
1142 
1143   GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
1144   auto* gpu_executable = new GpuExecutable(
1145       {std::move(backend_result.first), std::move(backend_result.second),
1146        gpu_version, std::move(thunk_schedule),
1147        std::move(ir_emitter_context->constants()), std::move(output_info),
1148        module_name, output_shape, std::move(allocations)});
1149   return std::unique_ptr<Executable>(gpu_executable);
1150 }
1151 
1152 }  // namespace gpu
1153 }  // namespace xla
1154