• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/resolver/resolver.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <limits>
21 #include <utility>
22 
23 #include "src/ast/alias.h"
24 #include "src/ast/array.h"
25 #include "src/ast/assignment_statement.h"
26 #include "src/ast/bitcast_expression.h"
27 #include "src/ast/break_statement.h"
28 #include "src/ast/call_statement.h"
29 #include "src/ast/continue_statement.h"
30 #include "src/ast/depth_texture.h"
31 #include "src/ast/disable_validation_decoration.h"
32 #include "src/ast/discard_statement.h"
33 #include "src/ast/fallthrough_statement.h"
34 #include "src/ast/for_loop_statement.h"
35 #include "src/ast/if_statement.h"
36 #include "src/ast/internal_decoration.h"
37 #include "src/ast/interpolate_decoration.h"
38 #include "src/ast/loop_statement.h"
39 #include "src/ast/matrix.h"
40 #include "src/ast/override_decoration.h"
41 #include "src/ast/pointer.h"
42 #include "src/ast/return_statement.h"
43 #include "src/ast/sampled_texture.h"
44 #include "src/ast/sampler.h"
45 #include "src/ast/storage_texture.h"
46 #include "src/ast/struct_block_decoration.h"
47 #include "src/ast/switch_statement.h"
48 #include "src/ast/traverse_expressions.h"
49 #include "src/ast/type_name.h"
50 #include "src/ast/unary_op_expression.h"
51 #include "src/ast/variable_decl_statement.h"
52 #include "src/ast/vector.h"
53 #include "src/ast/workgroup_decoration.h"
54 #include "src/sem/array.h"
55 #include "src/sem/atomic_type.h"
56 #include "src/sem/call.h"
57 #include "src/sem/depth_multisampled_texture_type.h"
58 #include "src/sem/depth_texture_type.h"
59 #include "src/sem/for_loop_statement.h"
60 #include "src/sem/function.h"
61 #include "src/sem/if_statement.h"
62 #include "src/sem/loop_statement.h"
63 #include "src/sem/member_accessor_expression.h"
64 #include "src/sem/multisampled_texture_type.h"
65 #include "src/sem/pointer_type.h"
66 #include "src/sem/reference_type.h"
67 #include "src/sem/sampled_texture_type.h"
68 #include "src/sem/sampler_type.h"
69 #include "src/sem/statement.h"
70 #include "src/sem/storage_texture_type.h"
71 #include "src/sem/struct.h"
72 #include "src/sem/switch_statement.h"
73 #include "src/sem/type_constructor.h"
74 #include "src/sem/type_conversion.h"
75 #include "src/sem/variable.h"
76 #include "src/utils/defer.h"
77 #include "src/utils/math.h"
78 #include "src/utils/reverse.h"
79 #include "src/utils/scoped_assignment.h"
80 #include "src/utils/transform.h"
81 
82 namespace tint {
83 namespace resolver {
84 
Resolver(ProgramBuilder * builder)85 Resolver::Resolver(ProgramBuilder* builder)
86     : builder_(builder),
87       diagnostics_(builder->Diagnostics()),
88       intrinsic_table_(IntrinsicTable::Create(*builder)) {}
89 
90 Resolver::~Resolver() = default;
91 
Resolve()92 bool Resolver::Resolve() {
93   if (builder_->Diagnostics().contains_errors()) {
94     return false;
95   }
96 
97   if (!DependencyGraph::Build(builder_->AST(), builder_->Symbols(),
98                               builder_->Diagnostics(), dependencies_,
99                               /* allow_out_of_order_decls*/ false)) {
100     return false;
101   }
102 
103   bool result = ResolveInternal();
104 
105   if (!result && !diagnostics_.contains_errors()) {
106     TINT_ICE(Resolver, diagnostics_)
107         << "resolving failed, but no error was raised";
108     return false;
109   }
110 
111   return result;
112 }
113 
ResolveInternal()114 bool Resolver::ResolveInternal() {
115   Mark(&builder_->AST());
116 
117   // Process everything else in the order they appear in the module. This is
118   // necessary for validation of use-before-declaration.
119   for (auto* decl : builder_->AST().GlobalDeclarations()) {
120     if (auto* td = decl->As<ast::TypeDecl>()) {
121       Mark(td);
122       if (!TypeDecl(td)) {
123         return false;
124       }
125     } else if (auto* func = decl->As<ast::Function>()) {
126       Mark(func);
127       if (!Function(func)) {
128         return false;
129       }
130     } else if (auto* var = decl->As<ast::Variable>()) {
131       Mark(var);
132       if (!GlobalVariable(var)) {
133         return false;
134       }
135     } else {
136       TINT_UNREACHABLE(Resolver, diagnostics_)
137           << "unhandled global declaration: " << decl->TypeInfo().name;
138       return false;
139     }
140   }
141 
142   AllocateOverridableConstantIds();
143 
144   SetShadows();
145 
146   if (!ValidatePipelineStages()) {
147     return false;
148   }
149 
150   bool result = true;
151   for (auto* node : builder_->ASTNodes().Objects()) {
152     if (marked_.count(node) == 0) {
153       TINT_ICE(Resolver, diagnostics_) << "AST node '" << node->TypeInfo().name
154                                        << "' was not reached by the resolver\n"
155                                        << "At: " << node->source << "\n"
156                                        << "Pointer: " << node;
157       result = false;
158     }
159   }
160 
161   return result;
162 }
163 
Type(const ast::Type * ty)164 sem::Type* Resolver::Type(const ast::Type* ty) {
165   Mark(ty);
166   auto* s = [&]() -> sem::Type* {
167     if (ty->Is<ast::Void>()) {
168       return builder_->create<sem::Void>();
169     }
170     if (ty->Is<ast::Bool>()) {
171       return builder_->create<sem::Bool>();
172     }
173     if (ty->Is<ast::I32>()) {
174       return builder_->create<sem::I32>();
175     }
176     if (ty->Is<ast::U32>()) {
177       return builder_->create<sem::U32>();
178     }
179     if (ty->Is<ast::F32>()) {
180       return builder_->create<sem::F32>();
181     }
182     if (auto* t = ty->As<ast::Vector>()) {
183       if (auto* el = Type(t->type)) {
184         if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
185           if (ValidateVector(vector, t->source)) {
186             return vector;
187           }
188         }
189       }
190       return nullptr;
191     }
192     if (auto* t = ty->As<ast::Matrix>()) {
193       if (auto* el = Type(t->type)) {
194         if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
195           if (auto* matrix =
196                   builder_->create<sem::Matrix>(column_type, t->columns)) {
197             if (ValidateMatrix(matrix, t->source)) {
198               return matrix;
199             }
200           }
201         }
202       }
203       return nullptr;
204     }
205     if (auto* t = ty->As<ast::Array>()) {
206       return Array(t);
207     }
208     if (auto* t = ty->As<ast::Atomic>()) {
209       if (auto* el = Type(t->type)) {
210         auto* a = builder_->create<sem::Atomic>(el);
211         if (!ValidateAtomic(t, a)) {
212           return nullptr;
213         }
214         return a;
215       }
216       return nullptr;
217     }
218     if (auto* t = ty->As<ast::Pointer>()) {
219       if (auto* el = Type(t->type)) {
220         auto access = t->access;
221         if (access == ast::kUndefined) {
222           access = DefaultAccessForStorageClass(t->storage_class);
223         }
224         return builder_->create<sem::Pointer>(el, t->storage_class, access);
225       }
226       return nullptr;
227     }
228     if (auto* t = ty->As<ast::Sampler>()) {
229       return builder_->create<sem::Sampler>(t->kind);
230     }
231     if (auto* t = ty->As<ast::SampledTexture>()) {
232       if (auto* el = Type(t->type)) {
233         return builder_->create<sem::SampledTexture>(t->dim, el);
234       }
235       return nullptr;
236     }
237     if (auto* t = ty->As<ast::MultisampledTexture>()) {
238       if (auto* el = Type(t->type)) {
239         return builder_->create<sem::MultisampledTexture>(t->dim, el);
240       }
241       return nullptr;
242     }
243     if (auto* t = ty->As<ast::DepthTexture>()) {
244       return builder_->create<sem::DepthTexture>(t->dim);
245     }
246     if (auto* t = ty->As<ast::DepthMultisampledTexture>()) {
247       return builder_->create<sem::DepthMultisampledTexture>(t->dim);
248     }
249     if (auto* t = ty->As<ast::StorageTexture>()) {
250       if (auto* el = Type(t->type)) {
251         if (!ValidateStorageTexture(t)) {
252           return nullptr;
253         }
254         return builder_->create<sem::StorageTexture>(t->dim, t->format,
255                                                      t->access, el);
256       }
257       return nullptr;
258     }
259     if (ty->As<ast::ExternalTexture>()) {
260       return builder_->create<sem::ExternalTexture>();
261     }
262     if (auto* type = ResolvedSymbol<sem::Type>(ty)) {
263       return type;
264     }
265     TINT_UNREACHABLE(Resolver, diagnostics_)
266         << "Unhandled ast::Type: " << ty->TypeInfo().name;
267     return nullptr;
268   }();
269 
270   if (s) {
271     builder_->Sem().Add(ty, s);
272   }
273   return s;
274 }
275 
Variable(const ast::Variable * var,VariableKind kind,uint32_t index)276 sem::Variable* Resolver::Variable(const ast::Variable* var,
277                                   VariableKind kind,
278                                   uint32_t index /* = 0 */) {
279   const sem::Type* storage_ty = nullptr;
280 
281   // If the variable has a declared type, resolve it.
282   if (auto* ty = var->type) {
283     storage_ty = Type(ty);
284     if (!storage_ty) {
285       return nullptr;
286     }
287   }
288 
289   const sem::Expression* rhs = nullptr;
290 
291   // Does the variable have a constructor?
292   if (var->constructor) {
293     rhs = Expression(var->constructor);
294     if (!rhs) {
295       return nullptr;
296     }
297 
298     // If the variable has no declared type, infer it from the RHS
299     if (!storage_ty) {
300       if (!var->is_const && kind == VariableKind::kGlobal) {
301         AddError("global var declaration must specify a type", var->source);
302         return nullptr;
303       }
304 
305       storage_ty = rhs->Type()->UnwrapRef();  // Implicit load of RHS
306     }
307   } else if (var->is_const && kind != VariableKind::kParameter &&
308              !ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
309     AddError("let declaration must have an initializer", var->source);
310     return nullptr;
311   } else if (!var->type) {
312     AddError(
313         (kind == VariableKind::kGlobal)
314             ? "module scope var declaration requires a type and initializer"
315             : "function scope var declaration requires a type or initializer",
316         var->source);
317     return nullptr;
318   }
319 
320   if (!storage_ty) {
321     TINT_ICE(Resolver, diagnostics_)
322         << "failed to determine storage type for variable '" +
323                builder_->Symbols().NameFor(var->symbol) + "'\n"
324         << "Source: " << var->source;
325     return nullptr;
326   }
327 
328   auto storage_class = var->declared_storage_class;
329   if (storage_class == ast::StorageClass::kNone && !var->is_const) {
330     // No declared storage class. Infer from usage / type.
331     if (kind == VariableKind::kLocal) {
332       storage_class = ast::StorageClass::kFunction;
333     } else if (storage_ty->UnwrapRef()->is_handle()) {
334       // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
335       // If the store type is a texture type or a sampler type, then the
336       // variable declaration must not have a storage class decoration. The
337       // storage class will always be handle.
338       storage_class = ast::StorageClass::kUniformConstant;
339     }
340   }
341 
342   if (kind == VariableKind::kLocal && !var->is_const &&
343       storage_class != ast::StorageClass::kFunction &&
344       IsValidationEnabled(var->decorations,
345                           ast::DisabledValidation::kIgnoreStorageClass)) {
346     AddError("function variable has a non-function storage class", var->source);
347     return nullptr;
348   }
349 
350   auto access = var->declared_access;
351   if (access == ast::Access::kUndefined) {
352     access = DefaultAccessForStorageClass(storage_class);
353   }
354 
355   auto* var_ty = storage_ty;
356   if (!var->is_const) {
357     // Variable declaration. Unlike `let`, `var` has storage.
358     // Variables are always of a reference type to the declared storage type.
359     var_ty =
360         builder_->create<sem::Reference>(storage_ty, storage_class, access);
361   }
362 
363   if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
364                                                 rhs->Type())) {
365     return nullptr;
366   }
367 
368   if (!ApplyStorageClassUsageToType(
369           storage_class, const_cast<sem::Type*>(var_ty), var->source)) {
370     AddNote(
371         std::string("while instantiating ") +
372             ((kind == VariableKind::kParameter) ? "parameter " : "variable ") +
373             builder_->Symbols().NameFor(var->symbol),
374         var->source);
375     return nullptr;
376   }
377 
378   if (kind == VariableKind::kParameter) {
379     if (auto* ptr = var_ty->As<sem::Pointer>()) {
380       // For MSL, we push module-scope variables into the entry point as pointer
381       // parameters, so we also need to handle their store type.
382       if (!ApplyStorageClassUsageToType(
383               ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()),
384               var->source)) {
385         AddNote("while instantiating parameter " +
386                     builder_->Symbols().NameFor(var->symbol),
387                 var->source);
388         return nullptr;
389       }
390     }
391   }
392 
393   switch (kind) {
394     case VariableKind::kGlobal: {
395       sem::BindingPoint binding_point;
396       if (auto bp = var->BindingPoint()) {
397         binding_point = {bp.group->value, bp.binding->value};
398       }
399 
400       auto* override =
401           ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
402       bool has_const_val = rhs && var->is_const && !override;
403 
404       auto* global = builder_->create<sem::GlobalVariable>(
405           var, var_ty, storage_class, access,
406           has_const_val ? rhs->ConstantValue() : sem::Constant{},
407           binding_point);
408 
409       if (override) {
410         global->SetIsOverridable();
411         if (override->has_value) {
412           global->SetConstantId(static_cast<uint16_t>(override->value));
413         }
414       }
415 
416       global->SetConstructor(rhs);
417 
418       builder_->Sem().Add(var, global);
419       return global;
420     }
421     case VariableKind::kLocal: {
422       auto* local = builder_->create<sem::LocalVariable>(
423           var, var_ty, storage_class, access, current_statement_,
424           (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
425       builder_->Sem().Add(var, local);
426       local->SetConstructor(rhs);
427       return local;
428     }
429     case VariableKind::kParameter: {
430       auto* param = builder_->create<sem::Parameter>(var, index, var_ty,
431                                                      storage_class, access);
432       builder_->Sem().Add(var, param);
433       return param;
434     }
435   }
436 
437   TINT_UNREACHABLE(Resolver, diagnostics_)
438       << "unhandled VariableKind " << static_cast<int>(kind);
439   return nullptr;
440 }
441 
DefaultAccessForStorageClass(ast::StorageClass storage_class)442 ast::Access Resolver::DefaultAccessForStorageClass(
443     ast::StorageClass storage_class) {
444   // https://gpuweb.github.io/gpuweb/wgsl/#storage-class
445   switch (storage_class) {
446     case ast::StorageClass::kStorage:
447     case ast::StorageClass::kUniform:
448     case ast::StorageClass::kUniformConstant:
449       return ast::Access::kRead;
450     default:
451       break;
452   }
453   return ast::Access::kReadWrite;
454 }
455 
AllocateOverridableConstantIds()456 void Resolver::AllocateOverridableConstantIds() {
457   // The next pipeline constant ID to try to allocate.
458   uint16_t next_constant_id = 0;
459 
460   // Allocate constant IDs in global declaration order, so that they are
461   // deterministic.
462   // TODO(crbug.com/tint/1192): If a transform changes the order or removes an
463   // unused constant, the allocation may change on the next Resolver pass.
464   for (auto* decl : builder_->AST().GlobalDeclarations()) {
465     auto* var = decl->As<ast::Variable>();
466     if (!var) {
467       continue;
468     }
469     auto* override_deco =
470         ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
471     if (!override_deco) {
472       continue;
473     }
474 
475     uint16_t constant_id;
476     if (override_deco->has_value) {
477       constant_id = static_cast<uint16_t>(override_deco->value);
478     } else {
479       // No ID was specified, so allocate the next available ID.
480       constant_id = next_constant_id;
481       while (constant_ids_.count(constant_id)) {
482         if (constant_id == UINT16_MAX) {
483           TINT_ICE(Resolver, builder_->Diagnostics())
484               << "no more pipeline constant IDs available";
485           return;
486         }
487         constant_id++;
488       }
489       next_constant_id = constant_id + 1;
490     }
491 
492     auto* sem = Sem<sem::GlobalVariable>(var);
493     const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
494   }
495 }
496 
SetShadows()497 void Resolver::SetShadows() {
498   for (auto it : dependencies_.shadows) {
499     auto* var = Sem(it.first);
500     if (auto* local = var->As<sem::LocalVariable>()) {
501       local->SetShadows(Sem(it.second));
502     }
503     if (auto* param = var->As<sem::Parameter>()) {
504       param->SetShadows(Sem(it.second));
505     }
506   }
507 }  // namespace resolver
508 
GlobalVariable(const ast::Variable * var)509 bool Resolver::GlobalVariable(const ast::Variable* var) {
510   auto* sem = Variable(var, VariableKind::kGlobal);
511   if (!sem) {
512     return false;
513   }
514 
515   auto storage_class = sem->StorageClass();
516   if (!var->is_const && storage_class == ast::StorageClass::kNone) {
517     AddError("global variables must have a storage class", var->source);
518     return false;
519   }
520   if (var->is_const && storage_class != ast::StorageClass::kNone) {
521     AddError("global constants shouldn't have a storage class", var->source);
522     return false;
523   }
524 
525   for (auto* deco : var->decorations) {
526     Mark(deco);
527 
528     if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
529       // Track the constant IDs that are specified in the shader.
530       if (override_deco->has_value) {
531         constant_ids_.emplace(override_deco->value, sem);
532       }
533     }
534   }
535 
536   if (!ValidateNoDuplicateDecorations(var->decorations)) {
537     return false;
538   }
539 
540   if (!ValidateGlobalVariable(sem)) {
541     return false;
542   }
543 
544   // TODO(bclayton): Call this at the end of resolve on all uniform and storage
545   // referenced structs
546   if (!ValidateStorageClassLayout(sem)) {
547     return false;
548   }
549 
550   return true;
551 }
552 
Function(const ast::Function * decl)553 sem::Function* Resolver::Function(const ast::Function* decl) {
554   uint32_t parameter_index = 0;
555   std::unordered_map<Symbol, Source> parameter_names;
556   std::vector<sem::Parameter*> parameters;
557 
558   // Resolve all the parameters
559   for (auto* param : decl->params) {
560     Mark(param);
561 
562     {  // Check the parameter name is unique for the function
563       auto emplaced = parameter_names.emplace(param->symbol, param->source);
564       if (!emplaced.second) {
565         auto name = builder_->Symbols().NameFor(param->symbol);
566         AddError("redefinition of parameter '" + name + "'", param->source);
567         AddNote("previous definition is here", emplaced.first->second);
568         return nullptr;
569       }
570     }
571 
572     auto* var = As<sem::Parameter>(
573         Variable(param, VariableKind::kParameter, parameter_index++));
574     if (!var) {
575       return nullptr;
576     }
577 
578     for (auto* deco : param->decorations) {
579       Mark(deco);
580     }
581     if (!ValidateNoDuplicateDecorations(param->decorations)) {
582       return nullptr;
583     }
584 
585     parameters.emplace_back(var);
586 
587     auto* var_ty = const_cast<sem::Type*>(var->Type());
588     if (auto* str = var_ty->As<sem::Struct>()) {
589       switch (decl->PipelineStage()) {
590         case ast::PipelineStage::kVertex:
591           str->AddUsage(sem::PipelineStageUsage::kVertexInput);
592           break;
593         case ast::PipelineStage::kFragment:
594           str->AddUsage(sem::PipelineStageUsage::kFragmentInput);
595           break;
596         case ast::PipelineStage::kCompute:
597           str->AddUsage(sem::PipelineStageUsage::kComputeInput);
598           break;
599         case ast::PipelineStage::kNone:
600           break;
601       }
602     }
603   }
604 
605   // Resolve the return type
606   sem::Type* return_type = nullptr;
607   if (auto* ty = decl->return_type) {
608     return_type = Type(ty);
609     if (!return_type) {
610       return nullptr;
611     }
612   } else {
613     return_type = builder_->create<sem::Void>();
614   }
615 
616   if (auto* str = return_type->As<sem::Struct>()) {
617     if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
618                                       decl->source)) {
619       AddNote("while instantiating return type for " +
620                   builder_->Symbols().NameFor(decl->symbol),
621               decl->source);
622       return nullptr;
623     }
624 
625     switch (decl->PipelineStage()) {
626       case ast::PipelineStage::kVertex:
627         str->AddUsage(sem::PipelineStageUsage::kVertexOutput);
628         break;
629       case ast::PipelineStage::kFragment:
630         str->AddUsage(sem::PipelineStageUsage::kFragmentOutput);
631         break;
632       case ast::PipelineStage::kCompute:
633         str->AddUsage(sem::PipelineStageUsage::kComputeOutput);
634         break;
635       case ast::PipelineStage::kNone:
636         break;
637     }
638   }
639 
640   auto* func = builder_->create<sem::Function>(decl, return_type, parameters);
641   builder_->Sem().Add(decl, func);
642 
643   TINT_SCOPED_ASSIGNMENT(current_function_, func);
644 
645   if (!WorkgroupSize(decl)) {
646     return nullptr;
647   }
648 
649   if (decl->IsEntryPoint()) {
650     entry_points_.emplace_back(func);
651   }
652 
653   if (decl->body) {
654     Mark(decl->body);
655     if (current_compound_statement_) {
656       TINT_ICE(Resolver, diagnostics_)
657           << "Resolver::Function() called with a current compound statement";
658       return nullptr;
659     }
660     auto* body = StatementScope(
661         decl->body, builder_->create<sem::FunctionBlockStatement>(func),
662         [&] { return Statements(decl->body->statements); });
663     if (!body) {
664       return nullptr;
665     }
666     func->Behaviors() = body->Behaviors();
667     if (func->Behaviors().Contains(sem::Behavior::kReturn)) {
668       // https://www.w3.org/TR/WGSL/#behaviors-rules
669       // We assign a behavior to each function: it is its body’s behavior
670       // (treating the body as a regular statement), with any "Return" replaced
671       // by "Next".
672       func->Behaviors().Remove(sem::Behavior::kReturn);
673       func->Behaviors().Add(sem::Behavior::kNext);
674     }
675   }
676 
677   for (auto* deco : decl->decorations) {
678     Mark(deco);
679   }
680   if (!ValidateNoDuplicateDecorations(decl->decorations)) {
681     return nullptr;
682   }
683 
684   for (auto* deco : decl->return_type_decorations) {
685     Mark(deco);
686   }
687   if (!ValidateNoDuplicateDecorations(decl->return_type_decorations)) {
688     return nullptr;
689   }
690 
691   if (!ValidateFunction(func)) {
692     return nullptr;
693   }
694 
695   // If this is an entry point, mark all transitively called functions as being
696   // used by this entry point.
697   if (decl->IsEntryPoint()) {
698     for (auto* f : func->TransitivelyCalledFunctions()) {
699       const_cast<sem::Function*>(f)->AddAncestorEntryPoint(func);
700     }
701   }
702 
703   return func;
704 }
705 
WorkgroupSize(const ast::Function * func)706 bool Resolver::WorkgroupSize(const ast::Function* func) {
707   // Set work-group size defaults.
708   sem::WorkgroupSize ws;
709   for (int i = 0; i < 3; i++) {
710     ws[i].value = 1;
711     ws[i].overridable_const = nullptr;
712   }
713 
714   auto* deco = ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations);
715   if (!deco) {
716     return true;
717   }
718 
719   auto values = deco->Values();
720   auto any_i32 = false;
721   auto any_u32 = false;
722   for (int i = 0; i < 3; i++) {
723     // Each argument to this decoration can either be a literal, an
724     // identifier for a module-scope constants, or nullptr if not specified.
725 
726     auto* expr = values[i];
727     if (!expr) {
728       // Not specified, just use the default.
729       continue;
730     }
731 
732     auto* expr_sem = Expression(expr);
733     if (!expr_sem) {
734       return false;
735     }
736 
737     constexpr const char* kErrBadType =
738         "workgroup_size argument must be either literal or module-scope "
739         "constant of type i32 or u32";
740     constexpr const char* kErrInconsistentType =
741         "workgroup_size arguments must be of the same type, either i32 "
742         "or u32";
743 
744     auto* ty = TypeOf(expr);
745     bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
746     bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
747     if (!is_i32 && !is_u32) {
748       AddError(kErrBadType, expr->source);
749       return false;
750     }
751 
752     any_i32 = any_i32 || is_i32;
753     any_u32 = any_u32 || is_u32;
754     if (any_i32 && any_u32) {
755       AddError(kErrInconsistentType, expr->source);
756       return false;
757     }
758 
759     sem::Constant value;
760 
761     if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
762       // We have an variable of a module-scope constant.
763       auto* decl = user->Variable()->Declaration();
764       if (!decl->is_const) {
765         AddError(kErrBadType, expr->source);
766         return false;
767       }
768       // Capture the constant if an [[override]] attribute is present.
769       if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
770         ws[i].overridable_const = decl;
771       }
772 
773       if (decl->constructor) {
774         value = Sem(decl->constructor)->ConstantValue();
775       } else {
776         // No constructor means this value must be overriden by the user.
777         ws[i].value = 0;
778         continue;
779       }
780     } else if (expr->Is<ast::LiteralExpression>()) {
781       value = Sem(expr)->ConstantValue();
782     } else {
783       AddError(
784           "workgroup_size argument must be either a literal or a "
785           "module-scope constant",
786           values[i]->source);
787       return false;
788     }
789 
790     if (!value) {
791       TINT_ICE(Resolver, diagnostics_)
792           << "could not resolve constant workgroup_size constant value";
793       continue;
794     }
795     // Validate and set the default value for this dimension.
796     if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
797       AddError("workgroup_size argument must be at least 1", values[i]->source);
798       return false;
799     }
800 
801     ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
802                          : value.Elements()[0].u32;
803   }
804 
805   current_function_->SetWorkgroupSize(std::move(ws));
806   return true;
807 }
808 
Statements(const ast::StatementList & stmts)809 bool Resolver::Statements(const ast::StatementList& stmts) {
810   sem::Behaviors behaviors{sem::Behavior::kNext};
811 
812   bool reachable = true;
813   for (auto* stmt : stmts) {
814     Mark(stmt);
815     auto* sem = Statement(stmt);
816     if (!sem) {
817       return false;
818     }
819     // s1 s2:(B1∖{Next}) ∪ B2
820     sem->SetIsReachable(reachable);
821     if (reachable) {
822       behaviors = (behaviors - sem::Behavior::kNext) + sem->Behaviors();
823     }
824     reachable = reachable && sem->Behaviors().Contains(sem::Behavior::kNext);
825   }
826 
827   current_statement_->Behaviors() = behaviors;
828 
829   if (!ValidateStatements(stmts)) {
830     return false;
831   }
832 
833   return true;
834 }
835 
Statement(const ast::Statement * stmt)836 sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
837   if (stmt->Is<ast::CaseStatement>()) {
838     AddError("case statement can only be used inside a switch statement",
839              stmt->source);
840     return nullptr;
841   }
842   if (stmt->Is<ast::ElseStatement>()) {
843     TINT_ICE(Resolver, diagnostics_)
844         << "Resolver::Statement() encountered an Else statement. Else "
845            "statements are embedded in If statements, so should never be "
846            "encountered as top-level statements";
847     return nullptr;
848   }
849 
850   // Compound statements. These create their own sem::CompoundStatement
851   // bindings.
852   if (auto* b = stmt->As<ast::BlockStatement>()) {
853     return BlockStatement(b);
854   }
855   if (auto* l = stmt->As<ast::ForLoopStatement>()) {
856     return ForLoopStatement(l);
857   }
858   if (auto* l = stmt->As<ast::LoopStatement>()) {
859     return LoopStatement(l);
860   }
861   if (auto* i = stmt->As<ast::IfStatement>()) {
862     return IfStatement(i);
863   }
864   if (auto* s = stmt->As<ast::SwitchStatement>()) {
865     return SwitchStatement(s);
866   }
867 
868   // Non-Compound statements
869   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
870     return AssignmentStatement(a);
871   }
872   if (auto* b = stmt->As<ast::BreakStatement>()) {
873     return BreakStatement(b);
874   }
875   if (auto* c = stmt->As<ast::CallStatement>()) {
876     return CallStatement(c);
877   }
878   if (auto* c = stmt->As<ast::ContinueStatement>()) {
879     return ContinueStatement(c);
880   }
881   if (auto* d = stmt->As<ast::DiscardStatement>()) {
882     return DiscardStatement(d);
883   }
884   if (auto* f = stmt->As<ast::FallthroughStatement>()) {
885     return FallthroughStatement(f);
886   }
887   if (auto* r = stmt->As<ast::ReturnStatement>()) {
888     return ReturnStatement(r);
889   }
890   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
891     return VariableDeclStatement(v);
892   }
893 
894   AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
895            stmt->source);
896   return nullptr;
897 }
898 
CaseStatement(const ast::CaseStatement * stmt)899 sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
900   auto* sem = builder_->create<sem::CaseStatement>(
901       stmt, current_compound_statement_, current_function_);
902   return StatementScope(stmt, sem, [&] {
903     for (auto* sel : stmt->selectors) {
904       Mark(sel);
905     }
906     Mark(stmt->body);
907     auto* body = BlockStatement(stmt->body);
908     if (!body) {
909       return false;
910     }
911     sem->SetBlock(body);
912     sem->Behaviors() = body->Behaviors();
913     return true;
914   });
915 }
916 
IfStatement(const ast::IfStatement * stmt)917 sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
918   auto* sem = builder_->create<sem::IfStatement>(
919       stmt, current_compound_statement_, current_function_);
920   return StatementScope(stmt, sem, [&] {
921     auto* cond = Expression(stmt->condition);
922     if (!cond) {
923       return false;
924     }
925     sem->SetCondition(cond);
926     sem->Behaviors() = cond->Behaviors();
927     sem->Behaviors().Remove(sem::Behavior::kNext);
928 
929     Mark(stmt->body);
930     auto* body = builder_->create<sem::BlockStatement>(
931         stmt->body, current_compound_statement_, current_function_);
932     if (!StatementScope(stmt->body, body,
933                         [&] { return Statements(stmt->body->statements); })) {
934       return false;
935     }
936     sem->Behaviors().Add(body->Behaviors());
937 
938     for (auto* else_stmt : stmt->else_statements) {
939       Mark(else_stmt);
940       auto* else_sem = ElseStatement(else_stmt);
941       if (!else_sem) {
942         return false;
943       }
944       sem->Behaviors().Add(else_sem->Behaviors());
945     }
946 
947     if (stmt->else_statements.empty() ||
948         stmt->else_statements.back()->condition != nullptr) {
949       // https://www.w3.org/TR/WGSL/#behaviors-rules
950       // if statements without an else branch are treated as if they had an
951       // empty else branch (which adds Next to their behavior)
952       sem->Behaviors().Add(sem::Behavior::kNext);
953     }
954 
955     return ValidateIfStatement(sem);
956   });
957 }
958 
ElseStatement(const ast::ElseStatement * stmt)959 sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
960   auto* sem = builder_->create<sem::ElseStatement>(
961       stmt, current_compound_statement_, current_function_);
962   return StatementScope(stmt, sem, [&] {
963     if (auto* cond_expr = stmt->condition) {
964       auto* cond = Expression(cond_expr);
965       if (!cond) {
966         return false;
967       }
968       sem->SetCondition(cond);
969       // https://www.w3.org/TR/WGSL/#behaviors-rules
970       // if statements with else if branches are treated as if they were nested
971       // simple if/else statements
972       sem->Behaviors() = cond->Behaviors();
973     }
974     sem->Behaviors().Remove(sem::Behavior::kNext);
975 
976     Mark(stmt->body);
977     auto* body = builder_->create<sem::BlockStatement>(
978         stmt->body, current_compound_statement_, current_function_);
979     if (!StatementScope(stmt->body, body,
980                         [&] { return Statements(stmt->body->statements); })) {
981       return false;
982     }
983     sem->Behaviors().Add(body->Behaviors());
984 
985     return ValidateElseStatement(sem);
986   });
987 }
988 
BlockStatement(const ast::BlockStatement * stmt)989 sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
990   auto* sem = builder_->create<sem::BlockStatement>(
991       stmt->As<ast::BlockStatement>(), current_compound_statement_,
992       current_function_);
993   return StatementScope(stmt, sem,
994                         [&] { return Statements(stmt->statements); });
995 }
996 
LoopStatement(const ast::LoopStatement * stmt)997 sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
998   auto* sem = builder_->create<sem::LoopStatement>(
999       stmt, current_compound_statement_, current_function_);
1000   return StatementScope(stmt, sem, [&] {
1001     Mark(stmt->body);
1002 
1003     auto* body = builder_->create<sem::LoopBlockStatement>(
1004         stmt->body, current_compound_statement_, current_function_);
1005     return StatementScope(stmt->body, body, [&] {
1006       if (!Statements(stmt->body->statements)) {
1007         return false;
1008       }
1009       auto& behaviors = sem->Behaviors();
1010       behaviors = body->Behaviors();
1011 
1012       if (stmt->continuing) {
1013         Mark(stmt->continuing);
1014         if (!stmt->continuing->Empty()) {
1015           auto* continuing = StatementScope(
1016               stmt->continuing,
1017               builder_->create<sem::LoopContinuingBlockStatement>(
1018                   stmt->continuing, current_compound_statement_,
1019                   current_function_),
1020               [&] { return Statements(stmt->continuing->statements); });
1021           if (!continuing) {
1022             return false;
1023           }
1024           behaviors.Add(continuing->Behaviors());
1025         }
1026       }
1027 
1028       if (behaviors.Contains(sem::Behavior::kBreak)) {  // Does the loop exit?
1029         behaviors.Add(sem::Behavior::kNext);
1030       } else {
1031         behaviors.Remove(sem::Behavior::kNext);
1032       }
1033       behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
1034 
1035       return true;
1036     });
1037   });
1038 }
1039 
ForLoopStatement(const ast::ForLoopStatement * stmt)1040 sem::ForLoopStatement* Resolver::ForLoopStatement(
1041     const ast::ForLoopStatement* stmt) {
1042   auto* sem = builder_->create<sem::ForLoopStatement>(
1043       stmt, current_compound_statement_, current_function_);
1044   return StatementScope(stmt, sem, [&] {
1045     auto& behaviors = sem->Behaviors();
1046     if (auto* initializer = stmt->initializer) {
1047       Mark(initializer);
1048       auto* init = Statement(initializer);
1049       if (!init) {
1050         return false;
1051       }
1052       behaviors.Add(init->Behaviors());
1053     }
1054 
1055     if (auto* cond_expr = stmt->condition) {
1056       auto* cond = Expression(cond_expr);
1057       if (!cond) {
1058         return false;
1059       }
1060       sem->SetCondition(cond);
1061       behaviors.Add(cond->Behaviors());
1062     }
1063 
1064     if (auto* continuing = stmt->continuing) {
1065       Mark(continuing);
1066       auto* cont = Statement(continuing);
1067       if (!cont) {
1068         return false;
1069       }
1070       behaviors.Add(cont->Behaviors());
1071     }
1072 
1073     Mark(stmt->body);
1074 
1075     auto* body = builder_->create<sem::LoopBlockStatement>(
1076         stmt->body, current_compound_statement_, current_function_);
1077     if (!StatementScope(stmt->body, body,
1078                         [&] { return Statements(stmt->body->statements); })) {
1079       return false;
1080     }
1081 
1082     behaviors.Add(body->Behaviors());
1083     if (stmt->condition ||
1084         behaviors.Contains(sem::Behavior::kBreak)) {  // Does the loop exit?
1085       behaviors.Add(sem::Behavior::kNext);
1086     } else {
1087       behaviors.Remove(sem::Behavior::kNext);
1088     }
1089     behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
1090 
1091     return ValidateForLoopStatement(sem);
1092   });
1093 }
1094 
Expression(const ast::Expression * root)1095 sem::Expression* Resolver::Expression(const ast::Expression* root) {
1096   std::vector<const ast::Expression*> sorted;
1097   bool mark_failed = false;
1098   if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
1099           root, diagnostics_, [&](const ast::Expression* expr) {
1100             if (!Mark(expr)) {
1101               mark_failed = true;
1102               return ast::TraverseAction::Stop;
1103             }
1104             sorted.emplace_back(expr);
1105             return ast::TraverseAction::Descend;
1106           })) {
1107     return nullptr;
1108   }
1109 
1110   if (mark_failed) {
1111     return nullptr;
1112   }
1113 
1114   for (auto* expr : utils::Reverse(sorted)) {
1115     sem::Expression* sem_expr = nullptr;
1116     if (auto* array = expr->As<ast::IndexAccessorExpression>()) {
1117       sem_expr = IndexAccessor(array);
1118     } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
1119       sem_expr = Binary(bin_op);
1120     } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
1121       sem_expr = Bitcast(bitcast);
1122     } else if (auto* call = expr->As<ast::CallExpression>()) {
1123       sem_expr = Call(call);
1124     } else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
1125       sem_expr = Identifier(ident);
1126     } else if (auto* literal = expr->As<ast::LiteralExpression>()) {
1127       sem_expr = Literal(literal);
1128     } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
1129       sem_expr = MemberAccessor(member);
1130     } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
1131       sem_expr = UnaryOp(unary);
1132     } else if (expr->Is<ast::PhonyExpression>()) {
1133       sem_expr = builder_->create<sem::Expression>(
1134           expr, builder_->create<sem::Void>(), current_statement_,
1135           sem::Constant{});
1136     } else {
1137       TINT_ICE(Resolver, diagnostics_)
1138           << "unhandled expression type: " << expr->TypeInfo().name;
1139       return nullptr;
1140     }
1141     if (!sem_expr) {
1142       return nullptr;
1143     }
1144 
1145     // https://www.w3.org/TR/WGSL/#behaviors-rules
1146     // an expression behavior is always either {Next} or {Next, Discard}
1147     if (sem_expr->Behaviors() != sem::Behavior::kNext &&
1148         sem_expr->Behaviors() != sem::Behaviors{sem::Behavior::kNext,  // NOLINT
1149                                                 sem::Behavior::kDiscard} &&
1150         !IsCallStatement(expr)) {
1151       TINT_ICE(Resolver, diagnostics_)
1152           << expr->TypeInfo().name
1153           << " behaviors are: " << sem_expr->Behaviors();
1154       return nullptr;
1155     }
1156 
1157     builder_->Sem().Add(expr, sem_expr);
1158     if (expr == root) {
1159       return sem_expr;
1160     }
1161   }
1162 
1163   TINT_ICE(Resolver, diagnostics_) << "Expression() did not find root node";
1164   return nullptr;
1165 }
1166 
IndexAccessor(const ast::IndexAccessorExpression * expr)1167 sem::Expression* Resolver::IndexAccessor(
1168     const ast::IndexAccessorExpression* expr) {
1169   auto* idx = Sem(expr->index);
1170   auto* obj = Sem(expr->object);
1171   auto* obj_raw_ty = obj->Type();
1172   auto* obj_ty = obj_raw_ty->UnwrapRef();
1173   const sem::Type* ty = nullptr;
1174   if (auto* arr = obj_ty->As<sem::Array>()) {
1175     ty = arr->ElemType();
1176   } else if (auto* vec = obj_ty->As<sem::Vector>()) {
1177     ty = vec->type();
1178   } else if (auto* mat = obj_ty->As<sem::Matrix>()) {
1179     ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
1180   } else {
1181     AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
1182     return nullptr;
1183   }
1184 
1185   auto* idx_ty = idx->Type()->UnwrapRef();
1186   if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
1187     AddError("index must be of type 'i32' or 'u32', found: '" +
1188                  TypeNameOf(idx_ty) + "'",
1189              idx->Declaration()->source);
1190     return nullptr;
1191   }
1192 
1193   if (obj_ty->IsAnyOf<sem::Array, sem::Matrix>()) {
1194     if (!obj_raw_ty->Is<sem::Reference>()) {
1195       // TODO(bclayton): expand this to allow any const_expr expression
1196       // https://github.com/gpuweb/gpuweb/issues/1272
1197       if (!idx->Declaration()->As<ast::IntLiteralExpression>()) {
1198         AddError("index must be signed or unsigned integer literal",
1199                  idx->Declaration()->source);
1200         return nullptr;
1201       }
1202     }
1203   }
1204 
1205   // If we're extracting from a reference, we return a reference.
1206   if (auto* ref = obj_raw_ty->As<sem::Reference>()) {
1207     ty = builder_->create<sem::Reference>(ty, ref->StorageClass(),
1208                                           ref->Access());
1209   }
1210 
1211   auto val = EvaluateConstantValue(expr, ty);
1212   auto* sem =
1213       builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1214   sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
1215   return sem;
1216 }
1217 
Bitcast(const ast::BitcastExpression * expr)1218 sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
1219   auto* inner = Sem(expr->expr);
1220   auto* ty = Type(expr->type);
1221   if (!ty) {
1222     return nullptr;
1223   }
1224 
1225   auto val = EvaluateConstantValue(expr, ty);
1226   auto* sem =
1227       builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1228 
1229   sem->Behaviors() = inner->Behaviors();
1230 
1231   if (!ValidateBitcast(expr, ty)) {
1232     return nullptr;
1233   }
1234 
1235   return sem;
1236 }
1237 
Call(const ast::CallExpression * expr)1238 sem::Call* Resolver::Call(const ast::CallExpression* expr) {
1239   std::vector<const sem::Expression*> args(expr->args.size());
1240   std::vector<const sem::Type*> arg_tys(args.size());
1241   sem::Behaviors arg_behaviors;
1242 
1243   for (size_t i = 0; i < expr->args.size(); i++) {
1244     auto* arg = Sem(expr->args[i]);
1245     if (!arg) {
1246       return nullptr;
1247     }
1248     args[i] = arg;
1249     arg_tys[i] = args[i]->Type();
1250     arg_behaviors.Add(arg->Behaviors());
1251   }
1252 
1253   arg_behaviors.Remove(sem::Behavior::kNext);
1254 
1255   auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
1256     // The call has resolved to a type constructor or cast.
1257     if (args.size() == 1) {
1258       auto* target = ty;
1259       auto* source = args[0]->Type()->UnwrapRef();
1260       if ((source != target) &&  //
1261           ((source->is_scalar() && target->is_scalar()) ||
1262            (source->Is<sem::Vector>() && target->Is<sem::Vector>()) ||
1263            (source->Is<sem::Matrix>() && target->Is<sem::Matrix>()))) {
1264         // Note: Matrix types currently cannot be converted (the element type
1265         // must only be f32). We implement this for the day we support other
1266         // matrix element types.
1267         return TypeConversion(expr, ty, args[0], arg_tys[0]);
1268       }
1269     }
1270     return TypeConstructor(expr, ty, std::move(args), std::move(arg_tys));
1271   };
1272 
1273   // Resolve the target of the CallExpression to determine whether this is a
1274   // function call, cast or type constructor expression.
1275   if (expr->target.type) {
1276     auto* ty = Type(expr->target.type);
1277     if (!ty) {
1278       return nullptr;
1279     }
1280     return type_ctor_or_conv(ty);
1281   }
1282 
1283   auto* ident = expr->target.name;
1284   Mark(ident);
1285 
1286   auto* resolved = ResolvedSymbol(ident);
1287   if (auto* ty = As<sem::Type>(resolved)) {
1288     return type_ctor_or_conv(ty);
1289   }
1290 
1291   if (auto* fn = As<sem::Function>(resolved)) {
1292     return FunctionCall(expr, fn, std::move(args), arg_behaviors);
1293   }
1294 
1295   auto name = builder_->Symbols().NameFor(ident->symbol);
1296   auto intrinsic_type = sem::ParseIntrinsicType(name);
1297   if (intrinsic_type != sem::IntrinsicType::kNone) {
1298     return IntrinsicCall(expr, intrinsic_type, std::move(args),
1299                          std::move(arg_tys));
1300   }
1301 
1302   TINT_ICE(Resolver, diagnostics_)
1303       << expr->source << " unresolved CallExpression target:\n"
1304       << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
1305       << "\n"
1306       << "name: " << builder_->Symbols().NameFor(ident->symbol);
1307   return nullptr;
1308 }
1309 
IntrinsicCall(const ast::CallExpression * expr,sem::IntrinsicType intrinsic_type,const std::vector<const sem::Expression * > args,const std::vector<const sem::Type * > arg_tys)1310 sem::Call* Resolver::IntrinsicCall(
1311     const ast::CallExpression* expr,
1312     sem::IntrinsicType intrinsic_type,
1313     const std::vector<const sem::Expression*> args,
1314     const std::vector<const sem::Type*> arg_tys) {
1315   auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys),
1316                                              expr->source);
1317   if (!intrinsic) {
1318     return nullptr;
1319   }
1320 
1321   if (intrinsic->IsDeprecated()) {
1322     AddWarning("use of deprecated intrinsic", expr->source);
1323   }
1324 
1325   auto* call = builder_->create<sem::Call>(expr, intrinsic, std::move(args),
1326                                            current_statement_, sem::Constant{});
1327 
1328   current_function_->AddDirectlyCalledIntrinsic(intrinsic);
1329 
1330   if (IsTextureIntrinsic(intrinsic_type) &&
1331       !ValidateTextureIntrinsicFunction(call)) {
1332     return nullptr;
1333   }
1334 
1335   if (!ValidateIntrinsicCall(call)) {
1336     return nullptr;
1337   }
1338 
1339   current_function_->AddDirectCall(call);
1340 
1341   return call;
1342 }
1343 
FunctionCall(const ast::CallExpression * expr,sem::Function * target,const std::vector<const sem::Expression * > args,sem::Behaviors arg_behaviors)1344 sem::Call* Resolver::FunctionCall(
1345     const ast::CallExpression* expr,
1346     sem::Function* target,
1347     const std::vector<const sem::Expression*> args,
1348     sem::Behaviors arg_behaviors) {
1349   auto sym = expr->target.name->symbol;
1350   auto name = builder_->Symbols().NameFor(sym);
1351 
1352   auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
1353                                            current_statement_, sem::Constant{});
1354 
1355   if (current_function_) {
1356     // Note: Requires called functions to be resolved first.
1357     // This is currently guaranteed as functions must be declared before
1358     // use.
1359     current_function_->AddTransitivelyCalledFunction(target);
1360     current_function_->AddDirectCall(call);
1361     for (auto* transitive_call : target->TransitivelyCalledFunctions()) {
1362       current_function_->AddTransitivelyCalledFunction(transitive_call);
1363     }
1364 
1365     // We inherit any referenced variables from the callee.
1366     for (auto* var : target->TransitivelyReferencedGlobals()) {
1367       current_function_->AddTransitivelyReferencedGlobal(var);
1368     }
1369   }
1370 
1371   target->AddCallSite(call);
1372 
1373   call->Behaviors() = arg_behaviors + target->Behaviors();
1374 
1375   if (!ValidateFunctionCall(call)) {
1376     return nullptr;
1377   }
1378 
1379   return call;
1380 }
1381 
TypeConversion(const ast::CallExpression * expr,const sem::Type * target,const sem::Expression * arg,const sem::Type * source)1382 sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
1383                                     const sem::Type* target,
1384                                     const sem::Expression* arg,
1385                                     const sem::Type* source) {
1386   // It is not valid to have a type-cast call expression inside a call
1387   // statement.
1388   if (IsCallStatement(expr)) {
1389     AddError("type cast evaluated but not used", expr->source);
1390     return nullptr;
1391   }
1392 
1393   auto* call_target = utils::GetOrCreate(
1394       type_conversions_, TypeConversionSig{target, source},
1395       [&]() -> sem::TypeConversion* {
1396         // Now that the argument types have been determined, make sure that they
1397         // obey the conversion rules laid out in
1398         // https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
1399         bool ok = true;
1400         if (auto* vec_type = target->As<sem::Vector>()) {
1401           ok = ValidateVectorConstructorOrCast(expr, vec_type);
1402         } else if (auto* mat_type = target->As<sem::Matrix>()) {
1403           // Note: Matrix types currently cannot be converted (the element type
1404           // must only be f32). We implement this for the day we support other
1405           // matrix element types.
1406           ok = ValidateMatrixConstructorOrCast(expr, mat_type);
1407         } else if (target->is_scalar()) {
1408           ok = ValidateScalarConstructorOrCast(expr, target);
1409         } else if (auto* arr_type = target->As<sem::Array>()) {
1410           ok = ValidateArrayConstructorOrCast(expr, arr_type);
1411         } else if (auto* struct_type = target->As<sem::Struct>()) {
1412           ok = ValidateStructureConstructorOrCast(expr, struct_type);
1413         } else {
1414           AddError("type is not constructible", expr->source);
1415           return nullptr;
1416         }
1417         if (!ok) {
1418           return nullptr;
1419         }
1420 
1421         auto* param = builder_->create<sem::Parameter>(
1422             nullptr,                   // declaration
1423             0,                         // index
1424             source->UnwrapRef(),       // type
1425             ast::StorageClass::kNone,  // storage_class
1426             ast::Access::kUndefined);  // access
1427         return builder_->create<sem::TypeConversion>(target, param);
1428       });
1429 
1430   if (!call_target) {
1431     return nullptr;
1432   }
1433 
1434   auto val = EvaluateConstantValue(expr, target);
1435   return builder_->create<sem::Call>(expr, call_target,
1436                                      std::vector<const sem::Expression*>{arg},
1437                                      current_statement_, val);
1438 }
1439 
TypeConstructor(const ast::CallExpression * expr,const sem::Type * ty,const std::vector<const sem::Expression * > args,const std::vector<const sem::Type * > arg_tys)1440 sem::Call* Resolver::TypeConstructor(
1441     const ast::CallExpression* expr,
1442     const sem::Type* ty,
1443     const std::vector<const sem::Expression*> args,
1444     const std::vector<const sem::Type*> arg_tys) {
1445   // It is not valid to have a type-constructor call expression as a call
1446   // statement.
1447   if (IsCallStatement(expr)) {
1448     AddError("type constructor evaluated but not used", expr->source);
1449     return nullptr;
1450   }
1451 
1452   auto* call_target = utils::GetOrCreate(
1453       type_ctors_, TypeConstructorSig{ty, arg_tys},
1454       [&]() -> sem::TypeConstructor* {
1455         // Now that the argument types have been determined, make sure that they
1456         // obey the constructor type rules laid out in
1457         // https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
1458         bool ok = true;
1459         if (auto* vec_type = ty->As<sem::Vector>()) {
1460           ok = ValidateVectorConstructorOrCast(expr, vec_type);
1461         } else if (auto* mat_type = ty->As<sem::Matrix>()) {
1462           ok = ValidateMatrixConstructorOrCast(expr, mat_type);
1463         } else if (ty->is_scalar()) {
1464           ok = ValidateScalarConstructorOrCast(expr, ty);
1465         } else if (auto* arr_type = ty->As<sem::Array>()) {
1466           ok = ValidateArrayConstructorOrCast(expr, arr_type);
1467         } else if (auto* struct_type = ty->As<sem::Struct>()) {
1468           ok = ValidateStructureConstructorOrCast(expr, struct_type);
1469         } else {
1470           AddError("type is not constructible", expr->source);
1471           return nullptr;
1472         }
1473         if (!ok) {
1474           return nullptr;
1475         }
1476 
1477         return builder_->create<sem::TypeConstructor>(
1478             ty, utils::Transform(
1479                     arg_tys,
1480                     [&](const sem::Type* t, size_t i) -> const sem::Parameter* {
1481                       return builder_->create<sem::Parameter>(
1482                           nullptr,                   // declaration
1483                           i,                         // index
1484                           t->UnwrapRef(),            // type
1485                           ast::StorageClass::kNone,  // storage_class
1486                           ast::Access::kUndefined);  // access
1487                     }));
1488       });
1489 
1490   if (!call_target) {
1491     return nullptr;
1492   }
1493 
1494   auto val = EvaluateConstantValue(expr, ty);
1495   return builder_->create<sem::Call>(expr, call_target, std::move(args),
1496                                      current_statement_, val);
1497 }
1498 
Literal(const ast::LiteralExpression * literal)1499 sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
1500   auto* ty = TypeOf(literal);
1501   if (!ty) {
1502     return nullptr;
1503   }
1504 
1505   auto val = EvaluateConstantValue(literal, ty);
1506   return builder_->create<sem::Expression>(literal, ty, current_statement_,
1507                                            val);
1508 }
1509 
Identifier(const ast::IdentifierExpression * expr)1510 sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
1511   auto symbol = expr->symbol;
1512   auto* resolved = ResolvedSymbol(expr);
1513   if (auto* var = As<sem::Variable>(resolved)) {
1514     auto* user =
1515         builder_->create<sem::VariableUser>(expr, current_statement_, var);
1516 
1517     if (current_statement_) {
1518       // If identifier is part of a loop continuing block, make sure it
1519       // doesn't refer to a variable that is bypassed by a continue statement
1520       // in the loop's body block.
1521       if (auto* continuing_block =
1522               current_statement_
1523                   ->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
1524         auto* loop_block =
1525             continuing_block->FindFirstParent<sem::LoopBlockStatement>();
1526         if (loop_block->FirstContinue()) {
1527           auto& decls = loop_block->Decls();
1528           // If our identifier is in loop_block->decls, make sure its index is
1529           // less than first_continue
1530           auto iter =
1531               std::find_if(decls.begin(), decls.end(),
1532                            [&symbol](auto* v) { return v->symbol == symbol; });
1533           if (iter != decls.end()) {
1534             auto var_decl_index =
1535                 static_cast<size_t>(std::distance(decls.begin(), iter));
1536             if (var_decl_index >= loop_block->NumDeclsAtFirstContinue()) {
1537               AddError("continue statement bypasses declaration of '" +
1538                            builder_->Symbols().NameFor(symbol) + "'",
1539                        loop_block->FirstContinue()->source);
1540               AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
1541                           "' declared here",
1542                       (*iter)->source);
1543               AddNote("identifier '" + builder_->Symbols().NameFor(symbol) +
1544                           "' referenced in continuing block here",
1545                       expr->source);
1546               return nullptr;
1547             }
1548           }
1549         }
1550       }
1551     }
1552 
1553     if (current_function_) {
1554       if (auto* global = var->As<sem::GlobalVariable>()) {
1555         current_function_->AddDirectlyReferencedGlobal(global);
1556       }
1557     }
1558 
1559     var->AddUser(user);
1560     return user;
1561   }
1562 
1563   if (Is<sem::Function>(resolved)) {
1564     AddError("missing '(' for function call", expr->source.End());
1565     return nullptr;
1566   }
1567 
1568   if (IsIntrinsic(symbol)) {
1569     AddError("missing '(' for intrinsic call", expr->source.End());
1570     return nullptr;
1571   }
1572 
1573   if (resolved->Is<sem::Type>()) {
1574     AddError("missing '(' for type constructor or cast", expr->source.End());
1575     return nullptr;
1576   }
1577 
1578   TINT_ICE(Resolver, diagnostics_)
1579       << expr->source << " unresolved identifier:\n"
1580       << "resolved: " << (resolved ? resolved->TypeInfo().name : "<null>")
1581       << "\n"
1582       << "name: " << builder_->Symbols().NameFor(symbol);
1583   return nullptr;
1584 }
1585 
MemberAccessor(const ast::MemberAccessorExpression * expr)1586 sem::Expression* Resolver::MemberAccessor(
1587     const ast::MemberAccessorExpression* expr) {
1588   auto* structure = TypeOf(expr->structure);
1589   auto* storage_ty = structure->UnwrapRef();
1590 
1591   const sem::Type* ret = nullptr;
1592   std::vector<uint32_t> swizzle;
1593 
1594   if (auto* str = storage_ty->As<sem::Struct>()) {
1595     Mark(expr->member);
1596     auto symbol = expr->member->symbol;
1597 
1598     const sem::StructMember* member = nullptr;
1599     for (auto* m : str->Members()) {
1600       if (m->Name() == symbol) {
1601         ret = m->Type();
1602         member = m;
1603         break;
1604       }
1605     }
1606 
1607     if (ret == nullptr) {
1608       AddError(
1609           "struct member " + builder_->Symbols().NameFor(symbol) + " not found",
1610           expr->source);
1611       return nullptr;
1612     }
1613 
1614     // If we're extracting from a reference, we return a reference.
1615     if (auto* ref = structure->As<sem::Reference>()) {
1616       ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
1617                                              ref->Access());
1618     }
1619 
1620     return builder_->create<sem::StructMemberAccess>(
1621         expr, ret, current_statement_, member);
1622   }
1623 
1624   if (auto* vec = storage_ty->As<sem::Vector>()) {
1625     Mark(expr->member);
1626     std::string s = builder_->Symbols().NameFor(expr->member->symbol);
1627     auto size = s.size();
1628     swizzle.reserve(s.size());
1629 
1630     for (auto c : s) {
1631       switch (c) {
1632         case 'x':
1633         case 'r':
1634           swizzle.emplace_back(0);
1635           break;
1636         case 'y':
1637         case 'g':
1638           swizzle.emplace_back(1);
1639           break;
1640         case 'z':
1641         case 'b':
1642           swizzle.emplace_back(2);
1643           break;
1644         case 'w':
1645         case 'a':
1646           swizzle.emplace_back(3);
1647           break;
1648         default:
1649           AddError("invalid vector swizzle character",
1650                    expr->member->source.Begin() + swizzle.size());
1651           return nullptr;
1652       }
1653 
1654       if (swizzle.back() >= vec->Width()) {
1655         AddError("invalid vector swizzle member", expr->member->source);
1656         return nullptr;
1657       }
1658     }
1659 
1660     if (size < 1 || size > 4) {
1661       AddError("invalid vector swizzle size", expr->member->source);
1662       return nullptr;
1663     }
1664 
1665     // All characters are valid, check if they're being mixed
1666     auto is_rgba = [](char c) {
1667       return c == 'r' || c == 'g' || c == 'b' || c == 'a';
1668     };
1669     auto is_xyzw = [](char c) {
1670       return c == 'x' || c == 'y' || c == 'z' || c == 'w';
1671     };
1672     if (!std::all_of(s.begin(), s.end(), is_rgba) &&
1673         !std::all_of(s.begin(), s.end(), is_xyzw)) {
1674       AddError("invalid mixing of vector swizzle characters rgba with xyzw",
1675                expr->member->source);
1676       return nullptr;
1677     }
1678 
1679     if (size == 1) {
1680       // A single element swizzle is just the type of the vector.
1681       ret = vec->type();
1682       // If we're extracting from a reference, we return a reference.
1683       if (auto* ref = structure->As<sem::Reference>()) {
1684         ret = builder_->create<sem::Reference>(ret, ref->StorageClass(),
1685                                                ref->Access());
1686       }
1687     } else {
1688       // The vector will have a number of components equal to the length of
1689       // the swizzle.
1690       ret = builder_->create<sem::Vector>(vec->type(),
1691                                           static_cast<uint32_t>(size));
1692     }
1693     return builder_->create<sem::Swizzle>(expr, ret, current_statement_,
1694                                           std::move(swizzle));
1695   }
1696 
1697   AddError(
1698       "invalid member accessor expression. Expected vector or struct, got '" +
1699           TypeNameOf(storage_ty) + "'",
1700       expr->structure->source);
1701   return nullptr;
1702 }
1703 
Binary(const ast::BinaryExpression * expr)1704 sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
1705   using Bool = sem::Bool;
1706   using F32 = sem::F32;
1707   using I32 = sem::I32;
1708   using U32 = sem::U32;
1709   using Matrix = sem::Matrix;
1710   using Vector = sem::Vector;
1711 
1712   auto* lhs = Sem(expr->lhs);
1713   auto* rhs = Sem(expr->rhs);
1714 
1715   auto* lhs_ty = lhs->Type()->UnwrapRef();
1716   auto* rhs_ty = rhs->Type()->UnwrapRef();
1717 
1718   auto* lhs_vec = lhs_ty->As<Vector>();
1719   auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
1720   auto* rhs_vec = rhs_ty->As<Vector>();
1721   auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
1722 
1723   const bool matching_vec_elem_types =
1724       lhs_vec_elem_type && rhs_vec_elem_type &&
1725       (lhs_vec_elem_type == rhs_vec_elem_type) &&
1726       (lhs_vec->Width() == rhs_vec->Width());
1727 
1728   const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
1729 
1730   auto build = [&](const sem::Type* ty) {
1731     auto val = EvaluateConstantValue(expr, ty);
1732     auto* sem =
1733         builder_->create<sem::Expression>(expr, ty, current_statement_, val);
1734     sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
1735     return sem;
1736   };
1737 
1738   // Binary logical expressions
1739   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
1740     if (matching_types && lhs_ty->Is<Bool>()) {
1741       return build(lhs_ty);
1742     }
1743   }
1744   if (expr->IsOr() || expr->IsAnd()) {
1745     if (matching_types && lhs_ty->Is<Bool>()) {
1746       return build(lhs_ty);
1747     }
1748     if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
1749       return build(lhs_ty);
1750     }
1751   }
1752 
1753   // Arithmetic expressions
1754   if (expr->IsArithmetic()) {
1755     // Binary arithmetic expressions over scalars
1756     if (matching_types && lhs_ty->is_numeric_scalar()) {
1757       return build(lhs_ty);
1758     }
1759 
1760     // Binary arithmetic expressions over vectors
1761     if (matching_types && lhs_vec_elem_type &&
1762         lhs_vec_elem_type->is_numeric_scalar()) {
1763       return build(lhs_ty);
1764     }
1765 
1766     // Binary arithmetic expressions with mixed scalar and vector operands
1767     if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) {
1768       if (expr->IsModulo()) {
1769         if (rhs_ty->is_integer_scalar()) {
1770           return build(lhs_ty);
1771         }
1772       } else if (rhs_ty->is_numeric_scalar()) {
1773         return build(lhs_ty);
1774       }
1775     }
1776     if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) {
1777       if (expr->IsModulo()) {
1778         if (lhs_ty->is_integer_scalar()) {
1779           return build(rhs_ty);
1780         }
1781       } else if (lhs_ty->is_numeric_scalar()) {
1782         return build(rhs_ty);
1783       }
1784     }
1785   }
1786 
1787   // Matrix arithmetic
1788   auto* lhs_mat = lhs_ty->As<Matrix>();
1789   auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
1790   auto* rhs_mat = rhs_ty->As<Matrix>();
1791   auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
1792   // Addition and subtraction of float matrices
1793   if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
1794       lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
1795       rhs_mat_elem_type->Is<F32>() &&
1796       (lhs_mat->columns() == rhs_mat->columns()) &&
1797       (lhs_mat->rows() == rhs_mat->rows())) {
1798     return build(rhs_ty);
1799   }
1800   if (expr->IsMultiply()) {
1801     // Multiplication of a matrix and a scalar
1802     if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
1803         rhs_mat_elem_type->Is<F32>()) {
1804       return build(rhs_ty);
1805     }
1806     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1807         rhs_ty->Is<F32>()) {
1808       return build(lhs_ty);
1809     }
1810 
1811     // Vector times matrix
1812     if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
1813         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
1814         (lhs_vec->Width() == rhs_mat->rows())) {
1815       return build(
1816           builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
1817     }
1818 
1819     // Matrix times vector
1820     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1821         rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
1822         (lhs_mat->columns() == rhs_vec->Width())) {
1823       return build(
1824           builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
1825     }
1826 
1827     // Matrix times matrix
1828     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
1829         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
1830         (lhs_mat->columns() == rhs_mat->rows())) {
1831       return build(builder_->create<sem::Matrix>(
1832           builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
1833           rhs_mat->columns()));
1834     }
1835   }
1836 
1837   // Comparison expressions
1838   if (expr->IsComparison()) {
1839     if (matching_types) {
1840       // Special case for bools: only == and !=
1841       if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
1842         return build(builder_->create<sem::Bool>());
1843       }
1844 
1845       // For the rest, we can compare i32, u32, and f32
1846       if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
1847         return build(builder_->create<sem::Bool>());
1848       }
1849     }
1850 
1851     // Same for vectors
1852     if (matching_vec_elem_types) {
1853       if (lhs_vec_elem_type->Is<Bool>() &&
1854           (expr->IsEqual() || expr->IsNotEqual())) {
1855         return build(builder_->create<sem::Vector>(
1856             builder_->create<sem::Bool>(), lhs_vec->Width()));
1857       }
1858 
1859       if (lhs_vec_elem_type->is_numeric_scalar()) {
1860         return build(builder_->create<sem::Vector>(
1861             builder_->create<sem::Bool>(), lhs_vec->Width()));
1862       }
1863     }
1864   }
1865 
1866   // Binary bitwise operations
1867   if (expr->IsBitwise()) {
1868     if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
1869       return build(lhs_ty);
1870     }
1871   }
1872 
1873   // Bit shift expressions
1874   if (expr->IsBitshift()) {
1875     // Type validation rules are the same for left or right shift, despite
1876     // differences in computation rules (i.e. right shift can be arithmetic or
1877     // logical depending on lhs type).
1878 
1879     if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
1880       return build(lhs_ty);
1881     }
1882 
1883     if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
1884         rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
1885       return build(lhs_ty);
1886     }
1887   }
1888 
1889   AddError("Binary expression operand types are invalid for this operation: " +
1890                TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
1891                TypeNameOf(rhs_ty),
1892            expr->source);
1893   return nullptr;
1894 }
1895 
UnaryOp(const ast::UnaryOpExpression * unary)1896 sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
1897   auto* expr = Sem(unary->expr);
1898   auto* expr_ty = expr->Type();
1899   if (!expr_ty) {
1900     return nullptr;
1901   }
1902 
1903   const sem::Type* ty = nullptr;
1904 
1905   switch (unary->op) {
1906     case ast::UnaryOp::kNot:
1907       // Result type matches the deref'd inner type.
1908       ty = expr_ty->UnwrapRef();
1909       if (!ty->Is<sem::Bool>() && !ty->is_bool_vector()) {
1910         AddError(
1911             "cannot logical negate expression of type '" + TypeNameOf(expr_ty),
1912             unary->expr->source);
1913         return nullptr;
1914       }
1915       break;
1916 
1917     case ast::UnaryOp::kComplement:
1918       // Result type matches the deref'd inner type.
1919       ty = expr_ty->UnwrapRef();
1920       if (!ty->is_integer_scalar_or_vector()) {
1921         AddError("cannot bitwise complement expression of type '" +
1922                      TypeNameOf(expr_ty),
1923                  unary->expr->source);
1924         return nullptr;
1925       }
1926       break;
1927 
1928     case ast::UnaryOp::kNegation:
1929       // Result type matches the deref'd inner type.
1930       ty = expr_ty->UnwrapRef();
1931       if (!(ty->IsAnyOf<sem::F32, sem::I32>() ||
1932             ty->is_signed_integer_vector() || ty->is_float_vector())) {
1933         AddError("cannot negate expression of type '" + TypeNameOf(expr_ty),
1934                  unary->expr->source);
1935         return nullptr;
1936       }
1937       break;
1938 
1939     case ast::UnaryOp::kAddressOf:
1940       if (auto* ref = expr_ty->As<sem::Reference>()) {
1941         if (ref->StoreType()->UnwrapRef()->is_handle()) {
1942           AddError(
1943               "cannot take the address of expression in handle storage class",
1944               unary->expr->source);
1945           return nullptr;
1946         }
1947 
1948         auto* array = unary->expr->As<ast::IndexAccessorExpression>();
1949         auto* member = unary->expr->As<ast::MemberAccessorExpression>();
1950         if ((array && TypeOf(array->object)->UnwrapRef()->Is<sem::Vector>()) ||
1951             (member &&
1952              TypeOf(member->structure)->UnwrapRef()->Is<sem::Vector>())) {
1953           AddError("cannot take the address of a vector component",
1954                    unary->expr->source);
1955           return nullptr;
1956         }
1957 
1958         ty = builder_->create<sem::Pointer>(ref->StoreType(),
1959                                             ref->StorageClass(), ref->Access());
1960       } else {
1961         AddError("cannot take the address of expression", unary->expr->source);
1962         return nullptr;
1963       }
1964       break;
1965 
1966     case ast::UnaryOp::kIndirection:
1967       if (auto* ptr = expr_ty->As<sem::Pointer>()) {
1968         ty = builder_->create<sem::Reference>(
1969             ptr->StoreType(), ptr->StorageClass(), ptr->Access());
1970       } else {
1971         AddError("cannot dereference expression of type '" +
1972                      TypeNameOf(expr_ty) + "'",
1973                  unary->expr->source);
1974         return nullptr;
1975       }
1976       break;
1977   }
1978 
1979   auto val = EvaluateConstantValue(unary, ty);
1980   auto* sem =
1981       builder_->create<sem::Expression>(unary, ty, current_statement_, val);
1982   sem->Behaviors() = expr->Behaviors();
1983   return sem;
1984 }
1985 
TypeDecl(const ast::TypeDecl * named_type)1986 sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
1987   sem::Type* result = nullptr;
1988   if (auto* alias = named_type->As<ast::Alias>()) {
1989     result = Alias(alias);
1990   } else if (auto* str = named_type->As<ast::Struct>()) {
1991     result = Structure(str);
1992   } else {
1993     TINT_UNREACHABLE(Resolver, diagnostics_) << "Unhandled TypeDecl";
1994   }
1995 
1996   if (!result) {
1997     return nullptr;
1998   }
1999 
2000   builder_->Sem().Add(named_type, result);
2001   return result;
2002 }
2003 
TypeOf(const ast::Expression * expr)2004 sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
2005   auto* sem = Sem(expr);
2006   return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
2007 }
2008 
TypeNameOf(const sem::Type * ty)2009 std::string Resolver::TypeNameOf(const sem::Type* ty) {
2010   return RawTypeNameOf(ty->UnwrapRef());
2011 }
2012 
RawTypeNameOf(const sem::Type * ty)2013 std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
2014   return ty->FriendlyName(builder_->Symbols());
2015 }
2016 
TypeOf(const ast::LiteralExpression * lit)2017 sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
2018   if (lit->Is<ast::SintLiteralExpression>()) {
2019     return builder_->create<sem::I32>();
2020   }
2021   if (lit->Is<ast::UintLiteralExpression>()) {
2022     return builder_->create<sem::U32>();
2023   }
2024   if (lit->Is<ast::FloatLiteralExpression>()) {
2025     return builder_->create<sem::F32>();
2026   }
2027   if (lit->Is<ast::BoolLiteralExpression>()) {
2028     return builder_->create<sem::Bool>();
2029   }
2030   TINT_UNREACHABLE(Resolver, diagnostics_)
2031       << "Unhandled literal type: " << lit->TypeInfo().name;
2032   return nullptr;
2033 }
2034 
Array(const ast::Array * arr)2035 sem::Array* Resolver::Array(const ast::Array* arr) {
2036   auto source = arr->source;
2037 
2038   auto* elem_type = Type(arr->type);
2039   if (!elem_type) {
2040     return nullptr;
2041   }
2042 
2043   if (!IsPlain(elem_type)) {  // Check must come before GetDefaultAlignAndSize()
2044     AddError(TypeNameOf(elem_type) +
2045                  " cannot be used as an element type of an array",
2046              source);
2047     return nullptr;
2048   }
2049 
2050   uint32_t el_align = elem_type->Align();
2051   uint32_t el_size = elem_type->Size();
2052 
2053   if (!ValidateNoDuplicateDecorations(arr->decorations)) {
2054     return nullptr;
2055   }
2056 
2057   // Look for explicit stride via [[stride(n)]] decoration
2058   uint32_t explicit_stride = 0;
2059   for (auto* deco : arr->decorations) {
2060     Mark(deco);
2061     if (auto* sd = deco->As<ast::StrideDecoration>()) {
2062       explicit_stride = sd->stride;
2063       if (!ValidateArrayStrideDecoration(sd, el_size, el_align, source)) {
2064         return nullptr;
2065       }
2066       continue;
2067     }
2068 
2069     AddError("decoration is not valid for array types", deco->source);
2070     return nullptr;
2071   }
2072 
2073   // Calculate implicit stride
2074   uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
2075 
2076   uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
2077 
2078   // Evaluate the constant array size expression.
2079   // sem::Array uses a size of 0 for a runtime-sized array.
2080   uint32_t count = 0;
2081   if (auto* count_expr = arr->count) {
2082     auto* count_sem = Expression(count_expr);
2083     if (!count_sem) {
2084       return nullptr;
2085     }
2086 
2087     auto size_source = count_expr->source;
2088 
2089     auto* ty = count_sem->Type()->UnwrapRef();
2090     if (!ty->is_integer_scalar()) {
2091       AddError("array size must be integer scalar", size_source);
2092       return nullptr;
2093     }
2094 
2095     if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
2096       // Make sure the identifier is a non-overridable module-scope constant.
2097       auto* var = ResolvedSymbol<sem::Variable>(ident);
2098       if (!var || !var->Is<sem::GlobalVariable>() ||
2099           !var->Declaration()->is_const) {
2100         AddError("array size identifier must be a module-scope constant",
2101                  size_source);
2102         return nullptr;
2103       }
2104       if (ast::HasDecoration<ast::OverrideDecoration>(
2105               var->Declaration()->decorations)) {
2106         AddError("array size expression must not be pipeline-overridable",
2107                  size_source);
2108         return nullptr;
2109       }
2110 
2111       count_expr = var->Declaration()->constructor;
2112     } else if (!count_expr->Is<ast::LiteralExpression>()) {
2113       AddError(
2114           "array size expression must be either a literal or a module-scope "
2115           "constant",
2116           size_source);
2117       return nullptr;
2118     }
2119 
2120     auto count_val = count_sem->ConstantValue();
2121     if (!count_val) {
2122       TINT_ICE(Resolver, diagnostics_)
2123           << "could not resolve array size expression";
2124       return nullptr;
2125     }
2126 
2127     if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1
2128                                        : count_val.Elements()[0].u32 < 1u) {
2129       AddError("array size must be at least 1", size_source);
2130       return nullptr;
2131     }
2132 
2133     count = count_val.Elements()[0].u32;
2134   }
2135 
2136   auto size = std::max<uint64_t>(count, 1) * stride;
2137   if (size > std::numeric_limits<uint32_t>::max()) {
2138     std::stringstream msg;
2139     msg << "array size in bytes must not exceed 0x" << std::hex
2140         << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
2141         << size;
2142     AddError(msg.str(), arr->source);
2143     return nullptr;
2144   }
2145   if (stride > std::numeric_limits<uint32_t>::max() ||
2146       implicit_stride > std::numeric_limits<uint32_t>::max()) {
2147     TINT_ICE(Resolver, diagnostics_)
2148         << "calculated array stride exceeds uint32";
2149     return nullptr;
2150   }
2151   auto* out = builder_->create<sem::Array>(
2152       elem_type, count, el_align, static_cast<uint32_t>(size),
2153       static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
2154 
2155   if (!ValidateArray(out, source)) {
2156     return nullptr;
2157   }
2158 
2159   if (elem_type->Is<sem::Atomic>()) {
2160     atomic_composite_info_.emplace(out, arr->type->source);
2161   } else {
2162     auto found = atomic_composite_info_.find(elem_type);
2163     if (found != atomic_composite_info_.end()) {
2164       atomic_composite_info_.emplace(out, found->second);
2165     }
2166   }
2167 
2168   return out;
2169 }
2170 
Alias(const ast::Alias * alias)2171 sem::Type* Resolver::Alias(const ast::Alias* alias) {
2172   auto* ty = Type(alias->type);
2173   if (!ty) {
2174     return nullptr;
2175   }
2176   if (!ValidateAlias(alias)) {
2177     return nullptr;
2178   }
2179   return ty;
2180 }
2181 
Structure(const ast::Struct * str)2182 sem::Struct* Resolver::Structure(const ast::Struct* str) {
2183   if (!ValidateNoDuplicateDecorations(str->decorations)) {
2184     return nullptr;
2185   }
2186   for (auto* deco : str->decorations) {
2187     Mark(deco);
2188   }
2189 
2190   sem::StructMemberList sem_members;
2191   sem_members.reserve(str->members.size());
2192 
2193   // Calculate the effective size and alignment of each field, and the overall
2194   // size of the structure.
2195   // For size, use the size attribute if provided, otherwise use the default
2196   // size for the type.
2197   // For alignment, use the alignment attribute if provided, otherwise use the
2198   // default alignment for the member type.
2199   // Diagnostic errors are raised if a basic rule is violated.
2200   // Validation of storage-class rules requires analysing the actual variable
2201   // usage of the structure, and so is performed as part of the variable
2202   // validation.
2203   uint64_t struct_size = 0;
2204   uint64_t struct_align = 1;
2205   std::unordered_map<Symbol, const ast::StructMember*> member_map;
2206 
2207   for (auto* member : str->members) {
2208     Mark(member);
2209     auto result = member_map.emplace(member->symbol, member);
2210     if (!result.second) {
2211       AddError("redefinition of '" +
2212                    builder_->Symbols().NameFor(member->symbol) + "'",
2213                member->source);
2214       AddNote("previous definition is here", result.first->second->source);
2215       return nullptr;
2216     }
2217 
2218     // Resolve member type
2219     auto* type = Type(member->type);
2220     if (!type) {
2221       return nullptr;
2222     }
2223 
2224     // Validate member type
2225     if (!IsPlain(type)) {
2226       AddError(TypeNameOf(type) +
2227                    " cannot be used as the type of a structure member",
2228                member->source);
2229       return nullptr;
2230     }
2231 
2232     uint64_t offset = struct_size;
2233     uint64_t align = type->Align();
2234     uint64_t size = type->Size();
2235 
2236     if (!ValidateNoDuplicateDecorations(member->decorations)) {
2237       return nullptr;
2238     }
2239 
2240     bool has_offset_deco = false;
2241     bool has_align_deco = false;
2242     bool has_size_deco = false;
2243     for (auto* deco : member->decorations) {
2244       Mark(deco);
2245       if (auto* o = deco->As<ast::StructMemberOffsetDecoration>()) {
2246         // Offset decorations are not part of the WGSL spec, but are emitted
2247         // by the SPIR-V reader.
2248         if (o->offset < struct_size) {
2249           AddError("offsets must be in ascending order", o->source);
2250           return nullptr;
2251         }
2252         offset = o->offset;
2253         align = 1;
2254         has_offset_deco = true;
2255       } else if (auto* a = deco->As<ast::StructMemberAlignDecoration>()) {
2256         if (a->align <= 0 || !utils::IsPowerOfTwo(a->align)) {
2257           AddError("align value must be a positive, power-of-two integer",
2258                    a->source);
2259           return nullptr;
2260         }
2261         align = a->align;
2262         has_align_deco = true;
2263       } else if (auto* s = deco->As<ast::StructMemberSizeDecoration>()) {
2264         if (s->size < size) {
2265           AddError("size must be at least as big as the type's size (" +
2266                        std::to_string(size) + ")",
2267                    s->source);
2268           return nullptr;
2269         }
2270         size = s->size;
2271         has_size_deco = true;
2272       }
2273     }
2274 
2275     if (has_offset_deco && (has_align_deco || has_size_deco)) {
2276       AddError(
2277           "offset decorations cannot be used with align or size decorations",
2278           member->source);
2279       return nullptr;
2280     }
2281 
2282     offset = utils::RoundUp(align, offset);
2283     if (offset > std::numeric_limits<uint32_t>::max()) {
2284       std::stringstream msg;
2285       msg << "struct member has byte offset 0x" << std::hex << offset
2286           << ", but must not exceed 0x" << std::hex
2287           << std::numeric_limits<uint32_t>::max();
2288       AddError(msg.str(), member->source);
2289       return nullptr;
2290     }
2291 
2292     auto* sem_member = builder_->create<sem::StructMember>(
2293         member, member->symbol, type, static_cast<uint32_t>(sem_members.size()),
2294         static_cast<uint32_t>(offset), static_cast<uint32_t>(align),
2295         static_cast<uint32_t>(size));
2296     builder_->Sem().Add(member, sem_member);
2297     sem_members.emplace_back(sem_member);
2298 
2299     struct_size = offset + size;
2300     struct_align = std::max(struct_align, align);
2301   }
2302 
2303   uint64_t size_no_padding = struct_size;
2304   struct_size = utils::RoundUp(struct_align, struct_size);
2305 
2306   if (struct_size > std::numeric_limits<uint32_t>::max()) {
2307     std::stringstream msg;
2308     msg << "struct size in bytes must not exceed 0x" << std::hex
2309         << std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
2310         << struct_size;
2311     AddError(msg.str(), str->source);
2312     return nullptr;
2313   }
2314   if (struct_align > std::numeric_limits<uint32_t>::max()) {
2315     TINT_ICE(Resolver, diagnostics_)
2316         << "calculated struct stride exceeds uint32";
2317     return nullptr;
2318   }
2319 
2320   auto* out = builder_->create<sem::Struct>(
2321       str, str->name, sem_members, static_cast<uint32_t>(struct_align),
2322       static_cast<uint32_t>(struct_size),
2323       static_cast<uint32_t>(size_no_padding));
2324 
2325   for (size_t i = 0; i < sem_members.size(); i++) {
2326     auto* mem_type = sem_members[i]->Type();
2327     if (mem_type->Is<sem::Atomic>()) {
2328       atomic_composite_info_.emplace(out,
2329                                      sem_members[i]->Declaration()->source);
2330       break;
2331     } else {
2332       auto found = atomic_composite_info_.find(mem_type);
2333       if (found != atomic_composite_info_.end()) {
2334         atomic_composite_info_.emplace(out, found->second);
2335         break;
2336       }
2337     }
2338   }
2339 
2340   if (!ValidateStructure(out)) {
2341     return nullptr;
2342   }
2343 
2344   return out;
2345 }
2346 
ReturnStatement(const ast::ReturnStatement * stmt)2347 sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
2348   auto* sem = builder_->create<sem::Statement>(
2349       stmt, current_compound_statement_, current_function_);
2350   return StatementScope(stmt, sem, [&] {
2351     auto& behaviors = current_statement_->Behaviors();
2352     behaviors = sem::Behavior::kReturn;
2353 
2354     if (auto* value = stmt->value) {
2355       auto* expr = Expression(value);
2356       if (!expr) {
2357         return false;
2358       }
2359       behaviors.Add(expr->Behaviors() - sem::Behavior::kNext);
2360     }
2361 
2362     // Validate after processing the return value expression so that its type is
2363     // available for validation.
2364     return ValidateReturn(stmt);
2365   });
2366 }
2367 
SwitchStatement(const ast::SwitchStatement * stmt)2368 sem::SwitchStatement* Resolver::SwitchStatement(
2369     const ast::SwitchStatement* stmt) {
2370   auto* sem = builder_->create<sem::SwitchStatement>(
2371       stmt, current_compound_statement_, current_function_);
2372   return StatementScope(stmt, sem, [&] {
2373     auto& behaviors = sem->Behaviors();
2374 
2375     auto* cond = Expression(stmt->condition);
2376     if (!cond) {
2377       return false;
2378     }
2379     behaviors = cond->Behaviors() - sem::Behavior::kNext;
2380 
2381     for (auto* case_stmt : stmt->body) {
2382       Mark(case_stmt);
2383       auto* c = CaseStatement(case_stmt);
2384       if (!c) {
2385         return false;
2386       }
2387       behaviors.Add(c->Behaviors());
2388     }
2389 
2390     if (behaviors.Contains(sem::Behavior::kBreak)) {
2391       behaviors.Add(sem::Behavior::kNext);
2392     }
2393     behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough);
2394 
2395     return ValidateSwitch(stmt);
2396   });
2397 }
2398 
VariableDeclStatement(const ast::VariableDeclStatement * stmt)2399 sem::Statement* Resolver::VariableDeclStatement(
2400     const ast::VariableDeclStatement* stmt) {
2401   auto* sem = builder_->create<sem::Statement>(
2402       stmt, current_compound_statement_, current_function_);
2403   return StatementScope(stmt, sem, [&] {
2404     Mark(stmt->variable);
2405 
2406     auto* var = Variable(stmt->variable, VariableKind::kLocal);
2407     if (!var) {
2408       return false;
2409     }
2410 
2411     for (auto* deco : stmt->variable->decorations) {
2412       Mark(deco);
2413       if (!deco->Is<ast::InternalDecoration>()) {
2414         AddError("decorations are not valid on local variables", deco->source);
2415         return false;
2416       }
2417     }
2418 
2419     if (current_block_) {  // Not all statements are inside a block
2420       current_block_->AddDecl(stmt->variable);
2421     }
2422 
2423     if (auto* ctor = var->Constructor()) {
2424       sem->Behaviors() = ctor->Behaviors();
2425     }
2426 
2427     return ValidateVariable(var);
2428   });
2429 }
2430 
AssignmentStatement(const ast::AssignmentStatement * stmt)2431 sem::Statement* Resolver::AssignmentStatement(
2432     const ast::AssignmentStatement* stmt) {
2433   auto* sem = builder_->create<sem::Statement>(
2434       stmt, current_compound_statement_, current_function_);
2435   return StatementScope(stmt, sem, [&] {
2436     auto* lhs = Expression(stmt->lhs);
2437     if (!lhs) {
2438       return false;
2439     }
2440 
2441     auto* rhs = Expression(stmt->rhs);
2442     if (!rhs) {
2443       return false;
2444     }
2445 
2446     auto& behaviors = sem->Behaviors();
2447     behaviors = rhs->Behaviors();
2448     if (!stmt->lhs->Is<ast::PhonyExpression>()) {
2449       behaviors.Add(lhs->Behaviors());
2450     }
2451 
2452     return ValidateAssignment(stmt);
2453   });
2454 }
2455 
BreakStatement(const ast::BreakStatement * stmt)2456 sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
2457   auto* sem = builder_->create<sem::Statement>(
2458       stmt, current_compound_statement_, current_function_);
2459   return StatementScope(stmt, sem, [&] {
2460     sem->Behaviors() = sem::Behavior::kBreak;
2461 
2462     return ValidateBreakStatement(sem);
2463   });
2464 }
2465 
CallStatement(const ast::CallStatement * stmt)2466 sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
2467   auto* sem = builder_->create<sem::Statement>(
2468       stmt, current_compound_statement_, current_function_);
2469   return StatementScope(stmt, sem, [&] {
2470     if (auto* expr = Expression(stmt->expr)) {
2471       sem->Behaviors() = expr->Behaviors();
2472       return true;
2473     }
2474     return false;
2475   });
2476 }
2477 
ContinueStatement(const ast::ContinueStatement * stmt)2478 sem::Statement* Resolver::ContinueStatement(
2479     const ast::ContinueStatement* stmt) {
2480   auto* sem = builder_->create<sem::Statement>(
2481       stmt, current_compound_statement_, current_function_);
2482   return StatementScope(stmt, sem, [&] {
2483     sem->Behaviors() = sem::Behavior::kContinue;
2484 
2485     // Set if we've hit the first continue statement in our parent loop
2486     if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
2487       if (!block->FirstContinue()) {
2488         const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
2489             stmt, block->Decls().size());
2490       }
2491     }
2492 
2493     return ValidateContinueStatement(sem);
2494   });
2495 }
2496 
DiscardStatement(const ast::DiscardStatement * stmt)2497 sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
2498   auto* sem = builder_->create<sem::Statement>(
2499       stmt, current_compound_statement_, current_function_);
2500   return StatementScope(stmt, sem, [&] {
2501     sem->Behaviors() = sem::Behavior::kDiscard;
2502     current_function_->SetHasDiscard();
2503 
2504     return ValidateDiscardStatement(sem);
2505   });
2506 }
2507 
FallthroughStatement(const ast::FallthroughStatement * stmt)2508 sem::Statement* Resolver::FallthroughStatement(
2509     const ast::FallthroughStatement* stmt) {
2510   auto* sem = builder_->create<sem::Statement>(
2511       stmt, current_compound_statement_, current_function_);
2512   return StatementScope(stmt, sem, [&] {
2513     sem->Behaviors() = sem::Behavior::kFallthrough;
2514 
2515     return ValidateFallthroughStatement(sem);
2516   });
2517 }
2518 
ApplyStorageClassUsageToType(ast::StorageClass sc,sem::Type * ty,const Source & usage)2519 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
2520                                             sem::Type* ty,
2521                                             const Source& usage) {
2522   ty = const_cast<sem::Type*>(ty->UnwrapRef());
2523 
2524   if (auto* str = ty->As<sem::Struct>()) {
2525     if (str->StorageClassUsage().count(sc)) {
2526       return true;  // Already applied
2527     }
2528 
2529     str->AddUsage(sc);
2530 
2531     for (auto* member : str->Members()) {
2532       if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) {
2533         std::stringstream err;
2534         err << "while analysing structure member " << TypeNameOf(str) << "."
2535             << builder_->Symbols().NameFor(member->Declaration()->symbol);
2536         AddNote(err.str(), member->Declaration()->source);
2537         return false;
2538       }
2539     }
2540     return true;
2541   }
2542 
2543   if (auto* arr = ty->As<sem::Array>()) {
2544     return ApplyStorageClassUsageToType(
2545         sc, const_cast<sem::Type*>(arr->ElemType()), usage);
2546   }
2547 
2548   if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
2549     std::stringstream err;
2550     err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '"
2551         << sc << "' as it is non-host-shareable";
2552     AddError(err.str(), usage);
2553     return false;
2554   }
2555 
2556   return true;
2557 }
2558 
2559 template <typename SEM, typename F>
StatementScope(const ast::Statement * ast,SEM * sem,F && callback)2560 SEM* Resolver::StatementScope(const ast::Statement* ast,
2561                               SEM* sem,
2562                               F&& callback) {
2563   builder_->Sem().Add(ast, sem);
2564 
2565   auto* as_compound =
2566       As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
2567   auto* as_block =
2568       As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
2569 
2570   TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
2571   TINT_SCOPED_ASSIGNMENT(
2572       current_compound_statement_,
2573       as_compound ? as_compound : current_compound_statement_);
2574   TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
2575 
2576   if (!callback()) {
2577     return nullptr;
2578   }
2579 
2580   return sem;
2581 }
2582 
VectorPretty(uint32_t size,const sem::Type * element_type)2583 std::string Resolver::VectorPretty(uint32_t size,
2584                                    const sem::Type* element_type) {
2585   sem::Vector vec_type(element_type, size);
2586   return vec_type.FriendlyName(builder_->Symbols());
2587 }
2588 
Mark(const ast::Node * node)2589 bool Resolver::Mark(const ast::Node* node) {
2590   if (node == nullptr) {
2591     TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";
2592     return false;
2593   }
2594   if (marked_.emplace(node).second) {
2595     return true;
2596   }
2597   TINT_ICE(Resolver, diagnostics_)
2598       << "AST node '" << node->TypeInfo().name
2599       << "' was encountered twice in the same AST of a Program\n"
2600       << "At: " << node->source << "\n"
2601       << "Pointer: " << node;
2602   return false;
2603 }
2604 
AddError(const std::string & msg,const Source & source) const2605 void Resolver::AddError(const std::string& msg, const Source& source) const {
2606   diagnostics_.add_error(diag::System::Resolver, msg, source);
2607 }
2608 
AddWarning(const std::string & msg,const Source & source) const2609 void Resolver::AddWarning(const std::string& msg, const Source& source) const {
2610   diagnostics_.add_warning(diag::System::Resolver, msg, source);
2611 }
2612 
AddNote(const std::string & msg,const Source & source) const2613 void Resolver::AddNote(const std::string& msg, const Source& source) const {
2614   diagnostics_.add_note(diag::System::Resolver, msg, source);
2615 }
2616 
2617 // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
IsPlain(const sem::Type * type) const2618 bool Resolver::IsPlain(const sem::Type* type) const {
2619   return type->is_scalar() ||
2620          type->IsAnyOf<sem::Atomic, sem::Vector, sem::Matrix, sem::Array,
2621                        sem::Struct>();
2622 }
2623 
2624 // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
IsStorable(const sem::Type * type) const2625 bool Resolver::IsStorable(const sem::Type* type) const {
2626   return IsPlain(type) || type->IsAnyOf<sem::Texture, sem::Sampler>();
2627 }
2628 
2629 // https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
IsHostShareable(const sem::Type * type) const2630 bool Resolver::IsHostShareable(const sem::Type* type) const {
2631   if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
2632     return true;
2633   }
2634   if (auto* vec = type->As<sem::Vector>()) {
2635     return IsHostShareable(vec->type());
2636   }
2637   if (auto* mat = type->As<sem::Matrix>()) {
2638     return IsHostShareable(mat->type());
2639   }
2640   if (auto* arr = type->As<sem::Array>()) {
2641     return IsHostShareable(arr->ElemType());
2642   }
2643   if (auto* str = type->As<sem::Struct>()) {
2644     for (auto* member : str->Members()) {
2645       if (!IsHostShareable(member->Type())) {
2646         return false;
2647       }
2648     }
2649     return true;
2650   }
2651   if (auto* atomic = type->As<sem::Atomic>()) {
2652     return IsHostShareable(atomic->Type());
2653   }
2654   return false;
2655 }
2656 
IsIntrinsic(Symbol symbol) const2657 bool Resolver::IsIntrinsic(Symbol symbol) const {
2658   std::string name = builder_->Symbols().NameFor(symbol);
2659   return sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone;
2660 }
2661 
IsCallStatement(const ast::Expression * expr) const2662 bool Resolver::IsCallStatement(const ast::Expression* expr) const {
2663   return current_statement_ &&
2664          Is<ast::CallStatement>(current_statement_->Declaration(),
2665                                 [&](auto* stmt) { return stmt->expr == expr; });
2666 }
2667 
ClosestContinuing(bool stop_at_loop) const2668 const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const {
2669   for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) {
2670     if (stop_at_loop && s->Is<sem::LoopStatement>()) {
2671       break;
2672     }
2673     if (s->Is<sem::LoopContinuingBlockStatement>()) {
2674       return s->Declaration();
2675     }
2676     if (auto* f = As<sem::ForLoopStatement>(s->Parent())) {
2677       if (f->Declaration()->continuing == s->Declaration()) {
2678         return s->Declaration();
2679       }
2680       if (stop_at_loop) {
2681         break;
2682       }
2683     }
2684   }
2685   return nullptr;
2686 }
2687 
2688 ////////////////////////////////////////////////////////////////////////////////
2689 // Resolver::TypeConversionSig
2690 ////////////////////////////////////////////////////////////////////////////////
operator ==(const TypeConversionSig & rhs) const2691 bool Resolver::TypeConversionSig::operator==(
2692     const TypeConversionSig& rhs) const {
2693   return target == rhs.target && source == rhs.source;
2694 }
operator ()(const TypeConversionSig & sig) const2695 std::size_t Resolver::TypeConversionSig::Hasher::operator()(
2696     const TypeConversionSig& sig) const {
2697   return utils::Hash(sig.target, sig.source);
2698 }
2699 
2700 ////////////////////////////////////////////////////////////////////////////////
2701 // Resolver::TypeConstructorSig
2702 ////////////////////////////////////////////////////////////////////////////////
TypeConstructorSig(const sem::Type * ty,const std::vector<const sem::Type * > params)2703 Resolver::TypeConstructorSig::TypeConstructorSig(
2704     const sem::Type* ty,
2705     const std::vector<const sem::Type*> params)
2706     : type(ty), parameters(params) {}
2707 Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) =
2708     default;
2709 Resolver::TypeConstructorSig::~TypeConstructorSig() = default;
2710 
operator ==(const TypeConstructorSig & rhs) const2711 bool Resolver::TypeConstructorSig::operator==(
2712     const TypeConstructorSig& rhs) const {
2713   return type == rhs.type && parameters == rhs.parameters;
2714 }
operator ()(const TypeConstructorSig & sig) const2715 std::size_t Resolver::TypeConstructorSig::Hasher::operator()(
2716     const TypeConstructorSig& sig) const {
2717   return utils::Hash(sig.type, sig.parameters);
2718 }
2719 
2720 }  // namespace resolver
2721 }  // namespace tint
2722