1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining gpu target information useful
16 // for llvm IR contruction.
17
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30 // Utility functions to obtain NVPTX/AMDGPU specific information.
31 using absl::StrCat;
32
33 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
34 // On AMDGPU, some of these operations are made as device functions instead of
35 // intrinsics. Therefore a variant type is used to wrap the lambda to call
36 // those device functions.
37 struct TargetIntrinsics {
38 llvm::Intrinsic::ID nvptx_intrinsic;
39 absl::variant<llvm::Intrinsic::ID,
40 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
41 amdgpu_intrinsic_or_function;
42 };
43
44 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
45 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)46 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
47 switch (intrin) {
48 case TargetIntrinsicID::kThreadIdx: {
49 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
50 llvm::Intrinsic::amdgcn_workitem_id_x};
51 }
52 case TargetIntrinsicID::kThreadIdy: {
53 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
54 llvm::Intrinsic::amdgcn_workitem_id_y};
55 }
56 case TargetIntrinsicID::kThreadIdz: {
57 return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
58 llvm::Intrinsic::amdgcn_workitem_id_z};
59 }
60 case TargetIntrinsicID::kBlockIdx: {
61 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
62 llvm::Intrinsic::amdgcn_workgroup_id_x};
63 }
64 case TargetIntrinsicID::kBlockIdy: {
65 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
66 llvm::Intrinsic::amdgcn_workgroup_id_y};
67 }
68 case TargetIntrinsicID::kBlockIdz: {
69 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
70 llvm::Intrinsic::amdgcn_workgroup_id_z};
71 }
72 case TargetIntrinsicID::kBarrierId: {
73 return {llvm::Intrinsic::nvvm_barrier0,
74 llvm::Intrinsic::amdgcn_s_barrier};
75 }
76 case TargetIntrinsicID::kBlockDimx: {
77 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
78 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
79 return EmitDeviceFunctionCall("__ockl_get_local_size",
80 {b_->getInt32(0)}, {U32}, U64, {},
81 b_);
82 }};
83 }
84 case TargetIntrinsicID::kBlockDimy: {
85 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
86 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
87 return EmitDeviceFunctionCall("__ockl_get_local_size",
88 {b_->getInt32(1)}, {U32}, U64, {},
89 b_);
90 }};
91 }
92 case TargetIntrinsicID::kBlockDimz: {
93 return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
94 [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
95 return EmitDeviceFunctionCall("__ockl_get_local_size",
96 {b_->getInt32(2)}, {U32}, U64, {},
97 b_);
98 }};
99 }
100 }
101 }
102
103 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
104 struct TargetDeviceFunction {
105 const string nvptx_root;
106 const string amdgpu_root;
107 };
108
109 // Gets the device function name on different platforms (NVPTX, AMDGPU)
110 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)111 struct TargetDeviceFunction GetDeviceFunctionRoot(
112 TargetDeviceFunctionID func_id) {
113 switch (func_id) {
114 case TargetDeviceFunctionID::kAtan2: {
115 return {"__nv_atan2", "__ocml_atan2"};
116 }
117 case TargetDeviceFunctionID::kCos: {
118 return {"__nv_cos", "__ocml_cos"};
119 }
120 case TargetDeviceFunctionID::kErfcinv: {
121 return {"__nv_erfcinv", "__ocml_erfcinv"};
122 }
123 case TargetDeviceFunctionID::kExp: {
124 return {"__nv_exp", "__ocml_exp"};
125 }
126 case TargetDeviceFunctionID::kExpm1: {
127 return {"__nv_expm1", "__ocml_expm1"};
128 }
129 case TargetDeviceFunctionID::kFmod: {
130 return {"__nv_fmod", "__ocml_fmod"};
131 }
132 case TargetDeviceFunctionID::kHypot: {
133 return {"__nv_hypot", "__ocml_hypot"};
134 }
135 case TargetDeviceFunctionID::kLog: {
136 return {"__nv_log", "__ocml_log"};
137 }
138 case TargetDeviceFunctionID::kLog1p: {
139 return {"__nv_log1p", "__ocml_log1p"};
140 }
141 case TargetDeviceFunctionID::kPow: {
142 return {"__nv_pow", "__ocml_pow"};
143 }
144 case TargetDeviceFunctionID::kRound: {
145 return {"__nv_round", "__ocml_round"};
146 }
147 case TargetDeviceFunctionID::kRsqrt: {
148 return {"__nv_rsqrt", "__ocml_rsqrt"};
149 }
150 case TargetDeviceFunctionID::kSin: {
151 return {"__nv_sin", "__ocml_sin"};
152 }
153 case TargetDeviceFunctionID::kSqrt: {
154 return {"__nv_sqrt", "__ocml_sqrt"};
155 }
156 case TargetDeviceFunctionID::kTanh: {
157 return {"__nv_tanh", "__ocml_tanh"};
158 }
159 }
160 }
161 } // namespace
162
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)163 string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
164 PrimitiveType output_type,
165 llvm::IRBuilder<>* b) {
166 // The device math functions differentiate between "double" and "float" by
167 // appending a double or float specific suffix to a root name. The suffix and
168 // the root name are specific to the target.
169 llvm::Triple target_triple =
170 llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
171 struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
172 if (target_triple.isNVPTX()) {
173 if (output_type == F32) {
174 return StrCat(gpu_root_names.nvptx_root, "f");
175 } else if (output_type == F64) {
176 return gpu_root_names.nvptx_root;
177 } else {
178 LOG(FATAL) << "Unexpected type while getting device function name.";
179 }
180 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
181 if (output_type == F32) {
182 return StrCat(gpu_root_names.amdgpu_root, "_f32");
183 } else if (output_type == F64) {
184 return StrCat(gpu_root_names.amdgpu_root, "_f64");
185 } else {
186 LOG(FATAL) << "Unexpected type while getting device function name.";
187 }
188 } else {
189 LOG(FATAL) << "Invalid triple " << target_triple.str();
190 }
191 }
192
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b,absl::string_view name)193 llvm::CallInst* EmitDeviceFunctionCall(
194 const string& callee_name, absl::Span<llvm::Value* const> operands,
195 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
196 absl::Span<const llvm::Attribute::AttrKind> attributes,
197 llvm::IRBuilder<>* b, absl::string_view name) {
198 std::vector<llvm::Type*> ir_input_types;
199 llvm::Module* module = b->GetInsertBlock()->getModule();
200 for (PrimitiveType input_type : input_types) {
201 ir_input_types.push_back(
202 llvm_ir::PrimitiveTypeToIrType(input_type, module));
203 }
204 llvm::FunctionType* callee_type = llvm::FunctionType::get(
205 llvm_ir::PrimitiveTypeToIrType(output_type, module), // Return type.
206 ir_input_types, // Parameter types.
207 false); // No variadic arguments.
208
209 // Declares the callee if it is not declared already.
210 llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
211 b->GetInsertBlock()
212 ->getModule()
213 ->getOrInsertFunction(callee_name, callee_type)
214 .getCallee());
215
216 for (auto attribute : attributes) {
217 callee->addFnAttr(attribute);
218 }
219
220 return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
221 }
222
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)223 llvm::CallInst* EmitCallToTargetIntrinsic(
224 TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
225 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
226 llvm::Module* module = b->GetInsertBlock()->getModule();
227 struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
228 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
229 llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
230 if (target_triple.isNVPTX()) {
231 llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
232 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
233 llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
234 absl::get_if<llvm::Intrinsic::ID>(
235 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
236 if (llvm_intrinsic_id_ptr) {
237 llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
238 } else {
239 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
240 absl::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
241 &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
242 return (*builder_func)(b);
243 }
244 } else {
245 LOG(FATAL) << "Invalid triple " << target_triple.str();
246 }
247
248 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
249 module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
250 return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
251 }
252
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)253 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
254 llvm::IRBuilder<>* b) {
255 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
256 if (target_triple.isNVPTX()) {
257 // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
258 // treats function as a CUDA kernel.
259 llvm::LLVMContext& context = module->getContext();
260 llvm::NamedMDNode* nvvm_annotations_node =
261 module->getOrInsertNamedMetadata("nvvm.annotations");
262 nvvm_annotations_node->addOperand(llvm::MDNode::get(
263 context, {llvm::ConstantAsMetadata::get(func),
264 llvm::MDString::get(context, "kernel"),
265 llvm::ConstantAsMetadata::get(b->getInt32(1))}));
266
267 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
268 // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
269 func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
270 func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
271 } else {
272 LOG(FATAL) << "Invalid triple " << target_triple.str();
273 }
274 }
275
276 } // namespace gpu
277 } // namespace xla
278