• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /// Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/writer/hlsl/generator_impl.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 
24 #include "src/ast/call_statement.h"
25 #include "src/ast/fallthrough_statement.h"
26 #include "src/ast/internal_decoration.h"
27 #include "src/ast/interpolate_decoration.h"
28 #include "src/ast/override_decoration.h"
29 #include "src/ast/variable_decl_statement.h"
30 #include "src/debug.h"
31 #include "src/sem/array.h"
32 #include "src/sem/atomic_type.h"
33 #include "src/sem/block_statement.h"
34 #include "src/sem/call.h"
35 #include "src/sem/depth_multisampled_texture_type.h"
36 #include "src/sem/depth_texture_type.h"
37 #include "src/sem/function.h"
38 #include "src/sem/member_accessor_expression.h"
39 #include "src/sem/multisampled_texture_type.h"
40 #include "src/sem/sampled_texture_type.h"
41 #include "src/sem/statement.h"
42 #include "src/sem/storage_texture_type.h"
43 #include "src/sem/struct.h"
44 #include "src/sem/type_constructor.h"
45 #include "src/sem/type_conversion.h"
46 #include "src/sem/variable.h"
47 #include "src/transform/add_empty_entry_point.h"
48 #include "src/transform/array_length_from_uniform.h"
49 #include "src/transform/calculate_array_length.h"
50 #include "src/transform/canonicalize_entry_point_io.h"
51 #include "src/transform/decompose_memory_access.h"
52 #include "src/transform/external_texture_transform.h"
53 #include "src/transform/fold_trivial_single_use_lets.h"
54 #include "src/transform/loop_to_for_loop.h"
55 #include "src/transform/manager.h"
56 #include "src/transform/num_workgroups_from_uniform.h"
57 #include "src/transform/pad_array_elements.h"
58 #include "src/transform/promote_initializers_to_const_var.h"
59 #include "src/transform/remove_phonies.h"
60 #include "src/transform/simplify_pointers.h"
61 #include "src/transform/unshadow.h"
62 #include "src/transform/zero_init_workgroup_memory.h"
63 #include "src/utils/defer.h"
64 #include "src/utils/map.h"
65 #include "src/utils/scoped_assignment.h"
66 #include "src/writer/append_vector.h"
67 #include "src/writer/float_to_string.h"
68 
69 namespace tint {
70 namespace writer {
71 namespace hlsl {
72 namespace {
73 
74 const char kTempNamePrefix[] = "tint_tmp";
75 const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
76 
image_format_to_rwtexture_type(ast::ImageFormat image_format)77 const char* image_format_to_rwtexture_type(ast::ImageFormat image_format) {
78   switch (image_format) {
79     case ast::ImageFormat::kRgba8Unorm:
80     case ast::ImageFormat::kRgba8Snorm:
81     case ast::ImageFormat::kRgba16Float:
82     case ast::ImageFormat::kR32Float:
83     case ast::ImageFormat::kRg32Float:
84     case ast::ImageFormat::kRgba32Float:
85       return "float4";
86     case ast::ImageFormat::kRgba8Uint:
87     case ast::ImageFormat::kRgba16Uint:
88     case ast::ImageFormat::kR32Uint:
89     case ast::ImageFormat::kRg32Uint:
90     case ast::ImageFormat::kRgba32Uint:
91       return "uint4";
92     case ast::ImageFormat::kRgba8Sint:
93     case ast::ImageFormat::kRgba16Sint:
94     case ast::ImageFormat::kR32Sint:
95     case ast::ImageFormat::kRg32Sint:
96     case ast::ImageFormat::kRgba32Sint:
97       return "int4";
98     default:
99       return nullptr;
100   }
101 }
102 
103 // Helper for writing " : register(RX, spaceY)", where R is the register, X is
104 // the binding point binding value, and Y is the binding point group value.
105 struct RegisterAndSpace {
RegisterAndSpacetint::writer::hlsl::__anon72f09d1b0111::RegisterAndSpace106   RegisterAndSpace(char r, ast::VariableBindingPoint bp)
107       : reg(r), binding_point(bp) {}
108 
109   const char reg;
110   ast::VariableBindingPoint const binding_point;
111 };
112 
operator <<(std::ostream & s,const RegisterAndSpace & rs)113 std::ostream& operator<<(std::ostream& s, const RegisterAndSpace& rs) {
114   s << " : register(" << rs.reg << rs.binding_point.binding->value << ", space"
115     << rs.binding_point.group->value << ")";
116   return s;
117 }
118 
LoopAttribute()119 const char* LoopAttribute() {
120   // Force loops not to be unrolled to work around FXC compilation issues when
121   // it attempts and fails to unroll loops when it contains gradient operations.
122   // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-while
123   return "[loop] ";
124 }
125 
126 }  // namespace
127 
128 SanitizedResult::SanitizedResult() = default;
129 SanitizedResult::~SanitizedResult() = default;
130 SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
131 
Sanitize(const Program * in,sem::BindingPoint root_constant_binding_point,bool disable_workgroup_init,const ArrayLengthFromUniformOptions & array_length_from_uniform)132 SanitizedResult Sanitize(
133     const Program* in,
134     sem::BindingPoint root_constant_binding_point,
135     bool disable_workgroup_init,
136     const ArrayLengthFromUniformOptions& array_length_from_uniform) {
137   transform::Manager manager;
138   transform::DataMap data;
139 
140   // Build the config for the internal ArrayLengthFromUniform transform.
141   transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
142       array_length_from_uniform.ubo_binding);
143   array_length_from_uniform_cfg.bindpoint_to_size_index =
144       array_length_from_uniform.bindpoint_to_size_index;
145 
146   manager.Add<transform::Unshadow>();
147 
148   // Attempt to convert `loop`s into for-loops. This is to try and massage the
149   // output into something that will not cause FXC to choke or misbehave.
150   manager.Add<transform::FoldTrivialSingleUseLets>();
151   manager.Add<transform::LoopToForLoop>();
152 
153   if (!disable_workgroup_init) {
154     // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
155     // ZeroInitWorkgroupMemory may inject new builtin parameters.
156     manager.Add<transform::ZeroInitWorkgroupMemory>();
157   }
158   manager.Add<transform::CanonicalizeEntryPointIO>();
159   // NumWorkgroupsFromUniform must come after CanonicalizeEntryPointIO, as it
160   // assumes that num_workgroups builtins only appear as struct members and are
161   // only accessed directly via member accessors.
162   manager.Add<transform::NumWorkgroupsFromUniform>();
163   manager.Add<transform::SimplifyPointers>();
164   manager.Add<transform::RemovePhonies>();
165   // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
166   // it assumes that the form of the array length argument is &var.array.
167   manager.Add<transform::ArrayLengthFromUniform>();
168   data.Add<transform::ArrayLengthFromUniform::Config>(
169       std::move(array_length_from_uniform_cfg));
170   // DecomposeMemoryAccess must come after:
171   // * InlinePointerLets, as we cannot take the address of calls to
172   //   DecomposeMemoryAccess::Intrinsic.
173   // * Simplify, as we need to fold away the address-of and dereferences of
174   // `*(&(intrinsic_load()))` expressions.
175   // * RemovePhonies, as phonies can be assigned a pointer to a
176   //   non-constructible buffer, or dynamic array, which DMA cannot cope with.
177   manager.Add<transform::DecomposeMemoryAccess>();
178   // CalculateArrayLength must come after DecomposeMemoryAccess, as
179   // DecomposeMemoryAccess special-cases the arrayLength() intrinsic, which
180   // will be transformed by CalculateArrayLength
181   manager.Add<transform::CalculateArrayLength>();
182   manager.Add<transform::ExternalTextureTransform>();
183   manager.Add<transform::PromoteInitializersToConstVar>();
184   manager.Add<transform::PadArrayElements>();
185   manager.Add<transform::AddEmptyEntryPoint>();
186 
187   data.Add<transform::CanonicalizeEntryPointIO::Config>(
188       transform::CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
189   data.Add<transform::NumWorkgroupsFromUniform::Config>(
190       root_constant_binding_point);
191 
192   auto out = manager.Run(in, data);
193 
194   SanitizedResult result;
195   result.program = std::move(out.program);
196   if (auto* res = out.data.Get<transform::ArrayLengthFromUniform::Result>()) {
197     result.used_array_length_from_uniform_indices =
198         std::move(res->used_size_indices);
199   }
200   return result;
201 }
202 
GeneratorImpl(const Program * program)203 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
204 
205 GeneratorImpl::~GeneratorImpl() = default;
206 
Generate()207 bool GeneratorImpl::Generate() {
208   const TypeInfo* last_kind = nullptr;
209   size_t last_padding_line = 0;
210 
211   for (auto* decl : builder_.AST().GlobalDeclarations()) {
212     if (decl->Is<ast::Alias>()) {
213       continue;  // Ignore aliases.
214     }
215 
216     // Emit a new line between declarations if the type of declaration has
217     // changed, or we're about to emit a function
218     auto* kind = &decl->TypeInfo();
219     if (current_buffer_->lines.size() != last_padding_line) {
220       if (last_kind && (last_kind != kind || decl->Is<ast::Function>())) {
221         line();
222         last_padding_line = current_buffer_->lines.size();
223       }
224     }
225     last_kind = kind;
226 
227     if (auto* global = decl->As<ast::Variable>()) {
228       if (!EmitGlobalVariable(global)) {
229         return false;
230       }
231     } else if (auto* str = decl->As<ast::Struct>()) {
232       auto* ty = builder_.Sem().Get(str);
233       auto storage_class_uses = ty->StorageClassUsage();
234       if (storage_class_uses.size() !=
235           (storage_class_uses.count(ast::StorageClass::kStorage) +
236            storage_class_uses.count(ast::StorageClass::kUniform))) {
237         // The structure is used as something other than a storage buffer or
238         // uniform buffer, so it needs to be emitted.
239         // Storage buffer are read and written to via a ByteAddressBuffer
240         // instead of true structure.
241         // Structures used as uniform buffer are read from an array of vectors
242         // instead of true structure.
243         if (!EmitStructType(current_buffer_, ty)) {
244           return false;
245         }
246       }
247     } else if (auto* func = decl->As<ast::Function>()) {
248       if (func->IsEntryPoint()) {
249         if (!EmitEntryPointFunction(func)) {
250           return false;
251         }
252       } else {
253         if (!EmitFunction(func)) {
254           return false;
255         }
256       }
257     } else {
258       TINT_ICE(Writer, diagnostics_)
259           << "unhandled module-scope declaration: " << decl->TypeInfo().name;
260       return false;
261     }
262   }
263 
264   if (!helpers_.lines.empty()) {
265     current_buffer_->Insert(helpers_, 0, 0);
266   }
267 
268   return true;
269 }
270 
EmitDynamicVectorAssignment(const ast::AssignmentStatement * stmt,const sem::Vector * vec)271 bool GeneratorImpl::EmitDynamicVectorAssignment(
272     const ast::AssignmentStatement* stmt,
273     const sem::Vector* vec) {
274   auto name =
275       utils::GetOrCreate(dynamic_vector_write_, vec, [&]() -> std::string {
276         std::string fn;
277         {
278           std::ostringstream ss;
279           if (!EmitType(ss, vec, tint::ast::StorageClass::kInvalid,
280                         ast::Access::kUndefined, "")) {
281             return "";
282           }
283           fn = UniqueIdentifier("set_" + ss.str());
284         }
285         {
286           auto out = line(&helpers_);
287           out << "void " << fn << "(inout ";
288           if (!EmitTypeAndName(out, vec, ast::StorageClass::kInvalid,
289                                ast::Access::kUndefined, "vec")) {
290             return "";
291           }
292           out << ", int idx, ";
293           if (!EmitTypeAndName(out, vec->type(), ast::StorageClass::kInvalid,
294                                ast::Access::kUndefined, "val")) {
295             return "";
296           }
297           out << ") {";
298         }
299         {
300           ScopedIndent si(&helpers_);
301           auto out = line(&helpers_);
302           switch (vec->Width()) {
303             case 2:
304               out << "vec = (idx.xx == int2(0, 1)) ? val.xx : vec;";
305               break;
306             case 3:
307               out << "vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;";
308               break;
309             case 4:
310               out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;";
311               break;
312             default:
313               TINT_UNREACHABLE(Writer, builder_.Diagnostics())
314                   << "invalid vector size " << vec->Width();
315               break;
316           }
317         }
318         line(&helpers_) << "}";
319         line(&helpers_);
320         return fn;
321       });
322 
323   if (name.empty()) {
324     return false;
325   }
326 
327   auto* ast_access_expr = stmt->lhs->As<ast::IndexAccessorExpression>();
328 
329   auto out = line();
330   out << name << "(";
331   if (!EmitExpression(out, ast_access_expr->object)) {
332     return false;
333   }
334   out << ", ";
335   if (!EmitExpression(out, ast_access_expr->index)) {
336     return false;
337   }
338   out << ", ";
339   if (!EmitExpression(out, stmt->rhs)) {
340     return false;
341   }
342   out << ");";
343 
344   return true;
345 }
346 
EmitDynamicMatrixVectorAssignment(const ast::AssignmentStatement * stmt,const sem::Matrix * mat)347 bool GeneratorImpl::EmitDynamicMatrixVectorAssignment(
348     const ast::AssignmentStatement* stmt,
349     const sem::Matrix* mat) {
350   auto name = utils::GetOrCreate(
351       dynamic_matrix_vector_write_, mat, [&]() -> std::string {
352         std::string fn;
353         {
354           std::ostringstream ss;
355           if (!EmitType(ss, mat, tint::ast::StorageClass::kInvalid,
356                         ast::Access::kUndefined, "")) {
357             return "";
358           }
359           fn = UniqueIdentifier("set_vector_" + ss.str());
360         }
361         {
362           auto out = line(&helpers_);
363           out << "void " << fn << "(inout ";
364           if (!EmitTypeAndName(out, mat, ast::StorageClass::kInvalid,
365                                ast::Access::kUndefined, "mat")) {
366             return "";
367           }
368           out << ", int col, ";
369           if (!EmitTypeAndName(out, mat->ColumnType(),
370                                ast::StorageClass::kInvalid,
371                                ast::Access::kUndefined, "val")) {
372             return "";
373           }
374           out << ") {";
375         }
376         {
377           ScopedIndent si(&helpers_);
378           line(&helpers_) << "switch (col) {";
379           {
380             ScopedIndent si2(&helpers_);
381             for (uint32_t i = 0; i < mat->columns(); ++i) {
382               line(&helpers_)
383                   << "case " << i << ": mat[" << i << "] = val; break;";
384             }
385           }
386           line(&helpers_) << "}";
387         }
388         line(&helpers_) << "}";
389         line(&helpers_);
390         return fn;
391       });
392 
393   if (name.empty()) {
394     return false;
395   }
396 
397   auto* ast_access_expr = stmt->lhs->As<ast::IndexAccessorExpression>();
398 
399   auto out = line();
400   out << name << "(";
401   if (!EmitExpression(out, ast_access_expr->object)) {
402     return false;
403   }
404   out << ", ";
405   if (!EmitExpression(out, ast_access_expr->index)) {
406     return false;
407   }
408   out << ", ";
409   if (!EmitExpression(out, stmt->rhs)) {
410     return false;
411   }
412   out << ");";
413 
414   return true;
415 }
416 
EmitDynamicMatrixScalarAssignment(const ast::AssignmentStatement * stmt,const sem::Matrix * mat)417 bool GeneratorImpl::EmitDynamicMatrixScalarAssignment(
418     const ast::AssignmentStatement* stmt,
419     const sem::Matrix* mat) {
420   auto* lhs_col_access = stmt->lhs->As<ast::IndexAccessorExpression>();
421   auto* lhs_row_access =
422       lhs_col_access->object->As<ast::IndexAccessorExpression>();
423 
424   auto name = utils::GetOrCreate(
425       dynamic_matrix_scalar_write_, mat, [&]() -> std::string {
426         std::string fn;
427         {
428           std::ostringstream ss;
429           if (!EmitType(ss, mat, tint::ast::StorageClass::kInvalid,
430                         ast::Access::kUndefined, "")) {
431             return "";
432           }
433           fn = UniqueIdentifier("set_scalar_" + ss.str());
434         }
435         {
436           auto out = line(&helpers_);
437           out << "void " << fn << "(inout ";
438           if (!EmitTypeAndName(out, mat, ast::StorageClass::kInvalid,
439                                ast::Access::kUndefined, "mat")) {
440             return "";
441           }
442           out << ", int col, int row, ";
443           if (!EmitTypeAndName(out, mat->type(), ast::StorageClass::kInvalid,
444                                ast::Access::kUndefined, "val")) {
445             return "";
446           }
447           out << ") {";
448         }
449         {
450           ScopedIndent si(&helpers_);
451           line(&helpers_) << "switch (col) {";
452           {
453             ScopedIndent si2(&helpers_);
454             auto* vec =
455                 TypeOf(lhs_row_access->object)->UnwrapRef()->As<sem::Vector>();
456             for (uint32_t i = 0; i < mat->columns(); ++i) {
457               line(&helpers_) << "case " << i << ":";
458               {
459                 auto vec_name = "mat[" + std::to_string(i) + "]";
460                 ScopedIndent si3(&helpers_);
461                 {
462                   auto out = line(&helpers_);
463                   switch (mat->rows()) {
464                     case 2:
465                       out << vec_name
466                           << " = (row.xx == int2(0, 1)) ? val.xx : " << vec_name
467                           << ";";
468                       break;
469                     case 3:
470                       out << vec_name
471                           << " = (row.xxx == int3(0, 1, 2)) ? val.xxx : "
472                           << vec_name << ";";
473                       break;
474                     case 4:
475                       out << vec_name
476                           << " = (row.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : "
477                           << vec_name << ";";
478                       break;
479                     default:
480                       TINT_UNREACHABLE(Writer, builder_.Diagnostics())
481                           << "invalid vector size " << vec->Width();
482                       break;
483                   }
484                 }
485                 line(&helpers_) << "break;";
486               }
487             }
488           }
489           line(&helpers_) << "}";
490         }
491         line(&helpers_) << "}";
492         line(&helpers_);
493         return fn;
494       });
495 
496   if (name.empty()) {
497     return false;
498   }
499 
500   auto out = line();
501   out << name << "(";
502   if (!EmitExpression(out, lhs_row_access->object)) {
503     return false;
504   }
505   out << ", ";
506   if (!EmitExpression(out, lhs_col_access->index)) {
507     return false;
508   }
509   out << ", ";
510   if (!EmitExpression(out, lhs_row_access->index)) {
511     return false;
512   }
513   out << ", ";
514   if (!EmitExpression(out, stmt->rhs)) {
515     return false;
516   }
517   out << ");";
518 
519   return true;
520 }
521 
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)522 bool GeneratorImpl::EmitIndexAccessor(
523     std::ostream& out,
524     const ast::IndexAccessorExpression* expr) {
525   if (!EmitExpression(out, expr->object)) {
526     return false;
527   }
528   out << "[";
529 
530   if (!EmitExpression(out, expr->index)) {
531     return false;
532   }
533   out << "]";
534 
535   return true;
536 }
537 
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)538 bool GeneratorImpl::EmitBitcast(std::ostream& out,
539                                 const ast::BitcastExpression* expr) {
540   auto* type = TypeOf(expr);
541   if (auto* vec = type->UnwrapRef()->As<sem::Vector>()) {
542     type = vec->type();
543   }
544 
545   if (!type->is_integer_scalar() && !type->is_float_scalar()) {
546     diagnostics_.add_error(diag::System::Writer,
547                            "Unable to do bitcast to type " + type->type_name());
548     return false;
549   }
550 
551   out << "as";
552   if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
553                 "")) {
554     return false;
555   }
556   out << "(";
557   if (!EmitExpression(out, expr->expr)) {
558     return false;
559   }
560   out << ")";
561   return true;
562 }
563 
EmitAssign(const ast::AssignmentStatement * stmt)564 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
565   if (auto* lhs_access = stmt->lhs->As<ast::IndexAccessorExpression>()) {
566     // BUG(crbug.com/tint/1333): work around assignment of scalar to matrices
567     // with at least one dynamic index
568     if (auto* lhs_sub_access =
569             lhs_access->object->As<ast::IndexAccessorExpression>()) {
570       if (auto* mat =
571               TypeOf(lhs_sub_access->object)->UnwrapRef()->As<sem::Matrix>()) {
572         auto* rhs_col_idx_sem = builder_.Sem().Get(lhs_access->index);
573         auto* rhs_row_idx_sem = builder_.Sem().Get(lhs_sub_access->index);
574         if (!rhs_col_idx_sem->ConstantValue().IsValid() ||
575             !rhs_row_idx_sem->ConstantValue().IsValid()) {
576           return EmitDynamicMatrixScalarAssignment(stmt, mat);
577         }
578       }
579     }
580     // BUG(crbug.com/tint/1333): work around assignment of vector to matrices
581     // with dynamic indices
582     const auto* lhs_access_type = TypeOf(lhs_access->object)->UnwrapRef();
583     if (auto* mat = lhs_access_type->As<sem::Matrix>()) {
584       auto* lhs_index_sem = builder_.Sem().Get(lhs_access->index);
585       if (!lhs_index_sem->ConstantValue().IsValid()) {
586         return EmitDynamicMatrixVectorAssignment(stmt, mat);
587       }
588     }
589     // BUG(crbug.com/tint/534): work around assignment to vectors with dynamic
590     // indices
591     if (auto* vec = lhs_access_type->As<sem::Vector>()) {
592       auto* rhs_sem = builder_.Sem().Get(lhs_access->index);
593       if (!rhs_sem->ConstantValue().IsValid()) {
594         return EmitDynamicVectorAssignment(stmt, vec);
595       }
596     }
597   }
598 
599   auto out = line();
600   if (!EmitExpression(out, stmt->lhs)) {
601     return false;
602   }
603   out << " = ";
604   if (!EmitExpression(out, stmt->rhs)) {
605     return false;
606   }
607   out << ";";
608   return true;
609 }
610 
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)611 bool GeneratorImpl::EmitBinary(std::ostream& out,
612                                const ast::BinaryExpression* expr) {
613   if (expr->op == ast::BinaryOp::kLogicalAnd ||
614       expr->op == ast::BinaryOp::kLogicalOr) {
615     auto name = UniqueIdentifier(kTempNamePrefix);
616 
617     {
618       auto pre = line();
619       pre << "bool " << name << " = ";
620       if (!EmitExpression(pre, expr->lhs)) {
621         return false;
622       }
623       pre << ";";
624     }
625 
626     if (expr->op == ast::BinaryOp::kLogicalOr) {
627       line() << "if (!" << name << ") {";
628     } else {
629       line() << "if (" << name << ") {";
630     }
631 
632     {
633       ScopedIndent si(this);
634       auto pre = line();
635       pre << name << " = ";
636       if (!EmitExpression(pre, expr->rhs)) {
637         return false;
638       }
639       pre << ";";
640     }
641 
642     line() << "}";
643 
644     out << "(" << name << ")";
645     return true;
646   }
647 
648   auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
649   auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
650   // Multiplying by a matrix requires the use of `mul` in order to get the
651   // type of multiply we desire.
652   if (expr->op == ast::BinaryOp::kMultiply &&
653       ((lhs_type->Is<sem::Vector>() && rhs_type->Is<sem::Matrix>()) ||
654        (lhs_type->Is<sem::Matrix>() && rhs_type->Is<sem::Vector>()) ||
655        (lhs_type->Is<sem::Matrix>() && rhs_type->Is<sem::Matrix>()))) {
656     // Matrices are transposed, so swap LHS and RHS.
657     out << "mul(";
658     if (!EmitExpression(out, expr->rhs)) {
659       return false;
660     }
661     out << ", ";
662     if (!EmitExpression(out, expr->lhs)) {
663       return false;
664     }
665     out << ")";
666 
667     return true;
668   }
669 
670   out << "(";
671   TINT_DEFER(out << ")");
672 
673   if (!EmitExpression(out, expr->lhs)) {
674     return false;
675   }
676   out << " ";
677 
678   switch (expr->op) {
679     case ast::BinaryOp::kAnd:
680       out << "&";
681       break;
682     case ast::BinaryOp::kOr:
683       out << "|";
684       break;
685     case ast::BinaryOp::kXor:
686       out << "^";
687       break;
688     case ast::BinaryOp::kLogicalAnd:
689     case ast::BinaryOp::kLogicalOr: {
690       // These are both handled above.
691       TINT_UNREACHABLE(Writer, diagnostics_);
692       return false;
693     }
694     case ast::BinaryOp::kEqual:
695       out << "==";
696       break;
697     case ast::BinaryOp::kNotEqual:
698       out << "!=";
699       break;
700     case ast::BinaryOp::kLessThan:
701       out << "<";
702       break;
703     case ast::BinaryOp::kGreaterThan:
704       out << ">";
705       break;
706     case ast::BinaryOp::kLessThanEqual:
707       out << "<=";
708       break;
709     case ast::BinaryOp::kGreaterThanEqual:
710       out << ">=";
711       break;
712     case ast::BinaryOp::kShiftLeft:
713       out << "<<";
714       break;
715     case ast::BinaryOp::kShiftRight:
716       // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
717       // implementation-defined behaviour for negative LHS.  We may have to
718       // generate extra code to implement WGSL-specified behaviour for negative
719       // LHS.
720       out << R"(>>)";
721       break;
722 
723     case ast::BinaryOp::kAdd:
724       out << "+";
725       break;
726     case ast::BinaryOp::kSubtract:
727       out << "-";
728       break;
729     case ast::BinaryOp::kMultiply:
730       out << "*";
731       break;
732     case ast::BinaryOp::kDivide:
733       out << "/";
734 
735       if (auto val = builder_.Sem().Get(expr->rhs)->ConstantValue()) {
736         // Integer divide by zero is a DXC compile error, and undefined behavior
737         // in WGSL. Replace the 0 with 1.
738         if (val.Type()->Is<sem::I32>() && val.Elements()[0].i32 == 0) {
739           out << " 1";
740           return true;
741         }
742         if (val.Type()->Is<sem::U32>() && val.Elements()[0].u32 == 0u) {
743           out << " 1u";
744           return true;
745         }
746       }
747       break;
748     case ast::BinaryOp::kModulo:
749       out << "%";
750       break;
751     case ast::BinaryOp::kNone:
752       diagnostics_.add_error(diag::System::Writer,
753                              "missing binary operation type");
754       return false;
755   }
756   out << " ";
757 
758   if (!EmitExpression(out, expr->rhs)) {
759     return false;
760   }
761 
762   return true;
763 }
764 
EmitStatements(const ast::StatementList & stmts)765 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
766   for (auto* s : stmts) {
767     if (!EmitStatement(s)) {
768       return false;
769     }
770   }
771   return true;
772 }
773 
EmitStatementsWithIndent(const ast::StatementList & stmts)774 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
775   ScopedIndent si(this);
776   return EmitStatements(stmts);
777 }
778 
EmitBlock(const ast::BlockStatement * stmt)779 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
780   line() << "{";
781   if (!EmitStatementsWithIndent(stmt->statements)) {
782     return false;
783   }
784   line() << "}";
785   return true;
786 }
787 
EmitBreak(const ast::BreakStatement *)788 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
789   line() << "break;";
790   return true;
791 }
792 
EmitCall(std::ostream & out,const ast::CallExpression * expr)793 bool GeneratorImpl::EmitCall(std::ostream& out,
794                              const ast::CallExpression* expr) {
795   auto* call = builder_.Sem().Get(expr);
796   auto* target = call->Target();
797 
798   if (auto* func = target->As<sem::Function>()) {
799     return EmitFunctionCall(out, call, func);
800   }
801   if (auto* intrinsic = target->As<sem::Intrinsic>()) {
802     return EmitIntrinsicCall(out, call, intrinsic);
803   }
804   if (auto* conv = target->As<sem::TypeConversion>()) {
805     return EmitTypeConversion(out, call, conv);
806   }
807   if (auto* ctor = target->As<sem::TypeConstructor>()) {
808     return EmitTypeConstructor(out, call, ctor);
809   }
810   TINT_ICE(Writer, diagnostics_)
811       << "unhandled call target: " << target->TypeInfo().name;
812   return false;
813 }
814 
EmitFunctionCall(std::ostream & out,const sem::Call * call,const sem::Function * func)815 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
816                                      const sem::Call* call,
817                                      const sem::Function* func) {
818   auto* expr = call->Declaration();
819 
820   if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
821           func->Declaration()->decorations)) {
822     // Special function generated by the CalculateArrayLength transform for
823     // calling X.GetDimensions(Y)
824     if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
825       return false;
826     }
827     out << ".GetDimensions(";
828     if (!EmitExpression(out, call->Arguments()[1]->Declaration())) {
829       return false;
830     }
831     out << ")";
832     return true;
833   }
834 
835   if (auto* intrinsic =
836           ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
837               func->Declaration()->decorations)) {
838     switch (intrinsic->storage_class) {
839       case ast::StorageClass::kUniform:
840         return EmitUniformBufferAccess(out, expr, intrinsic);
841       case ast::StorageClass::kStorage:
842         return EmitStorageBufferAccess(out, expr, intrinsic);
843       default:
844         TINT_UNREACHABLE(Writer, diagnostics_)
845             << "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
846             << intrinsic->storage_class;
847         return false;
848     }
849   }
850 
851   out << builder_.Symbols().NameFor(func->Declaration()->symbol) << "(";
852 
853   bool first = true;
854   for (auto* arg : call->Arguments()) {
855     if (!first) {
856       out << ", ";
857     }
858     first = false;
859 
860     if (!EmitExpression(out, arg->Declaration())) {
861       return false;
862     }
863   }
864 
865   out << ")";
866   return true;
867 }
868 
EmitIntrinsicCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)869 bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
870                                       const sem::Call* call,
871                                       const sem::Intrinsic* intrinsic) {
872   auto* expr = call->Declaration();
873   if (intrinsic->IsTexture()) {
874     return EmitTextureCall(out, call, intrinsic);
875   }
876   if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
877     return EmitSelectCall(out, expr);
878   }
879   if (intrinsic->Type() == sem::IntrinsicType::kModf) {
880     return EmitModfCall(out, expr, intrinsic);
881   }
882   if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
883     return EmitFrexpCall(out, expr, intrinsic);
884   }
885   if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
886     return EmitIsNormalCall(out, expr, intrinsic);
887   }
888   if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
889     return EmitExpression(out, expr->args[0]);  // [DEPRECATED]
890   }
891   if (intrinsic->IsDataPacking()) {
892     return EmitDataPackingCall(out, expr, intrinsic);
893   }
894   if (intrinsic->IsDataUnpacking()) {
895     return EmitDataUnpackingCall(out, expr, intrinsic);
896   }
897   if (intrinsic->IsBarrier()) {
898     return EmitBarrierCall(out, intrinsic);
899   }
900   if (intrinsic->IsAtomic()) {
901     return EmitWorkgroupAtomicCall(out, expr, intrinsic);
902   }
903 
904   auto name = generate_builtin_name(intrinsic);
905   if (name.empty()) {
906     return false;
907   }
908 
909   out << name << "(";
910 
911   bool first = true;
912   for (auto* arg : call->Arguments()) {
913     if (!first) {
914       out << ", ";
915     }
916     first = false;
917 
918     if (!EmitExpression(out, arg->Declaration())) {
919       return false;
920     }
921   }
922 
923   out << ")";
924   return true;
925 }
926 
EmitTypeConversion(std::ostream & out,const sem::Call * call,const sem::TypeConversion * conv)927 bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
928                                        const sem::Call* call,
929                                        const sem::TypeConversion* conv) {
930   if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
931                 ast::Access::kReadWrite, "")) {
932     return false;
933   }
934   out << "(";
935 
936   if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
937     return false;
938   }
939 
940   out << ")";
941   return true;
942 }
943 
EmitTypeConstructor(std::ostream & out,const sem::Call * call,const sem::TypeConstructor * ctor)944 bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
945                                         const sem::Call* call,
946                                         const sem::TypeConstructor* ctor) {
947   auto* type = call->Type();
948 
949   // If the type constructor is empty then we need to construct with the zero
950   // value for all components.
951   if (call->Arguments().empty()) {
952     return EmitZeroValue(out, type);
953   }
954 
955   bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
956 
957   // For single-value vector initializers, swizzle the scalar to the right
958   // vector dimension using .x
959   const bool is_single_value_vector_init =
960       type->is_scalar_vector() && call->Arguments().size() == 1 &&
961       ctor->Parameters()[0]->Type()->is_scalar();
962 
963   auto it = structure_builders_.find(As<sem::Struct>(type));
964   if (it != structure_builders_.end()) {
965     out << it->second << "(";
966     brackets = false;
967   } else if (brackets) {
968     out << "{";
969   } else {
970     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
971                   "")) {
972       return false;
973     }
974     out << "(";
975   }
976 
977   if (is_single_value_vector_init) {
978     out << "(";
979   }
980 
981   bool first = true;
982   for (auto* e : call->Arguments()) {
983     if (!first) {
984       out << ", ";
985     }
986     first = false;
987 
988     if (!EmitExpression(out, e->Declaration())) {
989       return false;
990     }
991   }
992 
993   if (is_single_value_vector_init) {
994     out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
995   }
996 
997   out << (brackets ? "}" : ")");
998   return true;
999 }
1000 
EmitUniformBufferAccess(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1001 bool GeneratorImpl::EmitUniformBufferAccess(
1002     std::ostream& out,
1003     const ast::CallExpression* expr,
1004     const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1005   const auto& args = expr->args;
1006   auto* offset_arg = builder_.Sem().Get(args[1]);
1007 
1008   uint32_t scalar_offset_value = 0;
1009   std::string scalar_offset_expr;
1010 
1011   // If true, use scalar_offset_value, otherwise use scalar_offset_expr
1012   bool scalar_offset_constant = false;
1013 
1014   if (auto val = offset_arg->ConstantValue()) {
1015     TINT_ASSERT(Writer, val.Type()->Is<sem::U32>());
1016     scalar_offset_value = val.Elements()[0].u32;
1017     scalar_offset_value /= 4;  // bytes -> scalar index
1018     scalar_offset_constant = true;
1019   }
1020 
1021   if (!scalar_offset_constant) {
1022     // UBO offset not compile-time known.
1023     // Calculate the scalar offset into a temporary.
1024     scalar_offset_expr = UniqueIdentifier("scalar_offset");
1025     auto pre = line();
1026     pre << "const uint " << scalar_offset_expr << " = (";
1027     if (!EmitExpression(pre, args[1])) {  // offset
1028       return false;
1029     }
1030     pre << ") / 4;";
1031   }
1032 
1033   using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1034   using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
1035   switch (intrinsic->op) {
1036     case Op::kLoad: {
1037       auto cast = [&](const char* to, auto&& load) {
1038         out << to << "(";
1039         auto result = load();
1040         out << ")";
1041         return result;
1042       };
1043       auto load_scalar = [&]() {
1044         if (!EmitExpression(out, args[0])) {  // buffer
1045           return false;
1046         }
1047         if (scalar_offset_constant) {
1048           char swizzle[] = {'x', 'y', 'z', 'w'};
1049           out << "[" << (scalar_offset_value / 4) << "]."
1050               << swizzle[scalar_offset_value & 3];
1051         } else {
1052           out << "[" << scalar_offset_expr << " / 4][" << scalar_offset_expr
1053               << " % 4]";
1054         }
1055         return true;
1056       };
1057       // Has a minimum alignment of 8 bytes, so is either .xy or .zw
1058       auto load_vec2 = [&] {
1059         if (scalar_offset_constant) {
1060           if (!EmitExpression(out, args[0])) {  // buffer
1061             return false;
1062           }
1063           out << "[" << (scalar_offset_value / 4) << "]";
1064           out << ((scalar_offset_value & 2) == 0 ? ".xy" : ".zw");
1065         } else {
1066           std::string ubo_load = UniqueIdentifier("ubo_load");
1067           {
1068             auto pre = line();
1069             pre << "uint4 " << ubo_load << " = ";
1070             if (!EmitExpression(pre, args[0])) {  // buffer
1071               return false;
1072             }
1073             pre << "[" << scalar_offset_expr << " / 4];";
1074           }
1075           out << "((" << scalar_offset_expr << " & 2) ? " << ubo_load
1076               << ".zw : " << ubo_load << ".xy)";
1077         }
1078         return true;
1079       };
1080       // vec4 has a minimum alignment of 16 bytes, easiest case
1081       auto load_vec4 = [&] {
1082         if (!EmitExpression(out, args[0])) {  // buffer
1083           return false;
1084         }
1085         if (scalar_offset_constant) {
1086           out << "[" << (scalar_offset_value / 4) << "]";
1087         } else {
1088           out << "[" << scalar_offset_expr << " / 4]";
1089         }
1090         return true;
1091       };
1092       // vec3 has a minimum alignment of 16 bytes, so is just a .xyz swizzle
1093       auto load_vec3 = [&] {
1094         if (!load_vec4()) {
1095           return false;
1096         }
1097         out << ".xyz";
1098         return true;
1099       };
1100       switch (intrinsic->type) {
1101         case DataType::kU32:
1102           return load_scalar();
1103         case DataType::kF32:
1104           return cast("asfloat", load_scalar);
1105         case DataType::kI32:
1106           return cast("asint", load_scalar);
1107         case DataType::kVec2U32:
1108           return load_vec2();
1109         case DataType::kVec2F32:
1110           return cast("asfloat", load_vec2);
1111         case DataType::kVec2I32:
1112           return cast("asint", load_vec2);
1113         case DataType::kVec3U32:
1114           return load_vec3();
1115         case DataType::kVec3F32:
1116           return cast("asfloat", load_vec3);
1117         case DataType::kVec3I32:
1118           return cast("asint", load_vec3);
1119         case DataType::kVec4U32:
1120           return load_vec4();
1121         case DataType::kVec4F32:
1122           return cast("asfloat", load_vec4);
1123         case DataType::kVec4I32:
1124           return cast("asint", load_vec4);
1125       }
1126       TINT_UNREACHABLE(Writer, diagnostics_)
1127           << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1128           << static_cast<int>(intrinsic->type);
1129       return false;
1130     }
1131     default:
1132       break;
1133   }
1134   TINT_UNREACHABLE(Writer, diagnostics_)
1135       << "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
1136       << static_cast<int>(intrinsic->op);
1137   return false;
1138 }
1139 
EmitStorageBufferAccess(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1140 bool GeneratorImpl::EmitStorageBufferAccess(
1141     std::ostream& out,
1142     const ast::CallExpression* expr,
1143     const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1144   const auto& args = expr->args;
1145 
1146   using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1147   using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
1148   switch (intrinsic->op) {
1149     case Op::kLoad: {
1150       auto load = [&](const char* cast, int n) {
1151         if (cast) {
1152           out << cast << "(";
1153         }
1154         if (!EmitExpression(out, args[0])) {  // buffer
1155           return false;
1156         }
1157         out << ".Load";
1158         if (n > 1) {
1159           out << n;
1160         }
1161         ScopedParen sp(out);
1162         if (!EmitExpression(out, args[1])) {  // offset
1163           return false;
1164         }
1165         if (cast) {
1166           out << ")";
1167         }
1168         return true;
1169       };
1170       switch (intrinsic->type) {
1171         case DataType::kU32:
1172           return load(nullptr, 1);
1173         case DataType::kF32:
1174           return load("asfloat", 1);
1175         case DataType::kI32:
1176           return load("asint", 1);
1177         case DataType::kVec2U32:
1178           return load(nullptr, 2);
1179         case DataType::kVec2F32:
1180           return load("asfloat", 2);
1181         case DataType::kVec2I32:
1182           return load("asint", 2);
1183         case DataType::kVec3U32:
1184           return load(nullptr, 3);
1185         case DataType::kVec3F32:
1186           return load("asfloat", 3);
1187         case DataType::kVec3I32:
1188           return load("asint", 3);
1189         case DataType::kVec4U32:
1190           return load(nullptr, 4);
1191         case DataType::kVec4F32:
1192           return load("asfloat", 4);
1193         case DataType::kVec4I32:
1194           return load("asint", 4);
1195       }
1196       TINT_UNREACHABLE(Writer, diagnostics_)
1197           << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1198           << static_cast<int>(intrinsic->type);
1199       return false;
1200     }
1201 
1202     case Op::kStore: {
1203       auto store = [&](int n) {
1204         if (!EmitExpression(out, args[0])) {  // buffer
1205           return false;
1206         }
1207         out << ".Store";
1208         if (n > 1) {
1209           out << n;
1210         }
1211         ScopedParen sp1(out);
1212         if (!EmitExpression(out, args[1])) {  // offset
1213           return false;
1214         }
1215         out << ", asuint";
1216         ScopedParen sp2(out);
1217         if (!EmitExpression(out, args[2])) {  // value
1218           return false;
1219         }
1220         return true;
1221       };
1222       switch (intrinsic->type) {
1223         case DataType::kU32:
1224           return store(1);
1225         case DataType::kF32:
1226           return store(1);
1227         case DataType::kI32:
1228           return store(1);
1229         case DataType::kVec2U32:
1230           return store(2);
1231         case DataType::kVec2F32:
1232           return store(2);
1233         case DataType::kVec2I32:
1234           return store(2);
1235         case DataType::kVec3U32:
1236           return store(3);
1237         case DataType::kVec3F32:
1238           return store(3);
1239         case DataType::kVec3I32:
1240           return store(3);
1241         case DataType::kVec4U32:
1242           return store(4);
1243         case DataType::kVec4F32:
1244           return store(4);
1245         case DataType::kVec4I32:
1246           return store(4);
1247       }
1248       TINT_UNREACHABLE(Writer, diagnostics_)
1249           << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1250           << static_cast<int>(intrinsic->type);
1251       return false;
1252     }
1253 
1254     case Op::kAtomicLoad:
1255     case Op::kAtomicStore:
1256     case Op::kAtomicAdd:
1257     case Op::kAtomicSub:
1258     case Op::kAtomicMax:
1259     case Op::kAtomicMin:
1260     case Op::kAtomicAnd:
1261     case Op::kAtomicOr:
1262     case Op::kAtomicXor:
1263     case Op::kAtomicExchange:
1264     case Op::kAtomicCompareExchangeWeak:
1265       return EmitStorageAtomicCall(out, expr, intrinsic);
1266   }
1267 
1268   TINT_UNREACHABLE(Writer, diagnostics_)
1269       << "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
1270       << static_cast<int>(intrinsic->op);
1271   return false;
1272 }
1273 
EmitStorageAtomicCall(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1274 bool GeneratorImpl::EmitStorageAtomicCall(
1275     std::ostream& out,
1276     const ast::CallExpression* expr,
1277     const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1278   using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1279 
1280   auto* result_ty = TypeOf(expr);
1281 
1282   auto& buf = helpers_;
1283 
1284   // generate_helper() generates a helper function that translates the
1285   // DecomposeMemoryAccess::Intrinsic call into the corresponding HLSL
1286   // atomic intrinsic function.
1287   auto generate_helper = [&]() -> std::string {
1288     auto rmw = [&](const char* wgsl, const char* hlsl) -> std::string {
1289       auto name = UniqueIdentifier(wgsl);
1290       {
1291         auto fn = line(&buf);
1292         if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1293                              ast::Access::kUndefined, name)) {
1294           return "";
1295         }
1296         fn << "(RWByteAddressBuffer buffer, uint offset, ";
1297         if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1298                              ast::Access::kUndefined, "value")) {
1299           return "";
1300         }
1301         fn << ") {";
1302       }
1303 
1304       buf.IncrementIndent();
1305       TINT_DEFER({
1306         buf.DecrementIndent();
1307         line(&buf) << "}";
1308         line(&buf);
1309       });
1310 
1311       {
1312         auto l = line(&buf);
1313         if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1314                              ast::Access::kUndefined, "original_value")) {
1315           return "";
1316         }
1317         l << " = 0;";
1318       }
1319       {
1320         auto l = line(&buf);
1321         l << "buffer." << hlsl << "(offset, ";
1322         if (intrinsic->op == Op::kAtomicSub) {
1323           l << "-";
1324         }
1325         l << "value, original_value);";
1326       }
1327       line(&buf) << "return original_value;";
1328       return name;
1329     };
1330 
1331     switch (intrinsic->op) {
1332       case Op::kAtomicAdd:
1333         return rmw("atomicAdd", "InterlockedAdd");
1334 
1335       case Op::kAtomicSub:
1336         // Use add with the operand negated.
1337         return rmw("atomicSub", "InterlockedAdd");
1338 
1339       case Op::kAtomicMax:
1340         return rmw("atomicMax", "InterlockedMax");
1341 
1342       case Op::kAtomicMin:
1343         return rmw("atomicMin", "InterlockedMin");
1344 
1345       case Op::kAtomicAnd:
1346         return rmw("atomicAnd", "InterlockedAnd");
1347 
1348       case Op::kAtomicOr:
1349         return rmw("atomicOr", "InterlockedOr");
1350 
1351       case Op::kAtomicXor:
1352         return rmw("atomicXor", "InterlockedXor");
1353 
1354       case Op::kAtomicExchange:
1355         return rmw("atomicExchange", "InterlockedExchange");
1356 
1357       case Op::kAtomicLoad: {
1358         // HLSL does not have an InterlockedLoad, so we emulate it with
1359         // InterlockedOr using 0 as the OR value
1360         auto name = UniqueIdentifier("atomicLoad");
1361         {
1362           auto fn = line(&buf);
1363           if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1364                                ast::Access::kUndefined, name)) {
1365             return "";
1366           }
1367           fn << "(RWByteAddressBuffer buffer, uint offset) {";
1368         }
1369 
1370         buf.IncrementIndent();
1371         TINT_DEFER({
1372           buf.DecrementIndent();
1373           line(&buf) << "}";
1374           line(&buf);
1375         });
1376 
1377         {
1378           auto l = line(&buf);
1379           if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1380                                ast::Access::kUndefined, "value")) {
1381             return "";
1382           }
1383           l << " = 0;";
1384         }
1385 
1386         line(&buf) << "buffer.InterlockedOr(offset, 0, value);";
1387         line(&buf) << "return value;";
1388         return name;
1389       }
1390       case Op::kAtomicStore: {
1391         // HLSL does not have an InterlockedStore, so we emulate it with
1392         // InterlockedExchange and discard the returned value
1393         auto* value_ty = TypeOf(expr->args[2])->UnwrapRef();
1394         auto name = UniqueIdentifier("atomicStore");
1395         {
1396           auto fn = line(&buf);
1397           fn << "void " << name << "(RWByteAddressBuffer buffer, uint offset, ";
1398           if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1399                                ast::Access::kUndefined, "value")) {
1400             return "";
1401           }
1402           fn << ") {";
1403         }
1404 
1405         buf.IncrementIndent();
1406         TINT_DEFER({
1407           buf.DecrementIndent();
1408           line(&buf) << "}";
1409           line(&buf);
1410         });
1411 
1412         {
1413           auto l = line(&buf);
1414           if (!EmitTypeAndName(l, value_ty, ast::StorageClass::kNone,
1415                                ast::Access::kUndefined, "ignored")) {
1416             return "";
1417           }
1418           l << ";";
1419         }
1420         line(&buf) << "buffer.InterlockedExchange(offset, value, ignored);";
1421         return name;
1422       }
1423       case Op::kAtomicCompareExchangeWeak: {
1424         auto* value_ty = TypeOf(expr->args[2])->UnwrapRef();
1425 
1426         auto name = UniqueIdentifier("atomicCompareExchangeWeak");
1427         {
1428           auto fn = line(&buf);
1429           if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1430                                ast::Access::kUndefined, name)) {
1431             return "";
1432           }
1433           fn << "(RWByteAddressBuffer buffer, uint offset, ";
1434           if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1435                                ast::Access::kUndefined, "compare")) {
1436             return "";
1437           }
1438           fn << ", ";
1439           if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1440                                ast::Access::kUndefined, "value")) {
1441             return "";
1442           }
1443           fn << ") {";
1444         }
1445 
1446         buf.IncrementIndent();
1447         TINT_DEFER({
1448           buf.DecrementIndent();
1449           line(&buf) << "}";
1450           line(&buf);
1451         });
1452 
1453         {  // T result = {0, 0};
1454           auto l = line(&buf);
1455           if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1456                                ast::Access::kUndefined, "result")) {
1457             return "";
1458           }
1459           l << " = {0, 0};";
1460         }
1461         line(&buf) << "buffer.InterlockedCompareExchange(offset, compare, "
1462                       "value, result.x);";
1463         line(&buf) << "result.y = result.x == compare;";
1464         line(&buf) << "return result;";
1465         return name;
1466       }
1467       default:
1468         break;
1469     }
1470     TINT_UNREACHABLE(Writer, diagnostics_)
1471         << "unsupported atomic DecomposeMemoryAccess::Intrinsic::Op: "
1472         << static_cast<int>(intrinsic->op);
1473     return "";
1474   };
1475 
1476   auto func = utils::GetOrCreate(dma_intrinsics_,
1477                                  DMAIntrinsic{intrinsic->op, intrinsic->type},
1478                                  generate_helper);
1479   if (func.empty()) {
1480     return false;
1481   }
1482 
1483   out << func;
1484   {
1485     ScopedParen sp(out);
1486     bool first = true;
1487     for (auto* arg : expr->args) {
1488       if (!first) {
1489         out << ", ";
1490       }
1491       first = false;
1492       if (!EmitExpression(out, arg)) {
1493         return false;
1494       }
1495     }
1496   }
1497 
1498   return true;
1499 }
1500 
EmitWorkgroupAtomicCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1501 bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& out,
1502                                             const ast::CallExpression* expr,
1503                                             const sem::Intrinsic* intrinsic) {
1504   std::string result = UniqueIdentifier("atomic_result");
1505 
1506   if (!intrinsic->ReturnType()->Is<sem::Void>()) {
1507     auto pre = line();
1508     if (!EmitTypeAndName(pre, intrinsic->ReturnType(), ast::StorageClass::kNone,
1509                          ast::Access::kUndefined, result)) {
1510       return false;
1511     }
1512     pre << " = ";
1513     if (!EmitZeroValue(pre, intrinsic->ReturnType())) {
1514       return false;
1515     }
1516     pre << ";";
1517   }
1518 
1519   auto call = [&](const char* name) {
1520     auto pre = line();
1521     pre << name;
1522 
1523     {
1524       ScopedParen sp(pre);
1525       for (size_t i = 0; i < expr->args.size(); i++) {
1526         auto* arg = expr->args[i];
1527         if (i > 0) {
1528           pre << ", ";
1529         }
1530         if (i == 1 && intrinsic->Type() == sem::IntrinsicType::kAtomicSub) {
1531           // Sub uses InterlockedAdd with the operand negated.
1532           pre << "-";
1533         }
1534         if (!EmitExpression(pre, arg)) {
1535           return false;
1536         }
1537       }
1538 
1539       pre << ", " << result;
1540     }
1541 
1542     pre << ";";
1543 
1544     out << result;
1545     return true;
1546   };
1547 
1548   switch (intrinsic->Type()) {
1549     case sem::IntrinsicType::kAtomicLoad: {
1550       // HLSL does not have an InterlockedLoad, so we emulate it with
1551       // InterlockedOr using 0 as the OR value
1552       auto pre = line();
1553       pre << "InterlockedOr";
1554       {
1555         ScopedParen sp(pre);
1556         if (!EmitExpression(pre, expr->args[0])) {
1557           return false;
1558         }
1559         pre << ", 0, " << result;
1560       }
1561       pre << ";";
1562 
1563       out << result;
1564       return true;
1565     }
1566     case sem::IntrinsicType::kAtomicStore: {
1567       // HLSL does not have an InterlockedStore, so we emulate it with
1568       // InterlockedExchange and discard the returned value
1569       {  // T result = 0;
1570         auto pre = line();
1571         auto* value_ty = intrinsic->Parameters()[1]->Type()->UnwrapRef();
1572         if (!EmitTypeAndName(pre, value_ty, ast::StorageClass::kNone,
1573                              ast::Access::kUndefined, result)) {
1574           return false;
1575         }
1576         pre << " = ";
1577         if (!EmitZeroValue(pre, value_ty)) {
1578           return false;
1579         }
1580         pre << ";";
1581       }
1582 
1583       out << "InterlockedExchange";
1584       {
1585         ScopedParen sp(out);
1586         if (!EmitExpression(out, expr->args[0])) {
1587           return false;
1588         }
1589         out << ", ";
1590         if (!EmitExpression(out, expr->args[1])) {
1591           return false;
1592         }
1593         out << ", " << result;
1594       }
1595       return true;
1596     }
1597     case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
1598       auto* dest = expr->args[0];
1599       auto* compare_value = expr->args[1];
1600       auto* value = expr->args[2];
1601 
1602       std::string compare = UniqueIdentifier("atomic_compare_value");
1603 
1604       {  // T compare_value = <compare_value>;
1605         auto pre = line();
1606         if (!EmitTypeAndName(pre, TypeOf(compare_value),
1607                              ast::StorageClass::kNone, ast::Access::kUndefined,
1608                              compare)) {
1609           return false;
1610         }
1611         pre << " = ";
1612         if (!EmitExpression(pre, compare_value)) {
1613           return false;
1614         }
1615         pre << ";";
1616       }
1617 
1618       {  // InterlockedCompareExchange(dst, compare, value, result.x);
1619         auto pre = line();
1620         pre << "InterlockedCompareExchange";
1621         {
1622           ScopedParen sp(pre);
1623           if (!EmitExpression(pre, dest)) {
1624             return false;
1625           }
1626           pre << ", " << compare << ", ";
1627           if (!EmitExpression(pre, value)) {
1628             return false;
1629           }
1630           pre << ", " << result << ".x";
1631         }
1632         pre << ";";
1633       }
1634 
1635       {  // result.y = result.x == compare;
1636         line() << result << ".y = " << result << ".x == " << compare << ";";
1637       }
1638 
1639       out << result;
1640       return true;
1641     }
1642 
1643     case sem::IntrinsicType::kAtomicAdd:
1644     case sem::IntrinsicType::kAtomicSub:
1645       return call("InterlockedAdd");
1646 
1647     case sem::IntrinsicType::kAtomicMax:
1648       return call("InterlockedMax");
1649 
1650     case sem::IntrinsicType::kAtomicMin:
1651       return call("InterlockedMin");
1652 
1653     case sem::IntrinsicType::kAtomicAnd:
1654       return call("InterlockedAnd");
1655 
1656     case sem::IntrinsicType::kAtomicOr:
1657       return call("InterlockedOr");
1658 
1659     case sem::IntrinsicType::kAtomicXor:
1660       return call("InterlockedXor");
1661 
1662     case sem::IntrinsicType::kAtomicExchange:
1663       return call("InterlockedExchange");
1664 
1665     default:
1666       break;
1667   }
1668 
1669   TINT_UNREACHABLE(Writer, diagnostics_)
1670       << "unsupported atomic intrinsic: " << intrinsic->Type();
1671   return false;
1672 }
1673 
EmitSelectCall(std::ostream & out,const ast::CallExpression * expr)1674 bool GeneratorImpl::EmitSelectCall(std::ostream& out,
1675                                    const ast::CallExpression* expr) {
1676   auto* expr_false = expr->args[0];
1677   auto* expr_true = expr->args[1];
1678   auto* expr_cond = expr->args[2];
1679   ScopedParen paren(out);
1680   if (!EmitExpression(out, expr_cond)) {
1681     return false;
1682   }
1683 
1684   out << " ? ";
1685 
1686   if (!EmitExpression(out, expr_true)) {
1687     return false;
1688   }
1689 
1690   out << " : ";
1691 
1692   if (!EmitExpression(out, expr_false)) {
1693     return false;
1694   }
1695 
1696   return true;
1697 }
1698 
EmitModfCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1699 bool GeneratorImpl::EmitModfCall(std::ostream& out,
1700                                  const ast::CallExpression* expr,
1701                                  const sem::Intrinsic* intrinsic) {
1702   return CallIntrinsicHelper(
1703       out, expr, intrinsic,
1704       [&](TextBuffer* b, const std::vector<std::string>& params) {
1705         auto* ty = intrinsic->Parameters()[0]->Type();
1706         auto in = params[0];
1707 
1708         std::string width;
1709         if (auto* vec = ty->As<sem::Vector>()) {
1710           width = std::to_string(vec->Width());
1711         }
1712 
1713         // Emit the builtin return type unique to this overload. This does not
1714         // exist in the AST, so it will not be generated in Generate().
1715         if (!EmitStructType(&helpers_,
1716                             intrinsic->ReturnType()->As<sem::Struct>())) {
1717           return false;
1718         }
1719 
1720         line(b) << "float" << width << " whole;";
1721         line(b) << "float" << width << " fract = modf(" << in << ", whole);";
1722         {
1723           auto l = line(b);
1724           if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
1725                         ast::Access::kUndefined, "")) {
1726             return false;
1727           }
1728           l << " result = {fract, whole};";
1729         }
1730         line(b) << "return result;";
1731         return true;
1732       });
1733 }
1734 
EmitFrexpCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1735 bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
1736                                   const ast::CallExpression* expr,
1737                                   const sem::Intrinsic* intrinsic) {
1738   return CallIntrinsicHelper(
1739       out, expr, intrinsic,
1740       [&](TextBuffer* b, const std::vector<std::string>& params) {
1741         auto* ty = intrinsic->Parameters()[0]->Type();
1742         auto in = params[0];
1743 
1744         std::string width;
1745         if (auto* vec = ty->As<sem::Vector>()) {
1746           width = std::to_string(vec->Width());
1747         }
1748 
1749         // Emit the builtin return type unique to this overload. This does not
1750         // exist in the AST, so it will not be generated in Generate().
1751         if (!EmitStructType(&helpers_,
1752                             intrinsic->ReturnType()->As<sem::Struct>())) {
1753           return false;
1754         }
1755 
1756         line(b) << "float" << width << " exp;";
1757         line(b) << "float" << width << " sig = frexp(" << in << ", exp);";
1758         {
1759           auto l = line(b);
1760           if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
1761                         ast::Access::kUndefined, "")) {
1762             return false;
1763           }
1764           l << " result = {sig, int" << width << "(exp)};";
1765         }
1766         line(b) << "return result;";
1767         return true;
1768       });
1769 }
1770 
EmitIsNormalCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1771 bool GeneratorImpl::EmitIsNormalCall(std::ostream& out,
1772                                      const ast::CallExpression* expr,
1773                                      const sem::Intrinsic* intrinsic) {
1774   // HLSL doesn't have a isNormal intrinsic, we need to emulate
1775   return CallIntrinsicHelper(
1776       out, expr, intrinsic,
1777       [&](TextBuffer* b, const std::vector<std::string>& params) {
1778         auto* input_ty = intrinsic->Parameters()[0]->Type();
1779 
1780         std::string width;
1781         if (auto* vec = input_ty->As<sem::Vector>()) {
1782           width = std::to_string(vec->Width());
1783         }
1784 
1785         constexpr auto* kExponentMask = "0x7f80000";
1786         constexpr auto* kMinNormalExponent = "0x0080000";
1787         constexpr auto* kMaxNormalExponent = "0x7f00000";
1788 
1789         line(b) << "uint" << width << " exponent = asuint(" << params[0]
1790                 << ") & " << kExponentMask << ";";
1791         line(b) << "uint" << width << " clamped = "
1792                 << "clamp(exponent, " << kMinNormalExponent << ", "
1793                 << kMaxNormalExponent << ");";
1794         line(b) << "return clamped == exponent;";
1795         return true;
1796       });
1797 }
1798 
EmitDataPackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1799 bool GeneratorImpl::EmitDataPackingCall(std::ostream& out,
1800                                         const ast::CallExpression* expr,
1801                                         const sem::Intrinsic* intrinsic) {
1802   return CallIntrinsicHelper(
1803       out, expr, intrinsic,
1804       [&](TextBuffer* b, const std::vector<std::string>& params) {
1805         uint32_t dims = 2;
1806         bool is_signed = false;
1807         uint32_t scale = 65535;
1808         if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
1809             intrinsic->Type() == sem::IntrinsicType::kPack4x8unorm) {
1810           dims = 4;
1811           scale = 255;
1812         }
1813         if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
1814             intrinsic->Type() == sem::IntrinsicType::kPack2x16snorm) {
1815           is_signed = true;
1816           scale = (scale - 1) / 2;
1817         }
1818         switch (intrinsic->Type()) {
1819           case sem::IntrinsicType::kPack4x8snorm:
1820           case sem::IntrinsicType::kPack4x8unorm:
1821           case sem::IntrinsicType::kPack2x16snorm:
1822           case sem::IntrinsicType::kPack2x16unorm: {
1823             {
1824               auto l = line(b);
1825               l << (is_signed ? "" : "u") << "int" << dims
1826                 << " i = " << (is_signed ? "" : "u") << "int" << dims
1827                 << "(round(clamp(" << params[0] << ", "
1828                 << (is_signed ? "-1.0" : "0.0") << ", 1.0) * " << scale
1829                 << ".0))";
1830               if (is_signed) {
1831                 l << " & " << (dims == 4 ? "0xff" : "0xffff");
1832               }
1833               l << ";";
1834             }
1835             {
1836               auto l = line(b);
1837               l << "return ";
1838               if (is_signed) {
1839                 l << "asuint";
1840               }
1841               l << "(i.x | i.y << " << (32 / dims);
1842               if (dims == 4) {
1843                 l << " | i.z << 16 | i.w << 24";
1844               }
1845               l << ");";
1846             }
1847             break;
1848           }
1849           case sem::IntrinsicType::kPack2x16float: {
1850             line(b) << "uint2 i = f32tof16(" << params[0] << ");";
1851             line(b) << "return i.x | (i.y << 16);";
1852             break;
1853           }
1854           default:
1855             diagnostics_.add_error(
1856                 diag::System::Writer,
1857                 "Internal error: unhandled data packing intrinsic");
1858             return false;
1859         }
1860 
1861         return true;
1862       });
1863 }
1864 
EmitDataUnpackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1865 bool GeneratorImpl::EmitDataUnpackingCall(std::ostream& out,
1866                                           const ast::CallExpression* expr,
1867                                           const sem::Intrinsic* intrinsic) {
1868   return CallIntrinsicHelper(
1869       out, expr, intrinsic,
1870       [&](TextBuffer* b, const std::vector<std::string>& params) {
1871         uint32_t dims = 2;
1872         bool is_signed = false;
1873         uint32_t scale = 65535;
1874         if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1875             intrinsic->Type() == sem::IntrinsicType::kUnpack4x8unorm) {
1876           dims = 4;
1877           scale = 255;
1878         }
1879         if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1880             intrinsic->Type() == sem::IntrinsicType::kUnpack2x16snorm) {
1881           is_signed = true;
1882           scale = (scale - 1) / 2;
1883         }
1884         switch (intrinsic->Type()) {
1885           case sem::IntrinsicType::kUnpack4x8snorm:
1886           case sem::IntrinsicType::kUnpack2x16snorm: {
1887             line(b) << "int j = int(" << params[0] << ");";
1888             {  // Perform sign extension on the converted values.
1889               auto l = line(b);
1890               l << "int" << dims << " i = int" << dims << "(";
1891               if (dims == 2) {
1892                 l << "j << 16, j) >> 16";
1893               } else {
1894                 l << "j << 24, j << 16, j << 8, j) >> 24";
1895               }
1896               l << ";";
1897             }
1898             line(b) << "return clamp(float" << dims << "(i) / " << scale
1899                     << ".0, " << (is_signed ? "-1.0" : "0.0") << ", 1.0);";
1900             break;
1901           }
1902           case sem::IntrinsicType::kUnpack4x8unorm:
1903           case sem::IntrinsicType::kUnpack2x16unorm: {
1904             line(b) << "uint j = " << params[0] << ";";
1905             {
1906               auto l = line(b);
1907               l << "uint" << dims << " i = uint" << dims << "(";
1908               l << "j & " << (dims == 2 ? "0xffff" : "0xff") << ", ";
1909               if (dims == 4) {
1910                 l << "(j >> " << (32 / dims)
1911                   << ") & 0xff, (j >> 16) & 0xff, j >> 24";
1912               } else {
1913                 l << "j >> " << (32 / dims);
1914               }
1915               l << ");";
1916             }
1917             line(b) << "return float" << dims << "(i) / " << scale << ".0;";
1918             break;
1919           }
1920           case sem::IntrinsicType::kUnpack2x16float:
1921             line(b) << "uint i = " << params[0] << ";";
1922             line(b) << "return f16tof32(uint2(i & 0xffff, i >> 16));";
1923             break;
1924           default:
1925             diagnostics_.add_error(
1926                 diag::System::Writer,
1927                 "Internal error: unhandled data packing intrinsic");
1928             return false;
1929         }
1930 
1931         return true;
1932       });
1933 }
1934 
EmitBarrierCall(std::ostream & out,const sem::Intrinsic * intrinsic)1935 bool GeneratorImpl::EmitBarrierCall(std::ostream& out,
1936                                     const sem::Intrinsic* intrinsic) {
1937   // TODO(crbug.com/tint/661): Combine sequential barriers to a single
1938   // instruction.
1939   if (intrinsic->Type() == sem::IntrinsicType::kWorkgroupBarrier) {
1940     out << "GroupMemoryBarrierWithGroupSync()";
1941   } else if (intrinsic->Type() == sem::IntrinsicType::kStorageBarrier) {
1942     out << "DeviceMemoryBarrierWithGroupSync()";
1943   } else {
1944     TINT_UNREACHABLE(Writer, diagnostics_)
1945         << "unexpected barrier intrinsic type " << sem::str(intrinsic->Type());
1946     return false;
1947   }
1948   return true;
1949 }
1950 
EmitTextureCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)1951 bool GeneratorImpl::EmitTextureCall(std::ostream& out,
1952                                     const sem::Call* call,
1953                                     const sem::Intrinsic* intrinsic) {
1954   using Usage = sem::ParameterUsage;
1955 
1956   auto& signature = intrinsic->Signature();
1957   auto* expr = call->Declaration();
1958   auto arguments = expr->args;
1959 
1960   // Returns the argument with the given usage
1961   auto arg = [&](Usage usage) {
1962     int idx = signature.IndexOf(usage);
1963     return (idx >= 0) ? arguments[idx] : nullptr;
1964   };
1965 
1966   auto* texture = arg(Usage::kTexture);
1967   if (!texture) {
1968     TINT_ICE(Writer, diagnostics_) << "missing texture argument";
1969     return false;
1970   }
1971 
1972   auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
1973 
1974   switch (intrinsic->Type()) {
1975     case sem::IntrinsicType::kTextureDimensions:
1976     case sem::IntrinsicType::kTextureNumLayers:
1977     case sem::IntrinsicType::kTextureNumLevels:
1978     case sem::IntrinsicType::kTextureNumSamples: {
1979       // All of these intrinsics use the GetDimensions() method on the texture
1980       bool is_ms = texture_type->IsAnyOf<sem::MultisampledTexture,
1981                                          sem::DepthMultisampledTexture>();
1982       int num_dimensions = 0;
1983       std::string swizzle;
1984 
1985       switch (intrinsic->Type()) {
1986         case sem::IntrinsicType::kTextureDimensions:
1987           switch (texture_type->dim()) {
1988             case ast::TextureDimension::kNone:
1989               TINT_ICE(Writer, diagnostics_) << "texture dimension is kNone";
1990               return false;
1991             case ast::TextureDimension::k1d:
1992               num_dimensions = 1;
1993               break;
1994             case ast::TextureDimension::k2d:
1995               num_dimensions = is_ms ? 3 : 2;
1996               swizzle = is_ms ? ".xy" : "";
1997               break;
1998             case ast::TextureDimension::k2dArray:
1999               num_dimensions = is_ms ? 4 : 3;
2000               swizzle = ".xy";
2001               break;
2002             case ast::TextureDimension::k3d:
2003               num_dimensions = 3;
2004               break;
2005             case ast::TextureDimension::kCube:
2006               num_dimensions = 2;
2007               break;
2008             case ast::TextureDimension::kCubeArray:
2009               num_dimensions = 3;
2010               swizzle = ".xy";
2011               break;
2012           }
2013           break;
2014         case sem::IntrinsicType::kTextureNumLayers:
2015           switch (texture_type->dim()) {
2016             default:
2017               TINT_ICE(Writer, diagnostics_)
2018                   << "texture dimension is not arrayed";
2019               return false;
2020             case ast::TextureDimension::k2dArray:
2021               num_dimensions = is_ms ? 4 : 3;
2022               swizzle = ".z";
2023               break;
2024             case ast::TextureDimension::kCubeArray:
2025               num_dimensions = 3;
2026               swizzle = ".z";
2027               break;
2028           }
2029           break;
2030         case sem::IntrinsicType::kTextureNumLevels:
2031           switch (texture_type->dim()) {
2032             default:
2033               TINT_ICE(Writer, diagnostics_)
2034                   << "texture dimension does not support mips";
2035               return false;
2036             case ast::TextureDimension::k1d:
2037               num_dimensions = 2;
2038               swizzle = ".y";
2039               break;
2040             case ast::TextureDimension::k2d:
2041             case ast::TextureDimension::kCube:
2042               num_dimensions = 3;
2043               swizzle = ".z";
2044               break;
2045             case ast::TextureDimension::k2dArray:
2046             case ast::TextureDimension::k3d:
2047             case ast::TextureDimension::kCubeArray:
2048               num_dimensions = 4;
2049               swizzle = ".w";
2050               break;
2051           }
2052           break;
2053         case sem::IntrinsicType::kTextureNumSamples:
2054           switch (texture_type->dim()) {
2055             default:
2056               TINT_ICE(Writer, diagnostics_)
2057                   << "texture dimension does not support multisampling";
2058               return false;
2059             case ast::TextureDimension::k2d:
2060               num_dimensions = 3;
2061               swizzle = ".z";
2062               break;
2063             case ast::TextureDimension::k2dArray:
2064               num_dimensions = 4;
2065               swizzle = ".w";
2066               break;
2067           }
2068           break;
2069         default:
2070           TINT_ICE(Writer, diagnostics_) << "unexpected intrinsic";
2071           return false;
2072       }
2073 
2074       auto* level_arg = arg(Usage::kLevel);
2075 
2076       if (level_arg) {
2077         // `NumberOfLevels` is a non-optional argument if `MipLevel` was passed.
2078         // Increment the number of dimensions for the temporary vector to
2079         // accommodate this.
2080         num_dimensions++;
2081 
2082         // If the swizzle was empty, the expression will evaluate to the whole
2083         // vector. As we've grown the vector by one element, we now need to
2084         // swizzle to keep the result expression equivalent.
2085         if (swizzle.empty()) {
2086           static constexpr const char* swizzles[] = {"", ".x", ".xy", ".xyz"};
2087           swizzle = swizzles[num_dimensions - 1];
2088         }
2089       }
2090 
2091       if (num_dimensions > 4) {
2092         TINT_ICE(Writer, diagnostics_)
2093             << "Texture query intrinsic temporary vector has " << num_dimensions
2094             << " dimensions";
2095         return false;
2096       }
2097 
2098       // Declare a variable to hold the queried texture info
2099       auto dims = UniqueIdentifier(kTempNamePrefix);
2100       if (num_dimensions == 1) {
2101         line() << "int " << dims << ";";
2102       } else {
2103         line() << "int" << num_dimensions << " " << dims << ";";
2104       }
2105 
2106       {  // texture.GetDimensions(...)
2107         auto pre = line();
2108         if (!EmitExpression(pre, texture)) {
2109           return false;
2110         }
2111         pre << ".GetDimensions(";
2112 
2113         if (level_arg) {
2114           if (!EmitExpression(pre, level_arg)) {
2115             return false;
2116           }
2117           pre << ", ";
2118         } else if (intrinsic->Type() == sem::IntrinsicType::kTextureNumLevels) {
2119           pre << "0, ";
2120         }
2121 
2122         if (num_dimensions == 1) {
2123           pre << dims;
2124         } else {
2125           static constexpr char xyzw[] = {'x', 'y', 'z', 'w'};
2126           if (num_dimensions < 0 || num_dimensions > 4) {
2127             TINT_ICE(Writer, diagnostics_)
2128                 << "vector dimensions are " << num_dimensions;
2129             return false;
2130           }
2131           for (int i = 0; i < num_dimensions; i++) {
2132             if (i > 0) {
2133               pre << ", ";
2134             }
2135             pre << dims << "." << xyzw[i];
2136           }
2137         }
2138 
2139         pre << ");";
2140       }
2141 
2142       // The out parameters of the GetDimensions() call is now in temporary
2143       // `dims` variable. This may be packed with other data, so the final
2144       // expression may require a swizzle.
2145       out << dims << swizzle;
2146       return true;
2147     }
2148     default:
2149       break;
2150   }
2151 
2152   if (!EmitExpression(out, texture))
2153     return false;
2154 
2155   // If pack_level_in_coords is true, then the mip level will be appended as the
2156   // last value of the coordinates argument. If the WGSL intrinsic overload does
2157   // not have a level parameter and pack_level_in_coords is true, then a zero
2158   // mip level will be inserted.
2159   bool pack_level_in_coords = false;
2160 
2161   uint32_t hlsl_ret_width = 4u;
2162 
2163   switch (intrinsic->Type()) {
2164     case sem::IntrinsicType::kTextureSample:
2165       out << ".Sample(";
2166       break;
2167     case sem::IntrinsicType::kTextureSampleBias:
2168       out << ".SampleBias(";
2169       break;
2170     case sem::IntrinsicType::kTextureSampleLevel:
2171       out << ".SampleLevel(";
2172       break;
2173     case sem::IntrinsicType::kTextureSampleGrad:
2174       out << ".SampleGrad(";
2175       break;
2176     case sem::IntrinsicType::kTextureSampleCompare:
2177       out << ".SampleCmp(";
2178       hlsl_ret_width = 1;
2179       break;
2180     case sem::IntrinsicType::kTextureSampleCompareLevel:
2181       out << ".SampleCmpLevelZero(";
2182       hlsl_ret_width = 1;
2183       break;
2184     case sem::IntrinsicType::kTextureLoad:
2185       out << ".Load(";
2186       // Multisampled textures do not support mip-levels.
2187       if (!texture_type->Is<sem::MultisampledTexture>()) {
2188         pack_level_in_coords = true;
2189       }
2190       break;
2191     case sem::IntrinsicType::kTextureGather:
2192       out << ".Gather";
2193       if (intrinsic->Parameters()[0]->Usage() ==
2194           sem::ParameterUsage::kComponent) {
2195         switch (call->Arguments()[0]->ConstantValue().Elements()[0].i32) {
2196           case 0:
2197             out << "Red";
2198             break;
2199           case 1:
2200             out << "Green";
2201             break;
2202           case 2:
2203             out << "Blue";
2204             break;
2205           case 3:
2206             out << "Alpha";
2207             break;
2208         }
2209       }
2210       out << "(";
2211       break;
2212     case sem::IntrinsicType::kTextureGatherCompare:
2213       out << ".GatherCmp(";
2214       break;
2215     case sem::IntrinsicType::kTextureStore:
2216       out << "[";
2217       break;
2218     default:
2219       diagnostics_.add_error(
2220           diag::System::Writer,
2221           "Internal compiler error: Unhandled texture intrinsic '" +
2222               std::string(intrinsic->str()) + "'");
2223       return false;
2224   }
2225 
2226   if (auto* sampler = arg(Usage::kSampler)) {
2227     if (!EmitExpression(out, sampler))
2228       return false;
2229     out << ", ";
2230   }
2231 
2232   auto* param_coords = arg(Usage::kCoords);
2233   if (!param_coords) {
2234     TINT_ICE(Writer, diagnostics_) << "missing coords argument";
2235     return false;
2236   }
2237 
2238   auto emit_vector_appended_with_i32_zero = [&](const ast::Expression* vector) {
2239     auto* i32 = builder_.create<sem::I32>();
2240     auto* zero = builder_.Expr(0);
2241     auto* stmt = builder_.Sem().Get(vector)->Stmt();
2242     builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
2243                                                               sem::Constant{}));
2244     auto* packed = AppendVector(&builder_, vector, zero);
2245     return EmitExpression(out, packed->Declaration());
2246   };
2247 
2248   auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
2249     if (auto* level = arg(Usage::kLevel)) {
2250       auto* packed = AppendVector(&builder_, vector, level);
2251       return EmitExpression(out, packed->Declaration());
2252     }
2253     return emit_vector_appended_with_i32_zero(vector);
2254   };
2255 
2256   if (auto* array_index = arg(Usage::kArrayIndex)) {
2257     // Array index needs to be appended to the coordinates.
2258     auto* packed = AppendVector(&builder_, param_coords, array_index);
2259     if (pack_level_in_coords) {
2260       // Then mip level needs to be appended to the coordinates.
2261       if (!emit_vector_appended_with_level(packed->Declaration())) {
2262         return false;
2263       }
2264     } else {
2265       if (!EmitExpression(out, packed->Declaration())) {
2266         return false;
2267       }
2268     }
2269   } else if (pack_level_in_coords) {
2270     // Mip level needs to be appended to the coordinates.
2271     if (!emit_vector_appended_with_level(param_coords)) {
2272       return false;
2273     }
2274   } else {
2275     if (!EmitExpression(out, param_coords)) {
2276       return false;
2277     }
2278   }
2279 
2280   for (auto usage : {Usage::kDepthRef, Usage::kBias, Usage::kLevel, Usage::kDdx,
2281                      Usage::kDdy, Usage::kSampleIndex, Usage::kOffset}) {
2282     if (usage == Usage::kLevel && pack_level_in_coords) {
2283       continue;  // mip level already packed in coordinates.
2284     }
2285     if (auto* e = arg(usage)) {
2286       out << ", ";
2287       if (!EmitExpression(out, e)) {
2288         return false;
2289       }
2290     }
2291   }
2292 
2293   if (intrinsic->Type() == sem::IntrinsicType::kTextureStore) {
2294     out << "] = ";
2295     if (!EmitExpression(out, arg(Usage::kValue))) {
2296       return false;
2297     }
2298   } else {
2299     out << ")";
2300 
2301     // If the intrinsic return type does not match the number of elements of the
2302     // HLSL intrinsic, we need to swizzle the expression to generate the correct
2303     // number of components.
2304     uint32_t wgsl_ret_width = 1;
2305     if (auto* vec = intrinsic->ReturnType()->As<sem::Vector>()) {
2306       wgsl_ret_width = vec->Width();
2307     }
2308     if (wgsl_ret_width < hlsl_ret_width) {
2309       out << ".";
2310       for (uint32_t i = 0; i < wgsl_ret_width; i++) {
2311         out << "xyz"[i];
2312       }
2313     }
2314     if (wgsl_ret_width > hlsl_ret_width) {
2315       TINT_ICE(Writer, diagnostics_)
2316           << "WGSL return width (" << wgsl_ret_width
2317           << ") is wider than HLSL return width (" << hlsl_ret_width << ") for "
2318           << intrinsic->Type();
2319       return false;
2320     }
2321   }
2322 
2323   return true;
2324 }
2325 
generate_builtin_name(const sem::Intrinsic * intrinsic)2326 std::string GeneratorImpl::generate_builtin_name(
2327     const sem::Intrinsic* intrinsic) {
2328   switch (intrinsic->Type()) {
2329     case sem::IntrinsicType::kAbs:
2330     case sem::IntrinsicType::kAcos:
2331     case sem::IntrinsicType::kAll:
2332     case sem::IntrinsicType::kAny:
2333     case sem::IntrinsicType::kAsin:
2334     case sem::IntrinsicType::kAtan:
2335     case sem::IntrinsicType::kAtan2:
2336     case sem::IntrinsicType::kCeil:
2337     case sem::IntrinsicType::kClamp:
2338     case sem::IntrinsicType::kCos:
2339     case sem::IntrinsicType::kCosh:
2340     case sem::IntrinsicType::kCross:
2341     case sem::IntrinsicType::kDeterminant:
2342     case sem::IntrinsicType::kDistance:
2343     case sem::IntrinsicType::kDot:
2344     case sem::IntrinsicType::kExp:
2345     case sem::IntrinsicType::kExp2:
2346     case sem::IntrinsicType::kFloor:
2347     case sem::IntrinsicType::kFrexp:
2348     case sem::IntrinsicType::kLdexp:
2349     case sem::IntrinsicType::kLength:
2350     case sem::IntrinsicType::kLog:
2351     case sem::IntrinsicType::kLog2:
2352     case sem::IntrinsicType::kMax:
2353     case sem::IntrinsicType::kMin:
2354     case sem::IntrinsicType::kModf:
2355     case sem::IntrinsicType::kNormalize:
2356     case sem::IntrinsicType::kPow:
2357     case sem::IntrinsicType::kReflect:
2358     case sem::IntrinsicType::kRefract:
2359     case sem::IntrinsicType::kRound:
2360     case sem::IntrinsicType::kSign:
2361     case sem::IntrinsicType::kSin:
2362     case sem::IntrinsicType::kSinh:
2363     case sem::IntrinsicType::kSqrt:
2364     case sem::IntrinsicType::kStep:
2365     case sem::IntrinsicType::kTan:
2366     case sem::IntrinsicType::kTanh:
2367     case sem::IntrinsicType::kTranspose:
2368     case sem::IntrinsicType::kTrunc:
2369       return intrinsic->str();
2370     case sem::IntrinsicType::kCountOneBits:
2371       return "countbits";
2372     case sem::IntrinsicType::kDpdx:
2373       return "ddx";
2374     case sem::IntrinsicType::kDpdxCoarse:
2375       return "ddx_coarse";
2376     case sem::IntrinsicType::kDpdxFine:
2377       return "ddx_fine";
2378     case sem::IntrinsicType::kDpdy:
2379       return "ddy";
2380     case sem::IntrinsicType::kDpdyCoarse:
2381       return "ddy_coarse";
2382     case sem::IntrinsicType::kDpdyFine:
2383       return "ddy_fine";
2384     case sem::IntrinsicType::kFaceForward:
2385       return "faceforward";
2386     case sem::IntrinsicType::kFract:
2387       return "frac";
2388     case sem::IntrinsicType::kFma:
2389       return "mad";
2390     case sem::IntrinsicType::kFwidth:
2391     case sem::IntrinsicType::kFwidthCoarse:
2392     case sem::IntrinsicType::kFwidthFine:
2393       return "fwidth";
2394     case sem::IntrinsicType::kInverseSqrt:
2395       return "rsqrt";
2396     case sem::IntrinsicType::kIsFinite:
2397       return "isfinite";
2398     case sem::IntrinsicType::kIsInf:
2399       return "isinf";
2400     case sem::IntrinsicType::kIsNan:
2401       return "isnan";
2402     case sem::IntrinsicType::kMix:
2403       return "lerp";
2404     case sem::IntrinsicType::kReverseBits:
2405       return "reversebits";
2406     case sem::IntrinsicType::kSmoothStep:
2407       return "smoothstep";
2408     default:
2409       diagnostics_.add_error(
2410           diag::System::Writer,
2411           "Unknown builtin method: " + std::string(intrinsic->str()));
2412   }
2413 
2414   return "";
2415 }
2416 
EmitCase(const ast::SwitchStatement * s,size_t case_idx)2417 bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
2418   auto* stmt = s->body[case_idx];
2419   if (stmt->IsDefault()) {
2420     line() << "default: {";
2421   } else {
2422     for (auto* selector : stmt->selectors) {
2423       auto out = line();
2424       out << "case ";
2425       if (!EmitLiteral(out, selector)) {
2426         return false;
2427       }
2428       out << ":";
2429       if (selector == stmt->selectors.back()) {
2430         out << " {";
2431       }
2432     }
2433   }
2434 
2435   increment_indent();
2436   TINT_DEFER({
2437     decrement_indent();
2438     line() << "}";
2439   });
2440 
2441   // Emit the case statement
2442   if (!EmitStatements(stmt->body->statements)) {
2443     return false;
2444   }
2445 
2446   // Inline all fallthrough case statements. FXC cannot handle fallthroughs.
2447   while (tint::Is<ast::FallthroughStatement>(stmt->body->Last())) {
2448     case_idx++;
2449     stmt = s->body[case_idx];
2450     // Generate each fallthrough case statement in a new block. This is done to
2451     // prevent symbol collision of variables declared in these cases statements.
2452     if (!EmitBlock(stmt->body)) {
2453       return false;
2454     }
2455   }
2456 
2457   if (!tint::IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(
2458           stmt->body->Last())) {
2459     line() << "break;";
2460   }
2461 
2462   return true;
2463 }
2464 
EmitContinue(const ast::ContinueStatement *)2465 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
2466   if (!emit_continuing_()) {
2467     return false;
2468   }
2469   line() << "continue;";
2470   return true;
2471 }
2472 
EmitDiscard(const ast::DiscardStatement *)2473 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
2474   // TODO(dsinclair): Verify this is correct when the discard semantics are
2475   // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
2476   line() << "discard;";
2477   return true;
2478 }
2479 
EmitExpression(std::ostream & out,const ast::Expression * expr)2480 bool GeneratorImpl::EmitExpression(std::ostream& out,
2481                                    const ast::Expression* expr) {
2482   if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
2483     return EmitIndexAccessor(out, a);
2484   }
2485   if (auto* b = expr->As<ast::BinaryExpression>()) {
2486     return EmitBinary(out, b);
2487   }
2488   if (auto* b = expr->As<ast::BitcastExpression>()) {
2489     return EmitBitcast(out, b);
2490   }
2491   if (auto* c = expr->As<ast::CallExpression>()) {
2492     return EmitCall(out, c);
2493   }
2494   if (auto* i = expr->As<ast::IdentifierExpression>()) {
2495     return EmitIdentifier(out, i);
2496   }
2497   if (auto* l = expr->As<ast::LiteralExpression>()) {
2498     return EmitLiteral(out, l);
2499   }
2500   if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
2501     return EmitMemberAccessor(out, m);
2502   }
2503   if (auto* u = expr->As<ast::UnaryOpExpression>()) {
2504     return EmitUnaryOp(out, u);
2505   }
2506 
2507   diagnostics_.add_error(
2508       diag::System::Writer,
2509       "unknown expression type: " + std::string(expr->TypeInfo().name));
2510   return false;
2511 }
2512 
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)2513 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
2514                                    const ast::IdentifierExpression* expr) {
2515   out << builder_.Symbols().NameFor(expr->symbol);
2516   return true;
2517 }
2518 
EmitIf(const ast::IfStatement * stmt)2519 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
2520   {
2521     auto out = line();
2522     out << "if (";
2523     if (!EmitExpression(out, stmt->condition)) {
2524       return false;
2525     }
2526     out << ") {";
2527   }
2528 
2529   if (!EmitStatementsWithIndent(stmt->body->statements)) {
2530     return false;
2531   }
2532 
2533   for (auto* e : stmt->else_statements) {
2534     if (e->condition) {
2535       line() << "} else {";
2536       increment_indent();
2537 
2538       {
2539         auto out = line();
2540         out << "if (";
2541         if (!EmitExpression(out, e->condition)) {
2542           return false;
2543         }
2544         out << ") {";
2545       }
2546     } else {
2547       line() << "} else {";
2548     }
2549 
2550     if (!EmitStatementsWithIndent(e->body->statements)) {
2551       return false;
2552     }
2553   }
2554 
2555   line() << "}";
2556 
2557   for (auto* e : stmt->else_statements) {
2558     if (e->condition) {
2559       decrement_indent();
2560       line() << "}";
2561     }
2562   }
2563   return true;
2564 }
2565 
EmitFunction(const ast::Function * func)2566 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
2567   auto* sem = builder_.Sem().Get(func);
2568 
2569   if (ast::HasDecoration<ast::InternalDecoration>(func->decorations)) {
2570     // An internal function. Do not emit.
2571     return true;
2572   }
2573 
2574   {
2575     auto out = line();
2576     auto name = builder_.Symbols().NameFor(func->symbol);
2577     // If the function returns an array, then we need to declare a typedef for
2578     // this.
2579     if (sem->ReturnType()->Is<sem::Array>()) {
2580       auto typedef_name = UniqueIdentifier(name + "_ret");
2581       auto pre = line();
2582       pre << "typedef ";
2583       if (!EmitTypeAndName(pre, sem->ReturnType(), ast::StorageClass::kNone,
2584                            ast::Access::kReadWrite, typedef_name)) {
2585         return false;
2586       }
2587       pre << ";";
2588       out << typedef_name;
2589     } else {
2590       if (!EmitType(out, sem->ReturnType(), ast::StorageClass::kNone,
2591                     ast::Access::kReadWrite, "")) {
2592         return false;
2593       }
2594     }
2595 
2596     out << " " << name << "(";
2597 
2598     bool first = true;
2599 
2600     for (auto* v : sem->Parameters()) {
2601       if (!first) {
2602         out << ", ";
2603       }
2604       first = false;
2605 
2606       auto const* type = v->Type();
2607 
2608       if (auto* ptr = type->As<sem::Pointer>()) {
2609         // Transform pointer parameters in to `inout` parameters.
2610         // The WGSL spec is highly restrictive in what can be passed in pointer
2611         // parameters, which allows for this transformation. See:
2612         // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
2613         out << "inout ";
2614         type = ptr->StoreType();
2615       }
2616 
2617       // Note: WGSL only allows for StorageClass::kNone on parameters, however
2618       // the sanitizer transforms generates load / store functions for storage
2619       // or uniform buffers. These functions have a buffer parameter with
2620       // StorageClass::kStorage or StorageClass::kUniform. This is required to
2621       // correctly translate the parameter to a [RW]ByteAddressBuffer for
2622       // storage buffers and a uint4[N] for uniform buffers.
2623       if (!EmitTypeAndName(
2624               out, type, v->StorageClass(), v->Access(),
2625               builder_.Symbols().NameFor(v->Declaration()->symbol))) {
2626         return false;
2627       }
2628     }
2629     out << ") {";
2630   }
2631 
2632   if (sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>()) {
2633     // BUG(crbug.com/tint/1081): work around non-void functions with discard
2634     // failing compilation sometimes
2635     if (!EmitFunctionBodyWithDiscard(func)) {
2636       return false;
2637     }
2638   } else {
2639     if (!EmitStatementsWithIndent(func->body->statements)) {
2640       return false;
2641     }
2642   }
2643 
2644   line() << "}";
2645 
2646   return true;
2647 }
2648 
EmitFunctionBodyWithDiscard(const ast::Function * func)2649 bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) {
2650   // FXC sometimes fails to compile functions that discard with 'Not all control
2651   // paths return a value'. We work around this by wrapping the function body
2652   // within an "if (true) { <body> } return <default return type obj>;" so that
2653   // there is always an (unused) return statement.
2654 
2655   auto* sem = builder_.Sem().Get(func);
2656   TINT_ASSERT(Writer, sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>());
2657 
2658   ScopedIndent si(this);
2659   line() << "if (true) {";
2660 
2661   if (!EmitStatementsWithIndent(func->body->statements)) {
2662     return false;
2663   }
2664 
2665   line() << "}";
2666 
2667   // Return an unused result that matches the type of the return value
2668   auto name = builder_.Symbols().NameFor(builder_.Symbols().New("unused"));
2669   {
2670     auto out = line();
2671     if (!EmitTypeAndName(out, sem->ReturnType(), ast::StorageClass::kNone,
2672                          ast::Access::kReadWrite, name)) {
2673       return false;
2674     }
2675     out << ";";
2676   }
2677   line() << "return " << name << ";";
2678 
2679   return true;
2680 }
2681 
EmitGlobalVariable(const ast::Variable * global)2682 bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
2683   if (global->is_const) {
2684     return EmitProgramConstVariable(global);
2685   }
2686 
2687   auto* sem = builder_.Sem().Get(global);
2688   switch (sem->StorageClass()) {
2689     case ast::StorageClass::kUniform:
2690       return EmitUniformVariable(sem);
2691     case ast::StorageClass::kStorage:
2692       return EmitStorageVariable(sem);
2693     case ast::StorageClass::kUniformConstant:
2694       return EmitHandleVariable(sem);
2695     case ast::StorageClass::kPrivate:
2696       return EmitPrivateVariable(sem);
2697     case ast::StorageClass::kWorkgroup:
2698       return EmitWorkgroupVariable(sem);
2699     default:
2700       break;
2701   }
2702 
2703   TINT_ICE(Writer, diagnostics_)
2704       << "unhandled storage class " << sem->StorageClass();
2705   return false;
2706 }
2707 
EmitUniformVariable(const sem::Variable * var)2708 bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
2709   auto* decl = var->Declaration();
2710   auto binding_point = decl->BindingPoint();
2711   auto* type = var->Type()->UnwrapRef();
2712 
2713   auto* str = type->As<sem::Struct>();
2714   if (!str) {
2715     // https://www.w3.org/TR/WGSL/#module-scope-variables
2716     TINT_ICE(Writer, diagnostics_)
2717         << "variables with uniform storage must be structure";
2718   }
2719 
2720   auto name = builder_.Symbols().NameFor(decl->symbol);
2721   line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point)
2722          << " {";
2723 
2724   {
2725     ScopedIndent si(this);
2726     auto out = line();
2727     if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(),
2728                          name)) {
2729       return false;
2730     }
2731     out << ";";
2732   }
2733 
2734   line() << "};";
2735 
2736   return true;
2737 }
2738 
EmitStorageVariable(const sem::Variable * var)2739 bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) {
2740   auto* decl = var->Declaration();
2741   auto* type = var->Type()->UnwrapRef();
2742   auto out = line();
2743   if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(),
2744                        builder_.Symbols().NameFor(decl->symbol))) {
2745     return false;
2746   }
2747 
2748   out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u',
2749                           decl->BindingPoint())
2750       << ";";
2751 
2752   return true;
2753 }
2754 
EmitHandleVariable(const sem::Variable * var)2755 bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
2756   auto* decl = var->Declaration();
2757   auto* unwrapped_type = var->Type()->UnwrapRef();
2758   auto out = line();
2759 
2760   auto name = builder_.Symbols().NameFor(decl->symbol);
2761   auto* type = var->Type()->UnwrapRef();
2762   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2763     return false;
2764   }
2765 
2766   const char* register_space = nullptr;
2767 
2768   if (unwrapped_type->Is<sem::Texture>()) {
2769     register_space = "t";
2770     if (unwrapped_type->Is<sem::StorageTexture>()) {
2771       register_space = "u";
2772     }
2773   } else if (unwrapped_type->Is<sem::Sampler>()) {
2774     register_space = "s";
2775   }
2776 
2777   if (register_space) {
2778     auto bp = decl->BindingPoint();
2779     out << " : register(" << register_space << bp.binding->value << ", space"
2780         << bp.group->value << ")";
2781   }
2782 
2783   out << ";";
2784   return true;
2785 }
2786 
EmitPrivateVariable(const sem::Variable * var)2787 bool GeneratorImpl::EmitPrivateVariable(const sem::Variable* var) {
2788   auto* decl = var->Declaration();
2789   auto out = line();
2790 
2791   out << "static ";
2792 
2793   auto name = builder_.Symbols().NameFor(decl->symbol);
2794   auto* type = var->Type()->UnwrapRef();
2795   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2796     return false;
2797   }
2798 
2799   out << " = ";
2800   if (auto* constructor = decl->constructor) {
2801     if (!EmitExpression(out, constructor)) {
2802       return false;
2803     }
2804   } else {
2805     if (!EmitZeroValue(out, var->Type()->UnwrapRef())) {
2806       return false;
2807     }
2808   }
2809 
2810   out << ";";
2811   return true;
2812 }
2813 
EmitWorkgroupVariable(const sem::Variable * var)2814 bool GeneratorImpl::EmitWorkgroupVariable(const sem::Variable* var) {
2815   auto* decl = var->Declaration();
2816   auto out = line();
2817 
2818   out << "groupshared ";
2819 
2820   auto name = builder_.Symbols().NameFor(decl->symbol);
2821   auto* type = var->Type()->UnwrapRef();
2822   if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2823     return false;
2824   }
2825 
2826   if (auto* constructor = decl->constructor) {
2827     out << " = ";
2828     if (!EmitExpression(out, constructor)) {
2829       return false;
2830     }
2831   }
2832 
2833   out << ";";
2834   return true;
2835 }
2836 
builtin_to_attribute(ast::Builtin builtin) const2837 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
2838   switch (builtin) {
2839     case ast::Builtin::kPosition:
2840       return "SV_Position";
2841     case ast::Builtin::kVertexIndex:
2842       return "SV_VertexID";
2843     case ast::Builtin::kInstanceIndex:
2844       return "SV_InstanceID";
2845     case ast::Builtin::kFrontFacing:
2846       return "SV_IsFrontFace";
2847     case ast::Builtin::kFragDepth:
2848       return "SV_Depth";
2849     case ast::Builtin::kLocalInvocationId:
2850       return "SV_GroupThreadID";
2851     case ast::Builtin::kLocalInvocationIndex:
2852       return "SV_GroupIndex";
2853     case ast::Builtin::kGlobalInvocationId:
2854       return "SV_DispatchThreadID";
2855     case ast::Builtin::kWorkgroupId:
2856       return "SV_GroupID";
2857     case ast::Builtin::kSampleIndex:
2858       return "SV_SampleIndex";
2859     case ast::Builtin::kSampleMask:
2860       return "SV_Coverage";
2861     default:
2862       break;
2863   }
2864   return "";
2865 }
2866 
interpolation_to_modifiers(ast::InterpolationType type,ast::InterpolationSampling sampling) const2867 std::string GeneratorImpl::interpolation_to_modifiers(
2868     ast::InterpolationType type,
2869     ast::InterpolationSampling sampling) const {
2870   std::string modifiers;
2871   switch (type) {
2872     case ast::InterpolationType::kPerspective:
2873       modifiers += "linear ";
2874       break;
2875     case ast::InterpolationType::kLinear:
2876       modifiers += "noperspective ";
2877       break;
2878     case ast::InterpolationType::kFlat:
2879       modifiers += "nointerpolation ";
2880       break;
2881   }
2882   switch (sampling) {
2883     case ast::InterpolationSampling::kCentroid:
2884       modifiers += "centroid ";
2885       break;
2886     case ast::InterpolationSampling::kSample:
2887       modifiers += "sample ";
2888       break;
2889     case ast::InterpolationSampling::kCenter:
2890     case ast::InterpolationSampling::kNone:
2891       break;
2892   }
2893   return modifiers;
2894 }
2895 
EmitEntryPointFunction(const ast::Function * func)2896 bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
2897   auto* func_sem = builder_.Sem().Get(func);
2898 
2899   {
2900     auto out = line();
2901     if (func->PipelineStage() == ast::PipelineStage::kCompute) {
2902       // Emit the workgroup_size attribute.
2903       auto wgsize = func_sem->WorkgroupSize();
2904       out << "[numthreads(";
2905       for (int i = 0; i < 3; i++) {
2906         if (i > 0) {
2907           out << ", ";
2908         }
2909 
2910         if (wgsize[i].overridable_const) {
2911           auto* global = builder_.Sem().Get<sem::GlobalVariable>(
2912               wgsize[i].overridable_const);
2913           if (!global->IsOverridable()) {
2914             TINT_ICE(Writer, builder_.Diagnostics())
2915                 << "expected a pipeline-overridable constant";
2916           }
2917           out << kSpecConstantPrefix << global->ConstantId();
2918         } else {
2919           out << std::to_string(wgsize[i].value);
2920         }
2921       }
2922       out << ")]" << std::endl;
2923     }
2924 
2925     out << func->return_type->FriendlyName(builder_.Symbols());
2926 
2927     out << " " << builder_.Symbols().NameFor(func->symbol) << "(";
2928 
2929     bool first = true;
2930 
2931     // Emit entry point parameters.
2932     for (auto* var : func->params) {
2933       auto* sem = builder_.Sem().Get(var);
2934       auto* type = sem->Type();
2935       if (!type->Is<sem::Struct>()) {
2936         // ICE likely indicates that the CanonicalizeEntryPointIO transform was
2937         // not run, or a builtin parameter was added after it was run.
2938         TINT_ICE(Writer, diagnostics_)
2939             << "Unsupported non-struct entry point parameter";
2940       }
2941 
2942       if (!first) {
2943         out << ", ";
2944       }
2945       first = false;
2946 
2947       if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
2948                            builder_.Symbols().NameFor(var->symbol))) {
2949         return false;
2950       }
2951     }
2952 
2953     out << ") {";
2954   }
2955 
2956   {
2957     ScopedIndent si(this);
2958 
2959     if (!EmitStatements(func->body->statements)) {
2960       return false;
2961     }
2962 
2963     if (!Is<ast::ReturnStatement>(func->body->Last())) {
2964       ast::ReturnStatement ret(ProgramID(), Source{});
2965       if (!EmitStatement(&ret)) {
2966         return false;
2967       }
2968     }
2969   }
2970 
2971   line() << "}";
2972 
2973   return true;
2974 }
2975 
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)2976 bool GeneratorImpl::EmitLiteral(std::ostream& out,
2977                                 const ast::LiteralExpression* lit) {
2978   if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
2979     out << (l->value ? "true" : "false");
2980   } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
2981     if (std::isinf(fl->value)) {
2982       out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
2983     } else if (std::isnan(fl->value)) {
2984       out << "asfloat(0x7fc00000u)";
2985     } else {
2986       out << FloatToString(fl->value) << "f";
2987     }
2988   } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
2989     out << sl->value;
2990   } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
2991     out << ul->value << "u";
2992   } else {
2993     diagnostics_.add_error(diag::System::Writer, "unknown literal type");
2994     return false;
2995   }
2996   return true;
2997 }
2998 
EmitZeroValue(std::ostream & out,const sem::Type * type)2999 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
3000   if (type->Is<sem::Bool>()) {
3001     out << "false";
3002   } else if (type->Is<sem::F32>()) {
3003     out << "0.0f";
3004   } else if (type->Is<sem::I32>()) {
3005     out << "0";
3006   } else if (type->Is<sem::U32>()) {
3007     out << "0u";
3008   } else if (auto* vec = type->As<sem::Vector>()) {
3009     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
3010                   "")) {
3011       return false;
3012     }
3013     ScopedParen sp(out);
3014     for (uint32_t i = 0; i < vec->Width(); i++) {
3015       if (i != 0) {
3016         out << ", ";
3017       }
3018       if (!EmitZeroValue(out, vec->type())) {
3019         return false;
3020       }
3021     }
3022   } else if (auto* mat = type->As<sem::Matrix>()) {
3023     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
3024                   "")) {
3025       return false;
3026     }
3027     ScopedParen sp(out);
3028     for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
3029       if (i != 0) {
3030         out << ", ";
3031       }
3032       if (!EmitZeroValue(out, mat->type())) {
3033         return false;
3034       }
3035     }
3036   } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
3037     out << "(";
3038     if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
3039                   "")) {
3040       return false;
3041     }
3042     out << ")0";
3043   } else {
3044     diagnostics_.add_error(
3045         diag::System::Writer,
3046         "Invalid type for zero emission: " + type->type_name());
3047     return false;
3048   }
3049   return true;
3050 }
3051 
EmitLoop(const ast::LoopStatement * stmt)3052 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
3053   auto emit_continuing = [this, stmt]() {
3054     if (stmt->continuing && !stmt->continuing->Empty()) {
3055       if (!EmitBlock(stmt->continuing)) {
3056         return false;
3057       }
3058     }
3059     return true;
3060   };
3061 
3062   TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3063   line() << LoopAttribute() << "while (true) {";
3064   {
3065     ScopedIndent si(this);
3066     if (!EmitStatements(stmt->body->statements)) {
3067       return false;
3068     }
3069     if (!emit_continuing()) {
3070       return false;
3071     }
3072   }
3073   line() << "}";
3074 
3075   return true;
3076 }
3077 
EmitForLoop(const ast::ForLoopStatement * stmt)3078 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
3079   // Nest a for loop with a new block. In HLSL the initializer scope is not
3080   // nested by the for-loop, so we may get variable redefinitions.
3081   line() << "{";
3082   increment_indent();
3083   TINT_DEFER({
3084     decrement_indent();
3085     line() << "}";
3086   });
3087 
3088   TextBuffer init_buf;
3089   if (auto* init = stmt->initializer) {
3090     TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
3091     if (!EmitStatement(init)) {
3092       return false;
3093     }
3094   }
3095 
3096   TextBuffer cond_pre;
3097   std::stringstream cond_buf;
3098   if (auto* cond = stmt->condition) {
3099     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
3100     if (!EmitExpression(cond_buf, cond)) {
3101       return false;
3102     }
3103   }
3104 
3105   TextBuffer cont_buf;
3106   if (auto* cont = stmt->continuing) {
3107     TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
3108     if (!EmitStatement(cont)) {
3109       return false;
3110     }
3111   }
3112 
3113   // If the for-loop has a multi-statement conditional and / or continuing, then
3114   // we cannot emit this as a regular for-loop in HLSL. Instead we need to
3115   // generate a `while(true)` loop.
3116   bool emit_as_loop = cond_pre.lines.size() > 0 || cont_buf.lines.size() > 1;
3117 
3118   // If the for-loop has multi-statement initializer, or is going to be emitted
3119   // as a `while(true)` loop, then declare the initializer statement(s) before
3120   // the loop.
3121   if (init_buf.lines.size() > 1 || (stmt->initializer && emit_as_loop)) {
3122     current_buffer_->Append(init_buf);
3123     init_buf.lines.clear();  // Don't emit the initializer again in the 'for'
3124   }
3125 
3126   if (emit_as_loop) {
3127     auto emit_continuing = [&]() {
3128       current_buffer_->Append(cont_buf);
3129       return true;
3130     };
3131 
3132     TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3133     line() << LoopAttribute() << "while (true) {";
3134     increment_indent();
3135     TINT_DEFER({
3136       decrement_indent();
3137       line() << "}";
3138     });
3139 
3140     if (stmt->condition) {
3141       current_buffer_->Append(cond_pre);
3142       line() << "if (!(" << cond_buf.str() << ")) { break; }";
3143     }
3144 
3145     if (!EmitStatements(stmt->body->statements)) {
3146       return false;
3147     }
3148 
3149     if (!emit_continuing()) {
3150       return false;
3151     }
3152   } else {
3153     // For-loop can be generated.
3154     {
3155       auto out = line();
3156       out << LoopAttribute() << "for";
3157       {
3158         ScopedParen sp(out);
3159 
3160         if (!init_buf.lines.empty()) {
3161           out << init_buf.lines[0].content << " ";
3162         } else {
3163           out << "; ";
3164         }
3165 
3166         out << cond_buf.str() << "; ";
3167 
3168         if (!cont_buf.lines.empty()) {
3169           out << TrimSuffix(cont_buf.lines[0].content, ";");
3170         }
3171       }
3172       out << " {";
3173     }
3174     {
3175       auto emit_continuing = [] { return true; };
3176       TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3177       if (!EmitStatementsWithIndent(stmt->body->statements)) {
3178         return false;
3179       }
3180     }
3181     line() << "}";
3182   }
3183 
3184   return true;
3185 }
3186 
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)3187 bool GeneratorImpl::EmitMemberAccessor(
3188     std::ostream& out,
3189     const ast::MemberAccessorExpression* expr) {
3190   if (!EmitExpression(out, expr->structure)) {
3191     return false;
3192   }
3193   out << ".";
3194 
3195   // Swizzles output the name directly
3196   if (builder_.Sem().Get(expr)->Is<sem::Swizzle>()) {
3197     out << builder_.Symbols().NameFor(expr->member->symbol);
3198   } else if (!EmitExpression(out, expr->member)) {
3199     return false;
3200   }
3201 
3202   return true;
3203 }
3204 
EmitReturn(const ast::ReturnStatement * stmt)3205 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
3206   if (stmt->value) {
3207     auto out = line();
3208     out << "return ";
3209     if (!EmitExpression(out, stmt->value)) {
3210       return false;
3211     }
3212     out << ";";
3213   } else {
3214     line() << "return;";
3215   }
3216   return true;
3217 }
3218 
EmitStatement(const ast::Statement * stmt)3219 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
3220   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
3221     return EmitAssign(a);
3222   }
3223   if (auto* b = stmt->As<ast::BlockStatement>()) {
3224     return EmitBlock(b);
3225   }
3226   if (auto* b = stmt->As<ast::BreakStatement>()) {
3227     return EmitBreak(b);
3228   }
3229   if (auto* c = stmt->As<ast::CallStatement>()) {
3230     auto out = line();
3231     if (!EmitCall(out, c->expr)) {
3232       return false;
3233     }
3234     out << ";";
3235     return true;
3236   }
3237   if (auto* c = stmt->As<ast::ContinueStatement>()) {
3238     return EmitContinue(c);
3239   }
3240   if (auto* d = stmt->As<ast::DiscardStatement>()) {
3241     return EmitDiscard(d);
3242   }
3243   if (stmt->As<ast::FallthroughStatement>()) {
3244     line() << "/* fallthrough */";
3245     return true;
3246   }
3247   if (auto* i = stmt->As<ast::IfStatement>()) {
3248     return EmitIf(i);
3249   }
3250   if (auto* l = stmt->As<ast::LoopStatement>()) {
3251     return EmitLoop(l);
3252   }
3253   if (auto* l = stmt->As<ast::ForLoopStatement>()) {
3254     return EmitForLoop(l);
3255   }
3256   if (auto* r = stmt->As<ast::ReturnStatement>()) {
3257     return EmitReturn(r);
3258   }
3259   if (auto* s = stmt->As<ast::SwitchStatement>()) {
3260     return EmitSwitch(s);
3261   }
3262   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
3263     return EmitVariable(v->variable);
3264   }
3265 
3266   diagnostics_.add_error(
3267       diag::System::Writer,
3268       "unknown statement type: " + std::string(stmt->TypeInfo().name));
3269   return false;
3270 }
3271 
EmitDefaultOnlySwitch(const ast::SwitchStatement * stmt)3272 bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
3273   TINT_ASSERT(Writer, stmt->body.size() == 1 && stmt->body[0]->IsDefault());
3274 
3275   // FXC fails to compile a switch with just a default case, ignoring the
3276   // default case body. We work around this here by emitting the default case
3277   // without the switch.
3278 
3279   // Emit the switch condition as-is in case it has side-effects (e.g.
3280   // function call). Note that's it's fine not to assign the result of the
3281   // expression.
3282   {
3283     auto out = line();
3284     if (!EmitExpression(out, stmt->condition)) {
3285       return false;
3286     }
3287     out << ";";
3288   }
3289 
3290   // Emit "do { <default case body> } while(false);". We use a 'do' loop so
3291   // that break statements work as expected, and make it 'while (false)' in
3292   // case there isn't a break statement.
3293   line() << "do {";
3294   {
3295     ScopedIndent si(this);
3296     if (!EmitStatements(stmt->body[0]->body->statements)) {
3297       return false;
3298     }
3299   }
3300   line() << "} while (false);";
3301   return true;
3302 }
3303 
EmitSwitch(const ast::SwitchStatement * stmt)3304 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
3305   // BUG(crbug.com/tint/1188): work around default-only switches
3306   if (stmt->body.size() == 1 && stmt->body[0]->IsDefault()) {
3307     return EmitDefaultOnlySwitch(stmt);
3308   }
3309 
3310   {  // switch(expr) {
3311     auto out = line();
3312     out << "switch(";
3313     if (!EmitExpression(out, stmt->condition)) {
3314       return false;
3315     }
3316     out << ") {";
3317   }
3318 
3319   {
3320     ScopedIndent si(this);
3321     for (size_t i = 0; i < stmt->body.size(); i++) {
3322       if (!EmitCase(stmt, i)) {
3323         return false;
3324       }
3325     }
3326   }
3327 
3328   line() << "}";
3329 
3330   return true;
3331 }
3332 
EmitType(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name,bool * name_printed)3333 bool GeneratorImpl::EmitType(std::ostream& out,
3334                              const sem::Type* type,
3335                              ast::StorageClass storage_class,
3336                              ast::Access access,
3337                              const std::string& name,
3338                              bool* name_printed /* = nullptr */) {
3339   if (name_printed) {
3340     *name_printed = false;
3341   }
3342   switch (storage_class) {
3343     case ast::StorageClass::kStorage:
3344       if (access != ast::Access::kRead) {
3345         out << "RW";
3346       }
3347       out << "ByteAddressBuffer";
3348       return true;
3349     case ast::StorageClass::kUniform: {
3350       auto* str = type->As<sem::Struct>();
3351       if (!str) {
3352         // https://www.w3.org/TR/WGSL/#module-scope-variables
3353         TINT_ICE(Writer, diagnostics_)
3354             << "variables with uniform storage must be structure";
3355       }
3356       auto array_length = (str->Size() + 15) / 16;
3357       out << "uint4 " << name << "[" << array_length << "]";
3358       if (name_printed) {
3359         *name_printed = true;
3360       }
3361       return true;
3362     }
3363     default:
3364       break;
3365   }
3366 
3367   if (auto* ary = type->As<sem::Array>()) {
3368     const sem::Type* base_type = ary;
3369     std::vector<uint32_t> sizes;
3370     while (auto* arr = base_type->As<sem::Array>()) {
3371       if (arr->IsRuntimeSized()) {
3372         TINT_ICE(Writer, diagnostics_)
3373             << "Runtime arrays may only exist in storage buffers, which should "
3374                "have been transformed into a ByteAddressBuffer";
3375         return false;
3376       }
3377       sizes.push_back(arr->Count());
3378       base_type = arr->ElemType();
3379     }
3380     if (!EmitType(out, base_type, storage_class, access, "")) {
3381       return false;
3382     }
3383     if (!name.empty()) {
3384       out << " " << name;
3385       if (name_printed) {
3386         *name_printed = true;
3387       }
3388     }
3389     for (uint32_t size : sizes) {
3390       out << "[" << size << "]";
3391     }
3392   } else if (type->Is<sem::Bool>()) {
3393     out << "bool";
3394   } else if (type->Is<sem::F32>()) {
3395     out << "float";
3396   } else if (type->Is<sem::I32>()) {
3397     out << "int";
3398   } else if (auto* mat = type->As<sem::Matrix>()) {
3399     if (!EmitType(out, mat->type(), storage_class, access, "")) {
3400       return false;
3401     }
3402     // Note: HLSL's matrices are declared as <type>NxM, where N is the number of
3403     // rows and M is the number of columns. Despite HLSL's matrices being
3404     // column-major by default, the index operator and constructors actually
3405     // operate on row-vectors, where as WGSL operates on column vectors.
3406     // To simplify everything we use the transpose of the matrices.
3407     // See:
3408     // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
3409     out << mat->columns() << "x" << mat->rows();
3410   } else if (type->Is<sem::Pointer>()) {
3411     TINT_ICE(Writer, diagnostics_)
3412         << "Attempting to emit pointer type. These should have been removed "
3413            "with the InlinePointerLets transform";
3414     return false;
3415   } else if (auto* sampler = type->As<sem::Sampler>()) {
3416     out << "Sampler";
3417     if (sampler->IsComparison()) {
3418       out << "Comparison";
3419     }
3420     out << "State";
3421   } else if (auto* str = type->As<sem::Struct>()) {
3422     out << StructName(str);
3423   } else if (auto* tex = type->As<sem::Texture>()) {
3424     auto* storage = tex->As<sem::StorageTexture>();
3425     auto* ms = tex->As<sem::MultisampledTexture>();
3426     auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
3427     auto* sampled = tex->As<sem::SampledTexture>();
3428 
3429     if (storage && storage->access() != ast::Access::kRead) {
3430       out << "RW";
3431     }
3432     out << "Texture";
3433 
3434     switch (tex->dim()) {
3435       case ast::TextureDimension::k1d:
3436         out << "1D";
3437         break;
3438       case ast::TextureDimension::k2d:
3439         out << ((ms || depth_ms) ? "2DMS" : "2D");
3440         break;
3441       case ast::TextureDimension::k2dArray:
3442         out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
3443         break;
3444       case ast::TextureDimension::k3d:
3445         out << "3D";
3446         break;
3447       case ast::TextureDimension::kCube:
3448         out << "Cube";
3449         break;
3450       case ast::TextureDimension::kCubeArray:
3451         out << "CubeArray";
3452         break;
3453       default:
3454         TINT_UNREACHABLE(Writer, diagnostics_)
3455             << "unexpected TextureDimension " << tex->dim();
3456         return false;
3457     }
3458 
3459     if (storage) {
3460       auto* component = image_format_to_rwtexture_type(storage->image_format());
3461       if (component == nullptr) {
3462         TINT_ICE(Writer, diagnostics_)
3463             << "Unsupported StorageTexture ImageFormat: "
3464             << static_cast<int>(storage->image_format());
3465         return false;
3466       }
3467       out << "<" << component << ">";
3468     } else if (depth_ms) {
3469       out << "<float4>";
3470     } else if (sampled || ms) {
3471       auto* subtype = sampled ? sampled->type() : ms->type();
3472       out << "<";
3473       if (subtype->Is<sem::F32>()) {
3474         out << "float4";
3475       } else if (subtype->Is<sem::I32>()) {
3476         out << "int4";
3477       } else if (subtype->Is<sem::U32>()) {
3478         out << "uint4";
3479       } else {
3480         TINT_ICE(Writer, diagnostics_)
3481             << "Unsupported multisampled texture type";
3482         return false;
3483       }
3484       out << ">";
3485     }
3486   } else if (type->Is<sem::U32>()) {
3487     out << "uint";
3488   } else if (auto* vec = type->As<sem::Vector>()) {
3489     auto width = vec->Width();
3490     if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
3491       out << "float" << width;
3492     } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
3493       out << "int" << width;
3494     } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
3495       out << "uint" << width;
3496     } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
3497       out << "bool" << width;
3498     } else {
3499       out << "vector<";
3500       if (!EmitType(out, vec->type(), storage_class, access, "")) {
3501         return false;
3502       }
3503       out << ", " << width << ">";
3504     }
3505   } else if (auto* atomic = type->As<sem::Atomic>()) {
3506     if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
3507       return false;
3508     }
3509   } else if (type->Is<sem::Void>()) {
3510     out << "void";
3511   } else {
3512     diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
3513     return false;
3514   }
3515 
3516   return true;
3517 }
3518 
EmitTypeAndName(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name)3519 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
3520                                     const sem::Type* type,
3521                                     ast::StorageClass storage_class,
3522                                     ast::Access access,
3523                                     const std::string& name) {
3524   bool name_printed = false;
3525   if (!EmitType(out, type, storage_class, access, name, &name_printed)) {
3526     return false;
3527   }
3528   if (!name.empty() && !name_printed) {
3529     out << " " << name;
3530   }
3531   return true;
3532 }
3533 
EmitStructType(TextBuffer * b,const sem::Struct * str)3534 bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
3535   line(b) << "struct " << StructName(str) << " {";
3536   {
3537     ScopedIndent si(b);
3538     for (auto* mem : str->Members()) {
3539       auto name = builder_.Symbols().NameFor(mem->Name());
3540 
3541       auto* ty = mem->Type();
3542 
3543       auto out = line(b);
3544 
3545       std::string pre, post;
3546 
3547       if (auto* decl = mem->Declaration()) {
3548         for (auto* deco : decl->decorations) {
3549           if (auto* location = deco->As<ast::LocationDecoration>()) {
3550             auto& pipeline_stage_uses = str->PipelineStageUses();
3551             if (pipeline_stage_uses.size() != 1) {
3552               TINT_ICE(Writer, diagnostics_)
3553                   << "invalid entry point IO struct uses";
3554             }
3555 
3556             if (pipeline_stage_uses.count(
3557                     sem::PipelineStageUsage::kVertexInput)) {
3558               post += " : TEXCOORD" + std::to_string(location->value);
3559             } else if (pipeline_stage_uses.count(
3560                            sem::PipelineStageUsage::kVertexOutput)) {
3561               post += " : TEXCOORD" + std::to_string(location->value);
3562             } else if (pipeline_stage_uses.count(
3563                            sem::PipelineStageUsage::kFragmentInput)) {
3564               post += " : TEXCOORD" + std::to_string(location->value);
3565             } else if (pipeline_stage_uses.count(
3566                            sem::PipelineStageUsage::kFragmentOutput)) {
3567               post += " : SV_Target" + std::to_string(location->value);
3568             } else {
3569               TINT_ICE(Writer, diagnostics_)
3570                   << "invalid use of location decoration";
3571             }
3572           } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
3573             auto attr = builtin_to_attribute(builtin->builtin);
3574             if (attr.empty()) {
3575               diagnostics_.add_error(diag::System::Writer,
3576                                      "unsupported builtin");
3577               return false;
3578             }
3579             post += " : " + attr;
3580           } else if (auto* interpolate =
3581                          deco->As<ast::InterpolateDecoration>()) {
3582             auto mod = interpolation_to_modifiers(interpolate->type,
3583                                                   interpolate->sampling);
3584             if (mod.empty()) {
3585               diagnostics_.add_error(diag::System::Writer,
3586                                      "unsupported interpolation");
3587               return false;
3588             }
3589             pre += mod;
3590 
3591           } else if (deco->Is<ast::InvariantDecoration>()) {
3592             // Note: `precise` is not exactly the same as `invariant`, but is
3593             // stricter and therefore provides the necessary guarantees.
3594             // See discussion here: https://github.com/gpuweb/gpuweb/issues/893
3595             pre += "precise ";
3596           } else if (!deco->IsAnyOf<ast::StructMemberAlignDecoration,
3597                                     ast::StructMemberOffsetDecoration,
3598                                     ast::StructMemberSizeDecoration>()) {
3599             TINT_ICE(Writer, diagnostics_)
3600                 << "unhandled struct member attribute: " << deco->Name();
3601             return false;
3602           }
3603         }
3604       }
3605 
3606       out << pre;
3607       if (!EmitTypeAndName(out, ty, ast::StorageClass::kNone,
3608                            ast::Access::kReadWrite, name)) {
3609         return false;
3610       }
3611       out << post << ";";
3612     }
3613   }
3614 
3615   line(b) << "};";
3616 
3617   return true;
3618 }
3619 
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)3620 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
3621                                 const ast::UnaryOpExpression* expr) {
3622   switch (expr->op) {
3623     case ast::UnaryOp::kIndirection:
3624     case ast::UnaryOp::kAddressOf:
3625       return EmitExpression(out, expr->expr);
3626     case ast::UnaryOp::kComplement:
3627       out << "~";
3628       break;
3629     case ast::UnaryOp::kNot:
3630       out << "!";
3631       break;
3632     case ast::UnaryOp::kNegation:
3633       out << "-";
3634       break;
3635   }
3636   out << "(";
3637 
3638   if (!EmitExpression(out, expr->expr)) {
3639     return false;
3640   }
3641 
3642   out << ")";
3643 
3644   return true;
3645 }
3646 
EmitVariable(const ast::Variable * var)3647 bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
3648   auto* sem = builder_.Sem().Get(var);
3649   auto* type = sem->Type()->UnwrapRef();
3650 
3651   // TODO(dsinclair): Handle variable decorations
3652   if (!var->decorations.empty()) {
3653     diagnostics_.add_error(diag::System::Writer,
3654                            "Variable decorations are not handled yet");
3655     return false;
3656   }
3657 
3658   auto out = line();
3659   if (var->is_const) {
3660     out << "const ";
3661   }
3662   if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3663                        builder_.Symbols().NameFor(var->symbol))) {
3664     return false;
3665   }
3666 
3667   out << " = ";
3668 
3669   if (var->constructor) {
3670     if (!EmitExpression(out, var->constructor)) {
3671       return false;
3672     }
3673   } else {
3674     if (!EmitZeroValue(out, type)) {
3675       return false;
3676     }
3677   }
3678   out << ";";
3679 
3680   return true;
3681 }
3682 
EmitProgramConstVariable(const ast::Variable * var)3683 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
3684   for (auto* d : var->decorations) {
3685     if (!d->Is<ast::OverrideDecoration>()) {
3686       diagnostics_.add_error(diag::System::Writer,
3687                              "Decorated const values not valid");
3688       return false;
3689     }
3690   }
3691   if (!var->is_const) {
3692     diagnostics_.add_error(diag::System::Writer, "Expected a const value");
3693     return false;
3694   }
3695 
3696   auto* sem = builder_.Sem().Get(var);
3697   auto* type = sem->Type();
3698 
3699   auto* global = sem->As<sem::GlobalVariable>();
3700   if (global && global->IsOverridable()) {
3701     auto const_id = global->ConstantId();
3702 
3703     line() << "#ifndef " << kSpecConstantPrefix << const_id;
3704 
3705     if (var->constructor != nullptr) {
3706       auto out = line();
3707       out << "#define " << kSpecConstantPrefix << const_id << " ";
3708       if (!EmitExpression(out, var->constructor)) {
3709         return false;
3710       }
3711     } else {
3712       line() << "#error spec constant required for constant id " << const_id;
3713     }
3714     line() << "#endif";
3715     {
3716       auto out = line();
3717       out << "static const ";
3718       if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3719                            builder_.Symbols().NameFor(var->symbol))) {
3720         return false;
3721       }
3722       out << " = " << kSpecConstantPrefix << const_id << ";";
3723     }
3724   } else {
3725     auto out = line();
3726     out << "static const ";
3727     if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3728                          builder_.Symbols().NameFor(var->symbol))) {
3729       return false;
3730     }
3731     out << " = ";
3732     if (!EmitExpression(out, var->constructor)) {
3733       return false;
3734     }
3735     out << ";";
3736   }
3737 
3738   return true;
3739 }
3740 
3741 template <typename F>
CallIntrinsicHelper(std::ostream & out,const ast::CallExpression * call,const sem::Intrinsic * intrinsic,F && build)3742 bool GeneratorImpl::CallIntrinsicHelper(std::ostream& out,
3743                                         const ast::CallExpression* call,
3744                                         const sem::Intrinsic* intrinsic,
3745                                         F&& build) {
3746   // Generate the helper function if it hasn't been created already
3747   auto fn = utils::GetOrCreate(intrinsics_, intrinsic, [&]() -> std::string {
3748     TextBuffer b;
3749     TINT_DEFER(helpers_.Append(b));
3750 
3751     auto fn_name =
3752         UniqueIdentifier(std::string("tint_") + sem::str(intrinsic->Type()));
3753     std::vector<std::string> parameter_names;
3754     {
3755       auto decl = line(&b);
3756       if (!EmitTypeAndName(decl, intrinsic->ReturnType(),
3757                            ast::StorageClass::kNone, ast::Access::kUndefined,
3758                            fn_name)) {
3759         return "";
3760       }
3761       {
3762         ScopedParen sp(decl);
3763         for (auto* param : intrinsic->Parameters()) {
3764           if (!parameter_names.empty()) {
3765             decl << ", ";
3766           }
3767           auto param_name = "param_" + std::to_string(parameter_names.size());
3768           const auto* ty = param->Type();
3769           if (auto* ptr = ty->As<sem::Pointer>()) {
3770             decl << "inout ";
3771             ty = ptr->StoreType();
3772           }
3773           if (!EmitTypeAndName(decl, ty, ast::StorageClass::kNone,
3774                                ast::Access::kUndefined, param_name)) {
3775             return "";
3776           }
3777           parameter_names.emplace_back(std::move(param_name));
3778         }
3779       }
3780       decl << " {";
3781     }
3782     {
3783       ScopedIndent si(&b);
3784       if (!build(&b, parameter_names)) {
3785         return "";
3786       }
3787     }
3788     line(&b) << "}";
3789     line(&b);
3790     return fn_name;
3791   });
3792 
3793   if (fn.empty()) {
3794     return false;
3795   }
3796 
3797   // Call the helper
3798   out << fn;
3799   {
3800     ScopedParen sp(out);
3801     bool first = true;
3802     for (auto* arg : call->args) {
3803       if (!first) {
3804         out << ", ";
3805       }
3806       first = false;
3807       if (!EmitExpression(out, arg)) {
3808         return false;
3809       }
3810     }
3811   }
3812   return true;
3813 }
3814 
3815 }  // namespace hlsl
3816 }  // namespace writer
3817 }  // namespace tint
3818