1 //===---- CGOpenMPRuntimeNVPTX.cpp - Interface to OpenMP NVPTX Runtimes ---===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This provides a class for OpenMP runtime code generation specialized to NVPTX
11 // targets.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "CGOpenMPRuntimeNVPTX.h"
16 #include "clang/AST/DeclOpenMP.h"
17 #include "CodeGenFunction.h"
18 #include "clang/AST/StmtOpenMP.h"
19
20 using namespace clang;
21 using namespace CodeGen;
22
23 /// \brief Get the GPU warp size.
getNVPTXWarpSize(CodeGenFunction & CGF)24 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXWarpSize(CodeGenFunction &CGF) {
25 CGBuilderTy &Bld = CGF.Builder;
26 return Bld.CreateCall(
27 llvm::Intrinsic::getDeclaration(
28 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_warpsize),
29 llvm::None, "nvptx_warp_size");
30 }
31
32 /// \brief Get the id of the current thread on the GPU.
getNVPTXThreadID(CodeGenFunction & CGF)33 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXThreadID(CodeGenFunction &CGF) {
34 CGBuilderTy &Bld = CGF.Builder;
35 return Bld.CreateCall(
36 llvm::Intrinsic::getDeclaration(
37 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x),
38 llvm::None, "nvptx_tid");
39 }
40
41 // \brief Get the maximum number of threads in a block of the GPU.
getNVPTXNumThreads(CodeGenFunction & CGF)42 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXNumThreads(CodeGenFunction &CGF) {
43 CGBuilderTy &Bld = CGF.Builder;
44 return Bld.CreateCall(
45 llvm::Intrinsic::getDeclaration(
46 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x),
47 llvm::None, "nvptx_num_threads");
48 }
49
50 /// \brief Get barrier to synchronize all threads in a block.
getNVPTXCTABarrier(CodeGenFunction & CGF)51 void CGOpenMPRuntimeNVPTX::getNVPTXCTABarrier(CodeGenFunction &CGF) {
52 CGBuilderTy &Bld = CGF.Builder;
53 Bld.CreateCall(llvm::Intrinsic::getDeclaration(
54 &CGM.getModule(), llvm::Intrinsic::nvvm_barrier0));
55 }
56
57 // \brief Synchronize all GPU threads in a block.
syncCTAThreads(CodeGenFunction & CGF)58 void CGOpenMPRuntimeNVPTX::syncCTAThreads(CodeGenFunction &CGF) {
59 getNVPTXCTABarrier(CGF);
60 }
61
62 /// \brief Get the thread id of the OMP master thread.
63 /// The master thread id is the first thread (lane) of the last warp in the
64 /// GPU block. Warp size is assumed to be some power of 2.
65 /// Thread id is 0 indexed.
66 /// E.g: If NumThreads is 33, master id is 32.
67 /// If NumThreads is 64, master id is 32.
68 /// If NumThreads is 1024, master id is 992.
getMasterThreadID(CodeGenFunction & CGF)69 llvm::Value *CGOpenMPRuntimeNVPTX::getMasterThreadID(CodeGenFunction &CGF) {
70 CGBuilderTy &Bld = CGF.Builder;
71 llvm::Value *NumThreads = getNVPTXNumThreads(CGF);
72
73 // We assume that the warp size is a power of 2.
74 llvm::Value *Mask = Bld.CreateSub(getNVPTXWarpSize(CGF), Bld.getInt32(1));
75
76 return Bld.CreateAnd(Bld.CreateSub(NumThreads, Bld.getInt32(1)),
77 Bld.CreateNot(Mask), "master_tid");
78 }
79
80 namespace {
81 enum OpenMPRTLFunctionNVPTX {
82 /// \brief Call to void __kmpc_kernel_init(kmp_int32 omp_handle,
83 /// kmp_int32 thread_limit);
84 OMPRTL_NVPTX__kmpc_kernel_init,
85 };
86
87 // NVPTX Address space
88 enum ADDRESS_SPACE {
89 ADDRESS_SPACE_SHARED = 3,
90 };
91 } // namespace
92
WorkerFunctionState(CodeGenModule & CGM)93 CGOpenMPRuntimeNVPTX::WorkerFunctionState::WorkerFunctionState(
94 CodeGenModule &CGM)
95 : WorkerFn(nullptr), CGFI(nullptr) {
96 createWorkerFunction(CGM);
97 }
98
createWorkerFunction(CodeGenModule & CGM)99 void CGOpenMPRuntimeNVPTX::WorkerFunctionState::createWorkerFunction(
100 CodeGenModule &CGM) {
101 // Create an worker function with no arguments.
102 CGFI = &CGM.getTypes().arrangeNullaryFunction();
103
104 WorkerFn = llvm::Function::Create(
105 CGM.getTypes().GetFunctionType(*CGFI), llvm::GlobalValue::InternalLinkage,
106 /* placeholder */ "_worker", &CGM.getModule());
107 CGM.SetInternalFunctionAttributes(/*D=*/nullptr, WorkerFn, *CGFI);
108 WorkerFn->setLinkage(llvm::GlobalValue::InternalLinkage);
109 WorkerFn->addFnAttr(llvm::Attribute::NoInline);
110 }
111
initializeEnvironment()112 void CGOpenMPRuntimeNVPTX::initializeEnvironment() {
113 //
114 // Initialize master-worker control state in shared memory.
115 //
116
117 auto DL = CGM.getDataLayout();
118 ActiveWorkers = new llvm::GlobalVariable(
119 CGM.getModule(), CGM.Int32Ty, /*isConstant=*/false,
120 llvm::GlobalValue::CommonLinkage,
121 llvm::Constant::getNullValue(CGM.Int32Ty), "__omp_num_threads", 0,
122 llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
123 ActiveWorkers->setAlignment(DL.getPrefTypeAlignment(CGM.Int32Ty));
124
125 WorkID = new llvm::GlobalVariable(
126 CGM.getModule(), CGM.Int64Ty, /*isConstant=*/false,
127 llvm::GlobalValue::CommonLinkage,
128 llvm::Constant::getNullValue(CGM.Int64Ty), "__tgt_work_id", 0,
129 llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
130 WorkID->setAlignment(DL.getPrefTypeAlignment(CGM.Int64Ty));
131 }
132
emitWorkerFunction(WorkerFunctionState & WST)133 void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) {
134 auto &Ctx = CGM.getContext();
135
136 CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
137 CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, WST.WorkerFn, *WST.CGFI, {});
138 emitWorkerLoop(CGF, WST);
139 CGF.FinishFunction();
140 }
141
emitWorkerLoop(CodeGenFunction & CGF,WorkerFunctionState & WST)142 void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
143 WorkerFunctionState &WST) {
144 //
145 // The workers enter this loop and wait for parallel work from the master.
146 // When the master encounters a parallel region it sets up the work + variable
147 // arguments, and wakes up the workers. The workers first check to see if
148 // they are required for the parallel region, i.e., within the # of requested
149 // parallel threads. The activated workers load the variable arguments and
150 // execute the parallel work.
151 //
152
153 CGBuilderTy &Bld = CGF.Builder;
154
155 llvm::BasicBlock *AwaitBB = CGF.createBasicBlock(".await.work");
156 llvm::BasicBlock *SelectWorkersBB = CGF.createBasicBlock(".select.workers");
157 llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute.parallel");
158 llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".terminate.parallel");
159 llvm::BasicBlock *BarrierBB = CGF.createBasicBlock(".barrier.parallel");
160 llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
161
162 CGF.EmitBranch(AwaitBB);
163
164 // Workers wait for work from master.
165 CGF.EmitBlock(AwaitBB);
166 // Wait for parallel work
167 syncCTAThreads(CGF);
168 // On termination condition (workid == 0), exit loop.
169 llvm::Value *ShouldTerminate = Bld.CreateICmpEQ(
170 Bld.CreateAlignedLoad(WorkID, WorkID->getAlignment()),
171 llvm::Constant::getNullValue(WorkID->getType()->getElementType()),
172 "should_terminate");
173 Bld.CreateCondBr(ShouldTerminate, ExitBB, SelectWorkersBB);
174
175 // Activate requested workers.
176 CGF.EmitBlock(SelectWorkersBB);
177 llvm::Value *ThreadID = getNVPTXThreadID(CGF);
178 llvm::Value *ActiveThread = Bld.CreateICmpSLT(
179 ThreadID,
180 Bld.CreateAlignedLoad(ActiveWorkers, ActiveWorkers->getAlignment()),
181 "active_thread");
182 Bld.CreateCondBr(ActiveThread, ExecuteBB, BarrierBB);
183
184 // Signal start of parallel region.
185 CGF.EmitBlock(ExecuteBB);
186 // TODO: Add parallel work.
187
188 // Signal end of parallel region.
189 CGF.EmitBlock(TerminateBB);
190 CGF.EmitBranch(BarrierBB);
191
192 // All active and inactive workers wait at a barrier after parallel region.
193 CGF.EmitBlock(BarrierBB);
194 // Barrier after parallel region.
195 syncCTAThreads(CGF);
196 CGF.EmitBranch(AwaitBB);
197
198 // Exit target region.
199 CGF.EmitBlock(ExitBB);
200 }
201
202 // Setup NVPTX threads for master-worker OpenMP scheme.
emitEntryHeader(CodeGenFunction & CGF,EntryFunctionState & EST,WorkerFunctionState & WST)203 void CGOpenMPRuntimeNVPTX::emitEntryHeader(CodeGenFunction &CGF,
204 EntryFunctionState &EST,
205 WorkerFunctionState &WST) {
206 CGBuilderTy &Bld = CGF.Builder;
207
208 // Get the master thread id.
209 llvm::Value *MasterID = getMasterThreadID(CGF);
210 // Current thread's identifier.
211 llvm::Value *ThreadID = getNVPTXThreadID(CGF);
212
213 // Setup BBs in entry function.
214 llvm::BasicBlock *WorkerCheckBB = CGF.createBasicBlock(".check.for.worker");
215 llvm::BasicBlock *WorkerBB = CGF.createBasicBlock(".worker");
216 llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
217 EST.ExitBB = CGF.createBasicBlock(".exit");
218
219 // The head (master thread) marches on while its body of companion threads in
220 // the warp go to sleep.
221 llvm::Value *ShouldDie =
222 Bld.CreateICmpUGT(ThreadID, MasterID, "excess_in_master_warp");
223 Bld.CreateCondBr(ShouldDie, EST.ExitBB, WorkerCheckBB);
224
225 // Select worker threads...
226 CGF.EmitBlock(WorkerCheckBB);
227 llvm::Value *IsWorker = Bld.CreateICmpULT(ThreadID, MasterID, "is_worker");
228 Bld.CreateCondBr(IsWorker, WorkerBB, MasterBB);
229
230 // ... and send to worker loop, awaiting parallel invocation.
231 CGF.EmitBlock(WorkerBB);
232 CGF.EmitCallOrInvoke(WST.WorkerFn, llvm::None);
233 CGF.EmitBranch(EST.ExitBB);
234
235 // Only master thread executes subsequent serial code.
236 CGF.EmitBlock(MasterBB);
237
238 // First action in sequential region:
239 // Initialize the state of the OpenMP runtime library on the GPU.
240 llvm::Value *Args[] = {Bld.getInt32(/*OmpHandle=*/0), getNVPTXThreadID(CGF)};
241 CGF.EmitRuntimeCall(createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init),
242 Args);
243 }
244
emitEntryFooter(CodeGenFunction & CGF,EntryFunctionState & EST)245 void CGOpenMPRuntimeNVPTX::emitEntryFooter(CodeGenFunction &CGF,
246 EntryFunctionState &EST) {
247 CGBuilderTy &Bld = CGF.Builder;
248 llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".termination.notifier");
249 CGF.EmitBranch(TerminateBB);
250
251 CGF.EmitBlock(TerminateBB);
252 // Signal termination condition.
253 Bld.CreateAlignedStore(
254 llvm::Constant::getNullValue(WorkID->getType()->getElementType()), WorkID,
255 WorkID->getAlignment());
256 // Barrier to terminate worker threads.
257 syncCTAThreads(CGF);
258 // Master thread jumps to exit point.
259 CGF.EmitBranch(EST.ExitBB);
260
261 CGF.EmitBlock(EST.ExitBB);
262 }
263
264 /// \brief Returns specified OpenMP runtime function for the current OpenMP
265 /// implementation. Specialized for the NVPTX device.
266 /// \param Function OpenMP runtime function.
267 /// \return Specified function.
268 llvm::Constant *
createNVPTXRuntimeFunction(unsigned Function)269 CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
270 llvm::Constant *RTLFn = nullptr;
271 switch (static_cast<OpenMPRTLFunctionNVPTX>(Function)) {
272 case OMPRTL_NVPTX__kmpc_kernel_init: {
273 // Build void __kmpc_kernel_init(kmp_int32 omp_handle,
274 // kmp_int32 thread_limit);
275 llvm::Type *TypeParams[] = {CGM.Int32Ty, CGM.Int32Ty};
276 llvm::FunctionType *FnTy =
277 llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
278 RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_init");
279 break;
280 }
281 }
282 return RTLFn;
283 }
284
createOffloadEntry(llvm::Constant * ID,llvm::Constant * Addr,uint64_t Size)285 void CGOpenMPRuntimeNVPTX::createOffloadEntry(llvm::Constant *ID,
286 llvm::Constant *Addr,
287 uint64_t Size) {
288 auto *F = dyn_cast<llvm::Function>(Addr);
289 // TODO: Add support for global variables on the device after declare target
290 // support.
291 if (!F)
292 return;
293 llvm::Module *M = F->getParent();
294 llvm::LLVMContext &Ctx = M->getContext();
295
296 // Get "nvvm.annotations" metadata node
297 llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
298
299 llvm::Metadata *MDVals[] = {
300 llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "kernel"),
301 llvm::ConstantAsMetadata::get(
302 llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
303 // Append metadata to nvvm.annotations
304 MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
305 }
306
emitTargetOutlinedFunction(const OMPExecutableDirective & D,StringRef ParentName,llvm::Function * & OutlinedFn,llvm::Constant * & OutlinedFnID,bool IsOffloadEntry,const RegionCodeGenTy & CodeGen)307 void CGOpenMPRuntimeNVPTX::emitTargetOutlinedFunction(
308 const OMPExecutableDirective &D, StringRef ParentName,
309 llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID,
310 bool IsOffloadEntry, const RegionCodeGenTy &CodeGen) {
311 if (!IsOffloadEntry) // Nothing to do.
312 return;
313
314 assert(!ParentName.empty() && "Invalid target region parent name!");
315
316 EntryFunctionState EST;
317 WorkerFunctionState WST(CGM);
318
319 // Emit target region as a standalone region.
320 class NVPTXPrePostActionTy : public PrePostActionTy {
321 CGOpenMPRuntimeNVPTX &RT;
322 CGOpenMPRuntimeNVPTX::EntryFunctionState &EST;
323 CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST;
324
325 public:
326 NVPTXPrePostActionTy(CGOpenMPRuntimeNVPTX &RT,
327 CGOpenMPRuntimeNVPTX::EntryFunctionState &EST,
328 CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST)
329 : RT(RT), EST(EST), WST(WST) {}
330 void Enter(CodeGenFunction &CGF) override {
331 RT.emitEntryHeader(CGF, EST, WST);
332 }
333 void Exit(CodeGenFunction &CGF) override { RT.emitEntryFooter(CGF, EST); }
334 } Action(*this, EST, WST);
335 CodeGen.setAction(Action);
336 emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
337 IsOffloadEntry, CodeGen);
338
339 // Create the worker function
340 emitWorkerFunction(WST);
341
342 // Now change the name of the worker function to correspond to this target
343 // region's entry function.
344 WST.WorkerFn->setName(OutlinedFn->getName() + "_worker");
345 }
346
CGOpenMPRuntimeNVPTX(CodeGenModule & CGM)347 CGOpenMPRuntimeNVPTX::CGOpenMPRuntimeNVPTX(CodeGenModule &CGM)
348 : CGOpenMPRuntime(CGM), ActiveWorkers(nullptr), WorkID(nullptr) {
349 if (!CGM.getLangOpts().OpenMPIsDevice)
350 llvm_unreachable("OpenMP NVPTX can only handle device code.");
351
352 // Called once per module during initialization.
353 initializeEnvironment();
354 }
355
emitNumTeamsClause(CodeGenFunction & CGF,const Expr * NumTeams,const Expr * ThreadLimit,SourceLocation Loc)356 void CGOpenMPRuntimeNVPTX::emitNumTeamsClause(CodeGenFunction &CGF,
357 const Expr *NumTeams,
358 const Expr *ThreadLimit,
359 SourceLocation Loc) {}
360
emitParallelOrTeamsOutlinedFunction(const OMPExecutableDirective & D,const VarDecl * ThreadIDVar,OpenMPDirectiveKind InnermostKind,const RegionCodeGenTy & CodeGen)361 llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOrTeamsOutlinedFunction(
362 const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
363 OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
364
365 llvm::Function *OutlinedFun = nullptr;
366 if (isa<OMPTeamsDirective>(D)) {
367 llvm::Value *OutlinedFunVal =
368 CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
369 D, ThreadIDVar, InnermostKind, CodeGen);
370 OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
371 OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
372 } else
373 llvm_unreachable("parallel directive is not yet supported for nvptx "
374 "backend.");
375
376 return OutlinedFun;
377 }
378
emitTeamsCall(CodeGenFunction & CGF,const OMPExecutableDirective & D,SourceLocation Loc,llvm::Value * OutlinedFn,ArrayRef<llvm::Value * > CapturedVars)379 void CGOpenMPRuntimeNVPTX::emitTeamsCall(CodeGenFunction &CGF,
380 const OMPExecutableDirective &D,
381 SourceLocation Loc,
382 llvm::Value *OutlinedFn,
383 ArrayRef<llvm::Value *> CapturedVars) {
384 if (!CGF.HaveInsertPoint())
385 return;
386
387 Address ZeroAddr =
388 CGF.CreateTempAlloca(CGF.Int32Ty, CharUnits::fromQuantity(4),
389 /*Name*/ ".zero.addr");
390 CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
391 llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
392 OutlinedFnArgs.push_back(ZeroAddr.getPointer());
393 OutlinedFnArgs.push_back(ZeroAddr.getPointer());
394 OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
395 CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
396 }
397