1 /// Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/writer/hlsl/generator_impl.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <iomanip>
20 #include <set>
21 #include <utility>
22 #include <vector>
23
24 #include "src/ast/call_statement.h"
25 #include "src/ast/fallthrough_statement.h"
26 #include "src/ast/internal_decoration.h"
27 #include "src/ast/interpolate_decoration.h"
28 #include "src/ast/override_decoration.h"
29 #include "src/ast/variable_decl_statement.h"
30 #include "src/debug.h"
31 #include "src/sem/array.h"
32 #include "src/sem/atomic_type.h"
33 #include "src/sem/block_statement.h"
34 #include "src/sem/call.h"
35 #include "src/sem/depth_multisampled_texture_type.h"
36 #include "src/sem/depth_texture_type.h"
37 #include "src/sem/function.h"
38 #include "src/sem/member_accessor_expression.h"
39 #include "src/sem/multisampled_texture_type.h"
40 #include "src/sem/sampled_texture_type.h"
41 #include "src/sem/statement.h"
42 #include "src/sem/storage_texture_type.h"
43 #include "src/sem/struct.h"
44 #include "src/sem/type_constructor.h"
45 #include "src/sem/type_conversion.h"
46 #include "src/sem/variable.h"
47 #include "src/transform/add_empty_entry_point.h"
48 #include "src/transform/array_length_from_uniform.h"
49 #include "src/transform/calculate_array_length.h"
50 #include "src/transform/canonicalize_entry_point_io.h"
51 #include "src/transform/decompose_memory_access.h"
52 #include "src/transform/external_texture_transform.h"
53 #include "src/transform/fold_trivial_single_use_lets.h"
54 #include "src/transform/loop_to_for_loop.h"
55 #include "src/transform/manager.h"
56 #include "src/transform/num_workgroups_from_uniform.h"
57 #include "src/transform/pad_array_elements.h"
58 #include "src/transform/promote_initializers_to_const_var.h"
59 #include "src/transform/remove_phonies.h"
60 #include "src/transform/simplify_pointers.h"
61 #include "src/transform/unshadow.h"
62 #include "src/transform/zero_init_workgroup_memory.h"
63 #include "src/utils/defer.h"
64 #include "src/utils/map.h"
65 #include "src/utils/scoped_assignment.h"
66 #include "src/writer/append_vector.h"
67 #include "src/writer/float_to_string.h"
68
69 namespace tint {
70 namespace writer {
71 namespace hlsl {
72 namespace {
73
74 const char kTempNamePrefix[] = "tint_tmp";
75 const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
76
image_format_to_rwtexture_type(ast::ImageFormat image_format)77 const char* image_format_to_rwtexture_type(ast::ImageFormat image_format) {
78 switch (image_format) {
79 case ast::ImageFormat::kRgba8Unorm:
80 case ast::ImageFormat::kRgba8Snorm:
81 case ast::ImageFormat::kRgba16Float:
82 case ast::ImageFormat::kR32Float:
83 case ast::ImageFormat::kRg32Float:
84 case ast::ImageFormat::kRgba32Float:
85 return "float4";
86 case ast::ImageFormat::kRgba8Uint:
87 case ast::ImageFormat::kRgba16Uint:
88 case ast::ImageFormat::kR32Uint:
89 case ast::ImageFormat::kRg32Uint:
90 case ast::ImageFormat::kRgba32Uint:
91 return "uint4";
92 case ast::ImageFormat::kRgba8Sint:
93 case ast::ImageFormat::kRgba16Sint:
94 case ast::ImageFormat::kR32Sint:
95 case ast::ImageFormat::kRg32Sint:
96 case ast::ImageFormat::kRgba32Sint:
97 return "int4";
98 default:
99 return nullptr;
100 }
101 }
102
103 // Helper for writing " : register(RX, spaceY)", where R is the register, X is
104 // the binding point binding value, and Y is the binding point group value.
105 struct RegisterAndSpace {
RegisterAndSpacetint::writer::hlsl::__anon72f09d1b0111::RegisterAndSpace106 RegisterAndSpace(char r, ast::VariableBindingPoint bp)
107 : reg(r), binding_point(bp) {}
108
109 const char reg;
110 ast::VariableBindingPoint const binding_point;
111 };
112
operator <<(std::ostream & s,const RegisterAndSpace & rs)113 std::ostream& operator<<(std::ostream& s, const RegisterAndSpace& rs) {
114 s << " : register(" << rs.reg << rs.binding_point.binding->value << ", space"
115 << rs.binding_point.group->value << ")";
116 return s;
117 }
118
LoopAttribute()119 const char* LoopAttribute() {
120 // Force loops not to be unrolled to work around FXC compilation issues when
121 // it attempts and fails to unroll loops when it contains gradient operations.
122 // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-while
123 return "[loop] ";
124 }
125
126 } // namespace
127
128 SanitizedResult::SanitizedResult() = default;
129 SanitizedResult::~SanitizedResult() = default;
130 SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
131
Sanitize(const Program * in,sem::BindingPoint root_constant_binding_point,bool disable_workgroup_init,const ArrayLengthFromUniformOptions & array_length_from_uniform)132 SanitizedResult Sanitize(
133 const Program* in,
134 sem::BindingPoint root_constant_binding_point,
135 bool disable_workgroup_init,
136 const ArrayLengthFromUniformOptions& array_length_from_uniform) {
137 transform::Manager manager;
138 transform::DataMap data;
139
140 // Build the config for the internal ArrayLengthFromUniform transform.
141 transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
142 array_length_from_uniform.ubo_binding);
143 array_length_from_uniform_cfg.bindpoint_to_size_index =
144 array_length_from_uniform.bindpoint_to_size_index;
145
146 manager.Add<transform::Unshadow>();
147
148 // Attempt to convert `loop`s into for-loops. This is to try and massage the
149 // output into something that will not cause FXC to choke or misbehave.
150 manager.Add<transform::FoldTrivialSingleUseLets>();
151 manager.Add<transform::LoopToForLoop>();
152
153 if (!disable_workgroup_init) {
154 // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
155 // ZeroInitWorkgroupMemory may inject new builtin parameters.
156 manager.Add<transform::ZeroInitWorkgroupMemory>();
157 }
158 manager.Add<transform::CanonicalizeEntryPointIO>();
159 // NumWorkgroupsFromUniform must come after CanonicalizeEntryPointIO, as it
160 // assumes that num_workgroups builtins only appear as struct members and are
161 // only accessed directly via member accessors.
162 manager.Add<transform::NumWorkgroupsFromUniform>();
163 manager.Add<transform::SimplifyPointers>();
164 manager.Add<transform::RemovePhonies>();
165 // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
166 // it assumes that the form of the array length argument is &var.array.
167 manager.Add<transform::ArrayLengthFromUniform>();
168 data.Add<transform::ArrayLengthFromUniform::Config>(
169 std::move(array_length_from_uniform_cfg));
170 // DecomposeMemoryAccess must come after:
171 // * InlinePointerLets, as we cannot take the address of calls to
172 // DecomposeMemoryAccess::Intrinsic.
173 // * Simplify, as we need to fold away the address-of and dereferences of
174 // `*(&(intrinsic_load()))` expressions.
175 // * RemovePhonies, as phonies can be assigned a pointer to a
176 // non-constructible buffer, or dynamic array, which DMA cannot cope with.
177 manager.Add<transform::DecomposeMemoryAccess>();
178 // CalculateArrayLength must come after DecomposeMemoryAccess, as
179 // DecomposeMemoryAccess special-cases the arrayLength() intrinsic, which
180 // will be transformed by CalculateArrayLength
181 manager.Add<transform::CalculateArrayLength>();
182 manager.Add<transform::ExternalTextureTransform>();
183 manager.Add<transform::PromoteInitializersToConstVar>();
184 manager.Add<transform::PadArrayElements>();
185 manager.Add<transform::AddEmptyEntryPoint>();
186
187 data.Add<transform::CanonicalizeEntryPointIO::Config>(
188 transform::CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
189 data.Add<transform::NumWorkgroupsFromUniform::Config>(
190 root_constant_binding_point);
191
192 auto out = manager.Run(in, data);
193
194 SanitizedResult result;
195 result.program = std::move(out.program);
196 if (auto* res = out.data.Get<transform::ArrayLengthFromUniform::Result>()) {
197 result.used_array_length_from_uniform_indices =
198 std::move(res->used_size_indices);
199 }
200 return result;
201 }
202
GeneratorImpl(const Program * program)203 GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
204
205 GeneratorImpl::~GeneratorImpl() = default;
206
Generate()207 bool GeneratorImpl::Generate() {
208 const TypeInfo* last_kind = nullptr;
209 size_t last_padding_line = 0;
210
211 for (auto* decl : builder_.AST().GlobalDeclarations()) {
212 if (decl->Is<ast::Alias>()) {
213 continue; // Ignore aliases.
214 }
215
216 // Emit a new line between declarations if the type of declaration has
217 // changed, or we're about to emit a function
218 auto* kind = &decl->TypeInfo();
219 if (current_buffer_->lines.size() != last_padding_line) {
220 if (last_kind && (last_kind != kind || decl->Is<ast::Function>())) {
221 line();
222 last_padding_line = current_buffer_->lines.size();
223 }
224 }
225 last_kind = kind;
226
227 if (auto* global = decl->As<ast::Variable>()) {
228 if (!EmitGlobalVariable(global)) {
229 return false;
230 }
231 } else if (auto* str = decl->As<ast::Struct>()) {
232 auto* ty = builder_.Sem().Get(str);
233 auto storage_class_uses = ty->StorageClassUsage();
234 if (storage_class_uses.size() !=
235 (storage_class_uses.count(ast::StorageClass::kStorage) +
236 storage_class_uses.count(ast::StorageClass::kUniform))) {
237 // The structure is used as something other than a storage buffer or
238 // uniform buffer, so it needs to be emitted.
239 // Storage buffer are read and written to via a ByteAddressBuffer
240 // instead of true structure.
241 // Structures used as uniform buffer are read from an array of vectors
242 // instead of true structure.
243 if (!EmitStructType(current_buffer_, ty)) {
244 return false;
245 }
246 }
247 } else if (auto* func = decl->As<ast::Function>()) {
248 if (func->IsEntryPoint()) {
249 if (!EmitEntryPointFunction(func)) {
250 return false;
251 }
252 } else {
253 if (!EmitFunction(func)) {
254 return false;
255 }
256 }
257 } else {
258 TINT_ICE(Writer, diagnostics_)
259 << "unhandled module-scope declaration: " << decl->TypeInfo().name;
260 return false;
261 }
262 }
263
264 if (!helpers_.lines.empty()) {
265 current_buffer_->Insert(helpers_, 0, 0);
266 }
267
268 return true;
269 }
270
EmitDynamicVectorAssignment(const ast::AssignmentStatement * stmt,const sem::Vector * vec)271 bool GeneratorImpl::EmitDynamicVectorAssignment(
272 const ast::AssignmentStatement* stmt,
273 const sem::Vector* vec) {
274 auto name =
275 utils::GetOrCreate(dynamic_vector_write_, vec, [&]() -> std::string {
276 std::string fn;
277 {
278 std::ostringstream ss;
279 if (!EmitType(ss, vec, tint::ast::StorageClass::kInvalid,
280 ast::Access::kUndefined, "")) {
281 return "";
282 }
283 fn = UniqueIdentifier("set_" + ss.str());
284 }
285 {
286 auto out = line(&helpers_);
287 out << "void " << fn << "(inout ";
288 if (!EmitTypeAndName(out, vec, ast::StorageClass::kInvalid,
289 ast::Access::kUndefined, "vec")) {
290 return "";
291 }
292 out << ", int idx, ";
293 if (!EmitTypeAndName(out, vec->type(), ast::StorageClass::kInvalid,
294 ast::Access::kUndefined, "val")) {
295 return "";
296 }
297 out << ") {";
298 }
299 {
300 ScopedIndent si(&helpers_);
301 auto out = line(&helpers_);
302 switch (vec->Width()) {
303 case 2:
304 out << "vec = (idx.xx == int2(0, 1)) ? val.xx : vec;";
305 break;
306 case 3:
307 out << "vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;";
308 break;
309 case 4:
310 out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;";
311 break;
312 default:
313 TINT_UNREACHABLE(Writer, builder_.Diagnostics())
314 << "invalid vector size " << vec->Width();
315 break;
316 }
317 }
318 line(&helpers_) << "}";
319 line(&helpers_);
320 return fn;
321 });
322
323 if (name.empty()) {
324 return false;
325 }
326
327 auto* ast_access_expr = stmt->lhs->As<ast::IndexAccessorExpression>();
328
329 auto out = line();
330 out << name << "(";
331 if (!EmitExpression(out, ast_access_expr->object)) {
332 return false;
333 }
334 out << ", ";
335 if (!EmitExpression(out, ast_access_expr->index)) {
336 return false;
337 }
338 out << ", ";
339 if (!EmitExpression(out, stmt->rhs)) {
340 return false;
341 }
342 out << ");";
343
344 return true;
345 }
346
EmitDynamicMatrixVectorAssignment(const ast::AssignmentStatement * stmt,const sem::Matrix * mat)347 bool GeneratorImpl::EmitDynamicMatrixVectorAssignment(
348 const ast::AssignmentStatement* stmt,
349 const sem::Matrix* mat) {
350 auto name = utils::GetOrCreate(
351 dynamic_matrix_vector_write_, mat, [&]() -> std::string {
352 std::string fn;
353 {
354 std::ostringstream ss;
355 if (!EmitType(ss, mat, tint::ast::StorageClass::kInvalid,
356 ast::Access::kUndefined, "")) {
357 return "";
358 }
359 fn = UniqueIdentifier("set_vector_" + ss.str());
360 }
361 {
362 auto out = line(&helpers_);
363 out << "void " << fn << "(inout ";
364 if (!EmitTypeAndName(out, mat, ast::StorageClass::kInvalid,
365 ast::Access::kUndefined, "mat")) {
366 return "";
367 }
368 out << ", int col, ";
369 if (!EmitTypeAndName(out, mat->ColumnType(),
370 ast::StorageClass::kInvalid,
371 ast::Access::kUndefined, "val")) {
372 return "";
373 }
374 out << ") {";
375 }
376 {
377 ScopedIndent si(&helpers_);
378 line(&helpers_) << "switch (col) {";
379 {
380 ScopedIndent si2(&helpers_);
381 for (uint32_t i = 0; i < mat->columns(); ++i) {
382 line(&helpers_)
383 << "case " << i << ": mat[" << i << "] = val; break;";
384 }
385 }
386 line(&helpers_) << "}";
387 }
388 line(&helpers_) << "}";
389 line(&helpers_);
390 return fn;
391 });
392
393 if (name.empty()) {
394 return false;
395 }
396
397 auto* ast_access_expr = stmt->lhs->As<ast::IndexAccessorExpression>();
398
399 auto out = line();
400 out << name << "(";
401 if (!EmitExpression(out, ast_access_expr->object)) {
402 return false;
403 }
404 out << ", ";
405 if (!EmitExpression(out, ast_access_expr->index)) {
406 return false;
407 }
408 out << ", ";
409 if (!EmitExpression(out, stmt->rhs)) {
410 return false;
411 }
412 out << ");";
413
414 return true;
415 }
416
EmitDynamicMatrixScalarAssignment(const ast::AssignmentStatement * stmt,const sem::Matrix * mat)417 bool GeneratorImpl::EmitDynamicMatrixScalarAssignment(
418 const ast::AssignmentStatement* stmt,
419 const sem::Matrix* mat) {
420 auto* lhs_col_access = stmt->lhs->As<ast::IndexAccessorExpression>();
421 auto* lhs_row_access =
422 lhs_col_access->object->As<ast::IndexAccessorExpression>();
423
424 auto name = utils::GetOrCreate(
425 dynamic_matrix_scalar_write_, mat, [&]() -> std::string {
426 std::string fn;
427 {
428 std::ostringstream ss;
429 if (!EmitType(ss, mat, tint::ast::StorageClass::kInvalid,
430 ast::Access::kUndefined, "")) {
431 return "";
432 }
433 fn = UniqueIdentifier("set_scalar_" + ss.str());
434 }
435 {
436 auto out = line(&helpers_);
437 out << "void " << fn << "(inout ";
438 if (!EmitTypeAndName(out, mat, ast::StorageClass::kInvalid,
439 ast::Access::kUndefined, "mat")) {
440 return "";
441 }
442 out << ", int col, int row, ";
443 if (!EmitTypeAndName(out, mat->type(), ast::StorageClass::kInvalid,
444 ast::Access::kUndefined, "val")) {
445 return "";
446 }
447 out << ") {";
448 }
449 {
450 ScopedIndent si(&helpers_);
451 line(&helpers_) << "switch (col) {";
452 {
453 ScopedIndent si2(&helpers_);
454 auto* vec =
455 TypeOf(lhs_row_access->object)->UnwrapRef()->As<sem::Vector>();
456 for (uint32_t i = 0; i < mat->columns(); ++i) {
457 line(&helpers_) << "case " << i << ":";
458 {
459 auto vec_name = "mat[" + std::to_string(i) + "]";
460 ScopedIndent si3(&helpers_);
461 {
462 auto out = line(&helpers_);
463 switch (mat->rows()) {
464 case 2:
465 out << vec_name
466 << " = (row.xx == int2(0, 1)) ? val.xx : " << vec_name
467 << ";";
468 break;
469 case 3:
470 out << vec_name
471 << " = (row.xxx == int3(0, 1, 2)) ? val.xxx : "
472 << vec_name << ";";
473 break;
474 case 4:
475 out << vec_name
476 << " = (row.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : "
477 << vec_name << ";";
478 break;
479 default:
480 TINT_UNREACHABLE(Writer, builder_.Diagnostics())
481 << "invalid vector size " << vec->Width();
482 break;
483 }
484 }
485 line(&helpers_) << "break;";
486 }
487 }
488 }
489 line(&helpers_) << "}";
490 }
491 line(&helpers_) << "}";
492 line(&helpers_);
493 return fn;
494 });
495
496 if (name.empty()) {
497 return false;
498 }
499
500 auto out = line();
501 out << name << "(";
502 if (!EmitExpression(out, lhs_row_access->object)) {
503 return false;
504 }
505 out << ", ";
506 if (!EmitExpression(out, lhs_col_access->index)) {
507 return false;
508 }
509 out << ", ";
510 if (!EmitExpression(out, lhs_row_access->index)) {
511 return false;
512 }
513 out << ", ";
514 if (!EmitExpression(out, stmt->rhs)) {
515 return false;
516 }
517 out << ");";
518
519 return true;
520 }
521
EmitIndexAccessor(std::ostream & out,const ast::IndexAccessorExpression * expr)522 bool GeneratorImpl::EmitIndexAccessor(
523 std::ostream& out,
524 const ast::IndexAccessorExpression* expr) {
525 if (!EmitExpression(out, expr->object)) {
526 return false;
527 }
528 out << "[";
529
530 if (!EmitExpression(out, expr->index)) {
531 return false;
532 }
533 out << "]";
534
535 return true;
536 }
537
EmitBitcast(std::ostream & out,const ast::BitcastExpression * expr)538 bool GeneratorImpl::EmitBitcast(std::ostream& out,
539 const ast::BitcastExpression* expr) {
540 auto* type = TypeOf(expr);
541 if (auto* vec = type->UnwrapRef()->As<sem::Vector>()) {
542 type = vec->type();
543 }
544
545 if (!type->is_integer_scalar() && !type->is_float_scalar()) {
546 diagnostics_.add_error(diag::System::Writer,
547 "Unable to do bitcast to type " + type->type_name());
548 return false;
549 }
550
551 out << "as";
552 if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
553 "")) {
554 return false;
555 }
556 out << "(";
557 if (!EmitExpression(out, expr->expr)) {
558 return false;
559 }
560 out << ")";
561 return true;
562 }
563
EmitAssign(const ast::AssignmentStatement * stmt)564 bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) {
565 if (auto* lhs_access = stmt->lhs->As<ast::IndexAccessorExpression>()) {
566 // BUG(crbug.com/tint/1333): work around assignment of scalar to matrices
567 // with at least one dynamic index
568 if (auto* lhs_sub_access =
569 lhs_access->object->As<ast::IndexAccessorExpression>()) {
570 if (auto* mat =
571 TypeOf(lhs_sub_access->object)->UnwrapRef()->As<sem::Matrix>()) {
572 auto* rhs_col_idx_sem = builder_.Sem().Get(lhs_access->index);
573 auto* rhs_row_idx_sem = builder_.Sem().Get(lhs_sub_access->index);
574 if (!rhs_col_idx_sem->ConstantValue().IsValid() ||
575 !rhs_row_idx_sem->ConstantValue().IsValid()) {
576 return EmitDynamicMatrixScalarAssignment(stmt, mat);
577 }
578 }
579 }
580 // BUG(crbug.com/tint/1333): work around assignment of vector to matrices
581 // with dynamic indices
582 const auto* lhs_access_type = TypeOf(lhs_access->object)->UnwrapRef();
583 if (auto* mat = lhs_access_type->As<sem::Matrix>()) {
584 auto* lhs_index_sem = builder_.Sem().Get(lhs_access->index);
585 if (!lhs_index_sem->ConstantValue().IsValid()) {
586 return EmitDynamicMatrixVectorAssignment(stmt, mat);
587 }
588 }
589 // BUG(crbug.com/tint/534): work around assignment to vectors with dynamic
590 // indices
591 if (auto* vec = lhs_access_type->As<sem::Vector>()) {
592 auto* rhs_sem = builder_.Sem().Get(lhs_access->index);
593 if (!rhs_sem->ConstantValue().IsValid()) {
594 return EmitDynamicVectorAssignment(stmt, vec);
595 }
596 }
597 }
598
599 auto out = line();
600 if (!EmitExpression(out, stmt->lhs)) {
601 return false;
602 }
603 out << " = ";
604 if (!EmitExpression(out, stmt->rhs)) {
605 return false;
606 }
607 out << ";";
608 return true;
609 }
610
EmitBinary(std::ostream & out,const ast::BinaryExpression * expr)611 bool GeneratorImpl::EmitBinary(std::ostream& out,
612 const ast::BinaryExpression* expr) {
613 if (expr->op == ast::BinaryOp::kLogicalAnd ||
614 expr->op == ast::BinaryOp::kLogicalOr) {
615 auto name = UniqueIdentifier(kTempNamePrefix);
616
617 {
618 auto pre = line();
619 pre << "bool " << name << " = ";
620 if (!EmitExpression(pre, expr->lhs)) {
621 return false;
622 }
623 pre << ";";
624 }
625
626 if (expr->op == ast::BinaryOp::kLogicalOr) {
627 line() << "if (!" << name << ") {";
628 } else {
629 line() << "if (" << name << ") {";
630 }
631
632 {
633 ScopedIndent si(this);
634 auto pre = line();
635 pre << name << " = ";
636 if (!EmitExpression(pre, expr->rhs)) {
637 return false;
638 }
639 pre << ";";
640 }
641
642 line() << "}";
643
644 out << "(" << name << ")";
645 return true;
646 }
647
648 auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
649 auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
650 // Multiplying by a matrix requires the use of `mul` in order to get the
651 // type of multiply we desire.
652 if (expr->op == ast::BinaryOp::kMultiply &&
653 ((lhs_type->Is<sem::Vector>() && rhs_type->Is<sem::Matrix>()) ||
654 (lhs_type->Is<sem::Matrix>() && rhs_type->Is<sem::Vector>()) ||
655 (lhs_type->Is<sem::Matrix>() && rhs_type->Is<sem::Matrix>()))) {
656 // Matrices are transposed, so swap LHS and RHS.
657 out << "mul(";
658 if (!EmitExpression(out, expr->rhs)) {
659 return false;
660 }
661 out << ", ";
662 if (!EmitExpression(out, expr->lhs)) {
663 return false;
664 }
665 out << ")";
666
667 return true;
668 }
669
670 out << "(";
671 TINT_DEFER(out << ")");
672
673 if (!EmitExpression(out, expr->lhs)) {
674 return false;
675 }
676 out << " ";
677
678 switch (expr->op) {
679 case ast::BinaryOp::kAnd:
680 out << "&";
681 break;
682 case ast::BinaryOp::kOr:
683 out << "|";
684 break;
685 case ast::BinaryOp::kXor:
686 out << "^";
687 break;
688 case ast::BinaryOp::kLogicalAnd:
689 case ast::BinaryOp::kLogicalOr: {
690 // These are both handled above.
691 TINT_UNREACHABLE(Writer, diagnostics_);
692 return false;
693 }
694 case ast::BinaryOp::kEqual:
695 out << "==";
696 break;
697 case ast::BinaryOp::kNotEqual:
698 out << "!=";
699 break;
700 case ast::BinaryOp::kLessThan:
701 out << "<";
702 break;
703 case ast::BinaryOp::kGreaterThan:
704 out << ">";
705 break;
706 case ast::BinaryOp::kLessThanEqual:
707 out << "<=";
708 break;
709 case ast::BinaryOp::kGreaterThanEqual:
710 out << ">=";
711 break;
712 case ast::BinaryOp::kShiftLeft:
713 out << "<<";
714 break;
715 case ast::BinaryOp::kShiftRight:
716 // TODO(dsinclair): MSL is based on C++14, and >> in C++14 has
717 // implementation-defined behaviour for negative LHS. We may have to
718 // generate extra code to implement WGSL-specified behaviour for negative
719 // LHS.
720 out << R"(>>)";
721 break;
722
723 case ast::BinaryOp::kAdd:
724 out << "+";
725 break;
726 case ast::BinaryOp::kSubtract:
727 out << "-";
728 break;
729 case ast::BinaryOp::kMultiply:
730 out << "*";
731 break;
732 case ast::BinaryOp::kDivide:
733 out << "/";
734
735 if (auto val = builder_.Sem().Get(expr->rhs)->ConstantValue()) {
736 // Integer divide by zero is a DXC compile error, and undefined behavior
737 // in WGSL. Replace the 0 with 1.
738 if (val.Type()->Is<sem::I32>() && val.Elements()[0].i32 == 0) {
739 out << " 1";
740 return true;
741 }
742 if (val.Type()->Is<sem::U32>() && val.Elements()[0].u32 == 0u) {
743 out << " 1u";
744 return true;
745 }
746 }
747 break;
748 case ast::BinaryOp::kModulo:
749 out << "%";
750 break;
751 case ast::BinaryOp::kNone:
752 diagnostics_.add_error(diag::System::Writer,
753 "missing binary operation type");
754 return false;
755 }
756 out << " ";
757
758 if (!EmitExpression(out, expr->rhs)) {
759 return false;
760 }
761
762 return true;
763 }
764
EmitStatements(const ast::StatementList & stmts)765 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
766 for (auto* s : stmts) {
767 if (!EmitStatement(s)) {
768 return false;
769 }
770 }
771 return true;
772 }
773
EmitStatementsWithIndent(const ast::StatementList & stmts)774 bool GeneratorImpl::EmitStatementsWithIndent(const ast::StatementList& stmts) {
775 ScopedIndent si(this);
776 return EmitStatements(stmts);
777 }
778
EmitBlock(const ast::BlockStatement * stmt)779 bool GeneratorImpl::EmitBlock(const ast::BlockStatement* stmt) {
780 line() << "{";
781 if (!EmitStatementsWithIndent(stmt->statements)) {
782 return false;
783 }
784 line() << "}";
785 return true;
786 }
787
EmitBreak(const ast::BreakStatement *)788 bool GeneratorImpl::EmitBreak(const ast::BreakStatement*) {
789 line() << "break;";
790 return true;
791 }
792
EmitCall(std::ostream & out,const ast::CallExpression * expr)793 bool GeneratorImpl::EmitCall(std::ostream& out,
794 const ast::CallExpression* expr) {
795 auto* call = builder_.Sem().Get(expr);
796 auto* target = call->Target();
797
798 if (auto* func = target->As<sem::Function>()) {
799 return EmitFunctionCall(out, call, func);
800 }
801 if (auto* intrinsic = target->As<sem::Intrinsic>()) {
802 return EmitIntrinsicCall(out, call, intrinsic);
803 }
804 if (auto* conv = target->As<sem::TypeConversion>()) {
805 return EmitTypeConversion(out, call, conv);
806 }
807 if (auto* ctor = target->As<sem::TypeConstructor>()) {
808 return EmitTypeConstructor(out, call, ctor);
809 }
810 TINT_ICE(Writer, diagnostics_)
811 << "unhandled call target: " << target->TypeInfo().name;
812 return false;
813 }
814
EmitFunctionCall(std::ostream & out,const sem::Call * call,const sem::Function * func)815 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
816 const sem::Call* call,
817 const sem::Function* func) {
818 auto* expr = call->Declaration();
819
820 if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
821 func->Declaration()->decorations)) {
822 // Special function generated by the CalculateArrayLength transform for
823 // calling X.GetDimensions(Y)
824 if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
825 return false;
826 }
827 out << ".GetDimensions(";
828 if (!EmitExpression(out, call->Arguments()[1]->Declaration())) {
829 return false;
830 }
831 out << ")";
832 return true;
833 }
834
835 if (auto* intrinsic =
836 ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
837 func->Declaration()->decorations)) {
838 switch (intrinsic->storage_class) {
839 case ast::StorageClass::kUniform:
840 return EmitUniformBufferAccess(out, expr, intrinsic);
841 case ast::StorageClass::kStorage:
842 return EmitStorageBufferAccess(out, expr, intrinsic);
843 default:
844 TINT_UNREACHABLE(Writer, diagnostics_)
845 << "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
846 << intrinsic->storage_class;
847 return false;
848 }
849 }
850
851 out << builder_.Symbols().NameFor(func->Declaration()->symbol) << "(";
852
853 bool first = true;
854 for (auto* arg : call->Arguments()) {
855 if (!first) {
856 out << ", ";
857 }
858 first = false;
859
860 if (!EmitExpression(out, arg->Declaration())) {
861 return false;
862 }
863 }
864
865 out << ")";
866 return true;
867 }
868
EmitIntrinsicCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)869 bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
870 const sem::Call* call,
871 const sem::Intrinsic* intrinsic) {
872 auto* expr = call->Declaration();
873 if (intrinsic->IsTexture()) {
874 return EmitTextureCall(out, call, intrinsic);
875 }
876 if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
877 return EmitSelectCall(out, expr);
878 }
879 if (intrinsic->Type() == sem::IntrinsicType::kModf) {
880 return EmitModfCall(out, expr, intrinsic);
881 }
882 if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
883 return EmitFrexpCall(out, expr, intrinsic);
884 }
885 if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
886 return EmitIsNormalCall(out, expr, intrinsic);
887 }
888 if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
889 return EmitExpression(out, expr->args[0]); // [DEPRECATED]
890 }
891 if (intrinsic->IsDataPacking()) {
892 return EmitDataPackingCall(out, expr, intrinsic);
893 }
894 if (intrinsic->IsDataUnpacking()) {
895 return EmitDataUnpackingCall(out, expr, intrinsic);
896 }
897 if (intrinsic->IsBarrier()) {
898 return EmitBarrierCall(out, intrinsic);
899 }
900 if (intrinsic->IsAtomic()) {
901 return EmitWorkgroupAtomicCall(out, expr, intrinsic);
902 }
903
904 auto name = generate_builtin_name(intrinsic);
905 if (name.empty()) {
906 return false;
907 }
908
909 out << name << "(";
910
911 bool first = true;
912 for (auto* arg : call->Arguments()) {
913 if (!first) {
914 out << ", ";
915 }
916 first = false;
917
918 if (!EmitExpression(out, arg->Declaration())) {
919 return false;
920 }
921 }
922
923 out << ")";
924 return true;
925 }
926
EmitTypeConversion(std::ostream & out,const sem::Call * call,const sem::TypeConversion * conv)927 bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
928 const sem::Call* call,
929 const sem::TypeConversion* conv) {
930 if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
931 ast::Access::kReadWrite, "")) {
932 return false;
933 }
934 out << "(";
935
936 if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
937 return false;
938 }
939
940 out << ")";
941 return true;
942 }
943
EmitTypeConstructor(std::ostream & out,const sem::Call * call,const sem::TypeConstructor * ctor)944 bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
945 const sem::Call* call,
946 const sem::TypeConstructor* ctor) {
947 auto* type = call->Type();
948
949 // If the type constructor is empty then we need to construct with the zero
950 // value for all components.
951 if (call->Arguments().empty()) {
952 return EmitZeroValue(out, type);
953 }
954
955 bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
956
957 // For single-value vector initializers, swizzle the scalar to the right
958 // vector dimension using .x
959 const bool is_single_value_vector_init =
960 type->is_scalar_vector() && call->Arguments().size() == 1 &&
961 ctor->Parameters()[0]->Type()->is_scalar();
962
963 auto it = structure_builders_.find(As<sem::Struct>(type));
964 if (it != structure_builders_.end()) {
965 out << it->second << "(";
966 brackets = false;
967 } else if (brackets) {
968 out << "{";
969 } else {
970 if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
971 "")) {
972 return false;
973 }
974 out << "(";
975 }
976
977 if (is_single_value_vector_init) {
978 out << "(";
979 }
980
981 bool first = true;
982 for (auto* e : call->Arguments()) {
983 if (!first) {
984 out << ", ";
985 }
986 first = false;
987
988 if (!EmitExpression(out, e->Declaration())) {
989 return false;
990 }
991 }
992
993 if (is_single_value_vector_init) {
994 out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
995 }
996
997 out << (brackets ? "}" : ")");
998 return true;
999 }
1000
EmitUniformBufferAccess(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1001 bool GeneratorImpl::EmitUniformBufferAccess(
1002 std::ostream& out,
1003 const ast::CallExpression* expr,
1004 const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1005 const auto& args = expr->args;
1006 auto* offset_arg = builder_.Sem().Get(args[1]);
1007
1008 uint32_t scalar_offset_value = 0;
1009 std::string scalar_offset_expr;
1010
1011 // If true, use scalar_offset_value, otherwise use scalar_offset_expr
1012 bool scalar_offset_constant = false;
1013
1014 if (auto val = offset_arg->ConstantValue()) {
1015 TINT_ASSERT(Writer, val.Type()->Is<sem::U32>());
1016 scalar_offset_value = val.Elements()[0].u32;
1017 scalar_offset_value /= 4; // bytes -> scalar index
1018 scalar_offset_constant = true;
1019 }
1020
1021 if (!scalar_offset_constant) {
1022 // UBO offset not compile-time known.
1023 // Calculate the scalar offset into a temporary.
1024 scalar_offset_expr = UniqueIdentifier("scalar_offset");
1025 auto pre = line();
1026 pre << "const uint " << scalar_offset_expr << " = (";
1027 if (!EmitExpression(pre, args[1])) { // offset
1028 return false;
1029 }
1030 pre << ") / 4;";
1031 }
1032
1033 using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1034 using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
1035 switch (intrinsic->op) {
1036 case Op::kLoad: {
1037 auto cast = [&](const char* to, auto&& load) {
1038 out << to << "(";
1039 auto result = load();
1040 out << ")";
1041 return result;
1042 };
1043 auto load_scalar = [&]() {
1044 if (!EmitExpression(out, args[0])) { // buffer
1045 return false;
1046 }
1047 if (scalar_offset_constant) {
1048 char swizzle[] = {'x', 'y', 'z', 'w'};
1049 out << "[" << (scalar_offset_value / 4) << "]."
1050 << swizzle[scalar_offset_value & 3];
1051 } else {
1052 out << "[" << scalar_offset_expr << " / 4][" << scalar_offset_expr
1053 << " % 4]";
1054 }
1055 return true;
1056 };
1057 // Has a minimum alignment of 8 bytes, so is either .xy or .zw
1058 auto load_vec2 = [&] {
1059 if (scalar_offset_constant) {
1060 if (!EmitExpression(out, args[0])) { // buffer
1061 return false;
1062 }
1063 out << "[" << (scalar_offset_value / 4) << "]";
1064 out << ((scalar_offset_value & 2) == 0 ? ".xy" : ".zw");
1065 } else {
1066 std::string ubo_load = UniqueIdentifier("ubo_load");
1067 {
1068 auto pre = line();
1069 pre << "uint4 " << ubo_load << " = ";
1070 if (!EmitExpression(pre, args[0])) { // buffer
1071 return false;
1072 }
1073 pre << "[" << scalar_offset_expr << " / 4];";
1074 }
1075 out << "((" << scalar_offset_expr << " & 2) ? " << ubo_load
1076 << ".zw : " << ubo_load << ".xy)";
1077 }
1078 return true;
1079 };
1080 // vec4 has a minimum alignment of 16 bytes, easiest case
1081 auto load_vec4 = [&] {
1082 if (!EmitExpression(out, args[0])) { // buffer
1083 return false;
1084 }
1085 if (scalar_offset_constant) {
1086 out << "[" << (scalar_offset_value / 4) << "]";
1087 } else {
1088 out << "[" << scalar_offset_expr << " / 4]";
1089 }
1090 return true;
1091 };
1092 // vec3 has a minimum alignment of 16 bytes, so is just a .xyz swizzle
1093 auto load_vec3 = [&] {
1094 if (!load_vec4()) {
1095 return false;
1096 }
1097 out << ".xyz";
1098 return true;
1099 };
1100 switch (intrinsic->type) {
1101 case DataType::kU32:
1102 return load_scalar();
1103 case DataType::kF32:
1104 return cast("asfloat", load_scalar);
1105 case DataType::kI32:
1106 return cast("asint", load_scalar);
1107 case DataType::kVec2U32:
1108 return load_vec2();
1109 case DataType::kVec2F32:
1110 return cast("asfloat", load_vec2);
1111 case DataType::kVec2I32:
1112 return cast("asint", load_vec2);
1113 case DataType::kVec3U32:
1114 return load_vec3();
1115 case DataType::kVec3F32:
1116 return cast("asfloat", load_vec3);
1117 case DataType::kVec3I32:
1118 return cast("asint", load_vec3);
1119 case DataType::kVec4U32:
1120 return load_vec4();
1121 case DataType::kVec4F32:
1122 return cast("asfloat", load_vec4);
1123 case DataType::kVec4I32:
1124 return cast("asint", load_vec4);
1125 }
1126 TINT_UNREACHABLE(Writer, diagnostics_)
1127 << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1128 << static_cast<int>(intrinsic->type);
1129 return false;
1130 }
1131 default:
1132 break;
1133 }
1134 TINT_UNREACHABLE(Writer, diagnostics_)
1135 << "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
1136 << static_cast<int>(intrinsic->op);
1137 return false;
1138 }
1139
EmitStorageBufferAccess(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1140 bool GeneratorImpl::EmitStorageBufferAccess(
1141 std::ostream& out,
1142 const ast::CallExpression* expr,
1143 const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1144 const auto& args = expr->args;
1145
1146 using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1147 using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
1148 switch (intrinsic->op) {
1149 case Op::kLoad: {
1150 auto load = [&](const char* cast, int n) {
1151 if (cast) {
1152 out << cast << "(";
1153 }
1154 if (!EmitExpression(out, args[0])) { // buffer
1155 return false;
1156 }
1157 out << ".Load";
1158 if (n > 1) {
1159 out << n;
1160 }
1161 ScopedParen sp(out);
1162 if (!EmitExpression(out, args[1])) { // offset
1163 return false;
1164 }
1165 if (cast) {
1166 out << ")";
1167 }
1168 return true;
1169 };
1170 switch (intrinsic->type) {
1171 case DataType::kU32:
1172 return load(nullptr, 1);
1173 case DataType::kF32:
1174 return load("asfloat", 1);
1175 case DataType::kI32:
1176 return load("asint", 1);
1177 case DataType::kVec2U32:
1178 return load(nullptr, 2);
1179 case DataType::kVec2F32:
1180 return load("asfloat", 2);
1181 case DataType::kVec2I32:
1182 return load("asint", 2);
1183 case DataType::kVec3U32:
1184 return load(nullptr, 3);
1185 case DataType::kVec3F32:
1186 return load("asfloat", 3);
1187 case DataType::kVec3I32:
1188 return load("asint", 3);
1189 case DataType::kVec4U32:
1190 return load(nullptr, 4);
1191 case DataType::kVec4F32:
1192 return load("asfloat", 4);
1193 case DataType::kVec4I32:
1194 return load("asint", 4);
1195 }
1196 TINT_UNREACHABLE(Writer, diagnostics_)
1197 << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1198 << static_cast<int>(intrinsic->type);
1199 return false;
1200 }
1201
1202 case Op::kStore: {
1203 auto store = [&](int n) {
1204 if (!EmitExpression(out, args[0])) { // buffer
1205 return false;
1206 }
1207 out << ".Store";
1208 if (n > 1) {
1209 out << n;
1210 }
1211 ScopedParen sp1(out);
1212 if (!EmitExpression(out, args[1])) { // offset
1213 return false;
1214 }
1215 out << ", asuint";
1216 ScopedParen sp2(out);
1217 if (!EmitExpression(out, args[2])) { // value
1218 return false;
1219 }
1220 return true;
1221 };
1222 switch (intrinsic->type) {
1223 case DataType::kU32:
1224 return store(1);
1225 case DataType::kF32:
1226 return store(1);
1227 case DataType::kI32:
1228 return store(1);
1229 case DataType::kVec2U32:
1230 return store(2);
1231 case DataType::kVec2F32:
1232 return store(2);
1233 case DataType::kVec2I32:
1234 return store(2);
1235 case DataType::kVec3U32:
1236 return store(3);
1237 case DataType::kVec3F32:
1238 return store(3);
1239 case DataType::kVec3I32:
1240 return store(3);
1241 case DataType::kVec4U32:
1242 return store(4);
1243 case DataType::kVec4F32:
1244 return store(4);
1245 case DataType::kVec4I32:
1246 return store(4);
1247 }
1248 TINT_UNREACHABLE(Writer, diagnostics_)
1249 << "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
1250 << static_cast<int>(intrinsic->type);
1251 return false;
1252 }
1253
1254 case Op::kAtomicLoad:
1255 case Op::kAtomicStore:
1256 case Op::kAtomicAdd:
1257 case Op::kAtomicSub:
1258 case Op::kAtomicMax:
1259 case Op::kAtomicMin:
1260 case Op::kAtomicAnd:
1261 case Op::kAtomicOr:
1262 case Op::kAtomicXor:
1263 case Op::kAtomicExchange:
1264 case Op::kAtomicCompareExchangeWeak:
1265 return EmitStorageAtomicCall(out, expr, intrinsic);
1266 }
1267
1268 TINT_UNREACHABLE(Writer, diagnostics_)
1269 << "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
1270 << static_cast<int>(intrinsic->op);
1271 return false;
1272 }
1273
EmitStorageAtomicCall(std::ostream & out,const ast::CallExpression * expr,const transform::DecomposeMemoryAccess::Intrinsic * intrinsic)1274 bool GeneratorImpl::EmitStorageAtomicCall(
1275 std::ostream& out,
1276 const ast::CallExpression* expr,
1277 const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
1278 using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
1279
1280 auto* result_ty = TypeOf(expr);
1281
1282 auto& buf = helpers_;
1283
1284 // generate_helper() generates a helper function that translates the
1285 // DecomposeMemoryAccess::Intrinsic call into the corresponding HLSL
1286 // atomic intrinsic function.
1287 auto generate_helper = [&]() -> std::string {
1288 auto rmw = [&](const char* wgsl, const char* hlsl) -> std::string {
1289 auto name = UniqueIdentifier(wgsl);
1290 {
1291 auto fn = line(&buf);
1292 if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1293 ast::Access::kUndefined, name)) {
1294 return "";
1295 }
1296 fn << "(RWByteAddressBuffer buffer, uint offset, ";
1297 if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1298 ast::Access::kUndefined, "value")) {
1299 return "";
1300 }
1301 fn << ") {";
1302 }
1303
1304 buf.IncrementIndent();
1305 TINT_DEFER({
1306 buf.DecrementIndent();
1307 line(&buf) << "}";
1308 line(&buf);
1309 });
1310
1311 {
1312 auto l = line(&buf);
1313 if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1314 ast::Access::kUndefined, "original_value")) {
1315 return "";
1316 }
1317 l << " = 0;";
1318 }
1319 {
1320 auto l = line(&buf);
1321 l << "buffer." << hlsl << "(offset, ";
1322 if (intrinsic->op == Op::kAtomicSub) {
1323 l << "-";
1324 }
1325 l << "value, original_value);";
1326 }
1327 line(&buf) << "return original_value;";
1328 return name;
1329 };
1330
1331 switch (intrinsic->op) {
1332 case Op::kAtomicAdd:
1333 return rmw("atomicAdd", "InterlockedAdd");
1334
1335 case Op::kAtomicSub:
1336 // Use add with the operand negated.
1337 return rmw("atomicSub", "InterlockedAdd");
1338
1339 case Op::kAtomicMax:
1340 return rmw("atomicMax", "InterlockedMax");
1341
1342 case Op::kAtomicMin:
1343 return rmw("atomicMin", "InterlockedMin");
1344
1345 case Op::kAtomicAnd:
1346 return rmw("atomicAnd", "InterlockedAnd");
1347
1348 case Op::kAtomicOr:
1349 return rmw("atomicOr", "InterlockedOr");
1350
1351 case Op::kAtomicXor:
1352 return rmw("atomicXor", "InterlockedXor");
1353
1354 case Op::kAtomicExchange:
1355 return rmw("atomicExchange", "InterlockedExchange");
1356
1357 case Op::kAtomicLoad: {
1358 // HLSL does not have an InterlockedLoad, so we emulate it with
1359 // InterlockedOr using 0 as the OR value
1360 auto name = UniqueIdentifier("atomicLoad");
1361 {
1362 auto fn = line(&buf);
1363 if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1364 ast::Access::kUndefined, name)) {
1365 return "";
1366 }
1367 fn << "(RWByteAddressBuffer buffer, uint offset) {";
1368 }
1369
1370 buf.IncrementIndent();
1371 TINT_DEFER({
1372 buf.DecrementIndent();
1373 line(&buf) << "}";
1374 line(&buf);
1375 });
1376
1377 {
1378 auto l = line(&buf);
1379 if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1380 ast::Access::kUndefined, "value")) {
1381 return "";
1382 }
1383 l << " = 0;";
1384 }
1385
1386 line(&buf) << "buffer.InterlockedOr(offset, 0, value);";
1387 line(&buf) << "return value;";
1388 return name;
1389 }
1390 case Op::kAtomicStore: {
1391 // HLSL does not have an InterlockedStore, so we emulate it with
1392 // InterlockedExchange and discard the returned value
1393 auto* value_ty = TypeOf(expr->args[2])->UnwrapRef();
1394 auto name = UniqueIdentifier("atomicStore");
1395 {
1396 auto fn = line(&buf);
1397 fn << "void " << name << "(RWByteAddressBuffer buffer, uint offset, ";
1398 if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1399 ast::Access::kUndefined, "value")) {
1400 return "";
1401 }
1402 fn << ") {";
1403 }
1404
1405 buf.IncrementIndent();
1406 TINT_DEFER({
1407 buf.DecrementIndent();
1408 line(&buf) << "}";
1409 line(&buf);
1410 });
1411
1412 {
1413 auto l = line(&buf);
1414 if (!EmitTypeAndName(l, value_ty, ast::StorageClass::kNone,
1415 ast::Access::kUndefined, "ignored")) {
1416 return "";
1417 }
1418 l << ";";
1419 }
1420 line(&buf) << "buffer.InterlockedExchange(offset, value, ignored);";
1421 return name;
1422 }
1423 case Op::kAtomicCompareExchangeWeak: {
1424 auto* value_ty = TypeOf(expr->args[2])->UnwrapRef();
1425
1426 auto name = UniqueIdentifier("atomicCompareExchangeWeak");
1427 {
1428 auto fn = line(&buf);
1429 if (!EmitTypeAndName(fn, result_ty, ast::StorageClass::kNone,
1430 ast::Access::kUndefined, name)) {
1431 return "";
1432 }
1433 fn << "(RWByteAddressBuffer buffer, uint offset, ";
1434 if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1435 ast::Access::kUndefined, "compare")) {
1436 return "";
1437 }
1438 fn << ", ";
1439 if (!EmitTypeAndName(fn, value_ty, ast::StorageClass::kNone,
1440 ast::Access::kUndefined, "value")) {
1441 return "";
1442 }
1443 fn << ") {";
1444 }
1445
1446 buf.IncrementIndent();
1447 TINT_DEFER({
1448 buf.DecrementIndent();
1449 line(&buf) << "}";
1450 line(&buf);
1451 });
1452
1453 { // T result = {0, 0};
1454 auto l = line(&buf);
1455 if (!EmitTypeAndName(l, result_ty, ast::StorageClass::kNone,
1456 ast::Access::kUndefined, "result")) {
1457 return "";
1458 }
1459 l << " = {0, 0};";
1460 }
1461 line(&buf) << "buffer.InterlockedCompareExchange(offset, compare, "
1462 "value, result.x);";
1463 line(&buf) << "result.y = result.x == compare;";
1464 line(&buf) << "return result;";
1465 return name;
1466 }
1467 default:
1468 break;
1469 }
1470 TINT_UNREACHABLE(Writer, diagnostics_)
1471 << "unsupported atomic DecomposeMemoryAccess::Intrinsic::Op: "
1472 << static_cast<int>(intrinsic->op);
1473 return "";
1474 };
1475
1476 auto func = utils::GetOrCreate(dma_intrinsics_,
1477 DMAIntrinsic{intrinsic->op, intrinsic->type},
1478 generate_helper);
1479 if (func.empty()) {
1480 return false;
1481 }
1482
1483 out << func;
1484 {
1485 ScopedParen sp(out);
1486 bool first = true;
1487 for (auto* arg : expr->args) {
1488 if (!first) {
1489 out << ", ";
1490 }
1491 first = false;
1492 if (!EmitExpression(out, arg)) {
1493 return false;
1494 }
1495 }
1496 }
1497
1498 return true;
1499 }
1500
EmitWorkgroupAtomicCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1501 bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& out,
1502 const ast::CallExpression* expr,
1503 const sem::Intrinsic* intrinsic) {
1504 std::string result = UniqueIdentifier("atomic_result");
1505
1506 if (!intrinsic->ReturnType()->Is<sem::Void>()) {
1507 auto pre = line();
1508 if (!EmitTypeAndName(pre, intrinsic->ReturnType(), ast::StorageClass::kNone,
1509 ast::Access::kUndefined, result)) {
1510 return false;
1511 }
1512 pre << " = ";
1513 if (!EmitZeroValue(pre, intrinsic->ReturnType())) {
1514 return false;
1515 }
1516 pre << ";";
1517 }
1518
1519 auto call = [&](const char* name) {
1520 auto pre = line();
1521 pre << name;
1522
1523 {
1524 ScopedParen sp(pre);
1525 for (size_t i = 0; i < expr->args.size(); i++) {
1526 auto* arg = expr->args[i];
1527 if (i > 0) {
1528 pre << ", ";
1529 }
1530 if (i == 1 && intrinsic->Type() == sem::IntrinsicType::kAtomicSub) {
1531 // Sub uses InterlockedAdd with the operand negated.
1532 pre << "-";
1533 }
1534 if (!EmitExpression(pre, arg)) {
1535 return false;
1536 }
1537 }
1538
1539 pre << ", " << result;
1540 }
1541
1542 pre << ";";
1543
1544 out << result;
1545 return true;
1546 };
1547
1548 switch (intrinsic->Type()) {
1549 case sem::IntrinsicType::kAtomicLoad: {
1550 // HLSL does not have an InterlockedLoad, so we emulate it with
1551 // InterlockedOr using 0 as the OR value
1552 auto pre = line();
1553 pre << "InterlockedOr";
1554 {
1555 ScopedParen sp(pre);
1556 if (!EmitExpression(pre, expr->args[0])) {
1557 return false;
1558 }
1559 pre << ", 0, " << result;
1560 }
1561 pre << ";";
1562
1563 out << result;
1564 return true;
1565 }
1566 case sem::IntrinsicType::kAtomicStore: {
1567 // HLSL does not have an InterlockedStore, so we emulate it with
1568 // InterlockedExchange and discard the returned value
1569 { // T result = 0;
1570 auto pre = line();
1571 auto* value_ty = intrinsic->Parameters()[1]->Type()->UnwrapRef();
1572 if (!EmitTypeAndName(pre, value_ty, ast::StorageClass::kNone,
1573 ast::Access::kUndefined, result)) {
1574 return false;
1575 }
1576 pre << " = ";
1577 if (!EmitZeroValue(pre, value_ty)) {
1578 return false;
1579 }
1580 pre << ";";
1581 }
1582
1583 out << "InterlockedExchange";
1584 {
1585 ScopedParen sp(out);
1586 if (!EmitExpression(out, expr->args[0])) {
1587 return false;
1588 }
1589 out << ", ";
1590 if (!EmitExpression(out, expr->args[1])) {
1591 return false;
1592 }
1593 out << ", " << result;
1594 }
1595 return true;
1596 }
1597 case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
1598 auto* dest = expr->args[0];
1599 auto* compare_value = expr->args[1];
1600 auto* value = expr->args[2];
1601
1602 std::string compare = UniqueIdentifier("atomic_compare_value");
1603
1604 { // T compare_value = <compare_value>;
1605 auto pre = line();
1606 if (!EmitTypeAndName(pre, TypeOf(compare_value),
1607 ast::StorageClass::kNone, ast::Access::kUndefined,
1608 compare)) {
1609 return false;
1610 }
1611 pre << " = ";
1612 if (!EmitExpression(pre, compare_value)) {
1613 return false;
1614 }
1615 pre << ";";
1616 }
1617
1618 { // InterlockedCompareExchange(dst, compare, value, result.x);
1619 auto pre = line();
1620 pre << "InterlockedCompareExchange";
1621 {
1622 ScopedParen sp(pre);
1623 if (!EmitExpression(pre, dest)) {
1624 return false;
1625 }
1626 pre << ", " << compare << ", ";
1627 if (!EmitExpression(pre, value)) {
1628 return false;
1629 }
1630 pre << ", " << result << ".x";
1631 }
1632 pre << ";";
1633 }
1634
1635 { // result.y = result.x == compare;
1636 line() << result << ".y = " << result << ".x == " << compare << ";";
1637 }
1638
1639 out << result;
1640 return true;
1641 }
1642
1643 case sem::IntrinsicType::kAtomicAdd:
1644 case sem::IntrinsicType::kAtomicSub:
1645 return call("InterlockedAdd");
1646
1647 case sem::IntrinsicType::kAtomicMax:
1648 return call("InterlockedMax");
1649
1650 case sem::IntrinsicType::kAtomicMin:
1651 return call("InterlockedMin");
1652
1653 case sem::IntrinsicType::kAtomicAnd:
1654 return call("InterlockedAnd");
1655
1656 case sem::IntrinsicType::kAtomicOr:
1657 return call("InterlockedOr");
1658
1659 case sem::IntrinsicType::kAtomicXor:
1660 return call("InterlockedXor");
1661
1662 case sem::IntrinsicType::kAtomicExchange:
1663 return call("InterlockedExchange");
1664
1665 default:
1666 break;
1667 }
1668
1669 TINT_UNREACHABLE(Writer, diagnostics_)
1670 << "unsupported atomic intrinsic: " << intrinsic->Type();
1671 return false;
1672 }
1673
EmitSelectCall(std::ostream & out,const ast::CallExpression * expr)1674 bool GeneratorImpl::EmitSelectCall(std::ostream& out,
1675 const ast::CallExpression* expr) {
1676 auto* expr_false = expr->args[0];
1677 auto* expr_true = expr->args[1];
1678 auto* expr_cond = expr->args[2];
1679 ScopedParen paren(out);
1680 if (!EmitExpression(out, expr_cond)) {
1681 return false;
1682 }
1683
1684 out << " ? ";
1685
1686 if (!EmitExpression(out, expr_true)) {
1687 return false;
1688 }
1689
1690 out << " : ";
1691
1692 if (!EmitExpression(out, expr_false)) {
1693 return false;
1694 }
1695
1696 return true;
1697 }
1698
EmitModfCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1699 bool GeneratorImpl::EmitModfCall(std::ostream& out,
1700 const ast::CallExpression* expr,
1701 const sem::Intrinsic* intrinsic) {
1702 return CallIntrinsicHelper(
1703 out, expr, intrinsic,
1704 [&](TextBuffer* b, const std::vector<std::string>& params) {
1705 auto* ty = intrinsic->Parameters()[0]->Type();
1706 auto in = params[0];
1707
1708 std::string width;
1709 if (auto* vec = ty->As<sem::Vector>()) {
1710 width = std::to_string(vec->Width());
1711 }
1712
1713 // Emit the builtin return type unique to this overload. This does not
1714 // exist in the AST, so it will not be generated in Generate().
1715 if (!EmitStructType(&helpers_,
1716 intrinsic->ReturnType()->As<sem::Struct>())) {
1717 return false;
1718 }
1719
1720 line(b) << "float" << width << " whole;";
1721 line(b) << "float" << width << " fract = modf(" << in << ", whole);";
1722 {
1723 auto l = line(b);
1724 if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
1725 ast::Access::kUndefined, "")) {
1726 return false;
1727 }
1728 l << " result = {fract, whole};";
1729 }
1730 line(b) << "return result;";
1731 return true;
1732 });
1733 }
1734
EmitFrexpCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1735 bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
1736 const ast::CallExpression* expr,
1737 const sem::Intrinsic* intrinsic) {
1738 return CallIntrinsicHelper(
1739 out, expr, intrinsic,
1740 [&](TextBuffer* b, const std::vector<std::string>& params) {
1741 auto* ty = intrinsic->Parameters()[0]->Type();
1742 auto in = params[0];
1743
1744 std::string width;
1745 if (auto* vec = ty->As<sem::Vector>()) {
1746 width = std::to_string(vec->Width());
1747 }
1748
1749 // Emit the builtin return type unique to this overload. This does not
1750 // exist in the AST, so it will not be generated in Generate().
1751 if (!EmitStructType(&helpers_,
1752 intrinsic->ReturnType()->As<sem::Struct>())) {
1753 return false;
1754 }
1755
1756 line(b) << "float" << width << " exp;";
1757 line(b) << "float" << width << " sig = frexp(" << in << ", exp);";
1758 {
1759 auto l = line(b);
1760 if (!EmitType(l, intrinsic->ReturnType(), ast::StorageClass::kNone,
1761 ast::Access::kUndefined, "")) {
1762 return false;
1763 }
1764 l << " result = {sig, int" << width << "(exp)};";
1765 }
1766 line(b) << "return result;";
1767 return true;
1768 });
1769 }
1770
EmitIsNormalCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1771 bool GeneratorImpl::EmitIsNormalCall(std::ostream& out,
1772 const ast::CallExpression* expr,
1773 const sem::Intrinsic* intrinsic) {
1774 // HLSL doesn't have a isNormal intrinsic, we need to emulate
1775 return CallIntrinsicHelper(
1776 out, expr, intrinsic,
1777 [&](TextBuffer* b, const std::vector<std::string>& params) {
1778 auto* input_ty = intrinsic->Parameters()[0]->Type();
1779
1780 std::string width;
1781 if (auto* vec = input_ty->As<sem::Vector>()) {
1782 width = std::to_string(vec->Width());
1783 }
1784
1785 constexpr auto* kExponentMask = "0x7f80000";
1786 constexpr auto* kMinNormalExponent = "0x0080000";
1787 constexpr auto* kMaxNormalExponent = "0x7f00000";
1788
1789 line(b) << "uint" << width << " exponent = asuint(" << params[0]
1790 << ") & " << kExponentMask << ";";
1791 line(b) << "uint" << width << " clamped = "
1792 << "clamp(exponent, " << kMinNormalExponent << ", "
1793 << kMaxNormalExponent << ");";
1794 line(b) << "return clamped == exponent;";
1795 return true;
1796 });
1797 }
1798
EmitDataPackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1799 bool GeneratorImpl::EmitDataPackingCall(std::ostream& out,
1800 const ast::CallExpression* expr,
1801 const sem::Intrinsic* intrinsic) {
1802 return CallIntrinsicHelper(
1803 out, expr, intrinsic,
1804 [&](TextBuffer* b, const std::vector<std::string>& params) {
1805 uint32_t dims = 2;
1806 bool is_signed = false;
1807 uint32_t scale = 65535;
1808 if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
1809 intrinsic->Type() == sem::IntrinsicType::kPack4x8unorm) {
1810 dims = 4;
1811 scale = 255;
1812 }
1813 if (intrinsic->Type() == sem::IntrinsicType::kPack4x8snorm ||
1814 intrinsic->Type() == sem::IntrinsicType::kPack2x16snorm) {
1815 is_signed = true;
1816 scale = (scale - 1) / 2;
1817 }
1818 switch (intrinsic->Type()) {
1819 case sem::IntrinsicType::kPack4x8snorm:
1820 case sem::IntrinsicType::kPack4x8unorm:
1821 case sem::IntrinsicType::kPack2x16snorm:
1822 case sem::IntrinsicType::kPack2x16unorm: {
1823 {
1824 auto l = line(b);
1825 l << (is_signed ? "" : "u") << "int" << dims
1826 << " i = " << (is_signed ? "" : "u") << "int" << dims
1827 << "(round(clamp(" << params[0] << ", "
1828 << (is_signed ? "-1.0" : "0.0") << ", 1.0) * " << scale
1829 << ".0))";
1830 if (is_signed) {
1831 l << " & " << (dims == 4 ? "0xff" : "0xffff");
1832 }
1833 l << ";";
1834 }
1835 {
1836 auto l = line(b);
1837 l << "return ";
1838 if (is_signed) {
1839 l << "asuint";
1840 }
1841 l << "(i.x | i.y << " << (32 / dims);
1842 if (dims == 4) {
1843 l << " | i.z << 16 | i.w << 24";
1844 }
1845 l << ");";
1846 }
1847 break;
1848 }
1849 case sem::IntrinsicType::kPack2x16float: {
1850 line(b) << "uint2 i = f32tof16(" << params[0] << ");";
1851 line(b) << "return i.x | (i.y << 16);";
1852 break;
1853 }
1854 default:
1855 diagnostics_.add_error(
1856 diag::System::Writer,
1857 "Internal error: unhandled data packing intrinsic");
1858 return false;
1859 }
1860
1861 return true;
1862 });
1863 }
1864
EmitDataUnpackingCall(std::ostream & out,const ast::CallExpression * expr,const sem::Intrinsic * intrinsic)1865 bool GeneratorImpl::EmitDataUnpackingCall(std::ostream& out,
1866 const ast::CallExpression* expr,
1867 const sem::Intrinsic* intrinsic) {
1868 return CallIntrinsicHelper(
1869 out, expr, intrinsic,
1870 [&](TextBuffer* b, const std::vector<std::string>& params) {
1871 uint32_t dims = 2;
1872 bool is_signed = false;
1873 uint32_t scale = 65535;
1874 if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1875 intrinsic->Type() == sem::IntrinsicType::kUnpack4x8unorm) {
1876 dims = 4;
1877 scale = 255;
1878 }
1879 if (intrinsic->Type() == sem::IntrinsicType::kUnpack4x8snorm ||
1880 intrinsic->Type() == sem::IntrinsicType::kUnpack2x16snorm) {
1881 is_signed = true;
1882 scale = (scale - 1) / 2;
1883 }
1884 switch (intrinsic->Type()) {
1885 case sem::IntrinsicType::kUnpack4x8snorm:
1886 case sem::IntrinsicType::kUnpack2x16snorm: {
1887 line(b) << "int j = int(" << params[0] << ");";
1888 { // Perform sign extension on the converted values.
1889 auto l = line(b);
1890 l << "int" << dims << " i = int" << dims << "(";
1891 if (dims == 2) {
1892 l << "j << 16, j) >> 16";
1893 } else {
1894 l << "j << 24, j << 16, j << 8, j) >> 24";
1895 }
1896 l << ";";
1897 }
1898 line(b) << "return clamp(float" << dims << "(i) / " << scale
1899 << ".0, " << (is_signed ? "-1.0" : "0.0") << ", 1.0);";
1900 break;
1901 }
1902 case sem::IntrinsicType::kUnpack4x8unorm:
1903 case sem::IntrinsicType::kUnpack2x16unorm: {
1904 line(b) << "uint j = " << params[0] << ";";
1905 {
1906 auto l = line(b);
1907 l << "uint" << dims << " i = uint" << dims << "(";
1908 l << "j & " << (dims == 2 ? "0xffff" : "0xff") << ", ";
1909 if (dims == 4) {
1910 l << "(j >> " << (32 / dims)
1911 << ") & 0xff, (j >> 16) & 0xff, j >> 24";
1912 } else {
1913 l << "j >> " << (32 / dims);
1914 }
1915 l << ");";
1916 }
1917 line(b) << "return float" << dims << "(i) / " << scale << ".0;";
1918 break;
1919 }
1920 case sem::IntrinsicType::kUnpack2x16float:
1921 line(b) << "uint i = " << params[0] << ";";
1922 line(b) << "return f16tof32(uint2(i & 0xffff, i >> 16));";
1923 break;
1924 default:
1925 diagnostics_.add_error(
1926 diag::System::Writer,
1927 "Internal error: unhandled data packing intrinsic");
1928 return false;
1929 }
1930
1931 return true;
1932 });
1933 }
1934
EmitBarrierCall(std::ostream & out,const sem::Intrinsic * intrinsic)1935 bool GeneratorImpl::EmitBarrierCall(std::ostream& out,
1936 const sem::Intrinsic* intrinsic) {
1937 // TODO(crbug.com/tint/661): Combine sequential barriers to a single
1938 // instruction.
1939 if (intrinsic->Type() == sem::IntrinsicType::kWorkgroupBarrier) {
1940 out << "GroupMemoryBarrierWithGroupSync()";
1941 } else if (intrinsic->Type() == sem::IntrinsicType::kStorageBarrier) {
1942 out << "DeviceMemoryBarrierWithGroupSync()";
1943 } else {
1944 TINT_UNREACHABLE(Writer, diagnostics_)
1945 << "unexpected barrier intrinsic type " << sem::str(intrinsic->Type());
1946 return false;
1947 }
1948 return true;
1949 }
1950
EmitTextureCall(std::ostream & out,const sem::Call * call,const sem::Intrinsic * intrinsic)1951 bool GeneratorImpl::EmitTextureCall(std::ostream& out,
1952 const sem::Call* call,
1953 const sem::Intrinsic* intrinsic) {
1954 using Usage = sem::ParameterUsage;
1955
1956 auto& signature = intrinsic->Signature();
1957 auto* expr = call->Declaration();
1958 auto arguments = expr->args;
1959
1960 // Returns the argument with the given usage
1961 auto arg = [&](Usage usage) {
1962 int idx = signature.IndexOf(usage);
1963 return (idx >= 0) ? arguments[idx] : nullptr;
1964 };
1965
1966 auto* texture = arg(Usage::kTexture);
1967 if (!texture) {
1968 TINT_ICE(Writer, diagnostics_) << "missing texture argument";
1969 return false;
1970 }
1971
1972 auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
1973
1974 switch (intrinsic->Type()) {
1975 case sem::IntrinsicType::kTextureDimensions:
1976 case sem::IntrinsicType::kTextureNumLayers:
1977 case sem::IntrinsicType::kTextureNumLevels:
1978 case sem::IntrinsicType::kTextureNumSamples: {
1979 // All of these intrinsics use the GetDimensions() method on the texture
1980 bool is_ms = texture_type->IsAnyOf<sem::MultisampledTexture,
1981 sem::DepthMultisampledTexture>();
1982 int num_dimensions = 0;
1983 std::string swizzle;
1984
1985 switch (intrinsic->Type()) {
1986 case sem::IntrinsicType::kTextureDimensions:
1987 switch (texture_type->dim()) {
1988 case ast::TextureDimension::kNone:
1989 TINT_ICE(Writer, diagnostics_) << "texture dimension is kNone";
1990 return false;
1991 case ast::TextureDimension::k1d:
1992 num_dimensions = 1;
1993 break;
1994 case ast::TextureDimension::k2d:
1995 num_dimensions = is_ms ? 3 : 2;
1996 swizzle = is_ms ? ".xy" : "";
1997 break;
1998 case ast::TextureDimension::k2dArray:
1999 num_dimensions = is_ms ? 4 : 3;
2000 swizzle = ".xy";
2001 break;
2002 case ast::TextureDimension::k3d:
2003 num_dimensions = 3;
2004 break;
2005 case ast::TextureDimension::kCube:
2006 num_dimensions = 2;
2007 break;
2008 case ast::TextureDimension::kCubeArray:
2009 num_dimensions = 3;
2010 swizzle = ".xy";
2011 break;
2012 }
2013 break;
2014 case sem::IntrinsicType::kTextureNumLayers:
2015 switch (texture_type->dim()) {
2016 default:
2017 TINT_ICE(Writer, diagnostics_)
2018 << "texture dimension is not arrayed";
2019 return false;
2020 case ast::TextureDimension::k2dArray:
2021 num_dimensions = is_ms ? 4 : 3;
2022 swizzle = ".z";
2023 break;
2024 case ast::TextureDimension::kCubeArray:
2025 num_dimensions = 3;
2026 swizzle = ".z";
2027 break;
2028 }
2029 break;
2030 case sem::IntrinsicType::kTextureNumLevels:
2031 switch (texture_type->dim()) {
2032 default:
2033 TINT_ICE(Writer, diagnostics_)
2034 << "texture dimension does not support mips";
2035 return false;
2036 case ast::TextureDimension::k1d:
2037 num_dimensions = 2;
2038 swizzle = ".y";
2039 break;
2040 case ast::TextureDimension::k2d:
2041 case ast::TextureDimension::kCube:
2042 num_dimensions = 3;
2043 swizzle = ".z";
2044 break;
2045 case ast::TextureDimension::k2dArray:
2046 case ast::TextureDimension::k3d:
2047 case ast::TextureDimension::kCubeArray:
2048 num_dimensions = 4;
2049 swizzle = ".w";
2050 break;
2051 }
2052 break;
2053 case sem::IntrinsicType::kTextureNumSamples:
2054 switch (texture_type->dim()) {
2055 default:
2056 TINT_ICE(Writer, diagnostics_)
2057 << "texture dimension does not support multisampling";
2058 return false;
2059 case ast::TextureDimension::k2d:
2060 num_dimensions = 3;
2061 swizzle = ".z";
2062 break;
2063 case ast::TextureDimension::k2dArray:
2064 num_dimensions = 4;
2065 swizzle = ".w";
2066 break;
2067 }
2068 break;
2069 default:
2070 TINT_ICE(Writer, diagnostics_) << "unexpected intrinsic";
2071 return false;
2072 }
2073
2074 auto* level_arg = arg(Usage::kLevel);
2075
2076 if (level_arg) {
2077 // `NumberOfLevels` is a non-optional argument if `MipLevel` was passed.
2078 // Increment the number of dimensions for the temporary vector to
2079 // accommodate this.
2080 num_dimensions++;
2081
2082 // If the swizzle was empty, the expression will evaluate to the whole
2083 // vector. As we've grown the vector by one element, we now need to
2084 // swizzle to keep the result expression equivalent.
2085 if (swizzle.empty()) {
2086 static constexpr const char* swizzles[] = {"", ".x", ".xy", ".xyz"};
2087 swizzle = swizzles[num_dimensions - 1];
2088 }
2089 }
2090
2091 if (num_dimensions > 4) {
2092 TINT_ICE(Writer, diagnostics_)
2093 << "Texture query intrinsic temporary vector has " << num_dimensions
2094 << " dimensions";
2095 return false;
2096 }
2097
2098 // Declare a variable to hold the queried texture info
2099 auto dims = UniqueIdentifier(kTempNamePrefix);
2100 if (num_dimensions == 1) {
2101 line() << "int " << dims << ";";
2102 } else {
2103 line() << "int" << num_dimensions << " " << dims << ";";
2104 }
2105
2106 { // texture.GetDimensions(...)
2107 auto pre = line();
2108 if (!EmitExpression(pre, texture)) {
2109 return false;
2110 }
2111 pre << ".GetDimensions(";
2112
2113 if (level_arg) {
2114 if (!EmitExpression(pre, level_arg)) {
2115 return false;
2116 }
2117 pre << ", ";
2118 } else if (intrinsic->Type() == sem::IntrinsicType::kTextureNumLevels) {
2119 pre << "0, ";
2120 }
2121
2122 if (num_dimensions == 1) {
2123 pre << dims;
2124 } else {
2125 static constexpr char xyzw[] = {'x', 'y', 'z', 'w'};
2126 if (num_dimensions < 0 || num_dimensions > 4) {
2127 TINT_ICE(Writer, diagnostics_)
2128 << "vector dimensions are " << num_dimensions;
2129 return false;
2130 }
2131 for (int i = 0; i < num_dimensions; i++) {
2132 if (i > 0) {
2133 pre << ", ";
2134 }
2135 pre << dims << "." << xyzw[i];
2136 }
2137 }
2138
2139 pre << ");";
2140 }
2141
2142 // The out parameters of the GetDimensions() call is now in temporary
2143 // `dims` variable. This may be packed with other data, so the final
2144 // expression may require a swizzle.
2145 out << dims << swizzle;
2146 return true;
2147 }
2148 default:
2149 break;
2150 }
2151
2152 if (!EmitExpression(out, texture))
2153 return false;
2154
2155 // If pack_level_in_coords is true, then the mip level will be appended as the
2156 // last value of the coordinates argument. If the WGSL intrinsic overload does
2157 // not have a level parameter and pack_level_in_coords is true, then a zero
2158 // mip level will be inserted.
2159 bool pack_level_in_coords = false;
2160
2161 uint32_t hlsl_ret_width = 4u;
2162
2163 switch (intrinsic->Type()) {
2164 case sem::IntrinsicType::kTextureSample:
2165 out << ".Sample(";
2166 break;
2167 case sem::IntrinsicType::kTextureSampleBias:
2168 out << ".SampleBias(";
2169 break;
2170 case sem::IntrinsicType::kTextureSampleLevel:
2171 out << ".SampleLevel(";
2172 break;
2173 case sem::IntrinsicType::kTextureSampleGrad:
2174 out << ".SampleGrad(";
2175 break;
2176 case sem::IntrinsicType::kTextureSampleCompare:
2177 out << ".SampleCmp(";
2178 hlsl_ret_width = 1;
2179 break;
2180 case sem::IntrinsicType::kTextureSampleCompareLevel:
2181 out << ".SampleCmpLevelZero(";
2182 hlsl_ret_width = 1;
2183 break;
2184 case sem::IntrinsicType::kTextureLoad:
2185 out << ".Load(";
2186 // Multisampled textures do not support mip-levels.
2187 if (!texture_type->Is<sem::MultisampledTexture>()) {
2188 pack_level_in_coords = true;
2189 }
2190 break;
2191 case sem::IntrinsicType::kTextureGather:
2192 out << ".Gather";
2193 if (intrinsic->Parameters()[0]->Usage() ==
2194 sem::ParameterUsage::kComponent) {
2195 switch (call->Arguments()[0]->ConstantValue().Elements()[0].i32) {
2196 case 0:
2197 out << "Red";
2198 break;
2199 case 1:
2200 out << "Green";
2201 break;
2202 case 2:
2203 out << "Blue";
2204 break;
2205 case 3:
2206 out << "Alpha";
2207 break;
2208 }
2209 }
2210 out << "(";
2211 break;
2212 case sem::IntrinsicType::kTextureGatherCompare:
2213 out << ".GatherCmp(";
2214 break;
2215 case sem::IntrinsicType::kTextureStore:
2216 out << "[";
2217 break;
2218 default:
2219 diagnostics_.add_error(
2220 diag::System::Writer,
2221 "Internal compiler error: Unhandled texture intrinsic '" +
2222 std::string(intrinsic->str()) + "'");
2223 return false;
2224 }
2225
2226 if (auto* sampler = arg(Usage::kSampler)) {
2227 if (!EmitExpression(out, sampler))
2228 return false;
2229 out << ", ";
2230 }
2231
2232 auto* param_coords = arg(Usage::kCoords);
2233 if (!param_coords) {
2234 TINT_ICE(Writer, diagnostics_) << "missing coords argument";
2235 return false;
2236 }
2237
2238 auto emit_vector_appended_with_i32_zero = [&](const ast::Expression* vector) {
2239 auto* i32 = builder_.create<sem::I32>();
2240 auto* zero = builder_.Expr(0);
2241 auto* stmt = builder_.Sem().Get(vector)->Stmt();
2242 builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
2243 sem::Constant{}));
2244 auto* packed = AppendVector(&builder_, vector, zero);
2245 return EmitExpression(out, packed->Declaration());
2246 };
2247
2248 auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
2249 if (auto* level = arg(Usage::kLevel)) {
2250 auto* packed = AppendVector(&builder_, vector, level);
2251 return EmitExpression(out, packed->Declaration());
2252 }
2253 return emit_vector_appended_with_i32_zero(vector);
2254 };
2255
2256 if (auto* array_index = arg(Usage::kArrayIndex)) {
2257 // Array index needs to be appended to the coordinates.
2258 auto* packed = AppendVector(&builder_, param_coords, array_index);
2259 if (pack_level_in_coords) {
2260 // Then mip level needs to be appended to the coordinates.
2261 if (!emit_vector_appended_with_level(packed->Declaration())) {
2262 return false;
2263 }
2264 } else {
2265 if (!EmitExpression(out, packed->Declaration())) {
2266 return false;
2267 }
2268 }
2269 } else if (pack_level_in_coords) {
2270 // Mip level needs to be appended to the coordinates.
2271 if (!emit_vector_appended_with_level(param_coords)) {
2272 return false;
2273 }
2274 } else {
2275 if (!EmitExpression(out, param_coords)) {
2276 return false;
2277 }
2278 }
2279
2280 for (auto usage : {Usage::kDepthRef, Usage::kBias, Usage::kLevel, Usage::kDdx,
2281 Usage::kDdy, Usage::kSampleIndex, Usage::kOffset}) {
2282 if (usage == Usage::kLevel && pack_level_in_coords) {
2283 continue; // mip level already packed in coordinates.
2284 }
2285 if (auto* e = arg(usage)) {
2286 out << ", ";
2287 if (!EmitExpression(out, e)) {
2288 return false;
2289 }
2290 }
2291 }
2292
2293 if (intrinsic->Type() == sem::IntrinsicType::kTextureStore) {
2294 out << "] = ";
2295 if (!EmitExpression(out, arg(Usage::kValue))) {
2296 return false;
2297 }
2298 } else {
2299 out << ")";
2300
2301 // If the intrinsic return type does not match the number of elements of the
2302 // HLSL intrinsic, we need to swizzle the expression to generate the correct
2303 // number of components.
2304 uint32_t wgsl_ret_width = 1;
2305 if (auto* vec = intrinsic->ReturnType()->As<sem::Vector>()) {
2306 wgsl_ret_width = vec->Width();
2307 }
2308 if (wgsl_ret_width < hlsl_ret_width) {
2309 out << ".";
2310 for (uint32_t i = 0; i < wgsl_ret_width; i++) {
2311 out << "xyz"[i];
2312 }
2313 }
2314 if (wgsl_ret_width > hlsl_ret_width) {
2315 TINT_ICE(Writer, diagnostics_)
2316 << "WGSL return width (" << wgsl_ret_width
2317 << ") is wider than HLSL return width (" << hlsl_ret_width << ") for "
2318 << intrinsic->Type();
2319 return false;
2320 }
2321 }
2322
2323 return true;
2324 }
2325
generate_builtin_name(const sem::Intrinsic * intrinsic)2326 std::string GeneratorImpl::generate_builtin_name(
2327 const sem::Intrinsic* intrinsic) {
2328 switch (intrinsic->Type()) {
2329 case sem::IntrinsicType::kAbs:
2330 case sem::IntrinsicType::kAcos:
2331 case sem::IntrinsicType::kAll:
2332 case sem::IntrinsicType::kAny:
2333 case sem::IntrinsicType::kAsin:
2334 case sem::IntrinsicType::kAtan:
2335 case sem::IntrinsicType::kAtan2:
2336 case sem::IntrinsicType::kCeil:
2337 case sem::IntrinsicType::kClamp:
2338 case sem::IntrinsicType::kCos:
2339 case sem::IntrinsicType::kCosh:
2340 case sem::IntrinsicType::kCross:
2341 case sem::IntrinsicType::kDeterminant:
2342 case sem::IntrinsicType::kDistance:
2343 case sem::IntrinsicType::kDot:
2344 case sem::IntrinsicType::kExp:
2345 case sem::IntrinsicType::kExp2:
2346 case sem::IntrinsicType::kFloor:
2347 case sem::IntrinsicType::kFrexp:
2348 case sem::IntrinsicType::kLdexp:
2349 case sem::IntrinsicType::kLength:
2350 case sem::IntrinsicType::kLog:
2351 case sem::IntrinsicType::kLog2:
2352 case sem::IntrinsicType::kMax:
2353 case sem::IntrinsicType::kMin:
2354 case sem::IntrinsicType::kModf:
2355 case sem::IntrinsicType::kNormalize:
2356 case sem::IntrinsicType::kPow:
2357 case sem::IntrinsicType::kReflect:
2358 case sem::IntrinsicType::kRefract:
2359 case sem::IntrinsicType::kRound:
2360 case sem::IntrinsicType::kSign:
2361 case sem::IntrinsicType::kSin:
2362 case sem::IntrinsicType::kSinh:
2363 case sem::IntrinsicType::kSqrt:
2364 case sem::IntrinsicType::kStep:
2365 case sem::IntrinsicType::kTan:
2366 case sem::IntrinsicType::kTanh:
2367 case sem::IntrinsicType::kTranspose:
2368 case sem::IntrinsicType::kTrunc:
2369 return intrinsic->str();
2370 case sem::IntrinsicType::kCountOneBits:
2371 return "countbits";
2372 case sem::IntrinsicType::kDpdx:
2373 return "ddx";
2374 case sem::IntrinsicType::kDpdxCoarse:
2375 return "ddx_coarse";
2376 case sem::IntrinsicType::kDpdxFine:
2377 return "ddx_fine";
2378 case sem::IntrinsicType::kDpdy:
2379 return "ddy";
2380 case sem::IntrinsicType::kDpdyCoarse:
2381 return "ddy_coarse";
2382 case sem::IntrinsicType::kDpdyFine:
2383 return "ddy_fine";
2384 case sem::IntrinsicType::kFaceForward:
2385 return "faceforward";
2386 case sem::IntrinsicType::kFract:
2387 return "frac";
2388 case sem::IntrinsicType::kFma:
2389 return "mad";
2390 case sem::IntrinsicType::kFwidth:
2391 case sem::IntrinsicType::kFwidthCoarse:
2392 case sem::IntrinsicType::kFwidthFine:
2393 return "fwidth";
2394 case sem::IntrinsicType::kInverseSqrt:
2395 return "rsqrt";
2396 case sem::IntrinsicType::kIsFinite:
2397 return "isfinite";
2398 case sem::IntrinsicType::kIsInf:
2399 return "isinf";
2400 case sem::IntrinsicType::kIsNan:
2401 return "isnan";
2402 case sem::IntrinsicType::kMix:
2403 return "lerp";
2404 case sem::IntrinsicType::kReverseBits:
2405 return "reversebits";
2406 case sem::IntrinsicType::kSmoothStep:
2407 return "smoothstep";
2408 default:
2409 diagnostics_.add_error(
2410 diag::System::Writer,
2411 "Unknown builtin method: " + std::string(intrinsic->str()));
2412 }
2413
2414 return "";
2415 }
2416
EmitCase(const ast::SwitchStatement * s,size_t case_idx)2417 bool GeneratorImpl::EmitCase(const ast::SwitchStatement* s, size_t case_idx) {
2418 auto* stmt = s->body[case_idx];
2419 if (stmt->IsDefault()) {
2420 line() << "default: {";
2421 } else {
2422 for (auto* selector : stmt->selectors) {
2423 auto out = line();
2424 out << "case ";
2425 if (!EmitLiteral(out, selector)) {
2426 return false;
2427 }
2428 out << ":";
2429 if (selector == stmt->selectors.back()) {
2430 out << " {";
2431 }
2432 }
2433 }
2434
2435 increment_indent();
2436 TINT_DEFER({
2437 decrement_indent();
2438 line() << "}";
2439 });
2440
2441 // Emit the case statement
2442 if (!EmitStatements(stmt->body->statements)) {
2443 return false;
2444 }
2445
2446 // Inline all fallthrough case statements. FXC cannot handle fallthroughs.
2447 while (tint::Is<ast::FallthroughStatement>(stmt->body->Last())) {
2448 case_idx++;
2449 stmt = s->body[case_idx];
2450 // Generate each fallthrough case statement in a new block. This is done to
2451 // prevent symbol collision of variables declared in these cases statements.
2452 if (!EmitBlock(stmt->body)) {
2453 return false;
2454 }
2455 }
2456
2457 if (!tint::IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(
2458 stmt->body->Last())) {
2459 line() << "break;";
2460 }
2461
2462 return true;
2463 }
2464
EmitContinue(const ast::ContinueStatement *)2465 bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
2466 if (!emit_continuing_()) {
2467 return false;
2468 }
2469 line() << "continue;";
2470 return true;
2471 }
2472
EmitDiscard(const ast::DiscardStatement *)2473 bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
2474 // TODO(dsinclair): Verify this is correct when the discard semantics are
2475 // defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
2476 line() << "discard;";
2477 return true;
2478 }
2479
EmitExpression(std::ostream & out,const ast::Expression * expr)2480 bool GeneratorImpl::EmitExpression(std::ostream& out,
2481 const ast::Expression* expr) {
2482 if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
2483 return EmitIndexAccessor(out, a);
2484 }
2485 if (auto* b = expr->As<ast::BinaryExpression>()) {
2486 return EmitBinary(out, b);
2487 }
2488 if (auto* b = expr->As<ast::BitcastExpression>()) {
2489 return EmitBitcast(out, b);
2490 }
2491 if (auto* c = expr->As<ast::CallExpression>()) {
2492 return EmitCall(out, c);
2493 }
2494 if (auto* i = expr->As<ast::IdentifierExpression>()) {
2495 return EmitIdentifier(out, i);
2496 }
2497 if (auto* l = expr->As<ast::LiteralExpression>()) {
2498 return EmitLiteral(out, l);
2499 }
2500 if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
2501 return EmitMemberAccessor(out, m);
2502 }
2503 if (auto* u = expr->As<ast::UnaryOpExpression>()) {
2504 return EmitUnaryOp(out, u);
2505 }
2506
2507 diagnostics_.add_error(
2508 diag::System::Writer,
2509 "unknown expression type: " + std::string(expr->TypeInfo().name));
2510 return false;
2511 }
2512
EmitIdentifier(std::ostream & out,const ast::IdentifierExpression * expr)2513 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
2514 const ast::IdentifierExpression* expr) {
2515 out << builder_.Symbols().NameFor(expr->symbol);
2516 return true;
2517 }
2518
EmitIf(const ast::IfStatement * stmt)2519 bool GeneratorImpl::EmitIf(const ast::IfStatement* stmt) {
2520 {
2521 auto out = line();
2522 out << "if (";
2523 if (!EmitExpression(out, stmt->condition)) {
2524 return false;
2525 }
2526 out << ") {";
2527 }
2528
2529 if (!EmitStatementsWithIndent(stmt->body->statements)) {
2530 return false;
2531 }
2532
2533 for (auto* e : stmt->else_statements) {
2534 if (e->condition) {
2535 line() << "} else {";
2536 increment_indent();
2537
2538 {
2539 auto out = line();
2540 out << "if (";
2541 if (!EmitExpression(out, e->condition)) {
2542 return false;
2543 }
2544 out << ") {";
2545 }
2546 } else {
2547 line() << "} else {";
2548 }
2549
2550 if (!EmitStatementsWithIndent(e->body->statements)) {
2551 return false;
2552 }
2553 }
2554
2555 line() << "}";
2556
2557 for (auto* e : stmt->else_statements) {
2558 if (e->condition) {
2559 decrement_indent();
2560 line() << "}";
2561 }
2562 }
2563 return true;
2564 }
2565
EmitFunction(const ast::Function * func)2566 bool GeneratorImpl::EmitFunction(const ast::Function* func) {
2567 auto* sem = builder_.Sem().Get(func);
2568
2569 if (ast::HasDecoration<ast::InternalDecoration>(func->decorations)) {
2570 // An internal function. Do not emit.
2571 return true;
2572 }
2573
2574 {
2575 auto out = line();
2576 auto name = builder_.Symbols().NameFor(func->symbol);
2577 // If the function returns an array, then we need to declare a typedef for
2578 // this.
2579 if (sem->ReturnType()->Is<sem::Array>()) {
2580 auto typedef_name = UniqueIdentifier(name + "_ret");
2581 auto pre = line();
2582 pre << "typedef ";
2583 if (!EmitTypeAndName(pre, sem->ReturnType(), ast::StorageClass::kNone,
2584 ast::Access::kReadWrite, typedef_name)) {
2585 return false;
2586 }
2587 pre << ";";
2588 out << typedef_name;
2589 } else {
2590 if (!EmitType(out, sem->ReturnType(), ast::StorageClass::kNone,
2591 ast::Access::kReadWrite, "")) {
2592 return false;
2593 }
2594 }
2595
2596 out << " " << name << "(";
2597
2598 bool first = true;
2599
2600 for (auto* v : sem->Parameters()) {
2601 if (!first) {
2602 out << ", ";
2603 }
2604 first = false;
2605
2606 auto const* type = v->Type();
2607
2608 if (auto* ptr = type->As<sem::Pointer>()) {
2609 // Transform pointer parameters in to `inout` parameters.
2610 // The WGSL spec is highly restrictive in what can be passed in pointer
2611 // parameters, which allows for this transformation. See:
2612 // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
2613 out << "inout ";
2614 type = ptr->StoreType();
2615 }
2616
2617 // Note: WGSL only allows for StorageClass::kNone on parameters, however
2618 // the sanitizer transforms generates load / store functions for storage
2619 // or uniform buffers. These functions have a buffer parameter with
2620 // StorageClass::kStorage or StorageClass::kUniform. This is required to
2621 // correctly translate the parameter to a [RW]ByteAddressBuffer for
2622 // storage buffers and a uint4[N] for uniform buffers.
2623 if (!EmitTypeAndName(
2624 out, type, v->StorageClass(), v->Access(),
2625 builder_.Symbols().NameFor(v->Declaration()->symbol))) {
2626 return false;
2627 }
2628 }
2629 out << ") {";
2630 }
2631
2632 if (sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>()) {
2633 // BUG(crbug.com/tint/1081): work around non-void functions with discard
2634 // failing compilation sometimes
2635 if (!EmitFunctionBodyWithDiscard(func)) {
2636 return false;
2637 }
2638 } else {
2639 if (!EmitStatementsWithIndent(func->body->statements)) {
2640 return false;
2641 }
2642 }
2643
2644 line() << "}";
2645
2646 return true;
2647 }
2648
EmitFunctionBodyWithDiscard(const ast::Function * func)2649 bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) {
2650 // FXC sometimes fails to compile functions that discard with 'Not all control
2651 // paths return a value'. We work around this by wrapping the function body
2652 // within an "if (true) { <body> } return <default return type obj>;" so that
2653 // there is always an (unused) return statement.
2654
2655 auto* sem = builder_.Sem().Get(func);
2656 TINT_ASSERT(Writer, sem->HasDiscard() && !sem->ReturnType()->Is<sem::Void>());
2657
2658 ScopedIndent si(this);
2659 line() << "if (true) {";
2660
2661 if (!EmitStatementsWithIndent(func->body->statements)) {
2662 return false;
2663 }
2664
2665 line() << "}";
2666
2667 // Return an unused result that matches the type of the return value
2668 auto name = builder_.Symbols().NameFor(builder_.Symbols().New("unused"));
2669 {
2670 auto out = line();
2671 if (!EmitTypeAndName(out, sem->ReturnType(), ast::StorageClass::kNone,
2672 ast::Access::kReadWrite, name)) {
2673 return false;
2674 }
2675 out << ";";
2676 }
2677 line() << "return " << name << ";";
2678
2679 return true;
2680 }
2681
EmitGlobalVariable(const ast::Variable * global)2682 bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
2683 if (global->is_const) {
2684 return EmitProgramConstVariable(global);
2685 }
2686
2687 auto* sem = builder_.Sem().Get(global);
2688 switch (sem->StorageClass()) {
2689 case ast::StorageClass::kUniform:
2690 return EmitUniformVariable(sem);
2691 case ast::StorageClass::kStorage:
2692 return EmitStorageVariable(sem);
2693 case ast::StorageClass::kUniformConstant:
2694 return EmitHandleVariable(sem);
2695 case ast::StorageClass::kPrivate:
2696 return EmitPrivateVariable(sem);
2697 case ast::StorageClass::kWorkgroup:
2698 return EmitWorkgroupVariable(sem);
2699 default:
2700 break;
2701 }
2702
2703 TINT_ICE(Writer, diagnostics_)
2704 << "unhandled storage class " << sem->StorageClass();
2705 return false;
2706 }
2707
EmitUniformVariable(const sem::Variable * var)2708 bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
2709 auto* decl = var->Declaration();
2710 auto binding_point = decl->BindingPoint();
2711 auto* type = var->Type()->UnwrapRef();
2712
2713 auto* str = type->As<sem::Struct>();
2714 if (!str) {
2715 // https://www.w3.org/TR/WGSL/#module-scope-variables
2716 TINT_ICE(Writer, diagnostics_)
2717 << "variables with uniform storage must be structure";
2718 }
2719
2720 auto name = builder_.Symbols().NameFor(decl->symbol);
2721 line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point)
2722 << " {";
2723
2724 {
2725 ScopedIndent si(this);
2726 auto out = line();
2727 if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(),
2728 name)) {
2729 return false;
2730 }
2731 out << ";";
2732 }
2733
2734 line() << "};";
2735
2736 return true;
2737 }
2738
EmitStorageVariable(const sem::Variable * var)2739 bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) {
2740 auto* decl = var->Declaration();
2741 auto* type = var->Type()->UnwrapRef();
2742 auto out = line();
2743 if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(),
2744 builder_.Symbols().NameFor(decl->symbol))) {
2745 return false;
2746 }
2747
2748 out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u',
2749 decl->BindingPoint())
2750 << ";";
2751
2752 return true;
2753 }
2754
EmitHandleVariable(const sem::Variable * var)2755 bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
2756 auto* decl = var->Declaration();
2757 auto* unwrapped_type = var->Type()->UnwrapRef();
2758 auto out = line();
2759
2760 auto name = builder_.Symbols().NameFor(decl->symbol);
2761 auto* type = var->Type()->UnwrapRef();
2762 if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2763 return false;
2764 }
2765
2766 const char* register_space = nullptr;
2767
2768 if (unwrapped_type->Is<sem::Texture>()) {
2769 register_space = "t";
2770 if (unwrapped_type->Is<sem::StorageTexture>()) {
2771 register_space = "u";
2772 }
2773 } else if (unwrapped_type->Is<sem::Sampler>()) {
2774 register_space = "s";
2775 }
2776
2777 if (register_space) {
2778 auto bp = decl->BindingPoint();
2779 out << " : register(" << register_space << bp.binding->value << ", space"
2780 << bp.group->value << ")";
2781 }
2782
2783 out << ";";
2784 return true;
2785 }
2786
EmitPrivateVariable(const sem::Variable * var)2787 bool GeneratorImpl::EmitPrivateVariable(const sem::Variable* var) {
2788 auto* decl = var->Declaration();
2789 auto out = line();
2790
2791 out << "static ";
2792
2793 auto name = builder_.Symbols().NameFor(decl->symbol);
2794 auto* type = var->Type()->UnwrapRef();
2795 if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2796 return false;
2797 }
2798
2799 out << " = ";
2800 if (auto* constructor = decl->constructor) {
2801 if (!EmitExpression(out, constructor)) {
2802 return false;
2803 }
2804 } else {
2805 if (!EmitZeroValue(out, var->Type()->UnwrapRef())) {
2806 return false;
2807 }
2808 }
2809
2810 out << ";";
2811 return true;
2812 }
2813
EmitWorkgroupVariable(const sem::Variable * var)2814 bool GeneratorImpl::EmitWorkgroupVariable(const sem::Variable* var) {
2815 auto* decl = var->Declaration();
2816 auto out = line();
2817
2818 out << "groupshared ";
2819
2820 auto name = builder_.Symbols().NameFor(decl->symbol);
2821 auto* type = var->Type()->UnwrapRef();
2822 if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
2823 return false;
2824 }
2825
2826 if (auto* constructor = decl->constructor) {
2827 out << " = ";
2828 if (!EmitExpression(out, constructor)) {
2829 return false;
2830 }
2831 }
2832
2833 out << ";";
2834 return true;
2835 }
2836
builtin_to_attribute(ast::Builtin builtin) const2837 std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
2838 switch (builtin) {
2839 case ast::Builtin::kPosition:
2840 return "SV_Position";
2841 case ast::Builtin::kVertexIndex:
2842 return "SV_VertexID";
2843 case ast::Builtin::kInstanceIndex:
2844 return "SV_InstanceID";
2845 case ast::Builtin::kFrontFacing:
2846 return "SV_IsFrontFace";
2847 case ast::Builtin::kFragDepth:
2848 return "SV_Depth";
2849 case ast::Builtin::kLocalInvocationId:
2850 return "SV_GroupThreadID";
2851 case ast::Builtin::kLocalInvocationIndex:
2852 return "SV_GroupIndex";
2853 case ast::Builtin::kGlobalInvocationId:
2854 return "SV_DispatchThreadID";
2855 case ast::Builtin::kWorkgroupId:
2856 return "SV_GroupID";
2857 case ast::Builtin::kSampleIndex:
2858 return "SV_SampleIndex";
2859 case ast::Builtin::kSampleMask:
2860 return "SV_Coverage";
2861 default:
2862 break;
2863 }
2864 return "";
2865 }
2866
interpolation_to_modifiers(ast::InterpolationType type,ast::InterpolationSampling sampling) const2867 std::string GeneratorImpl::interpolation_to_modifiers(
2868 ast::InterpolationType type,
2869 ast::InterpolationSampling sampling) const {
2870 std::string modifiers;
2871 switch (type) {
2872 case ast::InterpolationType::kPerspective:
2873 modifiers += "linear ";
2874 break;
2875 case ast::InterpolationType::kLinear:
2876 modifiers += "noperspective ";
2877 break;
2878 case ast::InterpolationType::kFlat:
2879 modifiers += "nointerpolation ";
2880 break;
2881 }
2882 switch (sampling) {
2883 case ast::InterpolationSampling::kCentroid:
2884 modifiers += "centroid ";
2885 break;
2886 case ast::InterpolationSampling::kSample:
2887 modifiers += "sample ";
2888 break;
2889 case ast::InterpolationSampling::kCenter:
2890 case ast::InterpolationSampling::kNone:
2891 break;
2892 }
2893 return modifiers;
2894 }
2895
EmitEntryPointFunction(const ast::Function * func)2896 bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
2897 auto* func_sem = builder_.Sem().Get(func);
2898
2899 {
2900 auto out = line();
2901 if (func->PipelineStage() == ast::PipelineStage::kCompute) {
2902 // Emit the workgroup_size attribute.
2903 auto wgsize = func_sem->WorkgroupSize();
2904 out << "[numthreads(";
2905 for (int i = 0; i < 3; i++) {
2906 if (i > 0) {
2907 out << ", ";
2908 }
2909
2910 if (wgsize[i].overridable_const) {
2911 auto* global = builder_.Sem().Get<sem::GlobalVariable>(
2912 wgsize[i].overridable_const);
2913 if (!global->IsOverridable()) {
2914 TINT_ICE(Writer, builder_.Diagnostics())
2915 << "expected a pipeline-overridable constant";
2916 }
2917 out << kSpecConstantPrefix << global->ConstantId();
2918 } else {
2919 out << std::to_string(wgsize[i].value);
2920 }
2921 }
2922 out << ")]" << std::endl;
2923 }
2924
2925 out << func->return_type->FriendlyName(builder_.Symbols());
2926
2927 out << " " << builder_.Symbols().NameFor(func->symbol) << "(";
2928
2929 bool first = true;
2930
2931 // Emit entry point parameters.
2932 for (auto* var : func->params) {
2933 auto* sem = builder_.Sem().Get(var);
2934 auto* type = sem->Type();
2935 if (!type->Is<sem::Struct>()) {
2936 // ICE likely indicates that the CanonicalizeEntryPointIO transform was
2937 // not run, or a builtin parameter was added after it was run.
2938 TINT_ICE(Writer, diagnostics_)
2939 << "Unsupported non-struct entry point parameter";
2940 }
2941
2942 if (!first) {
2943 out << ", ";
2944 }
2945 first = false;
2946
2947 if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
2948 builder_.Symbols().NameFor(var->symbol))) {
2949 return false;
2950 }
2951 }
2952
2953 out << ") {";
2954 }
2955
2956 {
2957 ScopedIndent si(this);
2958
2959 if (!EmitStatements(func->body->statements)) {
2960 return false;
2961 }
2962
2963 if (!Is<ast::ReturnStatement>(func->body->Last())) {
2964 ast::ReturnStatement ret(ProgramID(), Source{});
2965 if (!EmitStatement(&ret)) {
2966 return false;
2967 }
2968 }
2969 }
2970
2971 line() << "}";
2972
2973 return true;
2974 }
2975
EmitLiteral(std::ostream & out,const ast::LiteralExpression * lit)2976 bool GeneratorImpl::EmitLiteral(std::ostream& out,
2977 const ast::LiteralExpression* lit) {
2978 if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
2979 out << (l->value ? "true" : "false");
2980 } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
2981 if (std::isinf(fl->value)) {
2982 out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
2983 } else if (std::isnan(fl->value)) {
2984 out << "asfloat(0x7fc00000u)";
2985 } else {
2986 out << FloatToString(fl->value) << "f";
2987 }
2988 } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
2989 out << sl->value;
2990 } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
2991 out << ul->value << "u";
2992 } else {
2993 diagnostics_.add_error(diag::System::Writer, "unknown literal type");
2994 return false;
2995 }
2996 return true;
2997 }
2998
EmitZeroValue(std::ostream & out,const sem::Type * type)2999 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
3000 if (type->Is<sem::Bool>()) {
3001 out << "false";
3002 } else if (type->Is<sem::F32>()) {
3003 out << "0.0f";
3004 } else if (type->Is<sem::I32>()) {
3005 out << "0";
3006 } else if (type->Is<sem::U32>()) {
3007 out << "0u";
3008 } else if (auto* vec = type->As<sem::Vector>()) {
3009 if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
3010 "")) {
3011 return false;
3012 }
3013 ScopedParen sp(out);
3014 for (uint32_t i = 0; i < vec->Width(); i++) {
3015 if (i != 0) {
3016 out << ", ";
3017 }
3018 if (!EmitZeroValue(out, vec->type())) {
3019 return false;
3020 }
3021 }
3022 } else if (auto* mat = type->As<sem::Matrix>()) {
3023 if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
3024 "")) {
3025 return false;
3026 }
3027 ScopedParen sp(out);
3028 for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
3029 if (i != 0) {
3030 out << ", ";
3031 }
3032 if (!EmitZeroValue(out, mat->type())) {
3033 return false;
3034 }
3035 }
3036 } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
3037 out << "(";
3038 if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
3039 "")) {
3040 return false;
3041 }
3042 out << ")0";
3043 } else {
3044 diagnostics_.add_error(
3045 diag::System::Writer,
3046 "Invalid type for zero emission: " + type->type_name());
3047 return false;
3048 }
3049 return true;
3050 }
3051
EmitLoop(const ast::LoopStatement * stmt)3052 bool GeneratorImpl::EmitLoop(const ast::LoopStatement* stmt) {
3053 auto emit_continuing = [this, stmt]() {
3054 if (stmt->continuing && !stmt->continuing->Empty()) {
3055 if (!EmitBlock(stmt->continuing)) {
3056 return false;
3057 }
3058 }
3059 return true;
3060 };
3061
3062 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3063 line() << LoopAttribute() << "while (true) {";
3064 {
3065 ScopedIndent si(this);
3066 if (!EmitStatements(stmt->body->statements)) {
3067 return false;
3068 }
3069 if (!emit_continuing()) {
3070 return false;
3071 }
3072 }
3073 line() << "}";
3074
3075 return true;
3076 }
3077
EmitForLoop(const ast::ForLoopStatement * stmt)3078 bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
3079 // Nest a for loop with a new block. In HLSL the initializer scope is not
3080 // nested by the for-loop, so we may get variable redefinitions.
3081 line() << "{";
3082 increment_indent();
3083 TINT_DEFER({
3084 decrement_indent();
3085 line() << "}";
3086 });
3087
3088 TextBuffer init_buf;
3089 if (auto* init = stmt->initializer) {
3090 TINT_SCOPED_ASSIGNMENT(current_buffer_, &init_buf);
3091 if (!EmitStatement(init)) {
3092 return false;
3093 }
3094 }
3095
3096 TextBuffer cond_pre;
3097 std::stringstream cond_buf;
3098 if (auto* cond = stmt->condition) {
3099 TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
3100 if (!EmitExpression(cond_buf, cond)) {
3101 return false;
3102 }
3103 }
3104
3105 TextBuffer cont_buf;
3106 if (auto* cont = stmt->continuing) {
3107 TINT_SCOPED_ASSIGNMENT(current_buffer_, &cont_buf);
3108 if (!EmitStatement(cont)) {
3109 return false;
3110 }
3111 }
3112
3113 // If the for-loop has a multi-statement conditional and / or continuing, then
3114 // we cannot emit this as a regular for-loop in HLSL. Instead we need to
3115 // generate a `while(true)` loop.
3116 bool emit_as_loop = cond_pre.lines.size() > 0 || cont_buf.lines.size() > 1;
3117
3118 // If the for-loop has multi-statement initializer, or is going to be emitted
3119 // as a `while(true)` loop, then declare the initializer statement(s) before
3120 // the loop.
3121 if (init_buf.lines.size() > 1 || (stmt->initializer && emit_as_loop)) {
3122 current_buffer_->Append(init_buf);
3123 init_buf.lines.clear(); // Don't emit the initializer again in the 'for'
3124 }
3125
3126 if (emit_as_loop) {
3127 auto emit_continuing = [&]() {
3128 current_buffer_->Append(cont_buf);
3129 return true;
3130 };
3131
3132 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3133 line() << LoopAttribute() << "while (true) {";
3134 increment_indent();
3135 TINT_DEFER({
3136 decrement_indent();
3137 line() << "}";
3138 });
3139
3140 if (stmt->condition) {
3141 current_buffer_->Append(cond_pre);
3142 line() << "if (!(" << cond_buf.str() << ")) { break; }";
3143 }
3144
3145 if (!EmitStatements(stmt->body->statements)) {
3146 return false;
3147 }
3148
3149 if (!emit_continuing()) {
3150 return false;
3151 }
3152 } else {
3153 // For-loop can be generated.
3154 {
3155 auto out = line();
3156 out << LoopAttribute() << "for";
3157 {
3158 ScopedParen sp(out);
3159
3160 if (!init_buf.lines.empty()) {
3161 out << init_buf.lines[0].content << " ";
3162 } else {
3163 out << "; ";
3164 }
3165
3166 out << cond_buf.str() << "; ";
3167
3168 if (!cont_buf.lines.empty()) {
3169 out << TrimSuffix(cont_buf.lines[0].content, ";");
3170 }
3171 }
3172 out << " {";
3173 }
3174 {
3175 auto emit_continuing = [] { return true; };
3176 TINT_SCOPED_ASSIGNMENT(emit_continuing_, emit_continuing);
3177 if (!EmitStatementsWithIndent(stmt->body->statements)) {
3178 return false;
3179 }
3180 }
3181 line() << "}";
3182 }
3183
3184 return true;
3185 }
3186
EmitMemberAccessor(std::ostream & out,const ast::MemberAccessorExpression * expr)3187 bool GeneratorImpl::EmitMemberAccessor(
3188 std::ostream& out,
3189 const ast::MemberAccessorExpression* expr) {
3190 if (!EmitExpression(out, expr->structure)) {
3191 return false;
3192 }
3193 out << ".";
3194
3195 // Swizzles output the name directly
3196 if (builder_.Sem().Get(expr)->Is<sem::Swizzle>()) {
3197 out << builder_.Symbols().NameFor(expr->member->symbol);
3198 } else if (!EmitExpression(out, expr->member)) {
3199 return false;
3200 }
3201
3202 return true;
3203 }
3204
EmitReturn(const ast::ReturnStatement * stmt)3205 bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
3206 if (stmt->value) {
3207 auto out = line();
3208 out << "return ";
3209 if (!EmitExpression(out, stmt->value)) {
3210 return false;
3211 }
3212 out << ";";
3213 } else {
3214 line() << "return;";
3215 }
3216 return true;
3217 }
3218
EmitStatement(const ast::Statement * stmt)3219 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
3220 if (auto* a = stmt->As<ast::AssignmentStatement>()) {
3221 return EmitAssign(a);
3222 }
3223 if (auto* b = stmt->As<ast::BlockStatement>()) {
3224 return EmitBlock(b);
3225 }
3226 if (auto* b = stmt->As<ast::BreakStatement>()) {
3227 return EmitBreak(b);
3228 }
3229 if (auto* c = stmt->As<ast::CallStatement>()) {
3230 auto out = line();
3231 if (!EmitCall(out, c->expr)) {
3232 return false;
3233 }
3234 out << ";";
3235 return true;
3236 }
3237 if (auto* c = stmt->As<ast::ContinueStatement>()) {
3238 return EmitContinue(c);
3239 }
3240 if (auto* d = stmt->As<ast::DiscardStatement>()) {
3241 return EmitDiscard(d);
3242 }
3243 if (stmt->As<ast::FallthroughStatement>()) {
3244 line() << "/* fallthrough */";
3245 return true;
3246 }
3247 if (auto* i = stmt->As<ast::IfStatement>()) {
3248 return EmitIf(i);
3249 }
3250 if (auto* l = stmt->As<ast::LoopStatement>()) {
3251 return EmitLoop(l);
3252 }
3253 if (auto* l = stmt->As<ast::ForLoopStatement>()) {
3254 return EmitForLoop(l);
3255 }
3256 if (auto* r = stmt->As<ast::ReturnStatement>()) {
3257 return EmitReturn(r);
3258 }
3259 if (auto* s = stmt->As<ast::SwitchStatement>()) {
3260 return EmitSwitch(s);
3261 }
3262 if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
3263 return EmitVariable(v->variable);
3264 }
3265
3266 diagnostics_.add_error(
3267 diag::System::Writer,
3268 "unknown statement type: " + std::string(stmt->TypeInfo().name));
3269 return false;
3270 }
3271
EmitDefaultOnlySwitch(const ast::SwitchStatement * stmt)3272 bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
3273 TINT_ASSERT(Writer, stmt->body.size() == 1 && stmt->body[0]->IsDefault());
3274
3275 // FXC fails to compile a switch with just a default case, ignoring the
3276 // default case body. We work around this here by emitting the default case
3277 // without the switch.
3278
3279 // Emit the switch condition as-is in case it has side-effects (e.g.
3280 // function call). Note that's it's fine not to assign the result of the
3281 // expression.
3282 {
3283 auto out = line();
3284 if (!EmitExpression(out, stmt->condition)) {
3285 return false;
3286 }
3287 out << ";";
3288 }
3289
3290 // Emit "do { <default case body> } while(false);". We use a 'do' loop so
3291 // that break statements work as expected, and make it 'while (false)' in
3292 // case there isn't a break statement.
3293 line() << "do {";
3294 {
3295 ScopedIndent si(this);
3296 if (!EmitStatements(stmt->body[0]->body->statements)) {
3297 return false;
3298 }
3299 }
3300 line() << "} while (false);";
3301 return true;
3302 }
3303
EmitSwitch(const ast::SwitchStatement * stmt)3304 bool GeneratorImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
3305 // BUG(crbug.com/tint/1188): work around default-only switches
3306 if (stmt->body.size() == 1 && stmt->body[0]->IsDefault()) {
3307 return EmitDefaultOnlySwitch(stmt);
3308 }
3309
3310 { // switch(expr) {
3311 auto out = line();
3312 out << "switch(";
3313 if (!EmitExpression(out, stmt->condition)) {
3314 return false;
3315 }
3316 out << ") {";
3317 }
3318
3319 {
3320 ScopedIndent si(this);
3321 for (size_t i = 0; i < stmt->body.size(); i++) {
3322 if (!EmitCase(stmt, i)) {
3323 return false;
3324 }
3325 }
3326 }
3327
3328 line() << "}";
3329
3330 return true;
3331 }
3332
EmitType(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name,bool * name_printed)3333 bool GeneratorImpl::EmitType(std::ostream& out,
3334 const sem::Type* type,
3335 ast::StorageClass storage_class,
3336 ast::Access access,
3337 const std::string& name,
3338 bool* name_printed /* = nullptr */) {
3339 if (name_printed) {
3340 *name_printed = false;
3341 }
3342 switch (storage_class) {
3343 case ast::StorageClass::kStorage:
3344 if (access != ast::Access::kRead) {
3345 out << "RW";
3346 }
3347 out << "ByteAddressBuffer";
3348 return true;
3349 case ast::StorageClass::kUniform: {
3350 auto* str = type->As<sem::Struct>();
3351 if (!str) {
3352 // https://www.w3.org/TR/WGSL/#module-scope-variables
3353 TINT_ICE(Writer, diagnostics_)
3354 << "variables with uniform storage must be structure";
3355 }
3356 auto array_length = (str->Size() + 15) / 16;
3357 out << "uint4 " << name << "[" << array_length << "]";
3358 if (name_printed) {
3359 *name_printed = true;
3360 }
3361 return true;
3362 }
3363 default:
3364 break;
3365 }
3366
3367 if (auto* ary = type->As<sem::Array>()) {
3368 const sem::Type* base_type = ary;
3369 std::vector<uint32_t> sizes;
3370 while (auto* arr = base_type->As<sem::Array>()) {
3371 if (arr->IsRuntimeSized()) {
3372 TINT_ICE(Writer, diagnostics_)
3373 << "Runtime arrays may only exist in storage buffers, which should "
3374 "have been transformed into a ByteAddressBuffer";
3375 return false;
3376 }
3377 sizes.push_back(arr->Count());
3378 base_type = arr->ElemType();
3379 }
3380 if (!EmitType(out, base_type, storage_class, access, "")) {
3381 return false;
3382 }
3383 if (!name.empty()) {
3384 out << " " << name;
3385 if (name_printed) {
3386 *name_printed = true;
3387 }
3388 }
3389 for (uint32_t size : sizes) {
3390 out << "[" << size << "]";
3391 }
3392 } else if (type->Is<sem::Bool>()) {
3393 out << "bool";
3394 } else if (type->Is<sem::F32>()) {
3395 out << "float";
3396 } else if (type->Is<sem::I32>()) {
3397 out << "int";
3398 } else if (auto* mat = type->As<sem::Matrix>()) {
3399 if (!EmitType(out, mat->type(), storage_class, access, "")) {
3400 return false;
3401 }
3402 // Note: HLSL's matrices are declared as <type>NxM, where N is the number of
3403 // rows and M is the number of columns. Despite HLSL's matrices being
3404 // column-major by default, the index operator and constructors actually
3405 // operate on row-vectors, where as WGSL operates on column vectors.
3406 // To simplify everything we use the transpose of the matrices.
3407 // See:
3408 // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
3409 out << mat->columns() << "x" << mat->rows();
3410 } else if (type->Is<sem::Pointer>()) {
3411 TINT_ICE(Writer, diagnostics_)
3412 << "Attempting to emit pointer type. These should have been removed "
3413 "with the InlinePointerLets transform";
3414 return false;
3415 } else if (auto* sampler = type->As<sem::Sampler>()) {
3416 out << "Sampler";
3417 if (sampler->IsComparison()) {
3418 out << "Comparison";
3419 }
3420 out << "State";
3421 } else if (auto* str = type->As<sem::Struct>()) {
3422 out << StructName(str);
3423 } else if (auto* tex = type->As<sem::Texture>()) {
3424 auto* storage = tex->As<sem::StorageTexture>();
3425 auto* ms = tex->As<sem::MultisampledTexture>();
3426 auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
3427 auto* sampled = tex->As<sem::SampledTexture>();
3428
3429 if (storage && storage->access() != ast::Access::kRead) {
3430 out << "RW";
3431 }
3432 out << "Texture";
3433
3434 switch (tex->dim()) {
3435 case ast::TextureDimension::k1d:
3436 out << "1D";
3437 break;
3438 case ast::TextureDimension::k2d:
3439 out << ((ms || depth_ms) ? "2DMS" : "2D");
3440 break;
3441 case ast::TextureDimension::k2dArray:
3442 out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
3443 break;
3444 case ast::TextureDimension::k3d:
3445 out << "3D";
3446 break;
3447 case ast::TextureDimension::kCube:
3448 out << "Cube";
3449 break;
3450 case ast::TextureDimension::kCubeArray:
3451 out << "CubeArray";
3452 break;
3453 default:
3454 TINT_UNREACHABLE(Writer, diagnostics_)
3455 << "unexpected TextureDimension " << tex->dim();
3456 return false;
3457 }
3458
3459 if (storage) {
3460 auto* component = image_format_to_rwtexture_type(storage->image_format());
3461 if (component == nullptr) {
3462 TINT_ICE(Writer, diagnostics_)
3463 << "Unsupported StorageTexture ImageFormat: "
3464 << static_cast<int>(storage->image_format());
3465 return false;
3466 }
3467 out << "<" << component << ">";
3468 } else if (depth_ms) {
3469 out << "<float4>";
3470 } else if (sampled || ms) {
3471 auto* subtype = sampled ? sampled->type() : ms->type();
3472 out << "<";
3473 if (subtype->Is<sem::F32>()) {
3474 out << "float4";
3475 } else if (subtype->Is<sem::I32>()) {
3476 out << "int4";
3477 } else if (subtype->Is<sem::U32>()) {
3478 out << "uint4";
3479 } else {
3480 TINT_ICE(Writer, diagnostics_)
3481 << "Unsupported multisampled texture type";
3482 return false;
3483 }
3484 out << ">";
3485 }
3486 } else if (type->Is<sem::U32>()) {
3487 out << "uint";
3488 } else if (auto* vec = type->As<sem::Vector>()) {
3489 auto width = vec->Width();
3490 if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
3491 out << "float" << width;
3492 } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
3493 out << "int" << width;
3494 } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
3495 out << "uint" << width;
3496 } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
3497 out << "bool" << width;
3498 } else {
3499 out << "vector<";
3500 if (!EmitType(out, vec->type(), storage_class, access, "")) {
3501 return false;
3502 }
3503 out << ", " << width << ">";
3504 }
3505 } else if (auto* atomic = type->As<sem::Atomic>()) {
3506 if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
3507 return false;
3508 }
3509 } else if (type->Is<sem::Void>()) {
3510 out << "void";
3511 } else {
3512 diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
3513 return false;
3514 }
3515
3516 return true;
3517 }
3518
EmitTypeAndName(std::ostream & out,const sem::Type * type,ast::StorageClass storage_class,ast::Access access,const std::string & name)3519 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
3520 const sem::Type* type,
3521 ast::StorageClass storage_class,
3522 ast::Access access,
3523 const std::string& name) {
3524 bool name_printed = false;
3525 if (!EmitType(out, type, storage_class, access, name, &name_printed)) {
3526 return false;
3527 }
3528 if (!name.empty() && !name_printed) {
3529 out << " " << name;
3530 }
3531 return true;
3532 }
3533
EmitStructType(TextBuffer * b,const sem::Struct * str)3534 bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
3535 line(b) << "struct " << StructName(str) << " {";
3536 {
3537 ScopedIndent si(b);
3538 for (auto* mem : str->Members()) {
3539 auto name = builder_.Symbols().NameFor(mem->Name());
3540
3541 auto* ty = mem->Type();
3542
3543 auto out = line(b);
3544
3545 std::string pre, post;
3546
3547 if (auto* decl = mem->Declaration()) {
3548 for (auto* deco : decl->decorations) {
3549 if (auto* location = deco->As<ast::LocationDecoration>()) {
3550 auto& pipeline_stage_uses = str->PipelineStageUses();
3551 if (pipeline_stage_uses.size() != 1) {
3552 TINT_ICE(Writer, diagnostics_)
3553 << "invalid entry point IO struct uses";
3554 }
3555
3556 if (pipeline_stage_uses.count(
3557 sem::PipelineStageUsage::kVertexInput)) {
3558 post += " : TEXCOORD" + std::to_string(location->value);
3559 } else if (pipeline_stage_uses.count(
3560 sem::PipelineStageUsage::kVertexOutput)) {
3561 post += " : TEXCOORD" + std::to_string(location->value);
3562 } else if (pipeline_stage_uses.count(
3563 sem::PipelineStageUsage::kFragmentInput)) {
3564 post += " : TEXCOORD" + std::to_string(location->value);
3565 } else if (pipeline_stage_uses.count(
3566 sem::PipelineStageUsage::kFragmentOutput)) {
3567 post += " : SV_Target" + std::to_string(location->value);
3568 } else {
3569 TINT_ICE(Writer, diagnostics_)
3570 << "invalid use of location decoration";
3571 }
3572 } else if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
3573 auto attr = builtin_to_attribute(builtin->builtin);
3574 if (attr.empty()) {
3575 diagnostics_.add_error(diag::System::Writer,
3576 "unsupported builtin");
3577 return false;
3578 }
3579 post += " : " + attr;
3580 } else if (auto* interpolate =
3581 deco->As<ast::InterpolateDecoration>()) {
3582 auto mod = interpolation_to_modifiers(interpolate->type,
3583 interpolate->sampling);
3584 if (mod.empty()) {
3585 diagnostics_.add_error(diag::System::Writer,
3586 "unsupported interpolation");
3587 return false;
3588 }
3589 pre += mod;
3590
3591 } else if (deco->Is<ast::InvariantDecoration>()) {
3592 // Note: `precise` is not exactly the same as `invariant`, but is
3593 // stricter and therefore provides the necessary guarantees.
3594 // See discussion here: https://github.com/gpuweb/gpuweb/issues/893
3595 pre += "precise ";
3596 } else if (!deco->IsAnyOf<ast::StructMemberAlignDecoration,
3597 ast::StructMemberOffsetDecoration,
3598 ast::StructMemberSizeDecoration>()) {
3599 TINT_ICE(Writer, diagnostics_)
3600 << "unhandled struct member attribute: " << deco->Name();
3601 return false;
3602 }
3603 }
3604 }
3605
3606 out << pre;
3607 if (!EmitTypeAndName(out, ty, ast::StorageClass::kNone,
3608 ast::Access::kReadWrite, name)) {
3609 return false;
3610 }
3611 out << post << ";";
3612 }
3613 }
3614
3615 line(b) << "};";
3616
3617 return true;
3618 }
3619
EmitUnaryOp(std::ostream & out,const ast::UnaryOpExpression * expr)3620 bool GeneratorImpl::EmitUnaryOp(std::ostream& out,
3621 const ast::UnaryOpExpression* expr) {
3622 switch (expr->op) {
3623 case ast::UnaryOp::kIndirection:
3624 case ast::UnaryOp::kAddressOf:
3625 return EmitExpression(out, expr->expr);
3626 case ast::UnaryOp::kComplement:
3627 out << "~";
3628 break;
3629 case ast::UnaryOp::kNot:
3630 out << "!";
3631 break;
3632 case ast::UnaryOp::kNegation:
3633 out << "-";
3634 break;
3635 }
3636 out << "(";
3637
3638 if (!EmitExpression(out, expr->expr)) {
3639 return false;
3640 }
3641
3642 out << ")";
3643
3644 return true;
3645 }
3646
EmitVariable(const ast::Variable * var)3647 bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
3648 auto* sem = builder_.Sem().Get(var);
3649 auto* type = sem->Type()->UnwrapRef();
3650
3651 // TODO(dsinclair): Handle variable decorations
3652 if (!var->decorations.empty()) {
3653 diagnostics_.add_error(diag::System::Writer,
3654 "Variable decorations are not handled yet");
3655 return false;
3656 }
3657
3658 auto out = line();
3659 if (var->is_const) {
3660 out << "const ";
3661 }
3662 if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3663 builder_.Symbols().NameFor(var->symbol))) {
3664 return false;
3665 }
3666
3667 out << " = ";
3668
3669 if (var->constructor) {
3670 if (!EmitExpression(out, var->constructor)) {
3671 return false;
3672 }
3673 } else {
3674 if (!EmitZeroValue(out, type)) {
3675 return false;
3676 }
3677 }
3678 out << ";";
3679
3680 return true;
3681 }
3682
EmitProgramConstVariable(const ast::Variable * var)3683 bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
3684 for (auto* d : var->decorations) {
3685 if (!d->Is<ast::OverrideDecoration>()) {
3686 diagnostics_.add_error(diag::System::Writer,
3687 "Decorated const values not valid");
3688 return false;
3689 }
3690 }
3691 if (!var->is_const) {
3692 diagnostics_.add_error(diag::System::Writer, "Expected a const value");
3693 return false;
3694 }
3695
3696 auto* sem = builder_.Sem().Get(var);
3697 auto* type = sem->Type();
3698
3699 auto* global = sem->As<sem::GlobalVariable>();
3700 if (global && global->IsOverridable()) {
3701 auto const_id = global->ConstantId();
3702
3703 line() << "#ifndef " << kSpecConstantPrefix << const_id;
3704
3705 if (var->constructor != nullptr) {
3706 auto out = line();
3707 out << "#define " << kSpecConstantPrefix << const_id << " ";
3708 if (!EmitExpression(out, var->constructor)) {
3709 return false;
3710 }
3711 } else {
3712 line() << "#error spec constant required for constant id " << const_id;
3713 }
3714 line() << "#endif";
3715 {
3716 auto out = line();
3717 out << "static const ";
3718 if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3719 builder_.Symbols().NameFor(var->symbol))) {
3720 return false;
3721 }
3722 out << " = " << kSpecConstantPrefix << const_id << ";";
3723 }
3724 } else {
3725 auto out = line();
3726 out << "static const ";
3727 if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
3728 builder_.Symbols().NameFor(var->symbol))) {
3729 return false;
3730 }
3731 out << " = ";
3732 if (!EmitExpression(out, var->constructor)) {
3733 return false;
3734 }
3735 out << ";";
3736 }
3737
3738 return true;
3739 }
3740
3741 template <typename F>
CallIntrinsicHelper(std::ostream & out,const ast::CallExpression * call,const sem::Intrinsic * intrinsic,F && build)3742 bool GeneratorImpl::CallIntrinsicHelper(std::ostream& out,
3743 const ast::CallExpression* call,
3744 const sem::Intrinsic* intrinsic,
3745 F&& build) {
3746 // Generate the helper function if it hasn't been created already
3747 auto fn = utils::GetOrCreate(intrinsics_, intrinsic, [&]() -> std::string {
3748 TextBuffer b;
3749 TINT_DEFER(helpers_.Append(b));
3750
3751 auto fn_name =
3752 UniqueIdentifier(std::string("tint_") + sem::str(intrinsic->Type()));
3753 std::vector<std::string> parameter_names;
3754 {
3755 auto decl = line(&b);
3756 if (!EmitTypeAndName(decl, intrinsic->ReturnType(),
3757 ast::StorageClass::kNone, ast::Access::kUndefined,
3758 fn_name)) {
3759 return "";
3760 }
3761 {
3762 ScopedParen sp(decl);
3763 for (auto* param : intrinsic->Parameters()) {
3764 if (!parameter_names.empty()) {
3765 decl << ", ";
3766 }
3767 auto param_name = "param_" + std::to_string(parameter_names.size());
3768 const auto* ty = param->Type();
3769 if (auto* ptr = ty->As<sem::Pointer>()) {
3770 decl << "inout ";
3771 ty = ptr->StoreType();
3772 }
3773 if (!EmitTypeAndName(decl, ty, ast::StorageClass::kNone,
3774 ast::Access::kUndefined, param_name)) {
3775 return "";
3776 }
3777 parameter_names.emplace_back(std::move(param_name));
3778 }
3779 }
3780 decl << " {";
3781 }
3782 {
3783 ScopedIndent si(&b);
3784 if (!build(&b, parameter_names)) {
3785 return "";
3786 }
3787 }
3788 line(&b) << "}";
3789 line(&b);
3790 return fn_name;
3791 });
3792
3793 if (fn.empty()) {
3794 return false;
3795 }
3796
3797 // Call the helper
3798 out << fn;
3799 {
3800 ScopedParen sp(out);
3801 bool first = true;
3802 for (auto* arg : call->args) {
3803 if (!first) {
3804 out << ", ";
3805 }
3806 first = false;
3807 if (!EmitExpression(out, arg)) {
3808 return false;
3809 }
3810 }
3811 }
3812 return true;
3813 }
3814
3815 } // namespace hlsl
3816 } // namespace writer
3817 } // namespace tint
3818