• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining  gpu target information useful
16 // for llvm IR contruction.
17 
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 // Utility functions to obtain NVPTX/AMDGPU specific information.
31 using absl::StrCat;
32 
33 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
34 // On AMDGPU, some of these operations are made as device functions instead of
35 // intrinsics. Therefore a variant type is used to wrap the lambda to call
36 // those device functions.
37 struct TargetIntrinsics {
38   llvm::Intrinsic::ID nvptx_intrinsic;
39   absl::variant<llvm::Intrinsic::ID,
40                 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
41       amdgpu_intrinsic_or_function;
42 };
43 
44 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
45 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)46 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
47   switch (intrin) {
48     case TargetIntrinsicID::kThreadIdx: {
49       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
50               llvm::Intrinsic::amdgcn_workitem_id_x};
51     }
52     case TargetIntrinsicID::kThreadIdy: {
53       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
54               llvm::Intrinsic::amdgcn_workitem_id_y};
55     }
56     case TargetIntrinsicID::kThreadIdz: {
57       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
58               llvm::Intrinsic::amdgcn_workitem_id_z};
59     }
60     case TargetIntrinsicID::kBlockIdx: {
61       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
62               llvm::Intrinsic::amdgcn_workgroup_id_x};
63     }
64     case TargetIntrinsicID::kBlockIdy: {
65       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
66               llvm::Intrinsic::amdgcn_workgroup_id_y};
67     }
68     case TargetIntrinsicID::kBlockIdz: {
69       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
70               llvm::Intrinsic::amdgcn_workgroup_id_z};
71     }
72     case TargetIntrinsicID::kBarrierId: {
73       return {llvm::Intrinsic::nvvm_barrier0,
74               llvm::Intrinsic::amdgcn_s_barrier};
75     }
76     case TargetIntrinsicID::kBlockDimx: {
77       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
78               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
79                 return EmitDeviceFunctionCall("__ockl_get_local_size",
80                                               {b_->getInt32(0)}, {U32}, U64, {},
81                                               b_);
82               }};
83     }
84     case TargetIntrinsicID::kBlockDimy: {
85       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
86               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
87                 return EmitDeviceFunctionCall("__ockl_get_local_size",
88                                               {b_->getInt32(1)}, {U32}, U64, {},
89                                               b_);
90               }};
91     }
92     case TargetIntrinsicID::kBlockDimz: {
93       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
94               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
95                 return EmitDeviceFunctionCall("__ockl_get_local_size",
96                                               {b_->getInt32(2)}, {U32}, U64, {},
97                                               b_);
98               }};
99     }
100   }
101 }
102 
103 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
104 struct TargetDeviceFunction {
105   const string nvptx_root;
106   const string amdgpu_root;
107 };
108 
109 // Gets the device function name on different platforms (NVPTX, AMDGPU)
110 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)111 struct TargetDeviceFunction GetDeviceFunctionRoot(
112     TargetDeviceFunctionID func_id) {
113   switch (func_id) {
114     case TargetDeviceFunctionID::kAtan2: {
115       return {"__nv_atan2", "__ocml_atan2"};
116     }
117     case TargetDeviceFunctionID::kCos: {
118       return {"__nv_cos", "__ocml_cos"};
119     }
120     case TargetDeviceFunctionID::kErfcinv: {
121       return {"__nv_erfcinv", "__ocml_erfcinv"};
122     }
123     case TargetDeviceFunctionID::kExp: {
124       return {"__nv_exp", "__ocml_exp"};
125     }
126     case TargetDeviceFunctionID::kExpm1: {
127       return {"__nv_expm1", "__ocml_expm1"};
128     }
129     case TargetDeviceFunctionID::kFmod: {
130       return {"__nv_fmod", "__ocml_fmod"};
131     }
132     case TargetDeviceFunctionID::kHypot: {
133       return {"__nv_hypot", "__ocml_hypot"};
134     }
135     case TargetDeviceFunctionID::kLog: {
136       return {"__nv_log", "__ocml_log"};
137     }
138     case TargetDeviceFunctionID::kLog1p: {
139       return {"__nv_log1p", "__ocml_log1p"};
140     }
141     case TargetDeviceFunctionID::kPow: {
142       return {"__nv_pow", "__ocml_pow"};
143     }
144     case TargetDeviceFunctionID::kRound: {
145       return {"__nv_round", "__ocml_round"};
146     }
147     case TargetDeviceFunctionID::kRsqrt: {
148       return {"__nv_rsqrt", "__ocml_rsqrt"};
149     }
150     case TargetDeviceFunctionID::kSin: {
151       return {"__nv_sin", "__ocml_sin"};
152     }
153     case TargetDeviceFunctionID::kSqrt: {
154       return {"__nv_sqrt", "__ocml_sqrt"};
155     }
156     case TargetDeviceFunctionID::kTanh: {
157       return {"__nv_tanh", "__ocml_tanh"};
158     }
159   }
160 }
161 }  // namespace
162 
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)163 string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
164                                 PrimitiveType output_type,
165                                 llvm::IRBuilder<>* b) {
166   // The device math functions differentiate between "double" and "float" by
167   // appending a double or float specific suffix to a root name. The suffix and
168   // the root name are specific to the target.
169   llvm::Triple target_triple =
170       llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
171   struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
172   if (target_triple.isNVPTX()) {
173     if (output_type == F32) {
174       return StrCat(gpu_root_names.nvptx_root, "f");
175     } else if (output_type == F64) {
176       return gpu_root_names.nvptx_root;
177     } else {
178       LOG(FATAL) << "Unexpected type while getting device function name.";
179     }
180   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
181     if (output_type == F32) {
182       return StrCat(gpu_root_names.amdgpu_root, "_f32");
183     } else if (output_type == F64) {
184       return StrCat(gpu_root_names.amdgpu_root, "_f64");
185     } else {
186       LOG(FATAL) << "Unexpected type while getting device function name.";
187     }
188   } else {
189     LOG(FATAL) << "Invalid triple " << target_triple.str();
190   }
191 }
192 
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b,absl::string_view name)193 llvm::CallInst* EmitDeviceFunctionCall(
194     const string& callee_name, absl::Span<llvm::Value* const> operands,
195     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
196     absl::Span<const llvm::Attribute::AttrKind> attributes,
197     llvm::IRBuilder<>* b, absl::string_view name) {
198   std::vector<llvm::Type*> ir_input_types;
199   llvm::Module* module = b->GetInsertBlock()->getModule();
200   for (PrimitiveType input_type : input_types) {
201     ir_input_types.push_back(
202         llvm_ir::PrimitiveTypeToIrType(input_type, module));
203   }
204   llvm::FunctionType* callee_type = llvm::FunctionType::get(
205       llvm_ir::PrimitiveTypeToIrType(output_type, module),  // Return type.
206       ir_input_types,                                       // Parameter types.
207       false);  // No variadic arguments.
208 
209   // Declares the callee if it is not declared already.
210   llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
211       b->GetInsertBlock()
212           ->getModule()
213           ->getOrInsertFunction(callee_name, callee_type)
214           .getCallee());
215 
216   for (auto attribute : attributes) {
217     callee->addFnAttr(attribute);
218   }
219 
220   return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
221 }
222 
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)223 llvm::CallInst* EmitCallToTargetIntrinsic(
224     TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
225     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
226   llvm::Module* module = b->GetInsertBlock()->getModule();
227   struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
228   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
229   llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
230   if (target_triple.isNVPTX()) {
231     llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
232   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
233     llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
234         absl::get_if<llvm::Intrinsic::ID>(
235             &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
236     if (llvm_intrinsic_id_ptr) {
237       llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
238     } else {
239       std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
240           absl::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
241               &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
242       return (*builder_func)(b);
243     }
244   } else {
245     LOG(FATAL) << "Invalid triple " << target_triple.str();
246   }
247 
248   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
249       module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
250   return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
251 }
252 
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)253 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
254                                  llvm::IRBuilder<>* b) {
255   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
256   if (target_triple.isNVPTX()) {
257     // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
258     // treats function as a CUDA kernel.
259     llvm::LLVMContext& context = module->getContext();
260     llvm::NamedMDNode* nvvm_annotations_node =
261         module->getOrInsertNamedMetadata("nvvm.annotations");
262     nvvm_annotations_node->addOperand(llvm::MDNode::get(
263         context, {llvm::ConstantAsMetadata::get(func),
264                   llvm::MDString::get(context, "kernel"),
265                   llvm::ConstantAsMetadata::get(b->getInt32(1))}));
266 
267   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
268     // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
269     func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
270     func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
271   } else {
272     LOG(FATAL) << "Invalid triple " << target_triple.str();
273   }
274 }
275 
276 }  // namespace gpu
277 }  // namespace xla
278