1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/writer/msl/generator_impl.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <limits>
21 #include <utility>
22 #include <vector>
23
24 #include "src/ast/alias.h"
25 #include "src/ast/bool_literal_expression.h"
26 #include "src/ast/call_statement.h"
27 #include "src/ast/disable_validation_decoration.h"
28 #include "src/ast/fallthrough_statement.h"
29 #include "src/ast/float_literal_expression.h"
30 #include "src/ast/interpolate_decoration.h"
31 #include "src/ast/module.h"
32 #include "src/ast/override_decoration.h"
33 #include "src/ast/sint_literal_expression.h"
34 #include "src/ast/uint_literal_expression.h"
35 #include "src/ast/variable_decl_statement.h"
36 #include "src/ast/void.h"
37 #include "src/sem/array.h"
38 #include "src/sem/atomic_type.h"
39 #include "src/sem/bool_type.h"
40 #include "src/sem/call.h"
41 #include "src/sem/depth_multisampled_texture_type.h"
42 #include "src/sem/depth_texture_type.h"
43 #include "src/sem/f32_type.h"
44 #include "src/sem/function.h"
45 #include "src/sem/i32_type.h"
46 #include "src/sem/matrix_type.h"
47 #include "src/sem/member_accessor_expression.h"
48 #include "src/sem/multisampled_texture_type.h"
49 #include "src/sem/pointer_type.h"
50 #include "src/sem/reference_type.h"
51 #include "src/sem/sampled_texture_type.h"
52 #include "src/sem/storage_texture_type.h"
53 #include "src/sem/struct.h"
54 #include "src/sem/type_constructor.h"
55 #include "src/sem/type_conversion.h"
56 #include "src/sem/u32_type.h"
57 #include "src/sem/variable.h"
58 #include "src/sem/vector_type.h"
59 #include "src/sem/void_type.h"
60 #include "src/transform/array_length_from_uniform.h"
61 #include "src/transform/canonicalize_entry_point_io.h"
62 #include "src/transform/external_texture_transform.h"
63 #include "src/transform/manager.h"
64 #include "src/transform/module_scope_var_to_entry_point_param.h"
65 #include "src/transform/pad_array_elements.h"
66 #include "src/transform/promote_initializers_to_const_var.h"
67 #include "src/transform/remove_phonies.h"
68 #include "src/transform/simplify_pointers.h"
69 #include "src/transform/unshadow.h"
70 #include "src/transform/vectorize_scalar_matrix_constructors.h"
71 #include "src/transform/wrap_arrays_in_structs.h"
72 #include "src/transform/zero_init_workgroup_memory.h"
73 #include "src/utils/defer.h"
74 #include "src/utils/map.h"
75 #include "src/utils/scoped_assignment.h"
76 #include "src/writer/float_to_string.h"
77
78 namespace tint {
79 namespace writer {
80 namespace msl {
81 namespace {
82
last_is_break_or_fallthrough(const ast::BlockStatement * stmts)83 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
84 return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
85 }
86
87 class ScopedBitCast {
88 public:
ScopedBitCast(GeneratorImpl * generator,std::ostream & stream,const sem::Type * curr_type,const sem::Type * target_type)89 ScopedBitCast(GeneratorImpl* generator,
90 std::ostream& stream,
91 const sem::Type* curr_type,
92 const sem::Type* target_type)
93 : s(stream) {
94 auto* target_vec_type = target_type->As<sem::Vector>();
95
96 // If we need to promote from scalar to vector, bitcast the scalar to the
97 // vector element type.
98 if (curr_type->is_scalar() && target_vec_type) {
99 target_type = target_vec_type->type();
100 }
101
102 // Bit cast
103 s << "as_type<";
104 generator->EmitType(s, target_type, "");
105 s << ">(";
106 }
107
~ScopedBitCast()108 ~ScopedBitCast() { s << ")"; }
109
110 private:
111 std::ostream& s;
112 };
113 } // namespace
114
115 SanitizedResult::SanitizedResult() = default;
116 SanitizedResult::~SanitizedResult() = default;
117 SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
118
Sanitize(const Program * in,uint32_t buffer_size_ubo_index,uint32_t fixed_sample_mask,bool emit_vertex_point_size,bool disable_workgroup_init,const ArrayLengthFromUniformOptions & array_length_from_uniform)119 SanitizedResult Sanitize(
120 const Program* in,
121 uint32_t buffer_size_ubo_index,
122 uint32_t fixed_sample_mask,
123 bool emit_vertex_point_size,
124 bool disable_workgroup_init,
125 const ArrayLengthFromUniformOptions& array_length_from_uniform) {
126 transform::Manager manager;
127 transform::DataMap internal_inputs;
128
129 // Build the config for the internal ArrayLengthFromUniform transform.
130 transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
131 array_length_from_uniform.ubo_binding);
132 if (!array_length_from_uniform.bindpoint_to_size_index.empty()) {
133 // If |array_length_from_uniform| bindings are provided, use that config.
134 array_length_from_uniform_cfg.bindpoint_to_size_index =
135 array_length_from_uniform.bindpoint_to_size_index;
136 } else {
137 // If the binding map is empty, use the deprecated |buffer_size_ubo_index|
138 // and automatically choose indices using the binding numbers.
139 array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config(
140 sem::BindingPoint{0, buffer_size_ubo_index});
141 // Use the SSBO binding numbers as the indices for the buffer size lookups.
142 for (auto* var : in->AST().GlobalVariables()) {
143 auto* global = in->Sem().Get<sem::GlobalVariable>(var);
144 if (global && global->StorageClass() == ast::StorageClass::kStorage) {
145 array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
146 global->BindingPoint(), global->BindingPoint().binding);
147 }
148 }
149 }
150
151 // Build the configs for the internal CanonicalizeEntryPointIO transform.
152 auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config(
153 transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask,
154 emit_vertex_point_size);
155
156 manager.Add<transform::Unshadow>();
157
158 if (!disable_workgroup_init) {
159 // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
160 // ZeroInitWorkgroupMemory may inject new builtin parameters.
161 manager.Add<transform::ZeroInitWorkgroupMemory>();
162 }
163 manager.Add<transform::CanonicalizeEntryPointIO>();
164 manager.Add<transform::ExternalTextureTransform>();
165 manager.Add<transform::PromoteInitializersToConstVar>();
166 manager.Add<transform::VectorizeScalarMatrixConstructors>();
167 manager.Add<transform::WrapArraysInStructs>();
168 manager.Add<transform::PadArrayElements>();
169 manager.Add<transform::RemovePhonies>();
170 manager.Add<transform::SimplifyPointers>();
171 // ArrayLengthFromUniform must come after SimplifyPointers, as
172 // it assumes that the form of the array length argument is &var.array.
173 manager.Add<transform::ArrayLengthFromUniform>();
174 manager.Add<transform::ModuleScopeVarToEntryPointParam>();
175 internal_inputs.Add<transform::ArrayLengthFromUniform::Config>(
176 std::move(array_length_from_uniform_cfg));
177 internal_inputs.Add<transform::CanonicalizeEntryPointIO::Config>(
178 std::move(entry_point_io_cfg));
179 auto out = manager.Run(in, internal_inputs);
180
181 SanitizedResult result;
182 result.program = std::move(out.program);
183 if (!result.program.IsValid()) {
184 return result;
185 }
186 result.used_array_length_from_uniform_indices =
187 std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
188 ->used_size_indices);
189 result.needs_storage_buffer_sizes =
190 !result.used_array_length_from_uniform_indices.empty();
191 return result;
192 }
193
GeneratorImpl(const Program * program)194 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
195
196 GeneratorImpl::~GeneratorImpl() = default;
197
Generate()198 bool GeneratorImpl::Generate() {
199 line() << "#include <metal_stdlib>";
200 line();
201 line() << "using namespace metal;";
202
203 auto helpers_insertion_point = current_buffer_->lines.size();
204
205 for (auto* const type_decl : program_->AST().TypeDecls()) {
206 if (!type_decl->Is<ast::Alias>()) {
207 if (!EmitTypeDecl(TypeOf(type_decl))) {
208 return false;
209 }
210 }
211 }
212
213 if (!program_->AST().TypeDecls().empty()) {
214 line();
215 }
216
217 for (auto* var : program_->AST().GlobalVariables()) {
218 if (var->is_const) {
219 if (!EmitProgramConstVariable(var)) {
220 return false;
221 }
222 } else {
223 // These are pushed into the entry point by sanitizer transforms.
224 TINT_ICE(Writer, diagnostics_) << "module-scope variables should have "
225 "been handled by the MSL sanitizer";
226 break;
227 }
228 }
229
230 for (auto* func : program_->AST().Functions()) {
231 if (!func->IsEntryPoint()) {
232 if (!EmitFunction(func)) {
233 return false;
234 }
235 } else {
236 if (!EmitEntryPointFunction(func)) {
237 return false;
238 }
239 }
240 line();
241 }
242
243 if (!invariant_define_name_.empty()) {
244 // 'invariant' attribute requires MSL 2.1 or higher.
245 // WGSL can ignore the invariant attribute on pre MSL 2.1 devices.
246 // See: https://github.com/gpuweb/gpuweb/issues/893#issuecomment-745537465
247 line(&helpers_) << "#if __METAL_VERSION__ >= 210";
248 line(&helpers_) << "#define " << invariant_define_name_ << " [[invariant]]";
249 line(&helpers_) << "#else";
250 line(&helpers_) << "#define " << invariant_define_name_;
251 line(&helpers_) << "#endif";
252 line(&helpers_);
253 }
254
255 if (!helpers_.lines.empty()) {
256 current_buffer_->Insert("", helpers_insertion_point++, 0);
257 current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
258 }
259
260 return true;
261 }
262
EmitTypeDecl(const sem::Type * ty)263 bool GeneratorImpl::EmitTypeDecl(const sem::Type* ty) {
264 if (auto* str = ty->As<sem::Struct>()) {
265 if (!EmitStructType(current_buffer_, str)) {
266 return false;
267 }
268 } else {
269 diagnostics_.add_error(diag::System::Writer,
270 "unknown alias type: " + ty->type_name());
271 return false;
272 }
273
274 return true;
275 }
276
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)277 bool GeneratorImpl::EmitIndexAccessor(
278 std::ostream& out,
279 const ast::IndexAccessorExpression* expr) {
280 bool paren_lhs =
281 !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
282 ast::IdentifierExpression,
283 ast::MemberAccessorExpression>();
284
285 if (paren_lhs) {
286 out << "(";
287 }
288 if (!EmitExpression(out, expr->object)) {
289 return false;
290 }
291 if (paren_lhs) {
292 out << ")";
293 }
294
295 out << "[";
296
297 if (!EmitExpression(out, expr->index)) {
298 return false;
299 }
300 out << "]";
301
302 return true;
303 }
304
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)305 bool GeneratorImpl::EmitBitcast(std::ostream& out,
306 const ast::BitcastExpression* expr) {
307 out << "as_type<";
308 if (!EmitType(out, TypeOf(expr)->UnwrapRef(), "")) {
309 return false;
310 }
311
312 out << ">(";
313 if (!EmitExpression(out, expr->expr)) {
314 return false;
315 }
316
317 out << ")";
318 return true;
319 }
320
EmitAssign(const ast::AssignmentStatement * stmt)321 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
322 auto out = line();
323
324 if (!EmitExpression(out, stmt->lhs)) {
325 return false;
326 }
327
328 out << " = ";
329
330 if (!EmitExpression(out, stmt->rhs)) {
331 return false;
332 }
333
334 out << ";";
335
336 return true;
337 }
338
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)339 bool GeneratorImpl::EmitBinary(std::ostream& out,
340 const ast::BinaryExpression* expr) {
341 auto emit_op = [&] {
342 out << " ";
343
344 switch (expr->op) {
345 case ast::BinaryOp::kAnd:
346 out << "&";
347 break;
348 case ast::BinaryOp::kOr:
349 out << "|";
350 break;
351 case ast::BinaryOp::kXor:
352 out << "^";
353 break;
354 case ast::BinaryOp::kLogicalAnd:
355 out << "&&";
356 break;
357 case ast::BinaryOp::kLogicalOr:
358 out << "||";
359 break;
360 case ast::BinaryOp::kEqual:
361 out << "==";
362 break;
363 case ast::BinaryOp::kNotEqual:
364 out << "!=";
365 break;
366 case ast::BinaryOp::kLessThan:
367 out << "<";
368 break;
369 case ast::BinaryOp::kGreaterThan:
370 out << ">";
371 break;
372 case ast::BinaryOp::kLessThanEqual:
373 out << "<=";
374 break;
375 case ast::BinaryOp::kGreaterThanEqual:
376 out << ">=";
377 break;
378 case ast::BinaryOp::kShiftLeft:
379 out << "<<";
380 break;
381 case ast::BinaryOp::kShiftRight:
382 // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
383 // implementation-defined behaviour for negative LHS. We may have to
384 // generate extra code to implement WGSL-specified behaviour for
385 // negative LHS.
386 out << R"(>>)";
387 break;
388
389 case ast::BinaryOp::kAdd:
390 out << "+";
391 break;
392 case ast::BinaryOp::kSubtract:
393 out << "-";
394 break;
395 case ast::BinaryOp::kMultiply:
396 out << "*";
397 break;
398 case ast::BinaryOp::kDivide:
399 out << "/";
400 break;
401 case ast::BinaryOp::kModulo:
402 out << "%";
403 break;
404 case ast::BinaryOp::kNone:
405 diagnostics_.add_error(diag::System::Writer,
406 "missing binary operation type");
407 return false;
408 }
409 out << " ";
410 return true;
411 };
412
413 auto signed_type_of = [&](const sem::Type* ty) -> const sem::Type* {
414 if (ty->is_integer_scalar()) {
415 return builder_.create<sem::I32>();
416 } else if (auto* v = ty->As<sem::Vector>()) {
417 return builder_.create<sem::Vector>(builder_.create<sem::I32>(),
418 v->Width());
419 }
420 return {};
421 };
422
423 auto unsigned_type_of = [&](const sem::Type* ty) -> const sem::Type* {
424 if (ty->is_integer_scalar()) {
425 return builder_.create<sem::U32>();
426 } else if (auto* v = ty->As<sem::Vector>()) {
427 return builder_.create<sem::Vector>(builder_.create<sem::U32>(),
428 v->Width());
429 }
430 return {};
431 };
432
433 auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
434 auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
435
436 // Handle fmod
437 if (expr->op == ast::BinaryOp::kModulo &&
438 lhs_type->is_float_scalar_or_vector()) {
439 out << "fmod";
440 ScopedParen sp(out);
441 if (!EmitExpression(out, expr->lhs)) {
442 return false;
443 }
444 out << ", ";
445 if (!EmitExpression(out, expr->rhs)) {
446 return false;
447 }
448 return true;
449 }
450
451 // Handle +/-/* of signed values
452 if ((expr->IsAdd() || expr->IsSubtract() || expr->IsMultiply()) &&
453 lhs_type->is_signed_scalar_or_vector() &&
454 rhs_type->is_signed_scalar_or_vector()) {
455 // If lhs or rhs is a vector, use that type (support implicit scalar to
456 // vector promotion)
457 auto* target_type =
458 lhs_type->Is<sem::Vector>()
459 ? lhs_type
460 : (rhs_type->Is<sem::Vector>() ? rhs_type : lhs_type);
461
462 // WGSL defines behaviour for signed overflow, MSL does not. For these
463 // cases, bitcast operands to unsigned, then cast result to signed.
464 ScopedBitCast outer_int_cast(this, out, target_type,
465 signed_type_of(target_type));
466 ScopedParen sp(out);
467 {
468 ScopedBitCast lhs_uint_cast(this, out, lhs_type,
469 unsigned_type_of(target_type));
470 if (!EmitExpression(out, expr->lhs)) {
471 return false;
472 }
473 }
474 if (!emit_op()) {
475 return false;
476 }
477 {
478 ScopedBitCast rhs_uint_cast(this, out, rhs_type,
479 unsigned_type_of(target_type));
480 if (!EmitExpression(out, expr->rhs)) {
481 return false;
482 }
483 }
484 return true;
485 }
486
487 // Handle left bit shifting a signed value
488 // TODO(crbug.com/tint/1077): This may not be necessary. The MSL spec
489 // seems to imply that left shifting a signed value is treated the same as
490 // left shifting an unsigned value, but we need to make sure.
491 if (expr->IsShiftLeft() && lhs_type->is_signed_scalar_or_vector()) {
492 // Shift left: discards top bits, so convert first operand to unsigned
493 // first, then convert result back to signed
494 ScopedBitCast outer_int_cast(this, out, lhs_type, signed_type_of(lhs_type));
495 ScopedParen sp(out);
496 {
497 ScopedBitCast lhs_uint_cast(this, out, lhs_type,
498 unsigned_type_of(lhs_type));
499 if (!EmitExpression(out, expr->lhs)) {
500 return false;
501 }
502 }
503 if (!emit_op()) {
504 return false;
505 }
506 if (!EmitExpression(out, expr->rhs)) {
507 return false;
508 }
509 return true;
510 }
511
512 // Emit as usual
513 ScopedParen sp(out);
514 if (!EmitExpression(out, expr->lhs)) {
515 return false;
516 }
517 if (!emit_op()) {
518 return false;
519 }
520 if (!EmitExpression(out, expr->rhs)) {
521 return false;
522 }
523
524 return true;
525 }
526
EmitBreak(const ast::BreakStatement *)527 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
528 line() << "break;";
529 return true;
530 }
531
EmitCall(std::ostream & out,const ast::CallExpression * expr)532 bool GeneratorImpl::EmitCall(std::ostream& out,
533 const ast::CallExpression* expr) {
534 auto* call = program_->Sem().Get(expr);
535 auto* target = call->Target();
536
537 if (auto* func = target->As<sem::Function>()) {
538 return EmitFunctionCall(out, call, func);
539 }
540 if (auto* intrinsic = target->As<sem::Intrinsic>()) {
541 return EmitIntrinsicCall(out, call, intrinsic);
542 }
543 if (auto* conv = target->As<sem::TypeConversion>()) {
544 return EmitTypeConversion(out, call, conv);
545 }
546 if (auto* ctor = target->As<sem::TypeConstructor>()) {
547 return EmitTypeConstructor(out, call, ctor);
548 }
549
550 TINT_ICE(Writer, diagnostics_)
551 << "unhandled call target: " << target->TypeInfo().name;
552 return false;
553 }
554
EmitFunctionCall(std::ostream & out,const sem::Call * call,const sem::Function *)555 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
556 const sem::Call* call,
557 const sem::Function*) {
558 auto* ident = call->Declaration()->target.name;
559 out << program_->Symbols().NameFor(ident->symbol) << "(";
560
561 bool first = true;
562 for (auto* arg : call->Arguments()) {
563 if (!first) {
564 out << ", ";
565 }
566 first = false;
567
568 if (!EmitExpression(out, arg->Declaration())) {
569 return false;
570 }
571 }
572
573 out << ")";
574 return true;
575 }
576
EmitIntrinsicCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)577 bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
578 const sem::Call* call,
579 const sem::Intrinsic* intrinsic) {
580 auto* expr = call->Declaration();
581 if (intrinsic->IsAtomic()) {
582 return EmitAtomicCall(out, expr, intrinsic);
583 }
584 if (intrinsic->IsTexture()) {
585 return EmitTextureCall(out, call, intrinsic);
586 }
587
588 auto name = generate_builtin_name(intrinsic);
589
590 switch (intrinsic->Type()) {
591 case sem::IntrinsicType::kDot:
592 return EmitDotCall(out, expr, intrinsic);
593 case sem::IntrinsicType::kModf:
594 return EmitModfCall(out, expr, intrinsic);
595 case sem::IntrinsicType::kFrexp:
596 return EmitFrexpCall(out, expr, intrinsic);
597
598 case sem::IntrinsicType::kPack2x16float:
599 case sem::IntrinsicType::kUnpack2x16float: {
600 if (intrinsic->Type() == sem::IntrinsicType::kPack2x16float) {
601 out << "as_type<uint>(half2(";
602 } else {
603 out << "float2(as_type<half2>(";
604 }
605 if (!EmitExpression(out, expr->args[0])) {
606 return false;
607 }
608 out << "))";
609 return true;
610 }
611 // TODO(crbug.com/tint/661): Combine sequential barriers to a single
612 // instruction.
613 case sem::IntrinsicType::kStorageBarrier: {
614 out << "threadgroup_barrier(mem_flags::mem_device)";
615 return true;
616 }
617 case sem::IntrinsicType::kWorkgroupBarrier: {
618 out << "threadgroup_barrier(mem_flags::mem_threadgroup)";
619 return true;
620 }
621 case sem::IntrinsicType::kIgnore: { // [DEPRECATED]
622 out << "(void) ";
623 if (!EmitExpression(out, expr->args[0])) {
624 return false;
625 }
626 return true;
627 }
628
629 case sem::IntrinsicType::kLength: {
630 auto* sem = builder_.Sem().Get(expr->args[0]);
631 if (sem->Type()->UnwrapRef()->is_scalar()) {
632 // Emulate scalar overload using fabs(x).
633 name = "fabs";
634 }
635 break;
636 }
637
638 case sem::IntrinsicType::kDistance: {
639 auto* sem = builder_.Sem().Get(expr->args[0]);
640 if (sem->Type()->UnwrapRef()->is_scalar()) {
641 // Emulate scalar overload using fabs(x - y);
642 out << "fabs";
643 ScopedParen sp(out);
644 if (!EmitExpression(out, expr->args[0])) {
645 return false;
646 }
647 out << " - ";
648 if (!EmitExpression(out, expr->args[1])) {
649 return false;
650 }
651 return true;
652 }
653 break;
654 }
655
656 default:
657 break;
658 }
659
660 if (name.empty()) {
661 return false;
662 }
663
664 out << name << "(";
665
666 bool first = true;
667 for (auto* arg : expr->args) {
668 if (!first) {
669 out << ", ";
670 }
671 first = false;
672
673 if (!EmitExpression(out, arg)) {
674 return false;
675 }
676 }
677
678 out << ")";
679 return true;
680 }
681
EmitTypeConversion(std::ostream & out,const sem::Call * call,const sem::TypeConversion * conv)682 bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
683 const sem::Call* call,
684 const sem::TypeConversion* conv) {
685 if (!EmitType(out, conv->Target(), "")) {
686 return false;
687 }
688 out << "(";
689
690 if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
691 return false;
692 }
693
694 out << ")";
695 return true;
696 }
697
EmitTypeConstructor(std::ostream & out,const sem::Call * call,const sem::TypeConstructor * ctor)698 bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
699 const sem::Call* call,
700 const sem::TypeConstructor* ctor) {
701 auto* type = ctor->ReturnType();
702
703 if (type->IsAnyOf<sem::Array, sem::Struct>()) {
704 out << "{";
705 } else {
706 if (!EmitType(out, type, "")) {
707 return false;
708 }
709 out << "(";
710 }
711
712 int i = 0;
713 for (auto* arg : call->Arguments()) {
714 if (i > 0) {
715 out << ", ";
716 }
717
718 if (auto* struct_ty = type->As<sem::Struct>()) {
719 // Emit field designators for structures to account for padding members.
720 auto* member = struct_ty->Members()[i]->Declaration();
721 auto name = program_->Symbols().NameFor(member->symbol);
722 out << "." << name << "=";
723 }
724
725 if (!EmitExpression(out, arg->Declaration())) {
726 return false;
727 }
728
729 i++;
730 }
731
732 if (type->IsAnyOf<sem::Array, sem::Struct>()) {
733 out << "}";
734 } else {
735 out << ")";
736 }
737 return true;
738 }
739
EmitAtomicCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)740 bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
741 const ast::CallExpression* expr,
742 const sem::Intrinsic* intrinsic) {
743 auto call = [&](const std::string& name, bool append_memory_order_relaxed) {
744 out << name;
745 {
746 ScopedParen sp(out);
747 for (size_t i = 0; i < expr->args.size(); i++) {
748 auto* arg = expr->args[i];
749 if (i > 0) {
750 out << ", ";
751 }
752 if (!EmitExpression(out, arg)) {
753 return false;
754 }
755 }
756 if (append_memory_order_relaxed) {
757 out << ", memory_order_relaxed";
758 }
759 }
760 return true;
761 };
762
763 switch (intrinsic->Type()) {
764 case sem::IntrinsicType::kAtomicLoad:
765 return call("atomic_load_explicit", true);
766
767 case sem::IntrinsicType::kAtomicStore:
768 return call("atomic_store_explicit", true);
769
770 case sem::IntrinsicType::kAtomicAdd:
771 return call("atomic_fetch_add_explicit", true);
772
773 case sem::IntrinsicType::kAtomicSub:
774 return call("atomic_fetch_sub_explicit", true);
775
776 case sem::IntrinsicType::kAtomicMax:
777 return call("atomic_fetch_max_explicit", true);
778
779 case sem::IntrinsicType::kAtomicMin:
780 return call("atomic_fetch_min_explicit", true);
781
782 case sem::IntrinsicType::kAtomicAnd:
783 return call("atomic_fetch_and_explicit", true);
784
785 case sem::IntrinsicType::kAtomicOr:
786 return call("atomic_fetch_or_explicit", true);
787
788 case sem::IntrinsicType::kAtomicXor:
789 return call("atomic_fetch_xor_explicit", true);
790
791 case sem::IntrinsicType::kAtomicExchange:
792 return call("atomic_exchange_explicit", true);
793
794 case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
795 auto* ptr_ty = TypeOf(expr->args[0])->UnwrapRef()->As<sem::Pointer>();
796 auto sc = ptr_ty->StorageClass();
797
798 auto func = utils::GetOrCreate(
799 atomicCompareExchangeWeak_, sc, [&]() -> std::string {
800 auto name = UniqueIdentifier("atomicCompareExchangeWeak");
801 auto& buf = helpers_;
802
803 line(&buf) << "template <typename A, typename T>";
804 {
805 auto f = line(&buf);
806 f << "vec<T, 2> " << name << "(";
807 if (!EmitStorageClass(f, sc)) {
808 return "";
809 }
810 f << " A* atomic, T compare, T value) {";
811 }
812
813 buf.IncrementIndent();
814 TINT_DEFER({
815 buf.DecrementIndent();
816 line(&buf) << "}";
817 line(&buf);
818 });
819
820 line(&buf) << "T prev_value = compare;";
821 line(&buf) << "bool matched = "
822 "atomic_compare_exchange_weak_explicit(atomic, "
823 "&prev_value, value, memory_order_relaxed, "
824 "memory_order_relaxed);";
825 line(&buf) << "return {prev_value, matched};";
826 return name;
827 });
828
829 return call(func, false);
830 }
831
832 default:
833 break;
834 }
835
836 TINT_UNREACHABLE(Writer, diagnostics_)
837 << "unsupported atomic intrinsic: " << intrinsic->Type();
838 return false;
839 }
840
EmitTextureCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)841 bool GeneratorImpl::EmitTextureCall(std::ostream& out,
842 const sem::Call* call,
843 const sem::Intrinsic* intrinsic) {
844 using Usage = sem::ParameterUsage;
845
846 auto& signature = intrinsic->Signature();
847 auto* expr = call->Declaration();
848 auto& arguments = call->Arguments();
849
850 // Returns the argument with the given usage
851 auto arg = [&](Usage usage) {
852 int idx = signature.IndexOf(usage);
853 return (idx >= 0) ? arguments[idx] : nullptr;
854 };
855
856 auto* texture = arg(Usage::kTexture)->Declaration();
857 if (!texture) {
858 TINT_ICE(Writer, diagnostics_) << "missing texture arg";
859 return false;
860 }
861
862 auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
863
864 // Helper to emit the texture expression, wrapped in parentheses if the
865 // expression includes an operator with lower precedence than the member
866 // accessor used for the function calls.
867 auto texture_expr = [&]() {
868 bool paren_lhs =
869 !texture->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
870 ast::IdentifierExpression,
871 ast::MemberAccessorExpression>();
872 if (paren_lhs) {
873 out << "(";
874 }
875 if (!EmitExpression(out, texture)) {
876 return false;
877 }
878 if (paren_lhs) {
879 out << ")";
880 }
881 return true;
882 };
883
884 switch (intrinsic->Type()) {
885 case sem::IntrinsicType::kTextureDimensions: {
886 std::vector<const char*> dims;
887 switch (texture_type->dim()) {
888 case ast::TextureDimension::kNone:
889 diagnostics_.add_error(diag::System::Writer,
890 "texture dimension is kNone");
891 return false;
892 case ast::TextureDimension::k1d:
893 dims = {"width"};
894 break;
895 case ast::TextureDimension::k2d:
896 case ast::TextureDimension::k2dArray:
897 case ast::TextureDimension::kCube:
898 case ast::TextureDimension::kCubeArray:
899 dims = {"width", "height"};
900 break;
901 case ast::TextureDimension::k3d:
902 dims = {"width", "height", "depth"};
903 break;
904 }
905
906 auto get_dim = [&](const char* name) {
907 if (!texture_expr()) {
908 return false;
909 }
910 out << ".get_" << name << "(";
911 if (auto* level = arg(Usage::kLevel)) {
912 if (!EmitExpression(out, level->Declaration())) {
913 return false;
914 }
915 }
916 out << ")";
917 return true;
918 };
919
920 if (dims.size() == 1) {
921 out << "int(";
922 get_dim(dims[0]);
923 out << ")";
924 } else {
925 EmitType(out, TypeOf(expr)->UnwrapRef(), "");
926 out << "(";
927 for (size_t i = 0; i < dims.size(); i++) {
928 if (i > 0) {
929 out << ", ";
930 }
931 get_dim(dims[i]);
932 }
933 out << ")";
934 }
935 return true;
936 }
937 case sem::IntrinsicType::kTextureNumLayers: {
938 out << "int(";
939 if (!texture_expr()) {
940 return false;
941 }
942 out << ".get_array_size())";
943 return true;
944 }
945 case sem::IntrinsicType::kTextureNumLevels: {
946 out << "int(";
947 if (!texture_expr()) {
948 return false;
949 }
950 out << ".get_num_mip_levels())";
951 return true;
952 }
953 case sem::IntrinsicType::kTextureNumSamples: {
954 out << "int(";
955 if (!texture_expr()) {
956 return false;
957 }
958 out << ".get_num_samples())";
959 return true;
960 }
961 default:
962 break;
963 }
964
965 if (!texture_expr()) {
966 return false;
967 }
968
969 bool lod_param_is_named = true;
970
971 switch (intrinsic->Type()) {
972 case sem::IntrinsicType::kTextureSample:
973 case sem::IntrinsicType::kTextureSampleBias:
974 case sem::IntrinsicType::kTextureSampleLevel:
975 case sem::IntrinsicType::kTextureSampleGrad:
976 out << ".sample(";
977 break;
978 case sem::IntrinsicType::kTextureSampleCompare:
979 case sem::IntrinsicType::kTextureSampleCompareLevel:
980 out << ".sample_compare(";
981 break;
982 case sem::IntrinsicType::kTextureGather:
983 out << ".gather(";
984 break;
985 case sem::IntrinsicType::kTextureGatherCompare:
986 out << ".gather_compare(";
987 break;
988 case sem::IntrinsicType::kTextureLoad:
989 out << ".read(";
990 lod_param_is_named = false;
991 break;
992 case sem::IntrinsicType::kTextureStore:
993 out << ".write(";
994 break;
995 default:
996 TINT_UNREACHABLE(Writer, diagnostics_)
997 << "Unhandled texture intrinsic '" << intrinsic->str() << "'";
998 return false;
999 }
1000
1001 bool first_arg = true;
1002 auto maybe_write_comma = [&] {
1003 if (!first_arg) {
1004 out << ", ";
1005 }
1006 first_arg = false;
1007 };
1008
1009 for (auto usage :
1010 {Usage::kValue, Usage::kSampler, Usage::kCoords, Usage::kArrayIndex,
1011 Usage::kDepthRef, Usage::kSampleIndex}) {
1012 if (auto* e = arg(usage)) {
1013 maybe_write_comma();
1014
1015 // Cast the coordinates to unsigned integers if necessary.
1016 bool casted = false;
1017 if (usage == Usage::kCoords &&
1018 e->Type()->UnwrapRef()->is_integer_scalar_or_vector()) {
1019 casted = true;
1020 switch (texture_type->dim()) {
1021 case ast::TextureDimension::k1d:
1022 out << "uint(";
1023 break;
1024 case ast::TextureDimension::k2d:
1025 case ast::TextureDimension::k2dArray:
1026 out << "uint2(";
1027 break;
1028 case ast::TextureDimension::k3d:
1029 out << "uint3(";
1030 break;
1031 default:
1032 TINT_ICE(Writer, diagnostics_)
1033 << "unhandled texture dimensionality";
1034 break;
1035 }
1036 }
1037
1038 if (!EmitExpression(out, e->Declaration()))
1039 return false;
1040
1041 if (casted) {
1042 out << ")";
1043 }
1044 }
1045 }
1046
1047 if (auto* bias = arg(Usage::kBias)) {
1048 maybe_write_comma();
1049 out << "bias(";
1050 if (!EmitExpression(out, bias->Declaration())) {
1051 return false;
1052 }
1053 out << ")";
1054 }
1055 if (auto* level = arg(Usage::kLevel)) {
1056 maybe_write_comma();
1057 if (lod_param_is_named) {
1058 out << "level(";
1059 }
1060 if (!EmitExpression(out, level->Declaration())) {
1061 return false;
1062 }
1063 if (lod_param_is_named) {
1064 out << ")";
1065 }
1066 }
1067 if (intrinsic->Type() == sem::IntrinsicType::kTextureSampleCompareLevel) {
1068 maybe_write_comma();
1069 out << "level(0)";
1070 }
1071 if (auto* ddx = arg(Usage::kDdx)) {
1072 auto dim = texture_type->dim();
1073 switch (dim) {
1074 case ast::TextureDimension::k2d:
1075 case ast::TextureDimension::k2dArray:
1076 maybe_write_comma();
1077 out << "gradient2d(";
1078 break;
1079 case ast::TextureDimension::k3d:
1080 maybe_write_comma();
1081 out << "gradient3d(";
1082 break;
1083 case ast::TextureDimension::kCube:
1084 case ast::TextureDimension::kCubeArray:
1085 maybe_write_comma();
1086 out << "gradientcube(";
1087 break;
1088 default: {
1089 std::stringstream err;
1090 err << "MSL does not support gradients for " << dim << " textures";
1091 diagnostics_.add_error(diag::System::Writer, err.str());
1092 return false;
1093 }
1094 }
1095 if (!EmitExpression(out, ddx->Declaration())) {
1096 return false;
1097 }
1098 out << ", ";
1099 if (!EmitExpression(out, arg(Usage::kDdy)->Declaration())) {
1100 return false;
1101 }
1102 out << ")";
1103 }
1104
1105 bool has_offset = false;
1106 if (auto* offset = arg(Usage::kOffset)) {
1107 has_offset = true;
1108 maybe_write_comma();
1109 if (!EmitExpression(out, offset->Declaration())) {
1110 return false;
1111 }
1112 }
1113
1114 if (auto* component = arg(Usage::kComponent)) {
1115 maybe_write_comma();
1116 if (!has_offset) {
1117 // offset argument may need to be provided if we have a component.
1118 switch (texture_type->dim()) {
1119 case ast::TextureDimension::k2d:
1120 case ast::TextureDimension::k2dArray:
1121 out << "int2(0), ";
1122 break;
1123 default:
1124 break; // Other texture dimensions don't have an offset
1125 }
1126 }
1127 auto c = component->ConstantValue().Elements()[0].i32;
1128 switch (c) {
1129 case 0:
1130 out << "component::x";
1131 break;
1132 case 1:
1133 out << "component::y";
1134 break;
1135 case 2:
1136 out << "component::z";
1137 break;
1138 case 3:
1139 out << "component::w";
1140 break;
1141 default:
1142 TINT_ICE(Writer, diagnostics_)
1143 << "invalid textureGather component: " << c;
1144 break;
1145 }
1146 }
1147
1148 out << ")";
1149
1150 return true;
1151 }
1152
EmitDotCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1153 bool GeneratorImpl::EmitDotCall(std::ostream& out,
1154 const ast::CallExpression* expr,
1155 const sem::Intrinsic* intrinsic) {
1156 auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
1157 std::string fn = "dot";
1158 if (vec_ty->type()->is_integer_scalar()) {
1159 // MSL does not have a builtin for dot() with integer vector types.
1160 // Generate the helper function if it hasn't been created already
1161 fn = utils::GetOrCreate(
1162 int_dot_funcs_, vec_ty->Width(), [&]() -> std::string {
1163 TextBuffer b;
1164 TINT_DEFER(helpers_.Append(b));
1165
1166 auto fn_name =
1167 UniqueIdentifier("tint_dot" + std::to_string(vec_ty->Width()));
1168 auto v = "vec<T," + std::to_string(vec_ty->Width()) + ">";
1169
1170 line(&b) << "template<typename T>";
1171 line(&b) << "T " << fn_name << "(" << v << " a, " << v << " b) {";
1172 {
1173 auto l = line(&b);
1174 l << " return ";
1175 for (uint32_t i = 0; i < vec_ty->Width(); i++) {
1176 if (i > 0) {
1177 l << " + ";
1178 }
1179 l << "a[" << i << "]*b[" << i << "]";
1180 }
1181 l << ";";
1182 }
1183 line(&b) << "}";
1184 return fn_name;
1185 });
1186 }
1187
1188 out << fn << "(";
1189 if (!EmitExpression(out, expr->args[0])) {
1190 return false;
1191 }
1192 out << ", ";
1193 if (!EmitExpression(out, expr->args[1])) {
1194 return false;
1195 }
1196 out << ")";
1197 return true;
1198 }
1199
EmitModfCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1200 bool GeneratorImpl::EmitModfCall(std::ostream& out,
1201 const ast::CallExpression* expr,
1202 const sem::Intrinsic* intrinsic) {
1203 return CallIntrinsicHelper(
1204 out, expr, intrinsic,
1205 [&](TextBuffer* b, const std::vector<std::string>& params) {
1206 auto* ty = intrinsic->Parameters()[0]->Type();
1207 auto in = params[0];
1208
1209 std::string width;
1210 if (auto* vec = ty->As<sem::Vector>()) {
1211 width = std::to_string(vec->Width());
1212 }
1213
1214 // Emit the builtin return type unique to this overload. This does not
1215 // exist in the AST, so it will not be generated in Generate().
1216 if (!EmitStructType(&helpers_,
1217 intrinsic->ReturnType()->As<sem::Struct>())) {
1218 return false;
1219 }
1220
1221 line(b) << "float" << width << " whole;";
1222 line(b) << "float" << width << " fract = modf(" << in << ", whole);";
1223 line(b) << "return {fract, whole};";
1224 return true;
1225 });
1226 }
1227
EmitFrexpCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1228 bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
1229 const ast::CallExpression* expr,
1230 const sem::Intrinsic* intrinsic) {
1231 return CallIntrinsicHelper(
1232 out, expr, intrinsic,
1233 [&](TextBuffer* b, const std::vector<std::string>& params) {
1234 auto* ty = intrinsic->Parameters()[0]->Type();
1235 auto in = params[0];
1236
1237 std::string width;
1238 if (auto* vec = ty->As<sem::Vector>()) {
1239 width = std::to_string(vec->Width());
1240 }
1241
1242 // Emit the builtin return type unique to this overload. This does not
1243 // exist in the AST, so it will not be generated in Generate().
1244 if (!EmitStructType(&helpers_,
1245 intrinsic->ReturnType()->As<sem::Struct>())) {
1246 return false;
1247 }
1248
1249 line(b) << "int" << width << " exp;";
1250 line(b) << "float" << width << " sig = frexp(" << in << ", exp);";
1251 line(b) << "return {sig, exp};";
1252 return true;
1253 });
1254 }
1255
generate_builtin_name(const sem::Intrinsic * intrinsic)1256 std::string GeneratorImpl::generate_builtin_name(
1257 const sem::Intrinsic* intrinsic) {
1258 std::string out = "";
1259 switch (intrinsic->Type()) {
1260 case sem::IntrinsicType::kAcos:
1261 case sem::IntrinsicType::kAll:
1262 case sem::IntrinsicType::kAny:
1263 case sem::IntrinsicType::kAsin:
1264 case sem::IntrinsicType::kAtan:
1265 case sem::IntrinsicType::kAtan2:
1266 case sem::IntrinsicType::kCeil:
1267 case sem::IntrinsicType::kCos:
1268 case sem::IntrinsicType::kCosh:
1269 case sem::IntrinsicType::kCross:
1270 case sem::IntrinsicType::kDeterminant:
1271 case sem::IntrinsicType::kDistance:
1272 case sem::IntrinsicType::kDot:
1273 case sem::IntrinsicType::kExp:
1274 case sem::IntrinsicType::kExp2:
1275 case sem::IntrinsicType::kFloor:
1276 case sem::IntrinsicType::kFma:
1277 case sem::IntrinsicType::kFract:
1278 case sem::IntrinsicType::kFrexp:
1279 case sem::IntrinsicType::kLength:
1280 case sem::IntrinsicType::kLdexp:
1281 case sem::IntrinsicType::kLog:
1282 case sem::IntrinsicType::kLog2:
1283 case sem::IntrinsicType::kMix:
1284 case sem::IntrinsicType::kModf:
1285 case sem::IntrinsicType::kNormalize:
1286 case sem::IntrinsicType::kPow:
1287 case sem::IntrinsicType::kReflect:
1288 case sem::IntrinsicType::kRefract:
1289 case sem::IntrinsicType::kSelect:
1290 case sem::IntrinsicType::kSin:
1291 case sem::IntrinsicType::kSinh:
1292 case sem::IntrinsicType::kSqrt:
1293 case sem::IntrinsicType::kStep:
1294 case sem::IntrinsicType::kTan:
1295 case sem::IntrinsicType::kTanh:
1296 case sem::IntrinsicType::kTranspose:
1297 case sem::IntrinsicType::kTrunc:
1298 case sem::IntrinsicType::kSign:
1299 case sem::IntrinsicType::kClamp:
1300 out += intrinsic->str();
1301 break;
1302 case sem::IntrinsicType::kAbs:
1303 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1304 out += "fabs";
1305 } else {
1306 out += "abs";
1307 }
1308 break;
1309 case sem::IntrinsicType::kCountOneBits:
1310 out += "popcount";
1311 break;
1312 case sem::IntrinsicType::kDpdx:
1313 case sem::IntrinsicType::kDpdxCoarse:
1314 case sem::IntrinsicType::kDpdxFine:
1315 out += "dfdx";
1316 break;
1317 case sem::IntrinsicType::kDpdy:
1318 case sem::IntrinsicType::kDpdyCoarse:
1319 case sem::IntrinsicType::kDpdyFine:
1320 out += "dfdy";
1321 break;
1322 case sem::IntrinsicType::kFwidth:
1323 case sem::IntrinsicType::kFwidthCoarse:
1324 case sem::IntrinsicType::kFwidthFine:
1325 out += "fwidth";
1326 break;
1327 case sem::IntrinsicType::kIsFinite:
1328 out += "isfinite";
1329 break;
1330 case sem::IntrinsicType::kIsInf:
1331 out += "isinf";
1332 break;
1333 case sem::IntrinsicType::kIsNan:
1334 out += "isnan";
1335 break;
1336 case sem::IntrinsicType::kIsNormal:
1337 out += "isnormal";
1338 break;
1339 case sem::IntrinsicType::kMax:
1340 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1341 out += "fmax";
1342 } else {
1343 out += "max";
1344 }
1345 break;
1346 case sem::IntrinsicType::kMin:
1347 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
1348 out += "fmin";
1349 } else {
1350 out += "min";
1351 }
1352 break;
1353 case sem::IntrinsicType::kFaceForward:
1354 out += "faceforward";
1355 break;
1356 case sem::IntrinsicType::kPack4x8snorm:
1357 out += "pack_float_to_snorm4x8";
1358 break;
1359 case sem::IntrinsicType::kPack4x8unorm:
1360 out += "pack_float_to_unorm4x8";
1361 break;
1362 case sem::IntrinsicType::kPack2x16snorm:
1363 out += "pack_float_to_snorm2x16";
1364 break;
1365 case sem::IntrinsicType::kPack2x16unorm:
1366 out += "pack_float_to_unorm2x16";
1367 break;
1368 case sem::IntrinsicType::kReverseBits:
1369 out += "reverse_bits";
1370 break;
1371 case sem::IntrinsicType::kRound:
1372 out += "rint";
1373 break;
1374 case sem::IntrinsicType::kSmoothStep:
1375 out += "smoothstep";
1376 break;
1377 case sem::IntrinsicType::kInverseSqrt:
1378 out += "rsqrt";
1379 break;
1380 case sem::IntrinsicType::kUnpack4x8snorm:
1381 out += "unpack_snorm4x8_to_float";
1382 break;
1383 case sem::IntrinsicType::kUnpack4x8unorm:
1384 out += "unpack_unorm4x8_to_float";
1385 break;
1386 case sem::IntrinsicType::kUnpack2x16snorm:
1387 out += "unpack_snorm2x16_to_float";
1388 break;
1389 case sem::IntrinsicType::kUnpack2x16unorm:
1390 out += "unpack_unorm2x16_to_float";
1391 break;
1392 case sem::IntrinsicType::kArrayLength:
1393 diagnostics_.add_error(
1394 diag::System::Writer,
1395 "Unable to translate builtin: " + std::string(intrinsic->str()) +
1396 "\nDid you forget to pass array_length_from_uniform generator "
1397 "options?");
1398 return "";
1399 default:
1400 diagnostics_.add_error(
1401 diag::System::Writer,
1402 "Unknown import method: " + std::string(intrinsic->str()));
1403 return "";
1404 }
1405 return out;
1406 }
1407
EmitCase(const ast::CaseStatement * stmt)1408 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
1409 if (stmt->IsDefault()) {
1410 line() << "default: {";
1411 } else {
1412 for (auto* selector : stmt->selectors) {
1413 auto out = line();
1414 out << "case ";
1415 if (!EmitLiteral(out, selector)) {
1416 return false;
1417 }
1418 out << ":";
1419 if (selector == stmt->selectors.back()) {
1420 out << " {";
1421 }
1422 }
1423 }
1424
1425 {
1426 ScopedIndent si(this);
1427
1428 for (auto* s : stmt->body->statements) {
1429 if (!EmitStatement(s)) {
1430 return false;
1431 }
1432 }
1433
1434 if (!last_is_break_or_fallthrough(stmt->body)) {
1435 line() << "break;";
1436 }
1437 }
1438
1439 line() << "}";
1440
1441 return true;
1442 }
1443
EmitContinue(const ast::ContinueStatement *)1444 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
1445 if (!emit_continuing_()) {
1446 return false;
1447 }
1448
1449 line() << "continue;";
1450 return true;
1451 }
1452
EmitZeroValue(std::ostream & out,const sem::Type * type)1453 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
1454 if (type->Is<sem::Bool>()) {
1455 out << "false";
1456 } else if (type->Is<sem::F32>()) {
1457 out << "0.0f";
1458 } else if (type->Is<sem::I32>()) {
1459 out << "0";
1460 } else if (type->Is<sem::U32>()) {
1461 out << "0u";
1462 } else if (auto* vec = type->As<sem::Vector>()) {
1463 return EmitZeroValue(out, vec->type());
1464 } else if (auto* mat = type->As<sem::Matrix>()) {
1465 if (!EmitType(out, mat, "")) {
1466 return false;
1467 }
1468 out << "(";
1469 if (!EmitZeroValue(out, mat->type())) {
1470 return false;
1471 }
1472 out << ")";
1473 } else if (auto* arr = type->As<sem::Array>()) {
1474 out << "{";
1475 if (!EmitZeroValue(out, arr->ElemType())) {
1476 return false;
1477 }
1478 out << "}";
1479 } else if (type->As<sem::Struct>()) {
1480 out << "{}";
1481 } else {
1482 diagnostics_.add_error(
1483 diag::System::Writer,
1484 "Invalid type for zero emission: " + type->type_name());
1485 return false;
1486 }
1487 return true;
1488 }
1489
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)1490 bool GeneratorImpl::EmitLiteral(std::ostream& out,
1491 const ast::LiteralExpression* lit) {
1492 if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
1493 out << (l->value ? "true" : "false");
1494 } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
1495 if (std::isinf(fl->value)) {
1496 out << (fl->value >= 0 ? "INFINITY" : "-INFINITY");
1497 } else if (std::isnan(fl->value)) {
1498 out << "NAN";
1499 } else {
1500 out << FloatToString(fl->value) << "f";
1501 }
1502 } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
1503 // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
1504 // minus and `2147483648` as separate tokens, and the latter doesn't
1505 // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
1506 // issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
1507 // ensures the expression type is `int`.
1508 const auto int_min = std::numeric_limits<int32_t>::min();
1509 if (sl->ValueAsI32() == int_min) {
1510 out << "(" << int_min + 1 << " - 1)";
1511 } else {
1512 out << sl->value;
1513 }
1514 } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
1515 out << ul->value << "u";
1516 } else {
1517 diagnostics_.add_error(diag::System::Writer, "unknown literal type");
1518 return false;
1519 }
1520 return true;
1521 }
1522
EmitExpression(std::ostream & out,const ast::Expression * expr)1523 bool GeneratorImpl::EmitExpression(std::ostream& out,
1524 const ast::Expression* expr) {
1525 if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
1526 return EmitIndexAccessor(out, a);
1527 }
1528 if (auto* b = expr->As<ast::BinaryExpression>()) {
1529 return EmitBinary(out, b);
1530 }
1531 if (auto* b = expr->As<ast::BitcastExpression>()) {
1532 return EmitBitcast(out, b);
1533 }
1534 if (auto* c = expr->As<ast::CallExpression>()) {
1535 return EmitCall(out, c);
1536 }
1537 if (auto* i = expr->As<ast::IdentifierExpression>()) {
1538 return EmitIdentifier(out, i);
1539 }
1540 if (auto* l = expr->As<ast::LiteralExpression>()) {
1541 return EmitLiteral(out, l);
1542 }
1543 if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
1544 return EmitMemberAccessor(out, m);
1545 }
1546 if (auto* u = expr->As<ast::UnaryOpExpression>()) {
1547 return EmitUnaryOp(out, u);
1548 }
1549
1550 diagnostics_.add_error(
1551 diag::System::Writer,
1552 "unknown expression type: " + std::string(expr->TypeInfo().name));
1553 return false;
1554 }
1555
EmitStage(std::ostream & out,ast::PipelineStage stage)1556 void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
1557 switch (stage) {
1558 case ast::PipelineStage::kFragment:
1559 out << "fragment";
1560 break;
1561 case ast::PipelineStage::kVertex:
1562 out << "vertex";
1563 break;
1564 case ast::PipelineStage::kCompute:
1565 out << "kernel";
1566 break;
1567 case ast::PipelineStage::kNone:
1568 break;
1569 }
1570 return;
1571 }
1572
EmitFunction(const ast::Function * func)1573 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
1574 auto* func_sem = program_->Sem().Get(func);
1575
1576 {
1577 auto out = line();
1578 if (!EmitType(out, func_sem->ReturnType(), "")) {
1579 return false;
1580 }
1581 out << " " << program_->Symbols().NameFor(func->symbol) << "(";
1582
1583 bool first = true;
1584 for (auto* v : func->params) {
1585 if (!first) {
1586 out << ", ";
1587 }
1588 first = false;
1589
1590 auto* type = program_->Sem().Get(v)->Type();
1591
1592 std::string param_name =
1593 "const " + program_->Symbols().NameFor(v->symbol);
1594 if (!EmitType(out, type, param_name)) {
1595 return false;
1596 }
1597 // Parameter name is output as part of the type for arrays and pointers.
1598 if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
1599 out << " " << program_->Symbols().NameFor(v->symbol);
1600 }
1601 }
1602
1603 out << ") {";
1604 }
1605
1606 if (!EmitStatementsWithIndent(func->body->statements)) {
1607 return false;
1608 }
1609
1610 line() << "}";
1611
1612 return true;
1613 }
1614
builtin_to_attribute(ast::Builtin builtin) const1615 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
1616 switch (builtin) {
1617 case ast::Builtin::kPosition:
1618 return "position";
1619 case ast::Builtin::kVertexIndex:
1620 return "vertex_id";
1621 case ast::Builtin::kInstanceIndex:
1622 return "instance_id";
1623 case ast::Builtin::kFrontFacing:
1624 return "front_facing";
1625 case ast::Builtin::kFragDepth:
1626 return "depth(any)";
1627 case ast::Builtin::kLocalInvocationId:
1628 return "thread_position_in_threadgroup";
1629 case ast::Builtin::kLocalInvocationIndex:
1630 return "thread_index_in_threadgroup";
1631 case ast::Builtin::kGlobalInvocationId:
1632 return "thread_position_in_grid";
1633 case ast::Builtin::kWorkgroupId:
1634 return "threadgroup_position_in_grid";
1635 case ast::Builtin::kNumWorkgroups:
1636 return "threadgroups_per_grid";
1637 case ast::Builtin::kSampleIndex:
1638 return "sample_id";
1639 case ast::Builtin::kSampleMask:
1640 return "sample_mask";
1641 case ast::Builtin::kPointSize:
1642 return "point_size";
1643 default:
1644 break;
1645 }
1646 return "";
1647 }
1648
interpolation_to_attribute(ast::InterpolationType type,ast::InterpolationSampling sampling) const1649 std::string GeneratorImpl::interpolation_to_attribute(
1650 ast::InterpolationType type,
1651 ast::InterpolationSampling sampling) const {
1652 std::string attr;
1653 switch (sampling) {
1654 case ast::InterpolationSampling::kCenter:
1655 attr = "center_";
1656 break;
1657 case ast::InterpolationSampling::kCentroid:
1658 attr = "centroid_";
1659 break;
1660 case ast::InterpolationSampling::kSample:
1661 attr = "sample_";
1662 break;
1663 case ast::InterpolationSampling::kNone:
1664 break;
1665 }
1666 switch (type) {
1667 case ast::InterpolationType::kPerspective:
1668 attr += "perspective";
1669 break;
1670 case ast::InterpolationType::kLinear:
1671 attr += "no_perspective";
1672 break;
1673 case ast::InterpolationType::kFlat:
1674 attr += "flat";
1675 break;
1676 }
1677 return attr;
1678 }
1679
EmitEntryPointFunction(const ast::Function * func)1680 bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
1681 auto func_name = program_->Symbols().NameFor(func->symbol);
1682
1683 // Returns the binding index of a variable, requiring that the group attribute
1684 // have a value of zero.
1685 const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max();
1686 auto get_binding_index = [&](const ast::Variable* var) -> uint32_t {
1687 auto bp = var->BindingPoint();
1688 if (bp.group == nullptr || bp.binding == nullptr) {
1689 TINT_ICE(Writer, diagnostics_)
1690 << "missing binding attributes for entry point parameter";
1691 return kInvalidBindingIndex;
1692 }
1693 if (bp.group->value != 0) {
1694 TINT_ICE(Writer, diagnostics_)
1695 << "encountered non-zero resource group index (use "
1696 "BindingRemapper to fix)";
1697 return kInvalidBindingIndex;
1698 }
1699 return bp.binding->value;
1700 };
1701
1702 {
1703 auto out = line();
1704
1705 EmitStage(out, func->PipelineStage());
1706 out << " " << func->return_type->FriendlyName(program_->Symbols());
1707 out << " " << func_name << "(";
1708
1709 // Emit entry point parameters.
1710 bool first = true;
1711 for (auto* var : func->params) {
1712 if (!first) {
1713 out << ", ";
1714 }
1715 first = false;
1716
1717 auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
1718
1719 auto param_name = program_->Symbols().NameFor(var->symbol);
1720 if (!EmitType(out, type, param_name)) {
1721 return false;
1722 }
1723 // Parameter name is output as part of the type for arrays and pointers.
1724 if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
1725 out << " " << param_name;
1726 }
1727
1728 if (type->Is<sem::Struct>()) {
1729 out << " [[stage_in]]";
1730 } else if (type->is_handle()) {
1731 uint32_t binding = get_binding_index(var);
1732 if (binding == kInvalidBindingIndex) {
1733 return false;
1734 }
1735 if (var->type->Is<ast::Sampler>()) {
1736 out << " [[sampler(" << binding << ")]]";
1737 } else if (var->type->Is<ast::Texture>()) {
1738 out << " [[texture(" << binding << ")]]";
1739 } else {
1740 TINT_ICE(Writer, diagnostics_)
1741 << "invalid handle type entry point parameter";
1742 return false;
1743 }
1744 } else if (auto* ptr = var->type->As<ast::Pointer>()) {
1745 auto sc = ptr->storage_class;
1746 if (sc == ast::StorageClass::kWorkgroup) {
1747 auto& allocations = workgroup_allocations_[func_name];
1748 out << " [[threadgroup(" << allocations.size() << ")]]";
1749 allocations.push_back(program_->Sem().Get(ptr->type)->Size());
1750 } else if (sc == ast::StorageClass::kStorage ||
1751 sc == ast::StorageClass::kUniform) {
1752 uint32_t binding = get_binding_index(var);
1753 if (binding == kInvalidBindingIndex) {
1754 return false;
1755 }
1756 out << " [[buffer(" << binding << ")]]";
1757 } else {
1758 TINT_ICE(Writer, diagnostics_)
1759 << "invalid pointer storage class for entry point parameter";
1760 return false;
1761 }
1762 } else {
1763 auto& decos = var->decorations;
1764 bool builtin_found = false;
1765 for (auto* deco : decos) {
1766 auto* builtin = deco->As<ast::BuiltinDecoration>();
1767 if (!builtin) {
1768 continue;
1769 }
1770
1771 builtin_found = true;
1772
1773 auto attr = builtin_to_attribute(builtin->builtin);
1774 if (attr.empty()) {
1775 diagnostics_.add_error(diag::System::Writer, "unknown builtin");
1776 return false;
1777 }
1778 out << " [[" << attr << "]]";
1779 }
1780 if (!builtin_found) {
1781 TINT_ICE(Writer, diagnostics_) << "Unsupported entry point parameter";
1782 }
1783 }
1784 }
1785 out << ") {";
1786 }
1787
1788 {
1789 ScopedIndent si(this);
1790
1791 if (!EmitStatements(func->body->statements)) {
1792 return false;
1793 }
1794
1795 if (!Is<ast::ReturnStatement>(func->body->Last())) {
1796 ast::ReturnStatement ret(ProgramID{}, Source{});
1797 if (!EmitStatement(&ret)) {
1798 return false;
1799 }
1800 }
1801 }
1802
1803 line() << "}";
1804 return true;
1805 }
1806
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)1807 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
1808 const ast::IdentifierExpression* expr) {
1809 out << program_->Symbols().NameFor(expr->symbol);
1810 return true;
1811 }
1812
EmitLoop(const ast::LoopStatement * stmt)1813 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
1814 auto emit_continuing = [this, stmt]() {
1815 if (stmt->continuing && !stmt->continuing->Empty()) {
1816 if (!EmitBlock(stmt->continuing)) {
1817 return false;
1818 }
1819 }
1820 return true;
1821 };
1822
1823 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1824 line() << "while (true) {";
1825 {
1826 ScopedIndent si(this);
1827 if (!EmitStatements(stmt->body->statements)) {
1828 return false;
1829 }
1830 if (!emit_continuing()) {
1831 return false;
1832 }
1833 }
1834 line() << "}";
1835
1836 return true;
1837 }
1838
EmitForLoop(const ast::ForLoopStatement * stmt)1839 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
1840 TextBuffer init_buf;
1841 if (auto* init = stmt->initializer) {
1842 TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
1843 if (!EmitStatement(init)) {
1844 return false;
1845 }
1846 }
1847
1848 TextBuffer cond_pre;
1849 std::stringstream cond_buf;
1850 if (auto* cond = stmt->condition) {
1851 TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
1852 if (!EmitExpression(cond_buf, cond)) {
1853 return false;
1854 }
1855 }
1856
1857 TextBuffer cont_buf;
1858 if (auto* cont = stmt->continuing) {
1859 TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
1860 if (!EmitStatement(cont)) {
1861 return false;
1862 }
1863 }
1864
1865 // If the for-loop has a multi-statement conditional and / or continuing, then
1866 // we cannot emit this as a regular for-loop in MSL. Instead we need to
1867 // generate a `while(true)` loop.
1868 bool emit_as_loop = cond_pre.lines.size() > 0 || cont_buf.lines.size() > 1;
1869
1870 // If the for-loop has multi-statement initializer, or is going to be emitted
1871 // as a `while(true)` loop, then declare the initializer statement(s) before
1872 // the loop in a new block.
1873 bool nest_in_block =
1874 init_buf.lines.size() > 1 || (stmt->initializer && emit_as_loop);
1875 if (nest_in_block) {
1876 line() << "{";
1877 increment_indent();
1878 current_buffer_->Append(init_buf);
1879 init_buf.lines.clear(); // Don't emit the initializer again in the 'for'
1880 }
1881 TINT_DEFER({
1882 if (nest_in_block) {
1883 decrement_indent();
1884 line() << "}";
1885 }
1886 });
1887
1888 if (emit_as_loop) {
1889 auto emit_continuing = [&]() {
1890 current_buffer_->Append(cont_buf);
1891 return true;
1892 };
1893
1894 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1895 line() << "while (true) {";
1896 increment_indent();
1897 TINT_DEFER({
1898 decrement_indent();
1899 line() << "}";
1900 });
1901
1902 if (stmt->condition) {
1903 current_buffer_->Append(cond_pre);
1904 line() << "if (!(" << cond_buf.str() << ")) { break; }";
1905 }
1906
1907 if (!EmitStatements(stmt->body->statements)) {
1908 return false;
1909 }
1910
1911 if (!emit_continuing()) {
1912 return false;
1913 }
1914 } else {
1915 // For-loop can be generated.
1916 {
1917 auto out = line();
1918 out << "for";
1919 {
1920 ScopedParen sp(out);
1921
1922 if (!init_buf.lines.empty()) {
1923 out << init_buf.lines[0].content << " ";
1924 } else {
1925 out << "; ";
1926 }
1927
1928 out << cond_buf.str() << "; ";
1929
1930 if (!cont_buf.lines.empty()) {
1931 out << TrimSuffix(cont_buf.lines[0].content, ";");
1932 }
1933 }
1934 out << " {";
1935 }
1936 {
1937 auto emit_continuing = [] { return true; };
1938 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
1939 if (!EmitStatementsWithIndent(stmt->body->statements)) {
1940 return false;
1941 }
1942 }
1943 line() << "}";
1944 }
1945
1946 return true;
1947 }
1948
EmitDiscard(const ast::DiscardStatement *)1949 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
1950 // TODO(dsinclair): Verify this is correct when the discard semantics are
1951 // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
1952 line() << "discard_fragment();";
1953 return true;
1954 }
1955
EmitIf(const ast::IfStatement * stmt)1956 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
1957 {
1958 auto out = line();
1959 out << "if (";
1960 if (!EmitExpression(out, stmt->condition)) {
1961 return false;
1962 }
1963 out << ") {";
1964 }
1965
1966 if (!EmitStatementsWithIndent(stmt->body->statements)) {
1967 return false;
1968 }
1969
1970 for (auto* e : stmt->else_statements) {
1971 if (e->condition) {
1972 line() << "} else {";
1973 increment_indent();
1974
1975 {
1976 auto out = line();
1977 out << "if (";
1978 if (!EmitExpression(out, e->condition)) {
1979 return false;
1980 }
1981 out << ") {";
1982 }
1983 } else {
1984 line() << "} else {";
1985 }
1986
1987 if (!EmitStatementsWithIndent(e->body->statements)) {
1988 return false;
1989 }
1990 }
1991
1992 line() << "}";
1993
1994 for (auto* e : stmt->else_statements) {
1995 if (e->condition) {
1996 decrement_indent();
1997 line() << "}";
1998 }
1999 }
2000 return true;
2001 }
2002
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)2003 bool GeneratorImpl::EmitMemberAccessor(
2004 std::ostream& out,
2005 const ast::MemberAccessorExpression* expr) {
2006 auto write_lhs = [&] {
2007 bool paren_lhs = !expr->structure->IsAnyOf<
2008 ast::IndexAccessorExpression, ast::CallExpression,
2009 ast::IdentifierExpression, ast::MemberAccessorExpression>();
2010 if (paren_lhs) {
2011 out << "(";
2012 }
2013 if (!EmitExpression(out, expr->structure)) {
2014 return false;
2015 }
2016 if (paren_lhs) {
2017 out << ")";
2018 }
2019 return true;
2020 };
2021
2022 auto& sem = program_->Sem();
2023
2024 if (auto* swizzle = sem.Get(expr)->As<sem::Swizzle>()) {
2025 // Metal 1.x does not support swizzling of packed vector types.
2026 // For single element swizzles, we can use the index operator.
2027 // For multi-element swizzles, we need to cast to a regular vector type
2028 // first. Note that we do not currently allow assignments to swizzles, so
2029 // the casting which will convert the l-value to r-value is fine.
2030 if (swizzle->Indices().size() == 1) {
2031 if (!write_lhs()) {
2032 return false;
2033 }
2034 out << "[" << swizzle->Indices()[0] << "]";
2035 } else {
2036 if (!EmitType(out, sem.Get(expr->structure)->Type()->UnwrapRef(), "")) {
2037 return false;
2038 }
2039 out << "(";
2040 if (!write_lhs()) {
2041 return false;
2042 }
2043 out << ")." << program_->Symbols().NameFor(expr->member->symbol);
2044 }
2045 } else {
2046 if (!write_lhs()) {
2047 return false;
2048 }
2049 out << ".";
2050 if (!EmitExpression(out, expr->member)) {
2051 return false;
2052 }
2053 }
2054
2055 return true;
2056 }
2057
EmitReturn(const ast::ReturnStatement * stmt)2058 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
2059 auto out = line();
2060 out << "return";
2061 if (stmt->value) {
2062 out << " ";
2063 if (!EmitExpression(out, stmt->value)) {
2064 return false;
2065 }
2066 }
2067 out << ";";
2068 return true;
2069 }
2070
EmitBlock(const ast::BlockStatement * stmt)2071 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
2072 line() << "{";
2073
2074 if (!EmitStatementsWithIndent(stmt->statements)) {
2075 return false;
2076 }
2077
2078 line() << "}";
2079
2080 return true;
2081 }
2082
EmitStatement(const ast::Statement * stmt)2083 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
2084 if (auto* a = stmt->As<ast::AssignmentStatement>()) {
2085 return EmitAssign(a);
2086 }
2087 if (auto* b = stmt->As<ast::BlockStatement>()) {
2088 return EmitBlock(b);
2089 }
2090 if (auto* b = stmt->As<ast::BreakStatement>()) {
2091 return EmitBreak(b);
2092 }
2093 if (auto* c = stmt->As<ast::CallStatement>()) {
2094 auto out = line();
2095 if (!EmitCall(out, c->expr)) {
2096 return false;
2097 }
2098 out << ";";
2099 return true;
2100 }
2101 if (auto* c = stmt->As<ast::ContinueStatement>()) {
2102 return EmitContinue(c);
2103 }
2104 if (auto* d = stmt->As<ast::DiscardStatement>()) {
2105 return EmitDiscard(d);
2106 }
2107 if (stmt->As<ast::FallthroughStatement>()) {
2108 line() << "/* fallthrough */";
2109 return true;
2110 }
2111 if (auto* i = stmt->As<ast::IfStatement>()) {
2112 return EmitIf(i);
2113 }
2114 if (auto* l = stmt->As<ast::LoopStatement>()) {
2115 return EmitLoop(l);
2116 }
2117 if (auto* l = stmt->As<ast::ForLoopStatement>()) {
2118 return EmitForLoop(l);
2119 }
2120 if (auto* r = stmt->As<ast::ReturnStatement>()) {
2121 return EmitReturn(r);
2122 }
2123 if (auto* s = stmt->As<ast::SwitchStatement>()) {
2124 return EmitSwitch(s);
2125 }
2126 if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
2127 auto* var = program_->Sem().Get(v->variable);
2128 return EmitVariable(var);
2129 }
2130
2131 diagnostics_.add_error(
2132 diag::System::Writer,
2133 "unknown statement type: " + std::string(stmt->TypeInfo().name));
2134 return false;
2135 }
2136
EmitStatements(const ast::StatementList & stmts)2137 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
2138 for (auto* s : stmts) {
2139 if (!EmitStatement(s)) {
2140 return false;
2141 }
2142 }
2143 return true;
2144 }
2145
EmitStatementsWithIndent(const ast::StatementList & stmts)2146 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
2147 ScopedIndent si(this);
2148 return EmitStatements(stmts);
2149 }
2150
EmitSwitch(const ast::SwitchStatement * stmt)2151 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
2152 {
2153 auto out = line();
2154 out << "switch(";
2155 if (!EmitExpression(out, stmt->condition)) {
2156 return false;
2157 }
2158 out << ") {";
2159 }
2160
2161 {
2162 ScopedIndent si(this);
2163 for (auto* s : stmt->body) {
2164 if (!EmitCase(s)) {
2165 return false;
2166 }
2167 }
2168 }
2169
2170 line() << "}";
2171
2172 return true;
2173 }
2174
EmitType(std::ostream & out,const sem::Type * type,const std::string & name,bool * name_printed)2175 bool GeneratorImpl::EmitType(std::ostream& out,
2176 const sem::Type* type,
2177 const std::string& name,
2178 bool* name_printed /* = nullptr */) {
2179 if (name_printed) {
2180 *name_printed = false;
2181 }
2182 if (auto* atomic = type->As<sem::Atomic>()) {
2183 if (atomic->Type()->Is<sem::I32>()) {
2184 out << "atomic_int";
2185 return true;
2186 }
2187 if (atomic->Type()->Is<sem::U32>()) {
2188 out << "atomic_uint";
2189 return true;
2190 }
2191 TINT_ICE(Writer, diagnostics_)
2192 << "unhandled atomic type " << atomic->Type()->type_name();
2193 return false;
2194 }
2195
2196 if (auto* ary = type->As<sem::Array>()) {
2197 const sem::Type* base_type = ary;
2198 std::vector<uint32_t> sizes;
2199 while (auto* arr = base_type->As<sem::Array>()) {
2200 if (arr->IsRuntimeSized()) {
2201 sizes.push_back(1);
2202 } else {
2203 sizes.push_back(arr->Count());
2204 }
2205 base_type = arr->ElemType();
2206 }
2207 if (!EmitType(out, base_type, "")) {
2208 return false;
2209 }
2210 if (!name.empty()) {
2211 out << " " << name;
2212 if (name_printed) {
2213 *name_printed = true;
2214 }
2215 }
2216 for (uint32_t size : sizes) {
2217 out << "[" << size << "]";
2218 }
2219 return true;
2220 }
2221
2222 if (type->Is<sem::Bool>()) {
2223 out << "bool";
2224 return true;
2225 }
2226
2227 if (type->Is<sem::F32>()) {
2228 out << "float";
2229 return true;
2230 }
2231
2232 if (type->Is<sem::I32>()) {
2233 out << "int";
2234 return true;
2235 }
2236
2237 if (auto* mat = type->As<sem::Matrix>()) {
2238 if (!EmitType(out, mat->type(), "")) {
2239 return false;
2240 }
2241 out << mat->columns() << "x" << mat->rows();
2242 return true;
2243 }
2244
2245 if (auto* ptr = type->As<sem::Pointer>()) {
2246 if (ptr->Access() == ast::Access::kRead) {
2247 out << "const ";
2248 }
2249 if (!EmitStorageClass(out, ptr->StorageClass())) {
2250 return false;
2251 }
2252 out << " ";
2253 if (ptr->StoreType()->Is<sem::Array>()) {
2254 std::string inner = "(*" + name + ")";
2255 if (!EmitType(out, ptr->StoreType(), inner)) {
2256 return false;
2257 }
2258 if (name_printed) {
2259 *name_printed = true;
2260 }
2261 } else {
2262 if (!EmitType(out, ptr->StoreType(), "")) {
2263 return false;
2264 }
2265 out << "* " << name;
2266 if (name_printed) {
2267 *name_printed = true;
2268 }
2269 }
2270 return true;
2271 }
2272
2273 if (type->Is<sem::Sampler>()) {
2274 out << "sampler";
2275 return true;
2276 }
2277
2278 if (auto* str = type->As<sem::Struct>()) {
2279 // The struct type emits as just the name. The declaration would be emitted
2280 // as part of emitting the declared types.
2281 out << StructName(str);
2282 return true;
2283 }
2284
2285 if (auto* tex = type->As<sem::Texture>()) {
2286 if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
2287 out << "depth";
2288 } else {
2289 out << "texture";
2290 }
2291
2292 switch (tex->dim()) {
2293 case ast::TextureDimension::k1d:
2294 out << "1d";
2295 break;
2296 case ast::TextureDimension::k2d:
2297 out << "2d";
2298 break;
2299 case ast::TextureDimension::k2dArray:
2300 out << "2d_array";
2301 break;
2302 case ast::TextureDimension::k3d:
2303 out << "3d";
2304 break;
2305 case ast::TextureDimension::kCube:
2306 out << "cube";
2307 break;
2308 case ast::TextureDimension::kCubeArray:
2309 out << "cube_array";
2310 break;
2311 default:
2312 diagnostics_.add_error(diag::System::Writer,
2313 "Invalid texture dimensions");
2314 return false;
2315 }
2316 if (tex->IsAnyOf<sem::MultisampledTexture,
2317 sem::DepthMultisampledTexture>()) {
2318 out << "_ms";
2319 }
2320 out << "<";
2321 if (tex->Is<sem::DepthTexture>()) {
2322 out << "float, access::sample";
2323 } else if (tex->Is<sem::DepthMultisampledTexture>()) {
2324 out << "float, access::read";
2325 } else if (auto* storage = tex->As<sem::StorageTexture>()) {
2326 if (!EmitType(out, storage->type(), "")) {
2327 return false;
2328 }
2329
2330 std::string access_str;
2331 if (storage->access() == ast::Access::kRead) {
2332 out << ", access::read";
2333 } else if (storage->access() == ast::Access::kWrite) {
2334 out << ", access::write";
2335 } else {
2336 diagnostics_.add_error(diag::System::Writer,
2337 "Invalid access control for storage texture");
2338 return false;
2339 }
2340 } else if (auto* ms = tex->As<sem::MultisampledTexture>()) {
2341 if (!EmitType(out, ms->type(), "")) {
2342 return false;
2343 }
2344 out << ", access::read";
2345 } else if (auto* sampled = tex->As<sem::SampledTexture>()) {
2346 if (!EmitType(out, sampled->type(), "")) {
2347 return false;
2348 }
2349 out << ", access::sample";
2350 } else {
2351 diagnostics_.add_error(diag::System::Writer, "invalid texture type");
2352 return false;
2353 }
2354 out << ">";
2355 return true;
2356 }
2357
2358 if (type->Is<sem::U32>()) {
2359 out << "uint";
2360 return true;
2361 }
2362
2363 if (auto* vec = type->As<sem::Vector>()) {
2364 if (!EmitType(out, vec->type(), "")) {
2365 return false;
2366 }
2367 out << vec->Width();
2368 return true;
2369 }
2370
2371 if (type->Is<sem::Void>()) {
2372 out << "void";
2373 return true;
2374 }
2375
2376 diagnostics_.add_error(diag::System::Writer,
2377 "unknown type in EmitType: " + type->type_name());
2378 return false;
2379 }
2380
EmitTypeAndName(std::ostream & out,const sem::Type * type,const std::string & name)2381 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
2382 const sem::Type* type,
2383 const std::string& name) {
2384 bool name_printed = false;
2385 if (!EmitType(out, type, name, &name_printed)) {
2386 return false;
2387 }
2388 if (!name_printed) {
2389 out << " " << name;
2390 }
2391 return true;
2392 }
2393
EmitStorageClass(std::ostream & out,ast::StorageClass sc)2394 bool GeneratorImpl::EmitStorageClass(std::ostream& out, ast::StorageClass sc) {
2395 switch (sc) {
2396 case ast::StorageClass::kFunction:
2397 case ast::StorageClass::kPrivate:
2398 case ast::StorageClass::kUniformConstant:
2399 out << "thread";
2400 return true;
2401 case ast::StorageClass::kWorkgroup:
2402 out << "threadgroup";
2403 return true;
2404 case ast::StorageClass::kStorage:
2405 out << "device";
2406 return true;
2407 case ast::StorageClass::kUniform:
2408 out << "constant";
2409 return true;
2410 default:
2411 break;
2412 }
2413 TINT_ICE(Writer, diagnostics_) << "unhandled storage class: " << sc;
2414 return false;
2415 }
2416
EmitPackedType(std::ostream & out,const sem::Type * type,const std::string & name)2417 bool GeneratorImpl::EmitPackedType(std::ostream& out,
2418 const sem::Type* type,
2419 const std::string& name) {
2420 auto* vec = type->As<sem::Vector>();
2421 if (vec && vec->Width() == 3) {
2422 out << "packed_";
2423 if (!EmitType(out, vec, "")) {
2424 return false;
2425 }
2426
2427 if (vec->is_float_vector() && !matrix_packed_vector_overloads_) {
2428 // Overload operators for matrix-vector arithmetic where the vector
2429 // operand is packed, as these overloads to not exist in the metal
2430 // namespace.
2431 TextBuffer b;
2432 TINT_DEFER(helpers_.Append(b));
2433 line(&b) << R"(template<typename T, int N, int M>
2434 inline vec<T, M> operator*(matrix<T, N, M> lhs, packed_vec<T, N> rhs) {
2435 return lhs * vec<T, N>(rhs);
2436 }
2437
2438 template<typename T, int N, int M>
2439 inline vec<T, N> operator*(packed_vec<T, M> lhs, matrix<T, N, M> rhs) {
2440 return vec<T, M>(lhs) * rhs;
2441 }
2442 )";
2443 matrix_packed_vector_overloads_ = true;
2444 }
2445
2446 return true;
2447 }
2448
2449 return EmitType(out, type, name);
2450 }
2451
EmitStructType(TextBuffer * b,const sem::Struct * str)2452 bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
2453 line(b) << "struct " << StructName(str) << " {";
2454
2455 bool is_host_shareable = str->IsHostShareable();
2456
2457 // Emits a `/* 0xnnnn */` byte offset comment for a struct member.
2458 auto add_byte_offset_comment = [&](std::ostream& out, uint32_t offset) {
2459 std::ios_base::fmtflags saved_flag_state(out.flags());
2460 out << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset
2461 << " */ ";
2462 out.flags(saved_flag_state);
2463 };
2464
2465 auto add_padding = [&](uint32_t size, uint32_t msl_offset) {
2466 std::string name;
2467 do {
2468 name = UniqueIdentifier("tint_pad");
2469 } while (str->FindMember(program_->Symbols().Get(name)));
2470
2471 auto out = line(b);
2472 add_byte_offset_comment(out, msl_offset);
2473 out << "int8_t " << name << "[" << size << "];";
2474 };
2475
2476 b->IncrementIndent();
2477
2478 uint32_t msl_offset = 0;
2479 for (auto* mem : str->Members()) {
2480 auto out = line(b);
2481 auto name = program_->Symbols().NameFor(mem->Name());
2482 auto wgsl_offset = mem->Offset();
2483
2484 if (is_host_shareable) {
2485 if (wgsl_offset < msl_offset) {
2486 // Unimplementable layout
2487 TINT_ICE(Writer, diagnostics_)
2488 << "Structure member WGSL offset (" << wgsl_offset
2489 << ") is behind MSL offset (" << msl_offset << ")";
2490 return false;
2491 }
2492
2493 // Generate padding if required
2494 if (auto padding = wgsl_offset - msl_offset) {
2495 add_padding(padding, msl_offset);
2496 msl_offset += padding;
2497 }
2498
2499 add_byte_offset_comment(out, msl_offset);
2500
2501 if (!EmitPackedType(out, mem->Type(), name)) {
2502 return false;
2503 }
2504 } else {
2505 if (!EmitType(out, mem->Type(), name)) {
2506 return false;
2507 }
2508 }
2509
2510 auto* ty = mem->Type();
2511
2512 // Array member name will be output with the type
2513 if (!ty->Is<sem::Array>()) {
2514 out << " " << name;
2515 }
2516
2517 // Emit decorations
2518 if (auto* decl = mem->Declaration()) {
2519 for (auto* deco : decl->decorations) {
2520 if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
2521 auto attr = builtin_to_attribute(builtin->builtin);
2522 if (attr.empty()) {
2523 diagnostics_.add_error(diag::System::Writer, "unknown builtin");
2524 return false;
2525 }
2526 out << " [[" << attr << "]]";
2527 } else if (auto* loc = deco->As<ast::LocationDecoration>()) {
2528 auto& pipeline_stage_uses = str->PipelineStageUses();
2529 if (pipeline_stage_uses.size() != 1) {
2530 TINT_ICE(Writer, diagnostics_)
2531 << "invalid entry point IO struct uses";
2532 }
2533
2534 if (pipeline_stage_uses.count(
2535 sem::PipelineStageUsage::kVertexInput)) {
2536 out << " [[attribute(" + std::to_string(loc->value) + ")]]";
2537 } else if (pipeline_stage_uses.count(
2538 sem::PipelineStageUsage::kVertexOutput)) {
2539 out << " [[user(locn" + std::to_string(loc->value) + ")]]";
2540 } else if (pipeline_stage_uses.count(
2541 sem::PipelineStageUsage::kFragmentInput)) {
2542 out << " [[user(locn" + std::to_string(loc->value) + ")]]";
2543 } else if (pipeline_stage_uses.count(
2544 sem::PipelineStageUsage::kFragmentOutput)) {
2545 out << " [[color(" + std::to_string(loc->value) + ")]]";
2546 } else {
2547 TINT_ICE(Writer, diagnostics_)
2548 << "invalid use of location decoration";
2549 }
2550 } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
2551 auto attr = interpolation_to_attribute(interpolate->type,
2552 interpolate->sampling);
2553 if (attr.empty()) {
2554 diagnostics_.add_error(diag::System::Writer,
2555 "unknown interpolation attribute");
2556 return false;
2557 }
2558 out << " [[" << attr << "]]";
2559 } else if (deco->Is<ast::InvariantDecoration>()) {
2560 if (invariant_define_name_.empty()) {
2561 invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
2562 }
2563 out << " " << invariant_define_name_;
2564 } else if (!deco->IsAnyOf<ast::StructMemberOffsetDecoration,
2565 ast::StructMemberAlignDecoration,
2566 ast::StructMemberSizeDecoration>()) {
2567 TINT_ICE(Writer, diagnostics_)
2568 << "unhandled struct member attribute: " << deco->Name();
2569 }
2570 }
2571 }
2572
2573 out << ";";
2574
2575 if (is_host_shareable) {
2576 // Calculate new MSL offset
2577 auto size_align = MslPackedTypeSizeAndAlign(ty);
2578 if (msl_offset % size_align.align) {
2579 TINT_ICE(Writer, diagnostics_)
2580 << "Misaligned MSL structure member "
2581 << ty->FriendlyName(program_->Symbols()) << " " << name;
2582 return false;
2583 }
2584 msl_offset += size_align.size;
2585 }
2586 }
2587
2588 if (is_host_shareable && str->Size() != msl_offset) {
2589 add_padding(str->Size() - msl_offset, msl_offset);
2590 }
2591
2592 b->DecrementIndent();
2593
2594 line(b) << "};";
2595 return true;
2596 }
2597
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)2598 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
2599 const ast::UnaryOpExpression* expr) {
2600 // Handle `-e` when `e` is signed, so that we ensure that if `e` is the
2601 // largest negative value, it returns `e`.
2602 auto* expr_type = TypeOf(expr->expr)->UnwrapRef();
2603 if (expr->op == ast::UnaryOp::kNegation &&
2604 expr_type->is_signed_scalar_or_vector()) {
2605 auto fn =
2606 utils::GetOrCreate(unary_minus_funcs_, expr_type, [&]() -> std::string {
2607 // e.g.:
2608 // int tint_unary_minus(const int v) {
2609 // return (v == -2147483648) ? v : -v;
2610 // }
2611 TextBuffer b;
2612 TINT_DEFER(helpers_.Append(b));
2613
2614 auto fn_name = UniqueIdentifier("tint_unary_minus");
2615 {
2616 auto decl = line(&b);
2617 if (!EmitTypeAndName(decl, expr_type, fn_name)) {
2618 return "";
2619 }
2620 decl << "(const ";
2621 if (!EmitType(decl, expr_type, "")) {
2622 return "";
2623 }
2624 decl << " v) {";
2625 }
2626
2627 {
2628 ScopedIndent si(&b);
2629 const auto largest_negative_value =
2630 std::to_string(std::numeric_limits<int32_t>::min());
2631 line(&b) << "return select(-v, v, v == " << largest_negative_value
2632 << ");";
2633 }
2634 line(&b) << "}";
2635 line(&b);
2636 return fn_name;
2637 });
2638
2639 out << fn << "(";
2640 if (!EmitExpression(out, expr->expr)) {
2641 return false;
2642 }
2643 out << ")";
2644 return true;
2645 }
2646
2647 switch (expr->op) {
2648 case ast::UnaryOp::kAddressOf:
2649 out << "&";
2650 break;
2651 case ast::UnaryOp::kComplement:
2652 out << "~";
2653 break;
2654 case ast::UnaryOp::kIndirection:
2655 out << "*";
2656 break;
2657 case ast::UnaryOp::kNot:
2658 out << "!";
2659 break;
2660 case ast::UnaryOp::kNegation:
2661 out << "-";
2662 break;
2663 }
2664 out << "(";
2665
2666 if (!EmitExpression(out, expr->expr)) {
2667 return false;
2668 }
2669
2670 out << ")";
2671
2672 return true;
2673 }
2674
EmitVariable(const sem::Variable * var)2675 bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
2676 auto* decl = var->Declaration();
2677
2678 for (auto* deco : decl->decorations) {
2679 if (!deco->Is<ast::InternalDecoration>()) {
2680 TINT_ICE(Writer, diagnostics_) << "unexpected variable decoration";
2681 return false;
2682 }
2683 }
2684
2685 auto out = line();
2686
2687 switch (var->StorageClass()) {
2688 case ast::StorageClass::kFunction:
2689 case ast::StorageClass::kUniformConstant:
2690 case ast::StorageClass::kNone:
2691 break;
2692 case ast::StorageClass::kPrivate:
2693 out << "thread ";
2694 break;
2695 case ast::StorageClass::kWorkgroup:
2696 out << "threadgroup ";
2697 break;
2698 default:
2699 TINT_ICE(Writer, diagnostics_) << "unhandled variable storage class";
2700 return false;
2701 }
2702
2703 auto* type = var->Type()->UnwrapRef();
2704
2705 std::string name = program_->Symbols().NameFor(decl->symbol);
2706 if (decl->is_const) {
2707 name = "const " + name;
2708 }
2709 if (!EmitType(out, type, name)) {
2710 return false;
2711 }
2712 // Variable name is output as part of the type for arrays and pointers.
2713 if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
2714 out << " " << name;
2715 }
2716
2717 if (decl->constructor != nullptr) {
2718 out << " = ";
2719 if (!EmitExpression(out, decl->constructor)) {
2720 return false;
2721 }
2722 } else if (var->StorageClass() == ast::StorageClass::kPrivate ||
2723 var->StorageClass() == ast::StorageClass::kFunction ||
2724 var->StorageClass() == ast::StorageClass::kNone) {
2725 out << " = ";
2726 if (!EmitZeroValue(out, type)) {
2727 return false;
2728 }
2729 }
2730 out << ";";
2731
2732 return true;
2733 }
2734
EmitProgramConstVariable(const ast::Variable * var)2735 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
2736 for (auto* d : var->decorations) {
2737 if (!d->Is<ast::OverrideDecoration>()) {
2738 diagnostics_.add_error(diag::System::Writer,
2739 "Decorated const values not valid");
2740 return false;
2741 }
2742 }
2743 if (!var->is_const) {
2744 diagnostics_.add_error(diag::System::Writer, "Expected a const value");
2745 return false;
2746 }
2747
2748 auto out = line();
2749 out << "constant ";
2750 auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
2751 if (!EmitType(out, type, program_->Symbols().NameFor(var->symbol))) {
2752 return false;
2753 }
2754 if (!type->Is<sem::Array>()) {
2755 out << " " << program_->Symbols().NameFor(var->symbol);
2756 }
2757
2758 auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
2759 if (global && global->IsOverridable()) {
2760 out << " [[function_constant(" << global->ConstantId() << ")]]";
2761 } else if (var->constructor != nullptr) {
2762 out << " = ";
2763 if (!EmitExpression(out, var->constructor)) {
2764 return false;
2765 }
2766 }
2767 out << ";";
2768
2769 return true;
2770 }
2771
MslPackedTypeSizeAndAlign(const sem::Type * ty)2772 GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
2773 const sem::Type* ty) {
2774 if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2775 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2776 // 2.1 Scalar Data Types
2777 return {4, 4};
2778 }
2779
2780 if (auto* vec = ty->As<sem::Vector>()) {
2781 auto num_els = vec->Width();
2782 auto* el_ty = vec->type();
2783 if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2784 // Use a packed_vec type for 3-element vectors only.
2785 if (num_els == 3) {
2786 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2787 // 2.2.3 Packed Vector Types
2788 return SizeAndAlign{num_els * 4, 4};
2789 } else {
2790 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2791 // 2.2 Vector Data Types
2792 return SizeAndAlign{num_els * 4, num_els * 4};
2793 }
2794 }
2795 }
2796
2797 if (auto* mat = ty->As<sem::Matrix>()) {
2798 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
2799 // 2.3 Matrix Data Types
2800 auto cols = mat->columns();
2801 auto rows = mat->rows();
2802 auto* el_ty = mat->type();
2803 if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
2804 static constexpr SizeAndAlign table[] = {
2805 /* float2x2 */ {16, 8},
2806 /* float2x3 */ {32, 16},
2807 /* float2x4 */ {32, 16},
2808 /* float3x2 */ {24, 8},
2809 /* float3x3 */ {48, 16},
2810 /* float3x4 */ {48, 16},
2811 /* float4x2 */ {32, 8},
2812 /* float4x3 */ {64, 16},
2813 /* float4x4 */ {64, 16},
2814 };
2815 if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
2816 return table[(3 * (cols - 2)) + (rows - 2)];
2817 }
2818 }
2819 }
2820
2821 if (auto* arr = ty->As<sem::Array>()) {
2822 if (!arr->IsStrideImplicit()) {
2823 TINT_ICE(Writer, diagnostics_)
2824 << "arrays with explicit strides should have "
2825 "removed with the PadArrayElements transform";
2826 return {};
2827 }
2828 auto num_els = std::max<uint32_t>(arr->Count(), 1);
2829 return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
2830 }
2831
2832 if (auto* str = ty->As<sem::Struct>()) {
2833 // TODO(crbug.com/tint/650): There's an assumption here that MSL's default
2834 // structure size and alignment matches WGSL's. We need to confirm this.
2835 return SizeAndAlign{str->Size(), str->Align()};
2836 }
2837
2838 if (auto* atomic = ty->As<sem::Atomic>()) {
2839 return MslPackedTypeSizeAndAlign(atomic->Type());
2840 }
2841
2842 TINT_UNREACHABLE(Writer, diagnostics_)
2843 << "Unhandled type " << ty->TypeInfo().name;
2844 return {};
2845 }
2846
2847 template <typename F>
CallIntrinsicHelper(std::ostream & out,const ast::CallExpression * call,const sem::Intrinsic * intrinsic,F && build)2848 bool GeneratorImpl::CallIntrinsicHelper(std::ostream& out,
2849 const ast::CallExpression* call,
2850 const sem::Intrinsic* intrinsic,
2851 F&& build) {
2852 // Generate the helper function if it hasn't been created already
2853 auto fn = utils::GetOrCreate(intrinsics_, intrinsic, [&]() -> std::string {
2854 TextBuffer b;
2855 TINT_DEFER(helpers_.Append(b));
2856
2857 auto fn_name =
2858 UniqueIdentifier(std::string("tint_") + sem::str(intrinsic->Type()));
2859 std::vector<std::string> parameter_names;
2860 {
2861 auto decl = line(&b);
2862 if (!EmitTypeAndName(decl, intrinsic->ReturnType(), fn_name)) {
2863 return "";
2864 }
2865 {
2866 ScopedParen sp(decl);
2867 for (auto* param : intrinsic->Parameters()) {
2868 if (!parameter_names.empty()) {
2869 decl << ", ";
2870 }
2871 auto param_name = "param_" + std::to_string(parameter_names.size());
2872 if (!EmitTypeAndName(decl, param->Type(), param_name)) {
2873 return "";
2874 }
2875 parameter_names.emplace_back(std::move(param_name));
2876 }
2877 }
2878 decl << " {";
2879 }
2880 {
2881 ScopedIndent si(&b);
2882 if (!build(&b, parameter_names)) {
2883 return "";
2884 }
2885 }
2886 line(&b) << "}";
2887 line(&b);
2888 return fn_name;
2889 });
2890
2891 if (fn.empty()) {
2892 return false;
2893 }
2894
2895 // Call the helper
2896 out << fn;
2897 {
2898 ScopedParen sp(out);
2899 bool first = true;
2900 for (auto* arg : call->args) {
2901 if (!first) {
2902 out << ", ";
2903 }
2904 first = false;
2905 if (!EmitExpression(out, arg)) {
2906 return false;
2907 }
2908 }
2909 }
2910 return true;
2911 }
2912
2913 } // namespace msl
2914 } // namespace writer
2915 } // namespace tint
2916