1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/writer/wgsl/generator_impl.h"
16
17 #include <algorithm>
18
19 #include "src/ast/access.h"
20 #include "src/ast/alias.h"
21 #include "src/ast/array.h"
22 #include "src/ast/atomic.h"
23 #include "src/ast/bool.h"
24 #include "src/ast/bool_literal_expression.h"
25 #include "src/ast/call_statement.h"
26 #include "src/ast/depth_texture.h"
27 #include "src/ast/external_texture.h"
28 #include "src/ast/f32.h"
29 #include "src/ast/float_literal_expression.h"
30 #include "src/ast/i32.h"
31 #include "src/ast/internal_decoration.h"
32 #include "src/ast/interpolate_decoration.h"
33 #include "src/ast/invariant_decoration.h"
34 #include "src/ast/matrix.h"
35 #include "src/ast/module.h"
36 #include "src/ast/multisampled_texture.h"
37 #include "src/ast/override_decoration.h"
38 #include "src/ast/pointer.h"
39 #include "src/ast/sampled_texture.h"
40 #include "src/ast/sint_literal_expression.h"
41 #include "src/ast/stage_decoration.h"
42 #include "src/ast/storage_texture.h"
43 #include "src/ast/stride_decoration.h"
44 #include "src/ast/struct_block_decoration.h"
45 #include "src/ast/struct_member_align_decoration.h"
46 #include "src/ast/struct_member_offset_decoration.h"
47 #include "src/ast/struct_member_size_decoration.h"
48 #include "src/ast/type_name.h"
49 #include "src/ast/u32.h"
50 #include "src/ast/uint_literal_expression.h"
51 #include "src/ast/variable_decl_statement.h"
52 #include "src/ast/vector.h"
53 #include "src/ast/void.h"
54 #include "src/ast/workgroup_decoration.h"
55 #include "src/sem/struct.h"
56 #include "src/utils/math.h"
57 #include "src/utils/scoped_assignment.h"
58 #include "src/writer/float_to_string.h"
59
60 namespace tint {
61 namespace writer {
62 namespace wgsl {
63
GeneratorImpl(const Program * program)64 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
65
66 GeneratorImpl::~GeneratorImpl() = default;
67
Generate()68 bool GeneratorImpl::Generate() {
69 // Generate global declarations in the order they appear in the module.
70 for (auto* decl : program_->AST().GlobalDeclarations()) {
71 if (auto* td = decl->As<ast::TypeDecl>()) {
72 if (!EmitTypeDecl(td)) {
73 return false;
74 }
75 } else if (auto* func = decl->As<ast::Function>()) {
76 if (!EmitFunction(func)) {
77 return false;
78 }
79 } else if (auto* var = decl->As<ast::Variable>()) {
80 if (!EmitVariable(line(), var)) {
81 return false;
82 }
83 } else {
84 TINT_UNREACHABLE(Writer, diagnostics_);
85 return false;
86 }
87
88 if (decl != program_->AST().GlobalDeclarations().back()) {
89 line();
90 }
91 }
92
93 return true;
94 }
95
EmitTypeDecl(const ast::TypeDecl * ty)96 bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
97 if (auto* alias = ty->As<ast::Alias>()) {
98 auto out = line();
99 out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
100 if (!EmitType(out, alias->type)) {
101 return false;
102 }
103 out << ";";
104 } else if (auto* str = ty->As<ast::Struct>()) {
105 if (!EmitStructType(str)) {
106 return false;
107 }
108 } else {
109 diagnostics_.add_error(
110 diag::System::Writer,
111 "unknown declared type: " + std::string(ty->TypeInfo().name));
112 return false;
113 }
114 return true;
115 }
116
EmitExpression(std::ostream & out,const ast::Expression * expr)117 bool GeneratorImpl::EmitExpression(std::ostream& out,
118 const ast::Expression* expr) {
119 if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
120 return EmitIndexAccessor(out, a);
121 }
122 if (auto* b = expr->As<ast::BinaryExpression>()) {
123 return EmitBinary(out, b);
124 }
125 if (auto* b = expr->As<ast::BitcastExpression>()) {
126 return EmitBitcast(out, b);
127 }
128 if (auto* c = expr->As<ast::CallExpression>()) {
129 return EmitCall(out, c);
130 }
131 if (auto* i = expr->As<ast::IdentifierExpression>()) {
132 return EmitIdentifier(out, i);
133 }
134 if (auto* l = expr->As<ast::LiteralExpression>()) {
135 return EmitLiteral(out, l);
136 }
137 if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
138 return EmitMemberAccessor(out, m);
139 }
140 if (expr->Is<ast::PhonyExpression>()) {
141 out << "_";
142 return true;
143 }
144 if (auto* u = expr->As<ast::UnaryOpExpression>()) {
145 return EmitUnaryOp(out, u);
146 }
147
148 diagnostics_.add_error(diag::System::Writer, "unknown expression type");
149 return false;
150 }
151
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)152 bool GeneratorImpl::EmitIndexAccessor(
153 std::ostream& out,
154 const ast::IndexAccessorExpression* expr) {
155 bool paren_lhs =
156 !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
157 ast::IdentifierExpression,
158 ast::MemberAccessorExpression>();
159 if (paren_lhs) {
160 out << "(";
161 }
162 if (!EmitExpression(out, expr->object)) {
163 return false;
164 }
165 if (paren_lhs) {
166 out << ")";
167 }
168 out << "[";
169
170 if (!EmitExpression(out, expr->index)) {
171 return false;
172 }
173 out << "]";
174
175 return true;
176 }
177
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)178 bool GeneratorImpl::EmitMemberAccessor(
179 std::ostream& out,
180 const ast::MemberAccessorExpression* expr) {
181 bool paren_lhs =
182 !expr->structure->IsAnyOf<ast::IndexAccessorExpression,
183 ast::CallExpression, ast::IdentifierExpression,
184 ast::MemberAccessorExpression>();
185 if (paren_lhs) {
186 out << "(";
187 }
188 if (!EmitExpression(out, expr->structure)) {
189 return false;
190 }
191 if (paren_lhs) {
192 out << ")";
193 }
194
195 out << ".";
196
197 return EmitExpression(out, expr->member);
198 }
199
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)200 bool GeneratorImpl::EmitBitcast(std::ostream& out,
201 const ast::BitcastExpression* expr) {
202 out << "bitcast<";
203 if (!EmitType(out, expr->type)) {
204 return false;
205 }
206
207 out << ">(";
208 if (!EmitExpression(out, expr->expr)) {
209 return false;
210 }
211
212 out << ")";
213 return true;
214 }
215
EmitCall(std::ostream & out,const ast::CallExpression * expr)216 bool GeneratorImpl::EmitCall(std::ostream& out,
217 const ast::CallExpression* expr) {
218 if (expr->target.name) {
219 if (!EmitExpression(out, expr->target.name)) {
220 return false;
221 }
222 } else if (expr->target.type) {
223 if (!EmitType(out, expr->target.type)) {
224 return false;
225 }
226 } else {
227 TINT_ICE(Writer, diagnostics_)
228 << "CallExpression target had neither a name or type";
229 return false;
230 }
231 out << "(";
232
233 bool first = true;
234 const auto& args = expr->args;
235 for (auto* arg : args) {
236 if (!first) {
237 out << ", ";
238 }
239 first = false;
240
241 if (!EmitExpression(out, arg)) {
242 return false;
243 }
244 }
245
246 out << ")";
247
248 return true;
249 }
250
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)251 bool GeneratorImpl::EmitLiteral(std::ostream& out,
252 const ast::LiteralExpression* lit) {
253 if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
254 out << (bl->value ? "true" : "false");
255 } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
256 out << FloatToBitPreservingString(fl->value);
257 } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
258 out << sl->value;
259 } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
260 out << ul->value << "u";
261 } else {
262 diagnostics_.add_error(diag::System::Writer, "unknown literal type");
263 return false;
264 }
265 return true;
266 }
267
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)268 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
269 const ast::IdentifierExpression* expr) {
270 out << program_->Symbols().NameFor(expr->symbol);
271 return true;
272 }
273
EmitFunction(const ast::Function * func)274 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
275 if (func->decorations.size()) {
276 if (!EmitDecorations(line(), func->decorations)) {
277 return false;
278 }
279 }
280 {
281 auto out = line();
282 out << "fn " << program_->Symbols().NameFor(func->symbol) << "(";
283
284 bool first = true;
285 for (auto* v : func->params) {
286 if (!first) {
287 out << ", ";
288 }
289 first = false;
290
291 if (!v->decorations.empty()) {
292 if (!EmitDecorations(out, v->decorations)) {
293 return false;
294 }
295 out << " ";
296 }
297
298 out << program_->Symbols().NameFor(v->symbol) << " : ";
299
300 if (!EmitType(out, v->type)) {
301 return false;
302 }
303 }
304
305 out << ")";
306
307 if (!func->return_type->Is<ast::Void>() ||
308 !func->return_type_decorations.empty()) {
309 out << " -> ";
310
311 if (!func->return_type_decorations.empty()) {
312 if (!EmitDecorations(out, func->return_type_decorations)) {
313 return false;
314 }
315 out << " ";
316 }
317
318 if (!EmitType(out, func->return_type)) {
319 return false;
320 }
321 }
322
323 if (func->body) {
324 out << " {";
325 }
326 }
327
328 if (func->body) {
329 if (!EmitStatementsWithIndent(func->body->statements)) {
330 return false;
331 }
332 line() << "}";
333 }
334
335 return true;
336 }
337
EmitImageFormat(std::ostream & out,const ast::ImageFormat fmt)338 bool GeneratorImpl::EmitImageFormat(std::ostream& out,
339 const ast::ImageFormat fmt) {
340 switch (fmt) {
341 case ast::ImageFormat::kNone:
342 diagnostics_.add_error(diag::System::Writer, "unknown image format");
343 return false;
344 default:
345 out << fmt;
346 }
347 return true;
348 }
349
EmitAccess(std::ostream & out,const ast::Access access)350 bool GeneratorImpl::EmitAccess(std::ostream& out, const ast::Access access) {
351 switch (access) {
352 case ast::Access::kRead:
353 out << "read";
354 return true;
355 case ast::Access::kWrite:
356 out << "write";
357 return true;
358 case ast::Access::kReadWrite:
359 out << "read_write";
360 return true;
361 default:
362 break;
363 }
364 diagnostics_.add_error(diag::System::Writer, "unknown access");
365 return false;
366 }
367
EmitType(std::ostream & out,const ast::Type * ty)368 bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
369 if (auto* ary = ty->As<ast::Array>()) {
370 for (auto* deco : ary->decorations) {
371 if (auto* stride = deco->As<ast::StrideDecoration>()) {
372 out << "[[stride(" << stride->stride << ")]] ";
373 }
374 }
375
376 out << "array<";
377 if (!EmitType(out, ary->type)) {
378 return false;
379 }
380
381 if (!ary->IsRuntimeArray()) {
382 out << ", ";
383 if (!EmitExpression(out, ary->count)) {
384 return false;
385 }
386 }
387
388 out << ">";
389 } else if (ty->Is<ast::Bool>()) {
390 out << "bool";
391 } else if (ty->Is<ast::F32>()) {
392 out << "f32";
393 } else if (ty->Is<ast::I32>()) {
394 out << "i32";
395 } else if (auto* mat = ty->As<ast::Matrix>()) {
396 out << "mat" << mat->columns << "x" << mat->rows << "<";
397 if (!EmitType(out, mat->type)) {
398 return false;
399 }
400 out << ">";
401 } else if (auto* ptr = ty->As<ast::Pointer>()) {
402 out << "ptr<" << ptr->storage_class << ", ";
403 if (!EmitType(out, ptr->type)) {
404 return false;
405 }
406 if (ptr->access != ast::Access::kUndefined) {
407 out << ", ";
408 if (!EmitAccess(out, ptr->access)) {
409 return false;
410 }
411 }
412 out << ">";
413 } else if (auto* atomic = ty->As<ast::Atomic>()) {
414 out << "atomic<";
415 if (!EmitType(out, atomic->type)) {
416 return false;
417 }
418 out << ">";
419 } else if (auto* sampler = ty->As<ast::Sampler>()) {
420 out << "sampler";
421
422 if (sampler->IsComparison()) {
423 out << "_comparison";
424 }
425 } else if (ty->Is<ast::ExternalTexture>()) {
426 out << "texture_external";
427 } else if (auto* texture = ty->As<ast::Texture>()) {
428 out << "texture_";
429 if (texture->Is<ast::DepthTexture>()) {
430 out << "depth_";
431 } else if (texture->Is<ast::DepthMultisampledTexture>()) {
432 out << "depth_multisampled_";
433 } else if (texture->Is<ast::SampledTexture>()) {
434 /* nothing to emit */
435 } else if (texture->Is<ast::MultisampledTexture>()) {
436 out << "multisampled_";
437 } else if (texture->Is<ast::StorageTexture>()) {
438 out << "storage_";
439 } else {
440 diagnostics_.add_error(diag::System::Writer, "unknown texture type");
441 return false;
442 }
443
444 switch (texture->dim) {
445 case ast::TextureDimension::k1d:
446 out << "1d";
447 break;
448 case ast::TextureDimension::k2d:
449 out << "2d";
450 break;
451 case ast::TextureDimension::k2dArray:
452 out << "2d_array";
453 break;
454 case ast::TextureDimension::k3d:
455 out << "3d";
456 break;
457 case ast::TextureDimension::kCube:
458 out << "cube";
459 break;
460 case ast::TextureDimension::kCubeArray:
461 out << "cube_array";
462 break;
463 default:
464 diagnostics_.add_error(diag::System::Writer,
465 "unknown texture dimension");
466 return false;
467 }
468
469 if (auto* sampled = texture->As<ast::SampledTexture>()) {
470 out << "<";
471 if (!EmitType(out, sampled->type)) {
472 return false;
473 }
474 out << ">";
475 } else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
476 out << "<";
477 if (!EmitType(out, ms->type)) {
478 return false;
479 }
480 out << ">";
481 } else if (auto* storage = texture->As<ast::StorageTexture>()) {
482 out << "<";
483 if (!EmitImageFormat(out, storage->format)) {
484 return false;
485 }
486 out << ", ";
487 if (!EmitAccess(out, storage->access)) {
488 return false;
489 }
490 out << ">";
491 }
492
493 } else if (ty->Is<ast::U32>()) {
494 out << "u32";
495 } else if (auto* vec = ty->As<ast::Vector>()) {
496 out << "vec" << vec->width << "<";
497 if (!EmitType(out, vec->type)) {
498 return false;
499 }
500 out << ">";
501 } else if (ty->Is<ast::Void>()) {
502 out << "void";
503 } else if (auto* tn = ty->As<ast::TypeName>()) {
504 out << program_->Symbols().NameFor(tn->name);
505 } else {
506 diagnostics_.add_error(
507 diag::System::Writer,
508 "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
509 return false;
510 }
511 return true;
512 }
513
EmitStructType(const ast::Struct * str)514 bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
515 if (str->decorations.size()) {
516 if (!EmitDecorations(line(), str->decorations)) {
517 return false;
518 }
519 }
520 line() << "struct " << program_->Symbols().NameFor(str->name) << " {";
521
522 auto add_padding = [&](uint32_t size) {
523 line() << "[[size(" << size << ")]]";
524
525 // Note: u32 is the smallest primitive we currently support. When WGSL
526 // supports smaller types, this will need to be updated.
527 line() << UniqueIdentifier("padding") << " : u32;";
528 };
529
530 increment_indent();
531 uint32_t offset = 0;
532 for (auto* mem : str->members) {
533 // TODO(crbug.com/tint/798) move the [[offset]] decoration handling to the
534 // transform::Wgsl sanitizer.
535 if (auto* mem_sem = program_->Sem().Get(mem)) {
536 offset = utils::RoundUp(mem_sem->Align(), offset);
537 if (uint32_t padding = mem_sem->Offset() - offset) {
538 add_padding(padding);
539 offset += padding;
540 }
541 offset += mem_sem->Size();
542 }
543
544 // Offset decorations no longer exist in the WGSL spec, but are emitted
545 // by the SPIR-V reader and are consumed by the Resolver(). These should not
546 // be emitted, but instead struct padding fields should be emitted.
547 ast::DecorationList decorations_sanitized;
548 decorations_sanitized.reserve(mem->decorations.size());
549 for (auto* deco : mem->decorations) {
550 if (!deco->Is<ast::StructMemberOffsetDecoration>()) {
551 decorations_sanitized.emplace_back(deco);
552 }
553 }
554
555 if (!decorations_sanitized.empty()) {
556 if (!EmitDecorations(line(), decorations_sanitized)) {
557 return false;
558 }
559 }
560
561 auto out = line();
562 out << program_->Symbols().NameFor(mem->symbol) << " : ";
563 if (!EmitType(out, mem->type)) {
564 return false;
565 }
566 out << ";";
567 }
568 decrement_indent();
569
570 line() << "};";
571 return true;
572 }
573
EmitVariable(std::ostream & out,const ast::Variable * var)574 bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* var) {
575 if (!var->decorations.empty()) {
576 if (!EmitDecorations(out, var->decorations)) {
577 return false;
578 }
579 out << " ";
580 }
581
582 if (var->is_const) {
583 out << "let";
584 } else {
585 out << "var";
586 auto sc = var->declared_storage_class;
587 auto ac = var->declared_access;
588 if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) {
589 out << "<" << sc;
590 if (ac != ast::Access::kUndefined) {
591 out << ", ";
592 if (!EmitAccess(out, ac)) {
593 return false;
594 }
595 }
596 out << ">";
597 }
598 }
599
600 out << " " << program_->Symbols().NameFor(var->symbol);
601
602 if (auto* ty = var->type) {
603 out << " : ";
604 if (!EmitType(out, ty)) {
605 return false;
606 }
607 }
608
609 if (var->constructor != nullptr) {
610 out << " = ";
611 if (!EmitExpression(out, var->constructor)) {
612 return false;
613 }
614 }
615 out << ";";
616
617 return true;
618 }
619
EmitDecorations(std::ostream & out,const ast::DecorationList & decos)620 bool GeneratorImpl::EmitDecorations(std::ostream& out,
621 const ast::DecorationList& decos) {
622 out << "[[";
623 bool first = true;
624 for (auto* deco : decos) {
625 if (!first) {
626 out << ", ";
627 }
628 first = false;
629
630 if (auto* workgroup = deco->As<ast::WorkgroupDecoration>()) {
631 auto values = workgroup->Values();
632 out << "workgroup_size(";
633 for (int i = 0; i < 3; i++) {
634 if (values[i]) {
635 if (i > 0) {
636 out << ", ";
637 }
638 if (!EmitExpression(out, values[i])) {
639 return false;
640 }
641 }
642 }
643 out << ")";
644 } else if (deco->Is<ast::StructBlockDecoration>()) {
645 out << "block";
646 } else if (auto* stage = deco->As<ast::StageDecoration>()) {
647 out << "stage(" << stage->stage << ")";
648 } else if (auto* binding = deco->As<ast::BindingDecoration>()) {
649 out << "binding(" << binding->value << ")";
650 } else if (auto* group = deco->As<ast::GroupDecoration>()) {
651 out << "group(" << group->value << ")";
652 } else if (auto* location = deco->As<ast::LocationDecoration>()) {
653 out << "location(" << location->value << ")";
654 } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
655 out << "builtin(" << builtin->builtin << ")";
656 } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
657 out << "interpolate(" << interpolate->type;
658 if (interpolate->sampling != ast::InterpolationSampling::kNone) {
659 out << ", " << interpolate->sampling;
660 }
661 out << ")";
662 } else if (deco->Is<ast::InvariantDecoration>()) {
663 out << "invariant";
664 } else if (auto* override_deco = deco->As<ast::OverrideDecoration>()) {
665 out << "override";
666 if (override_deco->has_value) {
667 out << "(" << override_deco->value << ")";
668 }
669 } else if (auto* size = deco->As<ast::StructMemberSizeDecoration>()) {
670 out << "size(" << size->size << ")";
671 } else if (auto* align = deco->As<ast::StructMemberAlignDecoration>()) {
672 out << "align(" << align->align << ")";
673 } else if (auto* stride = deco->As<ast::StrideDecoration>()) {
674 out << "stride(" << stride->stride << ")";
675 } else if (auto* internal = deco->As<ast::InternalDecoration>()) {
676 out << "internal(" << internal->InternalName() << ")";
677 } else {
678 TINT_ICE(Writer, diagnostics_)
679 << "Unsupported decoration '" << deco->TypeInfo().name << "'";
680 return false;
681 }
682 }
683 out << "]]";
684
685 return true;
686 }
687
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)688 bool GeneratorImpl::EmitBinary(std::ostream& out,
689 const ast::BinaryExpression* expr) {
690 out << "(";
691
692 if (!EmitExpression(out, expr->lhs)) {
693 return false;
694 }
695 out << " ";
696
697 switch (expr->op) {
698 case ast::BinaryOp::kAnd:
699 out << "&";
700 break;
701 case ast::BinaryOp::kOr:
702 out << "|";
703 break;
704 case ast::BinaryOp::kXor:
705 out << "^";
706 break;
707 case ast::BinaryOp::kLogicalAnd:
708 out << "&&";
709 break;
710 case ast::BinaryOp::kLogicalOr:
711 out << "||";
712 break;
713 case ast::BinaryOp::kEqual:
714 out << "==";
715 break;
716 case ast::BinaryOp::kNotEqual:
717 out << "!=";
718 break;
719 case ast::BinaryOp::kLessThan:
720 out << "<";
721 break;
722 case ast::BinaryOp::kGreaterThan:
723 out << ">";
724 break;
725 case ast::BinaryOp::kLessThanEqual:
726 out << "<=";
727 break;
728 case ast::BinaryOp::kGreaterThanEqual:
729 out << ">=";
730 break;
731 case ast::BinaryOp::kShiftLeft:
732 out << "<<";
733 break;
734 case ast::BinaryOp::kShiftRight:
735 out << ">>";
736 break;
737 case ast::BinaryOp::kAdd:
738 out << "+";
739 break;
740 case ast::BinaryOp::kSubtract:
741 out << "-";
742 break;
743 case ast::BinaryOp::kMultiply:
744 out << "*";
745 break;
746 case ast::BinaryOp::kDivide:
747 out << "/";
748 break;
749 case ast::BinaryOp::kModulo:
750 out << "%";
751 break;
752 case ast::BinaryOp::kNone:
753 diagnostics_.add_error(diag::System::Writer,
754 "missing binary operation type");
755 return false;
756 }
757 out << " ";
758
759 if (!EmitExpression(out, expr->rhs)) {
760 return false;
761 }
762
763 out << ")";
764 return true;
765 }
766
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)767 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
768 const ast::UnaryOpExpression* expr) {
769 switch (expr->op) {
770 case ast::UnaryOp::kAddressOf:
771 out << "&";
772 break;
773 case ast::UnaryOp::kComplement:
774 out << "~";
775 break;
776 case ast::UnaryOp::kIndirection:
777 out << "*";
778 break;
779 case ast::UnaryOp::kNot:
780 out << "!";
781 break;
782 case ast::UnaryOp::kNegation:
783 out << "-";
784 break;
785 }
786 out << "(";
787
788 if (!EmitExpression(out, expr->expr)) {
789 return false;
790 }
791
792 out << ")";
793
794 return true;
795 }
796
EmitBlock(const ast::BlockStatement * stmt)797 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
798 line() << "{";
799 if (!EmitStatementsWithIndent(stmt->statements)) {
800 return false;
801 }
802 line() << "}";
803
804 return true;
805 }
806
EmitStatement(const ast::Statement * stmt)807 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
808 if (auto* a = stmt->As<ast::AssignmentStatement>()) {
809 return EmitAssign(a);
810 }
811 if (auto* b = stmt->As<ast::BlockStatement>()) {
812 return EmitBlock(b);
813 }
814 if (auto* b = stmt->As<ast::BreakStatement>()) {
815 return EmitBreak(b);
816 }
817 if (auto* c = stmt->As<ast::CallStatement>()) {
818 auto out = line();
819 if (!EmitCall(out, c->expr)) {
820 return false;
821 }
822 out << ";";
823 return true;
824 }
825 if (auto* c = stmt->As<ast::ContinueStatement>()) {
826 return EmitContinue(c);
827 }
828 if (auto* d = stmt->As<ast::DiscardStatement>()) {
829 return EmitDiscard(d);
830 }
831 if (auto* f = stmt->As<ast::FallthroughStatement>()) {
832 return EmitFallthrough(f);
833 }
834 if (auto* i = stmt->As<ast::IfStatement>()) {
835 return EmitIf(i);
836 }
837 if (auto* l = stmt->As<ast::LoopStatement>()) {
838 return EmitLoop(l);
839 }
840 if (auto* l = stmt->As<ast::ForLoopStatement>()) {
841 return EmitForLoop(l);
842 }
843 if (auto* r = stmt->As<ast::ReturnStatement>()) {
844 return EmitReturn(r);
845 }
846 if (auto* s = stmt->As<ast::SwitchStatement>()) {
847 return EmitSwitch(s);
848 }
849 if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
850 return EmitVariable(line(), v->variable);
851 }
852
853 diagnostics_.add_error(
854 diag::System::Writer,
855 "unknown statement type: " + std::string(stmt->TypeInfo().name));
856 return false;
857 }
858
EmitStatements(const ast::StatementList & stmts)859 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
860 for (auto* s : stmts) {
861 if (!EmitStatement(s)) {
862 return false;
863 }
864 }
865 return true;
866 }
867
EmitStatementsWithIndent(const ast::StatementList & stmts)868 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
869 ScopedIndent si(this);
870 return EmitStatements(stmts);
871 }
872
EmitAssign(const ast::AssignmentStatement * stmt)873 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
874 auto out = line();
875
876 if (!EmitExpression(out, stmt->lhs)) {
877 return false;
878 }
879
880 out << " = ";
881
882 if (!EmitExpression(out, stmt->rhs)) {
883 return false;
884 }
885
886 out << ";";
887
888 return true;
889 }
890
EmitBreak(const ast::BreakStatement *)891 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
892 line() << "break;";
893 return true;
894 }
895
EmitCase(const ast::CaseStatement * stmt)896 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
897 if (stmt->IsDefault()) {
898 line() << "default: {";
899 } else {
900 auto out = line();
901 out << "case ";
902
903 bool first = true;
904 for (auto* selector : stmt->selectors) {
905 if (!first) {
906 out << ", ";
907 }
908
909 first = false;
910 if (!EmitLiteral(out, selector)) {
911 return false;
912 }
913 }
914 out << ": {";
915 }
916
917 if (!EmitStatementsWithIndent(stmt->body->statements)) {
918 return false;
919 }
920
921 line() << "}";
922 return true;
923 }
924
EmitContinue(const ast::ContinueStatement *)925 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
926 line() << "continue;";
927 return true;
928 }
929
EmitFallthrough(const ast::FallthroughStatement *)930 bool GeneratorImpl::EmitFallthrough(const ast::FallthroughStatement*) {
931 line() << "fallthrough;";
932 return true;
933 }
934
EmitIf(const ast::IfStatement * stmt)935 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
936 {
937 auto out = line();
938 out << "if (";
939 if (!EmitExpression(out, stmt->condition)) {
940 return false;
941 }
942 out << ") {";
943 }
944
945 if (!EmitStatementsWithIndent(stmt->body->statements)) {
946 return false;
947 }
948
949 for (auto* e : stmt->else_statements) {
950 if (e->condition) {
951 auto out = line();
952 out << "} elseif (";
953 if (!EmitExpression(out, e->condition)) {
954 return false;
955 }
956 out << ") {";
957 } else {
958 line() << "} else {";
959 }
960
961 if (!EmitStatementsWithIndent(e->body->statements)) {
962 return false;
963 }
964 }
965
966 line() << "}";
967
968 return true;
969 }
970
EmitDiscard(const ast::DiscardStatement *)971 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
972 line() << "discard;";
973 return true;
974 }
975
EmitLoop(const ast::LoopStatement * stmt)976 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
977 line() << "loop {";
978 increment_indent();
979
980 if (!EmitStatements(stmt->body->statements)) {
981 return false;
982 }
983
984 if (stmt->continuing && !stmt->continuing->Empty()) {
985 line();
986 line() << "continuing {";
987 if (!EmitStatementsWithIndent(stmt->continuing->statements)) {
988 return false;
989 }
990 line() << "}";
991 }
992
993 decrement_indent();
994 line() << "}";
995
996 return true;
997 }
998
EmitForLoop(const ast::ForLoopStatement * stmt)999 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
1000 TextBuffer init_buf;
1001 if (auto* init = stmt->initializer) {
1002 TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
1003 if (!EmitStatement(init)) {
1004 return false;
1005 }
1006 }
1007
1008 TextBuffer cont_buf;
1009 if (auto* cont = stmt->continuing) {
1010 TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
1011 if (!EmitStatement(cont)) {
1012 return false;
1013 }
1014 }
1015
1016 {
1017 auto out = line();
1018 out << "for";
1019 {
1020 ScopedParen sp(out);
1021 switch (init_buf.lines.size()) {
1022 case 0: // No initializer
1023 break;
1024 case 1: // Single line initializer statement
1025 out << TrimSuffix(init_buf.lines[0].content, ";");
1026 break;
1027 default: // Block initializer statement
1028 for (size_t i = 1; i < init_buf.lines.size(); i++) {
1029 // Indent all by the first line
1030 init_buf.lines[i].indent += current_buffer_->current_indent;
1031 }
1032 out << TrimSuffix(init_buf.String(), "\n");
1033 break;
1034 }
1035
1036 out << "; ";
1037
1038 if (auto* cond = stmt->condition) {
1039 if (!EmitExpression(out, cond)) {
1040 return false;
1041 }
1042 }
1043
1044 out << "; ";
1045
1046 switch (cont_buf.lines.size()) {
1047 case 0: // No continuing
1048 break;
1049 case 1: // Single line continuing statement
1050 out << TrimSuffix(cont_buf.lines[0].content, ";");
1051 break;
1052 default: // Block continuing statement
1053 for (size_t i = 1; i < cont_buf.lines.size(); i++) {
1054 // Indent all by the first line
1055 cont_buf.lines[i].indent += current_buffer_->current_indent;
1056 }
1057 out << TrimSuffix(cont_buf.String(), "\n");
1058 break;
1059 }
1060 }
1061 out << " {";
1062 }
1063
1064 if (!EmitStatementsWithIndent(stmt->body->statements)) {
1065 return false;
1066 }
1067
1068 line() << "}";
1069
1070 return true;
1071 }
1072
EmitReturn(const ast::ReturnStatement * stmt)1073 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
1074 auto out = line();
1075 out << "return";
1076 if (stmt->value) {
1077 out << " ";
1078 if (!EmitExpression(out, stmt->value)) {
1079 return false;
1080 }
1081 }
1082 out << ";";
1083 return true;
1084 }
1085
EmitSwitch(const ast::SwitchStatement * stmt)1086 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
1087 {
1088 auto out = line();
1089 out << "switch(";
1090 if (!EmitExpression(out, stmt->condition)) {
1091 return false;
1092 }
1093 out << ") {";
1094 }
1095
1096 {
1097 ScopedIndent si(this);
1098 for (auto* s : stmt->body) {
1099 if (!EmitCase(s)) {
1100 return false;
1101 }
1102 }
1103 }
1104
1105 line() << "}";
1106 return true;
1107 }
1108
1109 } // namespace wgsl
1110 } // namespace writer
1111 } // namespace tint
1112