1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining gpu target information useful
16 // for llvm IR contruction.
17
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30 // Utility functions to obtain NVPTX/AMDGPU specific information.
31 using absl::StrCat;
32
33 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
34 // On AMDGPU, some of these operations are made as device functions instead of
35 // intrinsics. Therefore a variant type is used to wrap the lambda to call
36 // those device functions.
37 struct TargetIntrinsics {
38 llvm::Intrinsic::ID nvptx_intrinsic;
39 absl::variant<llvm::Intrinsic::ID,
40 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
41 amdgpu_intrinsic_or_function;
42 };
43
44 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
45 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)46 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
47 switch (intrin) {
48 case TargetIntrinsicID::kThreadIdx: {
49 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
50 llvm::Intrinsic::amdgcn_workitem_id_x};
51 }
52 case TargetIntrinsicID::kThreadIdy: {
53 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
54 llvm::Intrinsic::amdgcn_workitem_id_y};
55 }
56 case TargetIntrinsicID::kThreadIdz: {
57 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
58 llvm::Intrinsic::amdgcn_workitem_id_z};
59 }
60 case TargetIntrinsicID::kBlockIdx: {
61 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
62 llvm::Intrinsic::amdgcn_workgroup_id_x};
63 }
64 case TargetIntrinsicID::kBlockIdy: {
65 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
66 llvm::Intrinsic::amdgcn_workgroup_id_y};
67 }
68 case TargetIntrinsicID::kBlockIdz: {
69 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
70 llvm::Intrinsic::amdgcn_workgroup_id_z};
71 }
72 case TargetIntrinsicID::kBarrierId: {
73 return {llvm::Intrinsic::nvvm_barrier0,
74 llvm::Intrinsic::amdgcn_s_barrier};
75 }
76 case TargetIntrinsicID::kBlockDimx: {
77 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
78 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
79 return EmitDeviceFunctionCall("__ockl_get_local_size",
80 {b_->getInt32(0)}, {U32}, U64, {},
81 b_);
82 }};
83 }
84 case TargetIntrinsicID::kBlockDimy: {
85 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
86 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
87 return EmitDeviceFunctionCall("__ockl_get_local_size",
88 {b_->getInt32(1)}, {U32}, U64, {},
89 b_);
90 }};
91 }
92 case TargetIntrinsicID::kBlockDimz: {
93 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
94 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
95 return EmitDeviceFunctionCall("__ockl_get_local_size",
96 {b_->getInt32(2)}, {U32}, U64, {},
97 b_);
98 }};
99 }
100 }
101 }
102
103 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
104 struct TargetDeviceFunction {
105 const string nvptx_root;
106 const string amdgpu_root;
107 };
108
109 // Gets the device function name on different platforms (NVPTX, AMDGPU)
110 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)111 struct TargetDeviceFunction GetDeviceFunctionRoot(
112 TargetDeviceFunctionID func_id) {
113 switch (func_id) {
114 case TargetDeviceFunctionID::kPow: {
115 return {"__nv_pow", "__ocml_pow"};
116 }
117 case TargetDeviceFunctionID::kErfcinv: {
118 return {"__nv_erfcinv", "__ocml_erfcinv"};
119 }
120 case TargetDeviceFunctionID::kLog: {
121 return {"__nv_log", "__ocml_log"};
122 }
123 case TargetDeviceFunctionID::kLog1p: {
124 return {"__nv_log1p", "__ocml_log1p"};
125 }
126 case TargetDeviceFunctionID::kSin: {
127 return {"__nv_sin", "__ocml_sin"};
128 }
129 case TargetDeviceFunctionID::kCos: {
130 return {"__nv_cos", "__ocml_cos"};
131 }
132 case TargetDeviceFunctionID::kExp: {
133 return {"__nv_exp", "__ocml_exp"};
134 }
135 case TargetDeviceFunctionID::kExpm1: {
136 return {"__nv_expm1", "__ocml_expm1"};
137 }
138 case TargetDeviceFunctionID::kSqrt: {
139 return {"__nv_sqrt", "__ocml_sqrt"};
140 }
141 case TargetDeviceFunctionID::kRsqrt: {
142 return {"__nv_rsqrt", "__ocml_rsqrt"};
143 }
144 case TargetDeviceFunctionID::kAtan2: {
145 return {"__nv_atan2", "__ocml_atan2"};
146 }
147 case TargetDeviceFunctionID::kFmod: {
148 return {"__nv_fmod", "__ocml_fmod"};
149 }
150 case TargetDeviceFunctionID::kRound: {
151 return {"__nv_round", "__ocml_round"};
152 }
153 case TargetDeviceFunctionID::kHypot: {
154 return {"__nv_hypot", "__ocml_hypot"};
155 }
156 }
157 }
158 } // namespace
159
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)160 string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
161 PrimitiveType output_type,
162 llvm::IRBuilder<>* b) {
163 // The device math functions differentiate between "double" and "float" by
164 // appending a double or float specific suffix to a root name. The suffix and
165 // the root name are specific to the target.
166 llvm::Triple target_triple =
167 llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
168 struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
169 if (target_triple.isNVPTX()) {
170 if (output_type == F32) {
171 return StrCat(gpu_root_names.nvptx_root, "f");
172 } else if (output_type == F64) {
173 return gpu_root_names.nvptx_root;
174 } else {
175 LOG(FATAL) << "Unexpected type while getting device function name.";
176 }
177 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
178 if (output_type == F32) {
179 return StrCat(gpu_root_names.amdgpu_root, "_f32");
180 } else if (output_type == F64) {
181 return StrCat(gpu_root_names.amdgpu_root, "_f64");
182 } else {
183 LOG(FATAL) << "Unexpected type while getting device function name.";
184 }
185 } else {
186 LOG(FATAL) << "Invalid triple " << target_triple.str();
187 }
188 }
189
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b)190 llvm::CallInst* EmitDeviceFunctionCall(
191 const string& callee_name, absl::Span<llvm::Value* const> operands,
192 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
193 absl::Span<const llvm::Attribute::AttrKind> attributes,
194 llvm::IRBuilder<>* b) {
195 std::vector<llvm::Type*> ir_input_types;
196 llvm::Module* module = b->GetInsertBlock()->getModule();
197 for (PrimitiveType input_type : input_types) {
198 ir_input_types.push_back(
199 llvm_ir::PrimitiveTypeToIrType(input_type, module));
200 }
201 llvm::FunctionType* callee_type = llvm::FunctionType::get(
202 llvm_ir::PrimitiveTypeToIrType(output_type, module), // Return type.
203 ir_input_types, // Parameter types.
204 false); // No variadic arguments.
205
206 // Declares the callee if it is not declared already.
207 llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
208 b->GetInsertBlock()
209 ->getModule()
210 ->getOrInsertFunction(callee_name, callee_type)
211 .getCallee());
212
213 for (auto attribute : attributes) {
214 callee->addFnAttr(attribute);
215 }
216
217 return b->CreateCall(callee, llvm_ir::AsArrayRef(operands));
218 }
219
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)220 llvm::CallInst* EmitCallToTargetIntrinsic(
221 TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
222 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
223 llvm::Module* module = b->GetInsertBlock()->getModule();
224 struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
225 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
226 llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
227 if (target_triple.isNVPTX()) {
228 llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
229 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
230 llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
231 absl::get_if<llvm::Intrinsic::ID>(
232 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
233 if (llvm_intrinsic_id_ptr) {
234 llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
235 } else {
236 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
237 absl::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
238 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
239 return (*builder_func)(b);
240 }
241 } else {
242 LOG(FATAL) << "Invalid triple " << target_triple.str();
243 }
244
245 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
246 module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
247 return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
248 }
249
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)250 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
251 llvm::IRBuilder<>* b) {
252 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
253 if (target_triple.isNVPTX()) {
254 // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
255 // treats function as a CUDA kernel.
256 llvm::LLVMContext& context = module->getContext();
257 llvm::NamedMDNode* nvvm_annotations_node =
258 module->getOrInsertNamedMetadata("nvvm.annotations");
259 nvvm_annotations_node->addOperand(llvm::MDNode::get(
260 context, {llvm::ConstantAsMetadata::get(func),
261 llvm::MDString::get(context, "kernel"),
262 llvm::ConstantAsMetadata::get(b->getInt32(1))}));
263
264 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
265 // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
266 func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
267 func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
268 } else {
269 LOG(FATAL) << "Invalid triple " << target_triple.str();
270 }
271 }
272
273 } // namespace gpu
274 } // namespace xla
275