• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/reader/wgsl/parser_impl.h"
16 
17 #include "src/ast/array.h"
18 #include "src/ast/assignment_statement.h"
19 #include "src/ast/bitcast_expression.h"
20 #include "src/ast/break_statement.h"
21 #include "src/ast/call_statement.h"
22 #include "src/ast/continue_statement.h"
23 #include "src/ast/discard_statement.h"
24 #include "src/ast/external_texture.h"
25 #include "src/ast/fallthrough_statement.h"
26 #include "src/ast/if_statement.h"
27 #include "src/ast/invariant_decoration.h"
28 #include "src/ast/loop_statement.h"
29 #include "src/ast/override_decoration.h"
30 #include "src/ast/return_statement.h"
31 #include "src/ast/stage_decoration.h"
32 #include "src/ast/struct_block_decoration.h"
33 #include "src/ast/switch_statement.h"
34 #include "src/ast/type_name.h"
35 #include "src/ast/unary_op_expression.h"
36 #include "src/ast/variable_decl_statement.h"
37 #include "src/ast/vector.h"
38 #include "src/ast/workgroup_decoration.h"
39 #include "src/reader/wgsl/lexer.h"
40 #include "src/sem/depth_texture_type.h"
41 #include "src/sem/external_texture_type.h"
42 #include "src/sem/multisampled_texture_type.h"
43 #include "src/sem/sampled_texture_type.h"
44 
45 namespace tint {
46 namespace reader {
47 namespace wgsl {
48 namespace {
49 
50 template <typename T>
51 using Expect = ParserImpl::Expect<T>;
52 
53 template <typename T>
54 using Maybe = ParserImpl::Maybe<T>;
55 
56 /// Controls the maximum number of times we'll call into the sync() and
57 /// unary_expression() functions from themselves. This is to guard against stack
58 /// overflow when there is an excessive number of blocks.
59 constexpr uint32_t kMaxParseDepth = 128;
60 
61 /// The maximum number of tokens to look ahead to try and sync the
62 /// parser on error.
63 constexpr size_t const kMaxResynchronizeLookahead = 32;
64 
65 const char kVertexStage[] = "vertex";
66 const char kFragmentStage[] = "fragment";
67 const char kComputeStage[] = "compute";
68 
69 const char kReadAccess[] = "read";
70 const char kWriteAccess[] = "write";
71 const char kReadWriteAccess[] = "read_write";
72 
ident_to_builtin(const std::string & str)73 ast::Builtin ident_to_builtin(const std::string& str) {
74   if (str == "position") {
75     return ast::Builtin::kPosition;
76   }
77   if (str == "vertex_index") {
78     return ast::Builtin::kVertexIndex;
79   }
80   if (str == "instance_index") {
81     return ast::Builtin::kInstanceIndex;
82   }
83   if (str == "front_facing") {
84     return ast::Builtin::kFrontFacing;
85   }
86   if (str == "frag_depth") {
87     return ast::Builtin::kFragDepth;
88   }
89   if (str == "local_invocation_id") {
90     return ast::Builtin::kLocalInvocationId;
91   }
92   if (str == "local_invocation_idx" || str == "local_invocation_index") {
93     return ast::Builtin::kLocalInvocationIndex;
94   }
95   if (str == "global_invocation_id") {
96     return ast::Builtin::kGlobalInvocationId;
97   }
98   if (str == "workgroup_id") {
99     return ast::Builtin::kWorkgroupId;
100   }
101   if (str == "num_workgroups") {
102     return ast::Builtin::kNumWorkgroups;
103   }
104   if (str == "sample_index") {
105     return ast::Builtin::kSampleIndex;
106   }
107   if (str == "sample_mask") {
108     return ast::Builtin::kSampleMask;
109   }
110   return ast::Builtin::kNone;
111 }
112 
113 const char kBindingDecoration[] = "binding";
114 const char kBlockDecoration[] = "block";
115 const char kBuiltinDecoration[] = "builtin";
116 const char kGroupDecoration[] = "group";
117 const char kInterpolateDecoration[] = "interpolate";
118 const char kInvariantDecoration[] = "invariant";
119 const char kLocationDecoration[] = "location";
120 const char kOverrideDecoration[] = "override";
121 const char kSizeDecoration[] = "size";
122 const char kAlignDecoration[] = "align";
123 const char kStageDecoration[] = "stage";
124 const char kStrideDecoration[] = "stride";
125 const char kWorkgroupSizeDecoration[] = "workgroup_size";
126 
is_decoration(Token t)127 bool is_decoration(Token t) {
128   if (!t.IsIdentifier()) {
129     return false;
130   }
131 
132   auto s = t.to_str();
133   return s == kAlignDecoration || s == kBindingDecoration ||
134          s == kBlockDecoration || s == kBuiltinDecoration ||
135          s == kGroupDecoration || s == kInterpolateDecoration ||
136          s == kLocationDecoration || s == kOverrideDecoration ||
137          s == kSizeDecoration || s == kStageDecoration ||
138          s == kStrideDecoration || s == kWorkgroupSizeDecoration;
139 }
140 
141 // https://gpuweb.github.io/gpuweb/wgsl.html#reserved-keywords
is_reserved(Token t)142 bool is_reserved(Token t) {
143   auto s = t.to_str();
144   return s == "asm" || s == "bf16" || s == "const" || s == "do" ||
145          s == "enum" || s == "f16" || s == "f64" || s == "handle" ||
146          s == "i8" || s == "i16" || s == "i64" || s == "mat" ||
147          s == "premerge" || s == "regardless" || s == "typedef" || s == "u8" ||
148          s == "u16" || s == "u64" || s == "unless" || s == "using" ||
149          s == "vec" || s == "void" || s == "while";
150 }
151 
152 /// Enter-exit counters for block token types.
153 /// Used by sync_to() to skip over closing block tokens that were opened during
154 /// the forward scan.
155 struct BlockCounters {
156   int attrs = 0;    // [[ ]]
157   int brace = 0;    // {   }
158   int bracket = 0;  // [   ]
159   int paren = 0;    // (   )
160 
161   /// @return the current enter-exit depth for the given block token type. If
162   /// `t` is not a block token type, then 0 is always returned.
consumetint::reader::wgsl::__anonaf54b81e0111::BlockCounters163   int consume(const Token& t) {
164     if (t.Is(Token::Type::kAttrLeft))
165       return attrs++;
166     if (t.Is(Token::Type::kAttrRight))
167       return attrs--;
168     if (t.Is(Token::Type::kBraceLeft))
169       return brace++;
170     if (t.Is(Token::Type::kBraceRight))
171       return brace--;
172     if (t.Is(Token::Type::kBracketLeft))
173       return bracket++;
174     if (t.Is(Token::Type::kBracketRight))
175       return bracket--;
176     if (t.Is(Token::Type::kParenLeft))
177       return paren++;
178     if (t.Is(Token::Type::kParenRight))
179       return paren--;
180     return 0;
181   }
182 };
183 }  // namespace
184 
185 /// RAII helper that combines a Source on construction with the last token's
186 /// source when implicitly converted to `Source`.
187 class ParserImpl::MultiTokenSource {
188  public:
189   /// Constructor that starts with Source at the current peek position
190   /// @param parser the parser
MultiTokenSource(ParserImpl * parser)191   explicit MultiTokenSource(ParserImpl* parser)
192       : MultiTokenSource(parser, parser->peek().source().Begin()) {}
193 
194   /// Constructor that starts with the input `start` Source
195   /// @param parser the parser
196   /// @param start the start source of the range
MultiTokenSource(ParserImpl * parser,const Source & start)197   MultiTokenSource(ParserImpl* parser, const Source& start)
198       : parser_(parser), start_(start) {}
199 
200   /// Implicit conversion to Source that returns the combined source from start
201   /// to the current last token's source.
operator Source() const202   operator Source() const {
203     Source end = parser_->last_token().source().End();
204     if (end < start_) {
205       end = start_;
206     }
207     return Source::Combine(start_, end);
208   }
209 
210  private:
211   ParserImpl* parser_;
212   Source start_;
213 };
214 
215 ParserImpl::TypedIdentifier::TypedIdentifier() = default;
216 
217 ParserImpl::TypedIdentifier::TypedIdentifier(const TypedIdentifier&) = default;
218 
TypedIdentifier(const ast::Type * type_in,std::string name_in,Source source_in)219 ParserImpl::TypedIdentifier::TypedIdentifier(const ast::Type* type_in,
220                                              std::string name_in,
221                                              Source source_in)
222     : type(type_in), name(std::move(name_in)), source(std::move(source_in)) {}
223 
224 ParserImpl::TypedIdentifier::~TypedIdentifier() = default;
225 
226 ParserImpl::FunctionHeader::FunctionHeader() = default;
227 
228 ParserImpl::FunctionHeader::FunctionHeader(const FunctionHeader&) = default;
229 
FunctionHeader(Source src,std::string n,ast::VariableList p,const ast::Type * ret_ty,ast::DecorationList ret_decos)230 ParserImpl::FunctionHeader::FunctionHeader(Source src,
231                                            std::string n,
232                                            ast::VariableList p,
233                                            const ast::Type* ret_ty,
234                                            ast::DecorationList ret_decos)
235     : source(src),
236       name(n),
237       params(p),
238       return_type(ret_ty),
239       return_type_decorations(ret_decos) {}
240 
241 ParserImpl::FunctionHeader::~FunctionHeader() = default;
242 
243 ParserImpl::FunctionHeader& ParserImpl::FunctionHeader::operator=(
244     const FunctionHeader& rhs) = default;
245 
246 ParserImpl::VarDeclInfo::VarDeclInfo() = default;
247 
248 ParserImpl::VarDeclInfo::VarDeclInfo(const VarDeclInfo&) = default;
249 
VarDeclInfo(Source source_in,std::string name_in,ast::StorageClass storage_class_in,ast::Access access_in,const ast::Type * type_in)250 ParserImpl::VarDeclInfo::VarDeclInfo(Source source_in,
251                                      std::string name_in,
252                                      ast::StorageClass storage_class_in,
253                                      ast::Access access_in,
254                                      const ast::Type* type_in)
255     : source(std::move(source_in)),
256       name(std::move(name_in)),
257       storage_class(storage_class_in),
258       access(access_in),
259       type(type_in) {}
260 
261 ParserImpl::VarDeclInfo::~VarDeclInfo() = default;
262 
ParserImpl(Source::File const * file)263 ParserImpl::ParserImpl(Source::File const* file)
264     : lexer_(std::make_unique<Lexer>(file->path, &file->content)) {}
265 
266 ParserImpl::~ParserImpl() = default;
267 
add_error(const Source & source,const std::string & err,const std::string & use)268 ParserImpl::Failure::Errored ParserImpl::add_error(const Source& source,
269                                                    const std::string& err,
270                                                    const std::string& use) {
271   std::stringstream msg;
272   msg << err;
273   if (!use.empty()) {
274     msg << " for " << use;
275   }
276   add_error(source, msg.str());
277   return Failure::kErrored;
278 }
279 
add_error(const Token & t,const std::string & err)280 ParserImpl::Failure::Errored ParserImpl::add_error(const Token& t,
281                                                    const std::string& err) {
282   add_error(t.source(), err);
283   return Failure::kErrored;
284 }
285 
add_error(const Source & source,const std::string & err)286 ParserImpl::Failure::Errored ParserImpl::add_error(const Source& source,
287                                                    const std::string& err) {
288   if (silence_errors_ == 0) {
289     builder_.Diagnostics().add_error(diag::System::Reader, err, source);
290   }
291   return Failure::kErrored;
292 }
293 
deprecated(const Source & source,const std::string & msg)294 void ParserImpl::deprecated(const Source& source, const std::string& msg) {
295   builder_.Diagnostics().add_warning(
296       diag::System::Reader, "use of deprecated language feature: " + msg,
297       source);
298 }
299 
next()300 Token ParserImpl::next() {
301   if (!token_queue_.empty()) {
302     auto t = token_queue_.front();
303     token_queue_.pop_front();
304     last_token_ = t;
305     return last_token_;
306   }
307   last_token_ = lexer_->next();
308   return last_token_;
309 }
310 
peek(size_t idx)311 Token ParserImpl::peek(size_t idx) {
312   while (token_queue_.size() < (idx + 1))
313     token_queue_.push_back(lexer_->next());
314   return token_queue_[idx];
315 }
316 
peek_is(Token::Type tok,size_t idx)317 bool ParserImpl::peek_is(Token::Type tok, size_t idx) {
318   return peek(idx).Is(tok);
319 }
320 
last_token() const321 Token ParserImpl::last_token() const {
322   return last_token_;
323 }
324 
Parse()325 bool ParserImpl::Parse() {
326   translation_unit();
327   return !has_error();
328 }
329 
330 // translation_unit
331 //  : global_decl* EOF
translation_unit()332 void ParserImpl::translation_unit() {
333   while (continue_parsing()) {
334     auto p = peek();
335     if (p.IsEof()) {
336       break;
337     }
338     expect_global_decl();
339     if (builder_.Diagnostics().error_count() >= max_errors_) {
340       add_error(Source{{}, p.source().file_path},
341                 "stopping after " + std::to_string(max_errors_) + " errors");
342       break;
343     }
344   }
345 }
346 
347 // global_decl
348 //  : SEMICOLON
349 //  | global_variable_decl SEMICLON
350 //  | global_constant_decl SEMICOLON
351 //  | type_alias SEMICOLON
352 //  | struct_decl SEMICOLON
353 //  | function_decl
expect_global_decl()354 Expect<bool> ParserImpl::expect_global_decl() {
355   if (match(Token::Type::kSemicolon) || match(Token::Type::kEOF))
356     return true;
357 
358   bool errored = false;
359 
360   auto decos = decoration_list();
361   if (decos.errored)
362     errored = true;
363   if (!continue_parsing())
364     return Failure::kErrored;
365 
366   auto decl = sync(Token::Type::kSemicolon, [&]() -> Maybe<bool> {
367     auto gv = global_variable_decl(decos.value);
368     if (gv.errored)
369       return Failure::kErrored;
370     if (gv.matched) {
371       if (!expect("variable declaration", Token::Type::kSemicolon))
372         return Failure::kErrored;
373 
374       builder_.AST().AddGlobalVariable(gv.value);
375       return true;
376     }
377 
378     auto gc = global_constant_decl(decos.value);
379     if (gc.errored)
380       return Failure::kErrored;
381 
382     if (gc.matched) {
383       if (!expect("let declaration", Token::Type::kSemicolon))
384         return Failure::kErrored;
385 
386       builder_.AST().AddGlobalVariable(gc.value);
387       return true;
388     }
389 
390     auto ta = type_alias();
391     if (ta.errored)
392       return Failure::kErrored;
393 
394     if (ta.matched) {
395       if (!expect("type alias", Token::Type::kSemicolon))
396         return Failure::kErrored;
397 
398       builder_.AST().AddTypeDecl(ta.value);
399       return true;
400     }
401 
402     auto str = struct_decl(decos.value);
403     if (str.errored)
404       return Failure::kErrored;
405 
406     if (str.matched) {
407       if (!expect("struct declaration", Token::Type::kSemicolon))
408         return Failure::kErrored;
409 
410       builder_.AST().AddTypeDecl(str.value);
411       return true;
412     }
413 
414     return Failure::kNoMatch;
415   });
416 
417   if (decl.errored) {
418     errored = true;
419   }
420   if (decl.matched) {
421     return expect_decorations_consumed(decos.value);
422   }
423 
424   auto func = function_decl(decos.value);
425   if (func.errored) {
426     errored = true;
427   }
428   if (func.matched) {
429     builder_.AST().AddFunction(func.value);
430     return true;
431   }
432 
433   if (errored) {
434     return Failure::kErrored;
435   }
436 
437   // Invalid syntax found - try and determine the best error message
438 
439   // We have decorations parsed, but nothing to consume them?
440   if (decos.value.size() > 0) {
441     return add_error(next(), "expected declaration after decorations");
442   }
443 
444   // We have a statement outside of a function?
445   auto t = peek();
446   auto stat = without_error([&] { return statement(); });
447   if (stat.matched) {
448     // Attempt to jump to the next '}' - the function might have just been
449     // missing an opening line.
450     sync_to(Token::Type::kBraceRight, true);
451     return add_error(t, "statement found outside of function body");
452   }
453   if (!stat.errored) {
454     // No match, no error - the parser might not have progressed.
455     // Ensure we always make _some_ forward progress.
456     next();
457   }
458 
459   // The token might itself be an error.
460   if (t.IsError()) {
461     next();  // Consume it.
462     return add_error(t.source(), t.to_str());
463   }
464 
465   // Exhausted all attempts to make sense of where we're at.
466   // Spew a generic error.
467 
468   return add_error(t, "unexpected token");
469 }
470 
471 // global_variable_decl
472 //  : variable_decoration_list* variable_decl
473 //  | variable_decoration_list* variable_decl EQUAL const_expr
global_variable_decl(ast::DecorationList & decos)474 Maybe<const ast::Variable*> ParserImpl::global_variable_decl(
475     ast::DecorationList& decos) {
476   auto decl = variable_decl();
477   if (decl.errored)
478     return Failure::kErrored;
479   if (!decl.matched)
480     return Failure::kNoMatch;
481 
482   const ast::Expression* constructor = nullptr;
483   if (match(Token::Type::kEqual)) {
484     auto expr = expect_const_expr();
485     if (expr.errored)
486       return Failure::kErrored;
487     constructor = expr.value;
488   }
489 
490   return create<ast::Variable>(
491       decl->source,                             // source
492       builder_.Symbols().Register(decl->name),  // symbol
493       decl->storage_class,                      // storage class
494       decl->access,                             // access control
495       decl->type,                               // type
496       false,                                    // is_const
497       constructor,                              // constructor
498       std::move(decos));                        // decorations
499 }
500 
501 // global_constant_decl
502 //  : attribute_list* LET variable_ident_decl global_const_initializer?
503 // global_const_initializer
504 //  : EQUAL const_expr
global_constant_decl(ast::DecorationList & decos)505 Maybe<const ast::Variable*> ParserImpl::global_constant_decl(
506     ast::DecorationList& decos) {
507   if (!match(Token::Type::kLet)) {
508     return Failure::kNoMatch;
509   }
510 
511   const char* use = "let declaration";
512 
513   auto decl = expect_variable_ident_decl(use, /* allow_inferred = */ true);
514   if (decl.errored)
515     return Failure::kErrored;
516 
517   const ast::Expression* initializer = nullptr;
518   if (match(Token::Type::kEqual)) {
519     auto init = expect_const_expr();
520     if (init.errored) {
521       return Failure::kErrored;
522     }
523     initializer = std::move(init.value);
524   }
525 
526   return create<ast::Variable>(
527       decl->source,                             // source
528       builder_.Symbols().Register(decl->name),  // symbol
529       ast::StorageClass::kNone,                 // storage class
530       ast::Access::kUndefined,                  // access control
531       decl->type,                               // type
532       true,                                     // is_const
533       initializer,                              // constructor
534       std::move(decos));                        // decorations
535 }
536 
537 // variable_decl
538 //   : VAR variable_qualifier? variable_ident_decl
variable_decl(bool allow_inferred)539 Maybe<ParserImpl::VarDeclInfo> ParserImpl::variable_decl(bool allow_inferred) {
540   Source source;
541   if (!match(Token::Type::kVar, &source))
542     return Failure::kNoMatch;
543 
544   VariableQualifier vq;
545   auto explicit_vq = variable_qualifier();
546   if (explicit_vq.errored)
547     return Failure::kErrored;
548   if (explicit_vq.matched) {
549     vq = explicit_vq.value;
550   }
551 
552   auto decl =
553       expect_variable_ident_decl("variable declaration", allow_inferred);
554   if (decl.errored)
555     return Failure::kErrored;
556 
557   return VarDeclInfo{decl->source, decl->name, vq.storage_class, vq.access,
558                      decl->type};
559 }
560 
561 // texture_sampler_types
562 //  : sampler_type
563 //  | depth_texture_type
564 //  | sampled_texture_type LESS_THAN type_decl GREATER_THAN
565 //  | multisampled_texture_type LESS_THAN type_decl GREATER_THAN
566 //  | storage_texture_type LESS_THAN image_storage_type
567 //                         COMMA access GREATER_THAN
texture_sampler_types()568 Maybe<const ast::Type*> ParserImpl::texture_sampler_types() {
569   auto type = sampler_type();
570   if (type.matched)
571     return type;
572 
573   type = depth_texture_type();
574   if (type.matched)
575     return type;
576 
577   type = external_texture_type();
578   if (type.matched)
579     return type.value;
580 
581   auto source_range = make_source_range();
582 
583   auto dim = sampled_texture_type();
584   if (dim.matched) {
585     const char* use = "sampled texture type";
586 
587     auto subtype = expect_lt_gt_block(use, [&] { return expect_type(use); });
588     if (subtype.errored)
589       return Failure::kErrored;
590 
591     return builder_.ty.sampled_texture(source_range, dim.value, subtype.value);
592   }
593 
594   auto ms_dim = multisampled_texture_type();
595   if (ms_dim.matched) {
596     const char* use = "multisampled texture type";
597 
598     auto subtype = expect_lt_gt_block(use, [&] { return expect_type(use); });
599     if (subtype.errored)
600       return Failure::kErrored;
601 
602     return builder_.ty.multisampled_texture(source_range, ms_dim.value,
603                                             subtype.value);
604   }
605 
606   auto storage = storage_texture_type();
607   if (storage.matched) {
608     const char* use = "storage texture type";
609     using StorageTextureInfo =
610         std::pair<tint::ast::ImageFormat, tint::ast::Access>;
611     auto params = expect_lt_gt_block(use, [&]() -> Expect<StorageTextureInfo> {
612       auto format = expect_image_storage_type(use);
613       if (format.errored) {
614         return Failure::kErrored;
615       }
616 
617       if (!expect("access control", Token::Type::kComma)) {
618         return Failure::kErrored;
619       }
620 
621       auto access = expect_access("access control");
622       if (access.errored) {
623         return Failure::kErrored;
624       }
625 
626       return std::make_pair(format.value, access.value);
627     });
628 
629     if (params.errored) {
630       return Failure::kErrored;
631     }
632 
633     return builder_.ty.storage_texture(source_range, storage.value,
634                                        params->first, params->second);
635   }
636 
637   return Failure::kNoMatch;
638 }
639 
640 // sampler_type
641 //  : SAMPLER
642 //  | SAMPLER_COMPARISON
sampler_type()643 Maybe<const ast::Type*> ParserImpl::sampler_type() {
644   Source source;
645   if (match(Token::Type::kSampler, &source))
646     return builder_.ty.sampler(source, ast::SamplerKind::kSampler);
647 
648   if (match(Token::Type::kComparisonSampler, &source))
649     return builder_.ty.sampler(source, ast::SamplerKind::kComparisonSampler);
650 
651   return Failure::kNoMatch;
652 }
653 
654 // sampled_texture_type
655 //  : TEXTURE_SAMPLED_1D
656 //  | TEXTURE_SAMPLED_2D
657 //  | TEXTURE_SAMPLED_2D_ARRAY
658 //  | TEXTURE_SAMPLED_3D
659 //  | TEXTURE_SAMPLED_CUBE
660 //  | TEXTURE_SAMPLED_CUBE_ARRAY
sampled_texture_type()661 Maybe<const ast::TextureDimension> ParserImpl::sampled_texture_type() {
662   if (match(Token::Type::kTextureSampled1d))
663     return ast::TextureDimension::k1d;
664 
665   if (match(Token::Type::kTextureSampled2d))
666     return ast::TextureDimension::k2d;
667 
668   if (match(Token::Type::kTextureSampled2dArray))
669     return ast::TextureDimension::k2dArray;
670 
671   if (match(Token::Type::kTextureSampled3d))
672     return ast::TextureDimension::k3d;
673 
674   if (match(Token::Type::kTextureSampledCube))
675     return ast::TextureDimension::kCube;
676 
677   if (match(Token::Type::kTextureSampledCubeArray))
678     return ast::TextureDimension::kCubeArray;
679 
680   return Failure::kNoMatch;
681 }
682 
683 // external_texture_type
684 //  : TEXTURE_EXTERNAL
external_texture_type()685 Maybe<const ast::Type*> ParserImpl::external_texture_type() {
686   Source source;
687   if (match(Token::Type::kTextureExternal, &source)) {
688     return builder_.ty.external_texture(source);
689   }
690 
691   return Failure::kNoMatch;
692 }
693 
694 // multisampled_texture_type
695 //  : TEXTURE_MULTISAMPLED_2D
multisampled_texture_type()696 Maybe<const ast::TextureDimension> ParserImpl::multisampled_texture_type() {
697   if (match(Token::Type::kTextureMultisampled2d))
698     return ast::TextureDimension::k2d;
699 
700   return Failure::kNoMatch;
701 }
702 
703 // storage_texture_type
704 //  : TEXTURE_STORAGE_1D
705 //  | TEXTURE_STORAGE_2D
706 //  | TEXTURE_STORAGE_2D_ARRAY
707 //  | TEXTURE_STORAGE_3D
storage_texture_type()708 Maybe<const ast::TextureDimension> ParserImpl::storage_texture_type() {
709   if (match(Token::Type::kTextureStorage1d))
710     return ast::TextureDimension::k1d;
711   if (match(Token::Type::kTextureStorage2d))
712     return ast::TextureDimension::k2d;
713   if (match(Token::Type::kTextureStorage2dArray))
714     return ast::TextureDimension::k2dArray;
715   if (match(Token::Type::kTextureStorage3d))
716     return ast::TextureDimension::k3d;
717 
718   return Failure::kNoMatch;
719 }
720 
721 // depth_texture_type
722 //  : TEXTURE_DEPTH_2D
723 //  | TEXTURE_DEPTH_2D_ARRAY
724 //  | TEXTURE_DEPTH_CUBE
725 //  | TEXTURE_DEPTH_CUBE_ARRAY
726 //  | TEXTURE_DEPTH_MULTISAMPLED_2D
depth_texture_type()727 Maybe<const ast::Type*> ParserImpl::depth_texture_type() {
728   Source source;
729   if (match(Token::Type::kTextureDepth2d, &source)) {
730     return builder_.ty.depth_texture(source, ast::TextureDimension::k2d);
731   }
732   if (match(Token::Type::kTextureDepth2dArray, &source)) {
733     return builder_.ty.depth_texture(source, ast::TextureDimension::k2dArray);
734   }
735   if (match(Token::Type::kTextureDepthCube, &source)) {
736     return builder_.ty.depth_texture(source, ast::TextureDimension::kCube);
737   }
738   if (match(Token::Type::kTextureDepthCubeArray, &source)) {
739     return builder_.ty.depth_texture(source, ast::TextureDimension::kCubeArray);
740   }
741   if (match(Token::Type::kTextureDepthMultisampled2d, &source)) {
742     return builder_.ty.depth_multisampled_texture(source,
743                                                   ast::TextureDimension::k2d);
744   }
745   return Failure::kNoMatch;
746 }
747 
748 // image_storage_type
749 //  : R8UNORM
750 //  | R8SNORM
751 //  | R8UINT
752 //  | R8SINT
753 //  | R16UINT
754 //  | R16SINT
755 //  | R16FLOAT
756 //  | RG8UNORM
757 //  | RG8SNORM
758 //  | RG8UINT
759 //  | RG8SINT
760 //  | R32UINT
761 //  | R32SINT
762 //  | R32FLOAT
763 //  | RG16UINT
764 //  | RG16SINT
765 //  | RG16FLOAT
766 //  | RGBA8UNORM
767 /// | RGBA8UNORM-SRGB
768 //  | RGBA8SNORM
769 //  | RGBA8UINT
770 //  | RGBA8SINT
771 //  | BGRA8UNORM
772 //  | BGRA8UNORM-SRGB
773 //  | RGB10A2UNORM
774 //  | RG11B10FLOAT
775 //  | RG32UINT
776 //  | RG32SINT
777 //  | RG32FLOAT
778 //  | RGBA16UINT
779 //  | RGBA16SINT
780 //  | RGBA16FLOAT
781 //  | RGBA32UINT
782 //  | RGBA32SINT
783 //  | RGBA32FLOAT
expect_image_storage_type(const std::string & use)784 Expect<ast::ImageFormat> ParserImpl::expect_image_storage_type(
785     const std::string& use) {
786   if (match(Token::Type::kFormatR8Unorm))
787     return ast::ImageFormat::kR8Unorm;
788 
789   if (match(Token::Type::kFormatR8Snorm))
790     return ast::ImageFormat::kR8Snorm;
791 
792   if (match(Token::Type::kFormatR8Uint))
793     return ast::ImageFormat::kR8Uint;
794 
795   if (match(Token::Type::kFormatR8Sint))
796     return ast::ImageFormat::kR8Sint;
797 
798   if (match(Token::Type::kFormatR16Uint))
799     return ast::ImageFormat::kR16Uint;
800 
801   if (match(Token::Type::kFormatR16Sint))
802     return ast::ImageFormat::kR16Sint;
803 
804   if (match(Token::Type::kFormatR16Float))
805     return ast::ImageFormat::kR16Float;
806 
807   if (match(Token::Type::kFormatRg8Unorm))
808     return ast::ImageFormat::kRg8Unorm;
809 
810   if (match(Token::Type::kFormatRg8Snorm))
811     return ast::ImageFormat::kRg8Snorm;
812 
813   if (match(Token::Type::kFormatRg8Uint))
814     return ast::ImageFormat::kRg8Uint;
815 
816   if (match(Token::Type::kFormatRg8Sint))
817     return ast::ImageFormat::kRg8Sint;
818 
819   if (match(Token::Type::kFormatR32Uint))
820     return ast::ImageFormat::kR32Uint;
821 
822   if (match(Token::Type::kFormatR32Sint))
823     return ast::ImageFormat::kR32Sint;
824 
825   if (match(Token::Type::kFormatR32Float))
826     return ast::ImageFormat::kR32Float;
827 
828   if (match(Token::Type::kFormatRg16Uint))
829     return ast::ImageFormat::kRg16Uint;
830 
831   if (match(Token::Type::kFormatRg16Sint))
832     return ast::ImageFormat::kRg16Sint;
833 
834   if (match(Token::Type::kFormatRg16Float))
835     return ast::ImageFormat::kRg16Float;
836 
837   if (match(Token::Type::kFormatRgba8Unorm))
838     return ast::ImageFormat::kRgba8Unorm;
839 
840   if (match(Token::Type::kFormatRgba8UnormSrgb))
841     return ast::ImageFormat::kRgba8UnormSrgb;
842 
843   if (match(Token::Type::kFormatRgba8Snorm))
844     return ast::ImageFormat::kRgba8Snorm;
845 
846   if (match(Token::Type::kFormatRgba8Uint))
847     return ast::ImageFormat::kRgba8Uint;
848 
849   if (match(Token::Type::kFormatRgba8Sint))
850     return ast::ImageFormat::kRgba8Sint;
851 
852   if (match(Token::Type::kFormatBgra8Unorm))
853     return ast::ImageFormat::kBgra8Unorm;
854 
855   if (match(Token::Type::kFormatBgra8UnormSrgb))
856     return ast::ImageFormat::kBgra8UnormSrgb;
857 
858   if (match(Token::Type::kFormatRgb10A2Unorm))
859     return ast::ImageFormat::kRgb10A2Unorm;
860 
861   if (match(Token::Type::kFormatRg11B10Float))
862     return ast::ImageFormat::kRg11B10Float;
863 
864   if (match(Token::Type::kFormatRg32Uint))
865     return ast::ImageFormat::kRg32Uint;
866 
867   if (match(Token::Type::kFormatRg32Sint))
868     return ast::ImageFormat::kRg32Sint;
869 
870   if (match(Token::Type::kFormatRg32Float))
871     return ast::ImageFormat::kRg32Float;
872 
873   if (match(Token::Type::kFormatRgba16Uint))
874     return ast::ImageFormat::kRgba16Uint;
875 
876   if (match(Token::Type::kFormatRgba16Sint))
877     return ast::ImageFormat::kRgba16Sint;
878 
879   if (match(Token::Type::kFormatRgba16Float))
880     return ast::ImageFormat::kRgba16Float;
881 
882   if (match(Token::Type::kFormatRgba32Uint))
883     return ast::ImageFormat::kRgba32Uint;
884 
885   if (match(Token::Type::kFormatRgba32Sint))
886     return ast::ImageFormat::kRgba32Sint;
887 
888   if (match(Token::Type::kFormatRgba32Float))
889     return ast::ImageFormat::kRgba32Float;
890 
891   return add_error(peek().source(), "invalid format", use);
892 }
893 
894 // variable_ident_decl
895 //   : IDENT COLON variable_decoration_list* type_decl
expect_variable_ident_decl(const std::string & use,bool allow_inferred)896 Expect<ParserImpl::TypedIdentifier> ParserImpl::expect_variable_ident_decl(
897     const std::string& use,
898     bool allow_inferred) {
899   auto ident = expect_ident(use);
900   if (ident.errored)
901     return Failure::kErrored;
902 
903   if (allow_inferred && !peek_is(Token::Type::kColon)) {
904     return TypedIdentifier{nullptr, ident.value, ident.source};
905   }
906 
907   if (!expect(use, Token::Type::kColon))
908     return Failure::kErrored;
909 
910   auto decos = decoration_list();
911   if (decos.errored)
912     return Failure::kErrored;
913 
914   auto t = peek();
915   auto type = type_decl(decos.value);
916   if (type.errored)
917     return Failure::kErrored;
918   if (!type.matched)
919     return add_error(t.source(), "invalid type", use);
920 
921   if (!expect_decorations_consumed(decos.value))
922     return Failure::kErrored;
923 
924   return TypedIdentifier{type.value, ident.value, ident.source};
925 }
926 
expect_access(const std::string & use)927 Expect<ast::Access> ParserImpl::expect_access(const std::string& use) {
928   auto ident = expect_ident(use);
929   if (ident.errored)
930     return Failure::kErrored;
931 
932   if (ident.value == kReadAccess)
933     return {ast::Access::kRead, ident.source};
934   if (ident.value == kWriteAccess)
935     return {ast::Access::kWrite, ident.source};
936   if (ident.value == kReadWriteAccess)
937     return {ast::Access::kReadWrite, ident.source};
938 
939   return add_error(ident.source, "invalid value for access control");
940 }
941 
942 // variable_qualifier
943 //   : LESS_THAN storage_class (COMMA access_mode)? GREATER_THAN
variable_qualifier()944 Maybe<ParserImpl::VariableQualifier> ParserImpl::variable_qualifier() {
945   if (!peek_is(Token::Type::kLessThan)) {
946     return Failure::kNoMatch;
947   }
948 
949   auto* use = "variable declaration";
950   auto vq = expect_lt_gt_block(use, [&]() -> Expect<VariableQualifier> {
951     auto source = make_source_range();
952     auto sc = expect_storage_class(use);
953     if (sc.errored) {
954       return Failure::kErrored;
955     }
956     if (match(Token::Type::kComma)) {
957       auto ac = expect_access(use);
958       if (ac.errored) {
959         return Failure::kErrored;
960       }
961       return VariableQualifier{sc.value, ac.value};
962     }
963     return Expect<VariableQualifier>{
964         VariableQualifier{sc.value, ast::Access::kUndefined}, source};
965   });
966 
967   if (vq.errored) {
968     return Failure::kErrored;
969   }
970 
971   return vq;
972 }
973 
974 // type_alias
975 //   : TYPE IDENT EQUAL type_decl
type_alias()976 Maybe<const ast::Alias*> ParserImpl::type_alias() {
977   if (!peek_is(Token::Type::kType))
978     return Failure::kNoMatch;
979 
980   auto t = next();
981   const char* use = "type alias";
982 
983   auto name = expect_ident(use);
984   if (name.errored)
985     return Failure::kErrored;
986 
987   if (!expect(use, Token::Type::kEqual))
988     return Failure::kErrored;
989 
990   auto type = type_decl();
991   if (type.errored)
992     return Failure::kErrored;
993   if (!type.matched)
994     return add_error(peek(), "invalid type alias");
995 
996   return builder_.ty.alias(make_source_range_from(t.source()), name.value,
997                            type.value);
998 }
999 
1000 // type_decl
1001 //   : IDENTIFIER
1002 //   | BOOL
1003 //   | FLOAT32
1004 //   | INT32
1005 //   | UINT32
1006 //   | VEC2 LESS_THAN type_decl GREATER_THAN
1007 //   | VEC3 LESS_THAN type_decl GREATER_THAN
1008 //   | VEC4 LESS_THAN type_decl GREATER_THAN
1009 //   | PTR LESS_THAN storage_class, type_decl (COMMA access_mode)? GREATER_THAN
1010 //   | array_decoration_list* ARRAY LESS_THAN type_decl COMMA
1011 //          INT_LITERAL GREATER_THAN
1012 //   | array_decoration_list* ARRAY LESS_THAN type_decl
1013 //          GREATER_THAN
1014 //   | MAT2x2 LESS_THAN type_decl GREATER_THAN
1015 //   | MAT2x3 LESS_THAN type_decl GREATER_THAN
1016 //   | MAT2x4 LESS_THAN type_decl GREATER_THAN
1017 //   | MAT3x2 LESS_THAN type_decl GREATER_THAN
1018 //   | MAT3x3 LESS_THAN type_decl GREATER_THAN
1019 //   | MAT3x4 LESS_THAN type_decl GREATER_THAN
1020 //   | MAT4x2 LESS_THAN type_decl GREATER_THAN
1021 //   | MAT4x3 LESS_THAN type_decl GREATER_THAN
1022 //   | MAT4x4 LESS_THAN type_decl GREATER_THAN
1023 //   | texture_sampler_types
type_decl()1024 Maybe<const ast::Type*> ParserImpl::type_decl() {
1025   auto decos = decoration_list();
1026   if (decos.errored)
1027     return Failure::kErrored;
1028 
1029   auto type = type_decl(decos.value);
1030   if (type.errored) {
1031     return Failure::kErrored;
1032   }
1033   if (!expect_decorations_consumed(decos.value)) {
1034     return Failure::kErrored;
1035   }
1036   if (!type.matched) {
1037     return Failure::kNoMatch;
1038   }
1039 
1040   return type;
1041 }
1042 
type_decl(ast::DecorationList & decos)1043 Maybe<const ast::Type*> ParserImpl::type_decl(ast::DecorationList& decos) {
1044   auto t = peek();
1045   Source source;
1046   if (match(Token::Type::kIdentifier, &source)) {
1047     return builder_.create<ast::TypeName>(
1048         source, builder_.Symbols().Register(t.to_str()));
1049   }
1050 
1051   if (match(Token::Type::kBool, &source))
1052     return builder_.ty.bool_(source);
1053 
1054   if (match(Token::Type::kF32, &source))
1055     return builder_.ty.f32(source);
1056 
1057   if (match(Token::Type::kI32, &source))
1058     return builder_.ty.i32(source);
1059 
1060   if (match(Token::Type::kU32, &source))
1061     return builder_.ty.u32(source);
1062 
1063   if (t.IsVector()) {
1064     next();  // Consume the peek
1065     return expect_type_decl_vector(t);
1066   }
1067 
1068   if (match(Token::Type::kPtr)) {
1069     return expect_type_decl_pointer(t);
1070   }
1071 
1072   if (match(Token::Type::kAtomic)) {
1073     return expect_type_decl_atomic(t);
1074   }
1075 
1076   if (match(Token::Type::kArray, &source)) {
1077     return expect_type_decl_array(t, std::move(decos));
1078   }
1079 
1080   if (t.IsMatrix()) {
1081     next();  // Consume the peek
1082     return expect_type_decl_matrix(t);
1083   }
1084 
1085   auto texture_or_sampler = texture_sampler_types();
1086   if (texture_or_sampler.errored)
1087     return Failure::kErrored;
1088   if (texture_or_sampler.matched)
1089     return texture_or_sampler;
1090 
1091   return Failure::kNoMatch;
1092 }
1093 
expect_type(const std::string & use)1094 Expect<const ast::Type*> ParserImpl::expect_type(const std::string& use) {
1095   auto type = type_decl();
1096   if (type.errored)
1097     return Failure::kErrored;
1098   if (!type.matched)
1099     return add_error(peek().source(), "invalid type", use);
1100   return type.value;
1101 }
1102 
expect_type_decl_pointer(Token t)1103 Expect<const ast::Type*> ParserImpl::expect_type_decl_pointer(Token t) {
1104   const char* use = "ptr declaration";
1105 
1106   auto storage_class = ast::StorageClass::kNone;
1107   auto access = ast::Access::kUndefined;
1108 
1109   auto subtype = expect_lt_gt_block(use, [&]() -> Expect<const ast::Type*> {
1110     auto sc = expect_storage_class(use);
1111     if (sc.errored) {
1112       return Failure::kErrored;
1113     }
1114     storage_class = sc.value;
1115 
1116     if (!expect(use, Token::Type::kComma)) {
1117       return Failure::kErrored;
1118     }
1119 
1120     auto type = expect_type(use);
1121     if (type.errored) {
1122       return Failure::kErrored;
1123     }
1124 
1125     if (match(Token::Type::kComma)) {
1126       auto ac = expect_access("access control");
1127       if (ac.errored) {
1128         return Failure::kErrored;
1129       }
1130       access = ac.value;
1131     }
1132 
1133     return type.value;
1134   });
1135 
1136   if (subtype.errored) {
1137     return Failure::kErrored;
1138   }
1139 
1140   return builder_.ty.pointer(make_source_range_from(t.source()), subtype.value,
1141                              storage_class, access);
1142 }
1143 
expect_type_decl_atomic(Token t)1144 Expect<const ast::Type*> ParserImpl::expect_type_decl_atomic(Token t) {
1145   const char* use = "atomic declaration";
1146 
1147   auto subtype = expect_lt_gt_block(use, [&] { return expect_type(use); });
1148   if (subtype.errored) {
1149     return Failure::kErrored;
1150   }
1151 
1152   return builder_.ty.atomic(make_source_range_from(t.source()), subtype.value);
1153 }
1154 
expect_type_decl_vector(Token t)1155 Expect<const ast::Type*> ParserImpl::expect_type_decl_vector(Token t) {
1156   uint32_t count = 2;
1157   if (t.Is(Token::Type::kVec3))
1158     count = 3;
1159   else if (t.Is(Token::Type::kVec4))
1160     count = 4;
1161 
1162   const char* use = "vector";
1163 
1164   auto subtype = expect_lt_gt_block(use, [&] { return expect_type(use); });
1165   if (subtype.errored)
1166     return Failure::kErrored;
1167 
1168   return builder_.ty.vec(make_source_range_from(t.source()), subtype.value,
1169                          count);
1170 }
1171 
expect_type_decl_array(Token t,ast::DecorationList decos)1172 Expect<const ast::Type*> ParserImpl::expect_type_decl_array(
1173     Token t,
1174     ast::DecorationList decos) {
1175   const char* use = "array declaration";
1176 
1177   const ast::Expression* size = nullptr;
1178 
1179   auto subtype = expect_lt_gt_block(use, [&]() -> Expect<const ast::Type*> {
1180     auto type = expect_type(use);
1181     if (type.errored)
1182       return Failure::kErrored;
1183 
1184     if (match(Token::Type::kComma)) {
1185       auto expr = primary_expression();
1186       if (expr.errored) {
1187         return Failure::kErrored;
1188       } else if (!expr.matched) {
1189         return add_error(peek(), "expected array size expression");
1190       }
1191 
1192       size = std::move(expr.value);
1193     }
1194 
1195     return type.value;
1196   });
1197 
1198   if (subtype.errored) {
1199     return Failure::kErrored;
1200   }
1201 
1202   return builder_.ty.array(make_source_range_from(t.source()), subtype.value,
1203                            size, std::move(decos));
1204 }
1205 
expect_type_decl_matrix(Token t)1206 Expect<const ast::Type*> ParserImpl::expect_type_decl_matrix(Token t) {
1207   uint32_t rows = 2;
1208   uint32_t columns = 2;
1209   if (t.IsMat3xN()) {
1210     columns = 3;
1211   } else if (t.IsMat4xN()) {
1212     columns = 4;
1213   }
1214   if (t.IsMatNx3()) {
1215     rows = 3;
1216   } else if (t.IsMatNx4()) {
1217     rows = 4;
1218   }
1219 
1220   const char* use = "matrix";
1221 
1222   auto subtype = expect_lt_gt_block(use, [&] { return expect_type(use); });
1223   if (subtype.errored)
1224     return Failure::kErrored;
1225 
1226   return builder_.ty.mat(make_source_range_from(t.source()), subtype.value,
1227                          columns, rows);
1228 }
1229 
1230 // storage_class
1231 //  : INPUT
1232 //  | OUTPUT
1233 //  | UNIFORM
1234 //  | WORKGROUP
1235 //  | STORAGE
1236 //  | IMAGE
1237 //  | PRIVATE
1238 //  | FUNCTION
expect_storage_class(const std::string & use)1239 Expect<ast::StorageClass> ParserImpl::expect_storage_class(
1240     const std::string& use) {
1241   auto source = peek().source();
1242 
1243   if (match(Token::Type::kUniform))
1244     return {ast::StorageClass::kUniform, source};
1245 
1246   if (match(Token::Type::kWorkgroup))
1247     return {ast::StorageClass::kWorkgroup, source};
1248 
1249   if (match(Token::Type::kStorage))
1250     return {ast::StorageClass::kStorage, source};
1251 
1252   if (match(Token::Type::kImage))
1253     return {ast::StorageClass::kImage, source};
1254 
1255   if (match(Token::Type::kPrivate))
1256     return {ast::StorageClass::kPrivate, source};
1257 
1258   if (match(Token::Type::kFunction))
1259     return {ast::StorageClass::kFunction, source};
1260 
1261   return add_error(source, "invalid storage class", use);
1262 }
1263 
1264 // struct_decl
1265 //   : struct_decoration_decl* STRUCT IDENT struct_body_decl
struct_decl(ast::DecorationList & decos)1266 Maybe<const ast::Struct*> ParserImpl::struct_decl(ast::DecorationList& decos) {
1267   auto t = peek();
1268   auto source = t.source();
1269 
1270   if (!match(Token::Type::kStruct))
1271     return Failure::kNoMatch;
1272 
1273   auto name = expect_ident("struct declaration");
1274   if (name.errored)
1275     return Failure::kErrored;
1276 
1277   auto body = expect_struct_body_decl();
1278   if (body.errored)
1279     return Failure::kErrored;
1280 
1281   auto sym = builder_.Symbols().Register(name.value);
1282   return create<ast::Struct>(source, sym, std::move(body.value),
1283                              std::move(decos));
1284 }
1285 
1286 // struct_body_decl
1287 //   : BRACKET_LEFT struct_member* BRACKET_RIGHT
expect_struct_body_decl()1288 Expect<ast::StructMemberList> ParserImpl::expect_struct_body_decl() {
1289   return expect_brace_block(
1290       "struct declaration", [&]() -> Expect<ast::StructMemberList> {
1291         bool errored = false;
1292 
1293         ast::StructMemberList members;
1294 
1295         while (continue_parsing() && !peek_is(Token::Type::kBraceRight) &&
1296                !peek_is(Token::Type::kEOF)) {
1297           auto member = sync(Token::Type::kSemicolon,
1298                              [&]() -> Expect<ast::StructMember*> {
1299                                auto decos = decoration_list();
1300                                if (decos.errored) {
1301                                  errored = true;
1302                                }
1303                                if (!synchronized_) {
1304                                  return Failure::kErrored;
1305                                }
1306                                return expect_struct_member(decos.value);
1307                              });
1308 
1309           if (member.errored) {
1310             errored = true;
1311           } else {
1312             members.push_back(member.value);
1313           }
1314         }
1315 
1316         if (errored)
1317           return Failure::kErrored;
1318 
1319         return members;
1320       });
1321 }
1322 
1323 // struct_member
1324 //   : struct_member_decoration_decl+ variable_ident_decl SEMICOLON
expect_struct_member(ast::DecorationList & decos)1325 Expect<ast::StructMember*> ParserImpl::expect_struct_member(
1326     ast::DecorationList& decos) {
1327   auto decl = expect_variable_ident_decl("struct member");
1328   if (decl.errored)
1329     return Failure::kErrored;
1330 
1331   if (!expect("struct member", Token::Type::kSemicolon))
1332     return Failure::kErrored;
1333 
1334   return create<ast::StructMember>(decl->source,
1335                                    builder_.Symbols().Register(decl->name),
1336                                    decl->type, std::move(decos));
1337 }
1338 
1339 // function_decl
1340 //   : function_header body_stmt
function_decl(ast::DecorationList & decos)1341 Maybe<const ast::Function*> ParserImpl::function_decl(
1342     ast::DecorationList& decos) {
1343   auto header = function_header();
1344   if (header.errored) {
1345     if (sync_to(Token::Type::kBraceLeft, /* consume: */ false)) {
1346       // There were errors in the function header, but the parser has managed to
1347       // resynchronize with the opening brace. As there's no outer
1348       // synchronization token for function declarations, attempt to parse the
1349       // function body. The AST isn't used as we've already errored, but this
1350       // catches any errors inside the body, and can help keep the parser in
1351       // sync.
1352       expect_body_stmt();
1353     }
1354     return Failure::kErrored;
1355   }
1356   if (!header.matched)
1357     return Failure::kNoMatch;
1358 
1359   bool errored = false;
1360 
1361   auto body = expect_body_stmt();
1362   if (body.errored)
1363     errored = true;
1364 
1365   if (errored)
1366     return Failure::kErrored;
1367 
1368   return create<ast::Function>(
1369       header->source, builder_.Symbols().Register(header->name), header->params,
1370       header->return_type, body.value, decos, header->return_type_decorations);
1371 }
1372 
1373 // function_header
1374 //   : FN IDENT PAREN_LEFT param_list PAREN_RIGHT return_type_decl_optional
1375 // return_type_decl_optional
1376 //   :
1377 //   | ARROW attribute_list* type_decl
function_header()1378 Maybe<ParserImpl::FunctionHeader> ParserImpl::function_header() {
1379   Source source;
1380   if (!match(Token::Type::kFn, &source)) {
1381     return Failure::kNoMatch;
1382   }
1383 
1384   const char* use = "function declaration";
1385   bool errored = false;
1386 
1387   auto name = expect_ident(use);
1388   if (name.errored) {
1389     errored = true;
1390     if (!sync_to(Token::Type::kParenLeft, /* consume: */ false)) {
1391       return Failure::kErrored;
1392     }
1393   }
1394 
1395   auto params = expect_paren_block(use, [&] { return expect_param_list(); });
1396   if (params.errored) {
1397     errored = true;
1398     if (!synchronized_) {
1399       return Failure::kErrored;
1400     }
1401   }
1402 
1403   const ast::Type* return_type = nullptr;
1404   ast::DecorationList return_decorations;
1405 
1406   if (match(Token::Type::kArrow)) {
1407     auto decos = decoration_list();
1408     if (decos.errored) {
1409       return Failure::kErrored;
1410     }
1411     return_decorations = decos.value;
1412 
1413     // Apply stride decorations to the type node instead of the function.
1414     ast::DecorationList type_decorations;
1415     auto itr = std::find_if(
1416         return_decorations.begin(), return_decorations.end(),
1417         [](auto* deco) { return Is<ast::StrideDecoration>(deco); });
1418     if (itr != return_decorations.end()) {
1419       type_decorations.emplace_back(*itr);
1420       return_decorations.erase(itr);
1421     }
1422 
1423     auto tok = peek();
1424 
1425     auto type = type_decl(type_decorations);
1426     if (type.errored) {
1427       errored = true;
1428     } else if (!type.matched) {
1429       return add_error(peek(), "unable to determine function return type");
1430     } else {
1431       return_type = type.value;
1432     }
1433   } else {
1434     return_type = builder_.ty.void_();
1435   }
1436 
1437   if (errored) {
1438     return Failure::kErrored;
1439   }
1440 
1441   return FunctionHeader{source, name.value, std::move(params.value),
1442                         return_type, std::move(return_decorations)};
1443 }
1444 
1445 // param_list
1446 //   :
1447 //   | (param COMMA)* param COMMA?
expect_param_list()1448 Expect<ast::VariableList> ParserImpl::expect_param_list() {
1449   ast::VariableList ret;
1450   while (continue_parsing()) {
1451     // Check for the end of the list.
1452     auto t = peek();
1453     if (!t.IsIdentifier() && !t.Is(Token::Type::kAttrLeft)) {
1454       break;
1455     }
1456 
1457     auto param = expect_param();
1458     if (param.errored)
1459       return Failure::kErrored;
1460     ret.push_back(param.value);
1461 
1462     if (!match(Token::Type::kComma))
1463       break;
1464   }
1465 
1466   return ret;
1467 }
1468 
1469 // param
1470 //   : decoration_list* variable_ident_decl
expect_param()1471 Expect<ast::Variable*> ParserImpl::expect_param() {
1472   auto decos = decoration_list();
1473 
1474   auto decl = expect_variable_ident_decl("parameter");
1475   if (decl.errored)
1476     return Failure::kErrored;
1477 
1478   auto* var =
1479       create<ast::Variable>(decl->source,                             // source
1480                             builder_.Symbols().Register(decl->name),  // symbol
1481                             ast::StorageClass::kNone,  // storage class
1482                             ast::Access::kUndefined,   // access control
1483                             decl->type,                // type
1484                             true,                      // is_const
1485                             nullptr,                   // constructor
1486                             std::move(decos.value));   // decorations
1487   // Formal parameters are treated like a const declaration where the
1488   // initializer value is provided by the call's argument.  The key point is
1489   // that it's not updatable after initially set.  This is unlike C or GLSL
1490   // which treat formal parameters like local variables that can be updated.
1491 
1492   return var;
1493 }
1494 
1495 // pipeline_stage
1496 //   : VERTEX
1497 //   | FRAGMENT
1498 //   | COMPUTE
expect_pipeline_stage()1499 Expect<ast::PipelineStage> ParserImpl::expect_pipeline_stage() {
1500   auto t = peek();
1501   if (!t.IsIdentifier()) {
1502     return add_error(t, "invalid value for stage decoration");
1503   }
1504 
1505   auto s = t.to_str();
1506   if (s == kVertexStage) {
1507     next();  // Consume the peek
1508     return {ast::PipelineStage::kVertex, t.source()};
1509   }
1510   if (s == kFragmentStage) {
1511     next();  // Consume the peek
1512     return {ast::PipelineStage::kFragment, t.source()};
1513   }
1514   if (s == kComputeStage) {
1515     next();  // Consume the peek
1516     return {ast::PipelineStage::kCompute, t.source()};
1517   }
1518 
1519   return add_error(peek(), "invalid value for stage decoration");
1520 }
1521 
expect_builtin()1522 Expect<ast::Builtin> ParserImpl::expect_builtin() {
1523   auto ident = expect_ident("builtin");
1524   if (ident.errored)
1525     return Failure::kErrored;
1526 
1527   ast::Builtin builtin = ident_to_builtin(ident.value);
1528   if (builtin == ast::Builtin::kNone)
1529     return add_error(ident.source, "invalid value for builtin decoration");
1530 
1531   return {builtin, ident.source};
1532 }
1533 
1534 // body_stmt
1535 //   : BRACKET_LEFT statements BRACKET_RIGHT
expect_body_stmt()1536 Expect<ast::BlockStatement*> ParserImpl::expect_body_stmt() {
1537   return expect_brace_block("", [&]() -> Expect<ast::BlockStatement*> {
1538     auto stmts = expect_statements();
1539     if (stmts.errored)
1540       return Failure::kErrored;
1541     return create<ast::BlockStatement>(Source{}, stmts.value);
1542   });
1543 }
1544 
1545 // paren_rhs_stmt
1546 //   : PAREN_LEFT logical_or_expression PAREN_RIGHT
expect_paren_rhs_stmt()1547 Expect<const ast::Expression*> ParserImpl::expect_paren_rhs_stmt() {
1548   return expect_paren_block("", [&]() -> Expect<const ast::Expression*> {
1549     auto expr = logical_or_expression();
1550     if (expr.errored)
1551       return Failure::kErrored;
1552     if (!expr.matched)
1553       return add_error(peek(), "unable to parse expression");
1554 
1555     return expr.value;
1556   });
1557 }
1558 
1559 // statements
1560 //   : statement*
expect_statements()1561 Expect<ast::StatementList> ParserImpl::expect_statements() {
1562   bool errored = false;
1563   ast::StatementList stmts;
1564 
1565   while (continue_parsing()) {
1566     auto stmt = statement();
1567     if (stmt.errored) {
1568       errored = true;
1569     } else if (stmt.matched) {
1570       stmts.emplace_back(stmt.value);
1571     } else {
1572       break;
1573     }
1574   }
1575 
1576   if (errored)
1577     return Failure::kErrored;
1578 
1579   return stmts;
1580 }
1581 
1582 // statement
1583 //   : SEMICOLON
1584 //   | body_stmt?
1585 //   | if_stmt
1586 //   | switch_stmt
1587 //   | loop_stmt
1588 //   | for_stmt
1589 //   | non_block_statement
1590 //      : return_stmt SEMICOLON
1591 //      | func_call_stmt SEMICOLON
1592 //      | variable_stmt SEMICOLON
1593 //      | break_stmt SEMICOLON
1594 //      | continue_stmt SEMICOLON
1595 //      | DISCARD SEMICOLON
1596 //      | assignment_stmt SEMICOLON
statement()1597 Maybe<const ast::Statement*> ParserImpl::statement() {
1598   while (match(Token::Type::kSemicolon)) {
1599     // Skip empty statements
1600   }
1601 
1602   // Non-block statments that error can resynchronize on semicolon.
1603   auto stmt =
1604       sync(Token::Type::kSemicolon, [&] { return non_block_statement(); });
1605 
1606   if (stmt.errored)
1607     return Failure::kErrored;
1608   if (stmt.matched)
1609     return stmt;
1610 
1611   auto stmt_if = if_stmt();
1612   if (stmt_if.errored)
1613     return Failure::kErrored;
1614   if (stmt_if.matched)
1615     return stmt_if.value;
1616 
1617   auto sw = switch_stmt();
1618   if (sw.errored)
1619     return Failure::kErrored;
1620   if (sw.matched)
1621     return sw.value;
1622 
1623   auto loop = loop_stmt();
1624   if (loop.errored)
1625     return Failure::kErrored;
1626   if (loop.matched)
1627     return loop.value;
1628 
1629   auto stmt_for = for_stmt();
1630   if (stmt_for.errored)
1631     return Failure::kErrored;
1632   if (stmt_for.matched)
1633     return stmt_for.value;
1634 
1635   if (peek_is(Token::Type::kBraceLeft)) {
1636     auto body = expect_body_stmt();
1637     if (body.errored)
1638       return Failure::kErrored;
1639     return body.value;
1640   }
1641 
1642   return Failure::kNoMatch;
1643 }
1644 
1645 // statement (continued)
1646 //   : return_stmt SEMICOLON
1647 //   | func_call_stmt SEMICOLON
1648 //   | variable_stmt SEMICOLON
1649 //   | break_stmt SEMICOLON
1650 //   | continue_stmt SEMICOLON
1651 //   | DISCARD SEMICOLON
1652 //   | assignment_stmt SEMICOLON
non_block_statement()1653 Maybe<const ast::Statement*> ParserImpl::non_block_statement() {
1654   auto stmt = [&]() -> Maybe<const ast::Statement*> {
1655     auto ret_stmt = return_stmt();
1656     if (ret_stmt.errored)
1657       return Failure::kErrored;
1658     if (ret_stmt.matched)
1659       return ret_stmt.value;
1660 
1661     auto func = func_call_stmt();
1662     if (func.errored)
1663       return Failure::kErrored;
1664     if (func.matched)
1665       return func.value;
1666 
1667     auto var = variable_stmt();
1668     if (var.errored)
1669       return Failure::kErrored;
1670     if (var.matched)
1671       return var.value;
1672 
1673     auto b = break_stmt();
1674     if (b.errored)
1675       return Failure::kErrored;
1676     if (b.matched)
1677       return b.value;
1678 
1679     auto cont = continue_stmt();
1680     if (cont.errored)
1681       return Failure::kErrored;
1682     if (cont.matched)
1683       return cont.value;
1684 
1685     auto assign = assignment_stmt();
1686     if (assign.errored)
1687       return Failure::kErrored;
1688     if (assign.matched)
1689       return assign.value;
1690 
1691     Source source;
1692     if (match(Token::Type::kDiscard, &source))
1693       return create<ast::DiscardStatement>(source);
1694 
1695     return Failure::kNoMatch;
1696   }();
1697 
1698   if (stmt.matched && !expect(stmt->Name(), Token::Type::kSemicolon))
1699     return Failure::kErrored;
1700 
1701   return stmt;
1702 }
1703 
1704 // return_stmt
1705 //   : RETURN logical_or_expression?
return_stmt()1706 Maybe<const ast::ReturnStatement*> ParserImpl::return_stmt() {
1707   Source source;
1708   if (!match(Token::Type::kReturn, &source))
1709     return Failure::kNoMatch;
1710 
1711   if (peek_is(Token::Type::kSemicolon))
1712     return create<ast::ReturnStatement>(source, nullptr);
1713 
1714   auto expr = logical_or_expression();
1715   if (expr.errored)
1716     return Failure::kErrored;
1717 
1718   // TODO(bclayton): Check matched?
1719   return create<ast::ReturnStatement>(source, expr.value);
1720 }
1721 
1722 // variable_stmt
1723 //   : variable_decl
1724 //   | variable_decl EQUAL logical_or_expression
1725 //   | CONST variable_ident_decl EQUAL logical_or_expression
variable_stmt()1726 Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
1727   if (match(Token::Type::kLet)) {
1728     auto decl = expect_variable_ident_decl("let declaration",
1729                                            /*allow_inferred = */ true);
1730     if (decl.errored)
1731       return Failure::kErrored;
1732 
1733     if (!expect("let declaration", Token::Type::kEqual))
1734       return Failure::kErrored;
1735 
1736     auto constructor = logical_or_expression();
1737     if (constructor.errored)
1738       return Failure::kErrored;
1739     if (!constructor.matched)
1740       return add_error(peek(), "missing constructor for let declaration");
1741 
1742     auto* var = create<ast::Variable>(
1743         decl->source,                             // source
1744         builder_.Symbols().Register(decl->name),  // symbol
1745         ast::StorageClass::kNone,                 // storage class
1746         ast::Access::kUndefined,                  // access control
1747         decl->type,                               // type
1748         true,                                     // is_const
1749         constructor.value,                        // constructor
1750         ast::DecorationList{});                   // decorations
1751 
1752     return create<ast::VariableDeclStatement>(decl->source, var);
1753   }
1754 
1755   auto decl = variable_decl(/*allow_inferred = */ true);
1756   if (decl.errored)
1757     return Failure::kErrored;
1758   if (!decl.matched)
1759     return Failure::kNoMatch;
1760 
1761   const ast::Expression* constructor = nullptr;
1762   if (match(Token::Type::kEqual)) {
1763     auto constructor_expr = logical_or_expression();
1764     if (constructor_expr.errored)
1765       return Failure::kErrored;
1766     if (!constructor_expr.matched)
1767       return add_error(peek(), "missing constructor for variable declaration");
1768 
1769     constructor = constructor_expr.value;
1770   }
1771 
1772   auto* var =
1773       create<ast::Variable>(decl->source,                             // source
1774                             builder_.Symbols().Register(decl->name),  // symbol
1775                             decl->storage_class,     // storage class
1776                             decl->access,            // access control
1777                             decl->type,              // type
1778                             false,                   // is_const
1779                             constructor,             // constructor
1780                             ast::DecorationList{});  // decorations
1781 
1782   return create<ast::VariableDeclStatement>(var->source, var);
1783 }
1784 
1785 // if_stmt
1786 //   : IF paren_rhs_stmt body_stmt elseif_stmt? else_stmt?
if_stmt()1787 Maybe<const ast::IfStatement*> ParserImpl::if_stmt() {
1788   Source source;
1789   if (!match(Token::Type::kIf, &source))
1790     return Failure::kNoMatch;
1791 
1792   auto condition = expect_paren_rhs_stmt();
1793   if (condition.errored)
1794     return Failure::kErrored;
1795 
1796   auto body = expect_body_stmt();
1797   if (body.errored)
1798     return Failure::kErrored;
1799 
1800   auto elseif = elseif_stmt();
1801   if (elseif.errored)
1802     return Failure::kErrored;
1803 
1804   auto el = else_stmt();
1805   if (el.errored)
1806     return Failure::kErrored;
1807   if (el.matched)
1808     elseif.value.push_back(el.value);
1809 
1810   return create<ast::IfStatement>(source, condition.value, body.value,
1811                                   elseif.value);
1812 }
1813 
1814 // elseif_stmt
1815 //   : ELSE_IF paren_rhs_stmt body_stmt elseif_stmt?
elseif_stmt()1816 Maybe<ast::ElseStatementList> ParserImpl::elseif_stmt() {
1817   Source source;
1818   if (!match(Token::Type::kElseIf, &source))
1819     return Failure::kNoMatch;
1820 
1821   ast::ElseStatementList ret;
1822   while (continue_parsing()) {
1823     auto condition = expect_paren_rhs_stmt();
1824     if (condition.errored)
1825       return Failure::kErrored;
1826 
1827     auto body = expect_body_stmt();
1828     if (body.errored)
1829       return Failure::kErrored;
1830 
1831     ret.push_back(
1832         create<ast::ElseStatement>(source, condition.value, body.value));
1833 
1834     if (!match(Token::Type::kElseIf, &source))
1835       break;
1836   }
1837 
1838   return ret;
1839 }
1840 
1841 // else_stmt
1842 //   : ELSE body_stmt
else_stmt()1843 Maybe<const ast::ElseStatement*> ParserImpl::else_stmt() {
1844   Source source;
1845   if (!match(Token::Type::kElse, &source))
1846     return Failure::kNoMatch;
1847 
1848   auto body = expect_body_stmt();
1849   if (body.errored)
1850     return Failure::kErrored;
1851 
1852   return create<ast::ElseStatement>(source, nullptr, body.value);
1853 }
1854 
1855 // switch_stmt
1856 //   : SWITCH paren_rhs_stmt BRACKET_LEFT switch_body+ BRACKET_RIGHT
switch_stmt()1857 Maybe<const ast::SwitchStatement*> ParserImpl::switch_stmt() {
1858   Source source;
1859   if (!match(Token::Type::kSwitch, &source))
1860     return Failure::kNoMatch;
1861 
1862   auto condition = expect_paren_rhs_stmt();
1863   if (condition.errored)
1864     return Failure::kErrored;
1865 
1866   auto body = expect_brace_block("switch statement",
1867                                  [&]() -> Expect<ast::CaseStatementList> {
1868                                    bool errored = false;
1869                                    ast::CaseStatementList list;
1870                                    while (continue_parsing()) {
1871                                      auto stmt = switch_body();
1872                                      if (stmt.errored) {
1873                                        errored = true;
1874                                        continue;
1875                                      }
1876                                      if (!stmt.matched)
1877                                        break;
1878                                      list.push_back(stmt.value);
1879                                    }
1880                                    if (errored)
1881                                      return Failure::kErrored;
1882                                    return list;
1883                                  });
1884 
1885   if (body.errored)
1886     return Failure::kErrored;
1887 
1888   return create<ast::SwitchStatement>(source, condition.value, body.value);
1889 }
1890 
1891 // switch_body
1892 //   : CASE case_selectors COLON BRACKET_LEFT case_body BRACKET_RIGHT
1893 //   | DEFAULT COLON BRACKET_LEFT case_body BRACKET_RIGHT
switch_body()1894 Maybe<const ast::CaseStatement*> ParserImpl::switch_body() {
1895   if (!peek_is(Token::Type::kCase) && !peek_is(Token::Type::kDefault))
1896     return Failure::kNoMatch;
1897 
1898   auto t = next();
1899   auto source = t.source();
1900 
1901   ast::CaseSelectorList selector_list;
1902   if (t.Is(Token::Type::kCase)) {
1903     auto selectors = expect_case_selectors();
1904     if (selectors.errored)
1905       return Failure::kErrored;
1906 
1907     selector_list = std::move(selectors.value);
1908   }
1909 
1910   const char* use = "case statement";
1911 
1912   if (!expect(use, Token::Type::kColon))
1913     return Failure::kErrored;
1914 
1915   auto body = expect_brace_block(use, [&] { return case_body(); });
1916 
1917   if (body.errored)
1918     return Failure::kErrored;
1919   if (!body.matched)
1920     return add_error(body.source, "expected case body");
1921 
1922   return create<ast::CaseStatement>(source, selector_list, body.value);
1923 }
1924 
1925 // case_selectors
1926 //   : const_literal (COMMA const_literal)* COMMA?
expect_case_selectors()1927 Expect<ast::CaseSelectorList> ParserImpl::expect_case_selectors() {
1928   ast::CaseSelectorList selectors;
1929 
1930   while (continue_parsing()) {
1931     auto cond = const_literal();
1932     if (cond.errored) {
1933       return Failure::kErrored;
1934     } else if (!cond.matched) {
1935       break;
1936     } else if (!cond->Is<ast::IntLiteralExpression>()) {
1937       return add_error(cond.value->source,
1938                        "invalid case selector must be an integer value");
1939     }
1940 
1941     selectors.push_back(cond.value->As<ast::IntLiteralExpression>());
1942 
1943     if (!match(Token::Type::kComma)) {
1944       break;
1945     }
1946   }
1947 
1948   if (selectors.empty())
1949     return add_error(peek(), "unable to parse case selectors");
1950 
1951   return selectors;
1952 }
1953 
1954 // case_body
1955 //   :
1956 //   | statement case_body
1957 //   | FALLTHROUGH SEMICOLON
case_body()1958 Maybe<const ast::BlockStatement*> ParserImpl::case_body() {
1959   ast::StatementList stmts;
1960   while (continue_parsing()) {
1961     Source source;
1962     if (match(Token::Type::kFallthrough, &source)) {
1963       if (!expect("fallthrough statement", Token::Type::kSemicolon))
1964         return Failure::kErrored;
1965 
1966       stmts.emplace_back(create<ast::FallthroughStatement>(source));
1967       break;
1968     }
1969 
1970     auto stmt = statement();
1971     if (stmt.errored)
1972       return Failure::kErrored;
1973     if (!stmt.matched)
1974       break;
1975 
1976     stmts.emplace_back(stmt.value);
1977   }
1978 
1979   return create<ast::BlockStatement>(Source{}, stmts);
1980 }
1981 
1982 // loop_stmt
1983 //   : LOOP BRACKET_LEFT statements continuing_stmt? BRACKET_RIGHT
loop_stmt()1984 Maybe<const ast::LoopStatement*> ParserImpl::loop_stmt() {
1985   Source source;
1986   if (!match(Token::Type::kLoop, &source))
1987     return Failure::kNoMatch;
1988 
1989   return expect_brace_block("loop", [&]() -> Maybe<const ast::LoopStatement*> {
1990     auto stmts = expect_statements();
1991     if (stmts.errored)
1992       return Failure::kErrored;
1993 
1994     auto continuing = continuing_stmt();
1995     if (continuing.errored)
1996       return Failure::kErrored;
1997 
1998     auto* body = create<ast::BlockStatement>(source, stmts.value);
1999     return create<ast::LoopStatement>(source, body, continuing.value);
2000   });
2001 }
2002 
ForHeader(const ast::Statement * init,const ast::Expression * cond,const ast::Statement * cont)2003 ForHeader::ForHeader(const ast::Statement* init,
2004                      const ast::Expression* cond,
2005                      const ast::Statement* cont)
2006     : initializer(init), condition(cond), continuing(cont) {}
2007 
2008 ForHeader::~ForHeader() = default;
2009 
2010 // (variable_stmt | assignment_stmt | func_call_stmt)?
for_header_initializer()2011 Maybe<const ast::Statement*> ParserImpl::for_header_initializer() {
2012   auto call = func_call_stmt();
2013   if (call.errored)
2014     return Failure::kErrored;
2015   if (call.matched)
2016     return call.value;
2017 
2018   auto var = variable_stmt();
2019   if (var.errored)
2020     return Failure::kErrored;
2021   if (var.matched)
2022     return var.value;
2023 
2024   auto assign = assignment_stmt();
2025   if (assign.errored)
2026     return Failure::kErrored;
2027   if (assign.matched)
2028     return assign.value;
2029 
2030   return Failure::kNoMatch;
2031 }
2032 
2033 // (assignment_stmt | func_call_stmt)?
for_header_continuing()2034 Maybe<const ast::Statement*> ParserImpl::for_header_continuing() {
2035   auto call_stmt = func_call_stmt();
2036   if (call_stmt.errored)
2037     return Failure::kErrored;
2038   if (call_stmt.matched)
2039     return call_stmt.value;
2040 
2041   auto assign = assignment_stmt();
2042   if (assign.errored)
2043     return Failure::kErrored;
2044   if (assign.matched)
2045     return assign.value;
2046 
2047   return Failure::kNoMatch;
2048 }
2049 
2050 // for_header
2051 //   : (variable_stmt | assignment_stmt | func_call_stmt)?
2052 //   SEMICOLON
2053 //      logical_or_expression? SEMICOLON
2054 //      (assignment_stmt | func_call_stmt)?
expect_for_header()2055 Expect<std::unique_ptr<ForHeader>> ParserImpl::expect_for_header() {
2056   auto initializer = for_header_initializer();
2057   if (initializer.errored)
2058     return Failure::kErrored;
2059 
2060   if (!expect("initializer in for loop", Token::Type::kSemicolon))
2061     return Failure::kErrored;
2062 
2063   auto condition = logical_or_expression();
2064   if (condition.errored)
2065     return Failure::kErrored;
2066 
2067   if (!expect("condition in for loop", Token::Type::kSemicolon))
2068     return Failure::kErrored;
2069 
2070   auto continuing = for_header_continuing();
2071   if (continuing.errored)
2072     return Failure::kErrored;
2073 
2074   return std::make_unique<ForHeader>(initializer.value, condition.value,
2075                                      continuing.value);
2076 }
2077 
2078 // for_statement
2079 //   : FOR PAREN_LEFT for_header PAREN_RIGHT BRACE_LEFT statements BRACE_RIGHT
for_stmt()2080 Maybe<const ast::ForLoopStatement*> ParserImpl::for_stmt() {
2081   Source source;
2082   if (!match(Token::Type::kFor, &source))
2083     return Failure::kNoMatch;
2084 
2085   auto header =
2086       expect_paren_block("for loop", [&] { return expect_for_header(); });
2087   if (header.errored)
2088     return Failure::kErrored;
2089 
2090   auto stmts =
2091       expect_brace_block("for loop", [&] { return expect_statements(); });
2092   if (stmts.errored)
2093     return Failure::kErrored;
2094 
2095   return create<ast::ForLoopStatement>(
2096       source, header->initializer, header->condition, header->continuing,
2097       create<ast::BlockStatement>(stmts.value));
2098 }
2099 
2100 // func_call_stmt
2101 //    : IDENT argument_expression_list
func_call_stmt()2102 Maybe<const ast::CallStatement*> ParserImpl::func_call_stmt() {
2103   auto t = peek();
2104   auto t2 = peek(1);
2105   if (!t.IsIdentifier() || !t2.Is(Token::Type::kParenLeft))
2106     return Failure::kNoMatch;
2107 
2108   next();  // Consume the first peek
2109 
2110   auto source = t.source();
2111   auto name = t.to_str();
2112 
2113   auto params = expect_argument_expression_list("function call");
2114   if (params.errored)
2115     return Failure::kErrored;
2116 
2117   return create<ast::CallStatement>(
2118       source, create<ast::CallExpression>(
2119                   source,
2120                   create<ast::IdentifierExpression>(
2121                       source, builder_.Symbols().Register(name)),
2122                   std::move(params.value)));
2123 }
2124 
2125 // break_stmt
2126 //   : BREAK
break_stmt()2127 Maybe<const ast::BreakStatement*> ParserImpl::break_stmt() {
2128   Source source;
2129   if (!match(Token::Type::kBreak, &source))
2130     return Failure::kNoMatch;
2131 
2132   return create<ast::BreakStatement>(source);
2133 }
2134 
2135 // continue_stmt
2136 //   : CONTINUE
continue_stmt()2137 Maybe<const ast::ContinueStatement*> ParserImpl::continue_stmt() {
2138   Source source;
2139   if (!match(Token::Type::kContinue, &source))
2140     return Failure::kNoMatch;
2141 
2142   return create<ast::ContinueStatement>(source);
2143 }
2144 
2145 // continuing_stmt
2146 //   : CONTINUING body_stmt
continuing_stmt()2147 Maybe<const ast::BlockStatement*> ParserImpl::continuing_stmt() {
2148   if (!match(Token::Type::kContinuing))
2149     return create<ast::BlockStatement>(Source{}, ast::StatementList{});
2150 
2151   return expect_body_stmt();
2152 }
2153 
2154 // primary_expression
2155 //   : IDENT argument_expression_list?
2156 //   | type_decl argument_expression_list
2157 //   | const_literal
2158 //   | paren_rhs_stmt
2159 //   | BITCAST LESS_THAN type_decl GREATER_THAN paren_rhs_stmt
primary_expression()2160 Maybe<const ast::Expression*> ParserImpl::primary_expression() {
2161   auto t = peek();
2162   auto source = t.source();
2163 
2164   auto lit = const_literal();
2165   if (lit.errored) {
2166     return Failure::kErrored;
2167   }
2168   if (lit.matched) {
2169     return lit.value;
2170   }
2171 
2172   if (t.Is(Token::Type::kParenLeft)) {
2173     auto paren = expect_paren_rhs_stmt();
2174     if (paren.errored) {
2175       return Failure::kErrored;
2176     }
2177 
2178     return paren.value;
2179   }
2180 
2181   if (match(Token::Type::kBitcast)) {
2182     const char* use = "bitcast expression";
2183 
2184     auto type = expect_lt_gt_block(use, [&] { return expect_type(use); });
2185     if (type.errored)
2186       return Failure::kErrored;
2187 
2188     auto params = expect_paren_rhs_stmt();
2189     if (params.errored)
2190       return Failure::kErrored;
2191 
2192     return create<ast::BitcastExpression>(source, type.value, params.value);
2193   }
2194 
2195   if (t.IsIdentifier()) {
2196     next();
2197 
2198     auto* ident = create<ast::IdentifierExpression>(
2199         t.source(), builder_.Symbols().Register(t.to_str()));
2200 
2201     if (peek_is(Token::Type::kParenLeft)) {
2202       auto params = expect_argument_expression_list("function call");
2203       if (params.errored)
2204         return Failure::kErrored;
2205 
2206       return create<ast::CallExpression>(source, ident,
2207                                          std::move(params.value));
2208     }
2209 
2210     return ident;
2211   }
2212 
2213   auto type = type_decl();
2214   if (type.errored)
2215     return Failure::kErrored;
2216   if (type.matched) {
2217     auto params = expect_argument_expression_list("type constructor");
2218     if (params.errored)
2219       return Failure::kErrored;
2220 
2221     return builder_.Construct(source, type.value, std::move(params.value));
2222   }
2223 
2224   return Failure::kNoMatch;
2225 }
2226 
2227 // postfix_expression
2228 //   :
2229 //   | BRACE_LEFT logical_or_expression BRACE_RIGHT postfix_expr
2230 //   | PERIOD IDENTIFIER postfix_expr
postfix_expression(const ast::Expression * prefix)2231 Maybe<const ast::Expression*> ParserImpl::postfix_expression(
2232     const ast::Expression* prefix) {
2233   Source source;
2234 
2235   while (continue_parsing()) {
2236     if (match(Token::Type::kPlusPlus, &source) ||
2237         match(Token::Type::kMinusMinus, &source)) {
2238       add_error(source,
2239                 "postfix increment and decrement operators are reserved for a "
2240                 "future WGSL version");
2241       return Failure::kErrored;
2242     }
2243 
2244     if (match(Token::Type::kBracketLeft, &source)) {
2245       auto res = sync(
2246           Token::Type::kBracketRight, [&]() -> Maybe<const ast::Expression*> {
2247             auto param = logical_or_expression();
2248             if (param.errored)
2249               return Failure::kErrored;
2250             if (!param.matched) {
2251               return add_error(peek(), "unable to parse expression inside []");
2252             }
2253 
2254             if (!expect("index accessor", Token::Type::kBracketRight)) {
2255               return Failure::kErrored;
2256             }
2257 
2258             return create<ast::IndexAccessorExpression>(source, prefix,
2259                                                         param.value);
2260           });
2261 
2262       if (res.errored) {
2263         return res;
2264       }
2265       prefix = res.value;
2266       continue;
2267     }
2268 
2269     if (match(Token::Type::kPeriod)) {
2270       auto ident = expect_ident("member accessor");
2271       if (ident.errored) {
2272         return Failure::kErrored;
2273       }
2274 
2275       prefix = create<ast::MemberAccessorExpression>(
2276           ident.source, prefix,
2277           create<ast::IdentifierExpression>(
2278               ident.source, builder_.Symbols().Register(ident.value)));
2279       continue;
2280     }
2281 
2282     return prefix;
2283   }
2284 
2285   return Failure::kErrored;
2286 }
2287 
2288 // singular_expression
2289 //   : primary_expression postfix_expr
singular_expression()2290 Maybe<const ast::Expression*> ParserImpl::singular_expression() {
2291   auto prefix = primary_expression();
2292   if (prefix.errored)
2293     return Failure::kErrored;
2294   if (!prefix.matched)
2295     return Failure::kNoMatch;
2296 
2297   return postfix_expression(prefix.value);
2298 }
2299 
2300 // argument_expression_list
2301 //   : PAREN_LEFT ((logical_or_expression COMMA)* logical_or_expression COMMA?)?
2302 //   PAREN_RIGHT
expect_argument_expression_list(const std::string & use)2303 Expect<ast::ExpressionList> ParserImpl::expect_argument_expression_list(
2304     const std::string& use) {
2305   return expect_paren_block(use, [&]() -> Expect<ast::ExpressionList> {
2306     ast::ExpressionList ret;
2307     while (continue_parsing()) {
2308       auto arg = logical_or_expression();
2309       if (arg.errored) {
2310         return Failure::kErrored;
2311       } else if (!arg.matched) {
2312         break;
2313       }
2314       ret.push_back(arg.value);
2315 
2316       if (!match(Token::Type::kComma)) {
2317         break;
2318       }
2319     }
2320     return ret;
2321   });
2322 }
2323 
2324 // unary_expression
2325 //   : singular_expression
2326 //   | MINUS unary_expression
2327 //   | BANG unary_expression
2328 //   | TILDE unary_expression
2329 //   | STAR unary_expression
2330 //   | AND unary_expression
unary_expression()2331 Maybe<const ast::Expression*> ParserImpl::unary_expression() {
2332   auto t = peek();
2333 
2334   if (match(Token::Type::kPlusPlus) || match(Token::Type::kMinusMinus)) {
2335     add_error(t.source(),
2336               "prefix increment and decrement operators are reserved for a "
2337               "future WGSL version");
2338     return Failure::kErrored;
2339   }
2340 
2341   ast::UnaryOp op;
2342   if (match(Token::Type::kMinus)) {
2343     op = ast::UnaryOp::kNegation;
2344   } else if (match(Token::Type::kBang)) {
2345     op = ast::UnaryOp::kNot;
2346   } else if (match(Token::Type::kTilde)) {
2347     op = ast::UnaryOp::kComplement;
2348   } else if (match(Token::Type::kStar)) {
2349     op = ast::UnaryOp::kIndirection;
2350   } else if (match(Token::Type::kAnd)) {
2351     op = ast::UnaryOp::kAddressOf;
2352   } else {
2353     return singular_expression();
2354   }
2355 
2356   if (parse_depth_ >= kMaxParseDepth) {
2357     // We've hit a maximum parser recursive depth.
2358     // We can't call into unary_expression() as we might stack overflow.
2359     // Instead, report an error
2360     add_error(peek(), "maximum parser recursive depth reached");
2361     return Failure::kErrored;
2362   }
2363 
2364   ++parse_depth_;
2365   auto expr = unary_expression();
2366   --parse_depth_;
2367 
2368   if (expr.errored) {
2369     return Failure::kErrored;
2370   }
2371   if (!expr.matched) {
2372     return add_error(
2373         peek(), "unable to parse right side of " + t.to_name() + " expression");
2374   }
2375 
2376   return create<ast::UnaryOpExpression>(t.source(), op, expr.value);
2377 }
2378 
2379 // multiplicative_expr
2380 //   :
2381 //   | STAR unary_expression multiplicative_expr
2382 //   | FORWARD_SLASH unary_expression multiplicative_expr
2383 //   | MODULO unary_expression multiplicative_expr
expect_multiplicative_expr(const ast::Expression * lhs)2384 Expect<const ast::Expression*> ParserImpl::expect_multiplicative_expr(
2385     const ast::Expression* lhs) {
2386   while (continue_parsing()) {
2387     ast::BinaryOp op = ast::BinaryOp::kNone;
2388     if (peek_is(Token::Type::kStar))
2389       op = ast::BinaryOp::kMultiply;
2390     else if (peek_is(Token::Type::kForwardSlash))
2391       op = ast::BinaryOp::kDivide;
2392     else if (peek_is(Token::Type::kMod))
2393       op = ast::BinaryOp::kModulo;
2394     else
2395       return lhs;
2396 
2397     auto t = next();
2398     auto source = t.source();
2399     auto name = t.to_name();
2400 
2401     auto rhs = unary_expression();
2402     if (rhs.errored)
2403       return Failure::kErrored;
2404     if (!rhs.matched) {
2405       return add_error(peek(),
2406                        "unable to parse right side of " + name + " expression");
2407     }
2408 
2409     lhs = create<ast::BinaryExpression>(source, op, lhs, rhs.value);
2410   }
2411   return Failure::kErrored;
2412 }
2413 
2414 // multiplicative_expression
2415 //   : unary_expression multiplicative_expr
multiplicative_expression()2416 Maybe<const ast::Expression*> ParserImpl::multiplicative_expression() {
2417   auto lhs = unary_expression();
2418   if (lhs.errored)
2419     return Failure::kErrored;
2420   if (!lhs.matched)
2421     return Failure::kNoMatch;
2422 
2423   return expect_multiplicative_expr(lhs.value);
2424 }
2425 
2426 // additive_expr
2427 //   :
2428 //   | PLUS multiplicative_expression additive_expr
2429 //   | MINUS multiplicative_expression additive_expr
expect_additive_expr(const ast::Expression * lhs)2430 Expect<const ast::Expression*> ParserImpl::expect_additive_expr(
2431     const ast::Expression* lhs) {
2432   while (continue_parsing()) {
2433     ast::BinaryOp op = ast::BinaryOp::kNone;
2434     if (peek_is(Token::Type::kPlus))
2435       op = ast::BinaryOp::kAdd;
2436     else if (peek_is(Token::Type::kMinus))
2437       op = ast::BinaryOp::kSubtract;
2438     else
2439       return lhs;
2440 
2441     auto t = next();
2442     auto source = t.source();
2443 
2444     auto rhs = multiplicative_expression();
2445     if (rhs.errored)
2446       return Failure::kErrored;
2447     if (!rhs.matched)
2448       return add_error(peek(), "unable to parse right side of + expression");
2449 
2450     lhs = create<ast::BinaryExpression>(source, op, lhs, rhs.value);
2451   }
2452   return Failure::kErrored;
2453 }
2454 
2455 // additive_expression
2456 //   : multiplicative_expression additive_expr
additive_expression()2457 Maybe<const ast::Expression*> ParserImpl::additive_expression() {
2458   auto lhs = multiplicative_expression();
2459   if (lhs.errored)
2460     return Failure::kErrored;
2461   if (!lhs.matched)
2462     return Failure::kNoMatch;
2463 
2464   return expect_additive_expr(lhs.value);
2465 }
2466 
2467 // shift_expr
2468 //   :
2469 //   | SHIFT_LEFT additive_expression shift_expr
2470 //   | SHIFT_RIGHT additive_expression shift_expr
expect_shift_expr(const ast::Expression * lhs)2471 Expect<const ast::Expression*> ParserImpl::expect_shift_expr(
2472     const ast::Expression* lhs) {
2473   while (continue_parsing()) {
2474     auto* name = "";
2475     ast::BinaryOp op = ast::BinaryOp::kNone;
2476     if (peek_is(Token::Type::kShiftLeft)) {
2477       op = ast::BinaryOp::kShiftLeft;
2478       name = "<<";
2479     } else if (peek_is(Token::Type::kShiftRight)) {
2480       op = ast::BinaryOp::kShiftRight;
2481       name = ">>";
2482     } else {
2483       return lhs;
2484     }
2485 
2486     auto t = next();
2487     auto source = t.source();
2488     auto rhs = additive_expression();
2489     if (rhs.errored)
2490       return Failure::kErrored;
2491     if (!rhs.matched) {
2492       return add_error(peek(), std::string("unable to parse right side of ") +
2493                                    name + " expression");
2494     }
2495 
2496     return lhs = create<ast::BinaryExpression>(source, op, lhs, rhs.value);
2497   }
2498   return Failure::kErrored;
2499 }
2500 
2501 // shift_expression
2502 //   : additive_expression shift_expr
shift_expression()2503 Maybe<const ast::Expression*> ParserImpl::shift_expression() {
2504   auto lhs = additive_expression();
2505   if (lhs.errored)
2506     return Failure::kErrored;
2507   if (!lhs.matched)
2508     return Failure::kNoMatch;
2509 
2510   return expect_shift_expr(lhs.value);
2511 }
2512 
2513 // relational_expr
2514 //   :
2515 //   | LESS_THAN shift_expression relational_expr
2516 //   | GREATER_THAN shift_expression relational_expr
2517 //   | LESS_THAN_EQUAL shift_expression relational_expr
2518 //   | GREATER_THAN_EQUAL shift_expression relational_expr
expect_relational_expr(const ast::Expression * lhs)2519 Expect<const ast::Expression*> ParserImpl::expect_relational_expr(
2520     const ast::Expression* lhs) {
2521   while (continue_parsing()) {
2522     ast::BinaryOp op = ast::BinaryOp::kNone;
2523     if (peek_is(Token::Type::kLessThan))
2524       op = ast::BinaryOp::kLessThan;
2525     else if (peek_is(Token::Type::kGreaterThan))
2526       op = ast::BinaryOp::kGreaterThan;
2527     else if (peek_is(Token::Type::kLessThanEqual))
2528       op = ast::BinaryOp::kLessThanEqual;
2529     else if (peek_is(Token::Type::kGreaterThanEqual))
2530       op = ast::BinaryOp::kGreaterThanEqual;
2531     else
2532       return lhs;
2533 
2534     auto t = next();
2535     auto source = t.source();
2536     auto name = t.to_name();
2537 
2538     auto rhs = shift_expression();
2539     if (rhs.errored)
2540       return Failure::kErrored;
2541     if (!rhs.matched) {
2542       return add_error(peek(),
2543                        "unable to parse right side of " + name + " expression");
2544     }
2545 
2546     lhs = create<ast::BinaryExpression>(source, op, lhs, rhs.value);
2547   }
2548   return Failure::kErrored;
2549 }
2550 
2551 // relational_expression
2552 //   : shift_expression relational_expr
relational_expression()2553 Maybe<const ast::Expression*> ParserImpl::relational_expression() {
2554   auto lhs = shift_expression();
2555   if (lhs.errored)
2556     return Failure::kErrored;
2557   if (!lhs.matched)
2558     return Failure::kNoMatch;
2559 
2560   return expect_relational_expr(lhs.value);
2561 }
2562 
2563 // equality_expr
2564 //   :
2565 //   | EQUAL_EQUAL relational_expression equality_expr
2566 //   | NOT_EQUAL relational_expression equality_expr
expect_equality_expr(const ast::Expression * lhs)2567 Expect<const ast::Expression*> ParserImpl::expect_equality_expr(
2568     const ast::Expression* lhs) {
2569   while (continue_parsing()) {
2570     ast::BinaryOp op = ast::BinaryOp::kNone;
2571     if (peek_is(Token::Type::kEqualEqual))
2572       op = ast::BinaryOp::kEqual;
2573     else if (peek_is(Token::Type::kNotEqual))
2574       op = ast::BinaryOp::kNotEqual;
2575     else
2576       return lhs;
2577 
2578     auto t = next();
2579     auto source = t.source();
2580     auto name = t.to_name();
2581 
2582     auto rhs = relational_expression();
2583     if (rhs.errored)
2584       return Failure::kErrored;
2585     if (!rhs.matched) {
2586       return add_error(peek(),
2587                        "unable to parse right side of " + name + " expression");
2588     }
2589 
2590     lhs = create<ast::BinaryExpression>(source, op, lhs, rhs.value);
2591   }
2592   return Failure::kErrored;
2593 }
2594 
2595 // equality_expression
2596 //   : relational_expression equality_expr
equality_expression()2597 Maybe<const ast::Expression*> ParserImpl::equality_expression() {
2598   auto lhs = relational_expression();
2599   if (lhs.errored)
2600     return Failure::kErrored;
2601   if (!lhs.matched)
2602     return Failure::kNoMatch;
2603 
2604   return expect_equality_expr(lhs.value);
2605 }
2606 
2607 // and_expr
2608 //   :
2609 //   | AND equality_expression and_expr
expect_and_expr(const ast::Expression * lhs)2610 Expect<const ast::Expression*> ParserImpl::expect_and_expr(
2611     const ast::Expression* lhs) {
2612   while (continue_parsing()) {
2613     if (!peek_is(Token::Type::kAnd)) {
2614       return lhs;
2615     }
2616 
2617     auto t = next();
2618     auto source = t.source();
2619 
2620     auto rhs = equality_expression();
2621     if (rhs.errored)
2622       return Failure::kErrored;
2623     if (!rhs.matched)
2624       return add_error(peek(), "unable to parse right side of & expression");
2625 
2626     lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kAnd, lhs,
2627                                         rhs.value);
2628   }
2629   return Failure::kErrored;
2630 }
2631 
2632 // and_expression
2633 //   : equality_expression and_expr
and_expression()2634 Maybe<const ast::Expression*> ParserImpl::and_expression() {
2635   auto lhs = equality_expression();
2636   if (lhs.errored)
2637     return Failure::kErrored;
2638   if (!lhs.matched)
2639     return Failure::kNoMatch;
2640 
2641   return expect_and_expr(lhs.value);
2642 }
2643 
2644 // exclusive_or_expr
2645 //   :
2646 //   | XOR and_expression exclusive_or_expr
expect_exclusive_or_expr(const ast::Expression * lhs)2647 Expect<const ast::Expression*> ParserImpl::expect_exclusive_or_expr(
2648     const ast::Expression* lhs) {
2649   while (continue_parsing()) {
2650     Source source;
2651     if (!match(Token::Type::kXor, &source))
2652       return lhs;
2653 
2654     auto rhs = and_expression();
2655     if (rhs.errored)
2656       return Failure::kErrored;
2657     if (!rhs.matched)
2658       return add_error(peek(), "unable to parse right side of ^ expression");
2659 
2660     lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kXor, lhs,
2661                                         rhs.value);
2662   }
2663   return Failure::kErrored;
2664 }
2665 
2666 // exclusive_or_expression
2667 //   : and_expression exclusive_or_expr
exclusive_or_expression()2668 Maybe<const ast::Expression*> ParserImpl::exclusive_or_expression() {
2669   auto lhs = and_expression();
2670   if (lhs.errored)
2671     return Failure::kErrored;
2672   if (!lhs.matched)
2673     return Failure::kNoMatch;
2674 
2675   return expect_exclusive_or_expr(lhs.value);
2676 }
2677 
2678 // inclusive_or_expr
2679 //   :
2680 //   | OR exclusive_or_expression inclusive_or_expr
expect_inclusive_or_expr(const ast::Expression * lhs)2681 Expect<const ast::Expression*> ParserImpl::expect_inclusive_or_expr(
2682     const ast::Expression* lhs) {
2683   while (continue_parsing()) {
2684     Source source;
2685     if (!match(Token::Type::kOr))
2686       return lhs;
2687 
2688     auto rhs = exclusive_or_expression();
2689     if (rhs.errored)
2690       return Failure::kErrored;
2691     if (!rhs.matched)
2692       return add_error(peek(), "unable to parse right side of | expression");
2693 
2694     lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kOr, lhs,
2695                                         rhs.value);
2696   }
2697   return Failure::kErrored;
2698 }
2699 
2700 // inclusive_or_expression
2701 //   : exclusive_or_expression inclusive_or_expr
inclusive_or_expression()2702 Maybe<const ast::Expression*> ParserImpl::inclusive_or_expression() {
2703   auto lhs = exclusive_or_expression();
2704   if (lhs.errored)
2705     return Failure::kErrored;
2706   if (!lhs.matched)
2707     return Failure::kNoMatch;
2708 
2709   return expect_inclusive_or_expr(lhs.value);
2710 }
2711 
2712 // logical_and_expr
2713 //   :
2714 //   | AND_AND inclusive_or_expression logical_and_expr
expect_logical_and_expr(const ast::Expression * lhs)2715 Expect<const ast::Expression*> ParserImpl::expect_logical_and_expr(
2716     const ast::Expression* lhs) {
2717   while (continue_parsing()) {
2718     if (!peek_is(Token::Type::kAndAnd)) {
2719       return lhs;
2720     }
2721 
2722     auto t = next();
2723     auto source = t.source();
2724 
2725     auto rhs = inclusive_or_expression();
2726     if (rhs.errored)
2727       return Failure::kErrored;
2728     if (!rhs.matched)
2729       return add_error(peek(), "unable to parse right side of && expression");
2730 
2731     lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalAnd, lhs,
2732                                         rhs.value);
2733   }
2734   return Failure::kErrored;
2735 }
2736 
2737 // logical_and_expression
2738 //   : inclusive_or_expression logical_and_expr
logical_and_expression()2739 Maybe<const ast::Expression*> ParserImpl::logical_and_expression() {
2740   auto lhs = inclusive_or_expression();
2741   if (lhs.errored)
2742     return Failure::kErrored;
2743   if (!lhs.matched)
2744     return Failure::kNoMatch;
2745 
2746   return expect_logical_and_expr(lhs.value);
2747 }
2748 
2749 // logical_or_expr
2750 //   :
2751 //   | OR_OR logical_and_expression logical_or_expr
expect_logical_or_expr(const ast::Expression * lhs)2752 Expect<const ast::Expression*> ParserImpl::expect_logical_or_expr(
2753     const ast::Expression* lhs) {
2754   while (continue_parsing()) {
2755     Source source;
2756     if (!match(Token::Type::kOrOr))
2757       return lhs;
2758 
2759     auto rhs = logical_and_expression();
2760     if (rhs.errored)
2761       return Failure::kErrored;
2762     if (!rhs.matched)
2763       return add_error(peek(), "unable to parse right side of || expression");
2764 
2765     lhs = create<ast::BinaryExpression>(source, ast::BinaryOp::kLogicalOr, lhs,
2766                                         rhs.value);
2767   }
2768   return Failure::kErrored;
2769 }
2770 
2771 // logical_or_expression
2772 //   : logical_and_expression logical_or_expr
logical_or_expression()2773 Maybe<const ast::Expression*> ParserImpl::logical_or_expression() {
2774   auto lhs = logical_and_expression();
2775   if (lhs.errored)
2776     return Failure::kErrored;
2777   if (!lhs.matched)
2778     return Failure::kNoMatch;
2779 
2780   return expect_logical_or_expr(lhs.value);
2781 }
2782 
2783 // assignment_stmt
2784 //   : (unary_expression | underscore) EQUAL logical_or_expression
assignment_stmt()2785 Maybe<const ast::AssignmentStatement*> ParserImpl::assignment_stmt() {
2786   auto t = peek();
2787   auto source = t.source();
2788 
2789   // tint:295 - Test for `ident COLON` - this is invalid grammar, and without
2790   // special casing will error as "missing = for assignment", which is less
2791   // helpful than this error message:
2792   if (peek_is(Token::Type::kIdentifier) && peek_is(Token::Type::kColon, 1)) {
2793     return add_error(peek(0).source(),
2794                      "expected 'var' for variable declaration");
2795   }
2796 
2797   auto lhs = unary_expression();
2798   if (lhs.errored) {
2799     return Failure::kErrored;
2800   }
2801   if (!lhs.matched) {
2802     if (!match(Token::Type::kUnderscore, &source)) {
2803       return Failure::kNoMatch;
2804     }
2805     lhs = create<ast::PhonyExpression>(source);
2806   }
2807 
2808   if (!expect("assignment", Token::Type::kEqual)) {
2809     return Failure::kErrored;
2810   }
2811 
2812   auto rhs = logical_or_expression();
2813   if (rhs.errored) {
2814     return Failure::kErrored;
2815   }
2816   if (!rhs.matched) {
2817     return add_error(peek(), "unable to parse right side of assignment");
2818   }
2819 
2820   return create<ast::AssignmentStatement>(source, lhs.value, rhs.value);
2821 }
2822 
2823 // const_literal
2824 //   : INT_LITERAL
2825 //   | UINT_LITERAL
2826 //   | FLOAT_LITERAL
2827 //   | TRUE
2828 //   | FALSE
const_literal()2829 Maybe<const ast::LiteralExpression*> ParserImpl::const_literal() {
2830   auto t = peek();
2831   if (t.IsError()) {
2832     return add_error(t.source(), t.to_str());
2833   }
2834   if (match(Token::Type::kTrue)) {
2835     return create<ast::BoolLiteralExpression>(t.source(), true);
2836   }
2837   if (match(Token::Type::kFalse)) {
2838     return create<ast::BoolLiteralExpression>(t.source(), false);
2839   }
2840   if (match(Token::Type::kSintLiteral)) {
2841     return create<ast::SintLiteralExpression>(t.source(), t.to_i32());
2842   }
2843   if (match(Token::Type::kUintLiteral)) {
2844     return create<ast::UintLiteralExpression>(t.source(), t.to_u32());
2845   }
2846   if (match(Token::Type::kFloatLiteral)) {
2847     return create<ast::FloatLiteralExpression>(t.source(), t.to_f32());
2848   }
2849   return Failure::kNoMatch;
2850 }
2851 
2852 // const_expr
2853 //   : type_decl PAREN_LEFT ((const_expr COMMA)? const_expr COMMA?)? PAREN_RIGHT
2854 //   | const_literal
expect_const_expr()2855 Expect<const ast::Expression*> ParserImpl::expect_const_expr() {
2856   auto t = peek();
2857   auto source = t.source();
2858   if (t.IsLiteral()) {
2859     auto lit = const_literal();
2860     if (lit.errored) {
2861       return Failure::kErrored;
2862     }
2863     if (!lit.matched) {
2864       return add_error(peek(), "unable to parse constant literal");
2865     }
2866     return lit.value;
2867   }
2868 
2869   if (peek_is(Token::Type::kParenLeft, 1) ||
2870       peek_is(Token::Type::kLessThan, 1)) {
2871     auto type = expect_type("const_expr");
2872     if (type.errored) {
2873       return Failure::kErrored;
2874     }
2875 
2876     auto params = expect_paren_block(
2877         "type constructor", [&]() -> Expect<ast::ExpressionList> {
2878           ast::ExpressionList list;
2879           while (continue_parsing()) {
2880             if (peek_is(Token::Type::kParenRight)) {
2881               break;
2882             }
2883 
2884             auto arg = expect_const_expr();
2885             if (arg.errored) {
2886               return Failure::kErrored;
2887             }
2888             list.emplace_back(arg.value);
2889 
2890             if (!match(Token::Type::kComma)) {
2891               break;
2892             }
2893           }
2894           return list;
2895         });
2896 
2897     if (params.errored)
2898       return Failure::kErrored;
2899 
2900     return builder_.Construct(source, type.value, params.value);
2901   }
2902   return add_error(peek(), "unable to parse const_expr");
2903 }
2904 
decoration_list()2905 Maybe<ast::DecorationList> ParserImpl::decoration_list() {
2906   bool errored = false;
2907   bool matched = false;
2908   ast::DecorationList decos;
2909 
2910   while (continue_parsing()) {
2911     auto list = decoration_bracketed_list(decos);
2912     if (list.errored)
2913       errored = true;
2914     if (!list.matched)
2915       break;
2916 
2917     matched = true;
2918   }
2919 
2920   if (errored)
2921     return Failure::kErrored;
2922 
2923   if (!matched)
2924     return Failure::kNoMatch;
2925 
2926   return decos;
2927 }
2928 
decoration_bracketed_list(ast::DecorationList & decos)2929 Maybe<bool> ParserImpl::decoration_bracketed_list(ast::DecorationList& decos) {
2930   const char* use = "decoration list";
2931 
2932   if (!match(Token::Type::kAttrLeft)) {
2933     return Failure::kNoMatch;
2934   }
2935 
2936   Source source;
2937   if (match(Token::Type::kAttrRight, &source))
2938     return add_error(source, "empty decoration list");
2939 
2940   return sync(Token::Type::kAttrRight, [&]() -> Expect<bool> {
2941     bool errored = false;
2942 
2943     while (continue_parsing()) {
2944       auto deco = expect_decoration();
2945       if (deco.errored) {
2946         errored = true;
2947       }
2948       decos.emplace_back(deco.value);
2949 
2950       if (match(Token::Type::kComma)) {
2951         continue;
2952       }
2953 
2954       if (is_decoration(peek())) {
2955         // We have two decorations in a bracket without a separating comma.
2956         // e.g. [[location(1) group(2)]]
2957         //                    ^^^ expected comma
2958         expect(use, Token::Type::kComma);
2959         return Failure::kErrored;
2960       }
2961 
2962       break;
2963     }
2964 
2965     if (errored) {
2966       return Failure::kErrored;
2967     }
2968 
2969     if (!expect(use, Token::Type::kAttrRight)) {
2970       return Failure::kErrored;
2971     }
2972 
2973     return true;
2974   });
2975 }
2976 
expect_decoration()2977 Expect<const ast::Decoration*> ParserImpl::expect_decoration() {
2978   auto t = peek();
2979   auto deco = decoration();
2980   if (deco.errored)
2981     return Failure::kErrored;
2982   if (deco.matched)
2983     return deco.value;
2984   return add_error(t, "expected decoration");
2985 }
2986 
decoration()2987 Maybe<const ast::Decoration*> ParserImpl::decoration() {
2988   using Result = Maybe<const ast::Decoration*>;
2989   auto t = next();
2990 
2991   if (!t.IsIdentifier()) {
2992     return Failure::kNoMatch;
2993   }
2994 
2995   auto s = t.to_str();
2996 
2997   if (s == kLocationDecoration) {
2998     const char* use = "location decoration";
2999     return expect_paren_block(use, [&]() -> Result {
3000       auto val = expect_positive_sint(use);
3001       if (val.errored)
3002         return Failure::kErrored;
3003 
3004       return create<ast::LocationDecoration>(t.source(), val.value);
3005     });
3006   }
3007 
3008   if (s == kBindingDecoration) {
3009     const char* use = "binding decoration";
3010     return expect_paren_block(use, [&]() -> Result {
3011       auto val = expect_positive_sint(use);
3012       if (val.errored)
3013         return Failure::kErrored;
3014 
3015       return create<ast::BindingDecoration>(t.source(), val.value);
3016     });
3017   }
3018 
3019   if (s == kGroupDecoration) {
3020     const char* use = "group decoration";
3021     return expect_paren_block(use, [&]() -> Result {
3022       auto val = expect_positive_sint(use);
3023       if (val.errored)
3024         return Failure::kErrored;
3025 
3026       return create<ast::GroupDecoration>(t.source(), val.value);
3027     });
3028   }
3029 
3030   if (s == kInterpolateDecoration) {
3031     return expect_paren_block("interpolate decoration", [&]() -> Result {
3032       ast::InterpolationType type;
3033       ast::InterpolationSampling sampling = ast::InterpolationSampling::kNone;
3034 
3035       auto type_tok = next();
3036       auto type_str = type_tok.to_str();
3037       if (type_str == "perspective") {
3038         type = ast::InterpolationType::kPerspective;
3039       } else if (type_str == "linear") {
3040         type = ast::InterpolationType::kLinear;
3041       } else if (type_str == "flat") {
3042         type = ast::InterpolationType::kFlat;
3043       } else {
3044         return add_error(type_tok, "invalid interpolation type");
3045       }
3046 
3047       if (match(Token::Type::kComma)) {
3048         auto sampling_tok = next();
3049         auto sampling_str = sampling_tok.to_str();
3050         if (sampling_str == "center") {
3051           sampling = ast::InterpolationSampling::kCenter;
3052         } else if (sampling_str == "centroid") {
3053           sampling = ast::InterpolationSampling::kCentroid;
3054         } else if (sampling_str == "sample") {
3055           sampling = ast::InterpolationSampling::kSample;
3056         } else {
3057           return add_error(sampling_tok, "invalid interpolation sampling");
3058         }
3059       }
3060 
3061       return create<ast::InterpolateDecoration>(t.source(), type, sampling);
3062     });
3063   }
3064 
3065   if (s == kInvariantDecoration) {
3066     return create<ast::InvariantDecoration>(t.source());
3067   }
3068 
3069   if (s == kBuiltinDecoration) {
3070     return expect_paren_block("builtin decoration", [&]() -> Result {
3071       auto builtin = expect_builtin();
3072       if (builtin.errored)
3073         return Failure::kErrored;
3074 
3075       return create<ast::BuiltinDecoration>(t.source(), builtin.value);
3076     });
3077   }
3078 
3079   if (s == kWorkgroupSizeDecoration) {
3080     return expect_paren_block("workgroup_size decoration", [&]() -> Result {
3081       const ast::Expression* x = nullptr;
3082       const ast::Expression* y = nullptr;
3083       const ast::Expression* z = nullptr;
3084 
3085       auto expr = primary_expression();
3086       if (expr.errored) {
3087         return Failure::kErrored;
3088       } else if (!expr.matched) {
3089         return add_error(peek(), "expected workgroup_size x parameter");
3090       }
3091       x = std::move(expr.value);
3092 
3093       if (match(Token::Type::kComma)) {
3094         expr = primary_expression();
3095         if (expr.errored) {
3096           return Failure::kErrored;
3097         } else if (!expr.matched) {
3098           return add_error(peek(), "expected workgroup_size y parameter");
3099         }
3100         y = std::move(expr.value);
3101 
3102         if (match(Token::Type::kComma)) {
3103           expr = primary_expression();
3104           if (expr.errored) {
3105             return Failure::kErrored;
3106           } else if (!expr.matched) {
3107             return add_error(peek(), "expected workgroup_size z parameter");
3108           }
3109           z = std::move(expr.value);
3110         }
3111       }
3112 
3113       return create<ast::WorkgroupDecoration>(t.source(), x, y, z);
3114     });
3115   }
3116 
3117   if (s == kStageDecoration) {
3118     return expect_paren_block("stage decoration", [&]() -> Result {
3119       auto stage = expect_pipeline_stage();
3120       if (stage.errored)
3121         return Failure::kErrored;
3122 
3123       return create<ast::StageDecoration>(t.source(), stage.value);
3124     });
3125   }
3126 
3127   if (s == kBlockDecoration) {
3128     return create<ast::StructBlockDecoration>(t.source());
3129   }
3130 
3131   if (s == kStrideDecoration) {
3132     const char* use = "stride decoration";
3133     return expect_paren_block(use, [&]() -> Result {
3134       auto val = expect_nonzero_positive_sint(use);
3135       if (val.errored)
3136         return Failure::kErrored;
3137 
3138       return create<ast::StrideDecoration>(t.source(), val.value);
3139     });
3140   }
3141 
3142   if (s == kSizeDecoration) {
3143     const char* use = "size decoration";
3144     return expect_paren_block(use, [&]() -> Result {
3145       auto val = expect_positive_sint(use);
3146       if (val.errored)
3147         return Failure::kErrored;
3148 
3149       return create<ast::StructMemberSizeDecoration>(t.source(), val.value);
3150     });
3151   }
3152 
3153   if (s == kAlignDecoration) {
3154     const char* use = "align decoration";
3155     return expect_paren_block(use, [&]() -> Result {
3156       auto val = expect_positive_sint(use);
3157       if (val.errored)
3158         return Failure::kErrored;
3159 
3160       return create<ast::StructMemberAlignDecoration>(t.source(), val.value);
3161     });
3162   }
3163 
3164   if (s == kOverrideDecoration) {
3165     const char* use = "override decoration";
3166 
3167     if (peek_is(Token::Type::kParenLeft)) {
3168       // [[override(x)]]
3169       return expect_paren_block(use, [&]() -> Result {
3170         auto val = expect_positive_sint(use);
3171         if (val.errored)
3172           return Failure::kErrored;
3173 
3174         return create<ast::OverrideDecoration>(t.source(), val.value);
3175       });
3176     } else {
3177       // [[override]]
3178       return create<ast::OverrideDecoration>(t.source());
3179     }
3180   }
3181 
3182   return Failure::kNoMatch;
3183 }
3184 
expect_decorations_consumed(ast::DecorationList & in)3185 bool ParserImpl::expect_decorations_consumed(ast::DecorationList& in) {
3186   if (in.empty()) {
3187     return true;
3188   }
3189   add_error(in[0]->source, "unexpected decorations");
3190   return false;
3191 }
3192 
match(Token::Type tok,Source * source)3193 bool ParserImpl::match(Token::Type tok, Source* source /*= nullptr*/) {
3194   auto t = peek();
3195 
3196   if (source != nullptr)
3197     *source = t.source();
3198 
3199   if (t.Is(tok)) {
3200     next();
3201     return true;
3202   }
3203   return false;
3204 }
3205 
expect(const std::string & use,Token::Type tok)3206 bool ParserImpl::expect(const std::string& use, Token::Type tok) {
3207   auto t = peek();
3208   if (t.Is(tok)) {
3209     next();
3210     synchronized_ = true;
3211     return true;
3212   }
3213 
3214   // Special case to split `>>` and `>=` tokens if we are looking for a `>`.
3215   if (tok == Token::Type::kGreaterThan &&
3216       (t.Is(Token::Type::kShiftRight) ||
3217        t.Is(Token::Type::kGreaterThanEqual))) {
3218     next();
3219 
3220     // Push the second character to the token queue.
3221     auto source = t.source();
3222     source.range.begin.column++;
3223     if (t.Is(Token::Type::kShiftRight)) {
3224       token_queue_.push_front(Token(Token::Type::kGreaterThan, source));
3225     } else if (t.Is(Token::Type::kGreaterThanEqual)) {
3226       token_queue_.push_front(Token(Token::Type::kEqual, source));
3227     }
3228 
3229     synchronized_ = true;
3230     return true;
3231   }
3232 
3233   // Handle the case when `]` is expected but the actual token is `]]`.
3234   // For example, in `arr1[arr2[0]]`.
3235   if (tok == Token::Type::kBracketRight && t.Is(Token::Type::kAttrRight)) {
3236     next();
3237     auto source = t.source();
3238     source.range.begin.column++;
3239     token_queue_.push_front({Token::Type::kBracketRight, source});
3240     synchronized_ = true;
3241     return true;
3242   }
3243 
3244   std::stringstream err;
3245   err << "expected '" << Token::TypeToName(tok) << "'";
3246   if (!use.empty()) {
3247     err << " for " << use;
3248   }
3249   add_error(t, err.str());
3250   synchronized_ = false;
3251   return false;
3252 }
3253 
expect_sint(const std::string & use)3254 Expect<int32_t> ParserImpl::expect_sint(const std::string& use) {
3255   auto t = peek();
3256   if (!t.Is(Token::Type::kSintLiteral))
3257     return add_error(t.source(), "expected signed integer literal", use);
3258 
3259   next();
3260   return {t.to_i32(), t.source()};
3261 }
3262 
expect_positive_sint(const std::string & use)3263 Expect<uint32_t> ParserImpl::expect_positive_sint(const std::string& use) {
3264   auto sint = expect_sint(use);
3265   if (sint.errored)
3266     return Failure::kErrored;
3267 
3268   if (sint.value < 0)
3269     return add_error(sint.source, use + " must be positive");
3270 
3271   return {static_cast<uint32_t>(sint.value), sint.source};
3272 }
3273 
expect_nonzero_positive_sint(const std::string & use)3274 Expect<uint32_t> ParserImpl::expect_nonzero_positive_sint(
3275     const std::string& use) {
3276   auto sint = expect_sint(use);
3277   if (sint.errored)
3278     return Failure::kErrored;
3279 
3280   if (sint.value <= 0)
3281     return add_error(sint.source, use + " must be greater than 0");
3282 
3283   return {static_cast<uint32_t>(sint.value), sint.source};
3284 }
3285 
expect_ident(const std::string & use)3286 Expect<std::string> ParserImpl::expect_ident(const std::string& use) {
3287   auto t = peek();
3288   if (t.IsIdentifier()) {
3289     synchronized_ = true;
3290     next();
3291 
3292     if (is_reserved(t)) {
3293       return add_error(t.source(),
3294                        "'" + t.to_str() + "' is a reserved keyword");
3295     }
3296 
3297     return {t.to_str(), t.source()};
3298   }
3299   synchronized_ = false;
3300   return add_error(t.source(), "expected identifier", use);
3301 }
3302 
3303 template <typename F, typename T>
expect_block(Token::Type start,Token::Type end,const std::string & use,F && body)3304 T ParserImpl::expect_block(Token::Type start,
3305                            Token::Type end,
3306                            const std::string& use,
3307                            F&& body) {
3308   if (!expect(use, start)) {
3309     return Failure::kErrored;
3310   }
3311 
3312   return sync(end, [&]() -> T {
3313     auto res = body();
3314 
3315     if (res.errored)
3316       return Failure::kErrored;
3317 
3318     if (!expect(use, end))
3319       return Failure::kErrored;
3320 
3321     return res;
3322   });
3323 }
3324 
3325 template <typename F, typename T>
expect_paren_block(const std::string & use,F && body)3326 T ParserImpl::expect_paren_block(const std::string& use, F&& body) {
3327   return expect_block(Token::Type::kParenLeft, Token::Type::kParenRight, use,
3328                       std::forward<F>(body));
3329 }
3330 
3331 template <typename F, typename T>
expect_brace_block(const std::string & use,F && body)3332 T ParserImpl::expect_brace_block(const std::string& use, F&& body) {
3333   return expect_block(Token::Type::kBraceLeft, Token::Type::kBraceRight, use,
3334                       std::forward<F>(body));
3335 }
3336 
3337 template <typename F, typename T>
expect_lt_gt_block(const std::string & use,F && body)3338 T ParserImpl::expect_lt_gt_block(const std::string& use, F&& body) {
3339   return expect_block(Token::Type::kLessThan, Token::Type::kGreaterThan, use,
3340                       std::forward<F>(body));
3341 }
3342 
3343 template <typename F, typename T>
sync(Token::Type tok,F && body)3344 T ParserImpl::sync(Token::Type tok, F&& body) {
3345   if (parse_depth_ >= kMaxParseDepth) {
3346     // We've hit a maximum parser recursive depth.
3347     // We can't call into body() as we might stack overflow.
3348     // Instead, report an error...
3349     add_error(peek(), "maximum parser recursive depth reached");
3350     // ...and try to resynchronize. If we cannot resynchronize to `tok` then
3351     // synchronized_ is set to false, and the parser knows that forward progress
3352     // is not being made.
3353     sync_to(tok, /* consume: */ true);
3354     return Failure::kErrored;
3355   }
3356 
3357   sync_tokens_.push_back(tok);
3358 
3359   ++parse_depth_;
3360   auto result = body();
3361   --parse_depth_;
3362 
3363   if (sync_tokens_.back() != tok) {
3364     TINT_ICE(Reader, builder_.Diagnostics()) << "sync_tokens is out of sync";
3365   }
3366   sync_tokens_.pop_back();
3367 
3368   if (result.errored) {
3369     sync_to(tok, /* consume: */ true);
3370   }
3371 
3372   return result;
3373 }
3374 
sync_to(Token::Type tok,bool consume)3375 bool ParserImpl::sync_to(Token::Type tok, bool consume) {
3376   // Clear the synchronized state - gets set to true again on success.
3377   synchronized_ = false;
3378 
3379   BlockCounters counters;
3380 
3381   for (size_t i = 0; i < kMaxResynchronizeLookahead; i++) {
3382     auto t = peek(i);
3383     if (counters.consume(t) > 0) {
3384       continue;  // Nested block
3385     }
3386     if (!t.Is(tok) && !is_sync_token(t)) {
3387       continue;  // Not a synchronization point
3388     }
3389 
3390     // Synchronization point found.
3391 
3392     // Skip any tokens we don't understand, bringing us to just before the
3393     // resync point.
3394     while (i-- > 0) {
3395       next();
3396     }
3397 
3398     // Is this synchronization token |tok|?
3399     if (t.Is(tok)) {
3400       if (consume) {
3401         next();
3402       }
3403       synchronized_ = true;
3404       return true;
3405     }
3406     break;
3407   }
3408 
3409   return false;
3410 }
3411 
is_sync_token(const Token & t) const3412 bool ParserImpl::is_sync_token(const Token& t) const {
3413   for (auto r : sync_tokens_) {
3414     if (t.Is(r)) {
3415       return true;
3416     }
3417   }
3418   return false;
3419 }
3420 
3421 template <typename F, typename T>
without_error(F && body)3422 T ParserImpl::without_error(F&& body) {
3423   silence_errors_++;
3424   auto result = body();
3425   silence_errors_--;
3426   return result;
3427 }
3428 
make_source_range()3429 ParserImpl::MultiTokenSource ParserImpl::make_source_range() {
3430   return MultiTokenSource(this);
3431 }
3432 
make_source_range_from(const Source & start)3433 ParserImpl::MultiTokenSource ParserImpl::make_source_range_from(
3434     const Source& start) {
3435   return MultiTokenSource(this, start);
3436 }
3437 
3438 }  // namespace wgsl
3439 }  // namespace reader
3440 }  // namespace tint
3441