1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
17
18 #include <stdlib.h>
19
20 #include <atomic>
21 #include <functional>
22 #include <utility>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/AsmParser/Parser.h"
28 #include "llvm/Bitcode/BitcodeReader.h"
29 #include "llvm/Bitcode/BitcodeWriter.h"
30 #include "llvm/IR/DiagnosticInfo.h"
31 #include "llvm/IR/DiagnosticPrinter.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/Verifier.h"
35 #include "llvm/Transforms/Utils/SplitModule.h"
36 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
37 #include "mlir/InitAllDialects.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
41 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
42 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
43 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
44 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
45 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
46 #include "tensorflow/compiler/xla/service/call_inliner.h"
47 #include "tensorflow/compiler/xla/service/comparison_expander.h"
48 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
49 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
50 #include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
51 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
52 #include "tensorflow/compiler/xla/service/dump.h"
53 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
54 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
55 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
56 #include "tensorflow/compiler/xla/service/gather_expander.h"
57 #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
58 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
59 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
60 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
61 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
62 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
63 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
64 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
65 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
66 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
67 #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
68 #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
69 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
70 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
71 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
72 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
73 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
74 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
75 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
76 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
77 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
78 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
79 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
80 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
81 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
82 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
83 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
84 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
85 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
86 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
87 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
88 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
89 #include "tensorflow/compiler/xla/service/hlo_computation.h"
90 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
91 #include "tensorflow/compiler/xla/service/hlo_cse.h"
92 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
93 #include "tensorflow/compiler/xla/service/hlo_dce.h"
94 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
95 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
96 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
97 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
98 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
99 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
100 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
101 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
102 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
103 #include "tensorflow/compiler/xla/service/logistic_expander.h"
104 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
105 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
106 #include "tensorflow/compiler/xla/service/qr_expander.h"
107 #include "tensorflow/compiler/xla/service/reshape_mover.h"
108 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
109 #include "tensorflow/compiler/xla/service/rng_expander.h"
110 #include "tensorflow/compiler/xla/service/slice_sinker.h"
111 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
112 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
113 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
114 #include "tensorflow/compiler/xla/service/transpose_folding.h"
115 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
116 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
117 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
118 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
119 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
120 #include "tensorflow/compiler/xla/status_macros.h"
121 #include "tensorflow/compiler/xla/types.h"
122 #include "tensorflow/compiler/xla/util.h"
123 #include "tensorflow/core/lib/core/status.h"
124 #include "tensorflow/core/lib/gtl/cleanup.h"
125 #include "tensorflow/core/lib/io/path.h"
126 #include "tensorflow/core/platform/blocking_counter.h"
127 #include "tensorflow/core/platform/env.h"
128 #include "tensorflow/core/platform/logging.h"
129 #include "tensorflow/core/platform/regexp.h"
130 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
131 #include "tensorflow/core/platform/subprocess.h"
132 #include "tensorflow/core/platform/threadpool.h"
133 #include "tensorflow/core/platform/tracing.h"
134 #include "tensorflow/core/profiler/lib/traceme.h"
135 #include "tensorflow/core/util/env_var.h"
136
137 namespace xla {
138 namespace gpu {
139
GpuCompiler(se::Platform::Id platform_id,const char * target_triple,const char * data_layout)140 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
141 const char* target_triple, const char* data_layout)
142 : platform_id_(platform_id),
143 target_triple_(target_triple),
144 data_layout_(data_layout),
145 pointer_size_(llvm::DataLayout(data_layout)
146 .getPointerSize(0 /* default address space */)) {}
147
148 // Runs optimization passes on the given HLO module.
OptimizeHloModule(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)149 Status GpuCompiler::OptimizeHloModule(
150 HloModule* hlo_module, se::StreamExecutor* stream_exec,
151 se::DeviceMemoryAllocator* device_allocator) {
152 {
153 HloPassPipeline pipeline("optimization");
154 pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
155 /*allow_mixed_precision=*/false);
156
157 pipeline.AddPass<AllGatherDecomposer>(
158 [](const HloAllGatherInstruction& ag) {
159 return !NcclAllGatherThunk::CanImplement(&ag);
160 });
161 pipeline.AddPass<AllToAllDecomposer>();
162
163 pipeline.AddPass<OperandUpcaster>();
164
165 // Expand random number generation.
166 pipeline.AddPass<RngExpander>();
167 pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
168
169 // Comparison total order expander
170 pipeline.AddPass<ComparisonExpander>();
171
172 // Remove zero-sized HLO from the input so that other passes don't have to
173 // handle it.
174 pipeline.AddPass<ZeroSizedHloElimination>();
175
176 pipeline.AddPass<GpuScatterExpander>();
177 // TODO(phawkins): replace QR decompositions with calls to cuSOLVER.
178 pipeline.AddPass<QrExpander>();
179
180 pipeline.AddPass<DynamicIndexSplitter>();
181
182 // TODO(b/64094172): make Call work on GPU instead of inlining.
183 pipeline.AddPass<CallInliner>();
184
185 pipeline.AddPass<DotDecomposer>();
186
187 pipeline.AddPass<Convolution4DExpander>();
188
189 // Expand the sort op to support stable sorting if required.
190 pipeline.AddPass<StableSortExpander>();
191 // Convert BF16 operations to F32 operations so that the GPU backend can
192 // support BF16 operations without directly implementing a BF16 lowering for
193 // most ops.
194 pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
195
196 // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
197 // where possible. Not every batchnorm op can be implemented as a call to
198 // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
199 if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
200 // Since BatchNorm inference is essentially pointwise operations, it is
201 // always advantageous to use kernel fusion rather than cudnn.
202 pipeline.AddPass<BatchNormExpander>(
203 /*rewrite_training_op=*/false,
204 /*rewrite_inference_op=*/true,
205 /*rewrite_grad_op=*/false);
206 pipeline.AddPass<CudnnBatchNormRewriter>();
207 }
208 pipeline.AddPass<BatchNormExpander>(
209 /*rewrite_training_op=*/true,
210 /*rewrite_inference_op=*/true,
211 /*rewrite_grad_op=*/true);
212
213 pipeline.AddPass<LogisticExpander>(
214 /*expansion_type=*/LogisticExpansionType::kExp);
215 pipeline.AddPass<ConditionalCanonicalizer>();
216 pipeline.AddPass<DynamicPadder>();
217
218 {
219 auto& pass =
220 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
221 pass.AddInvariantCheckerDebug<HloVerifier>(
222 /*layout_sensitive=*/false,
223 /*allow_mixed_precision=*/false);
224
225 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
226 // elimination has to come after that pass.
227 pass.AddPass<ZeroSizedHloElimination>();
228
229 pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
230 pass.AddPass<ScatterExpander>(ScatterExpander::kEliminateSimpleScatters);
231
232 AlgebraicSimplifierOptions options;
233 // When transposes appear in a fusion node, we can easily adjust the
234 // multi-dimensional index to create the one needed for the operand. This
235 // is not as easy with bitcasts, because we don't have the information
236 // readily available which dimensions are permuted. In addition to that,
237 // if we have a transpose and a reshape next to each other, they will both
238 // be replaced by a bitcast, and we replace bitcast(bitcast) with one
239 // bitcast. This leads to having to linearize and then delinearize the
240 // index.
241 options.set_replace_transpose_with_bitcast(false);
242 options.set_enable_conv_operand_swap(false);
243 pass.AddPass<AlgebraicSimplifier>(options);
244 // AlgebraicSimplifier may add contracting dimensions to a dot.
245 pass.AddPass<DotDecomposer>();
246 pass.AddPass<SortSimplifier>();
247 pass.AddPass<TupleSimplifier>();
248 pass.AddPass<WhileLoopConstantSinking>();
249 pass.AddPass<WhileLoopSimplifier>();
250
251 // TODO(b/134075051): Re-enable after b/134075051 is fixed.
252 // pass.AddPass<SliceSinker>();
253
254 pass.AddPass<HloDCE>();
255 pass.AddPass<ReshapeMover>();
256 pass.AddPass<HloConstantFolding>();
257 pass.AddPass<ConditionalSimplifier>();
258 }
259
260 pipeline.AddPass<TransposeFolding>(
261 [](const HloInstruction& dot,
262 const TransposeFolding::OperandIndices& candidate_operands) {
263 return IsMatrixMultiplication(dot)
264 ? candidate_operands
265 : TransposeFolding::OperandIndices{};
266 });
267 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
268 pipeline.AddPass<HloDCE>();
269
270 // Run WhileLoopTripCountAnnotator at the end of the simplification
271 // pipeline, before layout assignment and fusion. This pass does some
272 // pattern-matching on while bodies/conditions, and this is where the HLO is
273 // "nicest".
274 //
275 // It's important that we don't make semantic changes (e.g. unrolling) to
276 // any `while` loops after this point, because otherwise the trip-count
277 // annotations added by this pass may not be correct after the
278 // modifications.
279 pipeline.AddPass<WhileLoopTripCountAnnotator>();
280 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
281 }
282
283 // Run target-specific HLO optimization passes for convolution
284 // canonicalization.
285 TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
286 hlo_module, stream_exec, device_allocator));
287
288 {
289 // Run layout assignment in a separate pipeline from
290 // "post-layout-assignment" because we want everything after layout
291 // assignment to have a layout-sensitive invariant-checker, but
292 // HloPassPipeline also runs its invariant checker before any passes are
293 // run, meaning, the pipeline that contains layout assignment cannot contain
294 // a layout-sensitive verifier!
295 HloPassPipeline pipeline("layout assignment");
296 // Layout assignment uses alias analysis, which requires the call graph to
297 // be flattened.
298 pipeline.AddPass<FlattenCallGraph>();
299 pipeline.AddPass<GpuLayoutAssignment>(
300 hlo_module->mutable_entry_computation_layout(),
301 LayoutAssignment::InstructionCanChangeLayout, stream_exec);
302 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
303 }
304
305 // Run target-specific HLO optimization passes after layout assignment.
306 TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec,
307 device_allocator));
308
309 {
310 HloPassFix<HloPassPipeline> fusion("fusion");
311 // We try to split variadic ops with many parameters into several such ops
312 // to avoid exceeding the parameter space.
313 fusion.AddPass<VariadicOpSplitter>();
314 /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
315 * fixing the ticket. */
316 fusion.AddInvariantCheckerDebug<HloVerifier>(
317 /*layout_sensitive=*/true,
318 /*allow_mixed_precision=*/false,
319 LayoutAssignment::InstructionCanChangeLayout);
320 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
321 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
322 fusion.AddPass<FusionMerger>();
323 fusion.AddPass<GpuMultiOutputFusion>();
324 fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
325 /*only_fusion_computations=*/true);
326 fusion.AddPass<HloDCE>();
327 TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
328
329 HloPassPipeline horizontal_fusion("horizontal_fusion");
330 horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
331 horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
332 horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
333 /*only_fusion_computations=*/true);
334 horizontal_fusion.AddPass<HloDCE>();
335 TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
336 }
337
338 {
339 HloPassPipeline pipeline("all_reduce_combiner");
340 pipeline.AddPass<AllReduceCombiner>(
341 /*combine_threshold_in_bytes=*/30 * 1024 * 1024,
342 /*combine_threshold_count=*/256);
343 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
344 }
345 {
346 // Now we allow to replace any transposes outside of fusions with bitcasts.
347 HloPassPipeline pipeline("final_algebraic_simplifier");
348 AlgebraicSimplifierOptions options;
349 options.set_is_layout_sensitive(true);
350 options.set_enable_conv_operand_swap(false);
351 pipeline.AddPass<AlgebraicSimplifier>(options);
352 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
353 }
354 return Status::OK();
355 }
356
357 // Modifies the given HLO module so that it will be accepted by IrEmitter.
358 // Unlike optimization passes, the passes are necessary for correctness.
PrepareHloModuleForIrEmitting(HloModule * hlo_module)359 Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
360 // In some cases, we have to place the result of an instruction in a temporary
361 // buffer. For instance, the buffer that holds an external parameter is
362 // assumed immutable at this point, and should not be reused for output
363 // (b/27180329). Therefore, in that case, we set the output to be a copy of
364 // the parameter.
365 HloPassPipeline pipeline("GPU-ir-emit-prepare");
366 /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
367 * fixing the ticket. */
368 pipeline.AddInvariantCheckerDebug<HloVerifier>(
369 /*layout_sensitive=*/true,
370 /*allow_mixed_precision=*/false,
371 LayoutAssignment::InstructionCanChangeLayout);
372
373 // Copy insertion should be performed immediately before IR emission to avoid
374 // inserting unnecessary copies (later pass adds an instruction which
375 // materializes the value) or missing a necessary copy (later pass removes an
376 // instruction which materializes a value). DCE must be run immediately before
377 // (and sometime after) copy insertion, to avoid dead code from interfering
378 // with the rewrites.
379 pipeline.AddPass<HloDCE>();
380 if (hlo_module->config().alias_passthrough_params()) {
381 pipeline.AddPass<AliasPassthroughParams>();
382 }
383 pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
384 pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
385 pipeline.AddPass<GpuSanitizeConstantNames>();
386 return pipeline.Run(hlo_module).status();
387 }
388
389 // TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a
390 // right way to share this.
RequireDeterminism()391 static bool RequireDeterminism() {
392 static bool require_determinism = [] {
393 bool deterministic_ops = false;
394 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
395 /*default_val=*/false,
396 &deterministic_ops));
397 return deterministic_ops;
398 }();
399 return require_determinism;
400 }
401
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)402 Status GpuCompiler::OptimizeHloPostLayoutAssignment(
403 HloModule* hlo_module, se::StreamExecutor* stream_exec,
404 se::DeviceMemoryAllocator* device_allocator) {
405 HloPassPipeline pipeline("post-layout_assignment");
406 /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
407 * fixing the ticket. */
408 pipeline.AddInvariantCheckerDebug<HloVerifier>(
409 /*layout_sensitive=*/true,
410 /*allow_mixed_precision=*/false,
411 LayoutAssignment::InstructionCanChangeLayout);
412
413 pipeline.AddPass<ReductionDegenerateDimRemover>();
414 pipeline.AddPass<ReductionLayoutNormalizer>();
415 pipeline.AddPass<ReductionDimensionGrouper>();
416 pipeline.AddPass<HloPassFix<ReductionSplitter>>();
417
418 // The LayoutAssignment pass may leave behind kCopy instructions which are
419 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
420 AlgebraicSimplifierOptions options;
421 options.set_is_layout_sensitive(true);
422 // When transposes appear in a fusion node, we can easily adjust the
423 // multi-dimensional index to create the one needed for the operand. This
424 // is not as easy with bitcasts, because we don't have the information
425 // readily available which dimensions are permuted. In addition to that,
426 // if we have a transpose and a reshape next to each other, they will both
427 // be replaced by a bitcast, and we replace bitcast(bitcast) with one
428 // bitcast. This leads to having to linearize and then delinearize the
429 // index.
430 options.set_replace_transpose_with_bitcast(false);
431 options.set_enable_conv_operand_swap(false);
432 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
433
434 if (RequireDeterminism() ||
435 hlo_module->config().debug_options().xla_gpu_deterministic_reductions() ||
436 hlo_module->config().debug_options().xla_gpu_deterministic_ops()) {
437 pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>();
438 }
439
440 // GemmRewriter assumes that all transposes are folded into gemms, but,
441 // since commit 7d529df, this is not always true at this point.
442 // Therefore, rerun transpose folding.
443 pipeline.AddPass<TransposeFolding>(
444 [](const HloInstruction& dot,
445 const TransposeFolding::OperandIndices& candidate_operands) {
446 return IsMatrixMultiplication(dot) ? candidate_operands
447 : TransposeFolding::OperandIndices{};
448 },
449 TransposeFolding::NeverFoldTranspose);
450 // Rewrite GEMMs into custom calls.
451 pipeline.AddPass<GemmRewriter>();
452
453 // Choose the fastest algorithm for each conv.
454 //
455 // We pick the algorithm before fusion so we can generate better HLO. After
456 // GpuConvRewriter, our convolutions are CustomCalls which return a
457 // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
458 // scratch:
459 //
460 // customcall = (f32[...], f32[0])
461 // return gte(customcall, 0)
462 //
463 // The algorithm picker then chooses the best algorithm, and potentially
464 // increases the scratch space. It replaces customcall with new_tuple,
465 // giving us the following:
466 //
467 // new_customcall = (f32[...], f32[N])
468 // new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
469 // return gte(new_tuple, 0)
470 //
471 // The new tuple and gte instructions then be simplified away, because
472 // nobody is expected to use the scratch value.
473 //
474 // However, if we were to run GpuConvAlgorithmPicker after fusion
475 // the gte(customcall, 0) would probably already be into a fusion node. We
476 // can't simplify across HloComputation boundaries, so in this case we
477 // wouldn't be able to simplify away the new_tuple bits.
478 pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
479
480 // Clean up new_tuple described above.
481 pipeline.AddPass<TupleSimplifier>();
482
483 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
484 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
485
486 return Status::OK();
487 }
488
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)489 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
490 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
491 const CompileOptions& options) {
492 // We dump the post-optimization HLO in RunBackend so no need to dump it here.
493 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
494 tensorflow::profiler::TraceMe activity(
495 [&] { return absl::StrCat("HLO Transforms:", module->name()); },
496 tensorflow::profiler::TraceMeLevel::kInfo);
497 TF_RETURN_IF_ERROR(
498 OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
499
500 TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
501
502 return std::move(module);
503 }
504
DummyCanShareBufferFunction(const HloInstruction *,const HloInstruction *,const ShapeIndex &)505 static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
506 const HloInstruction*,
507 const ShapeIndex&) {
508 return absl::nullopt;
509 }
510
511 StatusOr<
512 std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)513 GpuCompiler::RunHloPassesAndBufferAssignement(
514 std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
515 bool optimize, const CompileOptions& options) {
516 if (optimize) {
517 TF_ASSIGN_OR_RETURN(hlo_module,
518 RunHloPasses(std::move(hlo_module), executor, options));
519 }
520
521 std::unique_ptr<StreamAssignment> stream_assignment =
522 AssignStreams(*hlo_module);
523 TF_ASSIGN_OR_RETURN(
524 std::unique_ptr<GpuHloSchedule> hlo_schedule,
525 GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size_));
526
527 auto buffer_size_bytes_function =
528 [this](const BufferValue& buffer_value) -> int64 {
529 return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size_);
530 };
531
532 TF_ASSIGN_OR_RETURN(
533 std::unique_ptr<BufferAssignment> assignment,
534 BufferAssigner::Run(
535 hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
536 buffer_size_bytes_function,
537 /*color_alignment=*/
538 [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
539 /*allocate_buffers_for_constants=*/true,
540 /*colorer=*/BufferAssigner::DefaultColorer(),
541 /*must_not_live_out=*/{}, DummyCanShareBufferFunction));
542
543 return std::make_tuple(std::move(hlo_module), std::move(assignment));
544 }
545
546 // The order of `thunk_sequence` corresponds to
547 // `hlo_schedule->ThunkLaunchOrder()`.
CompileModuleToLlvmIrImpl(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,absl::optional<CudaComputeCapability> cuda_compute_capability,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer_function,int pointer_size,const HloProfileIndexMap * profile_index_map,std::unique_ptr<llvm::Module> * llvm_module,std::unique_ptr<BufferAssignment> * buffer_assignment,std::unique_ptr<ThunkSchedule> * thunk_schedule,std::vector<GpuExecutable::ConstantInfo> * constants)548 static Status CompileModuleToLlvmIrImpl(
549 HloModule* hlo_module, llvm::LLVMContext* llvm_context,
550 const std::string& target_triple, const std::string& data_layout,
551 const std::string& platform_name, GpuDeviceInfo gpu_device_info,
552 absl::optional<CudaComputeCapability> cuda_compute_capability,
553 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
554 int pointer_size, const HloProfileIndexMap* profile_index_map,
555 std::unique_ptr<llvm::Module>* llvm_module,
556 std::unique_ptr<BufferAssignment>* buffer_assignment,
557 std::unique_ptr<ThunkSchedule>* thunk_schedule,
558 std::vector<GpuExecutable::ConstantInfo>* constants) {
559 *llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
560
561 (*llvm_module)->setTargetTriple(target_triple);
562 (*llvm_module)->setDataLayout(data_layout);
563
564 std::unique_ptr<StreamAssignment> stream_assignment =
565 AssignStreams(*hlo_module);
566 TF_ASSIGN_OR_RETURN(
567 std::unique_ptr<GpuHloSchedule> hlo_schedule,
568 GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size));
569
570 auto buffer_size_bytes_function =
571 [pointer_size](const BufferValue& buffer_value) -> int64 {
572 return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size);
573 };
574
575 TF_ASSIGN_OR_RETURN(
576 *buffer_assignment,
577 BufferAssigner::Run(
578 hlo_module, hlo_schedule->ConsumeHloOrdering(),
579 buffer_size_bytes_function,
580 /*color_alignment=*/
581 [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
582 /*allocate_buffers_for_constants=*/true,
583 /*colorer=*/BufferAssigner::DefaultColorer(),
584 /*must_not_live_out=*/{}, can_share_buffer_function));
585
586 VLOG(1) << "Buffer Assignment Stats "
587 << (*buffer_assignment)->GetStats().ToString();
588 DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
589 "after_optimizations");
590
591 mlir::MLIRContext mlir_context;
592 mlir_context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
593 mlir::StandardOpsDialect,
594 mlir::lmhlo_gpu::LmhloGpuDialect>();
595
596 IrEmitterContext ir_emitter_context(
597 hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
598 cuda_compute_capability, profile_index_map, &mlir_context,
599 llvm_module->get());
600
601 HloComputation* entry_computation = hlo_module->entry_computation();
602
603 TF_ASSIGN_OR_RETURN(
604 auto ir_emitter,
605 IrEmitterUnnested::Create(hlo_module->config(), entry_computation,
606 &ir_emitter_context));
607
608 {
609 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
610
611 absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo;
612 ThunkSequence thunk_sequence;
613 absl::Span<HloInstruction* const> order = hlo_schedule->ThunkLaunchOrder();
614 for (HloInstruction* instruction : order) {
615 TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get()));
616 TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction));
617 std::unique_ptr<ThunkSequence> thunks =
618 ir_emitter->ConsumeThunkSequence();
619
620 // The invariants between each input HloInstruction* and output Thunk* are
621 // not all explicitly checked, but at least we can document them here:
622 // * The entry HloComputation shall not have dead code (all reachable from
623 // ROOT).
624 // * The visited instructions are all instructions in the entry
625 // computation.
626 // * For each visit of these HloInstructions, either none or one Thunk
627 // will be returned.
628 // * If there is a thunk returned, thunk->hlo_instruction_ equals the
629 // input HloInstruction*.
630 // * A returned thunk may contain other sub-thunks. A sub-thunk may or may
631 // not have an associated hlo_instruction_.
632 TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
633 if (!thunks->empty()) {
634 auto thunk = std::move(thunks->front());
635 InsertOrDie(&thunk_to_hlo, thunk.get(), instruction);
636 thunk_sequence.push_back(std::move(thunk));
637 }
638 }
639 // TODO(timshen): ThunkSchedule taking thunk_to_hlo is a bit awkward. To fix
640 // that, we can turn it into a proper pass, from:
641 // map<Thunk, HloInstruction> -> (ThunkSchedule, [Thunk...])
642 // to:
643 // map<Thunk, HloInstruction> -> GenerateMultiStreamDepInfo() -> [(Thunk,
644 // DepInfo)...]
645 //
646 // where "DepInfo" is
647 // struct {
648 // int stream_number;
649 // std::vector<Thunk*> dependencies;
650 // std::vector<Thunk*> users;
651 // };
652 // We might want to do this after MLIR migration.
653 *thunk_schedule = absl::make_unique<ThunkSchedule>(
654 std::make_unique<ThunkSequence>(std::move(thunk_sequence)),
655 std::move(stream_assignment), std::move(thunk_to_hlo));
656
657 if (constants) {
658 *constants = std::move(ir_emitter_context.constants());
659 }
660 }
661
662 return Status::OK();
663 }
664
NullDiagnosticHandler(const llvm::DiagnosticInfo & diag_info,void * context)665 static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info,
666 void* context) {
667 std::string error_string;
668 llvm::raw_string_ostream string_printer(error_string);
669 llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer);
670 diag_info.print(diagnostic_printer);
671
672 VLOG(1) << error_string;
673 }
674
675 StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileToTargetBinary(const HloModuleConfig & module_config,std::unique_ptr<llvm::Module> llvm_module,se::StreamExecutor * stream_exec,const CompileOptions & options,const HloModule * debug_module)676 GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
677 std::unique_ptr<llvm::Module> llvm_module,
678 se::StreamExecutor* stream_exec,
679 const CompileOptions& options,
680 const HloModule* debug_module) {
681 using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
682
683 const auto compile_single_module =
684 [this, stream_exec, &module_config, debug_module](
685 llvm::Module* llvm_module, bool relocatable,
686 absl::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
687 {
688 XLA_SCOPED_LOGGING_TIMER(
689 "GpuCompiler::RunBackend - Running LLVM verifier");
690
691 llvm_module->getContext().setDiagnosticHandlerCallBack(
692 NullDiagnosticHandler, nullptr);
693
694 std::string err;
695 llvm::raw_string_ostream err_stream(err);
696
697 // verifyModule() returns true if the module is broken.
698 TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
699 << "Invalid LLVM IR before optimizations:\n"
700 << err_stream.str()
701 << "\nThis probably indicates a bug in the HLO -> LLVM IR "
702 "lowering. Rerun with --xla_dump_to to get the IR"
703 << (debug_module
704 ? absl::StrCat(" and looks for files with name containing: *",
705 FilenameFor(*debug_module, "", ""), "*")
706 : ".");
707 }
708 GpuVersion gpu_version = GetGpuVersion(stream_exec);
709 StatusOr<std::pair<std::string, std::vector<uint8>>> result =
710 CompileTargetBinary(module_config, llvm_module, gpu_version,
711 stream_exec, relocatable, debug_module);
712
713 if (!result.ok()) {
714 return result;
715 }
716
717 const bool should_dump =
718 DumpingEnabledForHloModule(debug_module ? debug_module->name() : "",
719 module_config.debug_options());
720
721 if (should_dump) {
722 if (debug_module) {
723 if (shard_number.has_value()) {
724 llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
725 /*optimized=*/true,
726 std::to_string(*shard_number));
727 } else {
728 llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
729 /*optimized=*/true);
730 }
731 } else {
732 LOG(ERROR)
733 << "Dumping is not implemented since the file name cannot be "
734 "inferred. Please implement (potentially MLIR) module -> "
735 "filename heuristic.";
736 }
737 }
738
739 if (user_post_optimization_hook_) {
740 user_post_optimization_hook_(*llvm_module);
741 }
742
743 // Write PTX to IR dump directory, if IR dumping was requested.
744 if (should_dump) {
745 absl::string_view ptx = result->first;
746 if (debug_module) {
747 if (shard_number.has_value()) {
748 DumpToFileInDirOrStdout(*debug_module, "",
749 std::to_string(*shard_number) + ".ptx", ptx);
750 } else {
751 DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
752 }
753 } else {
754 LOG(ERROR)
755 << "Dumping is not implemented since the file name cannot be "
756 "inferred. Please implement (potentially MLIR) module -> "
757 "filename heuristic.";
758 }
759 }
760
761 return result;
762 };
763
764 tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
765
766 absl::optional<tensorflow::thread::ThreadPool> overriding_thread_pool;
767 if (module_config.debug_options().xla_gpu_force_compilation_parallelism() !=
768 0) {
769 overriding_thread_pool.emplace(
770 tensorflow::Env::Default(), "",
771 module_config.debug_options().xla_gpu_force_compilation_parallelism());
772 thread_pool = &*overriding_thread_pool;
773 }
774
775 if (!thread_pool) {
776 return compile_single_module(llvm_module.get(), /*relocatable=*/false,
777 /*shard_number=*/absl::nullopt);
778 }
779
780 // Test whether LinkModules is supported.
781 if (this->LinkModules(stream_exec, {}).status().code() ==
782 tensorflow::error::Code::UNIMPLEMENTED) {
783 return compile_single_module(llvm_module.get(), /*relocatable=*/false,
784 /*shard_number=*/absl::nullopt);
785 }
786
787 std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
788 int num_functions = 0;
789 for (llvm::Function& func : llvm_module->functions()) {
790 if (!func.isDeclaration() &&
791 func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
792 num_functions++;
793 }
794 }
795
796 llvm::SplitModule(
797 *llvm_module.get(),
798 std::max<unsigned>(
799 1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
800 [&](std::unique_ptr<llvm::Module> module) {
801 llvm_modules.push_back(std::move(module));
802 },
803 /*PreserveLocals=*/true);
804
805 std::vector<StatusOr<BackendCompileResult>> compile_results(
806 llvm_modules.size());
807 tensorflow::BlockingCounter counter(llvm_modules.size());
808 for (int i = 0; i < llvm_modules.size(); i++) {
809 thread_pool->Schedule(
810 [&compile_results, compile_single_module, i, &llvm_modules, &counter] {
811 llvm::Module* original_module = llvm_modules[i].get();
812 llvm::LLVMContext context;
813 std::string buffer;
814 llvm::raw_string_ostream error(buffer);
815
816 std::unique_ptr<llvm::Module> new_llvm_module;
817 // Switch to a new context by dumping and re-parsing LLVM IR. Each
818 // thread has its own context to avoid race conditions.
819 {
820 std::string ir;
821 {
822 llvm::raw_string_ostream os(ir);
823 original_module->print(os, nullptr);
824 }
825 llvm::SMDiagnostic err;
826 new_llvm_module = llvm::parseAssemblyString(ir, err, context);
827 }
828
829 compile_results[i] = compile_single_module(
830 new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
831 counter.DecrementCount();
832 });
833 }
834 counter.Wait();
835
836 std::string ptx_snippets;
837 std::vector<std::vector<uint8>> submodule_compile_results;
838 for (auto& maybe_result : compile_results) {
839 TF_ASSIGN_OR_RETURN(auto result, maybe_result);
840 if (result.second.empty()) {
841 continue;
842 }
843 ptx_snippets += result.first;
844 ptx_snippets += "\n";
845 submodule_compile_results.push_back(result.second);
846 }
847
848 TF_ASSIGN_OR_RETURN(
849 std::vector<uint8> backend_result,
850 this->LinkModules(stream_exec, std::move(submodule_compile_results)));
851
852 return std::make_pair(ptx_snippets, backend_result);
853 }
854
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)855 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
856 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
857 const CompileOptions& options) {
858 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
859 std::string slow_compilation_msg =
860 absl::StrCat("Compiling module ", module->name());
861 auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
862
863 TF_RET_CHECK(stream_exec != nullptr);
864
865 llvm::LLVMContext llvm_context;
866
867 GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
868
869 absl::optional<CudaComputeCapability> cuda_compute_capability =
870 [&]() -> absl::optional<CudaComputeCapability> {
871 CudaComputeCapability cuda_compute_capability;
872 stream_exec->GetDeviceDescription().cuda_compute_capability(
873 &cuda_compute_capability.cc_major, &cuda_compute_capability.cc_minor);
874 if (cuda_compute_capability.cc_major == -1) {
875 return absl::nullopt;
876 }
877 return cuda_compute_capability;
878 }();
879
880 std::unique_ptr<HloProfileIndexMap> profile_index_map;
881 std::unique_ptr<HloProfilePrinterData> profile_printer;
882
883 if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
884 HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
885 cost_analysis.set_bytes_per_second(
886 stream_exec->GetDeviceDescription().memory_bandwidth());
887 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
888 VLOG(1) << "HLO memory read+written: "
889 << tensorflow::strings::HumanReadableNumBytes(
890 cost_analysis.bytes_accessed());
891 if (module->config().hlo_profiling_enabled()) {
892 profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
893 profile_printer =
894 CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
895 module->entry_computation()->name());
896 }
897 }
898
899 std::unique_ptr<llvm::Module> llvm_module;
900 std::unique_ptr<BufferAssignment> buffer_assignment;
901 std::unique_ptr<ThunkSchedule> thunk_schedule;
902 std::vector<GpuExecutable::ConstantInfo> constants;
903
904 TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
905 module.get(), &llvm_context, target_triple_, data_layout_,
906 stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
907 GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
908 &buffer_assignment, &thunk_schedule, &constants));
909
910 if (user_pre_optimization_hook_) {
911 user_pre_optimization_hook_(*llvm_module);
912 }
913 string ir_module_string_before_opt;
914 const bool embed_ir_in_executable =
915 module->config().debug_options().xla_embed_ir_in_executable();
916 if (embed_ir_in_executable) {
917 ir_module_string_before_opt = llvm_ir::DumpModuleToString(*llvm_module);
918 }
919
920 llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
921
922 using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
923 TF_ASSIGN_OR_RETURN(
924 BackendCompileResult backend_result,
925 CompileToTargetBinary(module->config(), std::move(llvm_module),
926 stream_exec, options, module.get()));
927 if (DumpingEnabledForHloModule(*module)) {
928 DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
929 thunk_schedule->ToString());
930 }
931
932 using OutputInfoMap =
933 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
934 TF_ASSIGN_OR_RETURN(OutputInfoMap output_info,
935 GetOutputInfo(*module, *buffer_assignment));
936 auto buffer_assignment_proto =
937 std::make_unique<BufferAssignmentProto>(buffer_assignment->ToProto());
938 std::vector<BufferAllocation> allocations =
939 buffer_assignment->ReleaseAllocations();
940 std::string module_name = module->name();
941 Shape output_shape = module->entry_computation()->root_instruction()->shape();
942 size_t profile_index = 0;
943 if (profile_index_map) {
944 profile_index =
945 profile_index_map->GetProfileIndexFor(*module->entry_computation());
946 }
947
948 GpuVersion gpu_version = GetGpuVersion(stream_exec);
949 auto* gpu_executable = new GpuExecutable(
950 {std::move(backend_result.first), std::move(backend_result.second),
951 gpu_version, std::move(thunk_schedule), std::move(constants),
952 std::move(output_info), module_name, output_shape,
953 std::move(allocations), std::move(buffer_assignment_proto),
954 std::move(module), profile_index, std::move(profile_printer),
955 std::move(profile_index_map)});
956 if (embed_ir_in_executable) {
957 DCHECK_NE("", ir_module_string_before_opt);
958 gpu_executable->set_ir_module_string(ir_module_string_before_opt);
959 }
960 return std::unique_ptr<Executable>(gpu_executable);
961 }
962
GetGpuDeviceInfo(se::StreamExecutor * stream_exec)963 GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) {
964 GpuDeviceInfo gpu_device_info;
965 gpu_device_info.threads_per_block_limit =
966 stream_exec->GetDeviceDescription().threads_per_block_limit();
967 gpu_device_info.threads_per_warp =
968 stream_exec->GetDeviceDescription().threads_per_warp();
969 gpu_device_info.shared_memory_per_block =
970 stream_exec->GetDeviceDescription().shared_memory_per_block();
971 gpu_device_info.threads_per_core_limit =
972 stream_exec->GetDeviceDescription().threads_per_core_limit();
973 gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
974 return gpu_device_info;
975 }
976
977 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)978 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
979 const AotCompilationOptions& options) {
980 return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
981 }
982
CompileModuleToLlvmIr(HloModule * hlo_module,llvm::LLVMContext * llvm_context,const std::string & target_triple,const std::string & data_layout,const std::string & platform_name,GpuDeviceInfo gpu_device_info,absl::optional<CudaComputeCapability> cuda_compute_capability,int pointer_size)983 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
984 HloModule* hlo_module, llvm::LLVMContext* llvm_context,
985 const std::string& target_triple, const std::string& data_layout,
986 const std::string& platform_name, GpuDeviceInfo gpu_device_info,
987 absl::optional<CudaComputeCapability> cuda_compute_capability,
988 int pointer_size) {
989 std::unique_ptr<llvm::Module> llvm_module;
990 std::unique_ptr<BufferAssignment> buffer_assignment;
991 std::unique_ptr<ThunkSchedule> thunk_schedule;
992
993 TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
994 hlo_module, llvm_context, target_triple, data_layout, platform_name,
995 gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
996 pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
997 &buffer_assignment, &thunk_schedule, nullptr));
998 return llvm_module;
999 }
1000
1001 // Analyze the function signature to reconstruct a vector of BufferAllocation
1002 // objects, as well as other output information.
1003 //
1004 // This function also serves as a half-baked verifier for function arg
1005 // attributes, since a full verifier doens't exist yet.
GetMlirAllocationInfo(mlir::FuncOp func,std::vector<BufferAllocation> * allocations,absl::flat_hash_map<ShapeIndex,GpuExecutable::OutputInfo> * output_info,Shape * output_shape)1006 static Status GetMlirAllocationInfo(
1007 mlir::FuncOp func, std::vector<BufferAllocation>* allocations,
1008 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
1009 Shape* output_shape) {
1010 std::vector<absl::optional<BufferAllocation>> maybe_allocations;
1011
1012 for (int i = 0; i < func.getNumArguments(); i++) {
1013 auto allocation_index_attr =
1014 func.getArgAttr(i, "lmhlo.alloc").dyn_cast_or_null<mlir::IntegerAttr>();
1015 TF_RET_CHECK(allocation_index_attr);
1016 int index = allocation_index_attr.getInt();
1017 if (index >= maybe_allocations.size()) {
1018 maybe_allocations.resize(index + 1);
1019 }
1020 mlir::BlockArgument arg = func.getArgument(i);
1021 TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
1022 size_t size = arg.getType().cast<mlir::ShapedType>().getSizeInBits() / 8;
1023 maybe_allocations[index].emplace(index, size, 0);
1024 }
1025
1026 allocations->reserve(maybe_allocations.size());
1027 for (auto& maybe_alloc : maybe_allocations) {
1028 if (maybe_alloc.has_value()) {
1029 allocations->push_back(*maybe_alloc);
1030 } else {
1031 return InvalidArgument("Allocation indices should range in [0, n)");
1032 }
1033 }
1034
1035 for (int i = 0; i < func.getNumArguments(); i++) {
1036 for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
1037 TF_RET_CHECK(attr.first == "lmhlo.alloc" ||
1038 attr.first == "lmhlo.params" ||
1039 attr.first == "lmhlo.output_index");
1040 }
1041 }
1042
1043 std::vector<Shape> output_shapes;
1044 absl::optional<int> rank;
1045 for (int i = 0; i < func.getNumArguments(); i++) {
1046 auto index =
1047 func.getArgAttr(i, "lmhlo.alloc").cast<mlir::IntegerAttr>().getInt();
1048 if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1049 allocations->at(index).set_entry_computation_parameter(
1050 param_attr.cast<mlir::IntegerAttr>().getInt(), {},
1051 static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
1052 }
1053 if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
1054 allocations->at(index).set_maybe_live_out(true);
1055
1056 // Reconstruct a shape index from output_index.
1057 ShapeIndex shape_index;
1058 for (const llvm::APInt& i :
1059 output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
1060 shape_index.push_back(i.getSExtValue());
1061 }
1062 if (rank.has_value()) {
1063 if (*rank != shape_index.size()) {
1064 return InvalidArgument("Expect output_index to have the same ranks");
1065 }
1066 } else {
1067 rank.emplace(shape_index.size());
1068 }
1069 auto& o = (*output_info)[shape_index];
1070 o.allocation_index = index;
1071 if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
1072 o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
1073 ShapeIndex{});
1074 }
1075
1076 if (shape_index.size() > 1) {
1077 return Unimplemented("Expect array type or 1-level tuple type");
1078 }
1079
1080 mlir::BlockArgument arg = func.getArgument(i);
1081 if (shape_index.empty()) {
1082 output_shapes.push_back(TypeToShape(arg.getType()));
1083 } else {
1084 if (shape_index[0] >= output_shapes.size()) {
1085 output_shapes.resize(shape_index[0] + 1);
1086 }
1087 output_shapes[shape_index[0]] = TypeToShape(arg.getType());
1088 }
1089 }
1090 }
1091 *output_shape = ShapeUtil::MakeTupleShape(output_shapes);
1092
1093 return Status::OK();
1094 }
1095
CompileLmhloToExecutable(GpuCompiler * compiler,mlir::ModuleOp module,std::string module_name,const HloModuleConfig & module_config,const Compiler::CompileOptions & options,absl::string_view entry_function_name,se::StreamExecutor * stream_exec,std::unique_ptr<llvm::Module> llvm_module,IrEmitterContext * ir_emitter_context)1096 StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
1097 GpuCompiler* compiler, mlir::ModuleOp module, std::string module_name,
1098 const HloModuleConfig& module_config,
1099 const Compiler::CompileOptions& options,
1100 absl::string_view entry_function_name, se::StreamExecutor* stream_exec,
1101 std::unique_ptr<llvm::Module> llvm_module,
1102 IrEmitterContext* ir_emitter_context) {
1103 mlir::FuncOp entry_function = mlir::cast<mlir::FuncOp>(module.lookupSymbol(
1104 llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
1105
1106 std::vector<BufferAllocation> allocations;
1107 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
1108 Shape output_shape;
1109 absl::flat_hash_map<ShapeIndex, int> output_to_argnum_map;
1110 TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
1111 &output_info, &output_shape));
1112
1113 CHECK(!allocations.empty());
1114
1115 ir_emitter_context->set_allocations(allocations);
1116
1117 TF_ASSIGN_OR_RETURN(
1118 auto ir_emitter,
1119 IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
1120 ir_emitter_context));
1121 ThunkSequence thunk_sequence;
1122 for (mlir::Operation& op :
1123 entry_function.getBody().front().without_terminator()) {
1124 MlirEmitterInput input;
1125 input.op = &op;
1126 TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));
1127 std::unique_ptr<ThunkSequence> thunks = ir_emitter->ConsumeThunkSequence();
1128 TF_RET_CHECK(thunks->size() <= 1);
1129 if (!thunks->empty()) {
1130 auto thunk = std::move(thunks->front());
1131 thunk_sequence.push_back(std::move(thunk));
1132 }
1133 }
1134 auto thunk_schedule = absl::make_unique<ThunkSchedule>(
1135 std::make_unique<ThunkSequence>(std::move(thunk_sequence)));
1136
1137 using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
1138 TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
1139 compiler->CompileToTargetBinary(
1140 module_config, std::move(llvm_module), stream_exec,
1141 options, /*debug_module=*/nullptr));
1142
1143 GpuVersion gpu_version = compiler->GetGpuVersion(stream_exec);
1144 auto* gpu_executable = new GpuExecutable(
1145 {std::move(backend_result.first), std::move(backend_result.second),
1146 gpu_version, std::move(thunk_schedule),
1147 std::move(ir_emitter_context->constants()), std::move(output_info),
1148 module_name, output_shape, std::move(allocations)});
1149 return std::unique_ptr<Executable>(gpu_executable);
1150 }
1151
1152 } // namespace gpu
1153 } // namespace xla
1154