• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/wgsl/generator_impl.h"
16 
17 #include <algorithm>
18 
19 #include "src/ast/access.h"
20 #include "src/ast/alias.h"
21 #include "src/ast/array.h"
22 #include "src/ast/atomic.h"
23 #include "src/ast/bool.h"
24 #include "src/ast/bool_literal_expression.h"
25 #include "src/ast/call_statement.h"
26 #include "src/ast/depth_texture.h"
27 #include "src/ast/external_texture.h"
28 #include "src/ast/f32.h"
29 #include "src/ast/float_literal_expression.h"
30 #include "src/ast/i32.h"
31 #include "src/ast/internal_decoration.h"
32 #include "src/ast/interpolate_decoration.h"
33 #include "src/ast/invariant_decoration.h"
34 #include "src/ast/matrix.h"
35 #include "src/ast/module.h"
36 #include "src/ast/multisampled_texture.h"
37 #include "src/ast/override_decoration.h"
38 #include "src/ast/pointer.h"
39 #include "src/ast/sampled_texture.h"
40 #include "src/ast/sint_literal_expression.h"
41 #include "src/ast/stage_decoration.h"
42 #include "src/ast/storage_texture.h"
43 #include "src/ast/stride_decoration.h"
44 #include "src/ast/struct_block_decoration.h"
45 #include "src/ast/struct_member_align_decoration.h"
46 #include "src/ast/struct_member_offset_decoration.h"
47 #include "src/ast/struct_member_size_decoration.h"
48 #include "src/ast/type_name.h"
49 #include "src/ast/u32.h"
50 #include "src/ast/uint_literal_expression.h"
51 #include "src/ast/variable_decl_statement.h"
52 #include "src/ast/vector.h"
53 #include "src/ast/void.h"
54 #include "src/ast/workgroup_decoration.h"
55 #include "src/sem/struct.h"
56 #include "src/utils/math.h"
57 #include "src/utils/scoped_assignment.h"
58 #include "src/writer/float_to_string.h"
59 
60 namespace tint {
61 namespace writer {
62 namespace wgsl {
63 
GeneratorImpl(const Program * program)64 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
65 
66 GeneratorImpl::~GeneratorImpl() = default;
67 
Generate()68 bool GeneratorImpl::Generate() {
69   // Generate global declarations in the order they appear in the module.
70   for (auto* decl : program_->AST().GlobalDeclarations()) {
71     if (auto* td = decl->As<ast::TypeDecl>()) {
72       if (!EmitTypeDecl(td)) {
73         return false;
74       }
75     } else if (auto* func = decl->As<ast::Function>()) {
76       if (!EmitFunction(func)) {
77         return false;
78       }
79     } else if (auto* var = decl->As<ast::Variable>()) {
80       if (!EmitVariable(line(), var)) {
81         return false;
82       }
83     } else {
84       TINT_UNREACHABLE(Writer, diagnostics_);
85       return false;
86     }
87 
88     if (decl != program_->AST().GlobalDeclarations().back()) {
89       line();
90     }
91   }
92 
93   return true;
94 }
95 
EmitTypeDecl(const ast::TypeDecl * ty)96 bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
97   if (auto* alias = ty->As<ast::Alias>()) {
98     auto out = line();
99     out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
100     if (!EmitType(out, alias->type)) {
101       return false;
102     }
103     out << ";";
104   } else if (auto* str = ty->As<ast::Struct>()) {
105     if (!EmitStructType(str)) {
106       return false;
107     }
108   } else {
109     diagnostics_.add_error(
110         diag::System::Writer,
111         "unknown declared type: " + std::string(ty->TypeInfo().name));
112     return false;
113   }
114   return true;
115 }
116 
EmitExpression(std::ostream & out,const ast::Expression * expr)117 bool GeneratorImpl::EmitExpression(std::ostream& out,
118                                    const ast::Expression* expr) {
119   if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
120     return EmitIndexAccessor(out, a);
121   }
122   if (auto* b = expr->As<ast::BinaryExpression>()) {
123     return EmitBinary(out, b);
124   }
125   if (auto* b = expr->As<ast::BitcastExpression>()) {
126     return EmitBitcast(out, b);
127   }
128   if (auto* c = expr->As<ast::CallExpression>()) {
129     return EmitCall(out, c);
130   }
131   if (auto* i = expr->As<ast::IdentifierExpression>()) {
132     return EmitIdentifier(out, i);
133   }
134   if (auto* l = expr->As<ast::LiteralExpression>()) {
135     return EmitLiteral(out, l);
136   }
137   if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
138     return EmitMemberAccessor(out, m);
139   }
140   if (expr->Is<ast::PhonyExpression>()) {
141     out << "_";
142     return true;
143   }
144   if (auto* u = expr->As<ast::UnaryOpExpression>()) {
145     return EmitUnaryOp(out, u);
146   }
147 
148   diagnostics_.add_error(diag::System::Writer, "unknown expression type");
149   return false;
150 }
151 
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)152 bool GeneratorImpl::EmitIndexAccessor(
153     std::ostream& out,
154     const ast::IndexAccessorExpression* expr) {
155   bool paren_lhs =
156       !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
157                              ast::IdentifierExpression,
158                              ast::MemberAccessorExpression>();
159   if (paren_lhs) {
160     out << "(";
161   }
162   if (!EmitExpression(out, expr->object)) {
163     return false;
164   }
165   if (paren_lhs) {
166     out << ")";
167   }
168   out << "[";
169 
170   if (!EmitExpression(out, expr->index)) {
171     return false;
172   }
173   out << "]";
174 
175   return true;
176 }
177 
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)178 bool GeneratorImpl::EmitMemberAccessor(
179     std::ostream& out,
180     const ast::MemberAccessorExpression* expr) {
181   bool paren_lhs =
182       !expr->structure->IsAnyOf<ast::IndexAccessorExpression,
183                                 ast::CallExpression, ast::IdentifierExpression,
184                                 ast::MemberAccessorExpression>();
185   if (paren_lhs) {
186     out << "(";
187   }
188   if (!EmitExpression(out, expr->structure)) {
189     return false;
190   }
191   if (paren_lhs) {
192     out << ")";
193   }
194 
195   out << ".";
196 
197   return EmitExpression(out, expr->member);
198 }
199 
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)200 bool GeneratorImpl::EmitBitcast(std::ostream& out,
201                                 const ast::BitcastExpression* expr) {
202   out << "bitcast<";
203   if (!EmitType(out, expr->type)) {
204     return false;
205   }
206 
207   out << ">(";
208   if (!EmitExpression(out, expr->expr)) {
209     return false;
210   }
211 
212   out << ")";
213   return true;
214 }
215 
EmitCall(std::ostream & out,const ast::CallExpression * expr)216 bool GeneratorImpl::EmitCall(std::ostream& out,
217                              const ast::CallExpression* expr) {
218   if (expr->target.name) {
219     if (!EmitExpression(out, expr->target.name)) {
220       return false;
221     }
222   } else if (expr->target.type) {
223     if (!EmitType(out, expr->target.type)) {
224       return false;
225     }
226   } else {
227     TINT_ICE(Writer, diagnostics_)
228         << "CallExpression target had neither a name or type";
229     return false;
230   }
231   out << "(";
232 
233   bool first = true;
234   const auto& args = expr->args;
235   for (auto* arg : args) {
236     if (!first) {
237       out << ", ";
238     }
239     first = false;
240 
241     if (!EmitExpression(out, arg)) {
242       return false;
243     }
244   }
245 
246   out << ")";
247 
248   return true;
249 }
250 
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)251 bool GeneratorImpl::EmitLiteral(std::ostream& out,
252                                 const ast::LiteralExpression* lit) {
253   if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
254     out << (bl->value ? "true" : "false");
255   } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
256     out << FloatToBitPreservingString(fl->value);
257   } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
258     out << sl->value;
259   } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
260     out << ul->value << "u";
261   } else {
262     diagnostics_.add_error(diag::System::Writer, "unknown literal type");
263     return false;
264   }
265   return true;
266 }
267 
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)268 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
269                                    const ast::IdentifierExpression* expr) {
270   out << program_->Symbols().NameFor(expr->symbol);
271   return true;
272 }
273 
EmitFunction(const ast::Function * func)274 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
275   if (func->decorations.size()) {
276     if (!EmitDecorations(line(), func->decorations)) {
277       return false;
278     }
279   }
280   {
281     auto out = line();
282     out << "fn " << program_->Symbols().NameFor(func->symbol) << "(";
283 
284     bool first = true;
285     for (auto* v : func->params) {
286       if (!first) {
287         out << ", ";
288       }
289       first = false;
290 
291       if (!v->decorations.empty()) {
292         if (!EmitDecorations(out, v->decorations)) {
293           return false;
294         }
295         out << " ";
296       }
297 
298       out << program_->Symbols().NameFor(v->symbol) << " : ";
299 
300       if (!EmitType(out, v->type)) {
301         return false;
302       }
303     }
304 
305     out << ")";
306 
307     if (!func->return_type->Is<ast::Void>() ||
308         !func->return_type_decorations.empty()) {
309       out << " -> ";
310 
311       if (!func->return_type_decorations.empty()) {
312         if (!EmitDecorations(out, func->return_type_decorations)) {
313           return false;
314         }
315         out << " ";
316       }
317 
318       if (!EmitType(out, func->return_type)) {
319         return false;
320       }
321     }
322 
323     if (func->body) {
324       out << " {";
325     }
326   }
327 
328   if (func->body) {
329     if (!EmitStatementsWithIndent(func->body->statements)) {
330       return false;
331     }
332     line() << "}";
333   }
334 
335   return true;
336 }
337 
EmitImageFormat(std::ostream & out,const ast::ImageFormat fmt)338 bool GeneratorImpl::EmitImageFormat(std::ostream& out,
339                                     const ast::ImageFormat fmt) {
340   switch (fmt) {
341     case ast::ImageFormat::kNone:
342       diagnostics_.add_error(diag::System::Writer, "unknown image format");
343       return false;
344     default:
345       out << fmt;
346   }
347   return true;
348 }
349 
EmitAccess(std::ostream & out,const ast::Access access)350 bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
351   switch (access) {
352     case ast::Access::kRead:
353       out << "read";
354       return true;
355     case ast::Access::kWrite:
356       out << "write";
357       return true;
358     case ast::Access::kReadWrite:
359       out << "read_write";
360       return true;
361     default:
362       break;
363   }
364   diagnostics_.add_error(diag::System::Writer, "unknown access");
365   return false;
366 }
367 
EmitType(std::ostream & out,const ast::Type * ty)368 bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
369   if (auto* ary = ty->As<ast::Array>()) {
370     for (auto* deco : ary->decorations) {
371       if (auto* stride = deco->As<ast::StrideDecoration>()) {
372         out << "[[stride(" << stride->stride << ")]] ";
373       }
374     }
375 
376     out << "array<";
377     if (!EmitType(out, ary->type)) {
378       return false;
379     }
380 
381     if (!ary->IsRuntimeArray()) {
382       out << ", ";
383       if (!EmitExpression(out, ary->count)) {
384         return false;
385       }
386     }
387 
388     out << ">";
389   } else if (ty->Is<ast::Bool>()) {
390     out << "bool";
391   } else if (ty->Is<ast::F32>()) {
392     out << "f32";
393   } else if (ty->Is<ast::I32>()) {
394     out << "i32";
395   } else if (auto* mat = ty->As<ast::Matrix>()) {
396     out << "mat" << mat->columns << "x" << mat->rows << "<";
397     if (!EmitType(out, mat->type)) {
398       return false;
399     }
400     out << ">";
401   } else if (auto* ptr = ty->As<ast::Pointer>()) {
402     out << "ptr<" << ptr->storage_class << ", ";
403     if (!EmitType(out, ptr->type)) {
404       return false;
405     }
406     if (ptr->access != ast::Access::kUndefined) {
407       out << ", ";
408       if (!EmitAccess(out, ptr->access)) {
409         return false;
410       }
411     }
412     out << ">";
413   } else if (auto* atomic = ty->As<ast::Atomic>()) {
414     out << "atomic<";
415     if (!EmitType(out, atomic->type)) {
416       return false;
417     }
418     out << ">";
419   } else if (auto* sampler = ty->As<ast::Sampler>()) {
420     out << "sampler";
421 
422     if (sampler->IsComparison()) {
423       out << "_comparison";
424     }
425   } else if (ty->Is<ast::ExternalTexture>()) {
426     out << "texture_external";
427   } else if (auto* texture = ty->As<ast::Texture>()) {
428     out << "texture_";
429     if (texture->Is<ast::DepthTexture>()) {
430       out << "depth_";
431     } else if (texture->Is<ast::DepthMultisampledTexture>()) {
432       out << "depth_multisampled_";
433     } else if (texture->Is<ast::SampledTexture>()) {
434       /* nothing to emit */
435     } else if (texture->Is<ast::MultisampledTexture>()) {
436       out << "multisampled_";
437     } else if (texture->Is<ast::StorageTexture>()) {
438       out << "storage_";
439     } else {
440       diagnostics_.add_error(diag::System::Writer, "unknown texture type");
441       return false;
442     }
443 
444     switch (texture->dim) {
445       case ast::TextureDimension::k1d:
446         out << "1d";
447         break;
448       case ast::TextureDimension::k2d:
449         out << "2d";
450         break;
451       case ast::TextureDimension::k2dArray:
452         out << "2d_array";
453         break;
454       case ast::TextureDimension::k3d:
455         out << "3d";
456         break;
457       case ast::TextureDimension::kCube:
458         out << "cube";
459         break;
460       case ast::TextureDimension::kCubeArray:
461         out << "cube_array";
462         break;
463       default:
464         diagnostics_.add_error(diag::System::Writer,
465                                "unknown texture dimension");
466         return false;
467     }
468 
469     if (auto* sampled = texture->As<ast::SampledTexture>()) {
470       out << "<";
471       if (!EmitType(out, sampled->type)) {
472         return false;
473       }
474       out << ">";
475     } else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
476       out << "<";
477       if (!EmitType(out, ms->type)) {
478         return false;
479       }
480       out << ">";
481     } else if (auto* storage = texture->As<ast::StorageTexture>()) {
482       out << "<";
483       if (!EmitImageFormat(out, storage->format)) {
484         return false;
485       }
486       out << ", ";
487       if (!EmitAccess(out, storage->access)) {
488         return false;
489       }
490       out << ">";
491     }
492 
493   } else if (ty->Is<ast::U32>()) {
494     out << "u32";
495   } else if (auto* vec = ty->As<ast::Vector>()) {
496     out << "vec" << vec->width << "<";
497     if (!EmitType(out, vec->type)) {
498       return false;
499     }
500     out << ">";
501   } else if (ty->Is<ast::Void>()) {
502     out << "void";
503   } else if (auto* tn = ty->As<ast::TypeName>()) {
504     out << program_->Symbols().NameFor(tn->name);
505   } else {
506     diagnostics_.add_error(
507         diag::System::Writer,
508         "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
509     return false;
510   }
511   return true;
512 }
513 
EmitStructType(const ast::Struct * str)514 bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
515   if (str->decorations.size()) {
516     if (!EmitDecorations(line(), str->decorations)) {
517       return false;
518     }
519   }
520   line() << "struct " << program_->Symbols().NameFor(str->name) << " {";
521 
522   auto add_padding = [&](uint32_t size) {
523     line() << "[[size(" << size << ")]]";
524 
525     // Note: u32 is the smallest primitive we currently support. When WGSL
526     // supports smaller types, this will need to be updated.
527     line() << UniqueIdentifier("padding") << " : u32;";
528   };
529 
530   increment_indent();
531   uint32_t offset = 0;
532   for (auto* mem : str->members) {
533     // TODO(crbug.com/tint/798) move the [[offset]] decoration handling to the
534     // transform::Wgsl sanitizer.
535     if (auto* mem_sem = program_->Sem().Get(mem)) {
536       offset = utils::RoundUp(mem_sem->Align(), offset);
537       if (uint32_t padding = mem_sem->Offset() - offset) {
538         add_padding(padding);
539         offset += padding;
540       }
541       offset += mem_sem->Size();
542     }
543 
544     // Offset decorations no longer exist in the WGSL spec, but are emitted
545     // by the SPIR-V reader and are consumed by the Resolver(). These should not
546     // be emitted, but instead struct padding fields should be emitted.
547     ast::DecorationList decorations_sanitized;
548     decorations_sanitized.reserve(mem->decorations.size());
549     for (auto* deco : mem->decorations) {
550       if (!deco->Is<ast::StructMemberOffsetDecoration>()) {
551         decorations_sanitized.emplace_back(deco);
552       }
553     }
554 
555     if (!decorations_sanitized.empty()) {
556       if (!EmitDecorations(line(), decorations_sanitized)) {
557         return false;
558       }
559     }
560 
561     auto out = line();
562     out << program_->Symbols().NameFor(mem->symbol) << " : ";
563     if (!EmitType(out, mem->type)) {
564       return false;
565     }
566     out << ";";
567   }
568   decrement_indent();
569 
570   line() << "};";
571   return true;
572 }
573 
EmitVariable(std::ostream & out,const ast::Variable * var)574 bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* var) {
575   if (!var->decorations.empty()) {
576     if (!EmitDecorations(out, var->decorations)) {
577       return false;
578     }
579     out << " ";
580   }
581 
582   if (var->is_const) {
583     out << "let";
584   } else {
585     out << "var";
586     auto sc = var->declared_storage_class;
587     auto ac = var->declared_access;
588     if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) {
589       out << "<" << sc;
590       if (ac != ast::Access::kUndefined) {
591         out << ", ";
592         if (!EmitAccess(out, ac)) {
593           return false;
594         }
595       }
596       out << ">";
597     }
598   }
599 
600   out << " " << program_->Symbols().NameFor(var->symbol);
601 
602   if (auto* ty = var->type) {
603     out << " : ";
604     if (!EmitType(out, ty)) {
605       return false;
606     }
607   }
608 
609   if (var->constructor != nullptr) {
610     out << " = ";
611     if (!EmitExpression(out, var->constructor)) {
612       return false;
613     }
614   }
615   out << ";";
616 
617   return true;
618 }
619 
EmitDecorations(std::ostream & out,const ast::DecorationList & decos)620 bool GeneratorImpl::EmitDecorations(std::ostream& out,
621                                     const ast::DecorationList& decos) {
622   out << "[[";
623   bool first = true;
624   for (auto* deco : decos) {
625     if (!first) {
626       out << ", ";
627     }
628     first = false;
629 
630     if (auto* workgroup = deco->As<ast::WorkgroupDecoration>()) {
631       auto values = workgroup->Values();
632       out << "workgroup_size(";
633       for (int i = 0; i < 3; i++) {
634         if (values[i]) {
635           if (i > 0) {
636             out << ", ";
637           }
638           if (!EmitExpression(out, values[i])) {
639             return false;
640           }
641         }
642       }
643       out << ")";
644     } else if (deco->Is<ast::StructBlockDecoration>()) {
645       out << "block";
646     } else if (auto* stage = deco->As<ast::StageDecoration>()) {
647       out << "stage(" << stage->stage << ")";
648     } else if (auto* binding = deco->As<ast::BindingDecoration>()) {
649       out << "binding(" << binding->value << ")";
650     } else if (auto* group = deco->As<ast::GroupDecoration>()) {
651       out << "group(" << group->value << ")";
652     } else if (auto* location = deco->As<ast::LocationDecoration>()) {
653       out << "location(" << location->value << ")";
654     } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
655       out << "builtin(" << builtin->builtin << ")";
656     } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
657       out << "interpolate(" << interpolate->type;
658       if (interpolate->sampling != ast::InterpolationSampling::kNone) {
659         out << ", " << interpolate->sampling;
660       }
661       out << ")";
662     } else if (deco->Is<ast::InvariantDecoration>()) {
663       out << "invariant";
664     } else if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
665       out << "override";
666       if (override_deco->has_value) {
667         out << "(" << override_deco->value << ")";
668       }
669     } else if (auto* size = deco->As<ast::StructMemberSizeDecoration>()) {
670       out << "size(" << size->size << ")";
671     } else if (auto* align = deco->As<ast::StructMemberAlignDecoration>()) {
672       out << "align(" << align->align << ")";
673     } else if (auto* stride = deco->As<ast::StrideDecoration>()) {
674       out << "stride(" << stride->stride << ")";
675     } else if (auto* internal = deco->As<ast::InternalDecoration>()) {
676       out << "internal(" << internal->InternalName() << ")";
677     } else {
678       TINT_ICE(Writer, diagnostics_)
679           << "Unsupported decoration '" << deco->TypeInfo().name << "'";
680       return false;
681     }
682   }
683   out << "]]";
684 
685   return true;
686 }
687 
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)688 bool GeneratorImpl::EmitBinary(std::ostream& out,
689                                const ast::BinaryExpression* expr) {
690   out << "(";
691 
692   if (!EmitExpression(out, expr->lhs)) {
693     return false;
694   }
695   out << " ";
696 
697   switch (expr->op) {
698     case ast::BinaryOp::kAnd:
699       out << "&";
700       break;
701     case ast::BinaryOp::kOr:
702       out << "|";
703       break;
704     case ast::BinaryOp::kXor:
705       out << "^";
706       break;
707     case ast::BinaryOp::kLogicalAnd:
708       out << "&&";
709       break;
710     case ast::BinaryOp::kLogicalOr:
711       out << "||";
712       break;
713     case ast::BinaryOp::kEqual:
714       out << "==";
715       break;
716     case ast::BinaryOp::kNotEqual:
717       out << "!=";
718       break;
719     case ast::BinaryOp::kLessThan:
720       out << "<";
721       break;
722     case ast::BinaryOp::kGreaterThan:
723       out << ">";
724       break;
725     case ast::BinaryOp::kLessThanEqual:
726       out << "<=";
727       break;
728     case ast::BinaryOp::kGreaterThanEqual:
729       out << ">=";
730       break;
731     case ast::BinaryOp::kShiftLeft:
732       out << "<<";
733       break;
734     case ast::BinaryOp::kShiftRight:
735       out << ">>";
736       break;
737     case ast::BinaryOp::kAdd:
738       out << "+";
739       break;
740     case ast::BinaryOp::kSubtract:
741       out << "-";
742       break;
743     case ast::BinaryOp::kMultiply:
744       out << "*";
745       break;
746     case ast::BinaryOp::kDivide:
747       out << "/";
748       break;
749     case ast::BinaryOp::kModulo:
750       out << "%";
751       break;
752     case ast::BinaryOp::kNone:
753       diagnostics_.add_error(diag::System::Writer,
754                              "missing binary operation type");
755       return false;
756   }
757   out << " ";
758 
759   if (!EmitExpression(out, expr->rhs)) {
760     return false;
761   }
762 
763   out << ")";
764   return true;
765 }
766 
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)767 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
768                                 const ast::UnaryOpExpression* expr) {
769   switch (expr->op) {
770     case ast::UnaryOp::kAddressOf:
771       out << "&";
772       break;
773     case ast::UnaryOp::kComplement:
774       out << "~";
775       break;
776     case ast::UnaryOp::kIndirection:
777       out << "*";
778       break;
779     case ast::UnaryOp::kNot:
780       out << "!";
781       break;
782     case ast::UnaryOp::kNegation:
783       out << "-";
784       break;
785   }
786   out << "(";
787 
788   if (!EmitExpression(out, expr->expr)) {
789     return false;
790   }
791 
792   out << ")";
793 
794   return true;
795 }
796 
EmitBlock(const ast::BlockStatement * stmt)797 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
798   line() << "{";
799   if (!EmitStatementsWithIndent(stmt->statements)) {
800     return false;
801   }
802   line() << "}";
803 
804   return true;
805 }
806 
EmitStatement(const ast::Statement * stmt)807 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
808   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
809     return EmitAssign(a);
810   }
811   if (auto* b = stmt->As<ast::BlockStatement>()) {
812     return EmitBlock(b);
813   }
814   if (auto* b = stmt->As<ast::BreakStatement>()) {
815     return EmitBreak(b);
816   }
817   if (auto* c = stmt->As<ast::CallStatement>()) {
818     auto out = line();
819     if (!EmitCall(out, c->expr)) {
820       return false;
821     }
822     out << ";";
823     return true;
824   }
825   if (auto* c = stmt->As<ast::ContinueStatement>()) {
826     return EmitContinue(c);
827   }
828   if (auto* d = stmt->As<ast::DiscardStatement>()) {
829     return EmitDiscard(d);
830   }
831   if (auto* f = stmt->As<ast::FallthroughStatement>()) {
832     return EmitFallthrough(f);
833   }
834   if (auto* i = stmt->As<ast::IfStatement>()) {
835     return EmitIf(i);
836   }
837   if (auto* l = stmt->As<ast::LoopStatement>()) {
838     return EmitLoop(l);
839   }
840   if (auto* l = stmt->As<ast::ForLoopStatement>()) {
841     return EmitForLoop(l);
842   }
843   if (auto* r = stmt->As<ast::ReturnStatement>()) {
844     return EmitReturn(r);
845   }
846   if (auto* s = stmt->As<ast::SwitchStatement>()) {
847     return EmitSwitch(s);
848   }
849   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
850     return EmitVariable(line(), v->variable);
851   }
852 
853   diagnostics_.add_error(
854       diag::System::Writer,
855       "unknown statement type: " + std::string(stmt->TypeInfo().name));
856   return false;
857 }
858 
EmitStatements(const ast::StatementList & stmts)859 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
860   for (auto* s : stmts) {
861     if (!EmitStatement(s)) {
862       return false;
863     }
864   }
865   return true;
866 }
867 
EmitStatementsWithIndent(const ast::StatementList & stmts)868 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
869   ScopedIndent si(this);
870   return EmitStatements(stmts);
871 }
872 
EmitAssign(const ast::AssignmentStatement * stmt)873 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
874   auto out = line();
875 
876   if (!EmitExpression(out, stmt->lhs)) {
877     return false;
878   }
879 
880   out << " = ";
881 
882   if (!EmitExpression(out, stmt->rhs)) {
883     return false;
884   }
885 
886   out << ";";
887 
888   return true;
889 }
890 
EmitBreak(const ast::BreakStatement *)891 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
892   line() << "break;";
893   return true;
894 }
895 
EmitCase(const ast::CaseStatement * stmt)896 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
897   if (stmt->IsDefault()) {
898     line() << "default: {";
899   } else {
900     auto out = line();
901     out << "case ";
902 
903     bool first = true;
904     for (auto* selector : stmt->selectors) {
905       if (!first) {
906         out << ", ";
907       }
908 
909       first = false;
910       if (!EmitLiteral(out, selector)) {
911         return false;
912       }
913     }
914     out << ": {";
915   }
916 
917   if (!EmitStatementsWithIndent(stmt->body->statements)) {
918     return false;
919   }
920 
921   line() << "}";
922   return true;
923 }
924 
EmitContinue(const ast::ContinueStatement *)925 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
926   line() << "continue;";
927   return true;
928 }
929 
EmitFallthrough(const ast::FallthroughStatement *)930 bool GeneratorImpl::EmitFallthrough(const ast::FallthroughStatement*) {
931   line() << "fallthrough;";
932   return true;
933 }
934 
EmitIf(const ast::IfStatement * stmt)935 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
936   {
937     auto out = line();
938     out << "if (";
939     if (!EmitExpression(out, stmt->condition)) {
940       return false;
941     }
942     out << ") {";
943   }
944 
945   if (!EmitStatementsWithIndent(stmt->body->statements)) {
946     return false;
947   }
948 
949   for (auto* e : stmt->else_statements) {
950     if (e->condition) {
951       auto out = line();
952       out << "} elseif (";
953       if (!EmitExpression(out, e->condition)) {
954         return false;
955       }
956       out << ") {";
957     } else {
958       line() << "} else {";
959     }
960 
961     if (!EmitStatementsWithIndent(e->body->statements)) {
962       return false;
963     }
964   }
965 
966   line() << "}";
967 
968   return true;
969 }
970 
EmitDiscard(const ast::DiscardStatement *)971 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
972   line() << "discard;";
973   return true;
974 }
975 
EmitLoop(const ast::LoopStatement * stmt)976 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
977   line() << "loop {";
978   increment_indent();
979 
980   if (!EmitStatements(stmt->body->statements)) {
981     return false;
982   }
983 
984   if (stmt->continuing && !stmt->continuing->Empty()) {
985     line();
986     line() << "continuing {";
987     if (!EmitStatementsWithIndent(stmt->continuing->statements)) {
988       return false;
989     }
990     line() << "}";
991   }
992 
993   decrement_indent();
994   line() << "}";
995 
996   return true;
997 }
998 
EmitForLoop(const ast::ForLoopStatement * stmt)999 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
1000   TextBuffer init_buf;
1001   if (auto* init = stmt->initializer) {
1002     TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
1003     if (!EmitStatement(init)) {
1004       return false;
1005     }
1006   }
1007 
1008   TextBuffer cont_buf;
1009   if (auto* cont = stmt->continuing) {
1010     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
1011     if (!EmitStatement(cont)) {
1012       return false;
1013     }
1014   }
1015 
1016   {
1017     auto out = line();
1018     out << "for";
1019     {
1020       ScopedParen sp(out);
1021       switch (init_buf.lines.size()) {
1022         case 0:  // No initializer
1023           break;
1024         case 1:  // Single line initializer statement
1025           out << TrimSuffix(init_buf.lines[0].content, ";");
1026           break;
1027         default:  // Block initializer statement
1028           for (size_t i = 1; i < init_buf.lines.size(); i++) {
1029             // Indent all by the first line
1030             init_buf.lines[i].indent += current_buffer_->current_indent;
1031           }
1032           out << TrimSuffix(init_buf.String(), "\n");
1033           break;
1034       }
1035 
1036       out << "; ";
1037 
1038       if (auto* cond = stmt->condition) {
1039         if (!EmitExpression(out, cond)) {
1040           return false;
1041         }
1042       }
1043 
1044       out << "; ";
1045 
1046       switch (cont_buf.lines.size()) {
1047         case 0:  // No continuing
1048           break;
1049         case 1:  // Single line continuing statement
1050           out << TrimSuffix(cont_buf.lines[0].content, ";");
1051           break;
1052         default:  // Block continuing statement
1053           for (size_t i = 1; i < cont_buf.lines.size(); i++) {
1054             // Indent all by the first line
1055             cont_buf.lines[i].indent += current_buffer_->current_indent;
1056           }
1057           out << TrimSuffix(cont_buf.String(), "\n");
1058           break;
1059       }
1060     }
1061     out << " {";
1062   }
1063 
1064   if (!EmitStatementsWithIndent(stmt->body->statements)) {
1065     return false;
1066   }
1067 
1068   line() << "}";
1069 
1070   return true;
1071 }
1072 
EmitReturn(const ast::ReturnStatement * stmt)1073 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
1074   auto out = line();
1075   out << "return";
1076   if (stmt->value) {
1077     out << " ";
1078     if (!EmitExpression(out, stmt->value)) {
1079       return false;
1080     }
1081   }
1082   out << ";";
1083   return true;
1084 }
1085 
EmitSwitch(const ast::SwitchStatement * stmt)1086 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
1087   {
1088     auto out = line();
1089     out << "switch(";
1090     if (!EmitExpression(out, stmt->condition)) {
1091       return false;
1092     }
1093     out << ") {";
1094   }
1095 
1096   {
1097     ScopedIndent si(this);
1098     for (auto* s : stmt->body) {
1099       if (!EmitCase(s)) {
1100         return false;
1101       }
1102     }
1103   }
1104 
1105   line() << "}";
1106   return true;
1107 }
1108 
1109 }  // namespace wgsl
1110 }  // namespace writer
1111 }  // namespace tint
1112