1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements semantic analysis for C++ Coroutines.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "clang/Sema/SemaInternal.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/ExprCXX.h"
17 #include "clang/AST/StmtCXX.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Sema/Initialization.h"
20 #include "clang/Sema/Overload.h"
21 using namespace clang;
22 using namespace sema;
23
24 /// Look up the std::coroutine_traits<...>::promise_type for the given
25 /// function type.
lookupPromiseType(Sema & S,const FunctionProtoType * FnType,SourceLocation Loc)26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
27 SourceLocation Loc) {
28 // FIXME: Cache std::coroutine_traits once we've found it.
29 NamespaceDecl *Std = S.getStdNamespace();
30 if (!Std) {
31 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
32 return QualType();
33 }
34
35 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
36 Loc, Sema::LookupOrdinaryName);
37 if (!S.LookupQualifiedName(Result, Std)) {
38 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
39 return QualType();
40 }
41
42 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
43 if (!CoroTraits) {
44 Result.suppressDiagnostics();
45 // We found something weird. Complain about the first thing we found.
46 NamedDecl *Found = *Result.begin();
47 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
48 return QualType();
49 }
50
51 // Form template argument list for coroutine_traits<R, P1, P2, ...>.
52 TemplateArgumentListInfo Args(Loc, Loc);
53 Args.addArgument(TemplateArgumentLoc(
54 TemplateArgument(FnType->getReturnType()),
55 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
56 // FIXME: If the function is a non-static member function, add the type
57 // of the implicit object parameter before the formal parameters.
58 for (QualType T : FnType->getParamTypes())
59 Args.addArgument(TemplateArgumentLoc(
60 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
61
62 // Build the template-id.
63 QualType CoroTrait =
64 S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
65 if (CoroTrait.isNull())
66 return QualType();
67 if (S.RequireCompleteType(Loc, CoroTrait,
68 diag::err_coroutine_traits_missing_specialization))
69 return QualType();
70
71 CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
72 assert(RD && "specialization of class template is not a class?");
73
74 // Look up the ::promise_type member.
75 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
76 Sema::LookupOrdinaryName);
77 S.LookupQualifiedName(R, RD);
78 auto *Promise = R.getAsSingle<TypeDecl>();
79 if (!Promise) {
80 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
81 << RD;
82 return QualType();
83 }
84
85 // The promise type is required to be a class type.
86 QualType PromiseType = S.Context.getTypeDeclType(Promise);
87 if (!PromiseType->getAsCXXRecordDecl()) {
88 // Use the fully-qualified name of the type.
89 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std);
90 NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
91 CoroTrait.getTypePtr());
92 PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
93
94 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
95 << PromiseType;
96 return QualType();
97 }
98
99 return PromiseType;
100 }
101
102 /// Check that this is a context in which a coroutine suspension can appear.
103 static FunctionScopeInfo *
checkCoroutineContext(Sema & S,SourceLocation Loc,StringRef Keyword)104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
105 // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
106 if (S.isUnevaluatedContext()) {
107 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
108 return nullptr;
109 }
110
111 // Any other usage must be within a function.
112 // FIXME: Reject a coroutine with a deduced return type.
113 auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
114 if (!FD) {
115 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
116 ? diag::err_coroutine_objc_method
117 : diag::err_coroutine_outside_function) << Keyword;
118 } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
119 // Coroutines TS [special]/6:
120 // A special member function shall not be a coroutine.
121 //
122 // FIXME: We assume that this really means that a coroutine cannot
123 // be a constructor or destructor.
124 S.Diag(Loc, diag::err_coroutine_ctor_dtor)
125 << isa<CXXDestructorDecl>(FD) << Keyword;
126 } else if (FD->isConstexpr()) {
127 S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
128 } else if (FD->isVariadic()) {
129 S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
130 } else {
131 auto *ScopeInfo = S.getCurFunction();
132 assert(ScopeInfo && "missing function scope for function");
133
134 // If we don't have a promise variable, build one now.
135 if (!ScopeInfo->CoroutinePromise) {
136 QualType T =
137 FD->getType()->isDependentType()
138 ? S.Context.DependentTy
139 : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
140 Loc);
141 if (T.isNull())
142 return nullptr;
143
144 // Create and default-initialize the promise.
145 ScopeInfo->CoroutinePromise =
146 VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
147 &S.PP.getIdentifierTable().get("__promise"), T,
148 S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
149 S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
150 if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
151 S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
152 }
153
154 return ScopeInfo;
155 }
156
157 return nullptr;
158 }
159
160 /// Build a call to 'operator co_await' if there is a suitable operator for
161 /// the given expression.
buildOperatorCoawaitCall(Sema & SemaRef,Scope * S,SourceLocation Loc,Expr * E)162 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
163 SourceLocation Loc, Expr *E) {
164 UnresolvedSet<16> Functions;
165 SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
166 Functions);
167 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
168 }
169
170 struct ReadySuspendResumeResult {
171 bool IsInvalid;
172 Expr *Results[3];
173 };
174
buildMemberCall(Sema & S,Expr * Base,SourceLocation Loc,StringRef Name,MutableArrayRef<Expr * > Args)175 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
176 StringRef Name,
177 MutableArrayRef<Expr *> Args) {
178 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
179
180 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
181 CXXScopeSpec SS;
182 ExprResult Result = S.BuildMemberReferenceExpr(
183 Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
184 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
185 /*Scope=*/nullptr);
186 if (Result.isInvalid())
187 return ExprError();
188
189 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
190 }
191
192 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
193 /// expression.
buildCoawaitCalls(Sema & S,SourceLocation Loc,Expr * E)194 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
195 Expr *E) {
196 // Assume invalid until we see otherwise.
197 ReadySuspendResumeResult Calls = {true, {}};
198
199 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
200 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
201 Expr *Operand = new (S.Context) OpaqueValueExpr(
202 Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
203
204 // FIXME: Pass coroutine handle to await_suspend.
205 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
206 if (Result.isInvalid())
207 return Calls;
208 Calls.Results[I] = Result.get();
209 }
210
211 Calls.IsInvalid = false;
212 return Calls;
213 }
214
ActOnCoawaitExpr(Scope * S,SourceLocation Loc,Expr * E)215 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
216 if (E->getType()->isPlaceholderType()) {
217 ExprResult R = CheckPlaceholderExpr(E);
218 if (R.isInvalid()) return ExprError();
219 E = R.get();
220 }
221
222 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
223 if (Awaitable.isInvalid())
224 return ExprError();
225 return BuildCoawaitExpr(Loc, Awaitable.get());
226 }
BuildCoawaitExpr(SourceLocation Loc,Expr * E)227 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
228 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
229 if (!Coroutine)
230 return ExprError();
231
232 if (E->getType()->isPlaceholderType()) {
233 ExprResult R = CheckPlaceholderExpr(E);
234 if (R.isInvalid()) return ExprError();
235 E = R.get();
236 }
237
238 if (E->getType()->isDependentType()) {
239 Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
240 Coroutine->CoroutineStmts.push_back(Res);
241 return Res;
242 }
243
244 // If the expression is a temporary, materialize it as an lvalue so that we
245 // can use it multiple times.
246 if (E->getValueKind() == VK_RValue)
247 E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
248
249 // Build the await_ready, await_suspend, await_resume calls.
250 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
251 if (RSS.IsInvalid)
252 return ExprError();
253
254 Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
255 RSS.Results[2]);
256 Coroutine->CoroutineStmts.push_back(Res);
257 return Res;
258 }
259
buildPromiseCall(Sema & S,FunctionScopeInfo * Coroutine,SourceLocation Loc,StringRef Name,MutableArrayRef<Expr * > Args)260 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
261 SourceLocation Loc, StringRef Name,
262 MutableArrayRef<Expr *> Args) {
263 assert(Coroutine->CoroutinePromise && "no promise for coroutine");
264
265 // Form a reference to the promise.
266 auto *Promise = Coroutine->CoroutinePromise;
267 ExprResult PromiseRef = S.BuildDeclRefExpr(
268 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
269 if (PromiseRef.isInvalid())
270 return ExprError();
271
272 // Call 'yield_value', passing in E.
273 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
274 }
275
ActOnCoyieldExpr(Scope * S,SourceLocation Loc,Expr * E)276 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
277 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
278 if (!Coroutine)
279 return ExprError();
280
281 // Build yield_value call.
282 ExprResult Awaitable =
283 buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
284 if (Awaitable.isInvalid())
285 return ExprError();
286
287 // Build 'operator co_await' call.
288 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
289 if (Awaitable.isInvalid())
290 return ExprError();
291
292 return BuildCoyieldExpr(Loc, Awaitable.get());
293 }
BuildCoyieldExpr(SourceLocation Loc,Expr * E)294 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
295 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
296 if (!Coroutine)
297 return ExprError();
298
299 if (E->getType()->isPlaceholderType()) {
300 ExprResult R = CheckPlaceholderExpr(E);
301 if (R.isInvalid()) return ExprError();
302 E = R.get();
303 }
304
305 if (E->getType()->isDependentType()) {
306 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
307 Coroutine->CoroutineStmts.push_back(Res);
308 return Res;
309 }
310
311 // If the expression is a temporary, materialize it as an lvalue so that we
312 // can use it multiple times.
313 if (E->getValueKind() == VK_RValue)
314 E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
315
316 // Build the await_ready, await_suspend, await_resume calls.
317 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
318 if (RSS.IsInvalid)
319 return ExprError();
320
321 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
322 RSS.Results[2]);
323 Coroutine->CoroutineStmts.push_back(Res);
324 return Res;
325 }
326
ActOnCoreturnStmt(SourceLocation Loc,Expr * E)327 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
328 return BuildCoreturnStmt(Loc, E);
329 }
BuildCoreturnStmt(SourceLocation Loc,Expr * E)330 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
331 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
332 if (!Coroutine)
333 return StmtError();
334
335 if (E && E->getType()->isPlaceholderType() &&
336 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
337 ExprResult R = CheckPlaceholderExpr(E);
338 if (R.isInvalid()) return StmtError();
339 E = R.get();
340 }
341
342 // FIXME: If the operand is a reference to a variable that's about to go out
343 // of scope, we should treat the operand as an xvalue for this overload
344 // resolution.
345 ExprResult PC;
346 if (E && !E->getType()->isVoidType()) {
347 PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
348 } else {
349 E = MakeFullDiscardedValueExpr(E).get();
350 PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
351 }
352 if (PC.isInvalid())
353 return StmtError();
354
355 Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
356
357 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
358 Coroutine->CoroutineStmts.push_back(Res);
359 return Res;
360 }
361
CheckCompletedCoroutineBody(FunctionDecl * FD,Stmt * & Body)362 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
363 FunctionScopeInfo *Fn = getCurFunction();
364 assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
365
366 // Coroutines [stmt.return]p1:
367 // A return statement shall not appear in a coroutine.
368 if (Fn->FirstReturnLoc.isValid()) {
369 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
370 auto *First = Fn->CoroutineStmts[0];
371 Diag(First->getLocStart(), diag::note_declared_coroutine_here)
372 << (isa<CoawaitExpr>(First) ? 0 :
373 isa<CoyieldExpr>(First) ? 1 : 2);
374 }
375
376 bool AnyCoawaits = false;
377 bool AnyCoyields = false;
378 for (auto *CoroutineStmt : Fn->CoroutineStmts) {
379 AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
380 AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
381 }
382
383 if (!AnyCoawaits && !AnyCoyields)
384 Diag(Fn->CoroutineStmts.front()->getLocStart(),
385 diag::ext_coroutine_without_co_await_co_yield);
386
387 SourceLocation Loc = FD->getLocation();
388
389 // Form a declaration statement for the promise declaration, so that AST
390 // visitors can more easily find it.
391 StmtResult PromiseStmt =
392 ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
393 if (PromiseStmt.isInvalid())
394 return FD->setInvalidDecl();
395
396 // Form and check implicit 'co_await p.initial_suspend();' statement.
397 ExprResult InitialSuspend =
398 buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
399 // FIXME: Support operator co_await here.
400 if (!InitialSuspend.isInvalid())
401 InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
402 InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
403 if (InitialSuspend.isInvalid())
404 return FD->setInvalidDecl();
405
406 // Form and check implicit 'co_await p.final_suspend();' statement.
407 ExprResult FinalSuspend =
408 buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
409 // FIXME: Support operator co_await here.
410 if (!FinalSuspend.isInvalid())
411 FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
412 FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
413 if (FinalSuspend.isInvalid())
414 return FD->setInvalidDecl();
415
416 // FIXME: Perform analysis of set_exception call.
417
418 // FIXME: Try to form 'p.return_void();' expression statement to handle
419 // control flowing off the end of the coroutine.
420
421 // Build implicit 'p.get_return_object()' expression and form initialization
422 // of return type from it.
423 ExprResult ReturnObject =
424 buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
425 if (ReturnObject.isInvalid())
426 return FD->setInvalidDecl();
427 QualType RetType = FD->getReturnType();
428 if (!RetType->isDependentType()) {
429 InitializedEntity Entity =
430 InitializedEntity::InitializeResult(Loc, RetType, false);
431 ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
432 ReturnObject.get());
433 if (ReturnObject.isInvalid())
434 return FD->setInvalidDecl();
435 }
436 ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
437 if (ReturnObject.isInvalid())
438 return FD->setInvalidDecl();
439
440 // FIXME: Perform move-initialization of parameters into frame-local copies.
441 SmallVector<Expr*, 16> ParamMoves;
442
443 // Build body for the coroutine wrapper statement.
444 Body = new (Context) CoroutineBodyStmt(
445 Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
446 /*SetException*/nullptr, /*Fallthrough*/nullptr,
447 ReturnObject.get(), ParamMoves);
448 }
449