1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/resolver/resolver.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <limits>
21 #include <utility>
22
23 #include "src/ast/alias.h"
24 #include "src/ast/array.h"
25 #include "src/ast/assignment_statement.h"
26 #include "src/ast/bitcast_expression.h"
27 #include "src/ast/break_statement.h"
28 #include "src/ast/call_statement.h"
29 #include "src/ast/continue_statement.h"
30 #include "src/ast/depth_texture.h"
31 #include "src/ast/disable_validation_decoration.h"
32 #include "src/ast/discard_statement.h"
33 #include "src/ast/fallthrough_statement.h"
34 #include "src/ast/for_loop_statement.h"
35 #include "src/ast/if_statement.h"
36 #include "src/ast/internal_decoration.h"
37 #include "src/ast/interpolate_decoration.h"
38 #include "src/ast/loop_statement.h"
39 #include "src/ast/matrix.h"
40 #include "src/ast/override_decoration.h"
41 #include "src/ast/pointer.h"
42 #include "src/ast/return_statement.h"
43 #include "src/ast/sampled_texture.h"
44 #include "src/ast/sampler.h"
45 #include "src/ast/storage_texture.h"
46 #include "src/ast/struct_block_decoration.h"
47 #include "src/ast/switch_statement.h"
48 #include "src/ast/traverse_expressions.h"
49 #include "src/ast/type_name.h"
50 #include "src/ast/unary_op_expression.h"
51 #include "src/ast/variable_decl_statement.h"
52 #include "src/ast/vector.h"
53 #include "src/ast/workgroup_decoration.h"
54 #include "src/sem/array.h"
55 #include "src/sem/atomic_type.h"
56 #include "src/sem/call.h"
57 #include "src/sem/depth_multisampled_texture_type.h"
58 #include "src/sem/depth_texture_type.h"
59 #include "src/sem/for_loop_statement.h"
60 #include "src/sem/function.h"
61 #include "src/sem/if_statement.h"
62 #include "src/sem/loop_statement.h"
63 #include "src/sem/member_accessor_expression.h"
64 #include "src/sem/multisampled_texture_type.h"
65 #include "src/sem/pointer_type.h"
66 #include "src/sem/reference_type.h"
67 #include "src/sem/sampled_texture_type.h"
68 #include "src/sem/sampler_type.h"
69 #include "src/sem/statement.h"
70 #include "src/sem/storage_texture_type.h"
71 #include "src/sem/struct.h"
72 #include "src/sem/switch_statement.h"
73 #include "src/sem/type_constructor.h"
74 #include "src/sem/type_conversion.h"
75 #include "src/sem/variable.h"
76 #include "src/utils/defer.h"
77 #include "src/utils/math.h"
78 #include "src/utils/reverse.h"
79 #include "src/utils/scoped_assignment.h"
80 #include "src/utils/transform.h"
81
82 namespace tint {
83 namespace resolver {
84
Resolver(ProgramBuilder * builder)85 Resolver::Resolver(ProgramBuilder* builder)
86 : builder_(builder),
87 diagnostics_(builder->Diagnostics()),
88 intrinsic_table_(IntrinsicTable::Create(*builder)) {}
89
90 Resolver::~Resolver() = default;
91
Resolve()92 bool Resolver::Resolve() {
93 if (builder_->Diagnostics().contains_errors()) {
94 return false;
95 }
96
97 if (!DependencyGraph::Build(builder_->AST(), builder_->Symbols(),
98 builder_->Diagnostics(), dependencies_,
99 /* allow_out_of_order_decls*/ false)) {
100 return false;
101 }
102
103 bool result = ResolveInternal();
104
105 if (!result && !diagnostics_.contains_errors()) {
106 TINT_ICE(Resolver, diagnostics_)
107 << "resolving failed, but no error was raised";
108 return false;
109 }
110
111 return result;
112 }
113
ResolveInternal()114 bool Resolver::ResolveInternal() {
115 Mark(&builder_->AST());
116
117 // Process everything else in the order they appear in the module. This is
118 // necessary for validation of use-before-declaration.
119 for (auto* decl : builder_->AST().GlobalDeclarations()) {
120 if (auto* td = decl->As<ast::TypeDecl>()) {
121 Mark(td);
122 if (!TypeDecl(td)) {
123 return false;
124 }
125 } else if (auto* func = decl->As<ast::Function>()) {
126 Mark(func);
127 if (!Function(func)) {
128 return false;
129 }
130 } else if (auto* var = decl->As<ast::Variable>()) {
131 Mark(var);
132 if (!GlobalVariable(var)) {
133 return false;
134 }
135 } else {
136 TINT_UNREACHABLE(Resolver, diagnostics_)
137 << "unhandled global declaration: " << decl->TypeInfo().name;
138 return false;
139 }
140 }
141
142 AllocateOverridableConstantIds();
143
144 SetShadows();
145
146 if (!ValidatePipelineStages()) {
147 return false;
148 }
149
150 bool result = true;
151 for (auto* node : builder_->ASTNodes().Objects()) {
152 if (marked_.count(node) == 0) {
153 TINT_ICE(Resolver, diagnostics_) << "AST node '" << node->TypeInfo().name
154 << "' was not reached by the resolver\n"
155 << "At: " << node->source << "\n"
156 << "Pointer: " << node;
157 result = false;
158 }
159 }
160
161 return result;
162 }
163
Type(const ast::Type * ty)164 sem::Type* Resolver::Type(const ast::Type* ty) {
165 Mark(ty);
166 auto* s = [&]() -> sem::Type* {
167 if (ty->Is<ast::Void>()) {
168 return builder_->create<sem::Void>();
169 }
170 if (ty->Is<ast::Bool>()) {
171 return builder_->create<sem::Bool>();
172 }
173 if (ty->Is<ast::I32>()) {
174 return builder_->create<sem::I32>();
175 }
176 if (ty->Is<ast::U32>()) {
177 return builder_->create<sem::U32>();
178 }
179 if (ty->Is<ast::F32>()) {
180 return builder_->create<sem::F32>();
181 }
182 if (auto* t = ty->As<ast::Vector>()) {
183 if (auto* el = Type(t->type)) {
184 if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
185 if (ValidateVector(vector, t->source)) {
186 return vector;
187 }
188 }
189 }
190 return nullptr;
191 }
192 if (auto* t = ty->As<ast::Matrix>()) {
193 if (auto* el = Type(t->type)) {
194 if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
195 if (auto* matrix =
196 builder_->create<sem::Matrix>(column_type, t->columns)) {
197 if (ValidateMatrix(matrix, t->source)) {
198 return matrix;
199 }
200 }
201 }
202 }
203 return nullptr;
204 }
205 if (auto* t = ty->As<ast::Array>()) {
206 return Array(t);
207 }
208 if (auto* t = ty->As<ast::Atomic>()) {
209 if (auto* el = Type(t->type)) {
210 auto* a = builder_->create<sem::Atomic>(el);
211 if (!ValidateAtomic(t, a)) {
212 return nullptr;
213 }
214 return a;
215 }
216 return nullptr;
217 }
218 if (auto* t = ty->As<ast::Pointer>()) {
219 if (auto* el = Type(t->type)) {
220 auto access = t->access;
221 if (access == ast::kUndefined) {
222 access = DefaultAccessForStorageClass(t->storage_class);
223 }
224 return builder_->create<sem::Pointer>(el, t->storage_class, access);
225 }
226 return nullptr;
227 }
228 if (auto* t = ty->As<ast::Sampler>()) {
229 return builder_->create<sem::Sampler>(t->kind);
230 }
231 if (auto* t = ty->As<ast::SampledTexture>()) {
232 if (auto* el = Type(t->type)) {
233 return builder_->create<sem::SampledTexture>(t->dim, el);
234 }
235 return nullptr;
236 }
237 if (auto* t = ty->As<ast::MultisampledTexture>()) {
238 if (auto* el = Type(t->type)) {
239 return builder_->create<sem::MultisampledTexture>(t->dim, el);
240 }
241 return nullptr;
242 }
243 if (auto* t = ty->As<ast::DepthTexture>()) {
244 return builder_->create<sem::DepthTexture>(t->dim);
245 }
246 if (auto* t = ty->As<ast::DepthMultisampledTexture>()) {
247 return builder_->create<sem::DepthMultisampledTexture>(t->dim);
248 }
249 if (auto* t = ty->As<ast::StorageTexture>()) {
250 if (auto* el = Type(t->type)) {
251 if (!ValidateStorageTexture(t)) {
252 return nullptr;
253 }
254 return builder_->create<sem::StorageTexture>(t->dim, t->format,
255 t->access, el);
256 }
257 return nullptr;
258 }
259 if (ty->As<ast::ExternalTexture>()) {
260 return builder_->create<sem::ExternalTexture>();
261 }
262 if (auto* type = ResolvedSymbol<sem::Type>(ty)) {
263 return type;
264 }
265 TINT_UNREACHABLE(Resolver, diagnostics_)
266 << "Unhandled ast::Type: " << ty->TypeInfo().name;
267 return nullptr;
268 }();
269
270 if (s) {
271 builder_->Sem().Add(ty, s);
272 }
273 return s;
274 }
275
Variable(const ast::Variable * var,VariableKind kind,uint32_t index)276 sem::Variable* Resolver::Variable(const ast::Variable* var,
277 VariableKind kind,
278 uint32_t index /* = 0 */) {
279 const sem::Type* storage_ty = nullptr;
280
281 // If the variable has a declared type, resolve it.
282 if (auto* ty = var->type) {
283 storage_ty = Type(ty);
284 if (!storage_ty) {
285 return nullptr;
286 }
287 }
288
289 const sem::Expression* rhs = nullptr;
290
291 // Does the variable have a constructor?
292 if (var->constructor) {
293 rhs = Expression(var->constructor);
294 if (!rhs) {
295 return nullptr;
296 }
297
298 // If the variable has no declared type, infer it from the RHS
299 if (!storage_ty) {
300 if (!var->is_const && kind == VariableKind::kGlobal) {
301 AddError("global var declaration must specify a type", var->source);
302 return nullptr;
303 }
304
305 storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
306 }
307 } else if (var->is_const && kind != VariableKind::kParameter &&
308 !ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
309 AddError("let declaration must have an initializer", var->source);
310 return nullptr;
311 } else if (!var->type) {
312 AddError(
313 (kind == VariableKind::kGlobal)
314 ? "module scope var declaration requires a type and initializer"
315 : "function scope var declaration requires a type or initializer",
316 var->source);
317 return nullptr;
318 }
319
320 if (!storage_ty) {
321 TINT_ICE(Resolver, diagnostics_)
322 << "failed to determine storage type for variable '" +
323 builder_->Symbols().NameFor(var->symbol) + "'\n"
324 << "Source: " << var->source;
325 return nullptr;
326 }
327
328 auto storage_class = var->declared_storage_class;
329 if (storage_class == ast::StorageClass::kNone && !var->is_const) {
330 // No declared storage class. Infer from usage / type.
331 if (kind == VariableKind::kLocal) {
332 storage_class = ast::StorageClass::kFunction;
333 } else if (storage_ty->UnwrapRef()->is_handle()) {
334 // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
335 // If the store type is a texture type or a sampler type, then the
336 // variable declaration must not have a storage class decoration. The
337 // storage class will always be handle.
338 storage_class = ast::StorageClass::kUniformConstant;
339 }
340 }
341
342 if (kind == VariableKind::kLocal && !var->is_const &&
343 storage_class != ast::StorageClass::kFunction &&
344 IsValidationEnabled(var->decorations,
345 ast::DisabledValidation::kIgnoreStorageClass)) {
346 AddError("function variable has a non-function storage class", var->source);
347 return nullptr;
348 }
349
350 auto access = var->declared_access;
351 if (access == ast::Access::kUndefined) {
352 access = DefaultAccessForStorageClass(storage_class);
353 }
354
355 auto* var_ty = storage_ty;
356 if (!var->is_const) {
357 // Variable declaration. Unlike `let`, `var` has storage.
358 // Variables are always of a reference type to the declared storage type.
359 var_ty =
360 builder_->create<sem::Reference>(storage_ty, storage_class, access);
361 }
362
363 if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
364 rhs->Type())) {
365 return nullptr;
366 }
367
368 if (!ApplyStorageClassUsageToType(
369 storage_class, const_cast<sem::Type*>(var_ty), var->source)) {
370 AddNote(
371 std::string("while instantiating ") +
372 ((kind == VariableKind::kParameter) ? "parameter " : "variable ") +
373 builder_->Symbols().NameFor(var->symbol),
374 var->source);
375 return nullptr;
376 }
377
378 if (kind == VariableKind::kParameter) {
379 if (auto* ptr = var_ty->As<sem::Pointer>()) {
380 // For MSL, we push module-scope variables into the entry point as pointer
381 // parameters, so we also need to handle their store type.
382 if (!ApplyStorageClassUsageToType(
383 ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()),
384 var->source)) {
385 AddNote("while instantiating parameter " +
386 builder_->Symbols().NameFor(var->symbol),
387 var->source);
388 return nullptr;
389 }
390 }
391 }
392
393 switch (kind) {
394 case VariableKind::kGlobal: {
395 sem::BindingPoint binding_point;
396 if (auto bp = var->BindingPoint()) {
397 binding_point = {bp.group->value, bp.binding->value};
398 }
399
400 auto* override =
401 ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
402 bool has_const_val = rhs && var->is_const && !override;
403
404 auto* global = builder_->create<sem::GlobalVariable>(
405 var, var_ty, storage_class, access,
406 has_const_val ? rhs->ConstantValue() : sem::Constant{},
407 binding_point);
408
409 if (override) {
410 global->SetIsOverridable();
411 if (override->has_value) {
412 global->SetConstantId(static_cast<uint16_t>(override->value));
413 }
414 }
415
416 global->SetConstructor(rhs);
417
418 builder_->Sem().Add(var, global);
419 return global;
420 }
421 case VariableKind::kLocal: {
422 auto* local = builder_->create<sem::LocalVariable>(
423 var, var_ty, storage_class, access, current_statement_,
424 (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
425 builder_->Sem().Add(var, local);
426 local->SetConstructor(rhs);
427 return local;
428 }
429 case VariableKind::kParameter: {
430 auto* param = builder_->create<sem::Parameter>(var, index, var_ty,
431 storage_class, access);
432 builder_->Sem().Add(var, param);
433 return param;
434 }
435 }
436
437 TINT_UNREACHABLE(Resolver, diagnostics_)
438 << "unhandled VariableKind " << static_cast<int>(kind);
439 return nullptr;
440 }
441
DefaultAccessForStorageClass(ast::StorageClass storage_class)442 ast::Access Resolver::DefaultAccessForStorageClass(
443 ast::StorageClass storage_class) {
444 // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
445 switch (storage_class) {
446 case ast::StorageClass::kStorage:
447 case ast::StorageClass::kUniform:
448 case ast::StorageClass::kUniformConstant:
449 return ast::Access::kRead;
450 default:
451 break;
452 }
453 return ast::Access::kReadWrite;
454 }
455
AllocateOverridableConstantIds()456 void Resolver::AllocateOverridableConstantIds() {
457 // The next pipeline constant ID to try to allocate.
458 uint16_t next_constant_id = 0;
459
460 // Allocate constant IDs in global declaration order, so that they are
461 // deterministic.
462 // TODO(crbug.com/tint/1192): If a transform changes the order or removes an
463 // unused constant, the allocation may change on the next Resolver pass.
464 for (auto* decl : builder_->AST().GlobalDeclarations()) {
465 auto* var = decl->As<ast::Variable>();
466 if (!var) {
467 continue;
468 }
469 auto* override_deco =
470 ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
471 if (!override_deco) {
472 continue;
473 }
474
475 uint16_t constant_id;
476 if (override_deco->has_value) {
477 constant_id = static_cast<uint16_t>(override_deco->value);
478 } else {
479 // No ID was specified, so allocate the next available ID.
480 constant_id = next_constant_id;
481 while (constant_ids_.count(constant_id)) {
482 if (constant_id == UINT16_MAX) {
483 TINT_ICE(Resolver, builder_->Diagnostics())
484 << "no more pipeline constant IDs available";
485 return;
486 }
487 constant_id++;
488 }
489 next_constant_id = constant_id + 1;
490 }
491
492 auto* sem = Sem<sem::GlobalVariable>(var);
493 const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
494 }
495 }
496
SetShadows()497 void Resolver::SetShadows() {
498 for (auto it : dependencies_.shadows) {
499 auto* var = Sem(it.first);
500 if (auto* local = var->As<sem::LocalVariable>()) {
501 local->SetShadows(Sem(it.second));
502 }
503 if (auto* param = var->As<sem::Parameter>()) {
504 param->SetShadows(Sem(it.second));
505 }
506 }
507 } // namespace resolver
508
GlobalVariable(const ast::Variable * var)509 bool Resolver::GlobalVariable(const ast::Variable* var) {
510 auto* sem = Variable(var, VariableKind::kGlobal);
511 if (!sem) {
512 return false;
513 }
514
515 auto storage_class = sem->StorageClass();
516 if (!var->is_const && storage_class == ast::StorageClass::kNone) {
517 AddError("global variables must have a storage class", var->source);
518 return false;
519 }
520 if (var->is_const && storage_class != ast::StorageClass::kNone) {
521 AddError("global constants shouldn't have a storage class", var->source);
522 return false;
523 }
524
525 for (auto* deco : var->decorations) {
526 Mark(deco);
527
528 if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
529 // Track the constant IDs that are specified in the shader.
530 if (override_deco->has_value) {
531 constant_ids_.emplace(override_deco->value, sem);
532 }
533 }
534 }
535
536 if (!ValidateNoDuplicateDecorations(var->decorations)) {
537 return false;
538 }
539
540 if (!ValidateGlobalVariable(sem)) {
541 return false;
542 }
543
544 // TODO(bclayton): Call this at the end of resolve on all uniform and storage
545 // referenced structs
546 if (!ValidateStorageClassLayout(sem)) {
547 return false;
548 }
549
550 return true;
551 }
552
Function(const ast::Function * decl)553 sem::Function* Resolver::Function(const ast::Function* decl) {
554 uint32_t parameter_index = 0;
555 std::unordered_map<Symbol, Source> parameter_names;
556 std::vector<sem::Parameter*> parameters;
557
558 // Resolve all the parameters
559 for (auto* param : decl->params) {
560 Mark(param);
561
562 { // Check the parameter name is unique for the function
563 auto emplaced = parameter_names.emplace(param->symbol, param->source);
564 if (!emplaced.second) {
565 auto name = builder_->Symbols().NameFor(param->symbol);
566 AddError("redefinition of parameter '" + name + "'", param->source);
567 AddNote("previous definition is here", emplaced.first->second);
568 return nullptr;
569 }
570 }
571
572 auto* var = As<sem::Parameter>(
573 Variable(param, VariableKind::kParameter, parameter_index++));
574 if (!var) {
575 return nullptr;
576 }
577
578 for (auto* deco : param->decorations) {
579 Mark(deco);
580 }
581 if (!ValidateNoDuplicateDecorations(param->decorations)) {
582 return nullptr;
583 }
584
585 parameters.emplace_back(var);
586
587 auto* var_ty = const_cast<sem::Type*>(var->Type());
588 if (auto* str = var_ty->As<sem::Struct>()) {
589 switch (decl->PipelineStage()) {
590 case ast::PipelineStage::kVertex:
591 str->AddUsage(sem::PipelineStageUsage::kVertexInput);
592 break;
593 case ast::PipelineStage::kFragment:
594 str->AddUsage(sem::PipelineStageUsage::kFragmentInput);
595 break;
596 case ast::PipelineStage::kCompute:
597 str->AddUsage(sem::PipelineStageUsage::kComputeInput);
598 break;
599 case ast::PipelineStage::kNone:
600 break;
601 }
602 }
603 }
604
605 // Resolve the return type
606 sem::Type* return_type = nullptr;
607 if (auto* ty = decl->return_type) {
608 return_type = Type(ty);
609 if (!return_type) {
610 return nullptr;
611 }
612 } else {
613 return_type = builder_->create<sem::Void>();
614 }
615
616 if (auto* str = return_type->As<sem::Struct>()) {
617 if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
618 decl->source)) {
619 AddNote("while instantiating return type for " +
620 builder_->Symbols().NameFor(decl->symbol),
621 decl->source);
622 return nullptr;
623 }
624
625 switch (decl->PipelineStage()) {
626 case ast::PipelineStage::kVertex:
627 str->AddUsage(sem::PipelineStageUsage::kVertexOutput);
628 break;
629 case ast::PipelineStage::kFragment:
630 str->AddUsage(sem::PipelineStageUsage::kFragmentOutput);
631 break;
632 case ast::PipelineStage::kCompute:
633 str->AddUsage(sem::PipelineStageUsage::kComputeOutput);
634 break;
635 case ast::PipelineStage::kNone:
636 break;
637 }
638 }
639
640 auto* func = builder_->create<sem::Function>(decl, return_type, parameters);
641 builder_->Sem().Add(decl, func);
642
643 TINT_SCOPED_ASSIGNMENT(current_function_, func);
644
645 if (!WorkgroupSize(decl)) {
646 return nullptr;
647 }
648
649 if (decl->IsEntryPoint()) {
650 entry_points_.emplace_back(func);
651 }
652
653 if (decl->body) {
654 Mark(decl->body);
655 if (current_compound_statement_) {
656 TINT_ICE(Resolver, diagnostics_)
657 << "Resolver::Function() called with a current compound statement";
658 return nullptr;
659 }
660 auto* body = StatementScope(
661 decl->body, builder_->create<sem::FunctionBlockStatement>(func),
662 [&] { return Statements(decl->body->statements); });
663 if (!body) {
664 return nullptr;
665 }
666 func->Behaviors() = body->Behaviors();
667 if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
668 // https://www.w3.org/TR/WGSL/#behaviors-rules
669 // We assign a behavior to each function: it is its body’s behavior
670 // (treating the body as a regular statement), with any "Return" replaced
671 // by "Next".
672 func->Behaviors().Remove(sem::Behavior::kReturn);
673 func->Behaviors().Add(sem::Behavior::kNext);
674 }
675 }
676
677 for (auto* deco : decl->decorations) {
678 Mark(deco);
679 }
680 if (!ValidateNoDuplicateDecorations(decl->decorations)) {
681 return nullptr;
682 }
683
684 for (auto* deco : decl->return_type_decorations) {
685 Mark(deco);
686 }
687 if (!ValidateNoDuplicateDecorations(decl->return_type_decorations)) {
688 return nullptr;
689 }
690
691 if (!ValidateFunction(func)) {
692 return nullptr;
693 }
694
695 // If this is an entry point, mark all transitively called functions as being
696 // used by this entry point.
697 if (decl->IsEntryPoint()) {
698 for (auto* f : func->TransitivelyCalledFunctions()) {
699 const_cast<sem::Function*>(f)->AddAncestorEntryPoint(func);
700 }
701 }
702
703 return func;
704 }
705
WorkgroupSize(const ast::Function * func)706 bool Resolver::WorkgroupSize(const ast::Function* func) {
707 // Set work-group size defaults.
708 sem::WorkgroupSize ws;
709 for (int i = 0; i < 3; i++) {
710 ws[i].value = 1;
711 ws[i].overridable_const = nullptr;
712 }
713
714 auto* deco = ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations);
715 if (!deco) {
716 return true;
717 }
718
719 auto values = deco->Values();
720 auto any_i32 = false;
721 auto any_u32 = false;
722 for (int i = 0; i < 3; i++) {
723 // Each argument to this decoration can either be a literal, an
724 // identifier for a module-scope constants, or nullptr if not specified.
725
726 auto* expr = values[i];
727 if (!expr) {
728 // Not specified, just use the default.
729 continue;
730 }
731
732 auto* expr_sem = Expression(expr);
733 if (!expr_sem) {
734 return false;
735 }
736
737 constexpr const char* kErrBadType =
738 "workgroup_size argument must be either literal or module-scope "
739 "constant of type i32 or u32";
740 constexpr const char* kErrInconsistentType =
741 "workgroup_size arguments must be of the same type, either i32 "
742 "or u32";
743
744 auto* ty = TypeOf(expr);
745 bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
746 bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
747 if (!is_i32 && !is_u32) {
748 AddError(kErrBadType, expr->source);
749 return false;
750 }
751
752 any_i32 = any_i32 || is_i32;
753 any_u32 = any_u32 || is_u32;
754 if (any_i32 && any_u32) {
755 AddError(kErrInconsistentType, expr->source);
756 return false;
757 }
758
759 sem::Constant value;
760
761 if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
762 // We have an variable of a module-scope constant.
763 auto* decl = user->Variable()->Declaration();
764 if (!decl->is_const) {
765 AddError(kErrBadType, expr->source);
766 return false;
767 }
768 // Capture the constant if an [[override]] attribute is present.
769 if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
770 ws[i].overridable_const = decl;
771 }
772
773 if (decl->constructor) {
774 value = Sem(decl->constructor)->ConstantValue();
775 } else {
776 // No constructor means this value must be overriden by the user.
777 ws[i].value = 0;
778 continue;
779 }
780 } else if (expr->Is<ast::LiteralExpression>()) {
781 value = Sem(expr)->ConstantValue();
782 } else {
783 AddError(
784 "workgroup_size argument must be either a literal or a "
785 "module-scope constant",
786 values[i]->source);
787 return false;
788 }
789
790 if (!value) {
791 TINT_ICE(Resolver, diagnostics_)
792 << "could not resolve constant workgroup_size constant value";
793 continue;
794 }
795 // Validate and set the default value for this dimension.
796 if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
797 AddError("workgroup_size argument must be at least 1", values[i]->source);
798 return false;
799 }
800
801 ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
802 : value.Elements()[0].u32;
803 }
804
805 current_function_->SetWorkgroupSize(std::move(ws));
806 return true;
807 }
808
Statements(const ast::StatementList & stmts)809 bool Resolver::Statements(const ast::StatementList& stmts) {
810 sem::Behaviors behaviors{sem::Behavior::kNext};
811
812 bool reachable = true;
813 for (auto* stmt : stmts) {
814 Mark(stmt);
815 auto* sem = Statement(stmt);
816 if (!sem) {
817 return false;
818 }
819 // s1 s2:(B1∖{Next}) ∪ B2
820 sem->SetIsReachable(reachable);
821 if (reachable) {
822 behaviors = (behaviors - sem::Behavior::kNext) + sem->Behaviors();
823 }
824 reachable = reachable && sem->Behaviors().Contains(sem::Behavior::kNext);
825 }
826
827 current_statement_->Behaviors() = behaviors;
828
829 if (!ValidateStatements(stmts)) {
830 return false;
831 }
832
833 return true;
834 }
835
Statement(const ast::Statement * stmt)836 sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
837 if (stmt->Is<ast::CaseStatement>()) {
838 AddError("case statement can only be used inside a switch statement",
839 stmt->source);
840 return nullptr;
841 }
842 if (stmt->Is<ast::ElseStatement>()) {
843 TINT_ICE(Resolver, diagnostics_)
844 << "Resolver::Statement() encountered an Else statement. Else "
845 "statements are embedded in If statements, so should never be "
846 "encountered as top-level statements";
847 return nullptr;
848 }
849
850 // Compound statements. These create their own sem::CompoundStatement
851 // bindings.
852 if (auto* b = stmt->As<ast::BlockStatement>()) {
853 return BlockStatement(b);
854 }
855 if (auto* l = stmt->As<ast::ForLoopStatement>()) {
856 return ForLoopStatement(l);
857 }
858 if (auto* l = stmt->As<ast::LoopStatement>()) {
859 return LoopStatement(l);
860 }
861 if (auto* i = stmt->As<ast::IfStatement>()) {
862 return IfStatement(i);
863 }
864 if (auto* s = stmt->As<ast::SwitchStatement>()) {
865 return SwitchStatement(s);
866 }
867
868 // Non-Compound statements
869 if (auto* a = stmt->As<ast::AssignmentStatement>()) {
870 return AssignmentStatement(a);
871 }
872 if (auto* b = stmt->As<ast::BreakStatement>()) {
873 return BreakStatement(b);
874 }
875 if (auto* c = stmt->As<ast::CallStatement>()) {
876 return CallStatement(c);
877 }
878 if (auto* c = stmt->As<ast::ContinueStatement>()) {
879 return ContinueStatement(c);
880 }
881 if (auto* d = stmt->As<ast::DiscardStatement>()) {
882 return DiscardStatement(d);
883 }
884 if (auto* f = stmt->As<ast::FallthroughStatement>()) {
885 return FallthroughStatement(f);
886 }
887 if (auto* r = stmt->As<ast::ReturnStatement>()) {
888 return ReturnStatement(r);
889 }
890 if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
891 return VariableDeclStatement(v);
892 }
893
894 AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
895 stmt->source);
896 return nullptr;
897 }
898
CaseStatement(const ast::CaseStatement * stmt)899 sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
900 auto* sem = builder_->create<sem::CaseStatement>(
901 stmt, current_compound_statement_, current_function_);
902 return StatementScope(stmt, sem, [&] {
903 for (auto* sel : stmt->selectors) {
904 Mark(sel);
905 }
906 Mark(stmt->body);
907 auto* body = BlockStatement(stmt->body);
908 if (!body) {
909 return false;
910 }
911 sem->SetBlock(body);
912 sem->Behaviors() = body->Behaviors();
913 return true;
914 });
915 }
916
IfStatement(const ast::IfStatement * stmt)917 sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
918 auto* sem = builder_->create<sem::IfStatement>(
919 stmt, current_compound_statement_, current_function_);
920 return StatementScope(stmt, sem, [&] {
921 auto* cond = Expression(stmt->condition);
922 if (!cond) {
923 return false;
924 }
925 sem->SetCondition(cond);
926 sem->Behaviors() = cond->Behaviors();
927 sem->Behaviors().Remove(sem::Behavior::kNext);
928
929 Mark(stmt->body);
930 auto* body = builder_->create<sem::BlockStatement>(
931 stmt->body, current_compound_statement_, current_function_);
932 if (!StatementScope(stmt->body, body,
933 [&] { return Statements(stmt->body->statements); })) {
934 return false;
935 }
936 sem->Behaviors().Add(body->Behaviors());
937
938 for (auto* else_stmt : stmt->else_statements) {
939 Mark(else_stmt);
940 auto* else_sem = ElseStatement(else_stmt);
941 if (!else_sem) {
942 return false;
943 }
944 sem->Behaviors().Add(else_sem->Behaviors());
945 }
946
947 if (stmt->else_statements.empty() ||
948 stmt->else_statements.back()->condition != nullptr) {
949 // https://www.w3.org/TR/WGSL/#behaviors-rules
950 // if statements without an else branch are treated as if they had an
951 // empty else branch (which adds Next to their behavior)
952 sem->Behaviors().Add(sem::Behavior::kNext);
953 }
954
955 return ValidateIfStatement(sem);
956 });
957 }
958
ElseStatement(const ast::ElseStatement * stmt)959 sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
960 auto* sem = builder_->create<sem::ElseStatement>(
961 stmt, current_compound_statement_, current_function_);
962 return StatementScope(stmt, sem, [&] {
963 if (auto* cond_expr = stmt->condition) {
964 auto* cond = Expression(cond_expr);
965 if (!cond) {
966 return false;
967 }
968 sem->SetCondition(cond);
969 // https://www.w3.org/TR/WGSL/#behaviors-rules
970 // if statements with else if branches are treated as if they were nested
971 // simple if/else statements
972 sem->Behaviors() = cond->Behaviors();
973 }
974 sem->Behaviors().Remove(sem::Behavior::kNext);
975
976 Mark(stmt->body);
977 auto* body = builder_->create<sem::BlockStatement>(
978 stmt->body, current_compound_statement_, current_function_);
979 if (!StatementScope(stmt->body, body,
980 [&] { return Statements(stmt->body->statements); })) {
981 return false;
982 }
983 sem->Behaviors().Add(body->Behaviors());
984
985 return ValidateElseStatement(sem);
986 });
987 }
988
BlockStatement(const ast::BlockStatement * stmt)989 sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
990 auto* sem = builder_->create<sem::BlockStatement>(
991 stmt->As<ast::BlockStatement>(), current_compound_statement_,
992 current_function_);
993 return StatementScope(stmt, sem,
994 [&] { return Statements(stmt->statements); });
995 }
996
LoopStatement(const ast::LoopStatement * stmt)997 sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
998 auto* sem = builder_->create<sem::LoopStatement>(
999 stmt, current_compound_statement_, current_function_);
1000 return StatementScope(stmt, sem, [&] {
1001 Mark(stmt->body);
1002
1003 auto* body = builder_->create<sem::LoopBlockStatement>(
1004 stmt->body, current_compound_statement_, current_function_);
1005 return StatementScope(stmt->body, body, [&] {
1006 if (!Statements(stmt->body->statements)) {
1007 return false;
1008 }
1009 auto& behaviors = sem->Behaviors();
1010 behaviors = body->Behaviors();
1011
1012 if (stmt->continuing) {
1013 Mark(stmt->continuing);
1014 if (!stmt->continuing->Empty()) {
1015 auto* continuing = StatementScope(
1016 stmt->continuing,
1017 builder_->create<sem::LoopContinuingBlockStatement>(
1018 stmt->continuing, current_compound_statement_,
1019 current_function_),
1020 [&] { return Statements(stmt->continuing->statements); });
1021 if (!continuing) {
1022 return false;
1023 }
1024 behaviors.Add(continuing->Behaviors());
1025 }
1026 }
1027
1028 if (behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
1029 behaviors.Add(sem::Behavior::kNext);
1030 } else {
1031 behaviors.Remove(sem::Behavior::kNext);
1032 }
1033 behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
1034
1035 return true;
1036 });
1037 });
1038 }
1039
ForLoopStatement(const ast::ForLoopStatement * stmt)1040 sem::ForLoopStatement* Resolver::ForLoopStatement(
1041 const ast::ForLoopStatement* stmt) {
1042 auto* sem = builder_->create<sem::ForLoopStatement>(
1043 stmt, current_compound_statement_, current_function_);
1044 return StatementScope(stmt, sem, [&] {
1045 auto& behaviors = sem->Behaviors();
1046 if (auto* initializer = stmt->initializer) {
1047 Mark(initializer);
1048 auto* init = Statement(initializer);
1049 if (!init) {
1050 return false;
1051 }
1052 behaviors.Add(init->Behaviors());
1053 }
1054
1055 if (auto* cond_expr = stmt->condition) {
1056 auto* cond = Expression(cond_expr);
1057 if (!cond) {
1058 return false;
1059 }
1060 sem->SetCondition(cond);
1061 behaviors.Add(cond->Behaviors());
1062 }
1063
1064 if (auto* continuing = stmt->continuing) {
1065 Mark(continuing);
1066 auto* cont = Statement(continuing);
1067 if (!cont) {
1068 return false;
1069 }
1070 behaviors.Add(cont->Behaviors());
1071 }
1072
1073 Mark(stmt->body);
1074
1075 auto* body = builder_->create<sem::LoopBlockStatement>(
1076 stmt->body, current_compound_statement_, current_function_);
1077 if (!StatementScope(stmt->body, body,
1078 [&] { return Statements(stmt->body->statements); })) {
1079 return false;
1080 }
1081
1082 behaviors.Add(body->Behaviors());
1083 if (stmt->condition ||
1084 behaviors.Contains(sem::Behavior::kBreak)) { // Does the loop exit?
1085 behaviors.Add(sem::Behavior::kNext);
1086 } else {
1087 behaviors.Remove(sem::Behavior::kNext);
1088 }
1089 behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
1090
1091 return ValidateForLoopStatement(sem);
1092 });
1093 }
1094
Expression(const ast::Expression * root)1095 sem::Expression* Resolver::Expression(const ast::Expression* root) {
1096 std::vector<const ast::Expression*> sorted;
1097 bool mark_failed = false;
1098 if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
1099 root, diagnostics_, [&](const ast::Expression* expr) {
1100 if (!Mark(expr)) {
1101 mark_failed = true;
1102 return ast::TraverseAction::Stop;
1103 }
1104 sorted.emplace_back(expr);
1105 return ast::TraverseAction::Descend;
1106 })) {
1107 return nullptr;
1108 }
1109
1110 if (mark_failed) {
1111 return nullptr;
1112 }
1113
1114 for (auto* expr : utils::Reverse(sorted)) {
1115 sem::Expression* sem_expr = nullptr;
1116 if (auto* array = expr->As<ast::IndexAccessorExpression>()) {
1117 sem_expr = IndexAccessor(array);
1118 } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
1119 sem_expr = Binary(bin_op);
1120 } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
1121 sem_expr = Bitcast(bitcast);
1122 } else if (auto* call = expr->As<ast::CallExpression>()) {
1123 sem_expr = Call(call);
1124 } else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
1125 sem_expr = Identifier(ident);
1126 } else if (auto* literal = expr->As<ast::LiteralExpression>()) {
1127 sem_expr = Literal(literal);
1128 } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
1129 sem_expr = MemberAccessor(member);
1130 } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
1131 sem_expr = UnaryOp(unary);
1132 } else if (expr->Is<ast::PhonyExpression>()) {
1133 sem_expr = builder_->create<sem::Expression>(
1134 expr, builder_->create<sem::Void>(), current_statement_,
1135 sem::Constant{});
1136 } else {
1137 TINT_ICE(Resolver, diagnostics_)
1138 << "unhandled expression type: " << expr->TypeInfo().name;
1139 return nullptr;
1140 }
1141 if (!sem_expr) {
1142 return nullptr;
1143 }
1144
1145 // https://www.w3.org/TR/WGSL/#behaviors-rules
1146 // an expression behavior is always either {Next} or {Next, Discard}
1147 if (sem_expr->Behaviors() != sem::Behavior::kNext &&
1148 sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext, // NOLINT
1149 sem::Behavior::kDiscard} &&
1150 !IsCallStatement(expr)) {
1151 TINT_ICE(Resolver, diagnostics_)
1152 << expr->TypeInfo().name
1153 << " behaviors are: " << sem_expr->Behaviors();
1154 return nullptr;
1155 }
1156
1157 builder_->Sem().Add(expr, sem_expr);
1158 if (expr == root) {
1159 return sem_expr;
1160 }
1161 }
1162
1163 TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
1164 return nullptr;
1165 }
1166
IndexAccessor(const ast::IndexAccessorExpression * expr)1167 sem::Expression* Resolver::IndexAccessor(
1168 const ast::IndexAccessorExpression* expr) {
1169 auto* idx = Sem(expr->index);
1170 auto* obj = Sem(expr->object);
1171 auto* obj_raw_ty = obj->Type();
1172 auto* obj_ty = obj_raw_ty->UnwrapRef();
1173 const sem::Type* ty = nullptr;
1174 if (auto* arr = obj_ty->As<sem::Array>()) {
1175 ty = arr->ElemType();
1176 } else if (auto* vec = obj_ty->As<sem::Vector>()) {
1177 ty = vec->type();
1178 } else if (auto* mat = obj_ty->As<sem::Matrix>()) {
1179 ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
1180 } else {
1181 AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
1182 return nullptr;
1183 }
1184
1185 auto* idx_ty = idx->Type()->UnwrapRef();
1186 if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
1187 AddError("index must be of type 'i32' or 'u32', found: '" +
1188 TypeNameOf(idx_ty) + "'",
1189 idx->Declaration()->source);
1190 return nullptr;
1191 }
1192
1193 if (obj_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
1194 if (!obj_raw_ty->Is<sem::Reference>()) {
1195 // TODO(bclayton): expand this to allow any const_expr expression
1196 // https://github.com/gpuweb/gpuweb/issues/1272
1197 if (!idx->Declaration()->As<ast::IntLiteralExpression>()) {
1198 AddError("index must be signed or unsigned integer literal",
1199 idx->Declaration()->source);
1200 return nullptr;
1201 }
1202 }
1203 }
1204
1205 // If we're extracting from a reference, we return a reference.
1206 if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
1207 ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
1208 ref->Access());
1209 }
1210
1211 auto val = EvaluateConstantValue(expr, ty);
1212 auto* sem =
1213 builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1214 sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
1215 return sem;
1216 }
1217
Bitcast(const ast::BitcastExpression * expr)1218 sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
1219 auto* inner = Sem(expr->expr);
1220 auto* ty = Type(expr->type);
1221 if (!ty) {
1222 return nullptr;
1223 }
1224
1225 auto val = EvaluateConstantValue(expr, ty);
1226 auto* sem =
1227 builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1228
1229 sem->Behaviors() = inner->Behaviors();
1230
1231 if (!ValidateBitcast(expr, ty)) {
1232 return nullptr;
1233 }
1234
1235 return sem;
1236 }
1237
Call(const ast::CallExpression * expr)1238 sem::Call* Resolver::Call(const ast::CallExpression* expr) {
1239 std::vector<const sem::Expression*> args(expr->args.size());
1240 std::vector<const sem::Type*> arg_tys(args.size());
1241 sem::Behaviors arg_behaviors;
1242
1243 for (size_t i = 0; i < expr->args.size(); i++) {
1244 auto* arg = Sem(expr->args[i]);
1245 if (!arg) {
1246 return nullptr;
1247 }
1248 args[i] = arg;
1249 arg_tys[i] = args[i]->Type();
1250 arg_behaviors.Add(arg->Behaviors());
1251 }
1252
1253 arg_behaviors.Remove(sem::Behavior::kNext);
1254
1255 auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
1256 // The call has resolved to a type constructor or cast.
1257 if (args.size() == 1) {
1258 auto* target = ty;
1259 auto* source = args[0]->Type()->UnwrapRef();
1260 if ((source != target) && //
1261 ((source->is_scalar() && target->is_scalar()) ||
1262 (source->Is<sem::Vector>() && target->Is<sem::Vector>()) ||
1263 (source->Is<sem::Matrix>() && target->Is<sem::Matrix>()))) {
1264 // Note: Matrix types currently cannot be converted (the element type
1265 // must only be f32). We implement this for the day we support other
1266 // matrix element types.
1267 return TypeConversion(expr, ty, args[0], arg_tys[0]);
1268 }
1269 }
1270 return TypeConstructor(expr, ty, std::move(args), std::move(arg_tys));
1271 };
1272
1273 // Resolve the target of the CallExpression to determine whether this is a
1274 // function call, cast or type constructor expression.
1275 if (expr->target.type) {
1276 auto* ty = Type(expr->target.type);
1277 if (!ty) {
1278 return nullptr;
1279 }
1280 return type_ctor_or_conv(ty);
1281 }
1282
1283 auto* ident = expr->target.name;
1284 Mark(ident);
1285
1286 auto* resolved = ResolvedSymbol(ident);
1287 if (auto* ty = As<sem::Type>(resolved)) {
1288 return type_ctor_or_conv(ty);
1289 }
1290
1291 if (auto* fn = As<sem::Function>(resolved)) {
1292 return FunctionCall(expr, fn, std::move(args), arg_behaviors);
1293 }
1294
1295 auto name = builder_->Symbols().NameFor(ident->symbol);
1296 auto intrinsic_type = sem::ParseIntrinsicType(name);
1297 if (intrinsic_type != sem::IntrinsicType::kNone) {
1298 return IntrinsicCall(expr, intrinsic_type, std::move(args),
1299 std::move(arg_tys));
1300 }
1301
1302 TINT_ICE(Resolver, diagnostics_)
1303 << expr->source << " unresolved CallExpression target:\n"
1304 << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
1305 << "\n"
1306 << "name: " << builder_->Symbols().NameFor(ident->symbol);
1307 return nullptr;
1308 }
1309
IntrinsicCall(const ast::CallExpression * expr,sem::IntrinsicType intrinsic_type,const std::vector<const sem::Expression * > args,const std::vector<const sem::Type * > arg_tys)1310 sem::Call* Resolver::IntrinsicCall(
1311 const ast::CallExpression* expr,
1312 sem::IntrinsicType intrinsic_type,
1313 const std::vector<const sem::Expression*> args,
1314 const std::vector<const sem::Type*> arg_tys) {
1315 auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys),
1316 expr->source);
1317 if (!intrinsic) {
1318 return nullptr;
1319 }
1320
1321 if (intrinsic->IsDeprecated()) {
1322 AddWarning("use of deprecated intrinsic", expr->source);
1323 }
1324
1325 auto* call = builder_->create<sem::Call>(expr, intrinsic, std::move(args),
1326 current_statement_, sem::Constant{});
1327
1328 current_function_->AddDirectlyCalledIntrinsic(intrinsic);
1329
1330 if (IsTextureIntrinsic(intrinsic_type) &&
1331 !ValidateTextureIntrinsicFunction(call)) {
1332 return nullptr;
1333 }
1334
1335 if (!ValidateIntrinsicCall(call)) {
1336 return nullptr;
1337 }
1338
1339 current_function_->AddDirectCall(call);
1340
1341 return call;
1342 }
1343
FunctionCall(const ast::CallExpression * expr,sem::Function * target,const std::vector<const sem::Expression * > args,sem::Behaviors arg_behaviors)1344 sem::Call* Resolver::FunctionCall(
1345 const ast::CallExpression* expr,
1346 sem::Function* target,
1347 const std::vector<const sem::Expression*> args,
1348 sem::Behaviors arg_behaviors) {
1349 auto sym = expr->target.name->symbol;
1350 auto name = builder_->Symbols().NameFor(sym);
1351
1352 auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
1353 current_statement_, sem::Constant{});
1354
1355 if (current_function_) {
1356 // Note: Requires called functions to be resolved first.
1357 // This is currently guaranteed as functions must be declared before
1358 // use.
1359 current_function_->AddTransitivelyCalledFunction(target);
1360 current_function_->AddDirectCall(call);
1361 for (auto* transitive_call : target->TransitivelyCalledFunctions()) {
1362 current_function_->AddTransitivelyCalledFunction(transitive_call);
1363 }
1364
1365 // We inherit any referenced variables from the callee.
1366 for (auto* var : target->TransitivelyReferencedGlobals()) {
1367 current_function_->AddTransitivelyReferencedGlobal(var);
1368 }
1369 }
1370
1371 target->AddCallSite(call);
1372
1373 call->Behaviors() = arg_behaviors + target->Behaviors();
1374
1375 if (!ValidateFunctionCall(call)) {
1376 return nullptr;
1377 }
1378
1379 return call;
1380 }
1381
TypeConversion(const ast::CallExpression * expr,const sem::Type * target,const sem::Expression * arg,const sem::Type * source)1382 sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
1383 const sem::Type* target,
1384 const sem::Expression* arg,
1385 const sem::Type* source) {
1386 // It is not valid to have a type-cast call expression inside a call
1387 // statement.
1388 if (IsCallStatement(expr)) {
1389 AddError("type cast evaluated but not used", expr->source);
1390 return nullptr;
1391 }
1392
1393 auto* call_target = utils::GetOrCreate(
1394 type_conversions_, TypeConversionSig{target, source},
1395 [&]() -> sem::TypeConversion* {
1396 // Now that the argument types have been determined, make sure that they
1397 // obey the conversion rules laid out in
1398 // https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
1399 bool ok = true;
1400 if (auto* vec_type = target->As<sem::Vector>()) {
1401 ok = ValidateVectorConstructorOrCast(expr, vec_type);
1402 } else if (auto* mat_type = target->As<sem::Matrix>()) {
1403 // Note: Matrix types currently cannot be converted (the element type
1404 // must only be f32). We implement this for the day we support other
1405 // matrix element types.
1406 ok = ValidateMatrixConstructorOrCast(expr, mat_type);
1407 } else if (target->is_scalar()) {
1408 ok = ValidateScalarConstructorOrCast(expr, target);
1409 } else if (auto* arr_type = target->As<sem::Array>()) {
1410 ok = ValidateArrayConstructorOrCast(expr, arr_type);
1411 } else if (auto* struct_type = target->As<sem::Struct>()) {
1412 ok = ValidateStructureConstructorOrCast(expr, struct_type);
1413 } else {
1414 AddError("type is not constructible", expr->source);
1415 return nullptr;
1416 }
1417 if (!ok) {
1418 return nullptr;
1419 }
1420
1421 auto* param = builder_->create<sem::Parameter>(
1422 nullptr, // declaration
1423 0, // index
1424 source->UnwrapRef(), // type
1425 ast::StorageClass::kNone, // storage_class
1426 ast::Access::kUndefined); // access
1427 return builder_->create<sem::TypeConversion>(target, param);
1428 });
1429
1430 if (!call_target) {
1431 return nullptr;
1432 }
1433
1434 auto val = EvaluateConstantValue(expr, target);
1435 return builder_->create<sem::Call>(expr, call_target,
1436 std::vector<const sem::Expression*>{arg},
1437 current_statement_, val);
1438 }
1439
TypeConstructor(const ast::CallExpression * expr,const sem::Type * ty,const std::vector<const sem::Expression * > args,const std::vector<const sem::Type * > arg_tys)1440 sem::Call* Resolver::TypeConstructor(
1441 const ast::CallExpression* expr,
1442 const sem::Type* ty,
1443 const std::vector<const sem::Expression*> args,
1444 const std::vector<const sem::Type*> arg_tys) {
1445 // It is not valid to have a type-constructor call expression as a call
1446 // statement.
1447 if (IsCallStatement(expr)) {
1448 AddError("type constructor evaluated but not used", expr->source);
1449 return nullptr;
1450 }
1451
1452 auto* call_target = utils::GetOrCreate(
1453 type_ctors_, TypeConstructorSig{ty, arg_tys},
1454 [&]() -> sem::TypeConstructor* {
1455 // Now that the argument types have been determined, make sure that they
1456 // obey the constructor type rules laid out in
1457 // https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
1458 bool ok = true;
1459 if (auto* vec_type = ty->As<sem::Vector>()) {
1460 ok = ValidateVectorConstructorOrCast(expr, vec_type);
1461 } else if (auto* mat_type = ty->As<sem::Matrix>()) {
1462 ok = ValidateMatrixConstructorOrCast(expr, mat_type);
1463 } else if (ty->is_scalar()) {
1464 ok = ValidateScalarConstructorOrCast(expr, ty);
1465 } else if (auto* arr_type = ty->As<sem::Array>()) {
1466 ok = ValidateArrayConstructorOrCast(expr, arr_type);
1467 } else if (auto* struct_type = ty->As<sem::Struct>()) {
1468 ok = ValidateStructureConstructorOrCast(expr, struct_type);
1469 } else {
1470 AddError("type is not constructible", expr->source);
1471 return nullptr;
1472 }
1473 if (!ok) {
1474 return nullptr;
1475 }
1476
1477 return builder_->create<sem::TypeConstructor>(
1478 ty, utils::Transform(
1479 arg_tys,
1480 [&](const sem::Type* t, size_t i) -> const sem::Parameter* {
1481 return builder_->create<sem::Parameter>(
1482 nullptr, // declaration
1483 i, // index
1484 t->UnwrapRef(), // type
1485 ast::StorageClass::kNone, // storage_class
1486 ast::Access::kUndefined); // access
1487 }));
1488 });
1489
1490 if (!call_target) {
1491 return nullptr;
1492 }
1493
1494 auto val = EvaluateConstantValue(expr, ty);
1495 return builder_->create<sem::Call>(expr, call_target, std::move(args),
1496 current_statement_, val);
1497 }
1498
Literal(const ast::LiteralExpression * literal)1499 sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
1500 auto* ty = TypeOf(literal);
1501 if (!ty) {
1502 return nullptr;
1503 }
1504
1505 auto val = EvaluateConstantValue(literal, ty);
1506 return builder_->create<sem::Expression>(literal, ty, current_statement_,
1507 val);
1508 }
1509
Identifier(const ast::IdentifierExpression * expr)1510 sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
1511 auto symbol = expr->symbol;
1512 auto* resolved = ResolvedSymbol(expr);
1513 if (auto* var = As<sem::Variable>(resolved)) {
1514 auto* user =
1515 builder_->create<sem::VariableUser>(expr, current_statement_, var);
1516
1517 if (current_statement_) {
1518 // If identifier is part of a loop continuing block, make sure it
1519 // doesn't refer to a variable that is bypassed by a continue statement
1520 // in the loop's body block.
1521 if (auto* continuing_block =
1522 current_statement_
1523 ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
1524 auto* loop_block =
1525 continuing_block->FindFirstParent<sem::LoopBlockStatement>();
1526 if (loop_block->FirstContinue()) {
1527 auto& decls = loop_block->Decls();
1528 // If our identifier is in loop_block->decls, make sure its index is
1529 // less than first_continue
1530 auto iter =
1531 std::find_if(decls.begin(), decls.end(),
1532 [&symbol](auto* v) { return v->symbol == symbol; });
1533 if (iter != decls.end()) {
1534 auto var_decl_index =
1535 static_cast<size_t>(std::distance(decls.begin(), iter));
1536 if (var_decl_index >= loop_block->NumDeclsAtFirstContinue()) {
1537 AddError("continue statement bypasses declaration of '" +
1538 builder_->Symbols().NameFor(symbol) + "'",
1539 loop_block->FirstContinue()->source);
1540 AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
1541 "' declared here",
1542 (*iter)->source);
1543 AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
1544 "' referenced in continuing block here",
1545 expr->source);
1546 return nullptr;
1547 }
1548 }
1549 }
1550 }
1551 }
1552
1553 if (current_function_) {
1554 if (auto* global = var->As<sem::GlobalVariable>()) {
1555 current_function_->AddDirectlyReferencedGlobal(global);
1556 }
1557 }
1558
1559 var->AddUser(user);
1560 return user;
1561 }
1562
1563 if (Is<sem::Function>(resolved)) {
1564 AddError("missing '(' for function call", expr->source.End());
1565 return nullptr;
1566 }
1567
1568 if (IsIntrinsic(symbol)) {
1569 AddError("missing '(' for intrinsic call", expr->source.End());
1570 return nullptr;
1571 }
1572
1573 if (resolved->Is<sem::Type>()) {
1574 AddError("missing '(' for type constructor or cast", expr->source.End());
1575 return nullptr;
1576 }
1577
1578 TINT_ICE(Resolver, diagnostics_)
1579 << expr->source << " unresolved identifier:\n"
1580 << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
1581 << "\n"
1582 << "name: " << builder_->Symbols().NameFor(symbol);
1583 return nullptr;
1584 }
1585
MemberAccessor(const ast::MemberAccessorExpression * expr)1586 sem::Expression* Resolver::MemberAccessor(
1587 const ast::MemberAccessorExpression* expr) {
1588 auto* structure = TypeOf(expr->structure);
1589 auto* storage_ty = structure->UnwrapRef();
1590
1591 const sem::Type* ret = nullptr;
1592 std::vector<uint32_t> swizzle;
1593
1594 if (auto* str = storage_ty->As<sem::Struct>()) {
1595 Mark(expr->member);
1596 auto symbol = expr->member->symbol;
1597
1598 const sem::StructMember* member = nullptr;
1599 for (auto* m : str->Members()) {
1600 if (m->Name() == symbol) {
1601 ret = m->Type();
1602 member = m;
1603 break;
1604 }
1605 }
1606
1607 if (ret == nullptr) {
1608 AddError(
1609 "struct member " + builder_->Symbols().NameFor(symbol) + " not found",
1610 expr->source);
1611 return nullptr;
1612 }
1613
1614 // If we're extracting from a reference, we return a reference.
1615 if (auto* ref = structure->As<sem::Reference>()) {
1616 ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
1617 ref->Access());
1618 }
1619
1620 return builder_->create<sem::StructMemberAccess>(
1621 expr, ret, current_statement_, member);
1622 }
1623
1624 if (auto* vec = storage_ty->As<sem::Vector>()) {
1625 Mark(expr->member);
1626 std::string s = builder_->Symbols().NameFor(expr->member->symbol);
1627 auto size = s.size();
1628 swizzle.reserve(s.size());
1629
1630 for (auto c : s) {
1631 switch (c) {
1632 case 'x':
1633 case 'r':
1634 swizzle.emplace_back(0);
1635 break;
1636 case 'y':
1637 case 'g':
1638 swizzle.emplace_back(1);
1639 break;
1640 case 'z':
1641 case 'b':
1642 swizzle.emplace_back(2);
1643 break;
1644 case 'w':
1645 case 'a':
1646 swizzle.emplace_back(3);
1647 break;
1648 default:
1649 AddError("invalid vector swizzle character",
1650 expr->member->source.Begin() + swizzle.size());
1651 return nullptr;
1652 }
1653
1654 if (swizzle.back() >= vec->Width()) {
1655 AddError("invalid vector swizzle member", expr->member->source);
1656 return nullptr;
1657 }
1658 }
1659
1660 if (size < 1 || size > 4) {
1661 AddError("invalid vector swizzle size", expr->member->source);
1662 return nullptr;
1663 }
1664
1665 // All characters are valid, check if they're being mixed
1666 auto is_rgba = [](char c) {
1667 return c == 'r' || c == 'g' || c == 'b' || c == 'a';
1668 };
1669 auto is_xyzw = [](char c) {
1670 return c == 'x' || c == 'y' || c == 'z' || c == 'w';
1671 };
1672 if (!std::all_of(s.begin(), s.end(), is_rgba) &&
1673 !std::all_of(s.begin(), s.end(), is_xyzw)) {
1674 AddError("invalid mixing of vector swizzle characters rgba with xyzw",
1675 expr->member->source);
1676 return nullptr;
1677 }
1678
1679 if (size == 1) {
1680 // A single element swizzle is just the type of the vector.
1681 ret = vec->type();
1682 // If we're extracting from a reference, we return a reference.
1683 if (auto* ref = structure->As<sem::Reference>()) {
1684 ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
1685 ref->Access());
1686 }
1687 } else {
1688 // The vector will have a number of components equal to the length of
1689 // the swizzle.
1690 ret = builder_->create<sem::Vector>(vec->type(),
1691 static_cast<uint32_t>(size));
1692 }
1693 return builder_->create<sem::Swizzle>(expr, ret, current_statement_,
1694 std::move(swizzle));
1695 }
1696
1697 AddError(
1698 "invalid member accessor expression. Expected vector or struct, got '" +
1699 TypeNameOf(storage_ty) + "'",
1700 expr->structure->source);
1701 return nullptr;
1702 }
1703
Binary(const ast::BinaryExpression * expr)1704 sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
1705 using Bool = sem::Bool;
1706 using F32 = sem::F32;
1707 using I32 = sem::I32;
1708 using U32 = sem::U32;
1709 using Matrix = sem::Matrix;
1710 using Vector = sem::Vector;
1711
1712 auto* lhs = Sem(expr->lhs);
1713 auto* rhs = Sem(expr->rhs);
1714
1715 auto* lhs_ty = lhs->Type()->UnwrapRef();
1716 auto* rhs_ty = rhs->Type()->UnwrapRef();
1717
1718 auto* lhs_vec = lhs_ty->As<Vector>();
1719 auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
1720 auto* rhs_vec = rhs_ty->As<Vector>();
1721 auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
1722
1723 const bool matching_vec_elem_types =
1724 lhs_vec_elem_type && rhs_vec_elem_type &&
1725 (lhs_vec_elem_type == rhs_vec_elem_type) &&
1726 (lhs_vec->Width() == rhs_vec->Width());
1727
1728 const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
1729
1730 auto build = [&](const sem::Type* ty) {
1731 auto val = EvaluateConstantValue(expr, ty);
1732 auto* sem =
1733 builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1734 sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
1735 return sem;
1736 };
1737
1738 // Binary logical expressions
1739 if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
1740 if (matching_types && lhs_ty->Is<Bool>()) {
1741 return build(lhs_ty);
1742 }
1743 }
1744 if (expr->IsOr() || expr->IsAnd()) {
1745 if (matching_types && lhs_ty->Is<Bool>()) {
1746 return build(lhs_ty);
1747 }
1748 if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
1749 return build(lhs_ty);
1750 }
1751 }
1752
1753 // Arithmetic expressions
1754 if (expr->IsArithmetic()) {
1755 // Binary arithmetic expressions over scalars
1756 if (matching_types && lhs_ty->is_numeric_scalar()) {
1757 return build(lhs_ty);
1758 }
1759
1760 // Binary arithmetic expressions over vectors
1761 if (matching_types && lhs_vec_elem_type &&
1762 lhs_vec_elem_type->is_numeric_scalar()) {
1763 return build(lhs_ty);
1764 }
1765
1766 // Binary arithmetic expressions with mixed scalar and vector operands
1767 if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) {
1768 if (expr->IsModulo()) {
1769 if (rhs_ty->is_integer_scalar()) {
1770 return build(lhs_ty);
1771 }
1772 } else if (rhs_ty->is_numeric_scalar()) {
1773 return build(lhs_ty);
1774 }
1775 }
1776 if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) {
1777 if (expr->IsModulo()) {
1778 if (lhs_ty->is_integer_scalar()) {
1779 return build(rhs_ty);
1780 }
1781 } else if (lhs_ty->is_numeric_scalar()) {
1782 return build(rhs_ty);
1783 }
1784 }
1785 }
1786
1787 // Matrix arithmetic
1788 auto* lhs_mat = lhs_ty->As<Matrix>();
1789 auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
1790 auto* rhs_mat = rhs_ty->As<Matrix>();
1791 auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
1792 // Addition and subtraction of float matrices
1793 if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
1794 lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
1795 rhs_mat_elem_type->Is<F32>() &&
1796 (lhs_mat->columns() == rhs_mat->columns()) &&
1797 (lhs_mat->rows() == rhs_mat->rows())) {
1798 return build(rhs_ty);
1799 }
1800 if (expr->IsMultiply()) {
1801 // Multiplication of a matrix and a scalar
1802 if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
1803 rhs_mat_elem_type->Is<F32>()) {
1804 return build(rhs_ty);
1805 }
1806 if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1807 rhs_ty->Is<F32>()) {
1808 return build(lhs_ty);
1809 }
1810
1811 // Vector times matrix
1812 if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
1813 rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
1814 (lhs_vec->Width() == rhs_mat->rows())) {
1815 return build(
1816 builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
1817 }
1818
1819 // Matrix times vector
1820 if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1821 rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
1822 (lhs_mat->columns() == rhs_vec->Width())) {
1823 return build(
1824 builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
1825 }
1826
1827 // Matrix times matrix
1828 if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1829 rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
1830 (lhs_mat->columns() == rhs_mat->rows())) {
1831 return build(builder_->create<sem::Matrix>(
1832 builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
1833 rhs_mat->columns()));
1834 }
1835 }
1836
1837 // Comparison expressions
1838 if (expr->IsComparison()) {
1839 if (matching_types) {
1840 // Special case for bools: only == and !=
1841 if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
1842 return build(builder_->create<sem::Bool>());
1843 }
1844
1845 // For the rest, we can compare i32, u32, and f32
1846 if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
1847 return build(builder_->create<sem::Bool>());
1848 }
1849 }
1850
1851 // Same for vectors
1852 if (matching_vec_elem_types) {
1853 if (lhs_vec_elem_type->Is<Bool>() &&
1854 (expr->IsEqual() || expr->IsNotEqual())) {
1855 return build(builder_->create<sem::Vector>(
1856 builder_->create<sem::Bool>(), lhs_vec->Width()));
1857 }
1858
1859 if (lhs_vec_elem_type->is_numeric_scalar()) {
1860 return build(builder_->create<sem::Vector>(
1861 builder_->create<sem::Bool>(), lhs_vec->Width()));
1862 }
1863 }
1864 }
1865
1866 // Binary bitwise operations
1867 if (expr->IsBitwise()) {
1868 if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
1869 return build(lhs_ty);
1870 }
1871 }
1872
1873 // Bit shift expressions
1874 if (expr->IsBitshift()) {
1875 // Type validation rules are the same for left or right shift, despite
1876 // differences in computation rules (i.e. right shift can be arithmetic or
1877 // logical depending on lhs type).
1878
1879 if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
1880 return build(lhs_ty);
1881 }
1882
1883 if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
1884 rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
1885 return build(lhs_ty);
1886 }
1887 }
1888
1889 AddError("Binary expression operand types are invalid for this operation: " +
1890 TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
1891 TypeNameOf(rhs_ty),
1892 expr->source);
1893 return nullptr;
1894 }
1895
UnaryOp(const ast::UnaryOpExpression * unary)1896 sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
1897 auto* expr = Sem(unary->expr);
1898 auto* expr_ty = expr->Type();
1899 if (!expr_ty) {
1900 return nullptr;
1901 }
1902
1903 const sem::Type* ty = nullptr;
1904
1905 switch (unary->op) {
1906 case ast::UnaryOp::kNot:
1907 // Result type matches the deref'd inner type.
1908 ty = expr_ty->UnwrapRef();
1909 if (!ty->Is<sem::Bool>() && !ty->is_bool_vector()) {
1910 AddError(
1911 "cannot logical negate expression of type '" + TypeNameOf(expr_ty),
1912 unary->expr->source);
1913 return nullptr;
1914 }
1915 break;
1916
1917 case ast::UnaryOp::kComplement:
1918 // Result type matches the deref'd inner type.
1919 ty = expr_ty->UnwrapRef();
1920 if (!ty->is_integer_scalar_or_vector()) {
1921 AddError("cannot bitwise complement expression of type '" +
1922 TypeNameOf(expr_ty),
1923 unary->expr->source);
1924 return nullptr;
1925 }
1926 break;
1927
1928 case ast::UnaryOp::kNegation:
1929 // Result type matches the deref'd inner type.
1930 ty = expr_ty->UnwrapRef();
1931 if (!(ty->IsAnyOf<sem::F32, sem::I32>() ||
1932 ty->is_signed_integer_vector() || ty->is_float_vector())) {
1933 AddError("cannot negate expression of type '" + TypeNameOf(expr_ty),
1934 unary->expr->source);
1935 return nullptr;
1936 }
1937 break;
1938
1939 case ast::UnaryOp::kAddressOf:
1940 if (auto* ref = expr_ty->As<sem::Reference>()) {
1941 if (ref->StoreType()->UnwrapRef()->is_handle()) {
1942 AddError(
1943 "cannot take the address of expression in handle storage class",
1944 unary->expr->source);
1945 return nullptr;
1946 }
1947
1948 auto* array = unary->expr->As<ast::IndexAccessorExpression>();
1949 auto* member = unary->expr->As<ast::MemberAccessorExpression>();
1950 if ((array && TypeOf(array->object)->UnwrapRef()->Is<sem::Vector>()) ||
1951 (member &&
1952 TypeOf(member->structure)->UnwrapRef()->Is<sem::Vector>())) {
1953 AddError("cannot take the address of a vector component",
1954 unary->expr->source);
1955 return nullptr;
1956 }
1957
1958 ty = builder_->create<sem::Pointer>(ref->StoreType(),
1959 ref->StorageClass(), ref->Access());
1960 } else {
1961 AddError("cannot take the address of expression", unary->expr->source);
1962 return nullptr;
1963 }
1964 break;
1965
1966 case ast::UnaryOp::kIndirection:
1967 if (auto* ptr = expr_ty->As<sem::Pointer>()) {
1968 ty = builder_->create<sem::Reference>(
1969 ptr->StoreType(), ptr->StorageClass(), ptr->Access());
1970 } else {
1971 AddError("cannot dereference expression of type '" +
1972 TypeNameOf(expr_ty) + "'",
1973 unary->expr->source);
1974 return nullptr;
1975 }
1976 break;
1977 }
1978
1979 auto val = EvaluateConstantValue(unary, ty);
1980 auto* sem =
1981 builder_->create<sem::Expression>(unary, ty, current_statement_, val);
1982 sem->Behaviors() = expr->Behaviors();
1983 return sem;
1984 }
1985
TypeDecl(const ast::TypeDecl * named_type)1986 sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
1987 sem::Type* result = nullptr;
1988 if (auto* alias = named_type->As<ast::Alias>()) {
1989 result = Alias(alias);
1990 } else if (auto* str = named_type->As<ast::Struct>()) {
1991 result = Structure(str);
1992 } else {
1993 TINT_UNREACHABLE(Resolver, diagnostics_) << "Unhandled TypeDecl";
1994 }
1995
1996 if (!result) {
1997 return nullptr;
1998 }
1999
2000 builder_->Sem().Add(named_type, result);
2001 return result;
2002 }
2003
TypeOf(const ast::Expression * expr)2004 sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
2005 auto* sem = Sem(expr);
2006 return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
2007 }
2008
TypeNameOf(const sem::Type * ty)2009 std::string Resolver::TypeNameOf(const sem::Type* ty) {
2010 return RawTypeNameOf(ty->UnwrapRef());
2011 }
2012
RawTypeNameOf(const sem::Type * ty)2013 std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
2014 return ty->FriendlyName(builder_->Symbols());
2015 }
2016
TypeOf(const ast::LiteralExpression * lit)2017 sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
2018 if (lit->Is<ast::SintLiteralExpression>()) {
2019 return builder_->create<sem::I32>();
2020 }
2021 if (lit->Is<ast::UintLiteralExpression>()) {
2022 return builder_->create<sem::U32>();
2023 }
2024 if (lit->Is<ast::FloatLiteralExpression>()) {
2025 return builder_->create<sem::F32>();
2026 }
2027 if (lit->Is<ast::BoolLiteralExpression>()) {
2028 return builder_->create<sem::Bool>();
2029 }
2030 TINT_UNREACHABLE(Resolver, diagnostics_)
2031 << "Unhandled literal type: " << lit->TypeInfo().name;
2032 return nullptr;
2033 }
2034
Array(const ast::Array * arr)2035 sem::Array* Resolver::Array(const ast::Array* arr) {
2036 auto source = arr->source;
2037
2038 auto* elem_type = Type(arr->type);
2039 if (!elem_type) {
2040 return nullptr;
2041 }
2042
2043 if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize()
2044 AddError(TypeNameOf(elem_type) +
2045 " cannot be used as an element type of an array",
2046 source);
2047 return nullptr;
2048 }
2049
2050 uint32_t el_align = elem_type->Align();
2051 uint32_t el_size = elem_type->Size();
2052
2053 if (!ValidateNoDuplicateDecorations(arr->decorations)) {
2054 return nullptr;
2055 }
2056
2057 // Look for explicit stride via [[stride(n)]] decoration
2058 uint32_t explicit_stride = 0;
2059 for (auto* deco : arr->decorations) {
2060 Mark(deco);
2061 if (auto* sd = deco->As<ast::StrideDecoration>()) {
2062 explicit_stride = sd->stride;
2063 if (!ValidateArrayStrideDecoration(sd, el_size, el_align, source)) {
2064 return nullptr;
2065 }
2066 continue;
2067 }
2068
2069 AddError("decoration is not valid for array types", deco->source);
2070 return nullptr;
2071 }
2072
2073 // Calculate implicit stride
2074 uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
2075
2076 uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
2077
2078 // Evaluate the constant array size expression.
2079 // sem::Array uses a size of 0 for a runtime-sized array.
2080 uint32_t count = 0;
2081 if (auto* count_expr = arr->count) {
2082 auto* count_sem = Expression(count_expr);
2083 if (!count_sem) {
2084 return nullptr;
2085 }
2086
2087 auto size_source = count_expr->source;
2088
2089 auto* ty = count_sem->Type()->UnwrapRef();
2090 if (!ty->is_integer_scalar()) {
2091 AddError("array size must be integer scalar", size_source);
2092 return nullptr;
2093 }
2094
2095 if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
2096 // Make sure the identifier is a non-overridable module-scope constant.
2097 auto* var = ResolvedSymbol<sem::Variable>(ident);
2098 if (!var || !var->Is<sem::GlobalVariable>() ||
2099 !var->Declaration()->is_const) {
2100 AddError("array size identifier must be a module-scope constant",
2101 size_source);
2102 return nullptr;
2103 }
2104 if (ast::HasDecoration<ast::OverrideDecoration>(
2105 var->Declaration()->decorations)) {
2106 AddError("array size expression must not be pipeline-overridable",
2107 size_source);
2108 return nullptr;
2109 }
2110
2111 count_expr = var->Declaration()->constructor;
2112 } else if (!count_expr->Is<ast::LiteralExpression>()) {
2113 AddError(
2114 "array size expression must be either a literal or a module-scope "
2115 "constant",
2116 size_source);
2117 return nullptr;
2118 }
2119
2120 auto count_val = count_sem->ConstantValue();
2121 if (!count_val) {
2122 TINT_ICE(Resolver, diagnostics_)
2123 << "could not resolve array size expression";
2124 return nullptr;
2125 }
2126
2127 if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1
2128 : count_val.Elements()[0].u32 < 1u) {
2129 AddError("array size must be at least 1", size_source);
2130 return nullptr;
2131 }
2132
2133 count = count_val.Elements()[0].u32;
2134 }
2135
2136 auto size = std::max<uint64_t>(count, 1) * stride;
2137 if (size > std::numeric_limits<uint32_t>::max()) {
2138 std::stringstream msg;
2139 msg << "array size in bytes must not exceed 0x" << std::hex
2140 << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
2141 << size;
2142 AddError(msg.str(), arr->source);
2143 return nullptr;
2144 }
2145 if (stride > std::numeric_limits<uint32_t>::max() ||
2146 implicit_stride > std::numeric_limits<uint32_t>::max()) {
2147 TINT_ICE(Resolver, diagnostics_)
2148 << "calculated array stride exceeds uint32";
2149 return nullptr;
2150 }
2151 auto* out = builder_->create<sem::Array>(
2152 elem_type, count, el_align, static_cast<uint32_t>(size),
2153 static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
2154
2155 if (!ValidateArray(out, source)) {
2156 return nullptr;
2157 }
2158
2159 if (elem_type->Is<sem::Atomic>()) {
2160 atomic_composite_info_.emplace(out, arr->type->source);
2161 } else {
2162 auto found = atomic_composite_info_.find(elem_type);
2163 if (found != atomic_composite_info_.end()) {
2164 atomic_composite_info_.emplace(out, found->second);
2165 }
2166 }
2167
2168 return out;
2169 }
2170
Alias(const ast::Alias * alias)2171 sem::Type* Resolver::Alias(const ast::Alias* alias) {
2172 auto* ty = Type(alias->type);
2173 if (!ty) {
2174 return nullptr;
2175 }
2176 if (!ValidateAlias(alias)) {
2177 return nullptr;
2178 }
2179 return ty;
2180 }
2181
Structure(const ast::Struct * str)2182 sem::Struct* Resolver::Structure(const ast::Struct* str) {
2183 if (!ValidateNoDuplicateDecorations(str->decorations)) {
2184 return nullptr;
2185 }
2186 for (auto* deco : str->decorations) {
2187 Mark(deco);
2188 }
2189
2190 sem::StructMemberList sem_members;
2191 sem_members.reserve(str->members.size());
2192
2193 // Calculate the effective size and alignment of each field, and the overall
2194 // size of the structure.
2195 // For size, use the size attribute if provided, otherwise use the default
2196 // size for the type.
2197 // For alignment, use the alignment attribute if provided, otherwise use the
2198 // default alignment for the member type.
2199 // Diagnostic errors are raised if a basic rule is violated.
2200 // Validation of storage-class rules requires analysing the actual variable
2201 // usage of the structure, and so is performed as part of the variable
2202 // validation.
2203 uint64_t struct_size = 0;
2204 uint64_t struct_align = 1;
2205 std::unordered_map<Symbol, const ast::StructMember*> member_map;
2206
2207 for (auto* member : str->members) {
2208 Mark(member);
2209 auto result = member_map.emplace(member->symbol, member);
2210 if (!result.second) {
2211 AddError("redefinition of '" +
2212 builder_->Symbols().NameFor(member->symbol) + "'",
2213 member->source);
2214 AddNote("previous definition is here", result.first->second->source);
2215 return nullptr;
2216 }
2217
2218 // Resolve member type
2219 auto* type = Type(member->type);
2220 if (!type) {
2221 return nullptr;
2222 }
2223
2224 // Validate member type
2225 if (!IsPlain(type)) {
2226 AddError(TypeNameOf(type) +
2227 " cannot be used as the type of a structure member",
2228 member->source);
2229 return nullptr;
2230 }
2231
2232 uint64_t offset = struct_size;
2233 uint64_t align = type->Align();
2234 uint64_t size = type->Size();
2235
2236 if (!ValidateNoDuplicateDecorations(member->decorations)) {
2237 return nullptr;
2238 }
2239
2240 bool has_offset_deco = false;
2241 bool has_align_deco = false;
2242 bool has_size_deco = false;
2243 for (auto* deco : member->decorations) {
2244 Mark(deco);
2245 if (auto* o = deco->As<ast::StructMemberOffsetDecoration>()) {
2246 // Offset decorations are not part of the WGSL spec, but are emitted
2247 // by the SPIR-V reader.
2248 if (o->offset < struct_size) {
2249 AddError("offsets must be in ascending order", o->source);
2250 return nullptr;
2251 }
2252 offset = o->offset;
2253 align = 1;
2254 has_offset_deco = true;
2255 } else if (auto* a = deco->As<ast::StructMemberAlignDecoration>()) {
2256 if (a->align <= 0 || !utils::IsPowerOfTwo(a->align)) {
2257 AddError("align value must be a positive, power-of-two integer",
2258 a->source);
2259 return nullptr;
2260 }
2261 align = a->align;
2262 has_align_deco = true;
2263 } else if (auto* s = deco->As<ast::StructMemberSizeDecoration>()) {
2264 if (s->size < size) {
2265 AddError("size must be at least as big as the type's size (" +
2266 std::to_string(size) + ")",
2267 s->source);
2268 return nullptr;
2269 }
2270 size = s->size;
2271 has_size_deco = true;
2272 }
2273 }
2274
2275 if (has_offset_deco && (has_align_deco || has_size_deco)) {
2276 AddError(
2277 "offset decorations cannot be used with align or size decorations",
2278 member->source);
2279 return nullptr;
2280 }
2281
2282 offset = utils::RoundUp(align, offset);
2283 if (offset > std::numeric_limits<uint32_t>::max()) {
2284 std::stringstream msg;
2285 msg << "struct member has byte offset 0x" << std::hex << offset
2286 << ", but must not exceed 0x" << std::hex
2287 << std::numeric_limits<uint32_t>::max();
2288 AddError(msg.str(), member->source);
2289 return nullptr;
2290 }
2291
2292 auto* sem_member = builder_->create<sem::StructMember>(
2293 member, member->symbol, type, static_cast<uint32_t>(sem_members.size()),
2294 static_cast<uint32_t>(offset), static_cast<uint32_t>(align),
2295 static_cast<uint32_t>(size));
2296 builder_->Sem().Add(member, sem_member);
2297 sem_members.emplace_back(sem_member);
2298
2299 struct_size = offset + size;
2300 struct_align = std::max(struct_align, align);
2301 }
2302
2303 uint64_t size_no_padding = struct_size;
2304 struct_size = utils::RoundUp(struct_align, struct_size);
2305
2306 if (struct_size > std::numeric_limits<uint32_t>::max()) {
2307 std::stringstream msg;
2308 msg << "struct size in bytes must not exceed 0x" << std::hex
2309 << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
2310 << struct_size;
2311 AddError(msg.str(), str->source);
2312 return nullptr;
2313 }
2314 if (struct_align > std::numeric_limits<uint32_t>::max()) {
2315 TINT_ICE(Resolver, diagnostics_)
2316 << "calculated struct stride exceeds uint32";
2317 return nullptr;
2318 }
2319
2320 auto* out = builder_->create<sem::Struct>(
2321 str, str->name, sem_members, static_cast<uint32_t>(struct_align),
2322 static_cast<uint32_t>(struct_size),
2323 static_cast<uint32_t>(size_no_padding));
2324
2325 for (size_t i = 0; i < sem_members.size(); i++) {
2326 auto* mem_type = sem_members[i]->Type();
2327 if (mem_type->Is<sem::Atomic>()) {
2328 atomic_composite_info_.emplace(out,
2329 sem_members[i]->Declaration()->source);
2330 break;
2331 } else {
2332 auto found = atomic_composite_info_.find(mem_type);
2333 if (found != atomic_composite_info_.end()) {
2334 atomic_composite_info_.emplace(out, found->second);
2335 break;
2336 }
2337 }
2338 }
2339
2340 if (!ValidateStructure(out)) {
2341 return nullptr;
2342 }
2343
2344 return out;
2345 }
2346
ReturnStatement(const ast::ReturnStatement * stmt)2347 sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
2348 auto* sem = builder_->create<sem::Statement>(
2349 stmt, current_compound_statement_, current_function_);
2350 return StatementScope(stmt, sem, [&] {
2351 auto& behaviors = current_statement_->Behaviors();
2352 behaviors = sem::Behavior::kReturn;
2353
2354 if (auto* value = stmt->value) {
2355 auto* expr = Expression(value);
2356 if (!expr) {
2357 return false;
2358 }
2359 behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
2360 }
2361
2362 // Validate after processing the return value expression so that its type is
2363 // available for validation.
2364 return ValidateReturn(stmt);
2365 });
2366 }
2367
SwitchStatement(const ast::SwitchStatement * stmt)2368 sem::SwitchStatement* Resolver::SwitchStatement(
2369 const ast::SwitchStatement* stmt) {
2370 auto* sem = builder_->create<sem::SwitchStatement>(
2371 stmt, current_compound_statement_, current_function_);
2372 return StatementScope(stmt, sem, [&] {
2373 auto& behaviors = sem->Behaviors();
2374
2375 auto* cond = Expression(stmt->condition);
2376 if (!cond) {
2377 return false;
2378 }
2379 behaviors = cond->Behaviors() - sem::Behavior::kNext;
2380
2381 for (auto* case_stmt : stmt->body) {
2382 Mark(case_stmt);
2383 auto* c = CaseStatement(case_stmt);
2384 if (!c) {
2385 return false;
2386 }
2387 behaviors.Add(c->Behaviors());
2388 }
2389
2390 if (behaviors.Contains(sem::Behavior::kBreak)) {
2391 behaviors.Add(sem::Behavior::kNext);
2392 }
2393 behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
2394
2395 return ValidateSwitch(stmt);
2396 });
2397 }
2398
VariableDeclStatement(const ast::VariableDeclStatement * stmt)2399 sem::Statement* Resolver::VariableDeclStatement(
2400 const ast::VariableDeclStatement* stmt) {
2401 auto* sem = builder_->create<sem::Statement>(
2402 stmt, current_compound_statement_, current_function_);
2403 return StatementScope(stmt, sem, [&] {
2404 Mark(stmt->variable);
2405
2406 auto* var = Variable(stmt->variable, VariableKind::kLocal);
2407 if (!var) {
2408 return false;
2409 }
2410
2411 for (auto* deco : stmt->variable->decorations) {
2412 Mark(deco);
2413 if (!deco->Is<ast::InternalDecoration>()) {
2414 AddError("decorations are not valid on local variables", deco->source);
2415 return false;
2416 }
2417 }
2418
2419 if (current_block_) { // Not all statements are inside a block
2420 current_block_->AddDecl(stmt->variable);
2421 }
2422
2423 if (auto* ctor = var->Constructor()) {
2424 sem->Behaviors() = ctor->Behaviors();
2425 }
2426
2427 return ValidateVariable(var);
2428 });
2429 }
2430
AssignmentStatement(const ast::AssignmentStatement * stmt)2431 sem::Statement* Resolver::AssignmentStatement(
2432 const ast::AssignmentStatement* stmt) {
2433 auto* sem = builder_->create<sem::Statement>(
2434 stmt, current_compound_statement_, current_function_);
2435 return StatementScope(stmt, sem, [&] {
2436 auto* lhs = Expression(stmt->lhs);
2437 if (!lhs) {
2438 return false;
2439 }
2440
2441 auto* rhs = Expression(stmt->rhs);
2442 if (!rhs) {
2443 return false;
2444 }
2445
2446 auto& behaviors = sem->Behaviors();
2447 behaviors = rhs->Behaviors();
2448 if (!stmt->lhs->Is<ast::PhonyExpression>()) {
2449 behaviors.Add(lhs->Behaviors());
2450 }
2451
2452 return ValidateAssignment(stmt);
2453 });
2454 }
2455
BreakStatement(const ast::BreakStatement * stmt)2456 sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
2457 auto* sem = builder_->create<sem::Statement>(
2458 stmt, current_compound_statement_, current_function_);
2459 return StatementScope(stmt, sem, [&] {
2460 sem->Behaviors() = sem::Behavior::kBreak;
2461
2462 return ValidateBreakStatement(sem);
2463 });
2464 }
2465
CallStatement(const ast::CallStatement * stmt)2466 sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
2467 auto* sem = builder_->create<sem::Statement>(
2468 stmt, current_compound_statement_, current_function_);
2469 return StatementScope(stmt, sem, [&] {
2470 if (auto* expr = Expression(stmt->expr)) {
2471 sem->Behaviors() = expr->Behaviors();
2472 return true;
2473 }
2474 return false;
2475 });
2476 }
2477
ContinueStatement(const ast::ContinueStatement * stmt)2478 sem::Statement* Resolver::ContinueStatement(
2479 const ast::ContinueStatement* stmt) {
2480 auto* sem = builder_->create<sem::Statement>(
2481 stmt, current_compound_statement_, current_function_);
2482 return StatementScope(stmt, sem, [&] {
2483 sem->Behaviors() = sem::Behavior::kContinue;
2484
2485 // Set if we've hit the first continue statement in our parent loop
2486 if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
2487 if (!block->FirstContinue()) {
2488 const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
2489 stmt, block->Decls().size());
2490 }
2491 }
2492
2493 return ValidateContinueStatement(sem);
2494 });
2495 }
2496
DiscardStatement(const ast::DiscardStatement * stmt)2497 sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
2498 auto* sem = builder_->create<sem::Statement>(
2499 stmt, current_compound_statement_, current_function_);
2500 return StatementScope(stmt, sem, [&] {
2501 sem->Behaviors() = sem::Behavior::kDiscard;
2502 current_function_->SetHasDiscard();
2503
2504 return ValidateDiscardStatement(sem);
2505 });
2506 }
2507
FallthroughStatement(const ast::FallthroughStatement * stmt)2508 sem::Statement* Resolver::FallthroughStatement(
2509 const ast::FallthroughStatement* stmt) {
2510 auto* sem = builder_->create<sem::Statement>(
2511 stmt, current_compound_statement_, current_function_);
2512 return StatementScope(stmt, sem, [&] {
2513 sem->Behaviors() = sem::Behavior::kFallthrough;
2514
2515 return ValidateFallthroughStatement(sem);
2516 });
2517 }
2518
ApplyStorageClassUsageToType(ast::StorageClass sc,sem::Type * ty,const Source & usage)2519 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
2520 sem::Type* ty,
2521 const Source& usage) {
2522 ty = const_cast<sem::Type*>(ty->UnwrapRef());
2523
2524 if (auto* str = ty->As<sem::Struct>()) {
2525 if (str->StorageClassUsage().count(sc)) {
2526 return true; // Already applied
2527 }
2528
2529 str->AddUsage(sc);
2530
2531 for (auto* member : str->Members()) {
2532 if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
2533 std::stringstream err;
2534 err << "while analysing structure member " << TypeNameOf(str) << "."
2535 << builder_->Symbols().NameFor(member->Declaration()->symbol);
2536 AddNote(err.str(), member->Declaration()->source);
2537 return false;
2538 }
2539 }
2540 return true;
2541 }
2542
2543 if (auto* arr = ty->As<sem::Array>()) {
2544 return ApplyStorageClassUsageToType(
2545 sc, const_cast<sem::Type*>(arr->ElemType()), usage);
2546 }
2547
2548 if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
2549 std::stringstream err;
2550 err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '"
2551 << sc << "' as it is non-host-shareable";
2552 AddError(err.str(), usage);
2553 return false;
2554 }
2555
2556 return true;
2557 }
2558
2559 template <typename SEM, typename F>
StatementScope(const ast::Statement * ast,SEM * sem,F && callback)2560 SEM* Resolver::StatementScope(const ast::Statement* ast,
2561 SEM* sem,
2562 F&& callback) {
2563 builder_->Sem().Add(ast, sem);
2564
2565 auto* as_compound =
2566 As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
2567 auto* as_block =
2568 As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
2569
2570 TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
2571 TINT_SCOPED_ASSIGNMENT(
2572 current_compound_statement_,
2573 as_compound ? as_compound : current_compound_statement_);
2574 TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
2575
2576 if (!callback()) {
2577 return nullptr;
2578 }
2579
2580 return sem;
2581 }
2582
VectorPretty(uint32_t size,const sem::Type * element_type)2583 std::string Resolver::VectorPretty(uint32_t size,
2584 const sem::Type* element_type) {
2585 sem::Vector vec_type(element_type, size);
2586 return vec_type.FriendlyName(builder_->Symbols());
2587 }
2588
Mark(const ast::Node * node)2589 bool Resolver::Mark(const ast::Node* node) {
2590 if (node == nullptr) {
2591 TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";
2592 return false;
2593 }
2594 if (marked_.emplace(node).second) {
2595 return true;
2596 }
2597 TINT_ICE(Resolver, diagnostics_)
2598 << "AST node '" << node->TypeInfo().name
2599 << "' was encountered twice in the same AST of a Program\n"
2600 << "At: " << node->source << "\n"
2601 << "Pointer: " << node;
2602 return false;
2603 }
2604
AddError(const std::string & msg,const Source & source) const2605 void Resolver::AddError(const std::string& msg, const Source& source) const {
2606 diagnostics_.add_error(diag::System::Resolver, msg, source);
2607 }
2608
AddWarning(const std::string & msg,const Source & source) const2609 void Resolver::AddWarning(const std::string& msg, const Source& source) const {
2610 diagnostics_.add_warning(diag::System::Resolver, msg, source);
2611 }
2612
AddNote(const std::string & msg,const Source & source) const2613 void Resolver::AddNote(const std::string& msg, const Source& source) const {
2614 diagnostics_.add_note(diag::System::Resolver, msg, source);
2615 }
2616
2617 // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
IsPlain(const sem::Type * type) const2618 bool Resolver::IsPlain(const sem::Type* type) const {
2619 return type->is_scalar() ||
2620 type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
2621 sem::Struct>();
2622 }
2623
2624 // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
IsStorable(const sem::Type * type) const2625 bool Resolver::IsStorable(const sem::Type* type) const {
2626 return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
2627 }
2628
2629 // https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
IsHostShareable(const sem::Type * type) const2630 bool Resolver::IsHostShareable(const sem::Type* type) const {
2631 if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
2632 return true;
2633 }
2634 if (auto* vec = type->As<sem::Vector>()) {
2635 return IsHostShareable(vec->type());
2636 }
2637 if (auto* mat = type->As<sem::Matrix>()) {
2638 return IsHostShareable(mat->type());
2639 }
2640 if (auto* arr = type->As<sem::Array>()) {
2641 return IsHostShareable(arr->ElemType());
2642 }
2643 if (auto* str = type->As<sem::Struct>()) {
2644 for (auto* member : str->Members()) {
2645 if (!IsHostShareable(member->Type())) {
2646 return false;
2647 }
2648 }
2649 return true;
2650 }
2651 if (auto* atomic = type->As<sem::Atomic>()) {
2652 return IsHostShareable(atomic->Type());
2653 }
2654 return false;
2655 }
2656
IsIntrinsic(Symbol symbol) const2657 bool Resolver::IsIntrinsic(Symbol symbol) const {
2658 std::string name = builder_->Symbols().NameFor(symbol);
2659 return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone;
2660 }
2661
IsCallStatement(const ast::Expression * expr) const2662 bool Resolver::IsCallStatement(const ast::Expression* expr) const {
2663 return current_statement_ &&
2664 Is<ast::CallStatement>(current_statement_->Declaration(),
2665 [&](auto* stmt) { return stmt->expr == expr; });
2666 }
2667
ClosestContinuing(bool stop_at_loop) const2668 const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const {
2669 for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) {
2670 if (stop_at_loop && s->Is<sem::LoopStatement>()) {
2671 break;
2672 }
2673 if (s->Is<sem::LoopContinuingBlockStatement>()) {
2674 return s->Declaration();
2675 }
2676 if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
2677 if (f->Declaration()->continuing == s->Declaration()) {
2678 return s->Declaration();
2679 }
2680 if (stop_at_loop) {
2681 break;
2682 }
2683 }
2684 }
2685 return nullptr;
2686 }
2687
2688 ////////////////////////////////////////////////////////////////////////////////
2689 // Resolver::TypeConversionSig
2690 ////////////////////////////////////////////////////////////////////////////////
operator ==(const TypeConversionSig & rhs) const2691 bool Resolver::TypeConversionSig::operator==(
2692 const TypeConversionSig& rhs) const {
2693 return target == rhs.target && source == rhs.source;
2694 }
operator ()(const TypeConversionSig & sig) const2695 std::size_t Resolver::TypeConversionSig::Hasher::operator()(
2696 const TypeConversionSig& sig) const {
2697 return utils::Hash(sig.target, sig.source);
2698 }
2699
2700 ////////////////////////////////////////////////////////////////////////////////
2701 // Resolver::TypeConstructorSig
2702 ////////////////////////////////////////////////////////////////////////////////
TypeConstructorSig(const sem::Type * ty,const std::vector<const sem::Type * > params)2703 Resolver::TypeConstructorSig::TypeConstructorSig(
2704 const sem::Type* ty,
2705 const std::vector<const sem::Type*> params)
2706 : type(ty), parameters(params) {}
2707 Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) =
2708 default;
2709 Resolver::TypeConstructorSig::~TypeConstructorSig() = default;
2710
operator ==(const TypeConstructorSig & rhs) const2711 bool Resolver::TypeConstructorSig::operator==(
2712 const TypeConstructorSig& rhs) const {
2713 return type == rhs.type && parameters == rhs.parameters;
2714 }
operator ()(const TypeConstructorSig & sig) const2715 std::size_t Resolver::TypeConstructorSig::Hasher::operator()(
2716 const TypeConstructorSig& sig) const {
2717 return utils::Hash(sig.type, sig.parameters);
2718 }
2719
2720 } // namespace resolver
2721 } // namespace tint
2722