• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/msl/generator_impl.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <limits>
21 #include <utility>
22 #include <vector>
23 
24 #include "src/ast/alias.h"
25 #include "src/ast/bool_literal_expression.h"
26 #include "src/ast/call_statement.h"
27 #include "src/ast/disable_validation_decoration.h"
28 #include "src/ast/fallthrough_statement.h"
29 #include "src/ast/float_literal_expression.h"
30 #include "src/ast/interpolate_decoration.h"
31 #include "src/ast/module.h"
32 #include "src/ast/override_decoration.h"
33 #include "src/ast/sint_literal_expression.h"
34 #include "src/ast/uint_literal_expression.h"
35 #include "src/ast/variable_decl_statement.h"
36 #include "src/ast/void.h"
37 #include "src/sem/array.h"
38 #include "src/sem/atomic_type.h"
39 #include "src/sem/bool_type.h"
40 #include "src/sem/call.h"
41 #include "src/sem/depth_multisampled_texture_type.h"
42 #include "src/sem/depth_texture_type.h"
43 #include "src/sem/f32_type.h"
44 #include "src/sem/function.h"
45 #include "src/sem/i32_type.h"
46 #include "src/sem/matrix_type.h"
47 #include "src/sem/member_accessor_expression.h"
48 #include "src/sem/multisampled_texture_type.h"
49 #include "src/sem/pointer_type.h"
50 #include "src/sem/reference_type.h"
51 #include "src/sem/sampled_texture_type.h"
52 #include "src/sem/storage_texture_type.h"
53 #include "src/sem/struct.h"
54 #include "src/sem/type_constructor.h"
55 #include "src/sem/type_conversion.h"
56 #include "src/sem/u32_type.h"
57 #include "src/sem/variable.h"
58 #include "src/sem/vector_type.h"
59 #include "src/sem/void_type.h"
60 #include "src/transform/array_length_from_uniform.h"
61 #include "src/transform/canonicalize_entry_point_io.h"
62 #include "src/transform/external_texture_transform.h"
63 #include "src/transform/manager.h"
64 #include "src/transform/module_scope_var_to_entry_point_param.h"
65 #include "src/transform/pad_array_elements.h"
66 #include "src/transform/promote_initializers_to_const_var.h"
67 #include "src/transform/remove_phonies.h"
68 #include "src/transform/simplify_pointers.h"
69 #include "src/transform/unshadow.h"
70 #include "src/transform/vectorize_scalar_matrix_constructors.h"
71 #include "src/transform/wrap_arrays_in_structs.h"
72 #include "src/transform/zero_init_workgroup_memory.h"
73 #include "src/utils/defer.h"
74 #include "src/utils/map.h"
75 #include "src/utils/scoped_assignment.h"
76 #include "src/writer/float_to_string.h"
77 
78 namespace tint {
79 namespace writer {
80 namespace msl {
81 namespace {
82 
last_is_break_or_fallthrough(const ast::BlockStatement * stmts)83 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
84   return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
85 }
86 
87 class ScopedBitCast {
88  public:
ScopedBitCast(GeneratorImpl * generator,std::ostream & stream,const sem::Type * curr_type,const sem::Type * target_type)89   ScopedBitCast(GeneratorImpl* generator,
90                 std::ostream& stream,
91                 const sem::Type* curr_type,
92                 const sem::Type* target_type)
93       : s(stream) {
94     auto* target_vec_type = target_type->As<sem::Vector>();
95 
96     // If we need to promote from scalar to vector, bitcast the scalar to the
97     // vector element type.
98     if (curr_type->is_scalar() && target_vec_type) {
99       target_type = target_vec_type->type();
100     }
101 
102     // Bit cast
103     s << "as_type<";
104     generator->EmitType(s, target_type, "");
105     s << ">(";
106   }
107 
~ScopedBitCast()108   ~ScopedBitCast() { s << ")"; }
109 
110  private:
111   std::ostream& s;
112 };
113 }  // namespace
114 
115 SanitizedResult::SanitizedResult() = default;
116 SanitizedResult::~SanitizedResult() = default;
117 SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
118 
Sanitize(const Program * in,uint32_t buffer_size_ubo_index,uint32_t fixed_sample_mask,bool emit_vertex_point_size,bool disable_workgroup_init,const ArrayLengthFromUniformOptions & array_length_from_uniform)119 SanitizedResult Sanitize(
120     const Program* in,
121     uint32_t buffer_size_ubo_index,
122     uint32_t fixed_sample_mask,
123     bool emit_vertex_point_size,
124     bool disable_workgroup_init,
125     const ArrayLengthFromUniformOptions& array_length_from_uniform) {
126   transform::Manager manager;
127   transform::DataMap internal_inputs;
128 
129   // Build the config for the internal ArrayLengthFromUniform transform.
130   transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
131       array_length_from_uniform.ubo_binding);
132   if (!array_length_from_uniform.bindpoint_to_size_index.empty()) {
133     // If |array_length_from_uniform| bindings are provided, use that config.
134     array_length_from_uniform_cfg.bindpoint_to_size_index =
135         array_length_from_uniform.bindpoint_to_size_index;
136   } else {
137     // If the binding map is empty, use the deprecated |buffer_size_ubo_index|
138     // and automatically choose indices using the binding numbers.
139     array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config(
140         sem::BindingPoint{0, buffer_size_ubo_index});
141     // Use the SSBO binding numbers as the indices for the buffer size lookups.
142     for (auto* var : in->AST().GlobalVariables()) {
143       auto* global = in->Sem().Get<sem::GlobalVariable>(var);
144       if (global && global->StorageClass() == ast::StorageClass::kStorage) {
145         array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
146             global->BindingPoint(), global->BindingPoint().binding);
147       }
148     }
149   }
150 
151   // Build the configs for the internal CanonicalizeEntryPointIO transform.
152   auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config(
153       transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask,
154       emit_vertex_point_size);
155 
156   manager.Add<transform::Unshadow>();
157 
158   if (!disable_workgroup_init) {
159     // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
160     // ZeroInitWorkgroupMemory may inject new builtin parameters.
161     manager.Add<transform::ZeroInitWorkgroupMemory>();
162   }
163   manager.Add<transform::CanonicalizeEntryPointIO>();
164   manager.Add<transform::ExternalTextureTransform>();
165   manager.Add<transform::PromoteInitializersToConstVar>();
166   manager.Add<transform::VectorizeScalarMatrixConstructors>();
167   manager.Add<transform::WrapArraysInStructs>();
168   manager.Add<transform::PadArrayElements>();
169   manager.Add<transform::RemovePhonies>();
170   manager.Add<transform::SimplifyPointers>();
171   // ArrayLengthFromUniform must come after SimplifyPointers, as
172   // it assumes that the form of the array length argument is &var.array.
173   manager.Add<transform::ArrayLengthFromUniform>();
174   manager.Add<transform::ModuleScopeVarToEntryPointParam>();
175   internal_inputs.Add<transform::ArrayLengthFromUniform::Config>(
176       std::move(array_length_from_uniform_cfg));
177   internal_inputs.Add<transform::CanonicalizeEntryPointIO::Config>(
178       std::move(entry_point_io_cfg));
179   auto out = manager.Run(in, internal_inputs);
180 
181   SanitizedResult result;
182   result.program = std::move(out.program);
183   if (!result.program.IsValid()) {
184     return result;
185   }
186   result.used_array_length_from_uniform_indices =
187       std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
188                     ->used_size_indices);
189   result.needs_storage_buffer_sizes =
190       !result.used_array_length_from_uniform_indices.empty();
191   return result;
192 }
193 
GeneratorImpl(const Program * program)194 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
195 
196 GeneratorImpl::~GeneratorImpl() = default;
197 
Generate()198 bool GeneratorImpl::Generate() {
199   line() << "#include <metal_stdlib>";
200   line();
201   line() << "using namespace metal;";
202 
203   auto helpers_insertion_point = current_buffer_->lines.size();
204 
205   for (auto* const type_decl : program_->AST().TypeDecls()) {
206     if (!type_decl->Is<ast::Alias>()) {
207       if (!EmitTypeDecl(TypeOf(type_decl))) {
208         return false;
209       }
210     }
211   }
212 
213   if (!program_->AST().TypeDecls().empty()) {
214     line();
215   }
216 
217   for (auto* var : program_->AST().GlobalVariables()) {
218     if (var->is_const) {
219       if (!EmitProgramConstVariable(var)) {
220         return false;
221       }
222     } else {
223       // These are pushed into the entry point by sanitizer transforms.
224       TINT_ICE(Writer, diagnostics_) << "module-scope variables should have "
225                                         "been handled by the MSL sanitizer";
226       break;
227     }
228   }
229 
230   for (auto* func : program_->AST().Functions()) {
231     if (!func->IsEntryPoint()) {
232       if (!EmitFunction(func)) {
233         return false;
234       }
235     } else {
236       if (!EmitEntryPointFunction(func)) {
237         return false;
238       }
239     }
240     line();
241   }
242 
243   if (!invariant_define_name_.empty()) {
244     // 'invariant' attribute requires MSL 2.1 or higher.
245     // WGSL can ignore the invariant attribute on pre MSL 2.1 devices.
246     // See: https://github.com/gpuweb/gpuweb/issues/893#issuecomment-745537465
247     line(&helpers_) << "#if __METAL_VERSION__ >= 210";
248     line(&helpers_) << "#define " << invariant_define_name_ << " [[invariant]]";
249     line(&helpers_) << "#else";
250     line(&helpers_) << "#define " << invariant_define_name_;
251     line(&helpers_) << "#endif";
252     line(&helpers_);
253   }
254 
255   if (!helpers_.lines.empty()) {
256     current_buffer_->Insert("", helpers_insertion_point++, 0);
257     current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
258   }
259 
260   return true;
261 }
262 
EmitTypeDecl(const sem::Type * ty)263 bool GeneratorImpl::EmitTypeDecl(const sem::Type* ty) {
264   if (auto* str = ty->As<sem::Struct>()) {
265     if (!EmitStructType(current_buffer_, str)) {
266       return false;
267     }
268   } else {
269     diagnostics_.add_error(diag::System::Writer,
270                            "unknown alias type: " + ty->type_name());
271     return false;
272   }
273 
274   return true;
275 }
276 
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)277 bool GeneratorImpl::EmitIndexAccessor(
278     std::ostream& out,
279     const ast::IndexAccessorExpression* expr) {
280   bool paren_lhs =
281       !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
282                              ast::IdentifierExpression,
283                              ast::MemberAccessorExpression>();
284 
285   if (paren_lhs) {
286     out << "(";
287   }
288   if (!EmitExpression(out, expr->object)) {
289     return false;
290   }
291   if (paren_lhs) {
292     out << ")";
293   }
294 
295   out << "[";
296 
297   if (!EmitExpression(out, expr->index)) {
298     return false;
299   }
300   out << "]";
301 
302   return true;
303 }
304 
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)305 bool GeneratorImpl::EmitBitcast(std::ostream& out,
306                                 const ast::BitcastExpression* expr) {
307   out << "as_type<";
308   if (!EmitType(out, TypeOf(expr)->UnwrapRef(), "")) {
309     return false;
310   }
311 
312   out << ">(";
313   if (!EmitExpression(out, expr->expr)) {
314     return false;
315   }
316 
317   out << ")";
318   return true;
319 }
320 
EmitAssign(const ast::AssignmentStatement * stmt)321 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
322   auto out = line();
323 
324   if (!EmitExpression(out, stmt->lhs)) {
325     return false;
326   }
327 
328   out << " = ";
329 
330   if (!EmitExpression(out, stmt->rhs)) {
331     return false;
332   }
333 
334   out << ";";
335 
336   return true;
337 }
338 
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)339 bool GeneratorImpl::EmitBinary(std::ostream& out,
340                                const ast::BinaryExpression* expr) {
341   auto emit_op = [&] {
342     out << " ";
343 
344     switch (expr->op) {
345       case ast::BinaryOp::kAnd:
346         out << "&";
347         break;
348       case ast::BinaryOp::kOr:
349         out << "|";
350         break;
351       case ast::BinaryOp::kXor:
352         out << "^";
353         break;
354       case ast::BinaryOp::kLogicalAnd:
355         out << "&&";
356         break;
357       case ast::BinaryOp::kLogicalOr:
358         out << "||";
359         break;
360       case ast::BinaryOp::kEqual:
361         out << "==";
362         break;
363       case ast::BinaryOp::kNotEqual:
364         out << "!=";
365         break;
366       case ast::BinaryOp::kLessThan:
367         out << "<";
368         break;
369       case ast::BinaryOp::kGreaterThan:
370         out << ">";
371         break;
372       case ast::BinaryOp::kLessThanEqual:
373         out << "<=";
374         break;
375       case ast::BinaryOp::kGreaterThanEqual:
376         out << ">=";
377         break;
378       case ast::BinaryOp::kShiftLeft:
379         out << "<<";
380         break;
381       case ast::BinaryOp::kShiftRight:
382         // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
383         // implementation-defined behaviour for negative LHS.  We may have to
384         // generate extra code to implement WGSL-specified behaviour for
385         // negative LHS.
386         out << R"(>>)";
387         break;
388 
389       case ast::BinaryOp::kAdd:
390         out << "+";
391         break;
392       case ast::BinaryOp::kSubtract:
393         out << "-";
394         break;
395       case ast::BinaryOp::kMultiply:
396         out << "*";
397         break;
398       case ast::BinaryOp::kDivide:
399         out << "/";
400         break;
401       case ast::BinaryOp::kModulo:
402         out << "%";
403         break;
404       case ast::BinaryOp::kNone:
405         diagnostics_.add_error(diag::System::Writer,
406                                "missing binary operation type");
407         return false;
408     }
409     out << " ";
410     return true;
411   };
412 
413   auto signed_type_of = [&](const sem::Type* ty) -> const sem::Type* {
414     if (ty->is_integer_scalar()) {
415       return builder_.create<sem::I32>();
416     } else if (auto* v = ty->As<sem::Vector>()) {
417       return builder_.create<sem::Vector>(builder_.create<sem::I32>(),
418                                           v->Width());
419     }
420     return {};
421   };
422 
423   auto unsigned_type_of = [&](const sem::Type* ty) -> const sem::Type* {
424     if (ty->is_integer_scalar()) {
425       return builder_.create<sem::U32>();
426     } else if (auto* v = ty->As<sem::Vector>()) {
427       return builder_.create<sem::Vector>(builder_.create<sem::U32>(),
428                                           v->Width());
429     }
430     return {};
431   };
432 
433   auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
434   auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
435 
436   // Handle fmod
437   if (expr->op == ast::BinaryOp::kModulo &&
438       lhs_type->is_float_scalar_or_vector()) {
439     out << "fmod";
440     ScopedParen sp(out);
441     if (!EmitExpression(out, expr->lhs)) {
442       return false;
443     }
444     out << ", ";
445     if (!EmitExpression(out, expr->rhs)) {
446       return false;
447     }
448     return true;
449   }
450 
451   // Handle +/-/* of signed values
452   if ((expr->IsAdd() || expr->IsSubtract() || expr->IsMultiply()) &&
453       lhs_type->is_signed_scalar_or_vector() &&
454       rhs_type->is_signed_scalar_or_vector()) {
455     // If lhs or rhs is a vector, use that type (support implicit scalar to
456     // vector promotion)
457     auto* target_type =
458         lhs_type->Is<sem::Vector>()
459             ? lhs_type
460             : (rhs_type->Is<sem::Vector>() ? rhs_type : lhs_type);
461 
462     // WGSL defines behaviour for signed overflow, MSL does not. For these
463     // cases, bitcast operands to unsigned, then cast result to signed.
464     ScopedBitCast outer_int_cast(this, out, target_type,
465                                  signed_type_of(target_type));
466     ScopedParen sp(out);
467     {
468       ScopedBitCast lhs_uint_cast(this, out, lhs_type,
469                                   unsigned_type_of(target_type));
470       if (!EmitExpression(out, expr->lhs)) {
471         return false;
472       }
473     }
474     if (!emit_op()) {
475       return false;
476     }
477     {
478       ScopedBitCast rhs_uint_cast(this, out, rhs_type,
479                                   unsigned_type_of(target_type));
480       if (!EmitExpression(out, expr->rhs)) {
481         return false;
482       }
483     }
484     return true;
485   }
486 
487   // Handle left bit shifting a signed value
488   // TODO(crbug.com/tint/1077): This may not be necessary. The MSL spec
489   // seems to imply that left shifting a signed value is treated the same as
490   // left shifting an unsigned value, but we need to make sure.
491   if (expr->IsShiftLeft() && lhs_type->is_signed_scalar_or_vector()) {
492     // Shift left: discards top bits, so convert first operand to unsigned
493     // first, then convert result back to signed
494     ScopedBitCast outer_int_cast(this, out, lhs_type, signed_type_of(lhs_type));
495     ScopedParen sp(out);
496     {
497       ScopedBitCast lhs_uint_cast(this, out, lhs_type,
498                                   unsigned_type_of(lhs_type));
499       if (!EmitExpression(out, expr->lhs)) {
500         return false;
501       }
502     }
503     if (!emit_op()) {
504       return false;
505     }
506     if (!EmitExpression(out, expr->rhs)) {
507       return false;
508     }
509     return true;
510   }
511 
512   // Emit as usual
513   ScopedParen sp(out);
514   if (!EmitExpression(out, expr->lhs)) {
515     return false;
516   }
517   if (!emit_op()) {
518     return false;
519   }
520   if (!EmitExpression(out, expr->rhs)) {
521     return false;
522   }
523 
524   return true;
525 }
526 
EmitBreak(const ast::BreakStatement *)527 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
528   line() << "break;";
529   return true;
530 }
531 
EmitCall(std::ostream & out,const ast::CallExpression * expr)532 bool GeneratorImpl::EmitCall(std::ostream& out,
533                              const ast::CallExpression* expr) {
534   auto* call = program_->Sem().Get(expr);
535   auto* target = call->Target();
536 
537   if (auto* func = target->As<sem::Function>()) {
538     return EmitFunctionCall(out, call, func);
539   }
540   if (auto* intrinsic = target->As<sem::Intrinsic>()) {
541     return EmitIntrinsicCall(out, call, intrinsic);
542   }
543   if (auto* conv = target->As<sem::TypeConversion>()) {
544     return EmitTypeConversion(out, call, conv);
545   }
546   if (auto* ctor = target->As<sem::TypeConstructor>()) {
547     return EmitTypeConstructor(out, call, ctor);
548   }
549 
550   TINT_ICE(Writer, diagnostics_)
551       << "unhandled call target: " << target->TypeInfo().name;
552   return false;
553 }
554 
EmitFunctionCall(std::ostream & out,const sem::Call * call,const sem::Function *)555 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
556                                      const sem::Call* call,
557                                      const sem::Function*) {
558   auto* ident = call->Declaration()->target.name;
559   out << program_->Symbols().NameFor(ident->symbol) << "(";
560 
561   bool first = true;
562   for (auto* arg : call->Arguments()) {
563     if (!first) {
564       out << ", ";
565     }
566     first = false;
567 
568     if (!EmitExpression(out, arg->Declaration())) {
569       return false;
570     }
571   }
572 
573   out << ")";
574   return true;
575 }
576 
EmitIntrinsicCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)577 bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
578                                       const sem::Call* call,
579                                       const sem::Intrinsic* intrinsic) {
580   auto* expr = call->Declaration();
581   if (intrinsic->IsAtomic()) {
582     return EmitAtomicCall(out, expr, intrinsic);
583   }
584   if (intrinsic->IsTexture()) {
585     return EmitTextureCall(out, call, intrinsic);
586   }
587 
588   auto name = generate_builtin_name(intrinsic);
589 
590   switch (intrinsic->Type()) {
591     case sem::IntrinsicType::kDot:
592       return EmitDotCall(out, expr, intrinsic);
593     case sem::IntrinsicType::kModf:
594       return EmitModfCall(out, expr, intrinsic);
595     case sem::IntrinsicType::kFrexp:
596       return EmitFrexpCall(out, expr, intrinsic);
597 
598     case sem::IntrinsicType::kPack2x16float:
599     case sem::IntrinsicType::kUnpack2x16float: {
600       if (intrinsic->Type() == sem::IntrinsicType::kPack2x16float) {
601         out << "as_type<uint>(half2(";
602       } else {
603         out << "float2(as_type<half2>(";
604       }
605       if (!EmitExpression(out, expr->args[0])) {
606         return false;
607       }
608       out << "))";
609       return true;
610     }
611     // TODO(crbug.com/tint/661): Combine sequential barriers to a single
612     // instruction.
613     case sem::IntrinsicType::kStorageBarrier: {
614       out << "threadgroup_barrier(mem_flags::mem_device)";
615       return true;
616     }
617     case sem::IntrinsicType::kWorkgroupBarrier: {
618       out << "threadgroup_barrier(mem_flags::mem_threadgroup)";
619       return true;
620     }
621     case sem::IntrinsicType::kIgnore: {  // [DEPRECATED]
622       out << "(void) ";
623       if (!EmitExpression(out, expr->args[0])) {
624         return false;
625       }
626       return true;
627     }
628 
629     case sem::IntrinsicType::kLength: {
630       auto* sem = builder_.Sem().Get(expr->args[0]);
631       if (sem->Type()->UnwrapRef()->is_scalar()) {
632         // Emulate scalar overload using fabs(x).
633         name = "fabs";
634       }
635       break;
636     }
637 
638     case sem::IntrinsicType::kDistance: {
639       auto* sem = builder_.Sem().Get(expr->args[0]);
640       if (sem->Type()->UnwrapRef()->is_scalar()) {
641         // Emulate scalar overload using fabs(x - y);
642         out << "fabs";
643         ScopedParen sp(out);
644         if (!EmitExpression(out, expr->args[0])) {
645           return false;
646         }
647         out << " - ";
648         if (!EmitExpression(out, expr->args[1])) {
649           return false;
650         }
651         return true;
652       }
653       break;
654     }
655 
656     default:
657       break;
658   }
659 
660   if (name.empty()) {
661     return false;
662   }
663 
664   out << name << "(";
665 
666   bool first = true;
667   for (auto* arg : expr->args) {
668     if (!first) {
669       out << ", ";
670     }
671     first = false;
672 
673     if (!EmitExpression(out, arg)) {
674       return false;
675     }
676   }
677 
678   out << ")";
679   return true;
680 }
681 
EmitTypeConversion(std::ostream & out,const sem::Call * call,const sem::TypeConversion * conv)682 bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
683                                        const sem::Call* call,
684                                        const sem::TypeConversion* conv) {
685   if (!EmitType(out, conv->Target(), "")) {
686     return false;
687   }
688   out << "(";
689 
690   if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
691     return false;
692   }
693 
694   out << ")";
695   return true;
696 }
697 
EmitTypeConstructor(std::ostream & out,const sem::Call * call,const sem::TypeConstructor * ctor)698 bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
699                                         const sem::Call* call,
700                                         const sem::TypeConstructor* ctor) {
701   auto* type = ctor->ReturnType();
702 
703   if (type->IsAnyOf<sem::Array, sem::Struct>()) {
704     out << "{";
705   } else {
706     if (!EmitType(out, type, "")) {
707       return false;
708     }
709     out << "(";
710   }
711 
712   int i = 0;
713   for (auto* arg : call->Arguments()) {
714     if (i > 0) {
715       out << ", ";
716     }
717 
718     if (auto* struct_ty = type->As<sem::Struct>()) {
719       // Emit field designators for structures to account for padding members.
720       auto* member = struct_ty->Members()[i]->Declaration();
721       auto name = program_->Symbols().NameFor(member->symbol);
722       out << "." << name << "=";
723     }
724 
725     if (!EmitExpression(out, arg->Declaration())) {
726       return false;
727     }
728 
729     i++;
730   }
731 
732   if (type->IsAnyOf<sem::Array, sem::Struct>()) {
733     out << "}";
734   } else {
735     out << ")";
736   }
737   return true;
738 }
739 
EmitAtomicCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)740 bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
741                                    const ast::CallExpression* expr,
742                                    const sem::Intrinsic* intrinsic) {
743   auto call = [&](const std::string& name, bool append_memory_order_relaxed) {
744     out << name;
745     {
746       ScopedParen sp(out);
747       for (size_t i = 0; i < expr->args.size(); i++) {
748         auto* arg = expr->args[i];
749         if (i > 0) {
750           out << ", ";
751         }
752         if (!EmitExpression(out, arg)) {
753           return false;
754         }
755       }
756       if (append_memory_order_relaxed) {
757         out << ", memory_order_relaxed";
758       }
759     }
760     return true;
761   };
762 
763   switch (intrinsic->Type()) {
764     case sem::IntrinsicType::kAtomicLoad:
765       return call("atomic_load_explicit", true);
766 
767     case sem::IntrinsicType::kAtomicStore:
768       return call("atomic_store_explicit", true);
769 
770     case sem::IntrinsicType::kAtomicAdd:
771       return call("atomic_fetch_add_explicit", true);
772 
773     case sem::IntrinsicType::kAtomicSub:
774       return call("atomic_fetch_sub_explicit", true);
775 
776     case sem::IntrinsicType::kAtomicMax:
777       return call("atomic_fetch_max_explicit", true);
778 
779     case sem::IntrinsicType::kAtomicMin:
780       return call("atomic_fetch_min_explicit", true);
781 
782     case sem::IntrinsicType::kAtomicAnd:
783       return call("atomic_fetch_and_explicit", true);
784 
785     case sem::IntrinsicType::kAtomicOr:
786       return call("atomic_fetch_or_explicit", true);
787 
788     case sem::IntrinsicType::kAtomicXor:
789       return call("atomic_fetch_xor_explicit", true);
790 
791     case sem::IntrinsicType::kAtomicExchange:
792       return call("atomic_exchange_explicit", true);
793 
794     case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
795       auto* ptr_ty = TypeOf(expr->args[0])->UnwrapRef()->As<sem::Pointer>();
796       auto sc = ptr_ty->StorageClass();
797 
798       auto func = utils::GetOrCreate(
799           atomicCompareExchangeWeak_, sc, [&]() -> std::string {
800             auto name = UniqueIdentifier("atomicCompareExchangeWeak");
801             auto& buf = helpers_;
802 
803             line(&buf) << "template <typename A, typename T>";
804             {
805               auto f = line(&buf);
806               f << "vec<T, 2> " << name << "(";
807               if (!EmitStorageClass(f, sc)) {
808                 return "";
809               }
810               f << " A* atomic, T compare, T value) {";
811             }
812 
813             buf.IncrementIndent();
814             TINT_DEFER({
815               buf.DecrementIndent();
816               line(&buf) << "}";
817               line(&buf);
818             });
819 
820             line(&buf) << "T prev_value = compare;";
821             line(&buf) << "bool matched = "
822                           "atomic_compare_exchange_weak_explicit(atomic, "
823                           "&prev_value, value, memory_order_relaxed, "
824                           "memory_order_relaxed);";
825             line(&buf) << "return {prev_value, matched};";
826             return name;
827           });
828 
829       return call(func, false);
830     }
831 
832     default:
833       break;
834   }
835 
836   TINT_UNREACHABLE(Writer, diagnostics_)
837       << "unsupported atomic intrinsic: " << intrinsic->Type();
838   return false;
839 }
840 
EmitTextureCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)841 bool GeneratorImpl::EmitTextureCall(std::ostream& out,
842                                     const sem::Call* call,
843                                     const sem::Intrinsic* intrinsic) {
844   using Usage = sem::ParameterUsage;
845 
846   auto& signature = intrinsic->Signature();
847   auto* expr = call->Declaration();
848   auto& arguments = call->Arguments();
849 
850   // Returns the argument with the given usage
851   auto arg = [&](Usage usage) {
852     int idx = signature.IndexOf(usage);
853     return (idx >= 0) ? arguments[idx] : nullptr;
854   };
855 
856   auto* texture = arg(Usage::kTexture)->Declaration();
857   if (!texture) {
858     TINT_ICE(Writer, diagnostics_) << "missing texture arg";
859     return false;
860   }
861 
862   auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
863 
864   // Helper to emit the texture expression, wrapped in parentheses if the
865   // expression includes an operator with lower precedence than the member
866   // accessor used for the function calls.
867   auto texture_expr = [&]() {
868     bool paren_lhs =
869         !texture->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
870                           ast::IdentifierExpression,
871                           ast::MemberAccessorExpression>();
872     if (paren_lhs) {
873       out << "(";
874     }
875     if (!EmitExpression(out, texture)) {
876       return false;
877     }
878     if (paren_lhs) {
879       out << ")";
880     }
881     return true;
882   };
883 
884   switch (intrinsic->Type()) {
885     case sem::IntrinsicType::kTextureDimensions: {
886       std::vector<const char*> dims;
887       switch (texture_type->dim()) {
888         case ast::TextureDimension::kNone:
889           diagnostics_.add_error(diag::System::Writer,
890                                  "texture dimension is kNone");
891           return false;
892         case ast::TextureDimension::k1d:
893           dims = {"width"};
894           break;
895         case ast::TextureDimension::k2d:
896         case ast::TextureDimension::k2dArray:
897         case ast::TextureDimension::kCube:
898         case ast::TextureDimension::kCubeArray:
899           dims = {"width", "height"};
900           break;
901         case ast::TextureDimension::k3d:
902           dims = {"width", "height", "depth"};
903           break;
904       }
905 
906       auto get_dim = [&](const char* name) {
907         if (!texture_expr()) {
908           return false;
909         }
910         out << ".get_" << name << "(";
911         if (auto* level = arg(Usage::kLevel)) {
912           if (!EmitExpression(out, level->Declaration())) {
913             return false;
914           }
915         }
916         out << ")";
917         return true;
918       };
919 
920       if (dims.size() == 1) {
921         out << "int(";
922         get_dim(dims[0]);
923         out << ")";
924       } else {
925         EmitType(out, TypeOf(expr)->UnwrapRef(), "");
926         out << "(";
927         for (size_t i = 0; i < dims.size(); i++) {
928           if (i > 0) {
929             out << ", ";
930           }
931           get_dim(dims[i]);
932         }
933         out << ")";
934       }
935       return true;
936     }
937     case sem::IntrinsicType::kTextureNumLayers: {
938       out << "int(";
939       if (!texture_expr()) {
940         return false;
941       }
942       out << ".get_array_size())";
943       return true;
944     }
945     case sem::IntrinsicType::kTextureNumLevels: {
946       out << "int(";
947       if (!texture_expr()) {
948         return false;
949       }
950       out << ".get_num_mip_levels())";
951       return true;
952     }
953     case sem::IntrinsicType::kTextureNumSamples: {
954       out << "int(";
955       if (!texture_expr()) {
956         return false;
957       }
958       out << ".get_num_samples())";
959       return true;
960     }
961     default:
962       break;
963   }
964 
965   if (!texture_expr()) {
966     return false;
967   }
968 
969   bool lod_param_is_named = true;
970 
971   switch (intrinsic->Type()) {
972     case sem::IntrinsicType::kTextureSample:
973     case sem::IntrinsicType::kTextureSampleBias:
974     case sem::IntrinsicType::kTextureSampleLevel:
975     case sem::IntrinsicType::kTextureSampleGrad:
976       out << ".sample(";
977       break;
978     case sem::IntrinsicType::kTextureSampleCompare:
979     case sem::IntrinsicType::kTextureSampleCompareLevel:
980       out << ".sample_compare(";
981       break;
982     case sem::IntrinsicType::kTextureGather:
983       out << ".gather(";
984       break;
985     case sem::IntrinsicType::kTextureGatherCompare:
986       out << ".gather_compare(";
987       break;
988     case sem::IntrinsicType::kTextureLoad:
989       out << ".read(";
990       lod_param_is_named = false;
991       break;
992     case sem::IntrinsicType::kTextureStore:
993       out << ".write(";
994       break;
995     default:
996       TINT_UNREACHABLE(Writer, diagnostics_)
997           << "Unhandled texture intrinsic '" << intrinsic->str() << "'";
998       return false;
999   }
1000 
1001   bool first_arg = true;
1002   auto maybe_write_comma = [&] {
1003     if (!first_arg) {
1004       out << ", ";
1005     }
1006     first_arg = false;
1007   };
1008 
1009   for (auto usage :
1010        {Usage::kValue, Usage::kSampler, Usage::kCoords, Usage::kArrayIndex,
1011         Usage::kDepthRef, Usage::kSampleIndex}) {
1012     if (auto* e = arg(usage)) {
1013       maybe_write_comma();
1014 
1015       // Cast the coordinates to unsigned integers if necessary.
1016       bool casted = false;
1017       if (usage == Usage::kCoords &&
1018           e->Type()->UnwrapRef()->is_integer_scalar_or_vector()) {
1019         casted = true;
1020         switch (texture_type->dim()) {
1021           case ast::TextureDimension::k1d:
1022             out << "uint(";
1023             break;
1024           case ast::TextureDimension::k2d:
1025           case ast::TextureDimension::k2dArray:
1026             out << "uint2(";
1027             break;
1028           case ast::TextureDimension::k3d:
1029             out << "uint3(";
1030             break;
1031           default:
1032             TINT_ICE(Writer, diagnostics_)
1033                 << "unhandled texture dimensionality";
1034             break;
1035         }
1036       }
1037 
1038       if (!EmitExpression(out, e->Declaration()))
1039         return false;
1040 
1041       if (casted) {
1042         out << ")";
1043       }
1044     }
1045   }
1046 
1047   if (auto* bias = arg(Usage::kBias)) {
1048     maybe_write_comma();
1049     out << "bias(";
1050     if (!EmitExpression(out, bias->Declaration())) {
1051       return false;
1052     }
1053     out << ")";
1054   }
1055   if (auto* level = arg(Usage::kLevel)) {
1056     maybe_write_comma();
1057     if (lod_param_is_named) {
1058       out << "level(";
1059     }
1060     if (!EmitExpression(out, level->Declaration())) {
1061       return false;
1062     }
1063     if (lod_param_is_named) {
1064       out << ")";
1065     }
1066   }
1067   if (intrinsic->Type() == sem::IntrinsicType::kTextureSampleCompareLevel) {
1068     maybe_write_comma();
1069     out << "level(0)";
1070   }
1071   if (auto* ddx = arg(Usage::kDdx)) {
1072     auto dim = texture_type->dim();
1073     switch (dim) {
1074       case ast::TextureDimension::k2d:
1075       case ast::TextureDimension::k2dArray:
1076         maybe_write_comma();
1077         out << "gradient2d(";
1078         break;
1079       case ast::TextureDimension::k3d:
1080         maybe_write_comma();
1081         out << "gradient3d(";
1082         break;
1083       case ast::TextureDimension::kCube:
1084       case ast::TextureDimension::kCubeArray:
1085         maybe_write_comma();
1086         out << "gradientcube(";
1087         break;
1088       default: {
1089         std::stringstream err;
1090         err << "MSL does not support gradients for " << dim << " textures";
1091         diagnostics_.add_error(diag::System::Writer, err.str());
1092         return false;
1093       }
1094     }
1095     if (!EmitExpression(out, ddx->Declaration())) {
1096       return false;
1097     }
1098     out << ", ";
1099     if (!EmitExpression(out, arg(Usage::kDdy)->Declaration())) {
1100       return false;
1101     }
1102     out << ")";
1103   }
1104 
1105   bool has_offset = false;
1106   if (auto* offset = arg(Usage::kOffset)) {
1107     has_offset = true;
1108     maybe_write_comma();
1109     if (!EmitExpression(out, offset->Declaration())) {
1110       return false;
1111     }
1112   }
1113 
1114   if (auto* component = arg(Usage::kComponent)) {
1115     maybe_write_comma();
1116     if (!has_offset) {
1117       // offset argument may need to be provided if we have a component.
1118       switch (texture_type->dim()) {
1119         case ast::TextureDimension::k2d:
1120         case ast::TextureDimension::k2dArray:
1121           out << "int2(0), ";
1122           break;
1123         default:
1124           break;  // Other texture dimensions don't have an offset
1125       }
1126     }
1127     auto c = component->ConstantValue().Elements()[0].i32;
1128     switch (c) {
1129       case 0:
1130         out << "component::x";
1131         break;
1132       case 1:
1133         out << "component::y";
1134         break;
1135       case 2:
1136         out << "component::z";
1137         break;
1138       case 3:
1139         out << "component::w";
1140         break;
1141       default:
1142         TINT_ICE(Writer, diagnostics_)
1143             << "invalid textureGather component: " << c;
1144         break;
1145     }
1146   }
1147 
1148   out << ")";
1149 
1150   return true;
1151 }
1152 
EmitDotCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1153 bool GeneratorImpl::EmitDotCall(std::ostream& out,
1154                                 const ast::CallExpression* expr,
1155                                 const sem::Intrinsic* intrinsic) {
1156   auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
1157   std::string fn = "dot";
1158   if (vec_ty->type()->is_integer_scalar()) {
1159     // MSL does not have a builtin for dot() with integer vector types.
1160     // Generate the helper function if it hasn't been created already
1161     fn = utils::GetOrCreate(
1162         int_dot_funcs_, vec_ty->Width(), [&]() -> std::string {
1163           TextBuffer b;
1164           TINT_DEFER(helpers_.Append(b));
1165 
1166           auto fn_name =
1167               UniqueIdentifier("tint_dot" + std::to_string(vec_ty->Width()));
1168           auto v = "vec<T," + std::to_string(vec_ty->Width()) + ">";
1169 
1170           line(&b) << "template<typename T>";
1171           line(&b) << "T " << fn_name << "(" << v << " a, " << v << " b) {";
1172           {
1173             auto l = line(&b);
1174             l << "  return ";
1175             for (uint32_t i = 0; i < vec_ty->Width(); i++) {
1176               if (i > 0) {
1177                 l << " + ";
1178               }
1179               l << "a[" << i << "]*b[" << i << "]";
1180             }
1181             l << ";";
1182           }
1183           line(&b) << "}";
1184           return fn_name;
1185         });
1186   }
1187 
1188   out << fn << "(";
1189   if (!EmitExpression(out, expr->args[0])) {
1190     return false;
1191   }
1192   out << ", ";
1193   if (!EmitExpression(out, expr->args[1])) {
1194     return false;
1195   }
1196   out << ")";
1197   return true;
1198 }
1199 
EmitModfCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1200 bool GeneratorImpl::EmitModfCall(std::ostream& out,
1201                                  const ast::CallExpression* expr,
1202                                  const sem::Intrinsic* intrinsic) {
1203   return CallIntrinsicHelper(
1204       out, expr, intrinsic,
1205       [&](TextBuffer* b, const std::vector<std::string>& params) {
1206         auto* ty = intrinsic->Parameters()[0]->Type();
1207         auto in = params[0];
1208 
1209         std::string width;
1210         if (auto* vec = ty->As<sem::Vector>()) {
1211           width = std::to_string(vec->Width());
1212         }
1213 
1214         // Emit the builtin return type unique to this overload. This does not
1215         // exist in the AST, so it will not be generated in Generate().
1216         if (!EmitStructType(&helpers_,
1217                             intrinsic->ReturnType()->As<sem::Struct>())) {
1218           return false;
1219         }
1220 
1221         line(b) << "float" << width << " whole;";
1222         line(b) << "float" << width << " fract = modf(" << in << ", whole);";
1223         line(b) << "return {fract, whole};";
1224         return true;
1225       });
1226 }
1227 
EmitFrexpCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1228 bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
1229                                   const ast::CallExpression* expr,
1230                                   const sem::Intrinsic* intrinsic) {
1231   return CallIntrinsicHelper(
1232       out, expr, intrinsic,
1233       [&](TextBuffer* b, const std::vector<std::string>& params) {
1234         auto* ty = intrinsic->Parameters()[0]->Type();
1235         auto in = params[0];
1236 
1237         std::string width;
1238         if (auto* vec = ty->As<sem::Vector>()) {
1239           width = std::to_string(vec->Width());
1240         }
1241 
1242         // Emit the builtin return type unique to this overload. This does not
1243         // exist in the AST, so it will not be generated in Generate().
1244         if (!EmitStructType(&helpers_,
1245                             intrinsic->ReturnType()->As<sem::Struct>())) {
1246           return false;
1247         }
1248 
1249         line(b) << "int" << width << " exp;";
1250         line(b) << "float" << width << " sig = frexp(" << in << ", exp);";
1251         line(b) << "return {sig, exp};";
1252         return true;
1253       });
1254 }
1255 
generate_builtin_name(const sem::Intrinsic * intrinsic)1256 std::string GeneratorImpl::generate_builtin_name(
1257     const sem::Intrinsic* intrinsic) {
1258   std::string out = "";
1259   switch (intrinsic->Type()) {
1260     case sem::IntrinsicType::kAcos:
1261     case sem::IntrinsicType::kAll:
1262     case sem::IntrinsicType::kAny:
1263     case sem::IntrinsicType::kAsin:
1264     case sem::IntrinsicType::kAtan:
1265     case sem::IntrinsicType::kAtan2:
1266     case sem::IntrinsicType::kCeil:
1267     case sem::IntrinsicType::kCos:
1268     case sem::IntrinsicType::kCosh:
1269     case sem::IntrinsicType::kCross:
1270     case sem::IntrinsicType::kDeterminant:
1271     case sem::IntrinsicType::kDistance:
1272     case sem::IntrinsicType::kDot:
1273     case sem::IntrinsicType::kExp:
1274     case sem::IntrinsicType::kExp2:
1275     case sem::IntrinsicType::kFloor:
1276     case sem::IntrinsicType::kFma:
1277     case sem::IntrinsicType::kFract:
1278     case sem::IntrinsicType::kFrexp:
1279     case sem::IntrinsicType::kLength:
1280     case sem::IntrinsicType::kLdexp:
1281     case sem::IntrinsicType::kLog:
1282     case sem::IntrinsicType::kLog2:
1283     case sem::IntrinsicType::kMix:
1284     case sem::IntrinsicType::kModf:
1285     case sem::IntrinsicType::kNormalize:
1286     case sem::IntrinsicType::kPow:
1287     case sem::IntrinsicType::kReflect:
1288     case sem::IntrinsicType::kRefract:
1289     case sem::IntrinsicType::kSelect:
1290     case sem::IntrinsicType::kSin:
1291     case sem::IntrinsicType::kSinh:
1292     case sem::IntrinsicType::kSqrt:
1293     case sem::IntrinsicType::kStep:
1294     case sem::IntrinsicType::kTan:
1295     case sem::IntrinsicType::kTanh:
1296     case sem::IntrinsicType::kTranspose:
1297     case sem::IntrinsicType::kTrunc:
1298     case sem::IntrinsicType::kSign:
1299     case sem::IntrinsicType::kClamp:
1300       out += intrinsic->str();
1301       break;
1302     case sem::IntrinsicType::kAbs:
1303       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1304         out += "fabs";
1305       } else {
1306         out += "abs";
1307       }
1308       break;
1309     case sem::IntrinsicType::kCountOneBits:
1310       out += "popcount";
1311       break;
1312     case sem::IntrinsicType::kDpdx:
1313     case sem::IntrinsicType::kDpdxCoarse:
1314     case sem::IntrinsicType::kDpdxFine:
1315       out += "dfdx";
1316       break;
1317     case sem::IntrinsicType::kDpdy:
1318     case sem::IntrinsicType::kDpdyCoarse:
1319     case sem::IntrinsicType::kDpdyFine:
1320       out += "dfdy";
1321       break;
1322     case sem::IntrinsicType::kFwidth:
1323     case sem::IntrinsicType::kFwidthCoarse:
1324     case sem::IntrinsicType::kFwidthFine:
1325       out += "fwidth";
1326       break;
1327     case sem::IntrinsicType::kIsFinite:
1328       out += "isfinite";
1329       break;
1330     case sem::IntrinsicType::kIsInf:
1331       out += "isinf";
1332       break;
1333     case sem::IntrinsicType::kIsNan:
1334       out += "isnan";
1335       break;
1336     case sem::IntrinsicType::kIsNormal:
1337       out += "isnormal";
1338       break;
1339     case sem::IntrinsicType::kMax:
1340       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1341         out += "fmax";
1342       } else {
1343         out += "max";
1344       }
1345       break;
1346     case sem::IntrinsicType::kMin:
1347       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1348         out += "fmin";
1349       } else {
1350         out += "min";
1351       }
1352       break;
1353     case sem::IntrinsicType::kFaceForward:
1354       out += "faceforward";
1355       break;
1356     case sem::IntrinsicType::kPack4x8snorm:
1357       out += "pack_float_to_snorm4x8";
1358       break;
1359     case sem::IntrinsicType::kPack4x8unorm:
1360       out += "pack_float_to_unorm4x8";
1361       break;
1362     case sem::IntrinsicType::kPack2x16snorm:
1363       out += "pack_float_to_snorm2x16";
1364       break;
1365     case sem::IntrinsicType::kPack2x16unorm:
1366       out += "pack_float_to_unorm2x16";
1367       break;
1368     case sem::IntrinsicType::kReverseBits:
1369       out += "reverse_bits";
1370       break;
1371     case sem::IntrinsicType::kRound:
1372       out += "rint";
1373       break;
1374     case sem::IntrinsicType::kSmoothStep:
1375       out += "smoothstep";
1376       break;
1377     case sem::IntrinsicType::kInverseSqrt:
1378       out += "rsqrt";
1379       break;
1380     case sem::IntrinsicType::kUnpack4x8snorm:
1381       out += "unpack_snorm4x8_to_float";
1382       break;
1383     case sem::IntrinsicType::kUnpack4x8unorm:
1384       out += "unpack_unorm4x8_to_float";
1385       break;
1386     case sem::IntrinsicType::kUnpack2x16snorm:
1387       out += "unpack_snorm2x16_to_float";
1388       break;
1389     case sem::IntrinsicType::kUnpack2x16unorm:
1390       out += "unpack_unorm2x16_to_float";
1391       break;
1392     case sem::IntrinsicType::kArrayLength:
1393       diagnostics_.add_error(
1394           diag::System::Writer,
1395           "Unable to translate builtin: " + std::string(intrinsic->str()) +
1396               "\nDid you forget to pass array_length_from_uniform generator "
1397               "options?");
1398       return "";
1399     default:
1400       diagnostics_.add_error(
1401           diag::System::Writer,
1402           "Unknown import method: " + std::string(intrinsic->str()));
1403       return "";
1404   }
1405   return out;
1406 }
1407 
EmitCase(const ast::CaseStatement * stmt)1408 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
1409   if (stmt->IsDefault()) {
1410     line() << "default: {";
1411   } else {
1412     for (auto* selector : stmt->selectors) {
1413       auto out = line();
1414       out << "case ";
1415       if (!EmitLiteral(out, selector)) {
1416         return false;
1417       }
1418       out << ":";
1419       if (selector == stmt->selectors.back()) {
1420         out << " {";
1421       }
1422     }
1423   }
1424 
1425   {
1426     ScopedIndent si(this);
1427 
1428     for (auto* s : stmt->body->statements) {
1429       if (!EmitStatement(s)) {
1430         return false;
1431       }
1432     }
1433 
1434     if (!last_is_break_or_fallthrough(stmt->body)) {
1435       line() << "break;";
1436     }
1437   }
1438 
1439   line() << "}";
1440 
1441   return true;
1442 }
1443 
EmitContinue(const ast::ContinueStatement *)1444 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
1445   if (!emit_continuing_()) {
1446     return false;
1447   }
1448 
1449   line() << "continue;";
1450   return true;
1451 }
1452 
EmitZeroValue(std::ostream & out,const sem::Type * type)1453 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
1454   if (type->Is<sem::Bool>()) {
1455     out << "false";
1456   } else if (type->Is<sem::F32>()) {
1457     out << "0.0f";
1458   } else if (type->Is<sem::I32>()) {
1459     out << "0";
1460   } else if (type->Is<sem::U32>()) {
1461     out << "0u";
1462   } else if (auto* vec = type->As<sem::Vector>()) {
1463     return EmitZeroValue(out, vec->type());
1464   } else if (auto* mat = type->As<sem::Matrix>()) {
1465     if (!EmitType(out, mat, "")) {
1466       return false;
1467     }
1468     out << "(";
1469     if (!EmitZeroValue(out, mat->type())) {
1470       return false;
1471     }
1472     out << ")";
1473   } else if (auto* arr = type->As<sem::Array>()) {
1474     out << "{";
1475     if (!EmitZeroValue(out, arr->ElemType())) {
1476       return false;
1477     }
1478     out << "}";
1479   } else if (type->As<sem::Struct>()) {
1480     out << "{}";
1481   } else {
1482     diagnostics_.add_error(
1483         diag::System::Writer,
1484         "Invalid type for zero emission: " + type->type_name());
1485     return false;
1486   }
1487   return true;
1488 }
1489 
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)1490 bool GeneratorImpl::EmitLiteral(std::ostream& out,
1491                                 const ast::LiteralExpression* lit) {
1492   if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
1493     out << (l->value ? "true" : "false");
1494   } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
1495     if (std::isinf(fl->value)) {
1496       out << (fl->value >= 0 ? "INFINITY" : "-INFINITY");
1497     } else if (std::isnan(fl->value)) {
1498       out << "NAN";
1499     } else {
1500       out << FloatToString(fl->value) << "f";
1501     }
1502   } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
1503     // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
1504     // minus and `2147483648` as separate tokens, and the latter doesn't
1505     // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
1506     // issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
1507     // ensures the expression type is `int`.
1508     const auto int_min = std::numeric_limits<int32_t>::min();
1509     if (sl->ValueAsI32() == int_min) {
1510       out << "(" << int_min + 1 << " - 1)";
1511     } else {
1512       out << sl->value;
1513     }
1514   } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
1515     out << ul->value << "u";
1516   } else {
1517     diagnostics_.add_error(diag::System::Writer, "unknown literal type");
1518     return false;
1519   }
1520   return true;
1521 }
1522 
EmitExpression(std::ostream & out,const ast::Expression * expr)1523 bool GeneratorImpl::EmitExpression(std::ostream& out,
1524                                    const ast::Expression* expr) {
1525   if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
1526     return EmitIndexAccessor(out, a);
1527   }
1528   if (auto* b = expr->As<ast::BinaryExpression>()) {
1529     return EmitBinary(out, b);
1530   }
1531   if (auto* b = expr->As<ast::BitcastExpression>()) {
1532     return EmitBitcast(out, b);
1533   }
1534   if (auto* c = expr->As<ast::CallExpression>()) {
1535     return EmitCall(out, c);
1536   }
1537   if (auto* i = expr->As<ast::IdentifierExpression>()) {
1538     return EmitIdentifier(out, i);
1539   }
1540   if (auto* l = expr->As<ast::LiteralExpression>()) {
1541     return EmitLiteral(out, l);
1542   }
1543   if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
1544     return EmitMemberAccessor(out, m);
1545   }
1546   if (auto* u = expr->As<ast::UnaryOpExpression>()) {
1547     return EmitUnaryOp(out, u);
1548   }
1549 
1550   diagnostics_.add_error(
1551       diag::System::Writer,
1552       "unknown expression type: " + std::string(expr->TypeInfo().name));
1553   return false;
1554 }
1555 
EmitStage(std::ostream & out,ast::PipelineStage stage)1556 void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
1557   switch (stage) {
1558     case ast::PipelineStage::kFragment:
1559       out << "fragment";
1560       break;
1561     case ast::PipelineStage::kVertex:
1562       out << "vertex";
1563       break;
1564     case ast::PipelineStage::kCompute:
1565       out << "kernel";
1566       break;
1567     case ast::PipelineStage::kNone:
1568       break;
1569   }
1570   return;
1571 }
1572 
EmitFunction(const ast::Function * func)1573 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
1574   auto* func_sem = program_->Sem().Get(func);
1575 
1576   {
1577     auto out = line();
1578     if (!EmitType(out, func_sem->ReturnType(), "")) {
1579       return false;
1580     }
1581     out << " " << program_->Symbols().NameFor(func->symbol) << "(";
1582 
1583     bool first = true;
1584     for (auto* v : func->params) {
1585       if (!first) {
1586         out << ", ";
1587       }
1588       first = false;
1589 
1590       auto* type = program_->Sem().Get(v)->Type();
1591 
1592       std::string param_name =
1593           "const " + program_->Symbols().NameFor(v->symbol);
1594       if (!EmitType(out, type, param_name)) {
1595         return false;
1596       }
1597       // Parameter name is output as part of the type for arrays and pointers.
1598       if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
1599         out << " " << program_->Symbols().NameFor(v->symbol);
1600       }
1601     }
1602 
1603     out << ") {";
1604   }
1605 
1606   if (!EmitStatementsWithIndent(func->body->statements)) {
1607     return false;
1608   }
1609 
1610   line() << "}";
1611 
1612   return true;
1613 }
1614 
builtin_to_attribute(ast::Builtin builtin) const1615 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
1616   switch (builtin) {
1617     case ast::Builtin::kPosition:
1618       return "position";
1619     case ast::Builtin::kVertexIndex:
1620       return "vertex_id";
1621     case ast::Builtin::kInstanceIndex:
1622       return "instance_id";
1623     case ast::Builtin::kFrontFacing:
1624       return "front_facing";
1625     case ast::Builtin::kFragDepth:
1626       return "depth(any)";
1627     case ast::Builtin::kLocalInvocationId:
1628       return "thread_position_in_threadgroup";
1629     case ast::Builtin::kLocalInvocationIndex:
1630       return "thread_index_in_threadgroup";
1631     case ast::Builtin::kGlobalInvocationId:
1632       return "thread_position_in_grid";
1633     case ast::Builtin::kWorkgroupId:
1634       return "threadgroup_position_in_grid";
1635     case ast::Builtin::kNumWorkgroups:
1636       return "threadgroups_per_grid";
1637     case ast::Builtin::kSampleIndex:
1638       return "sample_id";
1639     case ast::Builtin::kSampleMask:
1640       return "sample_mask";
1641     case ast::Builtin::kPointSize:
1642       return "point_size";
1643     default:
1644       break;
1645   }
1646   return "";
1647 }
1648 
interpolation_to_attribute(ast::InterpolationType type,ast::InterpolationSampling sampling) const1649 std::string GeneratorImpl::interpolation_to_attribute(
1650     ast::InterpolationType type,
1651     ast::InterpolationSampling sampling) const {
1652   std::string attr;
1653   switch (sampling) {
1654     case ast::InterpolationSampling::kCenter:
1655       attr = "center_";
1656       break;
1657     case ast::InterpolationSampling::kCentroid:
1658       attr = "centroid_";
1659       break;
1660     case ast::InterpolationSampling::kSample:
1661       attr = "sample_";
1662       break;
1663     case ast::InterpolationSampling::kNone:
1664       break;
1665   }
1666   switch (type) {
1667     case ast::InterpolationType::kPerspective:
1668       attr += "perspective";
1669       break;
1670     case ast::InterpolationType::kLinear:
1671       attr += "no_perspective";
1672       break;
1673     case ast::InterpolationType::kFlat:
1674       attr += "flat";
1675       break;
1676   }
1677   return attr;
1678 }
1679 
EmitEntryPointFunction(const ast::Function * func)1680 bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
1681   auto func_name = program_->Symbols().NameFor(func->symbol);
1682 
1683   // Returns the binding index of a variable, requiring that the group attribute
1684   // have a value of zero.
1685   const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max();
1686   auto get_binding_index = [&](const ast::Variable* var) -> uint32_t {
1687     auto bp = var->BindingPoint();
1688     if (bp.group == nullptr || bp.binding == nullptr) {
1689       TINT_ICE(Writer, diagnostics_)
1690           << "missing binding attributes for entry point parameter";
1691       return kInvalidBindingIndex;
1692     }
1693     if (bp.group->value != 0) {
1694       TINT_ICE(Writer, diagnostics_)
1695           << "encountered non-zero resource group index (use "
1696              "BindingRemapper to fix)";
1697       return kInvalidBindingIndex;
1698     }
1699     return bp.binding->value;
1700   };
1701 
1702   {
1703     auto out = line();
1704 
1705     EmitStage(out, func->PipelineStage());
1706     out << " " << func->return_type->FriendlyName(program_->Symbols());
1707     out << " " << func_name << "(";
1708 
1709     // Emit entry point parameters.
1710     bool first = true;
1711     for (auto* var : func->params) {
1712       if (!first) {
1713         out << ", ";
1714       }
1715       first = false;
1716 
1717       auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
1718 
1719       auto param_name = program_->Symbols().NameFor(var->symbol);
1720       if (!EmitType(out, type, param_name)) {
1721         return false;
1722       }
1723       // Parameter name is output as part of the type for arrays and pointers.
1724       if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
1725         out << " " << param_name;
1726       }
1727 
1728       if (type->Is<sem::Struct>()) {
1729         out << " [[stage_in]]";
1730       } else if (type->is_handle()) {
1731         uint32_t binding = get_binding_index(var);
1732         if (binding == kInvalidBindingIndex) {
1733           return false;
1734         }
1735         if (var->type->Is<ast::Sampler>()) {
1736           out << " [[sampler(" << binding << ")]]";
1737         } else if (var->type->Is<ast::Texture>()) {
1738           out << " [[texture(" << binding << ")]]";
1739         } else {
1740           TINT_ICE(Writer, diagnostics_)
1741               << "invalid handle type entry point parameter";
1742           return false;
1743         }
1744       } else if (auto* ptr = var->type->As<ast::Pointer>()) {
1745         auto sc = ptr->storage_class;
1746         if (sc == ast::StorageClass::kWorkgroup) {
1747           auto& allocations = workgroup_allocations_[func_name];
1748           out << " [[threadgroup(" << allocations.size() << ")]]";
1749           allocations.push_back(program_->Sem().Get(ptr->type)->Size());
1750         } else if (sc == ast::StorageClass::kStorage ||
1751                    sc == ast::StorageClass::kUniform) {
1752           uint32_t binding = get_binding_index(var);
1753           if (binding == kInvalidBindingIndex) {
1754             return false;
1755           }
1756           out << " [[buffer(" << binding << ")]]";
1757         } else {
1758           TINT_ICE(Writer, diagnostics_)
1759               << "invalid pointer storage class for entry point parameter";
1760           return false;
1761         }
1762       } else {
1763         auto& decos = var->decorations;
1764         bool builtin_found = false;
1765         for (auto* deco : decos) {
1766           auto* builtin = deco->As<ast::BuiltinDecoration>();
1767           if (!builtin) {
1768             continue;
1769           }
1770 
1771           builtin_found = true;
1772 
1773           auto attr = builtin_to_attribute(builtin->builtin);
1774           if (attr.empty()) {
1775             diagnostics_.add_error(diag::System::Writer, "unknown builtin");
1776             return false;
1777           }
1778           out << " [[" << attr << "]]";
1779         }
1780         if (!builtin_found) {
1781           TINT_ICE(Writer, diagnostics_) << "Unsupported entry point parameter";
1782         }
1783       }
1784     }
1785     out << ") {";
1786   }
1787 
1788   {
1789     ScopedIndent si(this);
1790 
1791     if (!EmitStatements(func->body->statements)) {
1792       return false;
1793     }
1794 
1795     if (!Is<ast::ReturnStatement>(func->body->Last())) {
1796       ast::ReturnStatement ret(ProgramID{}, Source{});
1797       if (!EmitStatement(&ret)) {
1798         return false;
1799       }
1800     }
1801   }
1802 
1803   line() << "}";
1804   return true;
1805 }
1806 
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)1807 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
1808                                    const ast::IdentifierExpression* expr) {
1809   out << program_->Symbols().NameFor(expr->symbol);
1810   return true;
1811 }
1812 
EmitLoop(const ast::LoopStatement * stmt)1813 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
1814   auto emit_continuing = [this, stmt]() {
1815     if (stmt->continuing && !stmt->continuing->Empty()) {
1816       if (!EmitBlock(stmt->continuing)) {
1817         return false;
1818       }
1819     }
1820     return true;
1821   };
1822 
1823   TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1824   line() << "while (true) {";
1825   {
1826     ScopedIndent si(this);
1827     if (!EmitStatements(stmt->body->statements)) {
1828       return false;
1829     }
1830     if (!emit_continuing()) {
1831       return false;
1832     }
1833   }
1834   line() << "}";
1835 
1836   return true;
1837 }
1838 
EmitForLoop(const ast::ForLoopStatement * stmt)1839 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
1840   TextBuffer init_buf;
1841   if (auto* init = stmt->initializer) {
1842     TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
1843     if (!EmitStatement(init)) {
1844       return false;
1845     }
1846   }
1847 
1848   TextBuffer cond_pre;
1849   std::stringstream cond_buf;
1850   if (auto* cond = stmt->condition) {
1851     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
1852     if (!EmitExpression(cond_buf, cond)) {
1853       return false;
1854     }
1855   }
1856 
1857   TextBuffer cont_buf;
1858   if (auto* cont = stmt->continuing) {
1859     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
1860     if (!EmitStatement(cont)) {
1861       return false;
1862     }
1863   }
1864 
1865   // If the for-loop has a multi-statement conditional and / or continuing, then
1866   // we cannot emit this as a regular for-loop in MSL. Instead we need to
1867   // generate a `while(true)` loop.
1868   bool emit_as_loop = cond_pre.lines.size() > 0 || cont_buf.lines.size() > 1;
1869 
1870   // If the for-loop has multi-statement initializer, or is going to be emitted
1871   // as a `while(true)` loop, then declare the initializer statement(s) before
1872   // the loop in a new block.
1873   bool nest_in_block =
1874       init_buf.lines.size() > 1 || (stmt->initializer && emit_as_loop);
1875   if (nest_in_block) {
1876     line() << "{";
1877     increment_indent();
1878     current_buffer_->Append(init_buf);
1879     init_buf.lines.clear();  // Don't emit the initializer again in the 'for'
1880   }
1881   TINT_DEFER({
1882     if (nest_in_block) {
1883       decrement_indent();
1884       line() << "}";
1885     }
1886   });
1887 
1888   if (emit_as_loop) {
1889     auto emit_continuing = [&]() {
1890       current_buffer_->Append(cont_buf);
1891       return true;
1892     };
1893 
1894     TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1895     line() << "while (true) {";
1896     increment_indent();
1897     TINT_DEFER({
1898       decrement_indent();
1899       line() << "}";
1900     });
1901 
1902     if (stmt->condition) {
1903       current_buffer_->Append(cond_pre);
1904       line() << "if (!(" << cond_buf.str() << ")) { break; }";
1905     }
1906 
1907     if (!EmitStatements(stmt->body->statements)) {
1908       return false;
1909     }
1910 
1911     if (!emit_continuing()) {
1912       return false;
1913     }
1914   } else {
1915     // For-loop can be generated.
1916     {
1917       auto out = line();
1918       out << "for";
1919       {
1920         ScopedParen sp(out);
1921 
1922         if (!init_buf.lines.empty()) {
1923           out << init_buf.lines[0].content << " ";
1924         } else {
1925           out << "; ";
1926         }
1927 
1928         out << cond_buf.str() << "; ";
1929 
1930         if (!cont_buf.lines.empty()) {
1931           out << TrimSuffix(cont_buf.lines[0].content, ";");
1932         }
1933       }
1934       out << " {";
1935     }
1936     {
1937       auto emit_continuing = [] { return true; };
1938       TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1939       if (!EmitStatementsWithIndent(stmt->body->statements)) {
1940         return false;
1941       }
1942     }
1943     line() << "}";
1944   }
1945 
1946   return true;
1947 }
1948 
EmitDiscard(const ast::DiscardStatement *)1949 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
1950   // TODO(dsinclair): Verify this is correct when the discard semantics are
1951   // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
1952   line() << "discard_fragment();";
1953   return true;
1954 }
1955 
EmitIf(const ast::IfStatement * stmt)1956 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
1957   {
1958     auto out = line();
1959     out << "if (";
1960     if (!EmitExpression(out, stmt->condition)) {
1961       return false;
1962     }
1963     out << ") {";
1964   }
1965 
1966   if (!EmitStatementsWithIndent(stmt->body->statements)) {
1967     return false;
1968   }
1969 
1970   for (auto* e : stmt->else_statements) {
1971     if (e->condition) {
1972       line() << "} else {";
1973       increment_indent();
1974 
1975       {
1976         auto out = line();
1977         out << "if (";
1978         if (!EmitExpression(out, e->condition)) {
1979           return false;
1980         }
1981         out << ") {";
1982       }
1983     } else {
1984       line() << "} else {";
1985     }
1986 
1987     if (!EmitStatementsWithIndent(e->body->statements)) {
1988       return false;
1989     }
1990   }
1991 
1992   line() << "}";
1993 
1994   for (auto* e : stmt->else_statements) {
1995     if (e->condition) {
1996       decrement_indent();
1997       line() << "}";
1998     }
1999   }
2000   return true;
2001 }
2002 
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)2003 bool GeneratorImpl::EmitMemberAccessor(
2004     std::ostream& out,
2005     const ast::MemberAccessorExpression* expr) {
2006   auto write_lhs = [&] {
2007     bool paren_lhs = !expr->structure->IsAnyOf<
2008         ast::IndexAccessorExpression, ast::CallExpression,
2009         ast::IdentifierExpression, ast::MemberAccessorExpression>();
2010     if (paren_lhs) {
2011       out << "(";
2012     }
2013     if (!EmitExpression(out, expr->structure)) {
2014       return false;
2015     }
2016     if (paren_lhs) {
2017       out << ")";
2018     }
2019     return true;
2020   };
2021 
2022   auto& sem = program_->Sem();
2023 
2024   if (auto* swizzle = sem.Get(expr)->As<sem::Swizzle>()) {
2025     // Metal 1.x does not support swizzling of packed vector types.
2026     // For single element swizzles, we can use the index operator.
2027     // For multi-element swizzles, we need to cast to a regular vector type
2028     // first. Note that we do not currently allow assignments to swizzles, so
2029     // the casting which will convert the l-value to r-value is fine.
2030     if (swizzle->Indices().size() == 1) {
2031       if (!write_lhs()) {
2032         return false;
2033       }
2034       out << "[" << swizzle->Indices()[0] << "]";
2035     } else {
2036       if (!EmitType(out, sem.Get(expr->structure)->Type()->UnwrapRef(), "")) {
2037         return false;
2038       }
2039       out << "(";
2040       if (!write_lhs()) {
2041         return false;
2042       }
2043       out << ")." << program_->Symbols().NameFor(expr->member->symbol);
2044     }
2045   } else {
2046     if (!write_lhs()) {
2047       return false;
2048     }
2049     out << ".";
2050     if (!EmitExpression(out, expr->member)) {
2051       return false;
2052     }
2053   }
2054 
2055   return true;
2056 }
2057 
EmitReturn(const ast::ReturnStatement * stmt)2058 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
2059   auto out = line();
2060   out << "return";
2061   if (stmt->value) {
2062     out << " ";
2063     if (!EmitExpression(out, stmt->value)) {
2064       return false;
2065     }
2066   }
2067   out << ";";
2068   return true;
2069 }
2070 
EmitBlock(const ast::BlockStatement * stmt)2071 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
2072   line() << "{";
2073 
2074   if (!EmitStatementsWithIndent(stmt->statements)) {
2075     return false;
2076   }
2077 
2078   line() << "}";
2079 
2080   return true;
2081 }
2082 
EmitStatement(const ast::Statement * stmt)2083 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
2084   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
2085     return EmitAssign(a);
2086   }
2087   if (auto* b = stmt->As<ast::BlockStatement>()) {
2088     return EmitBlock(b);
2089   }
2090   if (auto* b = stmt->As<ast::BreakStatement>()) {
2091     return EmitBreak(b);
2092   }
2093   if (auto* c = stmt->As<ast::CallStatement>()) {
2094     auto out = line();
2095     if (!EmitCall(out, c->expr)) {
2096       return false;
2097     }
2098     out << ";";
2099     return true;
2100   }
2101   if (auto* c = stmt->As<ast::ContinueStatement>()) {
2102     return EmitContinue(c);
2103   }
2104   if (auto* d = stmt->As<ast::DiscardStatement>()) {
2105     return EmitDiscard(d);
2106   }
2107   if (stmt->As<ast::FallthroughStatement>()) {
2108     line() << "/* fallthrough */";
2109     return true;
2110   }
2111   if (auto* i = stmt->As<ast::IfStatement>()) {
2112     return EmitIf(i);
2113   }
2114   if (auto* l = stmt->As<ast::LoopStatement>()) {
2115     return EmitLoop(l);
2116   }
2117   if (auto* l = stmt->As<ast::ForLoopStatement>()) {
2118     return EmitForLoop(l);
2119   }
2120   if (auto* r = stmt->As<ast::ReturnStatement>()) {
2121     return EmitReturn(r);
2122   }
2123   if (auto* s = stmt->As<ast::SwitchStatement>()) {
2124     return EmitSwitch(s);
2125   }
2126   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
2127     auto* var = program_->Sem().Get(v->variable);
2128     return EmitVariable(var);
2129   }
2130 
2131   diagnostics_.add_error(
2132       diag::System::Writer,
2133       "unknown statement type: " + std::string(stmt->TypeInfo().name));
2134   return false;
2135 }
2136 
EmitStatements(const ast::StatementList & stmts)2137 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
2138   for (auto* s : stmts) {
2139     if (!EmitStatement(s)) {
2140       return false;
2141     }
2142   }
2143   return true;
2144 }
2145 
EmitStatementsWithIndent(const ast::StatementList & stmts)2146 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
2147   ScopedIndent si(this);
2148   return EmitStatements(stmts);
2149 }
2150 
EmitSwitch(const ast::SwitchStatement * stmt)2151 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
2152   {
2153     auto out = line();
2154     out << "switch(";
2155     if (!EmitExpression(out, stmt->condition)) {
2156       return false;
2157     }
2158     out << ") {";
2159   }
2160 
2161   {
2162     ScopedIndent si(this);
2163     for (auto* s : stmt->body) {
2164       if (!EmitCase(s)) {
2165         return false;
2166       }
2167     }
2168   }
2169 
2170   line() << "}";
2171 
2172   return true;
2173 }
2174 
EmitType(std::ostream & out,const sem::Type * type,const std::string & name,bool * name_printed)2175 bool GeneratorImpl::EmitType(std::ostream& out,
2176                              const sem::Type* type,
2177                              const std::string& name,
2178                              bool* name_printed /* = nullptr */) {
2179   if (name_printed) {
2180     *name_printed = false;
2181   }
2182   if (auto* atomic = type->As<sem::Atomic>()) {
2183     if (atomic->Type()->Is<sem::I32>()) {
2184       out << "atomic_int";
2185       return true;
2186     }
2187     if (atomic->Type()->Is<sem::U32>()) {
2188       out << "atomic_uint";
2189       return true;
2190     }
2191     TINT_ICE(Writer, diagnostics_)
2192         << "unhandled atomic type " << atomic->Type()->type_name();
2193     return false;
2194   }
2195 
2196   if (auto* ary = type->As<sem::Array>()) {
2197     const sem::Type* base_type = ary;
2198     std::vector<uint32_t> sizes;
2199     while (auto* arr = base_type->As<sem::Array>()) {
2200       if (arr->IsRuntimeSized()) {
2201         sizes.push_back(1);
2202       } else {
2203         sizes.push_back(arr->Count());
2204       }
2205       base_type = arr->ElemType();
2206     }
2207     if (!EmitType(out, base_type, "")) {
2208       return false;
2209     }
2210     if (!name.empty()) {
2211       out << " " << name;
2212       if (name_printed) {
2213         *name_printed = true;
2214       }
2215     }
2216     for (uint32_t size : sizes) {
2217       out << "[" << size << "]";
2218     }
2219     return true;
2220   }
2221 
2222   if (type->Is<sem::Bool>()) {
2223     out << "bool";
2224     return true;
2225   }
2226 
2227   if (type->Is<sem::F32>()) {
2228     out << "float";
2229     return true;
2230   }
2231 
2232   if (type->Is<sem::I32>()) {
2233     out << "int";
2234     return true;
2235   }
2236 
2237   if (auto* mat = type->As<sem::Matrix>()) {
2238     if (!EmitType(out, mat->type(), "")) {
2239       return false;
2240     }
2241     out << mat->columns() << "x" << mat->rows();
2242     return true;
2243   }
2244 
2245   if (auto* ptr = type->As<sem::Pointer>()) {
2246     if (ptr->Access() == ast::Access::kRead) {
2247       out << "const ";
2248     }
2249     if (!EmitStorageClass(out, ptr->StorageClass())) {
2250       return false;
2251     }
2252     out << " ";
2253     if (ptr->StoreType()->Is<sem::Array>()) {
2254       std::string inner = "(*" + name + ")";
2255       if (!EmitType(out, ptr->StoreType(), inner)) {
2256         return false;
2257       }
2258       if (name_printed) {
2259         *name_printed = true;
2260       }
2261     } else {
2262       if (!EmitType(out, ptr->StoreType(), "")) {
2263         return false;
2264       }
2265       out << "* " << name;
2266       if (name_printed) {
2267         *name_printed = true;
2268       }
2269     }
2270     return true;
2271   }
2272 
2273   if (type->Is<sem::Sampler>()) {
2274     out << "sampler";
2275     return true;
2276   }
2277 
2278   if (auto* str = type->As<sem::Struct>()) {
2279     // The struct type emits as just the name. The declaration would be emitted
2280     // as part of emitting the declared types.
2281     out << StructName(str);
2282     return true;
2283   }
2284 
2285   if (auto* tex = type->As<sem::Texture>()) {
2286     if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
2287       out << "depth";
2288     } else {
2289       out << "texture";
2290     }
2291 
2292     switch (tex->dim()) {
2293       case ast::TextureDimension::k1d:
2294         out << "1d";
2295         break;
2296       case ast::TextureDimension::k2d:
2297         out << "2d";
2298         break;
2299       case ast::TextureDimension::k2dArray:
2300         out << "2d_array";
2301         break;
2302       case ast::TextureDimension::k3d:
2303         out << "3d";
2304         break;
2305       case ast::TextureDimension::kCube:
2306         out << "cube";
2307         break;
2308       case ast::TextureDimension::kCubeArray:
2309         out << "cube_array";
2310         break;
2311       default:
2312         diagnostics_.add_error(diag::System::Writer,
2313                                "Invalid texture dimensions");
2314         return false;
2315     }
2316     if (tex->IsAnyOf<sem::MultisampledTexture,
2317                      sem::DepthMultisampledTexture>()) {
2318       out << "_ms";
2319     }
2320     out << "<";
2321     if (tex->Is<sem::DepthTexture>()) {
2322       out << "float, access::sample";
2323     } else if (tex->Is<sem::DepthMultisampledTexture>()) {
2324       out << "float, access::read";
2325     } else if (auto* storage = tex->As<sem::StorageTexture>()) {
2326       if (!EmitType(out, storage->type(), "")) {
2327         return false;
2328       }
2329 
2330       std::string access_str;
2331       if (storage->access() == ast::Access::kRead) {
2332         out << ", access::read";
2333       } else if (storage->access() == ast::Access::kWrite) {
2334         out << ", access::write";
2335       } else {
2336         diagnostics_.add_error(diag::System::Writer,
2337                                "Invalid access control for storage texture");
2338         return false;
2339       }
2340     } else if (auto* ms = tex->As<sem::MultisampledTexture>()) {
2341       if (!EmitType(out, ms->type(), "")) {
2342         return false;
2343       }
2344       out << ", access::read";
2345     } else if (auto* sampled = tex->As<sem::SampledTexture>()) {
2346       if (!EmitType(out, sampled->type(), "")) {
2347         return false;
2348       }
2349       out << ", access::sample";
2350     } else {
2351       diagnostics_.add_error(diag::System::Writer, "invalid texture type");
2352       return false;
2353     }
2354     out << ">";
2355     return true;
2356   }
2357 
2358   if (type->Is<sem::U32>()) {
2359     out << "uint";
2360     return true;
2361   }
2362 
2363   if (auto* vec = type->As<sem::Vector>()) {
2364     if (!EmitType(out, vec->type(), "")) {
2365       return false;
2366     }
2367     out << vec->Width();
2368     return true;
2369   }
2370 
2371   if (type->Is<sem::Void>()) {
2372     out << "void";
2373     return true;
2374   }
2375 
2376   diagnostics_.add_error(diag::System::Writer,
2377                          "unknown type in EmitType: " + type->type_name());
2378   return false;
2379 }
2380 
EmitTypeAndName(std::ostream & out,const sem::Type * type,const std::string & name)2381 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
2382                                     const sem::Type* type,
2383                                     const std::string& name) {
2384   bool name_printed = false;
2385   if (!EmitType(out, type, name, &name_printed)) {
2386     return false;
2387   }
2388   if (!name_printed) {
2389     out << " " << name;
2390   }
2391   return true;
2392 }
2393 
EmitStorageClass(std::ostream & out,ast::StorageClass sc)2394 bool GeneratorImpl::EmitStorageClass(std::ostream& out, ast::StorageClass sc) {
2395   switch (sc) {
2396     case ast::StorageClass::kFunction:
2397     case ast::StorageClass::kPrivate:
2398     case ast::StorageClass::kUniformConstant:
2399       out << "thread";
2400       return true;
2401     case ast::StorageClass::kWorkgroup:
2402       out << "threadgroup";
2403       return true;
2404     case ast::StorageClass::kStorage:
2405       out << "device";
2406       return true;
2407     case ast::StorageClass::kUniform:
2408       out << "constant";
2409       return true;
2410     default:
2411       break;
2412   }
2413   TINT_ICE(Writer, diagnostics_) << "unhandled storage class: " << sc;
2414   return false;
2415 }
2416 
EmitPackedType(std::ostream & out,const sem::Type * type,const std::string & name)2417 bool GeneratorImpl::EmitPackedType(std::ostream& out,
2418                                    const sem::Type* type,
2419                                    const std::string& name) {
2420   auto* vec = type->As<sem::Vector>();
2421   if (vec && vec->Width() == 3) {
2422     out << "packed_";
2423     if (!EmitType(out, vec, "")) {
2424       return false;
2425     }
2426 
2427     if (vec->is_float_vector() && !matrix_packed_vector_overloads_) {
2428       // Overload operators for matrix-vector arithmetic where the vector
2429       // operand is packed, as these overloads to not exist in the metal
2430       // namespace.
2431       TextBuffer b;
2432       TINT_DEFER(helpers_.Append(b));
2433       line(&b) << R"(template<typename T, int N, int M>
2434 inline vec<T, M> operator*(matrix<T, N, M> lhs, packed_vec<T, N> rhs) {
2435   return lhs * vec<T, N>(rhs);
2436 }
2437 
2438 template<typename T, int N, int M>
2439 inline vec<T, N> operator*(packed_vec<T, M> lhs, matrix<T, N, M> rhs) {
2440   return vec<T, M>(lhs) * rhs;
2441 }
2442 )";
2443       matrix_packed_vector_overloads_ = true;
2444     }
2445 
2446     return true;
2447   }
2448 
2449   return EmitType(out, type, name);
2450 }
2451 
EmitStructType(TextBuffer * b,const sem::Struct * str)2452 bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
2453   line(b) << "struct " << StructName(str) << " {";
2454 
2455   bool is_host_shareable = str->IsHostShareable();
2456 
2457   // Emits a `/* 0xnnnn */` byte offset comment for a struct member.
2458   auto add_byte_offset_comment = [&](std::ostream& out, uint32_t offset) {
2459     std::ios_base::fmtflags saved_flag_state(out.flags());
2460     out << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset
2461         << " */ ";
2462     out.flags(saved_flag_state);
2463   };
2464 
2465   auto add_padding = [&](uint32_t size, uint32_t msl_offset) {
2466     std::string name;
2467     do {
2468       name = UniqueIdentifier("tint_pad");
2469     } while (str->FindMember(program_->Symbols().Get(name)));
2470 
2471     auto out = line(b);
2472     add_byte_offset_comment(out, msl_offset);
2473     out << "int8_t " << name << "[" << size << "];";
2474   };
2475 
2476   b->IncrementIndent();
2477 
2478   uint32_t msl_offset = 0;
2479   for (auto* mem : str->Members()) {
2480     auto out = line(b);
2481     auto name = program_->Symbols().NameFor(mem->Name());
2482     auto wgsl_offset = mem->Offset();
2483 
2484     if (is_host_shareable) {
2485       if (wgsl_offset < msl_offset) {
2486         // Unimplementable layout
2487         TINT_ICE(Writer, diagnostics_)
2488             << "Structure member WGSL offset (" << wgsl_offset
2489             << ") is behind MSL offset (" << msl_offset << ")";
2490         return false;
2491       }
2492 
2493       // Generate padding if required
2494       if (auto padding = wgsl_offset - msl_offset) {
2495         add_padding(padding, msl_offset);
2496         msl_offset += padding;
2497       }
2498 
2499       add_byte_offset_comment(out, msl_offset);
2500 
2501       if (!EmitPackedType(out, mem->Type(), name)) {
2502         return false;
2503       }
2504     } else {
2505       if (!EmitType(out, mem->Type(), name)) {
2506         return false;
2507       }
2508     }
2509 
2510     auto* ty = mem->Type();
2511 
2512     // Array member name will be output with the type
2513     if (!ty->Is<sem::Array>()) {
2514       out << " " << name;
2515     }
2516 
2517     // Emit decorations
2518     if (auto* decl = mem->Declaration()) {
2519       for (auto* deco : decl->decorations) {
2520         if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
2521           auto attr = builtin_to_attribute(builtin->builtin);
2522           if (attr.empty()) {
2523             diagnostics_.add_error(diag::System::Writer, "unknown builtin");
2524             return false;
2525           }
2526           out << " [[" << attr << "]]";
2527         } else if (auto* loc = deco->As<ast::LocationDecoration>()) {
2528           auto& pipeline_stage_uses = str->PipelineStageUses();
2529           if (pipeline_stage_uses.size() != 1) {
2530             TINT_ICE(Writer, diagnostics_)
2531                 << "invalid entry point IO struct uses";
2532           }
2533 
2534           if (pipeline_stage_uses.count(
2535                   sem::PipelineStageUsage::kVertexInput)) {
2536             out << " [[attribute(" + std::to_string(loc->value) + ")]]";
2537           } else if (pipeline_stage_uses.count(
2538                          sem::PipelineStageUsage::kVertexOutput)) {
2539             out << " [[user(locn" + std::to_string(loc->value) + ")]]";
2540           } else if (pipeline_stage_uses.count(
2541                          sem::PipelineStageUsage::kFragmentInput)) {
2542             out << " [[user(locn" + std::to_string(loc->value) + ")]]";
2543           } else if (pipeline_stage_uses.count(
2544                          sem::PipelineStageUsage::kFragmentOutput)) {
2545             out << " [[color(" + std::to_string(loc->value) + ")]]";
2546           } else {
2547             TINT_ICE(Writer, diagnostics_)
2548                 << "invalid use of location decoration";
2549           }
2550         } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
2551           auto attr = interpolation_to_attribute(interpolate->type,
2552                                                  interpolate->sampling);
2553           if (attr.empty()) {
2554             diagnostics_.add_error(diag::System::Writer,
2555                                    "unknown interpolation attribute");
2556             return false;
2557           }
2558           out << " [[" << attr << "]]";
2559         } else if (deco->Is<ast::InvariantDecoration>()) {
2560           if (invariant_define_name_.empty()) {
2561             invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
2562           }
2563           out << " " << invariant_define_name_;
2564         } else if (!deco->IsAnyOf<ast::StructMemberOffsetDecoration,
2565                                   ast::StructMemberAlignDecoration,
2566                                   ast::StructMemberSizeDecoration>()) {
2567           TINT_ICE(Writer, diagnostics_)
2568               << "unhandled struct member attribute: " << deco->Name();
2569         }
2570       }
2571     }
2572 
2573     out << ";";
2574 
2575     if (is_host_shareable) {
2576       // Calculate new MSL offset
2577       auto size_align = MslPackedTypeSizeAndAlign(ty);
2578       if (msl_offset % size_align.align) {
2579         TINT_ICE(Writer, diagnostics_)
2580             << "Misaligned MSL structure member "
2581             << ty->FriendlyName(program_->Symbols()) << " " << name;
2582         return false;
2583       }
2584       msl_offset += size_align.size;
2585     }
2586   }
2587 
2588   if (is_host_shareable && str->Size() != msl_offset) {
2589     add_padding(str->Size() - msl_offset, msl_offset);
2590   }
2591 
2592   b->DecrementIndent();
2593 
2594   line(b) << "};";
2595   return true;
2596 }
2597 
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)2598 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
2599                                 const ast::UnaryOpExpression* expr) {
2600   // Handle `-e` when `e` is signed, so that we ensure that if `e` is the
2601   // largest negative value, it returns `e`.
2602   auto* expr_type = TypeOf(expr->expr)->UnwrapRef();
2603   if (expr->op == ast::UnaryOp::kNegation &&
2604       expr_type->is_signed_scalar_or_vector()) {
2605     auto fn =
2606         utils::GetOrCreate(unary_minus_funcs_, expr_type, [&]() -> std::string {
2607           // e.g.:
2608           // int tint_unary_minus(const int v) {
2609           //     return (v == -2147483648) ? v : -v;
2610           // }
2611           TextBuffer b;
2612           TINT_DEFER(helpers_.Append(b));
2613 
2614           auto fn_name = UniqueIdentifier("tint_unary_minus");
2615           {
2616             auto decl = line(&b);
2617             if (!EmitTypeAndName(decl, expr_type, fn_name)) {
2618               return "";
2619             }
2620             decl << "(const ";
2621             if (!EmitType(decl, expr_type, "")) {
2622               return "";
2623             }
2624             decl << " v) {";
2625           }
2626 
2627           {
2628             ScopedIndent si(&b);
2629             const auto largest_negative_value =
2630                 std::to_string(std::numeric_limits<int32_t>::min());
2631             line(&b) << "return select(-v, v, v == " << largest_negative_value
2632                      << ");";
2633           }
2634           line(&b) << "}";
2635           line(&b);
2636           return fn_name;
2637         });
2638 
2639     out << fn << "(";
2640     if (!EmitExpression(out, expr->expr)) {
2641       return false;
2642     }
2643     out << ")";
2644     return true;
2645   }
2646 
2647   switch (expr->op) {
2648     case ast::UnaryOp::kAddressOf:
2649       out << "&";
2650       break;
2651     case ast::UnaryOp::kComplement:
2652       out << "~";
2653       break;
2654     case ast::UnaryOp::kIndirection:
2655       out << "*";
2656       break;
2657     case ast::UnaryOp::kNot:
2658       out << "!";
2659       break;
2660     case ast::UnaryOp::kNegation:
2661       out << "-";
2662       break;
2663   }
2664   out << "(";
2665 
2666   if (!EmitExpression(out, expr->expr)) {
2667     return false;
2668   }
2669 
2670   out << ")";
2671 
2672   return true;
2673 }
2674 
EmitVariable(const sem::Variable * var)2675 bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
2676   auto* decl = var->Declaration();
2677 
2678   for (auto* deco : decl->decorations) {
2679     if (!deco->Is<ast::InternalDecoration>()) {
2680       TINT_ICE(Writer, diagnostics_) << "unexpected variable decoration";
2681       return false;
2682     }
2683   }
2684 
2685   auto out = line();
2686 
2687   switch (var->StorageClass()) {
2688     case ast::StorageClass::kFunction:
2689     case ast::StorageClass::kUniformConstant:
2690     case ast::StorageClass::kNone:
2691       break;
2692     case ast::StorageClass::kPrivate:
2693       out << "thread ";
2694       break;
2695     case ast::StorageClass::kWorkgroup:
2696       out << "threadgroup ";
2697       break;
2698     default:
2699       TINT_ICE(Writer, diagnostics_) << "unhandled variable storage class";
2700       return false;
2701   }
2702 
2703   auto* type = var->Type()->UnwrapRef();
2704 
2705   std::string name = program_->Symbols().NameFor(decl->symbol);
2706   if (decl->is_const) {
2707     name = "const " + name;
2708   }
2709   if (!EmitType(out, type, name)) {
2710     return false;
2711   }
2712   // Variable name is output as part of the type for arrays and pointers.
2713   if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
2714     out << " " << name;
2715   }
2716 
2717   if (decl->constructor != nullptr) {
2718     out << " = ";
2719     if (!EmitExpression(out, decl->constructor)) {
2720       return false;
2721     }
2722   } else if (var->StorageClass() == ast::StorageClass::kPrivate ||
2723              var->StorageClass() == ast::StorageClass::kFunction ||
2724              var->StorageClass() == ast::StorageClass::kNone) {
2725     out << " = ";
2726     if (!EmitZeroValue(out, type)) {
2727       return false;
2728     }
2729   }
2730   out << ";";
2731 
2732   return true;
2733 }
2734 
EmitProgramConstVariable(const ast::Variable * var)2735 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
2736   for (auto* d : var->decorations) {
2737     if (!d->Is<ast::OverrideDecoration>()) {
2738       diagnostics_.add_error(diag::System::Writer,
2739                              "Decorated const values not valid");
2740       return false;
2741     }
2742   }
2743   if (!var->is_const) {
2744     diagnostics_.add_error(diag::System::Writer, "Expected a const value");
2745     return false;
2746   }
2747 
2748   auto out = line();
2749   out << "constant ";
2750   auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
2751   if (!EmitType(out, type, program_->Symbols().NameFor(var->symbol))) {
2752     return false;
2753   }
2754   if (!type->Is<sem::Array>()) {
2755     out << " " << program_->Symbols().NameFor(var->symbol);
2756   }
2757 
2758   auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
2759   if (global && global->IsOverridable()) {
2760     out << " [[function_constant(" << global->ConstantId() << ")]]";
2761   } else if (var->constructor != nullptr) {
2762     out << " = ";
2763     if (!EmitExpression(out, var->constructor)) {
2764       return false;
2765     }
2766   }
2767   out << ";";
2768 
2769   return true;
2770 }
2771 
MslPackedTypeSizeAndAlign(const sem::Type * ty)2772 GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
2773     const sem::Type* ty) {
2774   if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2775     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2776     // 2.1 Scalar Data Types
2777     return {4, 4};
2778   }
2779 
2780   if (auto* vec = ty->As<sem::Vector>()) {
2781     auto num_els = vec->Width();
2782     auto* el_ty = vec->type();
2783     if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2784       // Use a packed_vec type for 3-element vectors only.
2785       if (num_els == 3) {
2786         // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2787         // 2.2.3 Packed Vector Types
2788         return SizeAndAlign{num_els * 4, 4};
2789       } else {
2790         // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2791         // 2.2 Vector Data Types
2792         return SizeAndAlign{num_els * 4, num_els * 4};
2793       }
2794     }
2795   }
2796 
2797   if (auto* mat = ty->As<sem::Matrix>()) {
2798     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2799     // 2.3 Matrix Data Types
2800     auto cols = mat->columns();
2801     auto rows = mat->rows();
2802     auto* el_ty = mat->type();
2803     if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2804       static constexpr SizeAndAlign table[] = {
2805           /* float2x2 */ {16, 8},
2806           /* float2x3 */ {32, 16},
2807           /* float2x4 */ {32, 16},
2808           /* float3x2 */ {24, 8},
2809           /* float3x3 */ {48, 16},
2810           /* float3x4 */ {48, 16},
2811           /* float4x2 */ {32, 8},
2812           /* float4x3 */ {64, 16},
2813           /* float4x4 */ {64, 16},
2814       };
2815       if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
2816         return table[(3 * (cols - 2)) + (rows - 2)];
2817       }
2818     }
2819   }
2820 
2821   if (auto* arr = ty->As<sem::Array>()) {
2822     if (!arr->IsStrideImplicit()) {
2823       TINT_ICE(Writer, diagnostics_)
2824           << "arrays with explicit strides should have "
2825              "removed with the PadArrayElements transform";
2826       return {};
2827     }
2828     auto num_els = std::max<uint32_t>(arr->Count(), 1);
2829     return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
2830   }
2831 
2832   if (auto* str = ty->As<sem::Struct>()) {
2833     // TODO(crbug.com/tint/650): There's an assumption here that MSL's default
2834     // structure size and alignment matches WGSL's. We need to confirm this.
2835     return SizeAndAlign{str->Size(), str->Align()};
2836   }
2837 
2838   if (auto* atomic = ty->As<sem::Atomic>()) {
2839     return MslPackedTypeSizeAndAlign(atomic->Type());
2840   }
2841 
2842   TINT_UNREACHABLE(Writer, diagnostics_)
2843       << "Unhandled type " << ty->TypeInfo().name;
2844   return {};
2845 }
2846 
2847 template <typename F>
CallIntrinsicHelper(std::ostream & out,const ast::CallExpression * call,const sem::Intrinsic * intrinsic,F && build)2848 bool GeneratorImpl::CallIntrinsicHelper(std::ostream& out,
2849                                         const ast::CallExpression* call,
2850                                         const sem::Intrinsic* intrinsic,
2851                                         F&& build) {
2852   // Generate the helper function if it hasn't been created already
2853   auto fn = utils::GetOrCreate(intrinsics_, intrinsic, [&]() -> std::string {
2854     TextBuffer b;
2855     TINT_DEFER(helpers_.Append(b));
2856 
2857     auto fn_name =
2858         UniqueIdentifier(std::string("tint_") + sem::str(intrinsic->Type()));
2859     std::vector<std::string> parameter_names;
2860     {
2861       auto decl = line(&b);
2862       if (!EmitTypeAndName(decl, intrinsic->ReturnType(), fn_name)) {
2863         return "";
2864       }
2865       {
2866         ScopedParen sp(decl);
2867         for (auto* param : intrinsic->Parameters()) {
2868           if (!parameter_names.empty()) {
2869             decl << ", ";
2870           }
2871           auto param_name = "param_" + std::to_string(parameter_names.size());
2872           if (!EmitTypeAndName(decl, param->Type(), param_name)) {
2873             return "";
2874           }
2875           parameter_names.emplace_back(std::move(param_name));
2876         }
2877       }
2878       decl << " {";
2879     }
2880     {
2881       ScopedIndent si(&b);
2882       if (!build(&b, parameter_names)) {
2883         return "";
2884       }
2885     }
2886     line(&b) << "}";
2887     line(&b);
2888     return fn_name;
2889   });
2890 
2891   if (fn.empty()) {
2892     return false;
2893   }
2894 
2895   // Call the helper
2896   out << fn;
2897   {
2898     ScopedParen sp(out);
2899     bool first = true;
2900     for (auto* arg : call->args) {
2901       if (!first) {
2902         out << ", ";
2903       }
2904       first = false;
2905       if (!EmitExpression(out, arg)) {
2906         return false;
2907       }
2908     }
2909   }
2910   return true;
2911 }
2912 
2913 }  // namespace msl
2914 }  // namespace writer
2915 }  // namespace tint
2916