• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/resolver/resolver.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <limits>
21 #include <utility>
22 
23 #include "src/ast/alias.h"
24 #include "src/ast/array.h"
25 #include "src/ast/assignment_statement.h"
26 #include "src/ast/bitcast_expression.h"
27 #include "src/ast/break_statement.h"
28 #include "src/ast/call_statement.h"
29 #include "src/ast/continue_statement.h"
30 #include "src/ast/depth_texture.h"
31 #include "src/ast/disable_validation_decoration.h"
32 #include "src/ast/discard_statement.h"
33 #include "src/ast/fallthrough_statement.h"
34 #include "src/ast/for_loop_statement.h"
35 #include "src/ast/if_statement.h"
36 #include "src/ast/internal_decoration.h"
37 #include "src/ast/interpolate_decoration.h"
38 #include "src/ast/loop_statement.h"
39 #include "src/ast/matrix.h"
40 #include "src/ast/override_decoration.h"
41 #include "src/ast/pointer.h"
42 #include "src/ast/return_statement.h"
43 #include "src/ast/sampled_texture.h"
44 #include "src/ast/sampler.h"
45 #include "src/ast/storage_texture.h"
46 #include "src/ast/struct_block_decoration.h"
47 #include "src/ast/switch_statement.h"
48 #include "src/ast/traverse_expressions.h"
49 #include "src/ast/type_name.h"
50 #include "src/ast/unary_op_expression.h"
51 #include "src/ast/variable_decl_statement.h"
52 #include "src/ast/vector.h"
53 #include "src/ast/workgroup_decoration.h"
54 #include "src/sem/array.h"
55 #include "src/sem/atomic_type.h"
56 #include "src/sem/call.h"
57 #include "src/sem/depth_multisampled_texture_type.h"
58 #include "src/sem/depth_texture_type.h"
59 #include "src/sem/for_loop_statement.h"
60 #include "src/sem/function.h"
61 #include "src/sem/if_statement.h"
62 #include "src/sem/loop_statement.h"
63 #include "src/sem/member_accessor_expression.h"
64 #include "src/sem/multisampled_texture_type.h"
65 #include "src/sem/pointer_type.h"
66 #include "src/sem/reference_type.h"
67 #include "src/sem/sampled_texture_type.h"
68 #include "src/sem/sampler_type.h"
69 #include "src/sem/statement.h"
70 #include "src/sem/storage_texture_type.h"
71 #include "src/sem/struct.h"
72 #include "src/sem/switch_statement.h"
73 #include "src/sem/type_constructor.h"
74 #include "src/sem/type_conversion.h"
75 #include "src/sem/variable.h"
76 #include "src/utils/defer.h"
77 #include "src/utils/map.h"
78 #include "src/utils/math.h"
79 #include "src/utils/reverse.h"
80 #include "src/utils/scoped_assignment.h"
81 #include "src/utils/transform.h"
82 
83 namespace tint {
84 namespace resolver {
85 namespace {
86 
IsValidStorageTextureDimension(ast::TextureDimension dim)87 bool IsValidStorageTextureDimension(ast::TextureDimension dim) {
88   switch (dim) {
89     case ast::TextureDimension::k1d:
90     case ast::TextureDimension::k2d:
91     case ast::TextureDimension::k2dArray:
92     case ast::TextureDimension::k3d:
93       return true;
94     default:
95       return false;
96   }
97 }
98 
IsValidStorageTextureImageFormat(ast::ImageFormat format)99 bool IsValidStorageTextureImageFormat(ast::ImageFormat format) {
100   switch (format) {
101     case ast::ImageFormat::kR32Uint:
102     case ast::ImageFormat::kR32Sint:
103     case ast::ImageFormat::kR32Float:
104     case ast::ImageFormat::kRg32Uint:
105     case ast::ImageFormat::kRg32Sint:
106     case ast::ImageFormat::kRg32Float:
107     case ast::ImageFormat::kRgba8Unorm:
108     case ast::ImageFormat::kRgba8Snorm:
109     case ast::ImageFormat::kRgba8Uint:
110     case ast::ImageFormat::kRgba8Sint:
111     case ast::ImageFormat::kRgba16Uint:
112     case ast::ImageFormat::kRgba16Sint:
113     case ast::ImageFormat::kRgba16Float:
114     case ast::ImageFormat::kRgba32Uint:
115     case ast::ImageFormat::kRgba32Sint:
116     case ast::ImageFormat::kRgba32Float:
117       return true;
118     default:
119       return false;
120   }
121 }
122 
123 // Helper to stringify a pipeline IO decoration.
deco_to_str(const ast::Decoration * deco)124 std::string deco_to_str(const ast::Decoration* deco) {
125   std::stringstream str;
126   if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
127     str << "builtin(" << builtin->builtin << ")";
128   } else if (auto* location = deco->As<ast::LocationDecoration>()) {
129     str << "location(" << location->value << ")";
130   }
131   return str.str();
132 }
133 
134 template <typename CALLBACK>
TraverseCallChain(diag::List & diagnostics,const sem::Function * from,const sem::Function * to,CALLBACK && callback)135 void TraverseCallChain(diag::List& diagnostics,
136                        const sem::Function* from,
137                        const sem::Function* to,
138                        CALLBACK&& callback) {
139   for (auto* f : from->TransitivelyCalledFunctions()) {
140     if (f == to) {
141       callback(f);
142       return;
143     }
144     if (f->TransitivelyCalledFunctions().contains(to)) {
145       TraverseCallChain(diagnostics, f, to, callback);
146       callback(f);
147       return;
148     }
149   }
150   TINT_ICE(Resolver, diagnostics)
151       << "TraverseCallChain() 'from' does not transitively call 'to'";
152 }
153 
154 }  // namespace
155 
ValidateAtomic(const ast::Atomic * a,const sem::Atomic * s)156 bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) {
157   // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
158   // T must be either u32 or i32.
159   if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
160     AddError("atomic only supports i32 or u32 types",
161              a->type ? a->type->source : a->source);
162     return false;
163   }
164   return true;
165 }
166 
ValidateStorageTexture(const ast::StorageTexture * t)167 bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) {
168   switch (t->access) {
169     case ast::Access::kWrite:
170       break;
171     case ast::Access::kUndefined:
172       AddError("storage texture missing access control", t->source);
173       return false;
174     default:
175       AddError("storage textures currently only support 'write' access control",
176                t->source);
177       return false;
178   }
179 
180   if (!IsValidStorageTextureDimension(t->dim)) {
181     AddError("cube dimensions for storage textures are not supported",
182              t->source);
183     return false;
184   }
185 
186   if (!IsValidStorageTextureImageFormat(t->format)) {
187     AddError(
188         "image format must be one of the texel formats specified for storage "
189         "textues in https://gpuweb.github.io/gpuweb/wgsl/#texel-formats",
190         t->source);
191     return false;
192   }
193   return true;
194 }
195 
ValidateVariableConstructorOrCast(const ast::Variable * var,ast::StorageClass storage_class,const sem::Type * storage_ty,const sem::Type * rhs_ty)196 bool Resolver::ValidateVariableConstructorOrCast(
197     const ast::Variable* var,
198     ast::StorageClass storage_class,
199     const sem::Type* storage_ty,
200     const sem::Type* rhs_ty) {
201   auto* value_type = rhs_ty->UnwrapRef();  // Implicit load of RHS
202 
203   // Value type has to match storage type
204   if (storage_ty != value_type) {
205     std::string decl = var->is_const ? "let" : "var";
206     AddError("cannot initialize " + decl + " of type '" +
207                  TypeNameOf(storage_ty) + "' with value of type '" +
208                  TypeNameOf(rhs_ty) + "'",
209              var->source);
210     return false;
211   }
212 
213   if (!var->is_const) {
214     switch (storage_class) {
215       case ast::StorageClass::kPrivate:
216       case ast::StorageClass::kFunction:
217         break;  // Allowed an initializer
218       default:
219         // https://gpuweb.github.io/gpuweb/wgsl/#var-and-let
220         // Optionally has an initializer expression, if the variable is in the
221         // private or function storage classes.
222         AddError("var of storage class '" +
223                      std::string(ast::ToString(storage_class)) +
224                      "' cannot have an initializer. var initializers are only "
225                      "supported for the storage classes "
226                      "'private' and 'function'",
227                  var->source);
228         return false;
229     }
230   }
231 
232   return true;
233 }
234 
ValidateStorageClassLayout(const sem::Struct * str,ast::StorageClass sc)235 bool Resolver::ValidateStorageClassLayout(const sem::Struct* str,
236                                           ast::StorageClass sc) {
237   // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
238 
239   auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
240     return sc == ast::StorageClass::kUniform &&
241            ty->IsAnyOf<sem::Array, sem::Struct>();
242   };
243 
244   auto is_uniform_struct = [sc](const sem::Type* ty) {
245     return sc == ast::StorageClass::kUniform && ty->Is<sem::Struct>();
246   };
247 
248   auto required_alignment_of = [&](const sem::Type* ty) {
249     uint32_t actual_align = ty->Align();
250     uint32_t required_align = actual_align;
251     if (is_uniform_struct_or_array(ty)) {
252       required_align = utils::RoundUp(16u, actual_align);
253     }
254     return required_align;
255   };
256 
257   auto member_name_of = [this](const sem::StructMember* sm) {
258     return builder_->Symbols().NameFor(sm->Declaration()->symbol);
259   };
260 
261   auto type_name_of = [this](const sem::StructMember* sm) {
262     return TypeNameOf(sm->Type());
263   };
264 
265   // TODO(amaiorano): Output struct and member decorations so that this output
266   // can be copied verbatim back into source
267   auto get_struct_layout_string = [&](const sem::Struct* st) -> std::string {
268     std::stringstream ss;
269 
270     if (st->Members().empty()) {
271       TINT_ICE(Resolver, diagnostics_) << "Validation should have ensured that "
272                                           "structs have at least one member";
273       return {};
274     }
275     const auto* const last_member = st->Members().back();
276     const uint32_t last_member_struct_padding_offset =
277         last_member->Offset() + last_member->Size();
278 
279     // Compute max widths to align output
280     const auto offset_w =
281         static_cast<int>(::log10(last_member_struct_padding_offset)) + 1;
282     const auto size_w = static_cast<int>(::log10(st->Size())) + 1;
283     const auto align_w = static_cast<int>(::log10(st->Align())) + 1;
284 
285     auto print_struct_begin_line = [&](size_t align, size_t size,
286                                        std::string struct_name) {
287       ss << "/*          " << std::setw(offset_w) << " "
288          << "align(" << std::setw(align_w) << align << ") size("
289          << std::setw(size_w) << size << ") */ struct " << struct_name
290          << " {\n";
291     };
292 
293     auto print_struct_end_line = [&]() {
294       ss << "/*                         "
295          << std::setw(offset_w + size_w + align_w) << " "
296          << "*/ };";
297     };
298 
299     auto print_member_line = [&](size_t offset, size_t align, size_t size,
300                                  std::string s) {
301       ss << "/* offset(" << std::setw(offset_w) << offset << ") align("
302          << std::setw(align_w) << align << ") size(" << std::setw(size_w)
303          << size << ") */   " << s << ";\n";
304     };
305 
306     print_struct_begin_line(st->Align(), st->Size(), TypeNameOf(st));
307 
308     for (size_t i = 0; i < st->Members().size(); ++i) {
309       auto* const m = st->Members()[i];
310 
311       // Output field alignment padding, if any
312       auto* const prev_member = (i == 0) ? nullptr : str->Members()[i - 1];
313       if (prev_member) {
314         uint32_t padding =
315             m->Offset() - (prev_member->Offset() + prev_member->Size());
316         if (padding > 0) {
317           size_t padding_offset = m->Offset() - padding;
318           print_member_line(padding_offset, 1, padding,
319                             "// -- implicit field alignment padding --");
320         }
321       }
322 
323       // Output member
324       std::string member_name = member_name_of(m);
325       print_member_line(m->Offset(), m->Align(), m->Size(),
326                         member_name_of(m) + " : " + type_name_of(m));
327     }
328 
329     // Output struct size padding, if any
330     uint32_t struct_padding = st->Size() - last_member_struct_padding_offset;
331     if (struct_padding > 0) {
332       print_member_line(last_member_struct_padding_offset, 1, struct_padding,
333                         "// -- implicit struct size padding --");
334     }
335 
336     print_struct_end_line();
337 
338     return ss.str();
339   };
340 
341   if (!ast::IsHostShareable(sc)) {
342     return true;
343   }
344 
345   for (size_t i = 0; i < str->Members().size(); ++i) {
346     auto* const m = str->Members()[i];
347     uint32_t required_align = required_alignment_of(m->Type());
348 
349     // Validate that member is at a valid byte offset
350     if (m->Offset() % required_align != 0) {
351       AddError("the offset of a struct member of type '" + type_name_of(m) +
352                    "' in storage class '" + ast::ToString(sc) +
353                    "' must be a multiple of " + std::to_string(required_align) +
354                    " bytes, but '" + member_name_of(m) +
355                    "' is currently at offset " + std::to_string(m->Offset()) +
356                    ". Consider setting [[align(" +
357                    std::to_string(required_align) + ")]] on this member",
358                m->Declaration()->source);
359 
360       AddNote("see layout of struct:\n" + get_struct_layout_string(str),
361               str->Declaration()->source);
362 
363       if (auto* member_str = m->Type()->As<sem::Struct>()) {
364         AddNote("and layout of struct member:\n" +
365                     get_struct_layout_string(member_str),
366                 member_str->Declaration()->source);
367       }
368 
369       return false;
370     }
371 
372     // For uniform buffers, validate that the number of bytes between the
373     // previous member of type struct and the current is a multiple of 16 bytes.
374     auto* const prev_member = (i == 0) ? nullptr : str->Members()[i - 1];
375     if (prev_member && is_uniform_struct(prev_member->Type())) {
376       const uint32_t prev_to_curr_offset = m->Offset() - prev_member->Offset();
377       if (prev_to_curr_offset % 16 != 0) {
378         AddError(
379             "uniform storage requires that the number of bytes between the "
380             "start of the previous member of type struct and the current "
381             "member be a multiple of 16 bytes, but there are currently " +
382                 std::to_string(prev_to_curr_offset) + " bytes between '" +
383                 member_name_of(prev_member) + "' and '" + member_name_of(m) +
384                 "'. Consider setting [[align(16)]] on this member",
385             m->Declaration()->source);
386 
387         AddNote("see layout of struct:\n" + get_struct_layout_string(str),
388                 str->Declaration()->source);
389 
390         auto* prev_member_str = prev_member->Type()->As<sem::Struct>();
391         AddNote("and layout of previous member struct:\n" +
392                     get_struct_layout_string(prev_member_str),
393                 prev_member_str->Declaration()->source);
394         return false;
395       }
396     }
397 
398     // For uniform buffer array members, validate that array elements are
399     // aligned to 16 bytes
400     if (auto* arr = m->Type()->As<sem::Array>()) {
401       if (sc == ast::StorageClass::kUniform) {
402         // We already validated that this array member is itself aligned to 16
403         // bytes above, so we only need to validate that stride is a multiple of
404         // 16 bytes.
405         if (arr->Stride() % 16 != 0) {
406           AddError(
407               "uniform storage requires that array elements be aligned to 16 "
408               "bytes, but array stride of '" +
409                   member_name_of(m) + "' is currently " +
410                   std::to_string(arr->Stride()) +
411                   ". Consider setting [[stride(" +
412                   std::to_string(
413                       utils::RoundUp(required_align, arr->Stride())) +
414                   ")]] on the array type",
415               m->Declaration()->type->source);
416           AddNote("see layout of struct:\n" + get_struct_layout_string(str),
417                   str->Declaration()->source);
418           return false;
419         }
420       }
421     }
422 
423     // If member is struct, recurse
424     if (auto* str_member = m->Type()->As<sem::Struct>()) {
425       // Cache result of struct + storage class pair
426       if (valid_struct_storage_layouts_.emplace(str_member, sc).second) {
427         if (!ValidateStorageClassLayout(str_member, sc)) {
428           return false;
429         }
430       }
431     }
432   }
433 
434   return true;
435 }
436 
ValidateStorageClassLayout(const sem::Variable * var)437 bool Resolver::ValidateStorageClassLayout(const sem::Variable* var) {
438   if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
439     if (!ValidateStorageClassLayout(str, var->StorageClass())) {
440       AddNote("see declaration of variable", var->Declaration()->source);
441       return false;
442     }
443   }
444 
445   return true;
446 }
447 
ValidateGlobalVariable(const sem::Variable * var)448 bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
449   auto* decl = var->Declaration();
450   if (!ValidateNoDuplicateDecorations(decl->decorations)) {
451     return false;
452   }
453 
454   for (auto* deco : decl->decorations) {
455     if (decl->is_const) {
456       if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
457         if (override_deco->has_value) {
458           uint32_t id = override_deco->value;
459           auto it = constant_ids_.find(id);
460           if (it != constant_ids_.end() && it->second != var) {
461             AddError("pipeline constant IDs must be unique", deco->source);
462             AddNote("a pipeline constant with an ID of " + std::to_string(id) +
463                         " was previously declared "
464                         "here:",
465                     ast::GetDecoration<ast::OverrideDecoration>(
466                         it->second->Declaration()->decorations)
467                         ->source);
468             return false;
469           }
470           if (id > 65535) {
471             AddError("pipeline constant IDs must be between 0 and 65535",
472                      deco->source);
473             return false;
474           }
475         }
476       } else {
477         AddError("decoration is not valid for constants", deco->source);
478         return false;
479       }
480     } else {
481       bool is_shader_io_decoration =
482           deco->IsAnyOf<ast::BuiltinDecoration, ast::InterpolateDecoration,
483                         ast::InvariantDecoration, ast::LocationDecoration>();
484       bool has_io_storage_class =
485           var->StorageClass() == ast::StorageClass::kInput ||
486           var->StorageClass() == ast::StorageClass::kOutput;
487       if (!(deco->IsAnyOf<ast::BindingDecoration, ast::GroupDecoration,
488                           ast::InternalDecoration>()) &&
489           (!is_shader_io_decoration || !has_io_storage_class)) {
490         AddError("decoration is not valid for variables", deco->source);
491         return false;
492       }
493     }
494   }
495 
496   auto binding_point = decl->BindingPoint();
497   switch (var->StorageClass()) {
498     case ast::StorageClass::kUniform:
499     case ast::StorageClass::kStorage:
500     case ast::StorageClass::kUniformConstant: {
501       // https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
502       // Each resource variable must be declared with both group and binding
503       // attributes.
504       if (!binding_point) {
505         AddError(
506             "resource variables require [[group]] and [[binding]] "
507             "decorations",
508             decl->source);
509         return false;
510       }
511       break;
512     }
513     default:
514       if (binding_point.binding || binding_point.group) {
515         // https://gpuweb.github.io/gpuweb/wgsl/#attribute-binding
516         // Must only be applied to a resource variable
517         AddError(
518             "non-resource variables must not have [[group]] or [[binding]] "
519             "decorations",
520             decl->source);
521         return false;
522       }
523   }
524 
525   // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
526   // The access mode always has a default, and except for variables in the
527   // storage storage class, must not be written.
528   if (var->StorageClass() != ast::StorageClass::kStorage &&
529       decl->declared_access != ast::Access::kUndefined) {
530     AddError(
531         "only variables in <storage> storage class may declare an access mode",
532         decl->source);
533     return false;
534   }
535 
536   switch (var->StorageClass()) {
537     case ast::StorageClass::kStorage: {
538       // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
539       // A variable in the storage storage class is a storage buffer variable.
540       // Its store type must be a host-shareable structure type with block
541       // attribute, satisfying the storage class constraints.
542 
543       auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
544 
545       if (!str) {
546         AddError(
547             "variables declared in the <storage> storage class must be of a "
548             "structure type",
549             decl->source);
550         return false;
551       }
552 
553       if (!str->IsBlockDecorated()) {
554         AddError(
555             "structure used as a storage buffer must be declared with the "
556             "[[block]] decoration",
557             str->Declaration()->source);
558         if (decl->source.range.begin.line) {
559           AddNote("structure used as storage buffer here", decl->source);
560         }
561         return false;
562       }
563       break;
564     }
565     case ast::StorageClass::kUniform: {
566       // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
567       // A variable in the uniform storage class is a uniform buffer variable.
568       // Its store type must be a host-shareable structure type with block
569       // attribute, satisfying the storage class constraints.
570       auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
571       if (!str) {
572         AddError(
573             "variables declared in the <uniform> storage class must be of a "
574             "structure type",
575             decl->source);
576         return false;
577       }
578 
579       if (!str->IsBlockDecorated()) {
580         AddError(
581             "structure used as a uniform buffer must be declared with the "
582             "[[block]] decoration",
583             str->Declaration()->source);
584         if (decl->source.range.begin.line) {
585           AddNote("structure used as uniform buffer here", decl->source);
586         }
587         return false;
588       }
589 
590       for (auto* member : str->Members()) {
591         if (auto* arr = member->Type()->As<sem::Array>()) {
592           if (arr->IsRuntimeSized()) {
593             AddError(
594                 "structure containing a runtime sized array "
595                 "cannot be used as a uniform buffer",
596                 decl->source);
597             AddNote("structure is declared here", str->Declaration()->source);
598             return false;
599           }
600         }
601       }
602 
603       break;
604     }
605     default:
606       break;
607   }
608 
609   if (!decl->is_const) {
610     if (!ValidateAtomicVariable(var)) {
611       return false;
612     }
613   }
614 
615   return ValidateVariable(var);
616 }
617 
618 // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
619 // Atomic types may only be instantiated by variables in the workgroup storage
620 // class or by storage buffer variables with a read_write access mode.
ValidateAtomicVariable(const sem::Variable * var)621 bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
622   auto sc = var->StorageClass();
623   auto* decl = var->Declaration();
624   auto access = var->Access();
625   auto* type = var->Type()->UnwrapRef();
626   auto source = decl->type ? decl->type->source : decl->source;
627 
628   if (type->Is<sem::Atomic>()) {
629     if (sc != ast::StorageClass::kWorkgroup) {
630       AddError(
631           "atomic variables must have <storage> or <workgroup> storage class",
632           source);
633       return false;
634     }
635   } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
636     auto found = atomic_composite_info_.find(type);
637     if (found != atomic_composite_info_.end()) {
638       if (sc != ast::StorageClass::kStorage &&
639           sc != ast::StorageClass::kWorkgroup) {
640         AddError(
641             "atomic variables must have <storage> or <workgroup> storage class",
642             source);
643         AddNote(
644             "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
645             found->second);
646         return false;
647       } else if (sc == ast::StorageClass::kStorage &&
648                  access != ast::Access::kReadWrite) {
649         AddError(
650             "atomic variables in <storage> storage class must have read_write "
651             "access mode",
652             source);
653         AddNote(
654             "atomic sub-type of '" + TypeNameOf(type) + "' is declared here",
655             found->second);
656         return false;
657       }
658     }
659   }
660 
661   return true;
662 }
663 
ValidateVariable(const sem::Variable * var)664 bool Resolver::ValidateVariable(const sem::Variable* var) {
665   auto* decl = var->Declaration();
666   auto* storage_ty = var->Type()->UnwrapRef();
667 
668   if (var->Is<sem::GlobalVariable>()) {
669     auto name = builder_->Symbols().NameFor(decl->symbol);
670     if (sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone) {
671       auto* kind = var->Declaration()->is_const ? "let" : "var";
672       AddError(
673           "'" + name +
674               "' is a builtin and cannot be redeclared as a module-scope " +
675               kind,
676           decl->source);
677       return false;
678     }
679   }
680 
681   if (!decl->is_const && !IsStorable(storage_ty)) {
682     AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a var",
683              decl->source);
684     return false;
685   }
686 
687   if (decl->is_const && !var->Is<sem::Parameter>() &&
688       !(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
689     AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a let",
690              decl->source);
691     return false;
692   }
693 
694   if (auto* r = storage_ty->As<sem::Array>()) {
695     if (r->IsRuntimeSized()) {
696       AddError("runtime arrays may only appear as the last member of a struct",
697                decl->source);
698       return false;
699     }
700   }
701 
702   if (auto* r = storage_ty->As<sem::MultisampledTexture>()) {
703     if (r->dim() != ast::TextureDimension::k2d) {
704       AddError("only 2d multisampled textures are supported", decl->source);
705       return false;
706     }
707 
708     if (!r->type()->UnwrapRef()->is_numeric_scalar()) {
709       AddError("texture_multisampled_2d<type>: type must be f32, i32 or u32",
710                decl->source);
711       return false;
712     }
713   }
714 
715   if (var->Is<sem::LocalVariable>() && !decl->is_const &&
716       IsValidationEnabled(decl->decorations,
717                           ast::DisabledValidation::kIgnoreStorageClass)) {
718     if (!var->Type()->UnwrapRef()->IsConstructible()) {
719       AddError("function variable must have a constructible type",
720                decl->type ? decl->type->source : decl->source);
721       return false;
722     }
723   }
724 
725   if (storage_ty->is_handle() &&
726       decl->declared_storage_class != ast::StorageClass::kNone) {
727     // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
728     // If the store type is a texture type or a sampler type, then the
729     // variable declaration must not have a storage class decoration. The
730     // storage class will always be handle.
731     AddError("variables of type '" + TypeNameOf(storage_ty) +
732                  "' must not have a storage class",
733              decl->source);
734     return false;
735   }
736 
737   if (IsValidationEnabled(decl->decorations,
738                           ast::DisabledValidation::kIgnoreStorageClass) &&
739       (decl->declared_storage_class == ast::StorageClass::kInput ||
740        decl->declared_storage_class == ast::StorageClass::kOutput)) {
741     AddError("invalid use of input/output storage class", decl->source);
742     return false;
743   }
744   return true;
745 }
746 
ValidateFunctionParameter(const ast::Function * func,const sem::Variable * var)747 bool Resolver::ValidateFunctionParameter(const ast::Function* func,
748                                          const sem::Variable* var) {
749   if (!ValidateVariable(var)) {
750     return false;
751   }
752 
753   auto* decl = var->Declaration();
754 
755   for (auto* deco : decl->decorations) {
756     if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) {
757       AddError(
758           "decoration is not valid for non-entry point function parameters",
759           deco->source);
760       return false;
761     } else if (!deco->IsAnyOf<ast::BuiltinDecoration, ast::InvariantDecoration,
762                               ast::LocationDecoration,
763                               ast::InterpolateDecoration,
764                               ast::InternalDecoration>() &&
765                (IsValidationEnabled(
766                     decl->decorations,
767                     ast::DisabledValidation::kEntryPointParameter) &&
768                 IsValidationEnabled(
769                     decl->decorations,
770                     ast::DisabledValidation::
771                         kIgnoreConstructibleFunctionParameter))) {
772       AddError("decoration is not valid for function parameters", deco->source);
773       return false;
774     }
775   }
776 
777   if (auto* ref = var->Type()->As<sem::Pointer>()) {
778     auto sc = ref->StorageClass();
779     if (!(sc == ast::StorageClass::kFunction ||
780           sc == ast::StorageClass::kPrivate ||
781           sc == ast::StorageClass::kWorkgroup) &&
782         IsValidationEnabled(decl->decorations,
783                             ast::DisabledValidation::kIgnoreStorageClass)) {
784       std::stringstream ss;
785       ss << "function parameter of pointer type cannot be in '" << sc
786          << "' storage class";
787       AddError(ss.str(), decl->source);
788       return false;
789     }
790   }
791 
792   if (IsPlain(var->Type())) {
793     if (!var->Type()->IsConstructible() &&
794         IsValidationEnabled(
795             decl->decorations,
796             ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) {
797       AddError("store type of function parameter must be a constructible type",
798                decl->source);
799       return false;
800     }
801   } else if (!var->Type()
802                   ->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
803     AddError(
804         "store type of function parameter cannot be " + TypeNameOf(var->Type()),
805         decl->source);
806     return false;
807   }
808 
809   return true;
810 }
811 
ValidateBuiltinDecoration(const ast::BuiltinDecoration * deco,const sem::Type * storage_ty,const bool is_input)812 bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
813                                          const sem::Type* storage_ty,
814                                          const bool is_input) {
815   auto* type = storage_ty->UnwrapRef();
816   const auto stage = current_function_
817                          ? current_function_->Declaration()->PipelineStage()
818                          : ast::PipelineStage::kNone;
819   std::stringstream stage_name;
820   stage_name << stage;
821   bool is_stage_mismatch = false;
822   bool is_output = !is_input;
823   switch (deco->builtin) {
824     case ast::Builtin::kPosition:
825       if (stage != ast::PipelineStage::kNone &&
826           !((is_input && stage == ast::PipelineStage::kFragment) ||
827             (is_output && stage == ast::PipelineStage::kVertex))) {
828         is_stage_mismatch = true;
829       }
830       if (!(type->is_float_vector() && type->As<sem::Vector>()->Width() == 4)) {
831         AddError("store type of " + deco_to_str(deco) + " must be 'vec4<f32>'",
832                  deco->source);
833         return false;
834       }
835       break;
836     case ast::Builtin::kGlobalInvocationId:
837     case ast::Builtin::kLocalInvocationId:
838     case ast::Builtin::kNumWorkgroups:
839     case ast::Builtin::kWorkgroupId:
840       if (stage != ast::PipelineStage::kNone &&
841           !(stage == ast::PipelineStage::kCompute && is_input)) {
842         is_stage_mismatch = true;
843       }
844       if (!(type->is_unsigned_integer_vector() &&
845             type->As<sem::Vector>()->Width() == 3)) {
846         AddError("store type of " + deco_to_str(deco) + " must be 'vec3<u32>'",
847                  deco->source);
848         return false;
849       }
850       break;
851     case ast::Builtin::kFragDepth:
852       if (stage != ast::PipelineStage::kNone &&
853           !(stage == ast::PipelineStage::kFragment && !is_input)) {
854         is_stage_mismatch = true;
855       }
856       if (!type->Is<sem::F32>()) {
857         AddError("store type of " + deco_to_str(deco) + " must be 'f32'",
858                  deco->source);
859         return false;
860       }
861       break;
862     case ast::Builtin::kFrontFacing:
863       if (stage != ast::PipelineStage::kNone &&
864           !(stage == ast::PipelineStage::kFragment && is_input)) {
865         is_stage_mismatch = true;
866       }
867       if (!type->Is<sem::Bool>()) {
868         AddError("store type of " + deco_to_str(deco) + " must be 'bool'",
869                  deco->source);
870         return false;
871       }
872       break;
873     case ast::Builtin::kLocalInvocationIndex:
874       if (stage != ast::PipelineStage::kNone &&
875           !(stage == ast::PipelineStage::kCompute && is_input)) {
876         is_stage_mismatch = true;
877       }
878       if (!type->Is<sem::U32>()) {
879         AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
880                  deco->source);
881         return false;
882       }
883       break;
884     case ast::Builtin::kVertexIndex:
885     case ast::Builtin::kInstanceIndex:
886       if (stage != ast::PipelineStage::kNone &&
887           !(stage == ast::PipelineStage::kVertex && is_input)) {
888         is_stage_mismatch = true;
889       }
890       if (!type->Is<sem::U32>()) {
891         AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
892                  deco->source);
893         return false;
894       }
895       break;
896     case ast::Builtin::kSampleMask:
897       if (stage != ast::PipelineStage::kNone &&
898           !(stage == ast::PipelineStage::kFragment)) {
899         is_stage_mismatch = true;
900       }
901       if (!type->Is<sem::U32>()) {
902         AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
903                  deco->source);
904         return false;
905       }
906       break;
907     case ast::Builtin::kSampleIndex:
908       if (stage != ast::PipelineStage::kNone &&
909           !(stage == ast::PipelineStage::kFragment && is_input)) {
910         is_stage_mismatch = true;
911       }
912       if (!type->Is<sem::U32>()) {
913         AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
914                  deco->source);
915         return false;
916       }
917       break;
918     default:
919       break;
920   }
921 
922   if (is_stage_mismatch) {
923     AddError(deco_to_str(deco) + " cannot be used in " +
924                  (is_input ? "input of " : "output of ") + stage_name.str() +
925                  " pipeline stage",
926              deco->source);
927     return false;
928   }
929 
930   return true;
931 }
932 
ValidateInterpolateDecoration(const ast::InterpolateDecoration * deco,const sem::Type * storage_ty)933 bool Resolver::ValidateInterpolateDecoration(
934     const ast::InterpolateDecoration* deco,
935     const sem::Type* storage_ty) {
936   auto* type = storage_ty->UnwrapRef();
937 
938   if (type->is_integer_scalar_or_vector() &&
939       deco->type != ast::InterpolationType::kFlat) {
940     AddError(
941         "interpolation type must be 'flat' for integral user-defined IO types",
942         deco->source);
943     return false;
944   }
945 
946   if (deco->type == ast::InterpolationType::kFlat &&
947       deco->sampling != ast::InterpolationSampling::kNone) {
948     AddError("flat interpolation attribute must not have a sampling parameter",
949              deco->source);
950     return false;
951   }
952 
953   return true;
954 }
955 
ValidateFunction(const sem::Function * func)956 bool Resolver::ValidateFunction(const sem::Function* func) {
957   auto* decl = func->Declaration();
958 
959   auto name = builder_->Symbols().NameFor(decl->symbol);
960   if (sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone) {
961     AddError(
962         "'" + name + "' is a builtin and cannot be redeclared as a function",
963         decl->source);
964     return false;
965   }
966 
967   auto workgroup_deco_count = 0;
968   for (auto* deco : decl->decorations) {
969     if (deco->Is<ast::WorkgroupDecoration>()) {
970       workgroup_deco_count++;
971       if (decl->PipelineStage() != ast::PipelineStage::kCompute) {
972         AddError(
973             "the workgroup_size attribute is only valid for compute stages",
974             deco->source);
975         return false;
976       }
977     } else if (!deco->IsAnyOf<ast::StageDecoration,
978                               ast::InternalDecoration>()) {
979       AddError("decoration is not valid for functions", deco->source);
980       return false;
981     }
982   }
983 
984   if (decl->params.size() > 255) {
985     AddError("functions may declare at most 255 parameters", decl->source);
986     return false;
987   }
988 
989   for (size_t i = 0; i < decl->params.size(); i++) {
990     if (!ValidateFunctionParameter(decl, func->Parameters()[i])) {
991       return false;
992     }
993   }
994 
995   if (!func->ReturnType()->Is<sem::Void>()) {
996     if (!func->ReturnType()->IsConstructible()) {
997       AddError("function return type must be a constructible type",
998                decl->return_type->source);
999       return false;
1000     }
1001 
1002     if (decl->body) {
1003       sem::Behaviors behaviors{sem::Behavior::kNext};
1004       if (auto* last = decl->body->Last()) {
1005         behaviors = Sem(last)->Behaviors();
1006       }
1007       if (behaviors.Contains(sem::Behavior::kNext)) {
1008         AddError("missing return at end of function", decl->source);
1009         return false;
1010       }
1011     } else if (IsValidationEnabled(
1012                    decl->decorations,
1013                    ast::DisabledValidation::kFunctionHasNoBody)) {
1014       TINT_ICE(Resolver, diagnostics_)
1015           << "Function " << builder_->Symbols().NameFor(decl->symbol)
1016           << " has no body";
1017     }
1018 
1019     for (auto* deco : decl->return_type_decorations) {
1020       if (!decl->IsEntryPoint()) {
1021         AddError(
1022             "decoration is not valid for non-entry point function return types",
1023             deco->source);
1024         return false;
1025       }
1026       if (!deco->IsAnyOf<ast::BuiltinDecoration, ast::InternalDecoration,
1027                          ast::LocationDecoration, ast::InterpolateDecoration,
1028                          ast::InvariantDecoration>() &&
1029           (IsValidationEnabled(decl->decorations,
1030                                ast::DisabledValidation::kEntryPointParameter) &&
1031            IsValidationEnabled(decl->decorations,
1032                                ast::DisabledValidation::
1033                                    kIgnoreConstructibleFunctionParameter))) {
1034         AddError("decoration is not valid for entry point return types",
1035                  deco->source);
1036         return false;
1037       }
1038     }
1039   }
1040 
1041   if (decl->IsEntryPoint()) {
1042     if (!ValidateEntryPoint(func)) {
1043       return false;
1044     }
1045   }
1046 
1047   // https://www.w3.org/TR/WGSL/#behaviors-rules
1048   // a function behavior is always one of {}, {Next}, {Discard}, or
1049   // {Next, Discard}.
1050   if (func->Behaviors() != sem::Behaviors{} &&  // NOLINT: bad warning
1051       func->Behaviors() != sem::Behavior::kNext &&
1052       func->Behaviors() != sem::Behavior::kDiscard &&
1053       func->Behaviors() != sem::Behaviors{sem::Behavior::kNext,  //
1054                                           sem::Behavior::kDiscard}) {
1055     TINT_ICE(Resolver, diagnostics_)
1056         << "function '" << name << "' behaviors are: " << func->Behaviors();
1057   }
1058 
1059   return true;
1060 }
1061 
ValidateEntryPoint(const sem::Function * func)1062 bool Resolver::ValidateEntryPoint(const sem::Function* func) {
1063   auto* decl = func->Declaration();
1064 
1065   // Use a lambda to validate the entry point decorations for a type.
1066   // Persistent state is used to track which builtins and locations have
1067   // already been seen, in order to catch conflicts.
1068   // TODO(jrprice): This state could be stored in sem::Function instead, and
1069   // then passed to sem::Function since it would be useful there too.
1070   std::unordered_set<ast::Builtin> builtins;
1071   std::unordered_set<uint32_t> locations;
1072   enum class ParamOrRetType {
1073     kParameter,
1074     kReturnType,
1075   };
1076 
1077   // Inner lambda that is applied to a type and all of its members.
1078   auto validate_entry_point_decorations_inner = [&](const ast::DecorationList&
1079                                                         decos,
1080                                                     const sem::Type* ty,
1081                                                     Source source,
1082                                                     ParamOrRetType param_or_ret,
1083                                                     bool is_struct_member) {
1084     // Scan decorations for pipeline IO attributes.
1085     // Check for overlap with attributes that have been seen previously.
1086     const ast::Decoration* pipeline_io_attribute = nullptr;
1087     const ast::InterpolateDecoration* interpolate_attribute = nullptr;
1088     const ast::InvariantDecoration* invariant_attribute = nullptr;
1089     for (auto* deco : decos) {
1090       auto is_invalid_compute_shader_decoration = false;
1091       if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
1092         if (pipeline_io_attribute) {
1093           AddError("multiple entry point IO attributes", deco->source);
1094           AddNote("previously consumed " + deco_to_str(pipeline_io_attribute),
1095                   pipeline_io_attribute->source);
1096           return false;
1097         }
1098         pipeline_io_attribute = deco;
1099 
1100         if (builtins.count(builtin->builtin)) {
1101           AddError(deco_to_str(builtin) +
1102                        " attribute appears multiple times as pipeline " +
1103                        (param_or_ret == ParamOrRetType::kParameter ? "input"
1104                                                                    : "output"),
1105                    decl->source);
1106           return false;
1107         }
1108 
1109         if (!ValidateBuiltinDecoration(
1110                 builtin, ty,
1111                 /* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
1112           return false;
1113         }
1114         builtins.emplace(builtin->builtin);
1115       } else if (auto* location = deco->As<ast::LocationDecoration>()) {
1116         if (pipeline_io_attribute) {
1117           AddError("multiple entry point IO attributes", deco->source);
1118           AddNote("previously consumed " + deco_to_str(pipeline_io_attribute),
1119                   pipeline_io_attribute->source);
1120           return false;
1121         }
1122         pipeline_io_attribute = deco;
1123 
1124         bool is_input = param_or_ret == ParamOrRetType::kParameter;
1125         if (!ValidateLocationDecoration(location, ty, locations, source,
1126                                         is_input)) {
1127           return false;
1128         }
1129       } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
1130         if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
1131           is_invalid_compute_shader_decoration = true;
1132         } else if (!ValidateInterpolateDecoration(interpolate, ty)) {
1133           return false;
1134         }
1135         interpolate_attribute = interpolate;
1136       } else if (auto* invariant = deco->As<ast::InvariantDecoration>()) {
1137         if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
1138           is_invalid_compute_shader_decoration = true;
1139         }
1140         invariant_attribute = invariant;
1141       }
1142       if (is_invalid_compute_shader_decoration) {
1143         std::string input_or_output =
1144             param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output";
1145         AddError(
1146             "decoration is not valid for compute shader " + input_or_output,
1147             deco->source);
1148         return false;
1149       }
1150     }
1151 
1152     if (IsValidationEnabled(decos,
1153                             ast::DisabledValidation::kEntryPointParameter)) {
1154       if (is_struct_member && ty->Is<sem::Struct>()) {
1155         AddError("nested structures cannot be used for entry point IO", source);
1156         return false;
1157       }
1158 
1159       if (!ty->Is<sem::Struct>() && !pipeline_io_attribute) {
1160         std::string err = "missing entry point IO attribute";
1161         if (!is_struct_member) {
1162           err +=
1163               (param_or_ret == ParamOrRetType::kParameter ? " on parameter"
1164                                                           : " on return type");
1165         }
1166         AddError(err, source);
1167         return false;
1168       }
1169 
1170       if (pipeline_io_attribute &&
1171           pipeline_io_attribute->Is<ast::LocationDecoration>()) {
1172         if (ty->is_integer_scalar_or_vector() && !interpolate_attribute) {
1173           // TODO(crbug.com/tint/1224): Make these errors once downstream
1174           // usages have caught up (no sooner than M99).
1175           if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
1176               param_or_ret == ParamOrRetType::kReturnType) {
1177             AddWarning(
1178                 "integral user-defined vertex outputs must have a flat "
1179                 "interpolation attribute",
1180                 source);
1181           }
1182           if (decl->PipelineStage() == ast::PipelineStage::kFragment &&
1183               param_or_ret == ParamOrRetType::kParameter) {
1184             AddWarning(
1185                 "integral user-defined fragment inputs must have a flat "
1186                 "interpolation attribute",
1187                 source);
1188           }
1189         }
1190       }
1191 
1192       if (interpolate_attribute) {
1193         if (!pipeline_io_attribute ||
1194             !pipeline_io_attribute->Is<ast::LocationDecoration>()) {
1195           AddError("interpolate attribute must only be used with [[location]]",
1196                    interpolate_attribute->source);
1197           return false;
1198         }
1199       }
1200 
1201       if (invariant_attribute) {
1202         bool has_position = false;
1203         if (pipeline_io_attribute) {
1204           if (auto* builtin =
1205                   pipeline_io_attribute->As<ast::BuiltinDecoration>()) {
1206             has_position = (builtin->builtin == ast::Builtin::kPosition);
1207           }
1208         }
1209         if (!has_position) {
1210           AddError(
1211               "invariant attribute must only be applied to a position "
1212               "builtin",
1213               invariant_attribute->source);
1214           return false;
1215         }
1216       }
1217     }
1218     return true;
1219   };
1220 
1221   // Outer lambda for validating the entry point decorations for a type.
1222   auto validate_entry_point_decorations = [&](const ast::DecorationList& decos,
1223                                               const sem::Type* ty,
1224                                               Source source,
1225                                               ParamOrRetType param_or_ret) {
1226     if (!validate_entry_point_decorations_inner(decos, ty, source, param_or_ret,
1227                                                 /*is_struct_member*/ false)) {
1228       return false;
1229     }
1230 
1231     if (auto* str = ty->As<sem::Struct>()) {
1232       for (auto* member : str->Members()) {
1233         if (!validate_entry_point_decorations_inner(
1234                 member->Declaration()->decorations, member->Type(),
1235                 member->Declaration()->source, param_or_ret,
1236                 /*is_struct_member*/ true)) {
1237           AddNote("while analysing entry point '" +
1238                       builder_->Symbols().NameFor(decl->symbol) + "'",
1239                   decl->source);
1240           return false;
1241         }
1242       }
1243     }
1244 
1245     return true;
1246   };
1247 
1248   for (auto* param : func->Parameters()) {
1249     auto* param_decl = param->Declaration();
1250     if (!validate_entry_point_decorations(param_decl->decorations,
1251                                           param->Type(), param_decl->source,
1252                                           ParamOrRetType::kParameter)) {
1253       return false;
1254     }
1255   }
1256 
1257   // Clear IO sets after parameter validation. Builtin and location attributes
1258   // in return types should be validated independently from those used in
1259   // parameters.
1260   builtins.clear();
1261   locations.clear();
1262 
1263   if (!func->ReturnType()->Is<sem::Void>()) {
1264     if (!validate_entry_point_decorations(decl->return_type_decorations,
1265                                           func->ReturnType(), decl->source,
1266                                           ParamOrRetType::kReturnType)) {
1267       return false;
1268     }
1269   }
1270 
1271   if (decl->PipelineStage() == ast::PipelineStage::kVertex &&
1272       builtins.count(ast::Builtin::kPosition) == 0) {
1273     // Check module-scope variables, as the SPIR-V sanitizer generates these.
1274     bool found = false;
1275     for (auto* global : func->TransitivelyReferencedGlobals()) {
1276       if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
1277               global->Declaration()->decorations)) {
1278         if (builtin->builtin == ast::Builtin::kPosition) {
1279           found = true;
1280           break;
1281         }
1282       }
1283     }
1284     if (!found) {
1285       AddError(
1286           "a vertex shader must include the 'position' builtin in its return "
1287           "type",
1288           decl->source);
1289       return false;
1290     }
1291   }
1292 
1293   if (decl->PipelineStage() == ast::PipelineStage::kCompute) {
1294     if (!ast::HasDecoration<ast::WorkgroupDecoration>(decl->decorations)) {
1295       AddError(
1296           "a compute shader must include 'workgroup_size' in its "
1297           "attributes",
1298           decl->source);
1299       return false;
1300     }
1301   }
1302 
1303   // Validate there are no resource variable binding collisions
1304   std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
1305   for (auto* var : func->TransitivelyReferencedGlobals()) {
1306     auto* var_decl = var->Declaration();
1307     if (!var_decl->BindingPoint()) {
1308       continue;
1309     }
1310     auto bp = var->BindingPoint();
1311     auto res = binding_points.emplace(bp, var_decl);
1312     if (!res.second &&
1313         IsValidationEnabled(decl->decorations,
1314                             ast::DisabledValidation::kBindingPointCollision) &&
1315         IsValidationEnabled(res.first->second->decorations,
1316                             ast::DisabledValidation::kBindingPointCollision)) {
1317       // https://gpuweb.github.io/gpuweb/wgsl/#resource-interface
1318       // Bindings must not alias within a shader stage: two different
1319       // variables in the resource interface of a given shader must not have
1320       // the same group and binding values, when considered as a pair of
1321       // values.
1322       auto func_name = builder_->Symbols().NameFor(decl->symbol);
1323       AddError("entry point '" + func_name +
1324                    "' references multiple variables that use the "
1325                    "same resource binding [[group(" +
1326                    std::to_string(bp.group) + "), binding(" +
1327                    std::to_string(bp.binding) + ")]]",
1328                var_decl->source);
1329       AddNote("first resource binding usage declared here",
1330               res.first->second->source);
1331       return false;
1332     }
1333   }
1334 
1335   return true;
1336 }
1337 
ValidateStatements(const ast::StatementList & stmts)1338 bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
1339   for (auto* stmt : stmts) {
1340     if (!Sem(stmt)->IsReachable()) {
1341       /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
1342       /// become an error.
1343       AddWarning("code is unreachable", stmt->source);
1344       break;
1345     }
1346   }
1347   return true;
1348 }
1349 
ValidateBitcast(const ast::BitcastExpression * cast,const sem::Type * to)1350 bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
1351                                const sem::Type* to) {
1352   auto* from = TypeOf(cast->expr)->UnwrapRef();
1353   if (!from->is_numeric_scalar_or_vector()) {
1354     AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
1355              cast->expr->source);
1356     return false;
1357   }
1358   if (!to->is_numeric_scalar_or_vector()) {
1359     AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
1360     return false;
1361   }
1362 
1363   auto width = [&](const sem::Type* ty) {
1364     if (auto* vec = ty->As<sem::Vector>()) {
1365       return vec->Width();
1366     }
1367     return 1u;
1368   };
1369 
1370   if (width(from) != width(to)) {
1371     AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
1372                  TypeNameOf(to) + "'",
1373              cast->source);
1374     return false;
1375   }
1376 
1377   return true;
1378 }
1379 
ValidateBreakStatement(const sem::Statement * stmt)1380 bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
1381   if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
1382     AddError("break statement must be in a loop or switch case",
1383              stmt->Declaration()->source);
1384     return false;
1385   }
1386   return true;
1387 }
1388 
ValidateContinueStatement(const sem::Statement * stmt)1389 bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
1390   if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
1391     AddError("continuing blocks must not contain a continue statement",
1392              stmt->Declaration()->source);
1393     if (continuing != stmt->Declaration() &&
1394         continuing != stmt->Parent()->Declaration()) {
1395       AddNote("see continuing block here", continuing->source);
1396     }
1397     return false;
1398   }
1399 
1400   if (!stmt->FindFirstParent<sem::LoopBlockStatement>()) {
1401     AddError("continue statement must be in a loop",
1402              stmt->Declaration()->source);
1403     return false;
1404   }
1405 
1406   return true;
1407 }
1408 
ValidateDiscardStatement(const sem::Statement * stmt)1409 bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
1410   if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
1411     AddError("continuing blocks must not contain a discard statement",
1412              stmt->Declaration()->source);
1413     if (continuing != stmt->Declaration() &&
1414         continuing != stmt->Parent()->Declaration()) {
1415       AddNote("see continuing block here", continuing->source);
1416     }
1417     return false;
1418   }
1419   return true;
1420 }
1421 
ValidateFallthroughStatement(const sem::Statement * stmt)1422 bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
1423   if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
1424     if (auto* c = As<sem::CaseStatement>(block->Parent())) {
1425       if (block->Declaration()->Last() == stmt->Declaration()) {
1426         if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
1427           if (c->Declaration() != s->Declaration()->body.back()) {
1428             return true;
1429           }
1430           AddError(
1431               "a fallthrough statement must not be used in the last switch "
1432               "case",
1433               stmt->Declaration()->source);
1434           return false;
1435         }
1436       }
1437     }
1438   }
1439   AddError(
1440       "fallthrough must only be used as the last statement of a case block",
1441       stmt->Declaration()->source);
1442   return false;
1443 }
1444 
ValidateElseStatement(const sem::ElseStatement * stmt)1445 bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
1446   if (auto* cond = stmt->Condition()) {
1447     auto* cond_ty = cond->Type()->UnwrapRef();
1448     if (!cond_ty->Is<sem::Bool>()) {
1449       AddError(
1450           "else statement condition must be bool, got " + TypeNameOf(cond_ty),
1451           stmt->Condition()->Declaration()->source);
1452       return false;
1453     }
1454   }
1455   return true;
1456 }
1457 
ValidateForLoopStatement(const sem::ForLoopStatement * stmt)1458 bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
1459   if (auto* cond = stmt->Condition()) {
1460     auto* cond_ty = cond->Type()->UnwrapRef();
1461     if (!cond_ty->Is<sem::Bool>()) {
1462       AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
1463                stmt->Condition()->Declaration()->source);
1464       return false;
1465     }
1466   }
1467   return true;
1468 }
1469 
ValidateIfStatement(const sem::IfStatement * stmt)1470 bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
1471   auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
1472   if (!cond_ty->Is<sem::Bool>()) {
1473     AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
1474              stmt->Condition()->Declaration()->source);
1475     return false;
1476   }
1477   return true;
1478 }
1479 
ValidateIntrinsicCall(const sem::Call * call)1480 bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
1481   if (call->Type()->Is<sem::Void>()) {
1482     bool is_call_statement = false;
1483     if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
1484       if (call_stmt->expr == call->Declaration()) {
1485         is_call_statement = true;
1486       }
1487     }
1488     if (!is_call_statement) {
1489       // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
1490       // If the called function does not return a value, a function call
1491       // statement should be used instead.
1492       auto* ident = call->Declaration()->target.name;
1493       auto name = builder_->Symbols().NameFor(ident->symbol);
1494       AddError("intrinsic '" + name + "' does not return a value",
1495                call->Declaration()->source);
1496       return false;
1497     }
1498   }
1499 
1500   return true;
1501 }
1502 
ValidateTextureIntrinsicFunction(const sem::Call * call)1503 bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
1504   auto* intrinsic = call->Target()->As<sem::Intrinsic>();
1505   if (!intrinsic) {
1506     return false;
1507   }
1508 
1509   std::string func_name = intrinsic->str();
1510   auto& signature = intrinsic->Signature();
1511 
1512   auto check_arg_is_constexpr = [&](sem::ParameterUsage usage, int min,
1513                                     int max) {
1514     auto index = signature.IndexOf(usage);
1515     if (index < 0) {
1516       return true;
1517     }
1518     std::string name = sem::str(usage);
1519     auto* arg = call->Arguments()[index];
1520     if (auto values = arg->ConstantValue()) {
1521       // Assert that the constant values are of the expected type.
1522       if (!values.Type()->IsAnyOf<sem::I32, sem::Vector>() ||
1523           !values.ElementType()->Is<sem::I32>()) {
1524         TINT_ICE(Resolver, diagnostics_)
1525             << "failed to resolve '" + func_name + "' " << name
1526             << " parameter type";
1527         return false;
1528       }
1529 
1530       // Currently const_expr is restricted to literals and type constructors.
1531       // Check that that's all we have for the parameter.
1532       bool is_const_expr = true;
1533       ast::TraverseExpressions(
1534           arg->Declaration(), diagnostics_, [&](const ast::Expression* e) {
1535             if (e->IsAnyOf<ast::LiteralExpression, ast::CallExpression>()) {
1536               return ast::TraverseAction::Descend;
1537             }
1538             is_const_expr = false;
1539             return ast::TraverseAction::Stop;
1540           });
1541       if (is_const_expr) {
1542         auto vector = intrinsic->Parameters()[index]->Type()->Is<sem::Vector>();
1543         for (size_t i = 0; i < values.Elements().size(); i++) {
1544           auto value = values.Elements()[i].i32;
1545           if (value < min || value > max) {
1546             if (vector) {
1547               AddError("each component of the " + name +
1548                            " argument must be at least " + std::to_string(min) +
1549                            " and at most " + std::to_string(max) + ". " + name +
1550                            " component " + std::to_string(i) + " is " +
1551                            std::to_string(value),
1552                        arg->Declaration()->source);
1553             } else {
1554               AddError("the " + name + " argument must be at least " +
1555                            std::to_string(min) + " and at most " +
1556                            std::to_string(max) + ". " + name + " is " +
1557                            std::to_string(value),
1558                        arg->Declaration()->source);
1559             }
1560             return false;
1561           }
1562         }
1563         return true;
1564       }
1565     }
1566     AddError("the " + name + " argument must be a const_expression",
1567              arg->Declaration()->source);
1568     return false;
1569   };
1570 
1571   return check_arg_is_constexpr(sem::ParameterUsage::kOffset, -8, 7) &&
1572          check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
1573 }
1574 
ValidateFunctionCall(const sem::Call * call)1575 bool Resolver::ValidateFunctionCall(const sem::Call* call) {
1576   auto* decl = call->Declaration();
1577   auto* target = call->Target()->As<sem::Function>();
1578   auto sym = decl->target.name->symbol;
1579   auto name = builder_->Symbols().NameFor(sym);
1580 
1581   if (target->Declaration()->IsEntryPoint()) {
1582     // https://www.w3.org/TR/WGSL/#function-restriction
1583     // An entry point must never be the target of a function call.
1584     AddError("entry point functions cannot be the target of a function call",
1585              decl->source);
1586     return false;
1587   }
1588 
1589   if (decl->args.size() != target->Parameters().size()) {
1590     bool more = decl->args.size() > target->Parameters().size();
1591     AddError("too " + (more ? std::string("many") : std::string("few")) +
1592                  " arguments in call to '" + name + "', expected " +
1593                  std::to_string(target->Parameters().size()) + ", got " +
1594                  std::to_string(call->Arguments().size()),
1595              decl->source);
1596     return false;
1597   }
1598 
1599   for (size_t i = 0; i < call->Arguments().size(); ++i) {
1600     const sem::Variable* param = target->Parameters()[i];
1601     const ast::Expression* arg_expr = decl->args[i];
1602     auto* param_type = param->Type();
1603     auto* arg_type = TypeOf(arg_expr)->UnwrapRef();
1604 
1605     if (param_type != arg_type) {
1606       AddError("type mismatch for argument " + std::to_string(i + 1) +
1607                    " in call to '" + name + "', expected '" +
1608                    TypeNameOf(param_type) + "', got '" + TypeNameOf(arg_type) +
1609                    "'",
1610                arg_expr->source);
1611       return false;
1612     }
1613 
1614     if (param_type->Is<sem::Pointer>()) {
1615       auto is_valid = false;
1616       if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
1617         auto* var = ResolvedSymbol<sem::Variable>(ident_expr);
1618         if (!var) {
1619           TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
1620           return false;
1621         }
1622         if (var->Is<sem::Parameter>()) {
1623           is_valid = true;
1624         }
1625       } else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
1626         if (unary->op == ast::UnaryOp::kAddressOf) {
1627           if (auto* ident_unary =
1628                   unary->expr->As<ast::IdentifierExpression>()) {
1629             auto* var = ResolvedSymbol<sem::Variable>(ident_unary);
1630             if (!var) {
1631               TINT_ICE(Resolver, diagnostics_)
1632                   << "failed to resolve identifier";
1633               return false;
1634             }
1635             if (var->Declaration()->is_const) {
1636               TINT_ICE(Resolver, diagnostics_)
1637                   << "Resolver::FunctionCall() encountered an address-of "
1638                      "expression of a constant identifier expression";
1639               return false;
1640             }
1641             is_valid = true;
1642           }
1643         }
1644       }
1645 
1646       if (!is_valid &&
1647           IsValidationEnabled(
1648               param->Declaration()->decorations,
1649               ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
1650         AddError(
1651             "expected an address-of expression of a variable identifier "
1652             "expression or a function parameter",
1653             arg_expr->source);
1654         return false;
1655       }
1656     }
1657   }
1658 
1659   if (call->Type()->Is<sem::Void>()) {
1660     bool is_call_statement = false;
1661     if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
1662       if (call_stmt->expr == call->Declaration()) {
1663         is_call_statement = true;
1664       }
1665     }
1666     if (!is_call_statement) {
1667       // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
1668       // If the called function does not return a value, a function call
1669       // statement should be used instead.
1670       AddError("function '" + name + "' does not return a value", decl->source);
1671       return false;
1672     }
1673   }
1674 
1675   if (call->Behaviors().Contains(sem::Behavior::kDiscard)) {
1676     if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
1677       AddError(
1678           "cannot call a function that may discard inside a continuing block",
1679           call->Declaration()->source);
1680       if (continuing != call->Stmt()->Declaration() &&
1681           continuing != call->Stmt()->Parent()->Declaration()) {
1682         AddNote("see continuing block here", continuing->source);
1683       }
1684       return false;
1685     }
1686   }
1687 
1688   return true;
1689 }
1690 
ValidateStructureConstructorOrCast(const ast::CallExpression * ctor,const sem::Struct * struct_type)1691 bool Resolver::ValidateStructureConstructorOrCast(
1692     const ast::CallExpression* ctor,
1693     const sem::Struct* struct_type) {
1694   if (!struct_type->IsConstructible()) {
1695     AddError("struct constructor has non-constructible type", ctor->source);
1696     return false;
1697   }
1698 
1699   if (ctor->args.size() > 0) {
1700     if (ctor->args.size() != struct_type->Members().size()) {
1701       std::string fm =
1702           ctor->args.size() < struct_type->Members().size() ? "few" : "many";
1703       AddError("struct constructor has too " + fm + " inputs: expected " +
1704                    std::to_string(struct_type->Members().size()) + ", found " +
1705                    std::to_string(ctor->args.size()),
1706                ctor->source);
1707       return false;
1708     }
1709     for (auto* member : struct_type->Members()) {
1710       auto* value = ctor->args[member->Index()];
1711       auto* value_ty = TypeOf(value);
1712       if (member->Type() != value_ty->UnwrapRef()) {
1713         AddError(
1714             "type in struct constructor does not match struct member type: "
1715             "expected '" +
1716                 TypeNameOf(member->Type()) + "', found '" +
1717                 TypeNameOf(value_ty) + "'",
1718             value->source);
1719         return false;
1720       }
1721     }
1722   }
1723   return true;
1724 }
1725 
ValidateArrayConstructorOrCast(const ast::CallExpression * ctor,const sem::Array * array_type)1726 bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
1727                                               const sem::Array* array_type) {
1728   auto& values = ctor->args;
1729   auto* elem_ty = array_type->ElemType();
1730   for (auto* value : values) {
1731     auto* value_ty = TypeOf(value)->UnwrapRef();
1732     if (value_ty != elem_ty) {
1733       AddError(
1734           "type in array constructor does not match array type: "
1735           "expected '" +
1736               TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
1737           value->source);
1738       return false;
1739     }
1740   }
1741 
1742   if (array_type->IsRuntimeSized()) {
1743     AddError("cannot init a runtime-sized array", ctor->source);
1744     return false;
1745   } else if (!elem_ty->IsConstructible()) {
1746     AddError("array constructor has non-constructible element type",
1747              ctor->source);
1748     return false;
1749   } else if (!values.empty() && (values.size() != array_type->Count())) {
1750     std::string fm = values.size() < array_type->Count() ? "few" : "many";
1751     AddError("array constructor has too " + fm + " elements: expected " +
1752                  std::to_string(array_type->Count()) + ", found " +
1753                  std::to_string(values.size()),
1754              ctor->source);
1755     return false;
1756   } else if (values.size() > array_type->Count()) {
1757     AddError("array constructor has too many elements: expected " +
1758                  std::to_string(array_type->Count()) + ", found " +
1759                  std::to_string(values.size()),
1760              ctor->source);
1761     return false;
1762   }
1763   return true;
1764 }
1765 
ValidateVectorConstructorOrCast(const ast::CallExpression * ctor,const sem::Vector * vec_type)1766 bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
1767                                                const sem::Vector* vec_type) {
1768   auto& values = ctor->args;
1769   auto* elem_ty = vec_type->type();
1770   size_t value_cardinality_sum = 0;
1771   for (auto* value : values) {
1772     auto* value_ty = TypeOf(value)->UnwrapRef();
1773     if (value_ty->is_scalar()) {
1774       if (elem_ty != value_ty) {
1775         AddError(
1776             "type in vector constructor does not match vector type: "
1777             "expected '" +
1778                 TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'",
1779             value->source);
1780         return false;
1781       }
1782 
1783       value_cardinality_sum++;
1784     } else if (auto* value_vec = value_ty->As<sem::Vector>()) {
1785       auto* value_elem_ty = value_vec->type();
1786       // A mismatch of vector type parameter T is only an error if multiple
1787       // arguments are present. A single argument constructor constitutes a
1788       // type conversion expression.
1789       if (elem_ty != value_elem_ty && values.size() > 1u) {
1790         AddError(
1791             "type in vector constructor does not match vector type: "
1792             "expected '" +
1793                 TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_elem_ty) +
1794                 "'",
1795             value->source);
1796         return false;
1797       }
1798 
1799       value_cardinality_sum += value_vec->Width();
1800     } else {
1801       // A vector constructor can only accept vectors and scalars.
1802       AddError("expected vector or scalar type in vector constructor; found: " +
1803                    TypeNameOf(value_ty),
1804                value->source);
1805       return false;
1806     }
1807   }
1808 
1809   // A correct vector constructor must either be a zero-value expression,
1810   // a single-value initializer (splat) expression, or the number of components
1811   // of all constructor arguments must add up to the vector cardinality.
1812   if (value_cardinality_sum > 1 && value_cardinality_sum != vec_type->Width()) {
1813     if (values.empty()) {
1814       TINT_ICE(Resolver, diagnostics_)
1815           << "constructor arguments expected to be non-empty!";
1816     }
1817     const Source& values_start = values[0]->source;
1818     const Source& values_end = values[values.size() - 1]->source;
1819     AddError("attempted to construct '" + TypeNameOf(vec_type) + "' with " +
1820                  std::to_string(value_cardinality_sum) + " component(s)",
1821              Source::Combine(values_start, values_end));
1822     return false;
1823   }
1824   return true;
1825 }
1826 
ValidateVector(const sem::Vector * ty,const Source & source)1827 bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
1828   if (!ty->type()->is_scalar()) {
1829     AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
1830              source);
1831     return false;
1832   }
1833   return true;
1834 }
1835 
ValidateMatrix(const sem::Matrix * ty,const Source & source)1836 bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
1837   if (!ty->is_float_matrix()) {
1838     AddError("matrix element type must be 'f32'", source);
1839     return false;
1840   }
1841   return true;
1842 }
1843 
ValidateMatrixConstructorOrCast(const ast::CallExpression * ctor,const sem::Matrix * matrix_ty)1844 bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
1845                                                const sem::Matrix* matrix_ty) {
1846   auto& values = ctor->args;
1847   // Zero Value expression
1848   if (values.empty()) {
1849     return true;
1850   }
1851 
1852   if (!ValidateMatrix(matrix_ty, ctor->source)) {
1853     return false;
1854   }
1855 
1856   auto* elem_type = matrix_ty->type();
1857   auto num_elements = matrix_ty->columns() * matrix_ty->rows();
1858 
1859   // Print a generic error for an invalid matrix constructor, showing the
1860   // available overloads.
1861   auto print_error = [&]() {
1862     const Source& values_start = values[0]->source;
1863     const Source& values_end = values[values.size() - 1]->source;
1864     auto type_name = TypeNameOf(matrix_ty);
1865     auto elem_type_name = TypeNameOf(elem_type);
1866     std::stringstream ss;
1867     ss << "invalid constructor for " + type_name << std::endl << std::endl;
1868     ss << "3 candidates available:" << std::endl;
1869     ss << "  " << type_name << "()" << std::endl;
1870     ss << "  " << type_name << "(" << elem_type_name << ",...,"
1871        << elem_type_name << ")"
1872        << " // " << std::to_string(num_elements) << " arguments" << std::endl;
1873     ss << "  " << type_name << "(";
1874     for (uint32_t c = 0; c < matrix_ty->columns(); c++) {
1875       if (c > 0) {
1876         ss << ", ";
1877       }
1878       ss << VectorPretty(matrix_ty->rows(), elem_type);
1879     }
1880     ss << ")" << std::endl;
1881     AddError(ss.str(), Source::Combine(values_start, values_end));
1882   };
1883 
1884   const sem::Type* expected_arg_type = nullptr;
1885   if (num_elements == values.size()) {
1886     // Column-major construction from scalar elements.
1887     expected_arg_type = matrix_ty->type();
1888   } else if (matrix_ty->columns() == values.size()) {
1889     // Column-by-column construction from vectors.
1890     expected_arg_type = matrix_ty->ColumnType();
1891   } else {
1892     print_error();
1893     return false;
1894   }
1895 
1896   for (auto* value : values) {
1897     if (TypeOf(value)->UnwrapRef() != expected_arg_type) {
1898       print_error();
1899       return false;
1900     }
1901   }
1902 
1903   return true;
1904 }
1905 
ValidateScalarConstructorOrCast(const ast::CallExpression * ctor,const sem::Type * ty)1906 bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
1907                                                const sem::Type* ty) {
1908   if (ctor->args.size() == 0) {
1909     return true;
1910   }
1911   if (ctor->args.size() > 1) {
1912     AddError("expected zero or one value in constructor, got " +
1913                  std::to_string(ctor->args.size()),
1914              ctor->source);
1915     return false;
1916   }
1917 
1918   // Validate constructor
1919   auto* value = ctor->args[0];
1920   auto* value_ty = TypeOf(value)->UnwrapRef();
1921 
1922   using Bool = sem::Bool;
1923   using I32 = sem::I32;
1924   using U32 = sem::U32;
1925   using F32 = sem::F32;
1926 
1927   const bool is_valid = (ty->Is<Bool>() && value_ty->is_scalar()) ||
1928                         (ty->Is<I32>() && value_ty->is_scalar()) ||
1929                         (ty->Is<U32>() && value_ty->is_scalar()) ||
1930                         (ty->Is<F32>() && value_ty->is_scalar());
1931   if (!is_valid) {
1932     AddError("cannot construct '" + TypeNameOf(ty) +
1933                  "' with a value of type '" + TypeNameOf(value_ty) + "'",
1934              ctor->source);
1935 
1936     return false;
1937   }
1938 
1939   return true;
1940 }
1941 
ValidatePipelineStages()1942 bool Resolver::ValidatePipelineStages() {
1943   auto check_workgroup_storage = [&](const sem::Function* func,
1944                                      const sem::Function* entry_point) {
1945     auto stage = entry_point->Declaration()->PipelineStage();
1946     if (stage != ast::PipelineStage::kCompute) {
1947       for (auto* var : func->DirectlyReferencedGlobals()) {
1948         if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
1949           std::stringstream stage_name;
1950           stage_name << stage;
1951           for (auto* user : var->Users()) {
1952             if (func == user->Stmt()->Function()) {
1953               AddError("workgroup memory cannot be used by " +
1954                            stage_name.str() + " pipeline stage",
1955                        user->Declaration()->source);
1956               break;
1957             }
1958           }
1959           AddNote("variable is declared here", var->Declaration()->source);
1960           if (func != entry_point) {
1961             TraverseCallChain(diagnostics_, entry_point, func,
1962                               [&](const sem::Function* f) {
1963                                 AddNote("called by function '" +
1964                                             builder_->Symbols().NameFor(
1965                                                 f->Declaration()->symbol) +
1966                                             "'",
1967                                         f->Declaration()->source);
1968                               });
1969             AddNote("called by entry point '" +
1970                         builder_->Symbols().NameFor(
1971                             entry_point->Declaration()->symbol) +
1972                         "'",
1973                     entry_point->Declaration()->source);
1974           }
1975           return false;
1976         }
1977       }
1978     }
1979     return true;
1980   };
1981 
1982   for (auto* entry_point : entry_points_) {
1983     if (!check_workgroup_storage(entry_point, entry_point)) {
1984       return false;
1985     }
1986     for (auto* func : entry_point->TransitivelyCalledFunctions()) {
1987       if (!check_workgroup_storage(func, entry_point)) {
1988         return false;
1989       }
1990     }
1991   }
1992 
1993   auto check_intrinsic_calls = [&](const sem::Function* func,
1994                                    const sem::Function* entry_point) {
1995     auto stage = entry_point->Declaration()->PipelineStage();
1996     for (auto* intrinsic : func->DirectlyCalledIntrinsics()) {
1997       if (!intrinsic->SupportedStages().Contains(stage)) {
1998         auto* call = func->FindDirectCallTo(intrinsic);
1999         std::stringstream err;
2000         err << "built-in cannot be used by " << stage << " pipeline stage";
2001         AddError(err.str(), call ? call->Declaration()->source
2002                                  : func->Declaration()->source);
2003         if (func != entry_point) {
2004           TraverseCallChain(
2005               diagnostics_, entry_point, func, [&](const sem::Function* f) {
2006                 AddNote(
2007                     "called by function '" +
2008                         builder_->Symbols().NameFor(f->Declaration()->symbol) +
2009                         "'",
2010                     f->Declaration()->source);
2011               });
2012           AddNote("called by entry point '" +
2013                       builder_->Symbols().NameFor(
2014                           entry_point->Declaration()->symbol) +
2015                       "'",
2016                   entry_point->Declaration()->source);
2017         }
2018         return false;
2019       }
2020     }
2021     return true;
2022   };
2023 
2024   for (auto* entry_point : entry_points_) {
2025     if (!check_intrinsic_calls(entry_point, entry_point)) {
2026       return false;
2027     }
2028     for (auto* func : entry_point->TransitivelyCalledFunctions()) {
2029       if (!check_intrinsic_calls(func, entry_point)) {
2030         return false;
2031       }
2032     }
2033   }
2034   return true;
2035 }
2036 
ValidateArray(const sem::Array * arr,const Source & source)2037 bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
2038   auto* el_ty = arr->ElemType();
2039 
2040   if (auto* el_str = el_ty->As<sem::Struct>()) {
2041     if (el_str->IsBlockDecorated()) {
2042       // https://gpuweb.github.io/gpuweb/wgsl/#attributes
2043       // A structure type with the block attribute must not be:
2044       // * the element type of an array type
2045       // * the member type in another structure
2046       AddError(
2047           "A structure type with a [[block]] decoration cannot be used as an "
2048           "element of an array",
2049           source);
2050       return false;
2051     }
2052   }
2053   return true;
2054 }
2055 
ValidateArrayStrideDecoration(const ast::StrideDecoration * deco,uint32_t el_size,uint32_t el_align,const Source & source)2056 bool Resolver::ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
2057                                              uint32_t el_size,
2058                                              uint32_t el_align,
2059                                              const Source& source) {
2060   auto stride = deco->stride;
2061   bool is_valid_stride =
2062       (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
2063   if (!is_valid_stride) {
2064     // https://gpuweb.github.io/gpuweb/wgsl/#array-layout-rules
2065     // Arrays decorated with the stride attribute must have a stride that is
2066     // at least the size of the element type, and be a multiple of the
2067     // element type's alignment value.
2068     AddError(
2069         "arrays decorated with the stride attribute must have a stride "
2070         "that is at least the size of the element type, and be a multiple "
2071         "of the element type's alignment value.",
2072         source);
2073     return false;
2074   }
2075   return true;
2076 }
2077 
ValidateAlias(const ast::Alias * alias)2078 bool Resolver::ValidateAlias(const ast::Alias* alias) {
2079   auto name = builder_->Symbols().NameFor(alias->name);
2080   if (sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone) {
2081     AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
2082              alias->source);
2083     return false;
2084   }
2085 
2086   return true;
2087 }
2088 
ValidateStructure(const sem::Struct * str)2089 bool Resolver::ValidateStructure(const sem::Struct* str) {
2090   auto name = builder_->Symbols().NameFor(str->Declaration()->name);
2091   if (sem::ParseIntrinsicType(name) != sem::IntrinsicType::kNone) {
2092     AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
2093              str->Declaration()->source);
2094     return false;
2095   }
2096 
2097   if (str->Members().empty()) {
2098     AddError("structures must have at least one member",
2099              str->Declaration()->source);
2100     return false;
2101   }
2102 
2103   std::unordered_set<uint32_t> locations;
2104   for (auto* member : str->Members()) {
2105     if (auto* r = member->Type()->As<sem::Array>()) {
2106       if (r->IsRuntimeSized()) {
2107         if (member != str->Members().back()) {
2108           AddError(
2109               "runtime arrays may only appear as the last member of a struct",
2110               member->Declaration()->source);
2111           return false;
2112         }
2113         if (!str->IsBlockDecorated()) {
2114           AddError(
2115               "a struct containing a runtime-sized array "
2116               "requires the [[block]] attribute: '" +
2117                   builder_->Symbols().NameFor(str->Declaration()->name) + "'",
2118               member->Declaration()->source);
2119           return false;
2120         }
2121       }
2122     }
2123 
2124     auto has_location = false;
2125     auto has_position = false;
2126     const ast::InvariantDecoration* invariant_attribute = nullptr;
2127     const ast::InterpolateDecoration* interpolate_attribute = nullptr;
2128     for (auto* deco : member->Declaration()->decorations) {
2129       if (!deco->IsAnyOf<ast::BuiltinDecoration,             //
2130                          ast::InternalDecoration,            //
2131                          ast::InterpolateDecoration,         //
2132                          ast::InvariantDecoration,           //
2133                          ast::LocationDecoration,            //
2134                          ast::StructMemberOffsetDecoration,  //
2135                          ast::StructMemberSizeDecoration,    //
2136                          ast::StructMemberAlignDecoration>()) {
2137         if (deco->Is<ast::StrideDecoration>() &&
2138             IsValidationDisabled(
2139                 member->Declaration()->decorations,
2140                 ast::DisabledValidation::kIgnoreStrideDecoration)) {
2141           continue;
2142         }
2143         AddError("decoration is not valid for structure members", deco->source);
2144         return false;
2145       }
2146 
2147       if (auto* invariant = deco->As<ast::InvariantDecoration>()) {
2148         invariant_attribute = invariant;
2149       } else if (auto* location = deco->As<ast::LocationDecoration>()) {
2150         has_location = true;
2151         if (!ValidateLocationDecoration(location, member->Type(), locations,
2152                                         member->Declaration()->source)) {
2153           return false;
2154         }
2155       } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
2156         if (!ValidateBuiltinDecoration(builtin, member->Type(),
2157                                        /* is_input */ false)) {
2158           return false;
2159         }
2160         if (builtin->builtin == ast::Builtin::kPosition) {
2161           has_position = true;
2162         }
2163       } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
2164         interpolate_attribute = interpolate;
2165         if (!ValidateInterpolateDecoration(interpolate, member->Type())) {
2166           return false;
2167         }
2168       }
2169     }
2170 
2171     if (invariant_attribute && !has_position) {
2172       AddError("invariant attribute must only be applied to a position builtin",
2173                invariant_attribute->source);
2174       return false;
2175     }
2176 
2177     if (interpolate_attribute && !has_location) {
2178       AddError("interpolate attribute must only be used with [[location]]",
2179                interpolate_attribute->source);
2180       return false;
2181     }
2182 
2183     if (auto* member_struct_type = member->Type()->As<sem::Struct>()) {
2184       if (auto* member_struct_type_block_decoration =
2185               ast::GetDecoration<ast::StructBlockDecoration>(
2186                   member_struct_type->Declaration()->decorations)) {
2187         AddError("structs must not contain [[block]] decorated struct members",
2188                  member->Declaration()->source);
2189         AddNote("see member's struct decoration here",
2190                 member_struct_type_block_decoration->source);
2191         return false;
2192       }
2193     }
2194   }
2195 
2196   for (auto* deco : str->Declaration()->decorations) {
2197     if (!(deco->Is<ast::StructBlockDecoration>())) {
2198       AddError("decoration is not valid for struct declarations", deco->source);
2199       return false;
2200     }
2201   }
2202 
2203   return true;
2204 }
2205 
ValidateLocationDecoration(const ast::LocationDecoration * location,const sem::Type * type,std::unordered_set<uint32_t> & locations,const Source & source,const bool is_input)2206 bool Resolver::ValidateLocationDecoration(
2207     const ast::LocationDecoration* location,
2208     const sem::Type* type,
2209     std::unordered_set<uint32_t>& locations,
2210     const Source& source,
2211     const bool is_input) {
2212   std::string inputs_or_output = is_input ? "inputs" : "output";
2213   if (current_function_ && current_function_->Declaration()->PipelineStage() ==
2214                                ast::PipelineStage::kCompute) {
2215     AddError("decoration is not valid for compute shader " + inputs_or_output,
2216              location->source);
2217     return false;
2218   }
2219 
2220   if (!type->is_numeric_scalar_or_vector()) {
2221     std::string invalid_type = TypeNameOf(type);
2222     AddError("cannot apply 'location' attribute to declaration of type '" +
2223                  invalid_type + "'",
2224              source);
2225     AddNote(
2226         "'location' attribute must only be applied to declarations of "
2227         "numeric scalar or numeric vector type",
2228         location->source);
2229     return false;
2230   }
2231 
2232   if (locations.count(location->value)) {
2233     AddError(deco_to_str(location) + " attribute appears multiple times",
2234              location->source);
2235     return false;
2236   }
2237   locations.emplace(location->value);
2238 
2239   return true;
2240 }
2241 
ValidateReturn(const ast::ReturnStatement * ret)2242 bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
2243   auto* func_type = current_function_->ReturnType();
2244 
2245   auto* ret_type = ret->value ? TypeOf(ret->value)->UnwrapRef()
2246                               : builder_->create<sem::Void>();
2247 
2248   if (func_type->UnwrapRef() != ret_type) {
2249     AddError(
2250         "return statement type must match its function "
2251         "return type, returned '" +
2252             TypeNameOf(ret_type) + "', expected '" + TypeNameOf(func_type) +
2253             "'",
2254         ret->source);
2255     return false;
2256   }
2257 
2258   auto* sem = Sem(ret);
2259   if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
2260     AddError("continuing blocks must not contain a return statement",
2261              ret->source);
2262     if (continuing != sem->Declaration() &&
2263         continuing != sem->Parent()->Declaration()) {
2264       AddNote("see continuing block here", continuing->source);
2265     }
2266     return false;
2267   }
2268 
2269   return true;
2270 }
2271 
ValidateSwitch(const ast::SwitchStatement * s)2272 bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
2273   auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
2274   if (!cond_ty->is_integer_scalar()) {
2275     AddError(
2276         "switch statement selector expression must be of a "
2277         "scalar integer type",
2278         s->condition->source);
2279     return false;
2280   }
2281 
2282   bool has_default = false;
2283   std::unordered_map<uint32_t, Source> selectors;
2284 
2285   for (auto* case_stmt : s->body) {
2286     if (case_stmt->IsDefault()) {
2287       if (has_default) {
2288         // More than one default clause
2289         AddError("switch statement must have exactly one default clause",
2290                  case_stmt->source);
2291         return false;
2292       }
2293       has_default = true;
2294     }
2295 
2296     for (auto* selector : case_stmt->selectors) {
2297       if (cond_ty != TypeOf(selector)) {
2298         AddError(
2299             "the case selector values must have the same "
2300             "type as the selector expression.",
2301             case_stmt->source);
2302         return false;
2303       }
2304 
2305       auto v = selector->ValueAsU32();
2306       auto it = selectors.find(v);
2307       if (it != selectors.end()) {
2308         auto val = selector->Is<ast::IntLiteralExpression>()
2309                        ? std::to_string(selector->ValueAsI32())
2310                        : std::to_string(selector->ValueAsU32());
2311         AddError("duplicate switch case '" + val + "'", selector->source);
2312         AddNote("previous case declared here", it->second);
2313         return false;
2314       }
2315       selectors.emplace(v, selector->source);
2316     }
2317   }
2318 
2319   if (!has_default) {
2320     // No default clause
2321     AddError("switch statement must have a default clause", s->source);
2322     return false;
2323   }
2324 
2325   return true;
2326 }
2327 
ValidateAssignment(const ast::AssignmentStatement * a)2328 bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
2329   auto const* rhs_ty = TypeOf(a->rhs);
2330 
2331   if (a->lhs->Is<ast::PhonyExpression>()) {
2332     // https://www.w3.org/TR/WGSL/#phony-assignment-section
2333     auto* ty = rhs_ty->UnwrapRef();
2334     if (!ty->IsConstructible() &&
2335         !ty->IsAnyOf<sem::Pointer, sem::Texture, sem::Sampler>()) {
2336       AddError(
2337           "cannot assign '" + TypeNameOf(rhs_ty) +
2338               "' to '_'. '_' can only be assigned a constructible, pointer, "
2339               "texture or sampler type",
2340           a->rhs->source);
2341       return false;
2342     }
2343     return true;  // RHS can be anything.
2344   }
2345 
2346   // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
2347   auto const* lhs_ty = TypeOf(a->lhs);
2348 
2349   if (auto* var = ResolvedSymbol<sem::Variable>(a->lhs)) {
2350     auto* decl = var->Declaration();
2351     if (var->Is<sem::Parameter>()) {
2352       AddError("cannot assign to function parameter", a->lhs->source);
2353       AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
2354                   "' is declared here:",
2355               decl->source);
2356       return false;
2357     }
2358     if (decl->is_const) {
2359       AddError("cannot assign to const", a->lhs->source);
2360       AddNote("'" + builder_->Symbols().NameFor(decl->symbol) +
2361                   "' is declared here:",
2362               decl->source);
2363       return false;
2364     }
2365   }
2366 
2367   auto* lhs_ref = lhs_ty->As<sem::Reference>();
2368   if (!lhs_ref) {
2369     // LHS is not a reference, so it has no storage.
2370     AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'",
2371              a->lhs->source);
2372     return false;
2373   }
2374 
2375   auto* storage_ty = lhs_ref->StoreType();
2376   auto* value_type = rhs_ty->UnwrapRef();  // Implicit load of RHS
2377 
2378   // Value type has to match storage type
2379   if (storage_ty != value_type) {
2380     AddError("cannot assign '" + TypeNameOf(rhs_ty) + "' to '" +
2381                  TypeNameOf(lhs_ty) + "'",
2382              a->source);
2383     return false;
2384   }
2385   if (!storage_ty->IsConstructible()) {
2386     AddError("storage type of assignment must be constructible", a->source);
2387     return false;
2388   }
2389   if (lhs_ref->Access() == ast::Access::kRead) {
2390     AddError(
2391         "cannot store into a read-only type '" + RawTypeNameOf(lhs_ty) + "'",
2392         a->source);
2393     return false;
2394   }
2395   return true;
2396 }
2397 
ValidateNoDuplicateDecorations(const ast::DecorationList & decorations)2398 bool Resolver::ValidateNoDuplicateDecorations(
2399     const ast::DecorationList& decorations) {
2400   std::unordered_map<const TypeInfo*, Source> seen;
2401   for (auto* d : decorations) {
2402     auto res = seen.emplace(&d->TypeInfo(), d->source);
2403     if (!res.second && !d->Is<ast::InternalDecoration>()) {
2404       AddError("duplicate " + d->Name() + " decoration", d->source);
2405       AddNote("first decoration declared here", res.first->second);
2406       return false;
2407     }
2408   }
2409   return true;
2410 }
2411 
IsValidationDisabled(const ast::DecorationList & decorations,ast::DisabledValidation validation) const2412 bool Resolver::IsValidationDisabled(const ast::DecorationList& decorations,
2413                                     ast::DisabledValidation validation) const {
2414   for (auto* decoration : decorations) {
2415     if (auto* dv = decoration->As<ast::DisableValidationDecoration>()) {
2416       if (dv->validation == validation) {
2417         return true;
2418       }
2419     }
2420   }
2421   return false;
2422 }
2423 
IsValidationEnabled(const ast::DecorationList & decorations,ast::DisabledValidation validation) const2424 bool Resolver::IsValidationEnabled(const ast::DecorationList& decorations,
2425                                    ast::DisabledValidation validation) const {
2426   return !IsValidationDisabled(decorations, validation);
2427 }
2428 
2429 }  // namespace resolver
2430 }  // namespace tint
2431