• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
17 
18 #include <stdlib.h>
19 
20 #include <atomic>
21 #include <functional>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/variant.h"
29 #include "llvm/AsmParser/Parser.h"
30 #include "llvm/Bitcode/BitcodeReader.h"
31 #include "llvm/Bitcode/BitcodeWriter.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/DiagnosticPrinter.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/IR/Verifier.h"
37 #include "llvm/Transforms/Utils/SplitModule.h"
38 #include "mlir/Dialect/GPU/Passes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
40 #include "mlir/InitAllDialects.h"  // from @llvm-project
41 #include "mlir/Pass/PassManager.h"  // from @llvm-project
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/utils/name_utils.h"
44 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
45 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
46 #include "tensorflow/compiler/xla/protobuf_util.h"
47 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
48 #include "tensorflow/compiler/xla/service/all_gather_broadcast_reorder.h"
49 #include "tensorflow/compiler/xla/service/all_gather_combiner.h"
50 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
51 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
52 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
53 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
54 #include "tensorflow/compiler/xla/service/async_collective_creator.h"
55 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
56 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
57 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
58 #include "tensorflow/compiler/xla/service/call_inliner.h"
59 #include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
60 #include "tensorflow/compiler/xla/service/comparison_expander.h"
61 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
62 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
63 #include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
64 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
65 #include "tensorflow/compiler/xla/service/dump.h"
66 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
67 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
68 #include "tensorflow/compiler/xla/service/eigh_expander.h"
69 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
70 #include "tensorflow/compiler/xla/service/gather_expander.h"
71 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
72 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
73 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
74 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
75 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
76 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
77 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
78 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
79 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
80 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
81 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
82 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
83 #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
84 #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
85 #include "tensorflow/compiler/xla/service/gpu/gpu_spmd_partitioner.h"
86 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
87 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
88 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
89 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
90 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
91 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
92 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
93 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
94 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
95 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
97 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
98 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
99 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
100 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
101 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
102 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
103 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
104 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
105 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
106 #include "tensorflow/compiler/xla/service/hlo_computation.h"
107 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
108 #include "tensorflow/compiler/xla/service/hlo_cse.h"
109 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
110 #include "tensorflow/compiler/xla/service/hlo_dce.h"
111 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
112 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
113 #include "tensorflow/compiler/xla/service/hlo_parser.h"
114 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
115 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
116 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
117 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
118 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
119 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
120 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
121 #include "tensorflow/compiler/xla/service/logistic_expander.h"
122 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
123 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
124 #include "tensorflow/compiler/xla/service/qr_expander.h"
125 #include "tensorflow/compiler/xla/service/real_imag_expander.h"
126 #include "tensorflow/compiler/xla/service/reduce_scatter_combiner.h"
127 #include "tensorflow/compiler/xla/service/reshape_mover.h"
128 #include "tensorflow/compiler/xla/service/result_caster.h"
129 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
130 #include "tensorflow/compiler/xla/service/rng_expander.h"
131 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
132 #include "tensorflow/compiler/xla/service/sharding_remover.h"
133 #include "tensorflow/compiler/xla/service/slice_sinker.h"
134 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
135 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
136 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
137 #include "tensorflow/compiler/xla/service/transpose_folding.h"
138 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
139 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
140 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
141 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
142 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
143 #include "tensorflow/compiler/xla/status_macros.h"
144 #include "tensorflow/compiler/xla/types.h"
145 #include "tensorflow/compiler/xla/util.h"
146 #include "tensorflow/core/lib/core/status.h"
147 #include "tensorflow/core/lib/gtl/cleanup.h"
148 #include "tensorflow/core/lib/io/path.h"
149 #include "tensorflow/core/platform/blocking_counter.h"
150 #include "tensorflow/core/platform/env.h"
151 #include "tensorflow/core/platform/logging.h"
152 #include "tensorflow/core/platform/regexp.h"
153 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
154 #include "tensorflow/core/platform/subprocess.h"
155 #include "tensorflow/core/platform/threadpool.h"
156 #include "tensorflow/core/platform/tracing.h"
157 #include "tensorflow/core/profiler/lib/traceme.h"
158 #include "tensorflow/core/util/env_var.h"
159 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
160 #include "tfrt/bef_converter/mlir_to_bef_translate.h"  // from @tf_runtime
161 
162 #if BEF_EXECUTABLE
163 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h"
164 #endif  // BEF_EXECUTABLE
165 
166 namespace xla {
167 namespace gpu {
168 namespace {
169 
170 class GpuBfloat16Support : public BFloat16Support {
171  public:
GpuBfloat16Support(bool supports_matrix_multiplication,se::StreamExecutor * stream_exec)172   explicit GpuBfloat16Support(bool supports_matrix_multiplication,
173                               se::StreamExecutor* stream_exec)
174       : supports_matrix_multiplication_(supports_matrix_multiplication),
175         stream_exec_(stream_exec) {}
176 
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const177   bool SupportsBF16Operand(const HloInstruction& hlo,
178                            int64_t operand_index) const override {
179     return BFloat16Support::SupportsBF16Operand(hlo, operand_index) ||
180            IsSupported(hlo);
181   }
182 
183   // Returns whether the backend supports BF16 output for the HLO instruction.
SupportsBF16Output(const HloInstruction & hlo) const184   bool SupportsBF16Output(const HloInstruction& hlo) const override {
185     return BFloat16Support::SupportsBF16Output(hlo) || IsSupported(hlo);
186   }
187 
188  private:
IsSupported(const HloInstruction & hlo) const189   bool IsSupported(const HloInstruction& hlo) const {
190     switch (hlo.opcode()) {
191       case HloOpcode::kAllGather:
192       case HloOpcode::kAllReduce:
193       case HloOpcode::kAllReduceStart:
194       case HloOpcode::kAllReduceDone:
195       case HloOpcode::kReduceScatter:
196       case HloOpcode::kAllToAll:
197       case HloOpcode::kBitcast:
198       case HloOpcode::kCollectivePermute:
199         return true;
200       case HloOpcode::kConvolution:
201         return IsConvBF16Supported();
202       default:
203         return supports_matrix_multiplication_ &&
204                gpu::IsMatrixMultiplication(hlo);
205     }
206   }
207 
IsConvBF16Supported() const208   bool IsConvBF16Supported() const {
209     if (se::dnn::DnnSupport* dnn = stream_exec_->AsDnn()) {
210       se::port::StatusOr<se::dnn::VersionInfo> cudnn_version =
211           dnn->GetVersion();
212       return cudnn_version.ok() &&
213              (cudnn_version->major_version() > 8 ||
214               (cudnn_version->major_version() == 8 &&
215                cudnn_version->minor_version() >= 2)) &&
216              stream_exec_->GetDeviceDescription()
217                  .cuda_compute_capability()
218                  .IsAtLeast(se::CudaComputeCapability::AMPERE);
219     }
220     return false;
221   }
222 
223   bool supports_matrix_multiplication_;
224   se::StreamExecutor* stream_exec_;
225 };
226 
227 }  // end anonymous namespace
228 
229 using OwnedThunkSchedule = GpuExecutable::OwnedThunkSchedule;
230 
GpuCompiler(se::Platform::Id platform_id,const char * target_triple,const char * data_layout)231 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
232                          const char* target_triple, const char* data_layout)
233     : platform_id_(platform_id),
234       target_triple_(target_triple),
235       data_layout_(data_layout),
236       pointer_size_(llvm::DataLayout(data_layout)
237                         .getPointerSize(0 /* default address space */)) {}
238 
239 // Runs optimization passes on the given HLO module.
OptimizeHloModule(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)240 Status GpuCompiler::OptimizeHloModule(
241     HloModule* hlo_module, se::StreamExecutor* stream_exec,
242     se::DeviceMemoryAllocator* device_allocator) {
243   {
244     HloPassPipeline pipeline("optimization");
245     pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
246                                               /*allow_mixed_precision=*/false);
247     pipeline.AddPass<AllToAllDecomposer>();
248 
249     OpExpanderPass::PatternExtraFilter upcaster_filter =
250         [&](const HloInstruction* instr) {
251           return !stream_exec->GetDeviceDescription()
252                       .cuda_compute_capability()
253                       .IsAtLeast(se::CudaComputeCapability::VOLTA) ||
254                  !gpu::IsMatrixMultiplication(*instr);
255         };
256 
257     pipeline.AddPass<OperandUpcaster>(upcaster_filter);
258     pipeline.AddPass<ResultCaster>(upcaster_filter);
259 
260     // Expand random number generation.
261     pipeline.AddPass<RngExpander>();
262     pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
263 
264     // Comparison total order expander
265     pipeline.AddPass<ComparisonExpander>();
266 
267     // Remove zero-sized HLO from the input so that other passes don't have to
268     // handle it.
269     pipeline.AddPass<ZeroSizedHloElimination>();
270 
271     pipeline.AddPass<GpuScatterExpander>();
272     // TODO(phawkins): replace QR and Eigh decompositions with calls to
273     // cuSOLVER.
274     pipeline.AddPass<QrExpander>();
275     pipeline.AddPass<EighExpander>();
276 
277     pipeline.AddPass<DynamicIndexSplitter>();
278 
279     // TODO(b/64094172): make Call work on GPU instead of inlining.
280     pipeline.AddPass<CallInliner>();
281 
282     pipeline.AddPass<DotDecomposer>();
283 
284     pipeline.AddPass<Convolution4DExpander>();
285 
286     // Expand the sort op to support stable sorting if required.
287     pipeline.AddPass<StableSortExpander>();
288 
289     GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/true,
290                             stream_exec);
291     pipeline.AddPass<BFloat16Normalization>(&bf16);
292 
293     // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
294     // where possible.  Not every batchnorm op can be implemented as a call to
295     // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
296     if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
297       // Since BatchNorm inference is essentially pointwise operations, it is
298       // always advantageous to use kernel fusion rather than cudnn.
299       pipeline.AddPass<BatchNormExpander>(
300           /*rewrite_training_op=*/false,
301           /*rewrite_inference_op=*/true,
302           /*rewrite_grad_op=*/false);
303       pipeline.AddPass<CudnnBatchNormRewriter>();
304     }
305     pipeline.AddPass<BatchNormExpander>(
306         /*rewrite_training_op=*/true,
307         /*rewrite_inference_op=*/true,
308         /*rewrite_grad_op=*/true);
309 
310     pipeline.AddPass<LogisticExpander>(
311         /*expansion_type=*/LogisticExpansionType::kExp);
312     pipeline.AddPass<ConditionalCanonicalizer>();
313     pipeline.AddPass<DynamicPadder>();
314 
315     // Build simplification pipeline.  The passes in here are run to a fixed
316     // point.
317     [&pipeline =
318          pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
319       pipeline.AddInvariantCheckerDebug<HloVerifier>(
320           /*layout_sensitive=*/false,
321           /*allow_mixed_precision=*/false);
322 
323       // BatchNormExpander can create zero-sized ops, so zero-sized HLO
324       // elimination has to come after that pass.
325       pipeline.AddPass<ZeroSizedHloElimination>();
326 
327       pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
328       pipeline.AddPass<ScatterExpander>(
329           ScatterExpander::kEliminateSimpleScatters);
330 
331       AlgebraicSimplifierOptions options;
332       // When transposes appear in a fusion node, we can easily adjust the
333       // multi-dimensional index to create the one needed for the operand.
334       // This is not as easy with bitcasts, because we don't have the
335       // information readily available which dimensions are permuted. In
336       // addition to that, if we have a transpose and a reshape next to each
337       // other, they will both be replaced by a bitcast, and we replace
338       // bitcast(bitcast) with one bitcast. This leads to having to
339       // linearize and then delinearize the index.
340       options.set_replace_transpose_with_bitcast(false);
341       options.set_enable_conv_operand_swap(false);
342       pipeline.AddPass<AlgebraicSimplifier>(options);
343       // AlgebraicSimplifier may add contracting dimensions to a dot.
344       pipeline.AddPass<DotDecomposer>();
345       pipeline.AddPass<SortSimplifier>();
346       pipeline.AddPass<TupleSimplifier>();
347       pipeline.AddPass<WhileLoopConstantSinking>();
348       pipeline.AddPass<WhileLoopSimplifier>();
349 
350       // TODO(b/134075051): Re-enable after b/134075051 is fixed.
351       // pipeline.AddPass<SliceSinker>();
352 
353       pipeline.AddPass<HloDCE>();
354       pipeline.AddPass<ReshapeMover>();
355       pipeline.AddPass<HloConstantFolding>();
356       pipeline.AddPass<ConditionalSimplifier>();
357       pipeline.AddPass<RealImagExpander>();
358     }();
359 
360     pipeline.AddPass<TransposeFolding>(
361         [](const HloInstruction& dot,
362            const TransposeFolding::OperandIndices& candidate_operands) {
363           return IsMatrixMultiplication(dot)
364                      ? candidate_operands
365                      : TransposeFolding::OperandIndices{};
366         });
367     pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
368     pipeline.AddPass<HloDCE>();
369 
370     // Run WhileLoopTripCountAnnotator at the end of the simplification
371     // pipeline, before layout assignment and fusion.  This pass does some
372     // pattern-matching on while bodies/conditions, and this is where the HLO is
373     // "nicest".
374     //
375     // It's important that we don't make semantic changes (e.g. unrolling) to
376     // any `while` loops after this point, because otherwise the trip-count
377     // annotations added by this pass may not be correct after the
378     // modifications.
379     pipeline.AddPass<WhileLoopTripCountAnnotator>();
380     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
381   }
382 
383   if (hlo_module->config().use_spmd_partitioning()) {
384     HloPassPipeline spmd_pipeline("spmd-partitioner");
385     const int64_t num_partitions = hlo_module->config().num_partitions();
386     if (num_partitions > 1) {
387       spmd_pipeline.AddPass<ShardingPropagation>(/*is_spmd=*/true);
388       spmd_pipeline.AddPass<GpuSpmdPartitioner>(
389           num_partitions, hlo_module->config().replica_count());
390     } else {
391       // Remove redundant sharding ops when partition_count == 1.
392       spmd_pipeline.AddPass<ShardingRemover>();
393       spmd_pipeline.AddPass<HloDCE>();
394     }
395     TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
396   }
397 
398   // Optimize collectives generated by SPMD partitioning. Enable these passes
399   // otherwise as well so that all collectives can get these optimizations.
400   {
401     HloPassPipeline collectives_pipeline("collective-optimizations");
402     collectives_pipeline.AddPass<ReduceScatterCreator>();
403     collectives_pipeline.AddPass<AllReduceReassociate>();
404 
405     // Run algebraic simplifier to reshape(broadcast) into a broadcast when
406     // the reshape is just adding a unit dimension. This will help with the
407     // AllGatherBroadcastReorder pass.
408     AlgebraicSimplifierOptions options;
409     options.set_replace_transpose_with_bitcast(false);
410     options.set_enable_conv_operand_swap(false);
411     collectives_pipeline.AddPass<AlgebraicSimplifier>(options);
412 
413     collectives_pipeline.AddPass<AllGatherBroadcastReorder>();
414     TF_RETURN_IF_ERROR(collectives_pipeline.Run(hlo_module).status());
415   }
416 
417   // Run target-specific HLO optimization passes for convolution
418   // canonicalization.
419   TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
420       hlo_module, stream_exec, device_allocator));
421 
422   {
423     // Run layout assignment in a separate pipeline from
424     // "post-layout-assignment" because we want everything after layout
425     // assignment to have a layout-sensitive invariant-checker, but
426     // HloPassPipeline also runs its invariant checker before any passes are
427     // run, meaning, the pipeline that contains layout assignment cannot contain
428     // a layout-sensitive verifier!
429     HloPassPipeline pipeline("layout assignment");
430     // Layout assignment uses alias analysis, which requires the call graph to
431     // be flattened.
432     pipeline.AddPass<FlattenCallGraph>();
433     ChannelLayoutConstraints layout_constraints;
434     pipeline.AddPass<GpuLayoutAssignment>(
435         hlo_module->mutable_entry_computation_layout(), stream_exec,
436         &layout_constraints);
437     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
438   }
439 
440   // Run target-specific HLO optimization passes after layout assignment.
441   TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
442                                                      device_allocator));
443 
444   {
445     HloPassFix<HloPassPipeline> fusion("fusion");
446     // We try to split variadic ops with many parameters into several such ops
447     // to avoid exceeding the parameter space.
448     fusion.AddPass<VariadicOpSplitter>();
449     fusion.AddInvariantCheckerDebug<HloVerifier>(
450         /*layout_sensitive=*/true,
451         /*allow_mixed_precision=*/false,
452         LayoutAssignment::InstructionCanChangeLayout);
453     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
454     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
455     fusion.AddPass<FusionMerger>();
456     fusion.AddPass<GpuMultiOutputFusion>();
457     fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
458                            /*only_fusion_computations=*/true);
459     fusion.AddPass<HloDCE>();
460     TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
461   }
462 
463   {
464     HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
465     horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
466     horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
467     horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
468                                       /*only_fusion_computations=*/true);
469     horizontal_fusion.AddPass<HloDCE>();
470     TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
471   }
472 
473   {
474     HloPassPipeline pipeline("post-fusion optimization");
475     pipeline.AddPass<AllGatherCombiner>(
476         /*combine_threshold_in_bytes=*/1024 * 1024 * 1024,
477         /*combine_threshold_count=*/256);
478     pipeline.AddPass<AllReduceCombiner>(
479         /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
480         /*combine_threshold_count=*/256);
481     pipeline.AddPass<ReduceScatterCombiner>(
482         /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
483         /*combine_threshold_count=*/256);
484 
485     if (hlo_module->config()
486             .debug_options()
487             .xla_gpu_enable_async_all_reduce()) {
488       pipeline.AddPass<AsyncCollectiveCreator>(/*convert_all_reduce=*/true,
489                                                /*convert_all_gather=*/false);
490     }
491 
492     pipeline.AddPass<CollectivesScheduleLinearizer>();
493 
494     // Now we allow replacing any transposes outside of fusions with bitcasts.
495     AlgebraicSimplifierOptions options;
496     options.set_is_layout_sensitive(true);
497     options.set_enable_conv_operand_swap(false);
498     pipeline.AddPass<AlgebraicSimplifier>(options);
499 
500     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
501   }
502 
503   return Status::OK();
504 }
505 
506 // Modifies the given HLO module so that it will be accepted by IrEmitter.
507 // Unlike optimization passes, the passes are necessary for correctness.
PrepareHloModuleForIrEmitting(HloModule * hlo_module)508 Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
509   // In some cases, we have to place the result of an instruction in a temporary
510   // buffer. For instance, the buffer that holds an external parameter is
511   // assumed immutable at this point, and should not be reused for output
512   // (b/27180329). Therefore, in that case, we set the output to be a copy of
513   // the parameter.
514   HloPassPipeline pipeline("GPU-ir-emit-prepare");
515   pipeline.AddInvariantCheckerDebug<HloVerifier>(
516       /*layout_sensitive=*/true,
517       /*allow_mixed_precision=*/false,
518       LayoutAssignment::InstructionCanChangeLayout);
519 
520   // Copy insertion should be performed immediately before IR emission to avoid
521   // inserting unnecessary copies (later pass adds an instruction which
522   // materializes the value) or missing a necessary copy (later pass removes an
523   // instruction which materializes a value). DCE must be run immediately before
524   // (and sometime after) copy insertion, to avoid dead code from interfering
525   // with the rewrites.
526   pipeline.AddPass<HloDCE>();
527   if (hlo_module->config().alias_passthrough_params()) {
528     pipeline.AddPass<AliasPassthroughParams>();
529   }
530   pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
531   pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
532   pipeline.AddPass<GpuSanitizeConstantNames>();
533   return pipeline.Run(hlo_module).status();
534 }
535 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)536 Status GpuCompiler::OptimizeHloPostLayoutAssignment(
537     HloModule* hlo_module, se::StreamExecutor* stream_exec,
538     se::DeviceMemoryAllocator* device_allocator) {
539   HloPassPipeline pipeline("post-layout_assignment");
540   pipeline.AddInvariantCheckerDebug<HloVerifier>(
541       /*layout_sensitive=*/true,
542       /*allow_mixed_precision=*/false,
543       LayoutAssignment::InstructionCanChangeLayout);
544 
545   pipeline.AddPass<ReductionDegenerateDimRemover>();
546   pipeline.AddPass<ReductionLayoutNormalizer>();
547   pipeline.AddPass<ReductionDimensionGrouper>();
548   pipeline.AddPass<HloPassFix<ReductionSplitter>>();
549 
550   // The LayoutAssignment pass may leave behind kCopy instructions which are
551   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
552   AlgebraicSimplifierOptions options;
553   options.set_is_layout_sensitive(true);
554   // When transposes appear in a fusion node, we can easily adjust the
555   // multi-dimensional index to create the one needed for the operand. This
556   // is not as easy with bitcasts, because we don't have the information
557   // readily available which dimensions are permuted. In addition to that,
558   // if we have a transpose and a reshape next to each other, they will both
559   // be replaced by a bitcast, and we replace bitcast(bitcast) with one
560   // bitcast. This leads to having to linearize and then delinearize the
561   // index.
562   options.set_replace_transpose_with_bitcast(false);
563   options.set_enable_conv_operand_swap(false);
564   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
565 
566   if (RequireDeterminism(hlo_module->config()) ||
567       hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) {
568     pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(
569         stream_exec->GetDeviceDescription().cuda_compute_capability());
570   }
571 
572   // GemmRewriter assumes that all transposes are folded into gemms, but,
573   // since commit 7d529df, this is not always true at this point.
574   // Therefore, rerun transpose folding.
575   pipeline.AddPass<TransposeFolding>(
576       [](const HloInstruction& dot,
577          const TransposeFolding::OperandIndices& candidate_operands) {
578         return IsMatrixMultiplication(dot) ? candidate_operands
579                                            : TransposeFolding::OperandIndices{};
580       },
581       TransposeFolding::NeverFoldTranspose);
582   // Rewrite GEMMs into custom calls.
583   pipeline.AddPass<GemmRewriter>();
584 
585   // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
586   pipeline.AddPass<GemmBroadcastFoldingRewriter>();
587 
588   // Run conversion again, to catch those matrix multiplications which were not
589   // rewritten into cuBLAS calls.
590   GpuBfloat16Support bf16(/*supports_matrix_multiplication=*/false,
591                           stream_exec);
592   pipeline.AddPass<BFloat16Normalization>(&bf16);
593 
594   // Choose the fastest algorithm for each conv.
595   //
596   // We pick the algorithm before fusion so we can generate better HLO. After
597   // GpuConvRewriter, our convolutions are CustomCalls which return a
598   // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
599   // scratch:
600   //
601   //   customcall = (f32[...], f32[0])
602   //   return gte(customcall, 0)
603   //
604   // The algorithm picker then chooses the best algorithm, and potentially
605   // increases the scratch space.  It replaces customcall with new_tuple,
606   // giving us the following:
607   //
608   //   new_customcall = (f32[...], f32[N])
609   //   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
610   //   return gte(new_tuple, 0)
611   //
612   // The new tuple and gte instructions then be simplified away, because
613   // nobody is expected to use the scratch value.
614   //
615   // However, if we were to run GpuConvAlgorithmPicker after fusion
616   // the gte(customcall, 0) would probably already be into a fusion node.  We
617   // can't simplify across HloComputation boundaries, so in this case we
618   // wouldn't be able to simplify away the new_tuple bits.
619   pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
620 
621   // Clean up new_tuple described above.
622   pipeline.AddPass<TupleSimplifier>();
623 
624   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
625   TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
626 
627   return Status::OK();
628 }
629 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)630 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
631     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
632     const CompileOptions& options) {
633   // We dump the post-optimization HLO in RunBackend so no need to dump it here.
634   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
635   tensorflow::profiler::TraceMe activity(
636       [&] { return absl::StrCat("HLO Transforms:", module->name()); },
637       tensorflow::profiler::TraceMeLevel::kInfo);
638   TF_RETURN_IF_ERROR(
639       OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
640 
641   TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
642 
643   return std::move(module);
644 }
645 
DummyCanShareBufferFunction(const HloInstruction *,const HloInstruction *,const ShapeIndex &)646 static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
647                                                         const HloInstruction*,
648                                                         const ShapeIndex&) {
649   return absl::nullopt;
650 }
651 
652 StatusOr<
653     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)654 GpuCompiler::RunHloPassesAndBufferAssignement(
655     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
656     bool optimize, const CompileOptions& options) {
657   if (optimize) {
658     TF_ASSIGN_OR_RETURN(hlo_module,
659                         RunHloPasses(std::move(hlo_module), executor, options));
660   }
661 
662   std::unique_ptr<StreamAssignment> stream_assignment =
663       AssignStreams(*hlo_module);
664   TF_ASSIGN_OR_RETURN(std::unique_ptr<GpuHloSchedule> hlo_schedule,
665                       GpuHloSchedule::Build(hlo_module.get(),
666                                             *stream_assignment, pointer_size_));
667 
668   auto buffer_size_bytes_function =
669       [this](const BufferValue& buffer_value) -> int64 {
670     return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size_);
671   };
672 
673   TF_ASSIGN_OR_RETURN(
674       std::unique_ptr<BufferAssignment> assignment,
675       BufferAssigner::Run(
676           hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
677           buffer_size_bytes_function,
678           /*color_alignment=*/
679           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
680           /*allocate_buffers_for_constants=*/true,
681           /*colorer=*/BufferAssigner::DefaultColorer(),
682           /*must_not_live_out=*/{}, GetCanShareBuffer()));
683 
684   return std::make_tuple(std::move(hlo_module), std::move(assignment));
685 }
686 
687 #if BEF_EXECUTABLE
LowerToBef(mlir::ModuleOp mlir_module)688 static StatusOr<tfrt::BefBuffer> LowerToBef(mlir::ModuleOp mlir_module) {
689   if (!mlir_module) {
690     return tensorflow::errors::FailedPrecondition(
691         "No mlir module to lower to BEF.");
692   }
693 
694   // LHLO -> TFRT Dialect (gpu kernels)
695   mlir::PassManager pm(mlir_module.getContext(),
696                        mlir::PassManager::Nesting::Implicit);
697   pm.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
698   pm.addPass(mlir::createGpuAsyncRegionPass());
699   pm.addPass(tensorflow::createAsyncGpuTfrtConversionPass());
700   if (pm.run(mlir_module).failed()) {
701     return InternalError(
702         "Failed to lower LHLO to TFRT Dialect with gpu kernels.");
703   }
704 
705   // Perform DCE with empty pattern set.
706   if (failed(mlir::applyPatternsAndFoldGreedily(
707           mlir_module, mlir::RewritePatternSet(mlir_module.getContext())))) {
708     return InternalError("Failed to remove dead ops.");
709   }
710 
711   // TFRT Dialect -> BEF
712   std::string bef;
713   llvm::raw_string_ostream bef_ostream(bef);
714   if (tfrt::MLIRToBEFTranslate(mlir_module, bef_ostream).failed()) {
715     return InternalError("Failed to lower TFRT Dialect to BEF.");
716   }
717 
718   return tfrt::BefBuffer(bef.data(), bef.data() + bef.size());
719 }
720 #endif  // BEF_EXECUTABLE
721 
722 using OutputInfoMap =
723     absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
724 static Status GetMlirAllocationInfo(mlir::FuncOp func,
725                                     std::vector<BufferAllocation>* allocations,
726                                     OutputInfoMap* output_info,
727                                     Shape* output_shape);
728 
729 struct CompileModuleResults {
730   std::unique_ptr<llvm::Module> llvm_module;
731   std::unique_ptr<BufferAssignment> buffer_assignment;
732   std::vector<BufferAllocation> allocations;
733   absl::variant<OwnedThunkSchedule, tfrt::BefBuffer> thunks_or_bef;
734   std::vector<GpuExecutable::ConstantInfo> constants;
735   OutputInfoMap output_info;
736   Shape output_shape;
737   std::string module_name;
738 };
739 // The order of `thunk_sequence` corresponds to
740 // `hlo_schedule->ThunkLaunchOrder()`.
CompileModuleToLlvmIrImpl(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer_function,int pointer_size,const HloProfileIndexMap * profile_index_map,CompileModuleResults * results)741 static Status CompileModuleToLlvmIrImpl(
742     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
743     const std::string& target_triple, const std::string& data_layout,
744     const std::string& platform_name, GpuDeviceInfo gpu_device_info,
745     se::CudaComputeCapability cuda_compute_capability,
746     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
747     int pointer_size, const HloProfileIndexMap* profile_index_map,
748     CompileModuleResults* results) {
749   results->llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
750   results->llvm_module->setTargetTriple(target_triple);
751   results->llvm_module->setDataLayout(data_layout);
752 
753   std::unique_ptr<StreamAssignment> stream_assignment =
754       AssignStreams(*hlo_module);
755   TF_ASSIGN_OR_RETURN(
756       std::unique_ptr<GpuHloSchedule> hlo_schedule,
757       GpuHloSchedule::Build(hlo_module, *stream_assignment, pointer_size));
758 
759   auto buffer_size_bytes_function =
760       [pointer_size](const BufferValue& buffer_value) -> int64 {
761     return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size);
762   };
763 
764   TF_ASSIGN_OR_RETURN(
765       results->buffer_assignment,
766       BufferAssigner::Run(
767           hlo_module, hlo_schedule->ConsumeHloOrdering(),
768           buffer_size_bytes_function,
769           /*color_alignment=*/
770           [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
771           /*allocate_buffers_for_constants=*/true,
772           /*colorer=*/BufferAssigner::DefaultColorer(),
773           /*must_not_live_out=*/{}, can_share_buffer_function));
774 
775   VLOG(1) << "Buffer Assignment Stats "
776           << results->buffer_assignment->GetStats().ToString();
777   DumpHloModuleIfEnabled(*hlo_module, *results->buffer_assignment,
778                          absl::StrCat("sm_", cuda_compute_capability.ToString(),
779                                       "_gpu_after_optimizations"));
780 
781   mlir::MLIRContext mlir_context;
782   mlir_context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
783                            mlir::StandardOpsDialect,
784                            mlir::lmhlo_gpu::LmhloGpuDialect>();
785   mlir::OwningModuleRef mlir_module =
786       mlir::ModuleOp::create(mlir::Builder(&mlir_context).getUnknownLoc());
787 
788   TF_RETURN_IF_ERROR(
789       HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module));
790 
791   results->module_name = mlir::GetNameFromLoc(mlir_module->getLoc());
792 
793   llvm_ir::DumpIrIfEnabled(mlir_module.get(), hlo_module->unique_id(),
794                            hlo_module->config().debug_options());
795 
796   auto entry_function = mlir::cast<mlir::FuncOp>(
797       mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
798 
799   TF_RETURN_IF_ERROR(
800       GetMlirAllocationInfo(entry_function, &results->allocations,
801                             &results->output_info, &results->output_shape));
802 
803   IrEmitterContext ir_emitter_context(
804       /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr, platform_name,
805       gpu_device_info, cuda_compute_capability, profile_index_map,
806       &mlir_context, results->llvm_module.get());
807 
808   ir_emitter_context.set_allocations(results->allocations);
809 
810   TF_ASSIGN_OR_RETURN(
811       auto ir_emitter,
812       IrEmitterUnnested::Create(hlo_module->config(), &ir_emitter_context));
813 
814   {
815     XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
816 
817     TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.body()));
818 
819     results->constants = std::move(ir_emitter_context.constants());
820   }
821 
822 #if BEF_EXECUTABLE
823   TF_ASSIGN_OR_RETURN(results->thunks_or_bef, LowerToBef(*mlir_module));
824 #else   // BEF_EXECUTABLE
825   results->thunks_or_bef =
826       absl::make_unique<ThunkSchedule>(ir_emitter->ConsumeThunkSequence());
827 #endif  // BEF_EXECUTABLE
828 
829   return Status::OK();
830 }
831 
NullDiagnosticHandler(const llvm::DiagnosticInfo & diag_info,void * context)832 static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
833                                   void* context) {
834   std::string error_string;
835   llvm::raw_string_ostream string_printer(error_string);
836   llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
837   diag_info.print(diagnostic_printer);
838 
839   VLOG(1) << error_string;
840 }
841 
842 StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileToTargetBinary(const HloModuleConfig & module_config,std::unique_ptr<llvm::Module> llvm_module,se::StreamExecutor * stream_exec,const CompileOptions & options,const HloModule * debug_module)843 GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
844                                    std::unique_ptr<llvm::Module> llvm_module,
845                                    se::StreamExecutor* stream_exec,
846                                    const CompileOptions& options,
847                                    const HloModule* debug_module) {
848   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
849 
850   const auto compile_single_module =
851       [this, stream_exec, &module_config, debug_module](
852           llvm::Module* llvm_module, bool relocatable,
853           absl::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
854     {
855       XLA_SCOPED_LOGGING_TIMER(
856           "GpuCompiler::RunBackend - Running LLVM verifier");
857 
858       llvm_module->getContext().setDiagnosticHandlerCallBack(
859           NullDiagnosticHandler, nullptr);
860 
861       std::string err;
862       llvm::raw_string_ostream err_stream(err);
863 
864       // verifyModule() returns true if the module is broken.
865       TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
866           << "Invalid LLVM IR before optimizations:\n"
867           << err_stream.str()
868           << "\nThis probably indicates a bug in the HLO -> LLVM IR "
869              "lowering. Rerun with --xla_dump_to to get the IR"
870           << (debug_module
871                   ? absl::StrCat(" and looks for files with name containing: *",
872                                  FilenameFor(*debug_module, "", ""), "*")
873                   : ".");
874     }
875     GpuVersion gpu_version = GetGpuVersion(stream_exec);
876     StatusOr<std::pair<std::string, std::vector<uint8>>> result =
877         CompileTargetBinary(module_config, llvm_module, gpu_version,
878                             stream_exec, relocatable, debug_module);
879 
880     if (!result.ok()) {
881       return result;
882     }
883 
884     const bool should_dump =
885         DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
886                                    module_config.debug_options());
887 
888     if (should_dump) {
889       if (debug_module) {
890         if (shard_number.has_value()) {
891           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
892                                    /*optimized=*/true,
893                                    std::to_string(*shard_number));
894         } else {
895           llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
896                                    /*optimized=*/true);
897         }
898       } else {
899         LOG(ERROR)
900             << "Dumping is not implemented since the file name cannot be "
901                "inferred. Please implement (potentially MLIR) module -> "
902                "filename heuristic.";
903       }
904     }
905 
906     if (user_post_optimization_hook_) {
907       user_post_optimization_hook_(*llvm_module);
908     }
909 
910     // Write PTX to IR dump directory, if IR dumping was requested.
911     if (should_dump) {
912       absl::string_view ptx = result->first;
913       if (debug_module) {
914         if (shard_number.has_value()) {
915           DumpToFileInDirOrStdout(*debug_module, "",
916                                   std::to_string(*shard_number) + ".ptx", ptx);
917         } else {
918           DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
919         }
920       } else {
921         LOG(ERROR)
922             << "Dumping is not implemented since the file name cannot be "
923                "inferred. Please implement (potentially MLIR) module -> "
924                "filename heuristic.";
925       }
926     }
927 
928     return result;
929   };
930 
931   tensorflow::thread::ThreadPool* thread_pool;
932   absl::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
933   switch (
934       module_config.debug_options().xla_gpu_force_compilation_parallelism()) {
935     case 0:
936       thread_pool = options.thread_pool;
937       break;
938     case 1:
939       thread_pool = nullptr;
940       break;
941     default:
942       overriding_thread_pool.emplace(
943           tensorflow::Env::Default(), "",
944           module_config.debug_options()
945               .xla_gpu_force_compilation_parallelism());
946       thread_pool = &*overriding_thread_pool;
947       break;
948   }
949 
950   if (!thread_pool) {
951     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
952                                  /*shard_number=*/absl::nullopt);
953   }
954 
955   // Test whether LinkModules is supported.
956   if (this->LinkModules(stream_exec, {}).status().code() ==
957       tensorflow::error::Code::UNIMPLEMENTED) {
958     return compile_single_module(llvm_module.get(), /*relocatable=*/false,
959                                  /*shard_number=*/absl::nullopt);
960   }
961 
962   std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
963   int num_functions = 0;
964   for (llvm::Function& func : llvm_module->functions()) {
965     if (!func.isDeclaration() &&
966         func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
967       num_functions++;
968     }
969   }
970 
971   llvm::SplitModule(
972       *llvm_module.get(),
973       std::max<unsigned>(
974           1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
975       [&](std::unique_ptr<llvm::Module> module) {
976         llvm_modules.push_back(std::move(module));
977       },
978       /*PreserveLocals=*/true);
979 
980   std::vector<StatusOr<BackendCompileResult>> compile_results(
981       llvm_modules.size());
982   tensorflow::BlockingCounter counter(llvm_modules.size());
983   for (int i = 0; i < llvm_modules.size(); i++) {
984     thread_pool->Schedule(
985         [&compile_results, compile_single_module, i, &llvm_modules, &counter] {
986           llvm::Module* original_module = llvm_modules[i].get();
987           llvm::LLVMContext context;
988           std::string buffer;
989           llvm::raw_string_ostream error(buffer);
990 
991           std::unique_ptr<llvm::Module> new_llvm_module;
992           // Switch to a new context by dumping and re-parsing LLVM IR. Each
993           // thread has its own context to avoid race conditions.
994           {
995             std::string ir;
996             {
997               llvm::raw_string_ostream os(ir);
998               original_module->print(os, nullptr);
999             }
1000             llvm::SMDiagnostic err;
1001             new_llvm_module = llvm::parseAssemblyString(ir, err, context);
1002             if (!new_llvm_module) {
1003               std::string err_string;
1004               llvm::raw_string_ostream os(err_string);
1005               err.print(/*ProgName=*/nullptr, os, /*ShowColors=*/false);
1006               LOG(FATAL) << "Failed to parse IR: " << err_string;
1007             }
1008           }
1009 
1010           compile_results[i] = compile_single_module(
1011               new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
1012           counter.DecrementCount();
1013         });
1014   }
1015   counter.Wait();
1016 
1017   std::string ptx_snippets;
1018   std::vector<std::vector<uint8>> submodule_compile_results;
1019   for (auto& maybe_result : compile_results) {
1020     TF_ASSIGN_OR_RETURN(auto result, maybe_result);
1021     if (result.second.empty()) {
1022       continue;
1023     }
1024     ptx_snippets += result.first;
1025     ptx_snippets += "\n";
1026     submodule_compile_results.push_back(result.second);
1027   }
1028 
1029   auto maybe_backend_result =
1030       this->LinkModules(stream_exec, std::move(submodule_compile_results));
1031   if (!maybe_backend_result.ok()) {
1032     LOG(ERROR) << "The CUDA linking API did not work. Please use "
1033                   "XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to "
1034                   "bypass it, but expect to get longer compilation time due to "
1035                   "the lack of multi-threading.";
1036     return maybe_backend_result.status();
1037   }
1038 
1039   return std::make_pair(ptx_snippets, std::move(*maybe_backend_result));
1040 }
1041 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)1042 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
1043     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
1044     const CompileOptions& options) {
1045   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
1046   std::string slow_compilation_msg =
1047       absl::StrCat("Compiling module ", module->name());
1048   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
1049 
1050   TF_RET_CHECK(stream_exec != nullptr);
1051 
1052   llvm::LLVMContext llvm_context;
1053 
1054   GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
1055 
1056   std::unique_ptr<HloProfileIndexMap> profile_index_map;
1057   std::unique_ptr<HloProfilePrinterData> profile_printer;
1058 
1059   if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
1060     HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
1061     cost_analysis.set_bytes_per_second(
1062         stream_exec->GetDeviceDescription().memory_bandwidth());
1063     TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
1064     VLOG(1) << "HLO memory read+written: "
1065             << tensorflow::strings::HumanReadableNumBytes(
1066                    cost_analysis.bytes_accessed());
1067     if (module->config().hlo_profiling_enabled()) {
1068       LOG(ERROR) << "--xla_hlo_profile for GPU is unsupported.";
1069     }
1070   }
1071 
1072   CompileModuleResults compile_module_results;
1073   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1074       module.get(), &llvm_context, target_triple_, data_layout_,
1075       stream_exec->platform()->Name(), gpu_device_info,
1076       stream_exec->GetDeviceDescription().cuda_compute_capability(),
1077       GetCanShareBuffer(), pointer_size_, profile_index_map.get(),
1078       &compile_module_results));
1079 
1080   if (user_pre_optimization_hook_) {
1081     user_pre_optimization_hook_(*compile_module_results.llvm_module);
1082   }
1083   string ir_module_string_before_opt;
1084   const bool embed_ir_in_executable =
1085       module->config().debug_options().xla_embed_ir_in_executable();
1086   if (embed_ir_in_executable) {
1087     ir_module_string_before_opt =
1088         llvm_ir::DumpModuleToString(*compile_module_results.llvm_module);
1089   }
1090 
1091   llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module,
1092                            /*optimized=*/false);
1093 
1094   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
1095   TF_ASSIGN_OR_RETURN(
1096       BackendCompileResult backend_result,
1097       CompileToTargetBinary(module->config(),
1098                             std::move(compile_module_results.llvm_module),
1099                             stream_exec, options, module.get()));
1100   if (DumpingEnabledForHloModule(*module) &&
1101       absl::holds_alternative<OwnedThunkSchedule>(
1102           compile_module_results.thunks_or_bef)) {
1103     const ThunkSchedule& thunk_schedule =
1104         *absl::get<OwnedThunkSchedule>(compile_module_results.thunks_or_bef);
1105     DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
1106                             thunk_schedule.ToString());
1107   }
1108 
1109   auto buffer_assignment_proto = std::make_unique<BufferAssignmentProto>(
1110       compile_module_results.buffer_assignment->ToProto());
1111 
1112   size_t profile_index = 0;
1113   if (profile_index_map) {
1114     profile_index =
1115         profile_index_map->GetProfileIndexFor(*module->entry_computation());
1116   }
1117 
1118   GpuVersion gpu_version = GetGpuVersion(stream_exec);
1119   auto* gpu_executable = new GpuExecutable(
1120       {std::move(backend_result.first), std::move(backend_result.second),
1121        gpu_version, std::move(compile_module_results.thunks_or_bef),
1122        std::move(compile_module_results.constants),
1123        std::move(compile_module_results.output_info),
1124        compile_module_results.module_name, compile_module_results.output_shape,
1125        std::move(compile_module_results.allocations),
1126        std::move(buffer_assignment_proto), std::move(module), profile_index,
1127        std::move(profile_printer), std::move(profile_index_map)});
1128   if (embed_ir_in_executable) {
1129     DCHECK_NE("", ir_module_string_before_opt);
1130     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
1131   }
1132 
1133   // Dump computation proto state and buffer assignment for debug and test, if
1134   // dump is enabled.
1135   if (DumpingEnabledForHloModule(gpu_executable->module())) {
1136     auto hlo_proto = absl::make_unique<HloProto>();
1137     *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto();
1138     *hlo_proto->mutable_buffer_assignment() =
1139         compile_module_results.buffer_assignment->ToProto();
1140     gpu_executable->set_hlo_proto(std::move(hlo_proto));
1141   }
1142   gpu_executable->set_debug_info(
1143       compile_module_results.buffer_assignment->GetStats().ToString());
1144   return std::unique_ptr<Executable>(gpu_executable);
1145 }
1146 
GetGpuDeviceInfo(se::StreamExecutor * stream_exec)1147 GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) {
1148   GpuDeviceInfo gpu_device_info;
1149   gpu_device_info.threads_per_block_limit =
1150       stream_exec->GetDeviceDescription().threads_per_block_limit();
1151   gpu_device_info.threads_per_warp =
1152       stream_exec->GetDeviceDescription().threads_per_warp();
1153   gpu_device_info.shared_memory_per_block =
1154       stream_exec->GetDeviceDescription().shared_memory_per_block();
1155   gpu_device_info.threads_per_core_limit =
1156       stream_exec->GetDeviceDescription().threads_per_core_limit();
1157   gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
1158   gpu_device_info.block_dim_limit_x =
1159       stream_exec->GetDeviceDescription().block_dim_limit().x;
1160   gpu_device_info.block_dim_limit_y =
1161       stream_exec->GetDeviceDescription().block_dim_limit().y;
1162   gpu_device_info.block_dim_limit_z =
1163       stream_exec->GetDeviceDescription().block_dim_limit().z;
1164   return gpu_device_info;
1165 }
1166 
1167 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)1168 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
1169                                 const AotCompilationOptions& options) {
1170   return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
1171 }
1172 
CompileModuleToLlvmIr(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,se::CudaComputeCapability cuda_compute_capability,int pointer_size)1173 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
1174     HloModule* hlo_module, llvm::LLVMContext* llvm_context,
1175     const std::string& target_triple, const std::string& data_layout,
1176     const std::string& platform_name, GpuDeviceInfo gpu_device_info,
1177     se::CudaComputeCapability cuda_compute_capability, int pointer_size) {
1178   CompileModuleResults results;
1179   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
1180       hlo_module, llvm_context, target_triple, data_layout, platform_name,
1181       gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
1182       pointer_size, /*profile_index_map=*/nullptr, &results));
1183   return std::move(results.llvm_module);
1184 }
1185 
1186 // Analyze the function signature to reconstruct a vector of BufferAllocation
1187 // objects, as well as other output information.
1188 //
1189 // This function also serves as a half-baked verifier for function arg
1190 // attributes, since a full verifier doens't exist yet.
GetMlirAllocationInfo(mlir::FuncOp func,std::vector<BufferAllocation> * allocations,OutputInfoMap * output_info,Shape * output_shape)1191 static Status GetMlirAllocationInfo(mlir::FuncOp func,
1192                                     std::vector<BufferAllocation>* allocations,
1193                                     OutputInfoMap* output_info,
1194                                     Shape* output_shape) {
1195   CHECK(allocations->empty());
1196   allocations->reserve(func.getNumArguments());
1197 
1198   for (int i = 0; i < func.getNumArguments(); i++) {
1199     mlir::BlockArgument arg = func.getArgument(i);
1200 
1201     TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
1202     mlir::ShapedType type = arg.getType().cast<mlir::ShapedType>();
1203     TF_ASSIGN_OR_RETURN(auto element_type_bytes,
1204                         GetElementTypeBytes(type.getElementType()));
1205     size_t size = type.getNumElements() * element_type_bytes;
1206     allocations->emplace_back(i, size, 0);
1207   }
1208 
1209   for (int i = 0; i < func.getNumArguments(); i++) {
1210     for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
1211       TF_RET_CHECK(attr.first == "lmhlo.params" ||
1212                    attr.first == "lmhlo.param_shape_index" ||
1213                    attr.first == "lmhlo.constant_name" ||
1214                    attr.first == "lmhlo.must_alias" ||
1215                    attr.first == "lmhlo.output_index");
1216     }
1217   }
1218 
1219   std::vector<std::pair<ShapeIndex, Shape>> sub_shapes;
1220   for (int i = 0; i < func.getNumArguments(); i++) {
1221     if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1222       xla::ShapeIndex shape_index;
1223       if (auto shape_index_attr =
1224               func.getArgAttrOfType<mlir::DenseIntElementsAttr>(
1225                   i, "lmhlo.param_shape_index")) {
1226         for (const llvm::APInt& element : shape_index_attr) {
1227           shape_index.push_back(element.getSExtValue());
1228         }
1229       }
1230       allocations->at(i).set_entry_computation_parameter(
1231           param_attr.cast<mlir::IntegerAttr>().getInt(), shape_index,
1232           static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
1233     }
1234     // TODO(timshen): this information is redundant. This is here only for
1235     // smooth migration to LMHLO. Remove it.
1236     if (func.getArgAttr(i, "lmhlo.constant_name")) {
1237       allocations->at(i).set_constant(true);
1238     }
1239     if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
1240       allocations->at(i).set_maybe_live_out(true);
1241 
1242       // Reconstruct a shape index from output_index.
1243       ShapeIndex shape_index;
1244       for (const llvm::APInt& element :
1245            output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1246         shape_index.push_back(element.getSExtValue());
1247       }
1248       auto& o = (*output_info)[shape_index];
1249       o.allocation_index = i;
1250       if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1251         HloInputOutputAliasConfig::AliasKind kind =
1252             HloInputOutputAliasConfig::kMayAlias;
1253         if (func.getArgAttr(i, "lmhlo.must_alias")) {
1254           kind = HloInputOutputAliasConfig::kMustAlias;
1255         }
1256         o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
1257                                ShapeIndex{}, kind);
1258       }
1259       if (func.getArgument(i).use_empty()) {
1260         o.passthrough = true;
1261       }
1262 
1263       mlir::BlockArgument arg = func.getArgument(i);
1264       sub_shapes.push_back(std::make_pair(shape_index, GetShape(arg)));
1265     }
1266   }
1267   // Expects result_xla_shape as a XLA shape in string form.
1268   //
1269   // The attribute is necessary, because GpuExecutable/ExecutionOutput supports
1270   // tuples / tree-like shapes, while the LMHLO argument list loses the tree
1271   // form.
1272   //
1273   // The string format is necessary since MLIR doesn't support XLA shape with
1274   // dynamic_dimension.
1275   //
1276   // TODO(timshen): now this field is mandatory. Make it optional for
1277   // non-GpuExecutable outputs.
1278   TF_ASSIGN_OR_RETURN(
1279       *output_shape,
1280       ParseShape(func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
1281                      .getValue()
1282                      .str()));
1283 
1284   return Status::OK();
1285 }
1286 
CompileLmhloToExecutable(GpuCompiler * compiler,mlir::ModuleOp module,std::string module_name,const HloModuleConfig & module_config,const Compiler::CompileOptions & options,absl::string_view entry_function_name,se::StreamExecutor * stream_exec,std::unique_ptr<llvm::Module> llvm_module,IrEmitterContext * ir_emitter_context)1287 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
1288     GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
1289     const HloModuleConfig& module_config,
1290     const Compiler::CompileOptions& options,
1291     absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
1292     std::unique_ptr<llvm::Module> llvm_module,
1293     IrEmitterContext* ir_emitter_context) {
1294   mlir::FuncOp entry_function = mlir::cast<mlir::FuncOp>(module.lookupSymbol(
1295       llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
1296 
1297   std::vector<BufferAllocation> allocations;
1298   OutputInfoMap output_info;
1299   Shape output_shape;
1300   TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
1301                                            &output_info, &output_shape));
1302 
1303   TF_RET_CHECK(!allocations.empty());
1304 
1305   ir_emitter_context->set_allocations(allocations);
1306 
1307   TF_ASSIGN_OR_RETURN(auto ir_emitter, IrEmitterUnnested::Create(
1308                                            module_config, ir_emitter_context));
1309   TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.body()));
1310 
1311   auto thunk_schedule =
1312       absl::make_unique<ThunkSchedule>(ir_emitter->ConsumeThunkSequence());
1313 
1314   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
1315   TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
1316                       compiler->CompileToTargetBinary(
1317                           module_config, std::move(llvm_module), stream_exec,
1318                           options, /*debug_module=*/nullptr));
1319 
1320   GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
1321   auto* gpu_executable = new GpuExecutable(
1322       {std::move(backend_result.first), std::move(backend_result.second),
1323        gpu_version, std::move(thunk_schedule),
1324        std::move(ir_emitter_context->constants()), std::move(output_info),
1325        module_name, output_shape, std::move(allocations)});
1326   return std::unique_ptr<Executable>(gpu_executable);
1327 }
1328 
1329 }  // namespace gpu
1330 }  // namespace xla
1331