• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Provide helper routine for obtaining  gpu target information useful
16 // for llvm IR contruction.
17 
18 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 // Utility functions to obtain NVPTX/AMDGPU specific information.
31 using absl::StrCat;
32 
33 // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms.
34 // On AMDGPU, some of these operations are made as device functions instead of
35 // intrinsics. Therefore a variant type is used to wrap the lambda to call
36 // those device functions.
37 struct TargetIntrinsics {
38   llvm::Intrinsic::ID nvptx_intrinsic;
39   absl::variant<llvm::Intrinsic::ID,
40                 std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>
41       amdgpu_intrinsic_or_function;
42 };
43 
44 // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU)
45 // corresponding to the give TargetIntrinsicID.
GetIntrinsic(TargetIntrinsicID intrin)46 struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) {
47   switch (intrin) {
48     case TargetIntrinsicID::kThreadIdx: {
49       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
50               llvm::Intrinsic::amdgcn_workitem_id_x};
51     }
52     case TargetIntrinsicID::kThreadIdy: {
53       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y,
54               llvm::Intrinsic::amdgcn_workitem_id_y};
55     }
56     case TargetIntrinsicID::kThreadIdz: {
57       return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z,
58               llvm::Intrinsic::amdgcn_workitem_id_z};
59     }
60     case TargetIntrinsicID::kBlockIdx: {
61       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
62               llvm::Intrinsic::amdgcn_workgroup_id_x};
63     }
64     case TargetIntrinsicID::kBlockIdy: {
65       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
66               llvm::Intrinsic::amdgcn_workgroup_id_y};
67     }
68     case TargetIntrinsicID::kBlockIdz: {
69       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z,
70               llvm::Intrinsic::amdgcn_workgroup_id_z};
71     }
72     case TargetIntrinsicID::kBarrierId: {
73       return {llvm::Intrinsic::nvvm_barrier0,
74               llvm::Intrinsic::amdgcn_s_barrier};
75     }
76     case TargetIntrinsicID::kBlockDimx: {
77       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
78               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
79                 return EmitDeviceFunctionCall("__ockl_get_local_size",
80                                               {b_->getInt32(0)}, {U32}, U64, {},
81                                               b_);
82               }};
83     }
84     case TargetIntrinsicID::kBlockDimy: {
85       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y,
86               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
87                 return EmitDeviceFunctionCall("__ockl_get_local_size",
88                                               {b_->getInt32(1)}, {U32}, U64, {},
89                                               b_);
90               }};
91     }
92     case TargetIntrinsicID::kBlockDimz: {
93       return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z,
94               [](llvm::IRBuilder<>* b_) -> llvm::CallInst* {
95                 return EmitDeviceFunctionCall("__ockl_get_local_size",
96                                               {b_->getInt32(2)}, {U32}, U64, {},
97                                               b_);
98               }};
99     }
100   }
101 }
102 
103 // Wrapper structure for carrying math functions for NVPTX/AMDGPU platforms.
104 struct TargetDeviceFunction {
105   const string nvptx_root;
106   const string amdgpu_root;
107 };
108 
109 // Gets the device function name on different platforms (NVPTX, AMDGPU)
110 // corresponding to the given TargetDeviceFunctionID.
GetDeviceFunctionRoot(TargetDeviceFunctionID func_id)111 struct TargetDeviceFunction GetDeviceFunctionRoot(
112     TargetDeviceFunctionID func_id) {
113   switch (func_id) {
114     case TargetDeviceFunctionID::kPow: {
115       return {"__nv_pow", "__ocml_pow"};
116     }
117     case TargetDeviceFunctionID::kErfcinv: {
118       return {"__nv_erfcinv", "__ocml_erfcinv"};
119     }
120     case TargetDeviceFunctionID::kLog: {
121       return {"__nv_log", "__ocml_log"};
122     }
123     case TargetDeviceFunctionID::kLog1p: {
124       return {"__nv_log1p", "__ocml_log1p"};
125     }
126     case TargetDeviceFunctionID::kSin: {
127       return {"__nv_sin", "__ocml_sin"};
128     }
129     case TargetDeviceFunctionID::kCos: {
130       return {"__nv_cos", "__ocml_cos"};
131     }
132     case TargetDeviceFunctionID::kExp: {
133       return {"__nv_exp", "__ocml_exp"};
134     }
135     case TargetDeviceFunctionID::kExpm1: {
136       return {"__nv_expm1", "__ocml_expm1"};
137     }
138     case TargetDeviceFunctionID::kSqrt: {
139       return {"__nv_sqrt", "__ocml_sqrt"};
140     }
141     case TargetDeviceFunctionID::kRsqrt: {
142       return {"__nv_rsqrt", "__ocml_rsqrt"};
143     }
144     case TargetDeviceFunctionID::kAtan2: {
145       return {"__nv_atan2", "__ocml_atan2"};
146     }
147     case TargetDeviceFunctionID::kFmod: {
148       return {"__nv_fmod", "__ocml_fmod"};
149     }
150     case TargetDeviceFunctionID::kRound: {
151       return {"__nv_round", "__ocml_round"};
152     }
153     case TargetDeviceFunctionID::kHypot: {
154       return {"__nv_hypot", "__ocml_hypot"};
155     }
156   }
157 }
158 }  // namespace
159 
ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,PrimitiveType output_type,llvm::IRBuilder<> * b)160 string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
161                                 PrimitiveType output_type,
162                                 llvm::IRBuilder<>* b) {
163   // The device math functions differentiate between "double" and "float" by
164   // appending a double or float specific suffix to a root name. The suffix and
165   // the root name are specific to the target.
166   llvm::Triple target_triple =
167       llvm::Triple(b->GetInsertBlock()->getModule()->getTargetTriple());
168   struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
169   if (target_triple.isNVPTX()) {
170     if (output_type == F32) {
171       return StrCat(gpu_root_names.nvptx_root, "f");
172     } else if (output_type == F64) {
173       return gpu_root_names.nvptx_root;
174     } else {
175       LOG(FATAL) << "Unexpected type while getting device function name.";
176     }
177   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
178     if (output_type == F32) {
179       return StrCat(gpu_root_names.amdgpu_root, "_f32");
180     } else if (output_type == F64) {
181       return StrCat(gpu_root_names.amdgpu_root, "_f64");
182     } else {
183       LOG(FATAL) << "Unexpected type while getting device function name.";
184     }
185   } else {
186     LOG(FATAL) << "Invalid triple " << target_triple.str();
187   }
188 }
189 
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes,llvm::IRBuilder<> * b)190 llvm::CallInst* EmitDeviceFunctionCall(
191     const string& callee_name, absl::Span<llvm::Value* const> operands,
192     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
193     absl::Span<const llvm::Attribute::AttrKind> attributes,
194     llvm::IRBuilder<>* b) {
195   std::vector<llvm::Type*> ir_input_types;
196   llvm::Module* module = b->GetInsertBlock()->getModule();
197   for (PrimitiveType input_type : input_types) {
198     ir_input_types.push_back(
199         llvm_ir::PrimitiveTypeToIrType(input_type, module));
200   }
201   llvm::FunctionType* callee_type = llvm::FunctionType::get(
202       llvm_ir::PrimitiveTypeToIrType(output_type, module),  // Return type.
203       ir_input_types,                                       // Parameter types.
204       false);  // No variadic arguments.
205 
206   // Declares the callee if it is not declared already.
207   llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
208       b->GetInsertBlock()
209           ->getModule()
210           ->getOrInsertFunction(callee_name, callee_type)
211           .getCallee());
212 
213   for (auto attribute : attributes) {
214     callee->addFnAttr(attribute);
215   }
216 
217   return b->CreateCall(callee, llvm_ir::AsArrayRef(operands));
218 }
219 
EmitCallToTargetIntrinsic(TargetIntrinsicID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)220 llvm::CallInst* EmitCallToTargetIntrinsic(
221     TargetIntrinsicID intrinsic_id, absl::Span<llvm::Value* const> operands,
222     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
223   llvm::Module* module = b->GetInsertBlock()->getModule();
224   struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id);
225   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
226   llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic;
227   if (target_triple.isNVPTX()) {
228     llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic;
229   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
230     llvm::Intrinsic::ID* llvm_intrinsic_id_ptr =
231         absl::get_if<llvm::Intrinsic::ID>(
232             &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
233     if (llvm_intrinsic_id_ptr) {
234       llvm_intrinsic_id = *llvm_intrinsic_id_ptr;
235     } else {
236       std::function<llvm::CallInst*(llvm::IRBuilder<>*)>* builder_func =
237           absl::get_if<std::function<llvm::CallInst*(llvm::IRBuilder<>*)>>(
238               &gpu_intrinsic_id.amdgpu_intrinsic_or_function);
239       return (*builder_func)(b);
240     }
241   } else {
242     LOG(FATAL) << "Invalid triple " << target_triple.str();
243   }
244 
245   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
246       module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types));
247   return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands));
248 }
249 
AnnotateFunctionAsGpuKernel(llvm::Module * module,llvm::Function * func,llvm::IRBuilder<> * b)250 void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
251                                  llvm::IRBuilder<>* b) {
252   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
253   if (target_triple.isNVPTX()) {
254     // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
255     // treats function as a CUDA kernel.
256     llvm::LLVMContext& context = module->getContext();
257     llvm::NamedMDNode* nvvm_annotations_node =
258         module->getOrInsertNamedMetadata("nvvm.annotations");
259     nvvm_annotations_node->addOperand(llvm::MDNode::get(
260         context, {llvm::ConstantAsMetadata::get(func),
261                   llvm::MDString::get(context, "kernel"),
262                   llvm::ConstantAsMetadata::get(b->getInt32(1))}));
263 
264   } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
265     // Attach information so AMDGPU can recognize function as a AMDGPU kernel.
266     func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
267     func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
268   } else {
269     LOG(FATAL) << "Invalid triple " << target_triple.str();
270   }
271 }
272 
273 }  // namespace gpu
274 }  // namespace xla
275