1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
17
18 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
19 #include "tensorflow/compiler/xla/service/call_inliner.h"
20 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
21 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
22 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
26 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
27 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
28 #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
29 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
30 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
31 #include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
32 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_rewriter.h"
33 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
34 #include "tensorflow/compiler/xla/service/hlo_cse.h"
35 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
36 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
37 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
39 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
40 #include "tensorflow/core/platform/rocm_rocdl_path.h"
41
42 namespace xla {
43 namespace gpu {
44
45 namespace {
46
47 // Returns the directory containing ROCm-Device-Libs files. This function is
48 // called in AMDGPUCompiler's constructor, so can't return an error. But
49 // AMDGPUCompiler::Compile will return an error when the wanted rocdl file
50 // doesn't exist in the folder this function returns.
GetROCDLDir(const HloModuleConfig & config)51 std::string GetROCDLDir(const HloModuleConfig& config) {
52 std::vector<std::string> potential_rocdl_dirs;
53 const std::string datadir = config.debug_options().xla_gpu_cuda_data_dir();
54 if (!datadir.empty()) {
55 potential_rocdl_dirs.push_back(datadir);
56 }
57 potential_rocdl_dirs.push_back(tensorflow::RocdlRoot());
58
59 // Tries all potential ROCDL directories in the order they are inserted.
60 // Returns the first directory that exists in the file system.
61 for (const std::string& potential_rocdl_dir : potential_rocdl_dirs) {
62 if (tensorflow::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) {
63 VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir;
64 return potential_rocdl_dir;
65 }
66 VLOG(2) << "Unable to find potential ROCm-Device-Libs dir "
67 << potential_rocdl_dir;
68 }
69
70 // Last resort: maybe in the current folder.
71 return ".";
72 }
73
74 } // namespace
75
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)76 Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
77 HloModule* hlo_module, se::StreamExecutor* stream_exec,
78 se::DeviceMemoryAllocator* device_allocator) {
79 // Convert convolutions into CustomCalls to MIOpen, then canonicalize them
80 // (PadInsertion).
81 HloPassPipeline pipeline("conv_canonicalization");
82 pipeline.AddInvariantCheckerDebug<HloVerifier>(
83 /*layout_sensitive=*/false,
84 /*allow_mixed_precision=*/false);
85 pipeline.AddPass<GpusolverRewriter>();
86 pipeline.AddPass<GpuConvRewriter>();
87 pipeline.AddPass<GpuConvPaddingLegalization>();
88
89 // The conv padding/vectorization passes which we need to get rid of. They
90 // also leave behind unnecessary tuple/get-tuple-element pairs that
91 // TupleSimplifier fixes.
92 pipeline.AddPass<CallInliner>();
93 pipeline.AddPass<TupleSimplifier>();
94
95 // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
96 // introduces reshapes and transposes that can be eliminated using
97 // AlgebraicSimplifier We run algsimp to a fixed point.
98 AlgebraicSimplifierOptions options;
99 options.set_enable_conv_operand_swap(false);
100 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
101
102 pipeline.AddPass<HloConstantFolding>();
103 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
104
105 return Status::OK();
106 }
107
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)108 Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
109 HloModule* hlo_module, se::StreamExecutor* stream_exec,
110 se::DeviceMemoryAllocator* device_allocator) {
111 TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
112 hlo_module, stream_exec, device_allocator));
113
114 HloPassPipeline post_pipeline("AMDGPU post-layout_assignment");
115
116 // Transform TriangularSolve ops into custom-calls, so we can add temp
117 // memory.
118 post_pipeline.AddPass<TriangularSolveRewriter>();
119
120 TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
121
122 return Status::OK();
123 }
124
AMDGPUCompiler()125 AMDGPUCompiler::AMDGPUCompiler()
126 : GpuCompiler(stream_executor::rocm::kROCmPlatformId,
127 amdgpu::TargetTriple(), amdgpu::DataLayout()) {}
128
GetGpuVersion(se::StreamExecutor * stream_exec)129 GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
130 return stream_exec->GetDeviceDescription().rocm_compute_capability();
131 }
132
133 StatusOr<std::pair<std::string, std::vector<uint8_t>>>
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)134 AMDGPUCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
135 llvm::Module* llvm_module,
136 GpuVersion gpu_version,
137 se::StreamExecutor* stream_exec,
138 bool relocatable,
139 const HloModule* debug_module) {
140 if (rocdl_dir_.empty()) {
141 // Compute rocdl_dir_ just once and cache it in this member.
142 rocdl_dir_ = GetROCDLDir(module_config);
143 }
144
145 if (relocatable) {
146 return Unimplemented("relocatable target binary is not implemented");
147 }
148
149 std::vector<uint8_t> hsaco;
150 {
151 XLA_SCOPED_LOGGING_TIMER(
152 "AMDGPUCompiler::CompileTargetBinary - CompileToHsaco");
153 TF_ASSIGN_OR_RETURN(
154 hsaco, amdgpu::CompileToHsaco(llvm_module, gpu_version, module_config,
155 rocdl_dir_));
156 }
157
158 return std::pair<std::string, std::vector<uint8_t>>("", std::move(hsaco));
159 }
160
161 } // namespace gpu
162 } // namespace xla
163