• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /// Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/glsl/generator_impl.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 
24 #include "src/ast/call_statement.h"
25 #include "src/ast/fallthrough_statement.h"
26 #include "src/ast/internal_decoration.h"
27 #include "src/ast/interpolate_decoration.h"
28 #include "src/ast/override_decoration.h"
29 #include "src/ast/variable_decl_statement.h"
30 #include "src/debug.h"
31 #include "src/sem/array.h"
32 #include "src/sem/atomic_type.h"
33 #include "src/sem/block_statement.h"
34 #include "src/sem/call.h"
35 #include "src/sem/depth_multisampled_texture_type.h"
36 #include "src/sem/depth_texture_type.h"
37 #include "src/sem/function.h"
38 #include "src/sem/member_accessor_expression.h"
39 #include "src/sem/multisampled_texture_type.h"
40 #include "src/sem/sampled_texture_type.h"
41 #include "src/sem/statement.h"
42 #include "src/sem/storage_texture_type.h"
43 #include "src/sem/struct.h"
44 #include "src/sem/type_constructor.h"
45 #include "src/sem/type_conversion.h"
46 #include "src/sem/variable.h"
47 #include "src/transform/calculate_array_length.h"
48 #include "src/transform/glsl.h"
49 #include "src/utils/defer.h"
50 #include "src/utils/map.h"
51 #include "src/utils/scoped_assignment.h"
52 #include "src/writer/append_vector.h"
53 #include "src/writer/float_to_string.h"
54 
55 namespace {
56 
IsRelational(tint::ast::BinaryOp op)57 bool IsRelational(tint::ast::BinaryOp op) {
58   return op == tint::ast::BinaryOp::kEqual ||
59          op == tint::ast::BinaryOp::kNotEqual ||
60          op == tint::ast::BinaryOp::kLessThan ||
61          op == tint::ast::BinaryOp::kGreaterThan ||
62          op == tint::ast::BinaryOp::kLessThanEqual ||
63          op == tint::ast::BinaryOp::kGreaterThanEqual;
64 }
65 
66 }  // namespace
67 
68 namespace tint {
69 namespace writer {
70 namespace glsl {
71 namespace {
72 
73 const char kTempNamePrefix[] = "tint_tmp";
74 const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
75 
last_is_break_or_fallthrough(const ast::BlockStatement * stmts)76 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
77   return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
78 }
79 
80 }  // namespace
81 
GeneratorImpl(const Program * program)82 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
83 
84 GeneratorImpl::~GeneratorImpl() = default;
85 
Generate()86 bool GeneratorImpl::Generate() {
87   if (!builder_.HasTransformApplied<transform::Glsl>()) {
88     diagnostics_.add_error(
89         diag::System::Writer,
90         "GLSL writer requires the transform::Glsl sanitizer to have been "
91         "applied to the input program");
92     return false;
93   }
94 
95   const TypeInfo* last_kind = nullptr;
96   size_t last_padding_line = 0;
97 
98   line() << "#version 310 es";
99   line() << "precision mediump float;";
100 
101   auto helpers_insertion_point = current_buffer_->lines.size();
102 
103   line();
104 
105   for (auto* decl : builder_.AST().GlobalDeclarations()) {
106     if (decl->Is<ast::Alias>()) {
107       continue;  // Ignore aliases.
108     }
109 
110     // Emit a new line between declarations if the type of declaration has
111     // changed, or we're about to emit a function
112     auto* kind = &decl->TypeInfo();
113     if (current_buffer_->lines.size() != last_padding_line) {
114       if (last_kind && (last_kind != kind || decl->Is<ast::Function>())) {
115         line();
116         last_padding_line = current_buffer_->lines.size();
117       }
118     }
119     last_kind = kind;
120 
121     if (auto* global = decl->As<ast::Variable>()) {
122       if (!EmitGlobalVariable(global)) {
123         return false;
124       }
125     } else if (auto* str = decl->As<ast::Struct>()) {
126       if (!str->IsBlockDecorated()) {
127         if (!EmitStructType(current_buffer_, builder_.Sem().Get(str))) {
128           return false;
129         }
130       }
131     } else if (auto* func = decl->As<ast::Function>()) {
132       if (func->IsEntryPoint()) {
133         if (!EmitEntryPointFunction(func)) {
134           return false;
135         }
136       } else {
137         if (!EmitFunction(func)) {
138           return false;
139         }
140       }
141     } else {
142       TINT_ICE(Writer, diagnostics_)
143           << "unhandled module-scope declaration: " << decl->TypeInfo().name;
144       return false;
145     }
146   }
147 
148   if (!helpers_.lines.empty()) {
149     current_buffer_->Insert("", helpers_insertion_point++, 0);
150     current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
151   }
152 
153   return true;
154 }
155 
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)156 bool GeneratorImpl::EmitIndexAccessor(
157     std::ostream& out,
158     const ast::IndexAccessorExpression* expr) {
159   if (!EmitExpression(out, expr->object)) {
160     return false;
161   }
162   out << "[";
163 
164   if (!EmitExpression(out, expr->index)) {
165     return false;
166   }
167   out << "]";
168 
169   return true;
170 }
171 
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)172 bool GeneratorImpl::EmitBitcast(std::ostream& out,
173                                 const ast::BitcastExpression* expr) {
174   auto* src_type = TypeOf(expr->expr);
175   auto* dst_type = TypeOf(expr);
176 
177   if (!dst_type->is_integer_scalar_or_vector() &&
178       !dst_type->is_float_scalar_or_vector()) {
179     diagnostics_.add_error(
180         diag::System::Writer,
181         "Unable to do bitcast to type " + dst_type->type_name());
182     return false;
183   }
184 
185   if (src_type == dst_type) {
186     return EmitExpression(out, expr->expr);
187   }
188 
189   if (src_type->is_float_scalar_or_vector() &&
190       dst_type->is_signed_scalar_or_vector()) {
191     out << "floatBitsToInt";
192   } else if (src_type->is_float_scalar_or_vector() &&
193              dst_type->is_unsigned_scalar_or_vector()) {
194     out << "floatBitsToUint";
195   } else if (src_type->is_signed_scalar_or_vector() &&
196              dst_type->is_float_scalar_or_vector()) {
197     out << "intBitsToFloat";
198   } else if (src_type->is_unsigned_scalar_or_vector() &&
199              dst_type->is_float_scalar_or_vector()) {
200     out << "uintBitsToFloat";
201   } else {
202     if (!EmitType(out, dst_type, ast::StorageClass::kNone,
203                   ast::Access::kReadWrite, "")) {
204       return false;
205     }
206   }
207   out << "(";
208   if (!EmitExpression(out, expr->expr)) {
209     return false;
210   }
211   out << ")";
212   return true;
213 }
214 
EmitAssign(const ast::AssignmentStatement * stmt)215 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
216   auto out = line();
217   if (!EmitExpression(out, stmt->lhs)) {
218     return false;
219   }
220   out << " = ";
221   if (!EmitExpression(out, stmt->rhs)) {
222     return false;
223   }
224   out << ";";
225   return true;
226 }
227 
EmitVectorRelational(std::ostream & out,const ast::BinaryExpression * expr)228 bool GeneratorImpl::EmitVectorRelational(std::ostream& out,
229                                          const ast::BinaryExpression* expr) {
230   switch (expr->op) {
231     case ast::BinaryOp::kEqual:
232       out << "equal";
233       break;
234     case ast::BinaryOp::kNotEqual:
235       out << "notEqual";
236       break;
237     case ast::BinaryOp::kLessThan:
238       out << "lessThan";
239       break;
240     case ast::BinaryOp::kGreaterThan:
241       out << "greaterThan";
242       break;
243     case ast::BinaryOp::kLessThanEqual:
244       out << "lessThanEqual";
245       break;
246     case ast::BinaryOp::kGreaterThanEqual:
247       out << "greaterThanEqual";
248       break;
249     default:
250       break;
251   }
252   out << "(";
253   if (!EmitExpression(out, expr->lhs)) {
254     return false;
255   }
256   out << ", ";
257   if (!EmitExpression(out, expr->rhs)) {
258     return false;
259   }
260   out << ")";
261   return true;
262 }
263 
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)264 bool GeneratorImpl::EmitBinary(std::ostream& out,
265                                const ast::BinaryExpression* expr) {
266   if (IsRelational(expr->op) && !TypeOf(expr->lhs)->UnwrapRef()->is_scalar()) {
267     return EmitVectorRelational(out, expr);
268   }
269   if (expr->op == ast::BinaryOp::kLogicalAnd ||
270       expr->op == ast::BinaryOp::kLogicalOr) {
271     auto name = UniqueIdentifier(kTempNamePrefix);
272 
273     {
274       auto pre = line();
275       pre << "bool " << name << " = ";
276       if (!EmitExpression(pre, expr->lhs)) {
277         return false;
278       }
279       pre << ";";
280     }
281 
282     if (expr->op == ast::BinaryOp::kLogicalOr) {
283       line() << "if (!" << name << ") {";
284     } else {
285       line() << "if (" << name << ") {";
286     }
287 
288     {
289       ScopedIndent si(this);
290       auto pre = line();
291       pre << name << " = ";
292       if (!EmitExpression(pre, expr->rhs)) {
293         return false;
294       }
295       pre << ";";
296     }
297 
298     line() << "}";
299 
300     out << "(" << name << ")";
301     return true;
302   }
303 
304   out << "(";
305   if (!EmitExpression(out, expr->lhs)) {
306     return false;
307   }
308   out << " ";
309 
310   switch (expr->op) {
311     case ast::BinaryOp::kAnd:
312       out << "&";
313       break;
314     case ast::BinaryOp::kOr:
315       out << "|";
316       break;
317     case ast::BinaryOp::kXor:
318       out << "^";
319       break;
320     case ast::BinaryOp::kLogicalAnd:
321     case ast::BinaryOp::kLogicalOr: {
322       // These are both handled above.
323       TINT_UNREACHABLE(Writer, diagnostics_);
324       return false;
325     }
326     case ast::BinaryOp::kEqual:
327       out << "==";
328       break;
329     case ast::BinaryOp::kNotEqual:
330       out << "!=";
331       break;
332     case ast::BinaryOp::kLessThan:
333       out << "<";
334       break;
335     case ast::BinaryOp::kGreaterThan:
336       out << ">";
337       break;
338     case ast::BinaryOp::kLessThanEqual:
339       out << "<=";
340       break;
341     case ast::BinaryOp::kGreaterThanEqual:
342       out << ">=";
343       break;
344     case ast::BinaryOp::kShiftLeft:
345       out << "<<";
346       break;
347     case ast::BinaryOp::kShiftRight:
348       // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
349       // implementation-defined behaviour for negative LHS.  We may have to
350       // generate extra code to implement WGSL-specified behaviour for negative
351       // LHS.
352       out << R"(>>)";
353       break;
354 
355     case ast::BinaryOp::kAdd:
356       out << "+";
357       break;
358     case ast::BinaryOp::kSubtract:
359       out << "-";
360       break;
361     case ast::BinaryOp::kMultiply:
362       out << "*";
363       break;
364     case ast::BinaryOp::kDivide:
365       out << "/";
366       break;
367     case ast::BinaryOp::kModulo:
368       out << "%";
369       break;
370     case ast::BinaryOp::kNone:
371       diagnostics_.add_error(diag::System::Writer,
372                              "missing binary operation type");
373       return false;
374   }
375   out << " ";
376 
377   if (!EmitExpression(out, expr->rhs)) {
378     return false;
379   }
380 
381   out << ")";
382   return true;
383 }
384 
EmitStatements(const ast::StatementList & stmts)385 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
386   for (auto* s : stmts) {
387     if (!EmitStatement(s)) {
388       return false;
389     }
390   }
391   return true;
392 }
393 
EmitStatementsWithIndent(const ast::StatementList & stmts)394 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
395   ScopedIndent si(this);
396   return EmitStatements(stmts);
397 }
398 
EmitBlock(const ast::BlockStatement * stmt)399 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
400   line() << "{";
401   if (!EmitStatementsWithIndent(stmt->statements)) {
402     return false;
403   }
404   line() << "}";
405   return true;
406 }
407 
EmitBreak(const ast::BreakStatement *)408 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
409   line() << "break;";
410   return true;
411 }
412 
EmitCall(std::ostream & out,const ast::CallExpression * expr)413 bool GeneratorImpl::EmitCall(std::ostream& out,
414                              const ast::CallExpression* expr) {
415   auto* call = builder_.Sem().Get(expr);
416   auto* target = call->Target();
417 
418   if (auto* func = target->As<sem::Function>()) {
419     return EmitFunctionCall(out, call, func);
420   }
421   if (auto* intrinsic = target->As<sem::Intrinsic>()) {
422     return EmitIntrinsicCall(out, call, intrinsic);
423   }
424   if (auto* cast = target->As<sem::TypeConversion>()) {
425     return EmitTypeConversion(out, call, cast);
426   }
427   if (auto* ctor = target->As<sem::TypeConstructor>()) {
428     return EmitTypeConstructor(out, call, ctor);
429   }
430   TINT_ICE(Writer, diagnostics_)
431       << "unhandled call target: " << target->TypeInfo().name;
432   return false;
433 }
434 
EmitFunctionCall(std::ostream & out,const sem::Call * call,const sem::Function * func)435 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
436                                      const sem::Call* call,
437                                      const sem::Function* func) {
438   const auto& args = call->Arguments();
439   auto* decl = call->Declaration();
440   auto* ident = decl->target.name;
441 
442   auto name = builder_.Symbols().NameFor(ident->symbol);
443   auto caller_sym = ident->symbol;
444 
445   if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
446           func->Declaration()->decorations)) {
447     // Special function generated by the CalculateArrayLength transform for
448     // calling X.GetDimensions(Y)
449     if (!EmitExpression(out, args[0]->Declaration())) {
450       return false;
451     }
452     out << ".GetDimensions(";
453     if (!EmitExpression(out, args[1]->Declaration())) {
454       return false;
455     }
456     out << ")";
457     return true;
458   }
459 
460   out << name << "(";
461 
462   bool first = true;
463   for (auto* arg : args) {
464     if (!first) {
465       out << ", ";
466     }
467     first = false;
468 
469     if (!EmitExpression(out, arg->Declaration())) {
470       return false;
471     }
472   }
473 
474   out << ")";
475   return true;
476 }
477 
EmitIntrinsicCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)478 bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
479                                       const sem::Call* call,
480                                       const sem::Intrinsic* intrinsic) {
481   auto* expr = call->Declaration();
482   if (intrinsic->IsTexture()) {
483     return EmitTextureCall(out, call, intrinsic);
484   }
485   if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
486     return EmitSelectCall(out, expr);
487   }
488   if (intrinsic->Type() == sem::IntrinsicType::kDot) {
489     return EmitDotCall(out, expr, intrinsic);
490   }
491   if (intrinsic->Type() == sem::IntrinsicType::kModf) {
492     return EmitModfCall(out, expr, intrinsic);
493   }
494   if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
495     return EmitFrexpCall(out, expr, intrinsic);
496   }
497   if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
498     return EmitIsNormalCall(out, expr, intrinsic);
499   }
500   if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
501     return EmitExpression(out, expr->args[0]);  // [DEPRECATED]
502   }
503   if (intrinsic->IsDataPacking()) {
504     return EmitDataPackingCall(out, expr, intrinsic);
505   }
506   if (intrinsic->IsDataUnpacking()) {
507     return EmitDataUnpackingCall(out, expr, intrinsic);
508   }
509   if (intrinsic->IsBarrier()) {
510     return EmitBarrierCall(out, intrinsic);
511   }
512   if (intrinsic->IsAtomic()) {
513     return EmitWorkgroupAtomicCall(out, expr, intrinsic);
514   }
515   auto name = generate_builtin_name(intrinsic);
516   if (name.empty()) {
517     return false;
518   }
519 
520   out << name << "(";
521 
522   bool first = true;
523   for (auto* arg : call->Arguments()) {
524     if (!first) {
525       out << ", ";
526     }
527     first = false;
528 
529     if (!EmitExpression(out, arg->Declaration())) {
530       return false;
531     }
532   }
533 
534   out << ")";
535   return true;
536 }
537 
EmitTypeConversion(std::ostream & out,const sem::Call * call,const sem::TypeConversion * conv)538 bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
539                                        const sem::Call* call,
540                                        const sem::TypeConversion* conv) {
541   if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
542                 ast::Access::kReadWrite, "")) {
543     return false;
544   }
545   out << "(";
546 
547   if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
548     return false;
549   }
550 
551   out << ")";
552   return true;
553 }
554 
EmitTypeConstructor(std::ostream & out,const sem::Call * call,const sem::TypeConstructor * ctor)555 bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
556                                         const sem::Call* call,
557                                         const sem::TypeConstructor* ctor) {
558   auto* type = ctor->ReturnType();
559 
560   // If the type constructor is empty then we need to construct with the zero
561   // value for all components.
562   if (call->Arguments().empty()) {
563     return EmitZeroValue(out, type);
564   }
565 
566   auto it = structure_builders_.find(As<sem::Struct>(type));
567   if (it != structure_builders_.end()) {
568     out << it->second << "(";
569   } else {
570     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
571                   "")) {
572       return false;
573     }
574     out << "(";
575   }
576 
577   bool first = true;
578   for (auto* arg : call->Arguments()) {
579     if (!first) {
580       out << ", ";
581     }
582     first = false;
583 
584     if (!EmitExpression(out, arg->Declaration())) {
585       return false;
586     }
587   }
588 
589   out << ")";
590   return true;
591 }
592 
EmitWorkgroupAtomicCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)593 bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& out,
594                                             const ast::CallExpression* expr,
595                                             const sem::Intrinsic* intrinsic) {
596   auto call = [&](const char* name) {
597     out << name;
598     {
599       ScopedParen sp(out);
600       for (size_t i = 0; i < expr->args.size(); i++) {
601         auto* arg = expr->args[i];
602         if (i > 0) {
603           out << ", ";
604         }
605         if (!EmitExpression(out, arg)) {
606           return false;
607         }
608       }
609     }
610     return true;
611   };
612 
613   switch (intrinsic->Type()) {
614     case sem::IntrinsicType::kAtomicLoad: {
615       // GLSL does not have an atomicLoad, so we emulate it with
616       // atomicOr using 0 as the OR value
617       out << "atomicOr";
618       {
619         ScopedParen sp(out);
620         if (!EmitExpression(out, expr->args[0])) {
621           return false;
622         }
623         out << ", 0";
624         if (intrinsic->ReturnType()->Is<sem::U32>()) {
625           out << "u";
626         }
627       }
628       return true;
629     }
630     case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
631       return CallIntrinsicHelper(
632           out, expr, intrinsic,
633           [&](TextBuffer* b, const std::vector<std::string>& params) {
634             {
635               auto pre = line(b);
636               if (!EmitTypeAndName(pre, intrinsic->ReturnType(),
637                                    ast::StorageClass::kNone,
638                                    ast::Access::kUndefined, "result")) {
639                 return false;
640               }
641               pre << ";";
642             }
643             {
644               auto pre = line(b);
645               pre << "result.x = atomicCompSwap";
646               {
647                 ScopedParen sp(pre);
648                 pre << params[0];
649                 pre << ", " << params[1];
650                 pre << ", " << params[2];
651               }
652               pre << ";";
653             }
654             {
655               auto pre = line(b);
656               pre << "result.y = result.x == " << params[2] << " ? ";
657               if (TypeOf(expr->args[2])->Is<sem::U32>()) {
658                 pre << "1u : 0u;";
659               } else {
660                 pre << "1 : 0;";
661               }
662             }
663             line(b) << "return result;";
664             return true;
665           });
666     }
667 
668     case sem::IntrinsicType::kAtomicAdd:
669     case sem::IntrinsicType::kAtomicSub:
670       return call("atomicAdd");
671 
672     case sem::IntrinsicType::kAtomicMax:
673       return call("atomicMax");
674 
675     case sem::IntrinsicType::kAtomicMin:
676       return call("atomicMin");
677 
678     case sem::IntrinsicType::kAtomicAnd:
679       return call("atomicAnd");
680 
681     case sem::IntrinsicType::kAtomicOr:
682       return call("atomicOr");
683 
684     case sem::IntrinsicType::kAtomicXor:
685       return call("atomicXor");
686 
687     case sem::IntrinsicType::kAtomicExchange:
688     case sem::IntrinsicType::kAtomicStore:
689       // GLSL does not have an atomicStore, so we emulate it with
690       // atomicExchange.
691       return call("atomicExchange");
692 
693     default:
694       break;
695   }
696 
697   TINT_UNREACHABLE(Writer, diagnostics_)
698       << "unsupported atomic intrinsic: " << intrinsic->Type();
699   return false;
700 }
701 
EmitSelectCall(std::ostream & out,const ast::CallExpression * expr)702 bool GeneratorImpl::EmitSelectCall(std::ostream& out,
703                                    const ast::CallExpression* expr) {
704   auto* expr_false = expr->args[0];
705   auto* expr_true = expr->args[1];
706   auto* expr_cond = expr->args[2];
707   ScopedParen paren(out);
708   if (!EmitExpression(out, expr_cond)) {
709     return false;
710   }
711 
712   out << " ? ";
713 
714   if (!EmitExpression(out, expr_true)) {
715     return false;
716   }
717 
718   out << " : ";
719 
720   if (!EmitExpression(out, expr_false)) {
721     return false;
722   }
723 
724   return true;
725 }
726 
EmitDotCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)727 bool GeneratorImpl::EmitDotCall(std::ostream& out,
728                                 const ast::CallExpression* expr,
729                                 const sem::Intrinsic* intrinsic) {
730   auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
731   std::string fn = "dot";
732   if (vec_ty->type()->is_integer_scalar()) {
733     // GLSL does not have a builtin for dot() with integer vector types.
734     // Generate the helper function if it hasn't been created already
735     fn = utils::GetOrCreate(int_dot_funcs_, vec_ty, [&]() -> std::string {
736       TextBuffer b;
737       TINT_DEFER(helpers_.Append(b));
738 
739       auto fn_name = UniqueIdentifier("tint_int_dot");
740 
741       std::string v;
742       {
743         std::stringstream s;
744         if (!EmitType(s, vec_ty->type(), ast::StorageClass::kNone,
745                       ast::Access::kRead, "")) {
746           return "";
747         }
748         v = s.str();
749       }
750       {  // (u)int tint_int_dot([i|u]vecN a, [i|u]vecN b) {
751         auto l = line(&b);
752         if (!EmitType(l, vec_ty->type(), ast::StorageClass::kNone,
753                       ast::Access::kRead, "")) {
754           return "";
755         }
756         l << " " << fn_name << "(";
757         if (!EmitType(l, vec_ty, ast::StorageClass::kNone, ast::Access::kRead,
758                       "")) {
759           return "";
760         }
761         l << " a, ";
762         if (!EmitType(l, vec_ty, ast::StorageClass::kNone, ast::Access::kRead,
763                       "")) {
764           return "";
765         }
766         l << " b) {";
767       }
768       {
769         auto l = line(&b);
770         l << "  return ";
771         for (uint32_t i = 0; i < vec_ty->Width(); i++) {
772           if (i > 0) {
773             l << " + ";
774           }
775           l << "a[" << i << "]*b[" << i << "]";
776         }
777         l << ";";
778       }
779       line(&b) << "}";
780       return fn_name;
781     });
782     if (fn.empty()) {
783       return false;
784     }
785   }
786 
787   out << fn << "(";
788   if (!EmitExpression(out, expr->args[0])) {
789     return false;
790   }
791   out << ", ";
792   if (!EmitExpression(out, expr->args[1])) {
793     return false;
794   }
795   out << ")";
796   return true;
797 }
798 
EmitModfCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)799 bool GeneratorImpl::EmitModfCall(std::ostream& out,
800                                  const ast::CallExpression* expr,
801                                  const sem::Intrinsic* intrinsic) {
802   if (expr->args.size() == 1) {
803     return CallIntrinsicHelper(
804         out, expr, intrinsic,
805         [&](TextBuffer* b, const std::vector<std::string>& params) {
806           auto* ty = intrinsic->Parameters()[0]->Type();
807           auto in = params[0];
808 
809           std::string width;
810           if (auto* vec = ty->As<sem::Vector>()) {
811             width = std::to_string(vec->Width());
812           }
813 
814           // Emit the builtin return type unique to this overload. This does not
815           // exist in the AST, so it will not be generated in Generate().
816           if (!EmitStructType(&helpers_,
817                               intrinsic->ReturnType()->As<sem::Struct>())) {
818             return false;
819           }
820 
821           line(b) << "float" << width << " whole;";
822           line(b) << "float" << width << " fract = modf(" << in << ", whole);";
823           {
824             auto l = line(b);
825             if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
826                           ast::Access::kUndefined, "")) {
827               return false;
828             }
829             l << " result = {fract, whole};";
830           }
831           line(b) << "return result;";
832           return true;
833         });
834   }
835 
836   // DEPRECATED
837   out << "modf";
838   ScopedParen sp(out);
839   if (!EmitExpression(out, expr->args[0])) {
840     return false;
841   }
842   out << ", ";
843   if (!EmitExpression(out, expr->args[1])) {
844     return false;
845   }
846   return true;
847 }
848 
EmitFrexpCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)849 bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
850                                   const ast::CallExpression* expr,
851                                   const sem::Intrinsic* intrinsic) {
852   if (expr->args.size() == 1) {
853     return CallIntrinsicHelper(
854         out, expr, intrinsic,
855         [&](TextBuffer* b, const std::vector<std::string>& params) {
856           auto* ty = intrinsic->Parameters()[0]->Type();
857           auto in = params[0];
858 
859           std::string width;
860           if (auto* vec = ty->As<sem::Vector>()) {
861             width = std::to_string(vec->Width());
862           }
863 
864           // Emit the builtin return type unique to this overload. This does not
865           // exist in the AST, so it will not be generated in Generate().
866           if (!EmitStructType(&helpers_,
867                               intrinsic->ReturnType()->As<sem::Struct>())) {
868             return false;
869           }
870 
871           line(b) << "float" << width << " exp;";
872           line(b) << "float" << width << " sig = frexp(" << in << ", exp);";
873           {
874             auto l = line(b);
875             if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
876                           ast::Access::kUndefined, "")) {
877               return false;
878             }
879             l << " result = {sig, int" << width << "(exp)};";
880           }
881           line(b) << "return result;";
882           return true;
883         });
884   }
885   // DEPRECATED
886   // Exponent is an integer in WGSL, but HLSL wants a float.
887   // We need to make the call with a temporary float, and then cast.
888   return CallIntrinsicHelper(
889       out, expr, intrinsic,
890       [&](TextBuffer* b, const std::vector<std::string>& params) {
891         auto* significand_ty = intrinsic->Parameters()[0]->Type();
892         auto significand = params[0];
893         auto* exponent_ty = intrinsic->Parameters()[1]->Type();
894         auto exponent = params[1];
895 
896         std::string width;
897         if (auto* vec = significand_ty->As<sem::Vector>()) {
898           width = std::to_string(vec->Width());
899         }
900 
901         // Exponent is an integer, which HLSL does not have an overload for.
902         // We need to cast from a float.
903         line(b) << "float" << width << " float_exp;";
904         line(b) << "float" << width << " significand = frexp(" << significand
905                 << ", float_exp);";
906         {
907           auto l = line(b);
908           l << exponent << " = ";
909           if (!EmitType(l, exponent_ty->UnwrapPtr(), ast::StorageClass::kNone,
910                         ast::Access::kUndefined, "")) {
911             return false;
912           }
913           l << "(float_exp);";
914         }
915         line(b) << "return significand;";
916         return true;
917       });
918 }
919 
EmitIsNormalCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)920 bool GeneratorImpl::EmitIsNormalCall(std::ostream& out,
921                                      const ast::CallExpression* expr,
922                                      const sem::Intrinsic* intrinsic) {
923   // GLSL doesn't have a isNormal intrinsic, we need to emulate
924   return CallIntrinsicHelper(
925       out, expr, intrinsic,
926       [&](TextBuffer* b, const std::vector<std::string>& params) {
927         auto* input_ty = intrinsic->Parameters()[0]->Type();
928 
929         std::string width;
930         if (auto* vec = input_ty->As<sem::Vector>()) {
931           width = std::to_string(vec->Width());
932         }
933 
934         constexpr auto* kExponentMask = "0x7f80000";
935         constexpr auto* kMinNormalExponent = "0x0080000";
936         constexpr auto* kMaxNormalExponent = "0x7f00000";
937 
938         line(b) << "uint" << width << " exponent = asuint(" << params[0]
939                 << ") & " << kExponentMask << ";";
940         line(b) << "uint" << width << " clamped = "
941                 << "clamp(exponent, " << kMinNormalExponent << ", "
942                 << kMaxNormalExponent << ");";
943         line(b) << "return clamped == exponent;";
944         return true;
945       });
946 }
947 
EmitDataPackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)948 bool GeneratorImpl::EmitDataPackingCall(std::ostream& out,
949                                         const ast::CallExpression* expr,
950                                         const sem::Intrinsic* intrinsic) {
951   return CallIntrinsicHelper(
952       out, expr, intrinsic,
953       [&](TextBuffer* b, const std::vector<std::string>& params) {
954         uint32_t dims = 2;
955         bool is_signed = false;
956         uint32_t scale = 65535;
957         if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
958             intrinsic->Type() == sem::IntrinsicType::kPack4x8unorm) {
959           dims = 4;
960           scale = 255;
961         }
962         if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
963             intrinsic->Type() == sem::IntrinsicType::kPack2x16snorm) {
964           is_signed = true;
965           scale = (scale - 1) / 2;
966         }
967         switch (intrinsic->Type()) {
968           case sem::IntrinsicType::kPack4x8snorm:
969           case sem::IntrinsicType::kPack4x8unorm:
970           case sem::IntrinsicType::kPack2x16snorm:
971           case sem::IntrinsicType::kPack2x16unorm: {
972             {
973               auto l = line(b);
974               l << (is_signed ? "" : "u") << "int" << dims
975                 << " i = " << (is_signed ? "" : "u") << "int" << dims
976                 << "(round(clamp(" << params[0] << ", "
977                 << (is_signed ? "-1.0" : "0.0") << ", 1.0) * " << scale
978                 << ".0))";
979               if (is_signed) {
980                 l << " & " << (dims == 4 ? "0xff" : "0xffff");
981               }
982               l << ";";
983             }
984             {
985               auto l = line(b);
986               l << "return ";
987               if (is_signed) {
988                 l << "asuint";
989               }
990               l << "(i.x | i.y << " << (32 / dims);
991               if (dims == 4) {
992                 l << " | i.z << 16 | i.w << 24";
993               }
994               l << ");";
995             }
996             break;
997           }
998           case sem::IntrinsicType::kPack2x16float: {
999             line(b) << "uint2 i = f32tof16(" << params[0] << ");";
1000             line(b) << "return i.x | (i.y << 16);";
1001             break;
1002           }
1003           default:
1004             diagnostics_.add_error(
1005                 diag::System::Writer,
1006                 "Internal error: unhandled data packing intrinsic");
1007             return false;
1008         }
1009 
1010         return true;
1011       });
1012 }
1013 
EmitDataUnpackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1014 bool GeneratorImpl::EmitDataUnpackingCall(std::ostream& out,
1015                                           const ast::CallExpression* expr,
1016                                           const sem::Intrinsic* intrinsic) {
1017   return CallIntrinsicHelper(
1018       out, expr, intrinsic,
1019       [&](TextBuffer* b, const std::vector<std::string>& params) {
1020         uint32_t dims = 2;
1021         bool is_signed = false;
1022         uint32_t scale = 65535;
1023         if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1024             intrinsic->Type() == sem::IntrinsicType::kUnpack4x8unorm) {
1025           dims = 4;
1026           scale = 255;
1027         }
1028         if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1029             intrinsic->Type() == sem::IntrinsicType::kUnpack2x16snorm) {
1030           is_signed = true;
1031           scale = (scale - 1) / 2;
1032         }
1033         switch (intrinsic->Type()) {
1034           case sem::IntrinsicType::kUnpack4x8snorm:
1035           case sem::IntrinsicType::kUnpack2x16snorm: {
1036             line(b) << "int j = int(" << params[0] << ");";
1037             {  // Perform sign extension on the converted values.
1038               auto l = line(b);
1039               l << "int" << dims << " i = int" << dims << "(";
1040               if (dims == 2) {
1041                 l << "j << 16, j) >> 16";
1042               } else {
1043                 l << "j << 24, j << 16, j << 8, j) >> 24";
1044               }
1045               l << ";";
1046             }
1047             line(b) << "return clamp(float" << dims << "(i) / " << scale
1048                     << ".0, " << (is_signed ? "-1.0" : "0.0") << ", 1.0);";
1049             break;
1050           }
1051           case sem::IntrinsicType::kUnpack4x8unorm:
1052           case sem::IntrinsicType::kUnpack2x16unorm: {
1053             line(b) << "uint j = " << params[0] << ";";
1054             {
1055               auto l = line(b);
1056               l << "uint" << dims << " i = uint" << dims << "(";
1057               l << "j & " << (dims == 2 ? "0xffff" : "0xff") << ", ";
1058               if (dims == 4) {
1059                 l << "(j >> " << (32 / dims)
1060                   << ") & 0xff, (j >> 16) & 0xff, j >> 24";
1061               } else {
1062                 l << "j >> " << (32 / dims);
1063               }
1064               l << ");";
1065             }
1066             line(b) << "return float" << dims << "(i) / " << scale << ".0;";
1067             break;
1068           }
1069           case sem::IntrinsicType::kUnpack2x16float:
1070             line(b) << "uint i = " << params[0] << ";";
1071             line(b) << "return f16tof32(uint2(i & 0xffff, i >> 16));";
1072             break;
1073           default:
1074             diagnostics_.add_error(
1075                 diag::System::Writer,
1076                 "Internal error: unhandled data packing intrinsic");
1077             return false;
1078         }
1079 
1080         return true;
1081       });
1082 }
1083 
EmitBarrierCall(std::ostream & out,const sem::Intrinsic * intrinsic)1084 bool GeneratorImpl::EmitBarrierCall(std::ostream& out,
1085                                     const sem::Intrinsic* intrinsic) {
1086   // TODO(crbug.com/tint/661): Combine sequential barriers to a single
1087   // instruction.
1088   if (intrinsic->Type() == sem::IntrinsicType::kWorkgroupBarrier) {
1089     out << "memoryBarrierShared()";
1090   } else if (intrinsic->Type() == sem::IntrinsicType::kStorageBarrier) {
1091     out << "memoryBarrierBuffer()";
1092   } else {
1093     TINT_UNREACHABLE(Writer, diagnostics_)
1094         << "unexpected barrier intrinsic type " << sem::str(intrinsic->Type());
1095     return false;
1096   }
1097   return true;
1098 }
1099 
EmitTextureCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)1100 bool GeneratorImpl::EmitTextureCall(std::ostream& out,
1101                                     const sem::Call* call,
1102                                     const sem::Intrinsic* intrinsic) {
1103   using Usage = sem::ParameterUsage;
1104 
1105   auto& signature = intrinsic->Signature();
1106   auto* expr = call->Declaration();
1107   auto arguments = expr->args;
1108 
1109   // Returns the argument with the given usage
1110   auto arg = [&](Usage usage) {
1111     int idx = signature.IndexOf(usage);
1112     return (idx >= 0) ? arguments[idx] : nullptr;
1113   };
1114 
1115   auto* texture = arg(Usage::kTexture);
1116   if (!texture) {
1117     TINT_ICE(Writer, diagnostics_) << "missing texture argument";
1118     return false;
1119   }
1120 
1121   auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
1122 
1123   switch (intrinsic->Type()) {
1124     case sem::IntrinsicType::kTextureDimensions: {
1125       if (texture_type->Is<sem::StorageTexture>()) {
1126         out << "imageSize(";
1127       } else {
1128         out << "textureSize(";
1129       }
1130       if (!EmitExpression(out, texture)) {
1131         return false;
1132       }
1133 
1134       // The LOD parameter is mandatory on textureSize() for non-multisampled
1135       // textures.
1136       if (!texture_type->Is<sem::StorageTexture>() &&
1137           !texture_type->Is<sem::MultisampledTexture>() &&
1138           !texture_type->Is<sem::DepthMultisampledTexture>()) {
1139         out << ", ";
1140         if (auto* level_arg = arg(Usage::kLevel)) {
1141           if (!EmitExpression(out, level_arg)) {
1142             return false;
1143           }
1144         } else {
1145           out << "0";
1146         }
1147       }
1148       out << ")";
1149       // textureSize() on sampler2dArray returns the array size in the
1150       // final component, so strip it out.
1151       if (texture_type->dim() == ast::TextureDimension::k2dArray) {
1152         out << ".xy";
1153       }
1154       return true;
1155     }
1156     // TODO(senorblanco): determine if this works for array textures
1157     case sem::IntrinsicType::kTextureNumLayers:
1158     case sem::IntrinsicType::kTextureNumLevels: {
1159       out << "textureQueryLevels(";
1160       if (!EmitExpression(out, texture)) {
1161         return false;
1162       }
1163       out << ");";
1164       return true;
1165     }
1166     case sem::IntrinsicType::kTextureNumSamples: {
1167       out << "textureSamples(";
1168       if (!EmitExpression(out, texture)) {
1169         return false;
1170       }
1171       out << ");";
1172       return true;
1173     }
1174     default:
1175       break;
1176   }
1177 
1178   uint32_t glsl_ret_width = 4u;
1179 
1180   switch (intrinsic->Type()) {
1181     case sem::IntrinsicType::kTextureSample:
1182     case sem::IntrinsicType::kTextureSampleBias:
1183       out << "texture(";
1184       break;
1185     case sem::IntrinsicType::kTextureSampleLevel:
1186       out << "textureLod(";
1187       break;
1188     case sem::IntrinsicType::kTextureGather:
1189     case sem::IntrinsicType::kTextureGatherCompare:
1190       out << (intrinsic->Signature().IndexOf(sem::ParameterUsage::kOffset) < 0
1191                   ? "textureGather("
1192                   : "textureGatherOffset(");
1193       break;
1194     case sem::IntrinsicType::kTextureSampleGrad:
1195       out << "textureGrad(";
1196       break;
1197     case sem::IntrinsicType::kTextureSampleCompare:
1198       out << "texture(";
1199       glsl_ret_width = 1;
1200       break;
1201     case sem::IntrinsicType::kTextureSampleCompareLevel:
1202       out << "texture(";
1203       glsl_ret_width = 1;
1204       break;
1205     case sem::IntrinsicType::kTextureLoad:
1206       out << "texelFetch(";
1207       break;
1208     case sem::IntrinsicType::kTextureStore:
1209       out << "imageStore(";
1210       break;
1211     default:
1212       diagnostics_.add_error(
1213           diag::System::Writer,
1214           "Internal compiler error: Unhandled texture intrinsic '" +
1215               std::string(intrinsic->str()) + "'");
1216       return false;
1217   }
1218 
1219   if (!EmitExpression(out, texture))
1220     return false;
1221 
1222   out << ", ";
1223 
1224   auto* param_coords = arg(Usage::kCoords);
1225   if (!param_coords) {
1226     TINT_ICE(Writer, diagnostics_) << "missing coords argument";
1227     return false;
1228   }
1229 
1230   if (auto* array_index = arg(Usage::kArrayIndex)) {
1231     // Array index needs to be appended to the coordinates.
1232     auto* packed = AppendVector(&builder_, param_coords, array_index);
1233     if (!EmitExpression(out, packed->Declaration())) {
1234       return false;
1235     }
1236   } else {
1237     if (!EmitExpression(out, param_coords)) {
1238       return false;
1239     }
1240   }
1241 
1242   for (auto usage : {Usage::kDepthRef, Usage::kBias, Usage::kLevel, Usage::kDdx,
1243                      Usage::kDdy, Usage::kSampleIndex, Usage::kOffset,
1244                      Usage::kComponent, Usage::kValue}) {
1245     if (auto* e = arg(usage)) {
1246       out << ", ";
1247       if (!EmitExpression(out, e)) {
1248         return false;
1249       }
1250     }
1251   }
1252 
1253   out << ")";
1254 
1255   if (intrinsic->ReturnType()->Is<sem::Void>()) {
1256     return true;
1257   }
1258   // If the intrinsic return type does not match the number of elements of the
1259   // GLSL intrinsic, we need to swizzle the expression to generate the correct
1260   // number of components.
1261   uint32_t wgsl_ret_width = 1;
1262   if (auto* vec = intrinsic->ReturnType()->As<sem::Vector>()) {
1263     wgsl_ret_width = vec->Width();
1264   }
1265   if (wgsl_ret_width < glsl_ret_width) {
1266     out << ".";
1267     for (uint32_t i = 0; i < wgsl_ret_width; i++) {
1268       out << "xyz"[i];
1269     }
1270   }
1271   if (wgsl_ret_width > glsl_ret_width) {
1272     TINT_ICE(Writer, diagnostics_)
1273         << "WGSL return width (" << wgsl_ret_width
1274         << ") is wider than GLSL return width (" << glsl_ret_width << ") for "
1275         << intrinsic->Type();
1276     return false;
1277   }
1278 
1279   return true;
1280 }
1281 
generate_builtin_name(const sem::Intrinsic * intrinsic)1282 std::string GeneratorImpl::generate_builtin_name(
1283     const sem::Intrinsic* intrinsic) {
1284   switch (intrinsic->Type()) {
1285     case sem::IntrinsicType::kAbs:
1286     case sem::IntrinsicType::kAcos:
1287     case sem::IntrinsicType::kAll:
1288     case sem::IntrinsicType::kAny:
1289     case sem::IntrinsicType::kAsin:
1290     case sem::IntrinsicType::kAtan:
1291     case sem::IntrinsicType::kCeil:
1292     case sem::IntrinsicType::kClamp:
1293     case sem::IntrinsicType::kCos:
1294     case sem::IntrinsicType::kCosh:
1295     case sem::IntrinsicType::kCross:
1296     case sem::IntrinsicType::kDeterminant:
1297     case sem::IntrinsicType::kDistance:
1298     case sem::IntrinsicType::kDot:
1299     case sem::IntrinsicType::kExp:
1300     case sem::IntrinsicType::kExp2:
1301     case sem::IntrinsicType::kFloor:
1302     case sem::IntrinsicType::kFrexp:
1303     case sem::IntrinsicType::kLdexp:
1304     case sem::IntrinsicType::kLength:
1305     case sem::IntrinsicType::kLog:
1306     case sem::IntrinsicType::kLog2:
1307     case sem::IntrinsicType::kMax:
1308     case sem::IntrinsicType::kMin:
1309     case sem::IntrinsicType::kModf:
1310     case sem::IntrinsicType::kNormalize:
1311     case sem::IntrinsicType::kPow:
1312     case sem::IntrinsicType::kReflect:
1313     case sem::IntrinsicType::kRefract:
1314     case sem::IntrinsicType::kRound:
1315     case sem::IntrinsicType::kSign:
1316     case sem::IntrinsicType::kSin:
1317     case sem::IntrinsicType::kSinh:
1318     case sem::IntrinsicType::kSqrt:
1319     case sem::IntrinsicType::kStep:
1320     case sem::IntrinsicType::kTan:
1321     case sem::IntrinsicType::kTanh:
1322     case sem::IntrinsicType::kTranspose:
1323     case sem::IntrinsicType::kTrunc:
1324       return intrinsic->str();
1325     case sem::IntrinsicType::kAtan2:
1326       return "atan";
1327     case sem::IntrinsicType::kCountOneBits:
1328       return "countbits";
1329     case sem::IntrinsicType::kDpdx:
1330       return "ddx";
1331     case sem::IntrinsicType::kDpdxCoarse:
1332       return "ddx_coarse";
1333     case sem::IntrinsicType::kDpdxFine:
1334       return "ddx_fine";
1335     case sem::IntrinsicType::kDpdy:
1336       return "ddy";
1337     case sem::IntrinsicType::kDpdyCoarse:
1338       return "ddy_coarse";
1339     case sem::IntrinsicType::kDpdyFine:
1340       return "ddy_fine";
1341     case sem::IntrinsicType::kFaceForward:
1342       return "faceforward";
1343     case sem::IntrinsicType::kFract:
1344       return "frac";
1345     case sem::IntrinsicType::kFma:
1346       return "mad";
1347     case sem::IntrinsicType::kFwidth:
1348     case sem::IntrinsicType::kFwidthCoarse:
1349     case sem::IntrinsicType::kFwidthFine:
1350       return "fwidth";
1351     case sem::IntrinsicType::kInverseSqrt:
1352       return "rsqrt";
1353     case sem::IntrinsicType::kIsFinite:
1354       return "isfinite";
1355     case sem::IntrinsicType::kIsInf:
1356       return "isinf";
1357     case sem::IntrinsicType::kIsNan:
1358       return "isnan";
1359     case sem::IntrinsicType::kMix:
1360       return "mix";
1361     case sem::IntrinsicType::kReverseBits:
1362       return "reversebits";
1363     case sem::IntrinsicType::kSmoothStep:
1364       return "smoothstep";
1365     default:
1366       diagnostics_.add_error(
1367           diag::System::Writer,
1368           "Unknown builtin method: " + std::string(intrinsic->str()));
1369   }
1370 
1371   return "";
1372 }
1373 
EmitCase(const ast::CaseStatement * stmt)1374 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
1375   if (stmt->IsDefault()) {
1376     line() << "default: {";
1377   } else {
1378     for (auto* selector : stmt->selectors) {
1379       auto out = line();
1380       out << "case ";
1381       if (!EmitLiteral(out, selector)) {
1382         return false;
1383       }
1384       out << ":";
1385       if (selector == stmt->selectors.back()) {
1386         out << " {";
1387       }
1388     }
1389   }
1390 
1391   {
1392     ScopedIndent si(this);
1393     if (!EmitStatements(stmt->body->statements)) {
1394       return false;
1395     }
1396     if (!last_is_break_or_fallthrough(stmt->body)) {
1397       line() << "break;";
1398     }
1399   }
1400 
1401   line() << "}";
1402 
1403   return true;
1404 }
1405 
EmitContinue(const ast::ContinueStatement *)1406 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
1407   if (!emit_continuing_()) {
1408     return false;
1409   }
1410   line() << "continue;";
1411   return true;
1412 }
1413 
EmitDiscard(const ast::DiscardStatement *)1414 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
1415   // TODO(dsinclair): Verify this is correct when the discard semantics are
1416   // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
1417   line() << "discard;";
1418   return true;
1419 }
1420 
EmitExpression(std::ostream & out,const ast::Expression * expr)1421 bool GeneratorImpl::EmitExpression(std::ostream& out,
1422                                    const ast::Expression* expr) {
1423   if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
1424     return EmitIndexAccessor(out, a);
1425   }
1426   if (auto* b = expr->As<ast::BinaryExpression>()) {
1427     return EmitBinary(out, b);
1428   }
1429   if (auto* b = expr->As<ast::BitcastExpression>()) {
1430     return EmitBitcast(out, b);
1431   }
1432   if (auto* c = expr->As<ast::CallExpression>()) {
1433     return EmitCall(out, c);
1434   }
1435   if (auto* i = expr->As<ast::IdentifierExpression>()) {
1436     return EmitIdentifier(out, i);
1437   }
1438   if (auto* l = expr->As<ast::LiteralExpression>()) {
1439     return EmitLiteral(out, l);
1440   }
1441   if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
1442     return EmitMemberAccessor(out, m);
1443   }
1444   if (auto* u = expr->As<ast::UnaryOpExpression>()) {
1445     return EmitUnaryOp(out, u);
1446   }
1447 
1448   diagnostics_.add_error(
1449       diag::System::Writer,
1450       "unknown expression type: " + std::string(expr->TypeInfo().name));
1451   return false;
1452 }
1453 
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)1454 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
1455                                    const ast::IdentifierExpression* expr) {
1456   out << builder_.Symbols().NameFor(expr->symbol);
1457   return true;
1458 }
1459 
EmitIf(const ast::IfStatement * stmt)1460 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
1461   {
1462     auto out = line();
1463     out << "if (";
1464     if (!EmitExpression(out, stmt->condition)) {
1465       return false;
1466     }
1467     out << ") {";
1468   }
1469 
1470   if (!EmitStatementsWithIndent(stmt->body->statements)) {
1471     return false;
1472   }
1473 
1474   for (auto* e : stmt->else_statements) {
1475     if (e->condition) {
1476       line() << "} else {";
1477       increment_indent();
1478 
1479       {
1480         auto out = line();
1481         out << "if (";
1482         if (!EmitExpression(out, e->condition)) {
1483           return false;
1484         }
1485         out << ") {";
1486       }
1487     } else {
1488       line() << "} else {";
1489     }
1490 
1491     if (!EmitStatementsWithIndent(e->body->statements)) {
1492       return false;
1493     }
1494   }
1495 
1496   line() << "}";
1497 
1498   for (auto* e : stmt->else_statements) {
1499     if (e->condition) {
1500       decrement_indent();
1501       line() << "}";
1502     }
1503   }
1504   return true;
1505 }
1506 
EmitFunction(const ast::Function * func)1507 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
1508   auto* sem = builder_.Sem().Get(func);
1509 
1510   if (ast::HasDecoration<ast::InternalDecoration>(func->decorations)) {
1511     // An internal function. Do not emit.
1512     return true;
1513   }
1514 
1515   {
1516     auto out = line();
1517     auto name = builder_.Symbols().NameFor(func->symbol);
1518     if (!EmitType(out, sem->ReturnType(), ast::StorageClass::kNone,
1519                   ast::Access::kReadWrite, "")) {
1520       return false;
1521     }
1522 
1523     out << " " << name << "(";
1524 
1525     bool first = true;
1526 
1527     for (auto* v : sem->Parameters()) {
1528       if (!first) {
1529         out << ", ";
1530       }
1531       first = false;
1532 
1533       auto const* type = v->Type();
1534 
1535       if (auto* ptr = type->As<sem::Pointer>()) {
1536         // Transform pointer parameters in to `inout` parameters.
1537         // The WGSL spec is highly restrictive in what can be passed in pointer
1538         // parameters, which allows for this transformation. See:
1539         // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
1540         out << "inout ";
1541         type = ptr->StoreType();
1542       }
1543 
1544       // Note: WGSL only allows for StorageClass::kNone on parameters, however
1545       // the sanitizer transforms generates load / store functions for storage
1546       // or uniform buffers. These functions have a buffer parameter with
1547       // StorageClass::kStorage or StorageClass::kUniform. This is required to
1548       // correctly translate the parameter to a [RW]ByteAddressBuffer for
1549       // storage buffers and a uint4[N] for uniform buffers.
1550       if (!EmitTypeAndName(
1551               out, type, v->StorageClass(), v->Access(),
1552               builder_.Symbols().NameFor(v->Declaration()->symbol))) {
1553         return false;
1554       }
1555     }
1556     out << ") {";
1557   }
1558 
1559   if (!EmitStatementsWithIndent(func->body->statements)) {
1560     return false;
1561   }
1562 
1563   line() << "}";
1564 
1565   return true;
1566 }
1567 
EmitGlobalVariable(const ast::Variable * global)1568 bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
1569   if (global->is_const) {
1570     return EmitProgramConstVariable(global);
1571   }
1572 
1573   auto* sem = builder_.Sem().Get(global);
1574   switch (sem->StorageClass()) {
1575     case ast::StorageClass::kUniform:
1576       return EmitUniformVariable(sem);
1577     case ast::StorageClass::kStorage:
1578       return EmitStorageVariable(sem);
1579     case ast::StorageClass::kUniformConstant:
1580       return EmitHandleVariable(sem);
1581     case ast::StorageClass::kPrivate:
1582       return EmitPrivateVariable(sem);
1583     case ast::StorageClass::kWorkgroup:
1584       return EmitWorkgroupVariable(sem);
1585     default:
1586       break;
1587   }
1588 
1589   TINT_ICE(Writer, diagnostics_)
1590       << "unhandled storage class " << sem->StorageClass();
1591   return false;
1592 }
1593 
EmitUniformVariable(const sem::Variable * var)1594 bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
1595   auto* decl = var->Declaration();
1596   auto* type = var->Type()->UnwrapRef();
1597   auto* str = type->As<sem::Struct>();
1598   if (!str) {
1599     TINT_ICE(Writer, builder_.Diagnostics())
1600         << "storage variable must be of struct type";
1601     return false;
1602   }
1603   ast::VariableBindingPoint bp = decl->BindingPoint();
1604   line() << "layout (binding = " << bp.binding->value << ") uniform "
1605          << UniqueIdentifier(StructName(str)) << " {";
1606   EmitStructMembers(current_buffer_, str);
1607   auto name = builder_.Symbols().NameFor(decl->symbol);
1608   line() << "} " << name << ";";
1609 
1610   return true;
1611 }
1612 
EmitStorageVariable(const sem::Variable * var)1613 bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) {
1614   auto* decl = var->Declaration();
1615   auto* type = var->Type()->UnwrapRef();
1616   auto* str = type->As<sem::Struct>();
1617   if (!str) {
1618     TINT_ICE(Writer, builder_.Diagnostics())
1619         << "storage variable must be of struct type";
1620     return false;
1621   }
1622   ast::VariableBindingPoint bp = decl->BindingPoint();
1623   line() << "layout (binding = " << bp.binding->value << ") buffer "
1624          << UniqueIdentifier(StructName(str)) << " {";
1625   EmitStructMembers(current_buffer_, str);
1626   auto name = builder_.Symbols().NameFor(decl->symbol);
1627   line() << "} " << name << ";";
1628   return true;
1629 }
1630 
EmitHandleVariable(const sem::Variable * var)1631 bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
1632   auto* decl = var->Declaration();
1633   auto out = line();
1634 
1635   auto name = builder_.Symbols().NameFor(decl->symbol);
1636   auto* type = var->Type()->UnwrapRef();
1637   if (type->As<sem::Sampler>()) {
1638     // GLSL ignores Sampler variables.
1639     return true;
1640   }
1641   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
1642     return false;
1643   }
1644 
1645   out << ";";
1646   return true;
1647 }
1648 
EmitPrivateVariable(const sem::Variable * var)1649 bool GeneratorImpl::EmitPrivateVariable(const sem::Variable* var) {
1650   auto* decl = var->Declaration();
1651   auto out = line();
1652 
1653   auto name = builder_.Symbols().NameFor(decl->symbol);
1654   auto* type = var->Type()->UnwrapRef();
1655   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
1656     return false;
1657   }
1658 
1659   out << " = ";
1660   if (auto* constructor = decl->constructor) {
1661     if (!EmitExpression(out, constructor)) {
1662       return false;
1663     }
1664   } else {
1665     if (!EmitZeroValue(out, var->Type()->UnwrapRef())) {
1666       return false;
1667     }
1668   }
1669 
1670   out << ";";
1671   return true;
1672 }
1673 
EmitWorkgroupVariable(const sem::Variable * var)1674 bool GeneratorImpl::EmitWorkgroupVariable(const sem::Variable* var) {
1675   auto* decl = var->Declaration();
1676   auto out = line();
1677 
1678   out << "shared ";
1679 
1680   auto name = builder_.Symbols().NameFor(decl->symbol);
1681   auto* type = var->Type()->UnwrapRef();
1682   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
1683     return false;
1684   }
1685 
1686   if (auto* constructor = decl->constructor) {
1687     out << " = ";
1688     if (!EmitExpression(out, constructor)) {
1689       return false;
1690     }
1691   }
1692 
1693   out << ";";
1694   return true;
1695 }
1696 
builtin_type(ast::Builtin builtin)1697 sem::Type* GeneratorImpl::builtin_type(ast::Builtin builtin) {
1698   switch (builtin) {
1699     case ast::Builtin::kPosition: {
1700       auto* f32 = builder_.create<sem::F32>();
1701       return builder_.create<sem::Vector>(f32, 4);
1702     }
1703     case ast::Builtin::kVertexIndex:
1704     case ast::Builtin::kInstanceIndex: {
1705       return builder_.create<sem::I32>();
1706     }
1707     case ast::Builtin::kFrontFacing: {
1708       return builder_.create<sem::Bool>();
1709     }
1710     case ast::Builtin::kFragDepth: {
1711       return builder_.create<sem::F32>();
1712     }
1713     case ast::Builtin::kLocalInvocationId:
1714     case ast::Builtin::kGlobalInvocationId:
1715     case ast::Builtin::kWorkgroupId: {
1716       auto* u32 = builder_.create<sem::U32>();
1717       return builder_.create<sem::Vector>(u32, 3);
1718     }
1719     case ast::Builtin::kSampleIndex: {
1720       return builder_.create<sem::I32>();
1721     }
1722     case ast::Builtin::kSampleMask:
1723     default:
1724       return nullptr;
1725   }
1726 }
1727 
builtin_to_string(ast::Builtin builtin,ast::PipelineStage stage)1728 const char* GeneratorImpl::builtin_to_string(ast::Builtin builtin,
1729                                              ast::PipelineStage stage) {
1730   switch (builtin) {
1731     case ast::Builtin::kPosition:
1732       switch (stage) {
1733         case ast::PipelineStage::kVertex:
1734           return "gl_Position";
1735         case ast::PipelineStage::kFragment:
1736           return "gl_FragCoord";
1737         default:
1738           TINT_ICE(Writer, builder_.Diagnostics())
1739               << "position builtin unexpected in this pipeline stage";
1740           return "";
1741       }
1742     case ast::Builtin::kVertexIndex:
1743       return "gl_VertexID";
1744     case ast::Builtin::kInstanceIndex:
1745       return "gl_InstanceID";
1746     case ast::Builtin::kFrontFacing:
1747       return "gl_FrontFacing";
1748     case ast::Builtin::kFragDepth:
1749       return "gl_FragDepth";
1750     case ast::Builtin::kLocalInvocationId:
1751       return "gl_LocalInvocationID";
1752     case ast::Builtin::kLocalInvocationIndex:
1753       return "gl_LocalInvocationIndex";
1754     case ast::Builtin::kGlobalInvocationId:
1755       return "gl_GlobalInvocationID";
1756     case ast::Builtin::kWorkgroupId:
1757       return "gl_WorkGroupID";
1758     case ast::Builtin::kSampleIndex:
1759       return "gl_SampleID";
1760     case ast::Builtin::kSampleMask:
1761       // FIXME: is this always available?
1762       return "gl_SampleMask";
1763     default:
1764       return "";
1765   }
1766 }
1767 
interpolation_to_modifiers(ast::InterpolationType type,ast::InterpolationSampling sampling) const1768 std::string GeneratorImpl::interpolation_to_modifiers(
1769     ast::InterpolationType type,
1770     ast::InterpolationSampling sampling) const {
1771   std::string modifiers;
1772   switch (type) {
1773     case ast::InterpolationType::kPerspective:
1774       modifiers += "linear ";
1775       break;
1776     case ast::InterpolationType::kLinear:
1777       modifiers += "noperspective ";
1778       break;
1779     case ast::InterpolationType::kFlat:
1780       modifiers += "nointerpolation ";
1781       break;
1782   }
1783   switch (sampling) {
1784     case ast::InterpolationSampling::kCentroid:
1785       modifiers += "centroid ";
1786       break;
1787     case ast::InterpolationSampling::kSample:
1788       modifiers += "sample ";
1789       break;
1790     case ast::InterpolationSampling::kCenter:
1791     case ast::InterpolationSampling::kNone:
1792       break;
1793   }
1794   return modifiers;
1795 }
1796 
EmitEntryPointFunction(const ast::Function * func)1797 bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
1798   auto* func_sem = builder_.Sem().Get(func);
1799 
1800   {
1801     auto out = line();
1802     if (func->PipelineStage() == ast::PipelineStage::kCompute) {
1803       // Emit the layout(local_size) attributes.
1804       auto wgsize = func_sem->WorkgroupSize();
1805       out << "layout(";
1806       for (int i = 0; i < 3; i++) {
1807         if (i > 0) {
1808           out << ", ";
1809         }
1810         out << "local_size_" << (i == 0 ? "x" : i == 1 ? "y" : "z") << " = ";
1811 
1812         if (wgsize[i].overridable_const) {
1813           auto* global = builder_.Sem().Get<sem::GlobalVariable>(
1814               wgsize[i].overridable_const);
1815           if (!global->IsOverridable()) {
1816             TINT_ICE(Writer, builder_.Diagnostics())
1817                 << "expected a pipeline-overridable constant";
1818           }
1819           out << kSpecConstantPrefix << global->ConstantId();
1820         } else {
1821           out << std::to_string(wgsize[i].value);
1822         }
1823       }
1824       out << ") in;" << std::endl;
1825     }
1826 
1827     out << func->return_type->FriendlyName(builder_.Symbols());
1828 
1829     out << " " << builder_.Symbols().NameFor(func->symbol) << "(";
1830 
1831     bool first = true;
1832 
1833     // Emit entry point parameters.
1834     for (auto* var : func->params) {
1835       auto* sem = builder_.Sem().Get(var);
1836       auto* type = sem->Type();
1837       if (!type->Is<sem::Struct>()) {
1838         // ICE likely indicates that the CanonicalizeEntryPointIO transform was
1839         // not run, or a builtin parameter was added after it was run.
1840         TINT_ICE(Writer, diagnostics_)
1841             << "Unsupported non-struct entry point parameter";
1842       }
1843 
1844       if (!first) {
1845         out << ", ";
1846       }
1847       first = false;
1848 
1849       if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
1850                            builder_.Symbols().NameFor(var->symbol))) {
1851         return false;
1852       }
1853     }
1854 
1855     out << ") {";
1856   }
1857 
1858   {
1859     ScopedIndent si(this);
1860 
1861     if (!EmitStatements(func->body->statements)) {
1862       return false;
1863     }
1864 
1865     if (!Is<ast::ReturnStatement>(func->body->Last())) {
1866       ast::ReturnStatement ret(ProgramID(), Source{});
1867       if (!EmitStatement(&ret)) {
1868         return false;
1869       }
1870     }
1871   }
1872 
1873   line() << "}";
1874 
1875   auto out = line();
1876 
1877   // Declare entry point input variables
1878   for (auto* var : func->params) {
1879     auto* sem = builder_.Sem().Get(var);
1880     auto* str = sem->Type()->As<sem::Struct>();
1881     for (auto* member : str->Members()) {
1882       if (ast::HasDecoration<ast::BuiltinDecoration>(
1883               member->Declaration()->decorations)) {
1884         continue;
1885       }
1886       if (!EmitTypeAndName(
1887               out, member->Type(), ast::StorageClass::kInput,
1888               ast::Access::kReadWrite,
1889               builder_.Symbols().NameFor(member->Declaration()->symbol))) {
1890         return false;
1891       }
1892       out << ";" << std::endl;
1893     }
1894   }
1895 
1896   // Declare entry point output variables
1897   auto* return_type = func_sem->ReturnType()->As<sem::Struct>();
1898   if (return_type) {
1899     for (auto* member : return_type->Members()) {
1900       if (ast::HasDecoration<ast::BuiltinDecoration>(
1901               member->Declaration()->decorations)) {
1902         continue;
1903       }
1904       if (!EmitTypeAndName(
1905               out, member->Type(), ast::StorageClass::kOutput,
1906               ast::Access::kReadWrite,
1907               builder_.Symbols().NameFor(member->Declaration()->symbol))) {
1908         return false;
1909       }
1910       out << ";" << std::endl;
1911     }
1912   }
1913 
1914   // Create a main() function which calls the entry point.
1915   out << "void main() {" << std::endl;
1916   std::string printed_name;
1917   for (auto* var : func->params) {
1918     out << "  ";
1919     auto* sem = builder_.Sem().Get(var);
1920     if (!EmitTypeAndName(out, sem->Type(), sem->StorageClass(), sem->Access(),
1921                          "inputs")) {
1922       return false;
1923     }
1924     out << ";" << std::endl;
1925     auto* type = sem->Type();
1926     auto* str = type->As<sem::Struct>();
1927     for (auto* member : str->Members()) {
1928       std::string name =
1929           builder_.Symbols().NameFor(member->Declaration()->symbol);
1930       out << "  inputs." << name << " = ";
1931       if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
1932               member->Declaration()->decorations)) {
1933         if (builtin_type(builtin->builtin) != member->Type()) {
1934           if (!EmitType(out, member->Type(), ast::StorageClass::kNone,
1935                         ast::Access::kReadWrite, "")) {
1936             return false;
1937           }
1938           out << "(";
1939           out << builtin_to_string(builtin->builtin, func->PipelineStage());
1940           out << ")";
1941         } else {
1942           out << builtin_to_string(builtin->builtin, func->PipelineStage());
1943         }
1944       } else {
1945         out << name;
1946       }
1947       out << ";" << std::endl;
1948     }
1949   }
1950   out << "  ";
1951   if (return_type) {
1952     out << return_type->FriendlyName(builder_.Symbols()) << " "
1953         << "outputs;" << std::endl;
1954     out << "  outputs = ";
1955   }
1956   out << builder_.Symbols().NameFor(func->symbol);
1957   if (func->params.empty()) {
1958     out << "()";
1959   } else {
1960     out << "(inputs)";
1961   }
1962   out << ";" << std::endl;
1963 
1964   auto* str = func_sem->ReturnType()->As<sem::Struct>();
1965   if (str) {
1966     for (auto* member : str->Members()) {
1967       std::string name =
1968           builder_.Symbols().NameFor(member->Declaration()->symbol);
1969       out << "  ";
1970       if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
1971               member->Declaration()->decorations)) {
1972         out << builtin_to_string(builtin->builtin, func->PipelineStage());
1973       } else {
1974         out << name;
1975       }
1976       out << " = outputs." << name << ";" << std::endl;
1977     }
1978   }
1979   if (func->PipelineStage() == ast::PipelineStage::kVertex) {
1980     out << "  gl_Position.y = -gl_Position.y;" << std::endl;
1981   }
1982 
1983   out << "}" << std::endl << std::endl;
1984 
1985   return true;
1986 }
1987 
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)1988 bool GeneratorImpl::EmitLiteral(std::ostream& out,
1989                                 const ast::LiteralExpression* lit) {
1990   if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
1991     out << (l->value ? "true" : "false");
1992   } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
1993     if (std::isinf(fl->value)) {
1994       out << (fl->value >= 0 ? "uintBitsToFloat(0x7f800000u)"
1995                              : "uintBitsToFloat(0xff800000u)");
1996     } else if (std::isnan(fl->value)) {
1997       out << "uintBitsToFloat(0x7fc00000u)";
1998     } else {
1999       out << FloatToString(fl->value) << "f";
2000     }
2001   } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
2002     out << sl->value;
2003   } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
2004     out << ul->value << "u";
2005   } else {
2006     diagnostics_.add_error(diag::System::Writer, "unknown literal type");
2007     return false;
2008   }
2009   return true;
2010 }
2011 
EmitZeroValue(std::ostream & out,const sem::Type * type)2012 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
2013   if (type->Is<sem::Bool>()) {
2014     out << "false";
2015   } else if (type->Is<sem::F32>()) {
2016     out << "0.0f";
2017   } else if (type->Is<sem::I32>()) {
2018     out << "0";
2019   } else if (type->Is<sem::U32>()) {
2020     out << "0u";
2021   } else if (auto* vec = type->As<sem::Vector>()) {
2022     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
2023                   "")) {
2024       return false;
2025     }
2026     ScopedParen sp(out);
2027     for (uint32_t i = 0; i < vec->Width(); i++) {
2028       if (i != 0) {
2029         out << ", ";
2030       }
2031       if (!EmitZeroValue(out, vec->type())) {
2032         return false;
2033       }
2034     }
2035   } else if (auto* mat = type->As<sem::Matrix>()) {
2036     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
2037                   "")) {
2038       return false;
2039     }
2040     ScopedParen sp(out);
2041     for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
2042       if (i != 0) {
2043         out << ", ";
2044       }
2045       if (!EmitZeroValue(out, mat->type())) {
2046         return false;
2047       }
2048     }
2049   } else if (auto* str = type->As<sem::Struct>()) {
2050     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
2051                   "")) {
2052       return false;
2053     }
2054     bool first = true;
2055     out << "(";
2056     for (auto* member : str->Members()) {
2057       if (!first) {
2058         out << ", ";
2059       } else {
2060         first = false;
2061       }
2062       EmitZeroValue(out, member->Type());
2063     }
2064     out << ")";
2065   } else if (auto* array = type->As<sem::Array>()) {
2066     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
2067                   "")) {
2068       return false;
2069     }
2070     out << "(";
2071     for (uint32_t i = 0; i < array->Count(); i++) {
2072       if (i != 0) {
2073         out << ", ";
2074       }
2075       EmitZeroValue(out, array->ElemType());
2076     }
2077     out << ")";
2078   } else {
2079     diagnostics_.add_error(
2080         diag::System::Writer,
2081         "Invalid type for zero emission: " + type->type_name());
2082     return false;
2083   }
2084   return true;
2085 }
2086 
EmitLoop(const ast::LoopStatement * stmt)2087 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
2088   auto emit_continuing = [this, stmt]() {
2089     if (stmt->continuing && !stmt->continuing->Empty()) {
2090       if (!EmitBlock(stmt->continuing)) {
2091         return false;
2092       }
2093     }
2094     return true;
2095   };
2096 
2097   TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
2098   line() << "while (true) {";
2099   {
2100     ScopedIndent si(this);
2101     if (!EmitStatements(stmt->body->statements)) {
2102       return false;
2103     }
2104     if (!emit_continuing()) {
2105       return false;
2106     }
2107   }
2108   line() << "}";
2109 
2110   return true;
2111 }
2112 
EmitForLoop(const ast::ForLoopStatement * stmt)2113 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
2114   // Nest a for loop with a new block. In HLSL the initializer scope is not
2115   // nested by the for-loop, so we may get variable redefinitions.
2116   line() << "{";
2117   increment_indent();
2118   TINT_DEFER({
2119     decrement_indent();
2120     line() << "}";
2121   });
2122 
2123   TextBuffer init_buf;
2124   if (auto* init = stmt->initializer) {
2125     TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
2126     if (!EmitStatement(init)) {
2127       return false;
2128     }
2129   }
2130 
2131   TextBuffer cond_pre;
2132   std::stringstream cond_buf;
2133   if (auto* cond = stmt->condition) {
2134     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
2135     if (!EmitExpression(cond_buf, cond)) {
2136       return false;
2137     }
2138   }
2139 
2140   TextBuffer cont_buf;
2141   if (auto* cont = stmt->continuing) {
2142     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
2143     if (!EmitStatement(cont)) {
2144       return false;
2145     }
2146   }
2147 
2148   // If the for-loop has a multi-statement conditional and / or continuing, then
2149   // we cannot emit this as a regular for-loop in HLSL. Instead we need to
2150   // generate a `while(true)` loop.
2151   bool emit_as_loop = cond_pre.lines.size() > 0 || cont_buf.lines.size() > 1;
2152 
2153   // If the for-loop has multi-statement initializer, or is going to be emitted
2154   // as a `while(true)` loop, then declare the initializer statement(s) before
2155   // the loop.
2156   if (init_buf.lines.size() > 1 || (stmt->initializer && emit_as_loop)) {
2157     current_buffer_->Append(init_buf);
2158     init_buf.lines.clear();  // Don't emit the initializer again in the 'for'
2159   }
2160 
2161   if (emit_as_loop) {
2162     auto emit_continuing = [&]() {
2163       current_buffer_->Append(cont_buf);
2164       return true;
2165     };
2166 
2167     TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
2168     line() << "while (true) {";
2169     increment_indent();
2170     TINT_DEFER({
2171       decrement_indent();
2172       line() << "}";
2173     });
2174 
2175     if (stmt->condition) {
2176       current_buffer_->Append(cond_pre);
2177       line() << "if (!(" << cond_buf.str() << ")) { break; }";
2178     }
2179 
2180     if (!EmitStatements(stmt->body->statements)) {
2181       return false;
2182     }
2183 
2184     if (!emit_continuing()) {
2185       return false;
2186     }
2187   } else {
2188     // For-loop can be generated.
2189     {
2190       auto out = line();
2191       out << "for";
2192       {
2193         ScopedParen sp(out);
2194 
2195         if (!init_buf.lines.empty()) {
2196           out << init_buf.lines[0].content << " ";
2197         } else {
2198           out << "; ";
2199         }
2200 
2201         out << cond_buf.str() << "; ";
2202 
2203         if (!cont_buf.lines.empty()) {
2204           out << TrimSuffix(cont_buf.lines[0].content, ";");
2205         }
2206       }
2207       out << " {";
2208     }
2209     {
2210       auto emit_continuing = [] { return true; };
2211       TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
2212       if (!EmitStatementsWithIndent(stmt->body->statements)) {
2213         return false;
2214       }
2215     }
2216     line() << "}";
2217   }
2218 
2219   return true;
2220 }
2221 
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)2222 bool GeneratorImpl::EmitMemberAccessor(
2223     std::ostream& out,
2224     const ast::MemberAccessorExpression* expr) {
2225   if (!EmitExpression(out, expr->structure)) {
2226     return false;
2227   }
2228   out << ".";
2229 
2230   // Swizzles output the name directly
2231   if (builder_.Sem().Get(expr)->Is<sem::Swizzle>()) {
2232     out << builder_.Symbols().NameFor(expr->member->symbol);
2233   } else if (!EmitExpression(out, expr->member)) {
2234     return false;
2235   }
2236 
2237   return true;
2238 }
2239 
EmitReturn(const ast::ReturnStatement * stmt)2240 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
2241   if (stmt->value) {
2242     auto out = line();
2243     out << "return ";
2244     if (!EmitExpression(out, stmt->value)) {
2245       return false;
2246     }
2247     out << ";";
2248   } else {
2249     line() << "return;";
2250   }
2251   return true;
2252 }
2253 
EmitStatement(const ast::Statement * stmt)2254 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
2255   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
2256     return EmitAssign(a);
2257   }
2258   if (auto* b = stmt->As<ast::BlockStatement>()) {
2259     return EmitBlock(b);
2260   }
2261   if (auto* b = stmt->As<ast::BreakStatement>()) {
2262     return EmitBreak(b);
2263   }
2264   if (auto* c = stmt->As<ast::CallStatement>()) {
2265     auto out = line();
2266     if (!EmitCall(out, c->expr)) {
2267       return false;
2268     }
2269     out << ";";
2270     return true;
2271   }
2272   if (auto* c = stmt->As<ast::ContinueStatement>()) {
2273     return EmitContinue(c);
2274   }
2275   if (auto* d = stmt->As<ast::DiscardStatement>()) {
2276     return EmitDiscard(d);
2277   }
2278   if (stmt->As<ast::FallthroughStatement>()) {
2279     line() << "/* fallthrough */";
2280     return true;
2281   }
2282   if (auto* i = stmt->As<ast::IfStatement>()) {
2283     return EmitIf(i);
2284   }
2285   if (auto* l = stmt->As<ast::LoopStatement>()) {
2286     return EmitLoop(l);
2287   }
2288   if (auto* l = stmt->As<ast::ForLoopStatement>()) {
2289     return EmitForLoop(l);
2290   }
2291   if (auto* r = stmt->As<ast::ReturnStatement>()) {
2292     return EmitReturn(r);
2293   }
2294   if (auto* s = stmt->As<ast::SwitchStatement>()) {
2295     return EmitSwitch(s);
2296   }
2297   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
2298     return EmitVariable(v->variable);
2299   }
2300 
2301   diagnostics_.add_error(
2302       diag::System::Writer,
2303       "unknown statement type: " + std::string(stmt->TypeInfo().name));
2304   return false;
2305 }
2306 
EmitSwitch(const ast::SwitchStatement * stmt)2307 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
2308   {  // switch(expr) {
2309     auto out = line();
2310     out << "switch(";
2311     if (!EmitExpression(out, stmt->condition)) {
2312       return false;
2313     }
2314     out << ") {";
2315   }
2316 
2317   {
2318     ScopedIndent si(this);
2319     for (auto* s : stmt->body) {
2320       if (!EmitCase(s)) {
2321         return false;
2322       }
2323     }
2324   }
2325 
2326   line() << "}";
2327 
2328   return true;
2329 }
2330 
EmitType(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name,bool * name_printed)2331 bool GeneratorImpl::EmitType(std::ostream& out,
2332                              const sem::Type* type,
2333                              ast::StorageClass storage_class,
2334                              ast::Access access,
2335                              const std::string& name,
2336                              bool* name_printed /* = nullptr */) {
2337   if (name_printed) {
2338     *name_printed = false;
2339   }
2340   switch (storage_class) {
2341     case ast::StorageClass::kInput: {
2342       out << "in ";
2343       break;
2344     }
2345     case ast::StorageClass::kOutput: {
2346       out << "out ";
2347       break;
2348     }
2349     case ast::StorageClass::kUniform: {
2350       out << "uniform ";
2351       break;
2352     }
2353     default:
2354       break;
2355   }
2356 
2357   if (auto* ary = type->As<sem::Array>()) {
2358     const sem::Type* base_type = ary;
2359     std::vector<uint32_t> sizes;
2360     while (auto* arr = base_type->As<sem::Array>()) {
2361       sizes.push_back(arr->Count());
2362       base_type = arr->ElemType();
2363     }
2364     if (!EmitType(out, base_type, storage_class, access, "")) {
2365       return false;
2366     }
2367     if (!name.empty()) {
2368       out << " " << name;
2369       if (name_printed) {
2370         *name_printed = true;
2371       }
2372     }
2373     for (uint32_t size : sizes) {
2374       if (size > 0) {
2375         out << "[" << size << "]";
2376       } else {
2377         out << "[]";
2378       }
2379     }
2380   } else if (type->Is<sem::Bool>()) {
2381     out << "bool";
2382   } else if (type->Is<sem::F32>()) {
2383     out << "float";
2384   } else if (type->Is<sem::I32>()) {
2385     out << "int";
2386   } else if (auto* mat = type->As<sem::Matrix>()) {
2387     TINT_ASSERT(Writer, mat->type()->Is<sem::F32>());
2388     out << "mat" << mat->columns();
2389     if (mat->rows() != mat->columns()) {
2390       out << "x" << mat->rows();
2391     }
2392   } else if (type->Is<sem::Pointer>()) {
2393     TINT_ICE(Writer, diagnostics_)
2394         << "Attempting to emit pointer type. These should have been removed "
2395            "with the InlinePointerLets transform";
2396     return false;
2397   } else if (type->Is<sem::Sampler>()) {
2398     return false;
2399   } else if (auto* str = type->As<sem::Struct>()) {
2400     out << StructName(str);
2401   } else if (auto* tex = type->As<sem::Texture>()) {
2402     auto* storage = tex->As<sem::StorageTexture>();
2403     auto* ms = tex->As<sem::MultisampledTexture>();
2404     auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
2405     auto* sampled = tex->As<sem::SampledTexture>();
2406 
2407     out << "uniform highp ";
2408 
2409     if (storage && storage->access() != ast::Access::kRead) {
2410       out << "writeonly ";
2411     }
2412     auto* subtype = sampled
2413                         ? sampled->type()
2414                         : storage ? storage->type() : ms ? ms->type() : nullptr;
2415     if (!subtype || subtype->Is<sem::F32>()) {
2416     } else if (subtype->Is<sem::I32>()) {
2417       out << "i";
2418     } else if (subtype->Is<sem::U32>()) {
2419       out << "u";
2420     } else {
2421       TINT_ICE(Writer, diagnostics_) << "Unsupported texture type";
2422       return false;
2423     }
2424 
2425     out << (storage ? "image" : "sampler");
2426 
2427     switch (tex->dim()) {
2428       case ast::TextureDimension::k1d:
2429         out << "1D";
2430         break;
2431       case ast::TextureDimension::k2d:
2432         out << ((ms || depth_ms) ? "2DMS" : "2D");
2433         break;
2434       case ast::TextureDimension::k2dArray:
2435         out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
2436         break;
2437       case ast::TextureDimension::k3d:
2438         out << "3D";
2439         break;
2440       case ast::TextureDimension::kCube:
2441         out << "Cube";
2442         break;
2443       case ast::TextureDimension::kCubeArray:
2444         out << "CubeArray";
2445         break;
2446       default:
2447         TINT_UNREACHABLE(Writer, diagnostics_)
2448             << "unexpected TextureDimension " << tex->dim();
2449         return false;
2450     }
2451   } else if (type->Is<sem::U32>()) {
2452     out << "uint";
2453   } else if (auto* vec = type->As<sem::Vector>()) {
2454     auto width = vec->Width();
2455     if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
2456       out << "vec" << width;
2457     } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
2458       out << "ivec" << width;
2459     } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
2460       out << "uvec" << width;
2461     } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
2462       out << "bvec" << width;
2463     } else {
2464       out << "vector<";
2465       if (!EmitType(out, vec->type(), storage_class, access, "")) {
2466         return false;
2467       }
2468       out << ", " << width << ">";
2469     }
2470   } else if (auto* atomic = type->As<sem::Atomic>()) {
2471     if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
2472       return false;
2473     }
2474   } else if (type->Is<sem::Void>()) {
2475     out << "void";
2476   } else {
2477     diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
2478     return false;
2479   }
2480 
2481   return true;
2482 }
2483 
EmitTypeAndName(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name)2484 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
2485                                     const sem::Type* type,
2486                                     ast::StorageClass storage_class,
2487                                     ast::Access access,
2488                                     const std::string& name) {
2489   bool printed_name = false;
2490   if (!EmitType(out, type, storage_class, access, name, &printed_name)) {
2491     return false;
2492   }
2493   if (!name.empty() && !printed_name) {
2494     out << " " << name;
2495   }
2496   return true;
2497 }
2498 
EmitStructType(TextBuffer * b,const sem::Struct * str)2499 bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
2500   auto storage_class_uses = str->StorageClassUsage();
2501   line(b) << "struct " << StructName(str) << " {";
2502   EmitStructMembers(b, str);
2503   line(b) << "};";
2504 
2505   return true;
2506 }
2507 
EmitStructMembers(TextBuffer * b,const sem::Struct * str)2508 bool GeneratorImpl::EmitStructMembers(TextBuffer* b, const sem::Struct* str) {
2509   ScopedIndent si(b);
2510   for (auto* mem : str->Members()) {
2511     auto name = builder_.Symbols().NameFor(mem->Name());
2512 
2513     auto* ty = mem->Type();
2514 
2515     auto out = line(b);
2516 
2517     std::string pre, post;
2518 
2519     if (auto* decl = mem->Declaration()) {
2520       for (auto* deco : decl->decorations) {
2521         if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
2522           auto mod = interpolation_to_modifiers(interpolate->type,
2523                                                 interpolate->sampling);
2524           if (mod.empty()) {
2525             diagnostics_.add_error(diag::System::Writer,
2526                                    "unsupported interpolation");
2527             return false;
2528           }
2529         }
2530       }
2531     }
2532 
2533     out << pre;
2534     if (!EmitTypeAndName(out, ty, ast::StorageClass::kNone,
2535                          ast::Access::kReadWrite, name)) {
2536       return false;
2537     }
2538     out << post << ";";
2539   }
2540   return true;
2541 }
2542 
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)2543 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
2544                                 const ast::UnaryOpExpression* expr) {
2545   switch (expr->op) {
2546     case ast::UnaryOp::kIndirection:
2547     case ast::UnaryOp::kAddressOf:
2548       return EmitExpression(out, expr->expr);
2549     case ast::UnaryOp::kComplement:
2550       out << "~";
2551       break;
2552     case ast::UnaryOp::kNot:
2553       out << "!";
2554       break;
2555     case ast::UnaryOp::kNegation:
2556       out << "-";
2557       break;
2558   }
2559   out << "(";
2560 
2561   if (!EmitExpression(out, expr->expr)) {
2562     return false;
2563   }
2564 
2565   out << ")";
2566 
2567   return true;
2568 }
2569 
EmitVariable(const ast::Variable * var)2570 bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
2571   auto* sem = builder_.Sem().Get(var);
2572   auto* type = sem->Type()->UnwrapRef();
2573 
2574   // TODO(dsinclair): Handle variable decorations
2575   if (!var->decorations.empty()) {
2576     diagnostics_.add_error(diag::System::Writer,
2577                            "Variable decorations are not handled yet");
2578     return false;
2579   }
2580 
2581   auto out = line();
2582   // TODO(senorblanco): handle const
2583   if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
2584                        builder_.Symbols().NameFor(var->symbol))) {
2585     return false;
2586   }
2587 
2588   out << " = ";
2589 
2590   if (var->constructor) {
2591     if (!EmitExpression(out, var->constructor)) {
2592       return false;
2593     }
2594   } else {
2595     if (!EmitZeroValue(out, type)) {
2596       return false;
2597     }
2598   }
2599   out << ";";
2600 
2601   return true;
2602 }
2603 
EmitProgramConstVariable(const ast::Variable * var)2604 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
2605   for (auto* d : var->decorations) {
2606     if (!d->Is<ast::OverrideDecoration>()) {
2607       diagnostics_.add_error(diag::System::Writer,
2608                              "Decorated const values not valid");
2609       return false;
2610     }
2611   }
2612   if (!var->is_const) {
2613     diagnostics_.add_error(diag::System::Writer, "Expected a const value");
2614     return false;
2615   }
2616 
2617   auto* sem = builder_.Sem().Get(var);
2618   auto* type = sem->Type();
2619 
2620   auto* global = sem->As<sem::GlobalVariable>();
2621   if (global && global->IsOverridable()) {
2622     auto const_id = global->ConstantId();
2623 
2624     line() << "#ifndef " << kSpecConstantPrefix << const_id;
2625 
2626     if (var->constructor != nullptr) {
2627       auto out = line();
2628       out << "#define " << kSpecConstantPrefix << const_id << " ";
2629       if (!EmitExpression(out, var->constructor)) {
2630         return false;
2631       }
2632     } else {
2633       line() << "#error spec constant required for constant id " << const_id;
2634     }
2635     line() << "#endif";
2636     {
2637       auto out = line();
2638       out << "const ";
2639       if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
2640                            builder_.Symbols().NameFor(var->symbol))) {
2641         return false;
2642       }
2643       out << " = " << kSpecConstantPrefix << const_id << ";";
2644     }
2645   } else {
2646     auto out = line();
2647     out << "const ";
2648     if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
2649                          builder_.Symbols().NameFor(var->symbol))) {
2650       return false;
2651     }
2652     out << " = ";
2653     if (!EmitExpression(out, var->constructor)) {
2654       return false;
2655     }
2656     out << ";";
2657   }
2658 
2659   return true;
2660 }
2661 
2662 template <typename F>
CallIntrinsicHelper(std::ostream & out,const ast::CallExpression * call,const sem::Intrinsic * intrinsic,F && build)2663 bool GeneratorImpl::CallIntrinsicHelper(std::ostream& out,
2664                                         const ast::CallExpression* call,
2665                                         const sem::Intrinsic* intrinsic,
2666                                         F&& build) {
2667   // Generate the helper function if it hasn't been created already
2668   auto fn = utils::GetOrCreate(intrinsics_, intrinsic, [&]() -> std::string {
2669     TextBuffer b;
2670     TINT_DEFER(helpers_.Append(b));
2671 
2672     auto fn_name =
2673         UniqueIdentifier(std::string("tint_") + sem::str(intrinsic->Type()));
2674     std::vector<std::string> parameter_names;
2675     {
2676       auto decl = line(&b);
2677       if (!EmitTypeAndName(decl, intrinsic->ReturnType(),
2678                            ast::StorageClass::kNone, ast::Access::kUndefined,
2679                            fn_name)) {
2680         return "";
2681       }
2682       {
2683         ScopedParen sp(decl);
2684         for (auto* param : intrinsic->Parameters()) {
2685           if (!parameter_names.empty()) {
2686             decl << ", ";
2687           }
2688           auto param_name = "param_" + std::to_string(parameter_names.size());
2689           const auto* ty = param->Type();
2690           if (auto* ptr = ty->As<sem::Pointer>()) {
2691             decl << "inout ";
2692             ty = ptr->StoreType();
2693           }
2694           if (!EmitTypeAndName(decl, ty, ast::StorageClass::kNone,
2695                                ast::Access::kUndefined, param_name)) {
2696             return "";
2697           }
2698           parameter_names.emplace_back(std::move(param_name));
2699         }
2700       }
2701       decl << " {";
2702     }
2703     {
2704       ScopedIndent si(&b);
2705       if (!build(&b, parameter_names)) {
2706         return "";
2707       }
2708     }
2709     line(&b) << "}";
2710     line(&b);
2711     return fn_name;
2712   });
2713 
2714   if (fn.empty()) {
2715     return false;
2716   }
2717 
2718   // Call the helper
2719   out << fn;
2720   {
2721     ScopedParen sp(out);
2722     bool first = true;
2723     for (auto* arg : call->args) {
2724       if (!first) {
2725         out << ", ";
2726       }
2727       first = false;
2728       if (!EmitExpression(out, arg)) {
2729         return false;
2730       }
2731     }
2732   }
2733   return true;
2734 }
2735 
2736 }  // namespace glsl
2737 }  // namespace writer
2738 }  // namespace tint
2739