1 // Copyright 2020 The Tint Authors. //
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #include "src/writer/spirv/builder.h"
15
16 #include <algorithm>
17 #include <limits>
18 #include <utility>
19
20 #include "spirv/unified1/GLSL.std.450.h"
21 #include "src/ast/call_statement.h"
22 #include "src/ast/fallthrough_statement.h"
23 #include "src/ast/internal_decoration.h"
24 #include "src/ast/override_decoration.h"
25 #include "src/ast/traverse_expressions.h"
26 #include "src/sem/array.h"
27 #include "src/sem/atomic_type.h"
28 #include "src/sem/call.h"
29 #include "src/sem/depth_multisampled_texture_type.h"
30 #include "src/sem/depth_texture_type.h"
31 #include "src/sem/function.h"
32 #include "src/sem/intrinsic.h"
33 #include "src/sem/member_accessor_expression.h"
34 #include "src/sem/multisampled_texture_type.h"
35 #include "src/sem/reference_type.h"
36 #include "src/sem/sampled_texture_type.h"
37 #include "src/sem/statement.h"
38 #include "src/sem/struct.h"
39 #include "src/sem/type_constructor.h"
40 #include "src/sem/type_conversion.h"
41 #include "src/sem/variable.h"
42 #include "src/sem/vector_type.h"
43 #include "src/transform/add_empty_entry_point.h"
44 #include "src/transform/canonicalize_entry_point_io.h"
45 #include "src/transform/external_texture_transform.h"
46 #include "src/transform/fold_constants.h"
47 #include "src/transform/for_loop_to_loop.h"
48 #include "src/transform/manager.h"
49 #include "src/transform/simplify_pointers.h"
50 #include "src/transform/unshadow.h"
51 #include "src/transform/vectorize_scalar_matrix_constructors.h"
52 #include "src/transform/zero_init_workgroup_memory.h"
53 #include "src/utils/defer.h"
54 #include "src/utils/map.h"
55 #include "src/writer/append_vector.h"
56
57 namespace tint {
58 namespace writer {
59 namespace spirv {
60 namespace {
61
62 using IntrinsicType = sem::IntrinsicType;
63
64 const char kGLSLstd450[] = "GLSL.std.450";
65
size_of(const InstructionList & instructions)66 uint32_t size_of(const InstructionList& instructions) {
67 uint32_t size = 0;
68 for (const auto& inst : instructions)
69 size += inst.word_length();
70
71 return size;
72 }
73
pipeline_stage_to_execution_model(ast::PipelineStage stage)74 uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
75 SpvExecutionModel model = SpvExecutionModelVertex;
76
77 switch (stage) {
78 case ast::PipelineStage::kFragment:
79 model = SpvExecutionModelFragment;
80 break;
81 case ast::PipelineStage::kVertex:
82 model = SpvExecutionModelVertex;
83 break;
84 case ast::PipelineStage::kCompute:
85 model = SpvExecutionModelGLCompute;
86 break;
87 case ast::PipelineStage::kNone:
88 model = SpvExecutionModelMax;
89 break;
90 }
91 return model;
92 }
93
LastIsFallthrough(const ast::BlockStatement * stmts)94 bool LastIsFallthrough(const ast::BlockStatement* stmts) {
95 return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
96 }
97
98 // A terminator is anything which will cause a SPIR-V terminator to be emitted.
99 // This means things like breaks, fallthroughs and continues which all emit an
100 // OpBranch or return for the OpReturn emission.
LastIsTerminator(const ast::BlockStatement * stmts)101 bool LastIsTerminator(const ast::BlockStatement* stmts) {
102 if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
103 ast::DiscardStatement, ast::ReturnStatement,
104 ast::FallthroughStatement>(stmts->Last())) {
105 return true;
106 }
107
108 if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
109 return LastIsTerminator(block);
110 }
111
112 return false;
113 }
114
115 /// Returns the matrix type that is `type` or that is wrapped by
116 /// one or more levels of an arrays inside of `type`.
117 /// @param type the given type, which must not be null
118 /// @returns the nested matrix type, or nullptr if none
GetNestedMatrixType(const sem::Type * type)119 const sem::Matrix* GetNestedMatrixType(const sem::Type* type) {
120 while (auto* arr = type->As<sem::Array>()) {
121 type = arr->ElemType();
122 }
123 return type->As<sem::Matrix>();
124 }
125
intrinsic_to_glsl_method(const sem::Intrinsic * intrinsic)126 uint32_t intrinsic_to_glsl_method(const sem::Intrinsic* intrinsic) {
127 switch (intrinsic->Type()) {
128 case IntrinsicType::kAcos:
129 return GLSLstd450Acos;
130 case IntrinsicType::kAsin:
131 return GLSLstd450Asin;
132 case IntrinsicType::kAtan:
133 return GLSLstd450Atan;
134 case IntrinsicType::kAtan2:
135 return GLSLstd450Atan2;
136 case IntrinsicType::kCeil:
137 return GLSLstd450Ceil;
138 case IntrinsicType::kClamp:
139 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
140 return GLSLstd450NClamp;
141 } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
142 return GLSLstd450UClamp;
143 } else {
144 return GLSLstd450SClamp;
145 }
146 case IntrinsicType::kCos:
147 return GLSLstd450Cos;
148 case IntrinsicType::kCosh:
149 return GLSLstd450Cosh;
150 case IntrinsicType::kCross:
151 return GLSLstd450Cross;
152 case IntrinsicType::kDeterminant:
153 return GLSLstd450Determinant;
154 case IntrinsicType::kDistance:
155 return GLSLstd450Distance;
156 case IntrinsicType::kExp:
157 return GLSLstd450Exp;
158 case IntrinsicType::kExp2:
159 return GLSLstd450Exp2;
160 case IntrinsicType::kFaceForward:
161 return GLSLstd450FaceForward;
162 case IntrinsicType::kFloor:
163 return GLSLstd450Floor;
164 case IntrinsicType::kFma:
165 return GLSLstd450Fma;
166 case IntrinsicType::kFract:
167 return GLSLstd450Fract;
168 case IntrinsicType::kFrexp:
169 return GLSLstd450FrexpStruct;
170 case IntrinsicType::kInverseSqrt:
171 return GLSLstd450InverseSqrt;
172 case IntrinsicType::kLdexp:
173 return GLSLstd450Ldexp;
174 case IntrinsicType::kLength:
175 return GLSLstd450Length;
176 case IntrinsicType::kLog:
177 return GLSLstd450Log;
178 case IntrinsicType::kLog2:
179 return GLSLstd450Log2;
180 case IntrinsicType::kMax:
181 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
182 return GLSLstd450NMax;
183 } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
184 return GLSLstd450UMax;
185 } else {
186 return GLSLstd450SMax;
187 }
188 case IntrinsicType::kMin:
189 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
190 return GLSLstd450NMin;
191 } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
192 return GLSLstd450UMin;
193 } else {
194 return GLSLstd450SMin;
195 }
196 case IntrinsicType::kMix:
197 return GLSLstd450FMix;
198 case IntrinsicType::kModf:
199 return GLSLstd450ModfStruct;
200 case IntrinsicType::kNormalize:
201 return GLSLstd450Normalize;
202 case IntrinsicType::kPack4x8snorm:
203 return GLSLstd450PackSnorm4x8;
204 case IntrinsicType::kPack4x8unorm:
205 return GLSLstd450PackUnorm4x8;
206 case IntrinsicType::kPack2x16snorm:
207 return GLSLstd450PackSnorm2x16;
208 case IntrinsicType::kPack2x16unorm:
209 return GLSLstd450PackUnorm2x16;
210 case IntrinsicType::kPack2x16float:
211 return GLSLstd450PackHalf2x16;
212 case IntrinsicType::kPow:
213 return GLSLstd450Pow;
214 case IntrinsicType::kReflect:
215 return GLSLstd450Reflect;
216 case IntrinsicType::kRefract:
217 return GLSLstd450Refract;
218 case IntrinsicType::kRound:
219 return GLSLstd450RoundEven;
220 case IntrinsicType::kSign:
221 return GLSLstd450FSign;
222 case IntrinsicType::kSin:
223 return GLSLstd450Sin;
224 case IntrinsicType::kSinh:
225 return GLSLstd450Sinh;
226 case IntrinsicType::kSmoothStep:
227 return GLSLstd450SmoothStep;
228 case IntrinsicType::kSqrt:
229 return GLSLstd450Sqrt;
230 case IntrinsicType::kStep:
231 return GLSLstd450Step;
232 case IntrinsicType::kTan:
233 return GLSLstd450Tan;
234 case IntrinsicType::kTanh:
235 return GLSLstd450Tanh;
236 case IntrinsicType::kTrunc:
237 return GLSLstd450Trunc;
238 case IntrinsicType::kUnpack4x8snorm:
239 return GLSLstd450UnpackSnorm4x8;
240 case IntrinsicType::kUnpack4x8unorm:
241 return GLSLstd450UnpackUnorm4x8;
242 case IntrinsicType::kUnpack2x16snorm:
243 return GLSLstd450UnpackSnorm2x16;
244 case IntrinsicType::kUnpack2x16unorm:
245 return GLSLstd450UnpackUnorm2x16;
246 case IntrinsicType::kUnpack2x16float:
247 return GLSLstd450UnpackHalf2x16;
248 default:
249 break;
250 }
251 return 0;
252 }
253
254 /// @return the vector element type if ty is a vector, otherwise return ty.
ElementTypeOf(const sem::Type * ty)255 const sem::Type* ElementTypeOf(const sem::Type* ty) {
256 if (auto* v = ty->As<sem::Vector>()) {
257 return v->type();
258 }
259 return ty;
260 }
261
262 } // namespace
263
Sanitize(const Program * in,bool emit_vertex_point_size,bool disable_workgroup_init)264 SanitizedResult Sanitize(const Program* in,
265 bool emit_vertex_point_size,
266 bool disable_workgroup_init) {
267 transform::Manager manager;
268 transform::DataMap data;
269
270 manager.Add<transform::Unshadow>();
271 if (!disable_workgroup_init) {
272 manager.Add<transform::ZeroInitWorkgroupMemory>();
273 }
274 manager.Add<transform::SimplifyPointers>(); // Required for arrayLength()
275 manager.Add<transform::FoldConstants>();
276 manager.Add<transform::ExternalTextureTransform>();
277 manager.Add<transform::VectorizeScalarMatrixConstructors>();
278 manager.Add<transform::ForLoopToLoop>(); // Must come after
279 // ZeroInitWorkgroupMemory
280 manager.Add<transform::CanonicalizeEntryPointIO>();
281 manager.Add<transform::AddEmptyEntryPoint>();
282
283 data.Add<transform::CanonicalizeEntryPointIO::Config>(
284 transform::CanonicalizeEntryPointIO::Config(
285 transform::CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
286 emit_vertex_point_size));
287
288 SanitizedResult result;
289 result.program = std::move(manager.Run(in, data).program);
290 return result;
291 }
292
AccessorInfo()293 Builder::AccessorInfo::AccessorInfo() : source_id(0), source_type(nullptr) {}
294
~AccessorInfo()295 Builder::AccessorInfo::~AccessorInfo() {}
296
Builder(const Program * program)297 Builder::Builder(const Program* program)
298 : builder_(ProgramBuilder::Wrap(program)), scope_stack_({}) {}
299
300 Builder::~Builder() = default;
301
Build()302 bool Builder::Build() {
303 push_capability(SpvCapabilityShader);
304
305 push_memory_model(spv::Op::OpMemoryModel,
306 {Operand::Int(SpvAddressingModelLogical),
307 Operand::Int(SpvMemoryModelGLSL450)});
308
309 for (auto* var : builder_.AST().GlobalVariables()) {
310 if (!GenerateGlobalVariable(var)) {
311 return false;
312 }
313 }
314
315 for (auto* func : builder_.AST().Functions()) {
316 if (!GenerateFunction(func)) {
317 return false;
318 }
319 }
320
321 return true;
322 }
323
result_op()324 Operand Builder::result_op() {
325 return Operand::Int(next_id());
326 }
327
total_size() const328 uint32_t Builder::total_size() const {
329 // The 5 covers the magic, version, generator, id bound and reserved.
330 uint32_t size = 5;
331
332 size += size_of(capabilities_);
333 size += size_of(extensions_);
334 size += size_of(ext_imports_);
335 size += size_of(memory_model_);
336 size += size_of(entry_points_);
337 size += size_of(execution_modes_);
338 size += size_of(debug_);
339 size += size_of(annotations_);
340 size += size_of(types_);
341 for (const auto& func : functions_) {
342 size += func.word_length();
343 }
344
345 return size;
346 }
347
iterate(std::function<void (const Instruction &)> cb) const348 void Builder::iterate(std::function<void(const Instruction&)> cb) const {
349 for (const auto& inst : capabilities_) {
350 cb(inst);
351 }
352 for (const auto& inst : extensions_) {
353 cb(inst);
354 }
355 for (const auto& inst : ext_imports_) {
356 cb(inst);
357 }
358 for (const auto& inst : memory_model_) {
359 cb(inst);
360 }
361 for (const auto& inst : entry_points_) {
362 cb(inst);
363 }
364 for (const auto& inst : execution_modes_) {
365 cb(inst);
366 }
367 for (const auto& inst : debug_) {
368 cb(inst);
369 }
370 for (const auto& inst : annotations_) {
371 cb(inst);
372 }
373 for (const auto& inst : types_) {
374 cb(inst);
375 }
376 for (const auto& func : functions_) {
377 func.iterate(cb);
378 }
379 }
380
push_capability(uint32_t cap)381 void Builder::push_capability(uint32_t cap) {
382 if (capability_set_.count(cap) == 0) {
383 capability_set_.insert(cap);
384 capabilities_.push_back(
385 Instruction{spv::Op::OpCapability, {Operand::Int(cap)}});
386 }
387 }
388
GenerateLabel(uint32_t id)389 bool Builder::GenerateLabel(uint32_t id) {
390 if (!push_function_inst(spv::Op::OpLabel, {Operand::Int(id)})) {
391 return false;
392 }
393 current_label_id_ = id;
394 return true;
395 }
396
GenerateAssignStatement(const ast::AssignmentStatement * assign)397 bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) {
398 if (assign->lhs->Is<ast::PhonyExpression>()) {
399 auto rhs_id = GenerateExpression(assign->rhs);
400 if (rhs_id == 0) {
401 return false;
402 }
403 return true;
404 } else {
405 auto lhs_id = GenerateExpression(assign->lhs);
406 if (lhs_id == 0) {
407 return false;
408 }
409 auto rhs_id = GenerateExpression(assign->rhs);
410 if (rhs_id == 0) {
411 return false;
412 }
413
414 // If the thing we're assigning is a reference then we must load it first.
415 auto* type = TypeOf(assign->rhs);
416 rhs_id = GenerateLoadIfNeeded(type, rhs_id);
417
418 return GenerateStore(lhs_id, rhs_id);
419 }
420 }
421
GenerateBreakStatement(const ast::BreakStatement *)422 bool Builder::GenerateBreakStatement(const ast::BreakStatement*) {
423 if (merge_stack_.empty()) {
424 error_ = "Attempted to break without a merge block";
425 return false;
426 }
427 if (!push_function_inst(spv::Op::OpBranch,
428 {Operand::Int(merge_stack_.back())})) {
429 return false;
430 }
431 return true;
432 }
433
GenerateContinueStatement(const ast::ContinueStatement *)434 bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
435 if (continue_stack_.empty()) {
436 error_ = "Attempted to continue without a continue block";
437 return false;
438 }
439 if (!push_function_inst(spv::Op::OpBranch,
440 {Operand::Int(continue_stack_.back())})) {
441 return false;
442 }
443 return true;
444 }
445
446 // TODO(dsinclair): This is generating an OpKill but the semantics of kill
447 // haven't been defined for WGSL yet. So, this may need to change.
448 // https://github.com/gpuweb/gpuweb/issues/676
GenerateDiscardStatement(const ast::DiscardStatement *)449 bool Builder::GenerateDiscardStatement(const ast::DiscardStatement*) {
450 if (!push_function_inst(spv::Op::OpKill, {})) {
451 return false;
452 }
453 return true;
454 }
455
GenerateEntryPoint(const ast::Function * func,uint32_t id)456 bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) {
457 auto stage = pipeline_stage_to_execution_model(func->PipelineStage());
458 if (stage == SpvExecutionModelMax) {
459 error_ = "Unknown pipeline stage provided";
460 return false;
461 }
462
463 OperandList operands = {
464 Operand::Int(stage), Operand::Int(id),
465 Operand::String(builder_.Symbols().NameFor(func->symbol))};
466
467 auto* func_sem = builder_.Sem().Get(func);
468 for (const auto* var : func_sem->TransitivelyReferencedGlobals()) {
469 // For SPIR-V 1.3 we only output Input/output variables. If we update to
470 // SPIR-V 1.4 or later this should be all variables.
471 if (var->StorageClass() != ast::StorageClass::kInput &&
472 var->StorageClass() != ast::StorageClass::kOutput) {
473 continue;
474 }
475
476 uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
477 if (var_id == 0) {
478 error_ = "unable to find ID for global variable: " +
479 builder_.Symbols().NameFor(var->Declaration()->symbol);
480 return false;
481 }
482
483 operands.push_back(Operand::Int(var_id));
484 }
485 push_entry_point(spv::Op::OpEntryPoint, operands);
486
487 return true;
488 }
489
GenerateExecutionModes(const ast::Function * func,uint32_t id)490 bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
491 auto* func_sem = builder_.Sem().Get(func);
492
493 // WGSL fragment shader origin is upper left
494 if (func->PipelineStage() == ast::PipelineStage::kFragment) {
495 push_execution_mode(
496 spv::Op::OpExecutionMode,
497 {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
498 } else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
499 auto& wgsize = func_sem->WorkgroupSize();
500
501 // Check if the workgroup_size uses pipeline-overridable constants.
502 if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
503 wgsize[2].overridable_const) {
504 if (has_overridable_workgroup_size_) {
505 // Only one stage can have a pipeline-overridable workgroup size.
506 // TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario.
507 TINT_ICE(Writer, builder_.Diagnostics())
508 << "multiple stages using pipeline-overridable workgroup sizes";
509 }
510 has_overridable_workgroup_size_ = true;
511
512 auto* vec3_u32 =
513 builder_.create<sem::Vector>(builder_.create<sem::U32>(), 3);
514 uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(vec3_u32);
515 if (vec3_u32_type_id == 0) {
516 return 0;
517 }
518
519 OperandList wgsize_ops;
520 auto wgsize_result = result_op();
521 wgsize_ops.push_back(Operand::Int(vec3_u32_type_id));
522 wgsize_ops.push_back(wgsize_result);
523
524 // Generate OpConstant instructions for each dimension.
525 for (int i = 0; i < 3; i++) {
526 auto constant = ScalarConstant::U32(wgsize[i].value);
527 if (wgsize[i].overridable_const) {
528 // Make the constant specializable.
529 auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
530 wgsize[i].overridable_const);
531 if (!sem_const->IsOverridable()) {
532 TINT_ICE(Writer, builder_.Diagnostics())
533 << "expected a pipeline-overridable constant";
534 }
535 constant.is_spec_op = true;
536 constant.constant_id = sem_const->ConstantId();
537 }
538
539 auto result = GenerateConstantIfNeeded(constant);
540 wgsize_ops.push_back(Operand::Int(result));
541 }
542
543 // Generate the WorkgroupSize builtin.
544 push_type(spv::Op::OpSpecConstantComposite, wgsize_ops);
545 push_annot(spv::Op::OpDecorate,
546 {wgsize_result, Operand::Int(SpvDecorationBuiltIn),
547 Operand::Int(SpvBuiltInWorkgroupSize)});
548 } else {
549 // Not overridable, so just use OpExecutionMode LocalSize.
550 uint32_t x = wgsize[0].value;
551 uint32_t y = wgsize[1].value;
552 uint32_t z = wgsize[2].value;
553 push_execution_mode(
554 spv::Op::OpExecutionMode,
555 {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
556 Operand::Int(x), Operand::Int(y), Operand::Int(z)});
557 }
558 }
559
560 for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {
561 if (builtin.second->builtin == ast::Builtin::kFragDepth) {
562 push_execution_mode(
563 spv::Op::OpExecutionMode,
564 {Operand::Int(id), Operand::Int(SpvExecutionModeDepthReplacing)});
565 }
566 }
567
568 return true;
569 }
570
GenerateExpression(const ast::Expression * expr)571 uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
572 if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
573 return GenerateAccessorExpression(a);
574 }
575 if (auto* b = expr->As<ast::BinaryExpression>()) {
576 return GenerateBinaryExpression(b);
577 }
578 if (auto* b = expr->As<ast::BitcastExpression>()) {
579 return GenerateBitcastExpression(b);
580 }
581 if (auto* c = expr->As<ast::CallExpression>()) {
582 return GenerateCallExpression(c);
583 }
584 if (auto* i = expr->As<ast::IdentifierExpression>()) {
585 return GenerateIdentifierExpression(i);
586 }
587 if (auto* l = expr->As<ast::LiteralExpression>()) {
588 return GenerateLiteralIfNeeded(nullptr, l);
589 }
590 if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
591 return GenerateAccessorExpression(m);
592 }
593 if (auto* u = expr->As<ast::UnaryOpExpression>()) {
594 return GenerateUnaryOpExpression(u);
595 }
596
597 error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
598 return 0;
599 }
600
GenerateFunction(const ast::Function * func_ast)601 bool Builder::GenerateFunction(const ast::Function* func_ast) {
602 auto* func = builder_.Sem().Get(func_ast);
603
604 uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
605 if (func_type_id == 0) {
606 return false;
607 }
608
609 auto func_op = result_op();
610 auto func_id = func_op.to_i();
611
612 push_debug(spv::Op::OpName,
613 {Operand::Int(func_id),
614 Operand::String(builder_.Symbols().NameFor(func_ast->symbol))});
615
616 auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
617 if (ret_id == 0) {
618 return false;
619 }
620
621 scope_stack_.Push();
622 TINT_DEFER(scope_stack_.Pop());
623
624 auto definition_inst = Instruction{
625 spv::Op::OpFunction,
626 {Operand::Int(ret_id), func_op, Operand::Int(SpvFunctionControlMaskNone),
627 Operand::Int(func_type_id)}};
628
629 InstructionList params;
630 for (auto* param : func->Parameters()) {
631 auto param_op = result_op();
632 auto param_id = param_op.to_i();
633
634 auto param_type_id = GenerateTypeIfNeeded(param->Type());
635 if (param_type_id == 0) {
636 return false;
637 }
638
639 push_debug(spv::Op::OpName, {Operand::Int(param_id),
640 Operand::String(builder_.Symbols().NameFor(
641 param->Declaration()->symbol))});
642 params.push_back(Instruction{spv::Op::OpFunctionParameter,
643 {Operand::Int(param_type_id), param_op}});
644
645 scope_stack_.Set(param->Declaration()->symbol, param_id);
646 }
647
648 push_function(Function{definition_inst, result_op(), std::move(params)});
649
650 for (auto* stmt : func_ast->body->statements) {
651 if (!GenerateStatement(stmt)) {
652 return false;
653 }
654 }
655
656 if (!LastIsTerminator(func_ast->body)) {
657 if (func->ReturnType()->Is<sem::Void>()) {
658 push_function_inst(spv::Op::OpReturn, {});
659 } else {
660 auto zero = GenerateConstantNullIfNeeded(func->ReturnType());
661 push_function_inst(spv::Op::OpReturnValue, {Operand::Int(zero)});
662 }
663 }
664
665 if (func_ast->IsEntryPoint()) {
666 if (!GenerateEntryPoint(func_ast, func_id)) {
667 return false;
668 }
669 if (!GenerateExecutionModes(func_ast, func_id)) {
670 return false;
671 }
672 }
673
674 func_symbol_to_id_[func_ast->symbol] = func_id;
675
676 return true;
677 }
678
GenerateFunctionTypeIfNeeded(const sem::Function * func)679 uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
680 return utils::GetOrCreate(
681 func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
682 auto func_op = result_op();
683 auto func_type_id = func_op.to_i();
684
685 auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
686 if (ret_id == 0) {
687 return 0;
688 }
689
690 OperandList ops = {func_op, Operand::Int(ret_id)};
691 for (auto* param : func->Parameters()) {
692 auto param_type_id = GenerateTypeIfNeeded(param->Type());
693 if (param_type_id == 0) {
694 return 0;
695 }
696 ops.push_back(Operand::Int(param_type_id));
697 }
698
699 push_type(spv::Op::OpTypeFunction, std::move(ops));
700 return func_type_id;
701 });
702 }
703
GenerateFunctionVariable(const ast::Variable * var)704 bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
705 uint32_t init_id = 0;
706 if (var->constructor) {
707 init_id = GenerateExpression(var->constructor);
708 if (init_id == 0) {
709 return false;
710 }
711 auto* type = TypeOf(var->constructor);
712 if (type->Is<sem::Reference>()) {
713 init_id = GenerateLoadIfNeeded(type, init_id);
714 }
715 }
716
717 if (var->is_const) {
718 if (!var->constructor) {
719 error_ = "missing constructor for constant";
720 return false;
721 }
722 scope_stack_.Set(var->symbol, init_id);
723 spirv_id_to_variable_[init_id] = var;
724 return true;
725 }
726
727 auto result = result_op();
728 auto var_id = result.to_i();
729 auto sc = ast::StorageClass::kFunction;
730 auto* type = builder_.Sem().Get(var)->Type();
731 auto type_id = GenerateTypeIfNeeded(type);
732 if (type_id == 0) {
733 return false;
734 }
735
736 push_debug(spv::Op::OpName,
737 {Operand::Int(var_id),
738 Operand::String(builder_.Symbols().NameFor(var->symbol))});
739
740 // TODO(dsinclair) We could detect if the constructor is fully const and emit
741 // an initializer value for the variable instead of doing the OpLoad.
742 auto null_id = GenerateConstantNullIfNeeded(type->UnwrapRef());
743 if (null_id == 0) {
744 return 0;
745 }
746 push_function_var({Operand::Int(type_id), result,
747 Operand::Int(ConvertStorageClass(sc)),
748 Operand::Int(null_id)});
749
750 if (var->constructor) {
751 if (!GenerateStore(var_id, init_id)) {
752 return false;
753 }
754 }
755
756 scope_stack_.Set(var->symbol, var_id);
757 spirv_id_to_variable_[var_id] = var;
758
759 return true;
760 }
761
GenerateStore(uint32_t to,uint32_t from)762 bool Builder::GenerateStore(uint32_t to, uint32_t from) {
763 return push_function_inst(spv::Op::OpStore,
764 {Operand::Int(to), Operand::Int(from)});
765 }
766
GenerateGlobalVariable(const ast::Variable * var)767 bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
768 auto* sem = builder_.Sem().Get(var);
769 auto* type = sem->Type()->UnwrapRef();
770
771 uint32_t init_id = 0;
772 if (var->constructor) {
773 init_id = GenerateConstructorExpression(var, var->constructor);
774 if (init_id == 0) {
775 return false;
776 }
777 }
778
779 if (var->is_const) {
780 if (!var->constructor) {
781 // Constants must have an initializer unless they have an override
782 // decoration.
783 if (!ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
784 error_ = "missing constructor for constant";
785 return false;
786 }
787
788 // SPIR-V requires specialization constants to have initializers.
789 if (type->Is<sem::F32>()) {
790 ast::FloatLiteralExpression l(ProgramID(), Source{}, 0.0f);
791 init_id = GenerateLiteralIfNeeded(var, &l);
792 } else if (type->Is<sem::U32>()) {
793 ast::UintLiteralExpression l(ProgramID(), Source{}, 0);
794 init_id = GenerateLiteralIfNeeded(var, &l);
795 } else if (type->Is<sem::I32>()) {
796 ast::SintLiteralExpression l(ProgramID(), Source{}, 0);
797 init_id = GenerateLiteralIfNeeded(var, &l);
798 } else if (type->Is<sem::Bool>()) {
799 ast::BoolLiteralExpression l(ProgramID(), Source{}, false);
800 init_id = GenerateLiteralIfNeeded(var, &l);
801 } else {
802 error_ = "invalid type for pipeline constant ID, must be scalar";
803 return false;
804 }
805 if (init_id == 0) {
806 return 0;
807 }
808 }
809 push_debug(spv::Op::OpName,
810 {Operand::Int(init_id),
811 Operand::String(builder_.Symbols().NameFor(var->symbol))});
812
813 scope_stack_.Set(var->symbol, init_id);
814 spirv_id_to_variable_[init_id] = var;
815 return true;
816 }
817
818 auto result = result_op();
819 auto var_id = result.to_i();
820
821 auto sc = sem->StorageClass() == ast::StorageClass::kNone
822 ? ast::StorageClass::kPrivate
823 : sem->StorageClass();
824
825 auto type_id = GenerateTypeIfNeeded(sem->Type());
826 if (type_id == 0) {
827 return false;
828 }
829
830 push_debug(spv::Op::OpName,
831 {Operand::Int(var_id),
832 Operand::String(builder_.Symbols().NameFor(var->symbol))});
833
834 OperandList ops = {Operand::Int(type_id), result,
835 Operand::Int(ConvertStorageClass(sc))};
836
837 if (var->constructor) {
838 ops.push_back(Operand::Int(init_id));
839 } else {
840 auto* st = type->As<sem::StorageTexture>();
841 if (st || type->Is<sem::Struct>()) {
842 // type is a sem::Struct or a sem::StorageTexture
843 auto access = st ? st->access() : sem->Access();
844 switch (access) {
845 case ast::Access::kWrite:
846 push_annot(
847 spv::Op::OpDecorate,
848 {Operand::Int(var_id), Operand::Int(SpvDecorationNonReadable)});
849 break;
850 case ast::Access::kRead:
851 push_annot(
852 spv::Op::OpDecorate,
853 {Operand::Int(var_id), Operand::Int(SpvDecorationNonWritable)});
854 break;
855 case ast::Access::kUndefined:
856 case ast::Access::kReadWrite:
857 break;
858 }
859 }
860 if (!type->Is<sem::Sampler>()) {
861 // If we don't have a constructor and we're an Output or Private
862 // variable, then WGSL requires that we zero-initialize.
863 if (sem->StorageClass() == ast::StorageClass::kPrivate ||
864 sem->StorageClass() == ast::StorageClass::kOutput) {
865 init_id = GenerateConstantNullIfNeeded(type);
866 if (init_id == 0) {
867 return 0;
868 }
869 ops.push_back(Operand::Int(init_id));
870 }
871 }
872 }
873
874 push_type(spv::Op::OpVariable, std::move(ops));
875
876 for (auto* deco : var->decorations) {
877 if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
878 push_annot(spv::Op::OpDecorate,
879 {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
880 Operand::Int(
881 ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
882 } else if (auto* location = deco->As<ast::LocationDecoration>()) {
883 push_annot(spv::Op::OpDecorate,
884 {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
885 Operand::Int(location->value)});
886 } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
887 AddInterpolationDecorations(var_id, interpolate->type,
888 interpolate->sampling);
889 } else if (deco->Is<ast::InvariantDecoration>()) {
890 push_annot(spv::Op::OpDecorate,
891 {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
892 } else if (auto* binding = deco->As<ast::BindingDecoration>()) {
893 push_annot(spv::Op::OpDecorate,
894 {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
895 Operand::Int(binding->value)});
896 } else if (auto* group = deco->As<ast::GroupDecoration>()) {
897 push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
898 Operand::Int(SpvDecorationDescriptorSet),
899 Operand::Int(group->value)});
900 } else if (deco->Is<ast::OverrideDecoration>()) {
901 // Spec constants are handled elsewhere
902 } else if (!deco->Is<ast::InternalDecoration>()) {
903 error_ = "unknown decoration";
904 return false;
905 }
906 }
907
908 scope_stack_.Set(var->symbol, var_id);
909 spirv_id_to_variable_[var_id] = var;
910 return true;
911 }
912
GenerateIndexAccessor(const ast::IndexAccessorExpression * expr,AccessorInfo * info)913 bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr,
914 AccessorInfo* info) {
915 auto idx_id = GenerateExpression(expr->index);
916 if (idx_id == 0) {
917 return 0;
918 }
919 auto* type = TypeOf(expr->index);
920 idx_id = GenerateLoadIfNeeded(type, idx_id);
921
922 // If the source is a reference, we access chain into it.
923 // In the future, pointers may support access-chaining.
924 // See https://github.com/gpuweb/gpuweb/pull/1580
925 if (info->source_type->Is<sem::Reference>()) {
926 info->access_chain_indices.push_back(idx_id);
927 info->source_type = TypeOf(expr);
928 return true;
929 }
930
931 auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
932 if (result_type_id == 0) {
933 return false;
934 }
935
936 // We don't have a pointer, so we can just directly extract the value.
937 auto extract = result_op();
938 auto extract_id = extract.to_i();
939
940 // If the index is a literal, we use OpCompositeExtract.
941 if (auto* literal = expr->index->As<ast::IntLiteralExpression>()) {
942 if (!push_function_inst(spv::Op::OpCompositeExtract,
943 {Operand::Int(result_type_id), extract,
944 Operand::Int(info->source_id),
945 Operand::Int(literal->ValueAsU32())})) {
946 return false;
947 }
948
949 info->source_id = extract_id;
950 info->source_type = TypeOf(expr);
951
952 return true;
953 }
954
955 // If the source is a vector, we use OpVectorExtractDynamic.
956 if (info->source_type->Is<sem::Vector>()) {
957 if (!push_function_inst(
958 spv::Op::OpVectorExtractDynamic,
959 {Operand::Int(result_type_id), extract,
960 Operand::Int(info->source_id), Operand::Int(idx_id)})) {
961 return false;
962 }
963
964 info->source_id = extract_id;
965 info->source_type = TypeOf(expr);
966
967 return true;
968 }
969
970 TINT_ICE(Writer, builder_.Diagnostics())
971 << "unsupported index accessor expression";
972 return false;
973 }
974
GenerateMemberAccessor(const ast::MemberAccessorExpression * expr,AccessorInfo * info)975 bool Builder::GenerateMemberAccessor(const ast::MemberAccessorExpression* expr,
976 AccessorInfo* info) {
977 auto* expr_sem = builder_.Sem().Get(expr);
978 auto* expr_type = expr_sem->Type();
979
980 if (auto* access = expr_sem->As<sem::StructMemberAccess>()) {
981 uint32_t idx = access->Member()->Index();
982
983 if (info->source_type->Is<sem::Reference>()) {
984 auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
985 if (idx_id == 0) {
986 return 0;
987 }
988 info->access_chain_indices.push_back(idx_id);
989 info->source_type = expr_type;
990 } else {
991 auto result_type_id = GenerateTypeIfNeeded(expr_type);
992 if (result_type_id == 0) {
993 return false;
994 }
995
996 auto extract = result_op();
997 auto extract_id = extract.to_i();
998 if (!push_function_inst(
999 spv::Op::OpCompositeExtract,
1000 {Operand::Int(result_type_id), extract,
1001 Operand::Int(info->source_id), Operand::Int(idx)})) {
1002 return false;
1003 }
1004
1005 info->source_id = extract_id;
1006 info->source_type = expr_type;
1007 }
1008
1009 return true;
1010 }
1011
1012 if (auto* swizzle = expr_sem->As<sem::Swizzle>()) {
1013 // Single element swizzle is either an access chain or a composite extract
1014 auto& indices = swizzle->Indices();
1015 if (indices.size() == 1) {
1016 if (info->source_type->Is<sem::Reference>()) {
1017 auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
1018 if (idx_id == 0) {
1019 return 0;
1020 }
1021 info->access_chain_indices.push_back(idx_id);
1022 } else {
1023 auto result_type_id = GenerateTypeIfNeeded(expr_type);
1024 if (result_type_id == 0) {
1025 return 0;
1026 }
1027
1028 auto extract = result_op();
1029 auto extract_id = extract.to_i();
1030 if (!push_function_inst(
1031 spv::Op::OpCompositeExtract,
1032 {Operand::Int(result_type_id), extract,
1033 Operand::Int(info->source_id), Operand::Int(indices[0])})) {
1034 return false;
1035 }
1036
1037 info->source_id = extract_id;
1038 info->source_type = expr_type;
1039 }
1040 return true;
1041 }
1042
1043 // Store the type away as it may change if we run the access chain
1044 auto* incoming_type = info->source_type;
1045
1046 // Multi-item extract is a VectorShuffle. We have to emit any existing
1047 // access chain data, then load the access chain and shuffle that.
1048 if (!info->access_chain_indices.empty()) {
1049 auto result_type_id = GenerateTypeIfNeeded(info->source_type);
1050 if (result_type_id == 0) {
1051 return 0;
1052 }
1053 auto extract = result_op();
1054 auto extract_id = extract.to_i();
1055
1056 OperandList ops = {Operand::Int(result_type_id), extract,
1057 Operand::Int(info->source_id)};
1058 for (auto id : info->access_chain_indices) {
1059 ops.push_back(Operand::Int(id));
1060 }
1061
1062 if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
1063 return false;
1064 }
1065
1066 info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
1067 info->source_type = expr_type->UnwrapRef();
1068 info->access_chain_indices.clear();
1069 }
1070
1071 auto result_type_id = GenerateTypeIfNeeded(expr_type);
1072 if (result_type_id == 0) {
1073 return false;
1074 }
1075
1076 auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
1077
1078 auto result = result_op();
1079 auto result_id = result.to_i();
1080
1081 OperandList ops = {Operand::Int(result_type_id), result,
1082 Operand::Int(vec_id), Operand::Int(vec_id)};
1083
1084 for (auto idx : indices) {
1085 ops.push_back(Operand::Int(idx));
1086 }
1087
1088 if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
1089 return false;
1090 }
1091 info->source_id = result_id;
1092 info->source_type = expr_type;
1093 return true;
1094 }
1095
1096 TINT_ICE(Writer, builder_.Diagnostics())
1097 << "unhandled member index type: " << expr_sem->TypeInfo().name;
1098 return false;
1099 }
1100
GenerateAccessorExpression(const ast::Expression * expr)1101 uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
1102 if (!expr->IsAnyOf<ast::IndexAccessorExpression,
1103 ast::MemberAccessorExpression>()) {
1104 TINT_ICE(Writer, builder_.Diagnostics()) << "expression is not an accessor";
1105 return 0;
1106 }
1107
1108 // Gather a list of all the member and index accessors that are in this chain.
1109 // The list is built in reverse order as that's the order we need to access
1110 // the chain.
1111 std::vector<const ast::Expression*> accessors;
1112 const ast::Expression* source = expr;
1113 while (true) {
1114 if (auto* array = source->As<ast::IndexAccessorExpression>()) {
1115 accessors.insert(accessors.begin(), source);
1116 source = array->object;
1117 } else if (auto* member = source->As<ast::MemberAccessorExpression>()) {
1118 accessors.insert(accessors.begin(), source);
1119 source = member->structure;
1120 } else {
1121 break;
1122 }
1123 }
1124
1125 AccessorInfo info;
1126 info.source_id = GenerateExpression(source);
1127 if (info.source_id == 0) {
1128 return 0;
1129 }
1130 info.source_type = TypeOf(source);
1131
1132 for (auto* accessor : accessors) {
1133 if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
1134 if (!GenerateIndexAccessor(array, &info)) {
1135 return 0;
1136 }
1137 } else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
1138 if (!GenerateMemberAccessor(member, &info)) {
1139 return 0;
1140 }
1141
1142 } else {
1143 error_ =
1144 "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
1145 return 0;
1146 }
1147 }
1148
1149 if (!info.access_chain_indices.empty()) {
1150 auto* type = TypeOf(expr);
1151 auto result_type_id = GenerateTypeIfNeeded(type);
1152 if (result_type_id == 0) {
1153 return 0;
1154 }
1155
1156 auto result = result_op();
1157 auto result_id = result.to_i();
1158
1159 OperandList ops = {Operand::Int(result_type_id), result,
1160 Operand::Int(info.source_id)};
1161 for (auto id : info.access_chain_indices) {
1162 ops.push_back(Operand::Int(id));
1163 }
1164
1165 if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
1166 return false;
1167 }
1168 info.source_id = result_id;
1169 }
1170
1171 return info.source_id;
1172 }
1173
GenerateIdentifierExpression(const ast::IdentifierExpression * expr)1174 uint32_t Builder::GenerateIdentifierExpression(
1175 const ast::IdentifierExpression* expr) {
1176 uint32_t val = scope_stack_.Get(expr->symbol);
1177 if (val == 0) {
1178 error_ = "unable to find variable with identifier: " +
1179 builder_.Symbols().NameFor(expr->symbol);
1180 }
1181 return val;
1182 }
1183
GenerateLoadIfNeeded(const sem::Type * type,uint32_t id)1184 uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
1185 if (auto* ref = type->As<sem::Reference>()) {
1186 type = ref->StoreType();
1187 } else {
1188 return id;
1189 }
1190
1191 auto type_id = GenerateTypeIfNeeded(type);
1192 auto result = result_op();
1193 auto result_id = result.to_i();
1194 if (!push_function_inst(spv::Op::OpLoad,
1195 {Operand::Int(type_id), result, Operand::Int(id)})) {
1196 return 0;
1197 }
1198 return result_id;
1199 }
1200
GenerateUnaryOpExpression(const ast::UnaryOpExpression * expr)1201 uint32_t Builder::GenerateUnaryOpExpression(
1202 const ast::UnaryOpExpression* expr) {
1203 auto result = result_op();
1204 auto result_id = result.to_i();
1205
1206 auto val_id = GenerateExpression(expr->expr);
1207 if (val_id == 0) {
1208 return 0;
1209 }
1210
1211 spv::Op op = spv::Op::OpNop;
1212 switch (expr->op) {
1213 case ast::UnaryOp::kComplement:
1214 op = spv::Op::OpNot;
1215 break;
1216 case ast::UnaryOp::kNegation:
1217 if (TypeOf(expr)->is_float_scalar_or_vector()) {
1218 op = spv::Op::OpFNegate;
1219 } else {
1220 op = spv::Op::OpSNegate;
1221 }
1222 break;
1223 case ast::UnaryOp::kNot:
1224 op = spv::Op::OpLogicalNot;
1225 break;
1226 case ast::UnaryOp::kAddressOf:
1227 case ast::UnaryOp::kIndirection:
1228 // Address-of converts a reference to a pointer, and dereference converts
1229 // a pointer to a reference. These are the same thing in SPIR-V, so this
1230 // is a no-op.
1231 return val_id;
1232 }
1233
1234 val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
1235
1236 auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1237 if (type_id == 0) {
1238 return 0;
1239 }
1240
1241 if (!push_function_inst(
1242 op, {Operand::Int(type_id), result, Operand::Int(val_id)})) {
1243 return false;
1244 }
1245
1246 return result_id;
1247 }
1248
GetGLSLstd450Import()1249 uint32_t Builder::GetGLSLstd450Import() {
1250 auto where = import_name_to_id_.find(kGLSLstd450);
1251 if (where != import_name_to_id_.end()) {
1252 return where->second;
1253 }
1254
1255 // It doesn't exist yet. Generate it.
1256 auto result = result_op();
1257 auto id = result.to_i();
1258
1259 push_ext_import(spv::Op::OpExtInstImport,
1260 {result, Operand::String(kGLSLstd450)});
1261
1262 // Remember it for later.
1263 import_name_to_id_[kGLSLstd450] = id;
1264 return id;
1265 }
1266
GenerateConstructorExpression(const ast::Variable * var,const ast::Expression * expr)1267 uint32_t Builder::GenerateConstructorExpression(const ast::Variable* var,
1268 const ast::Expression* expr) {
1269 if (auto* literal = expr->As<ast::LiteralExpression>()) {
1270 return GenerateLiteralIfNeeded(var, literal);
1271 }
1272 if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
1273 if (call->Target()->IsAnyOf<sem::TypeConstructor, sem::TypeConversion>()) {
1274 return GenerateTypeConstructorOrConversion(call, var);
1275 }
1276 }
1277 error_ = "unknown constructor expression";
1278 return 0;
1279 }
1280
IsConstructorConst(const ast::Expression * expr)1281 bool Builder::IsConstructorConst(const ast::Expression* expr) {
1282 bool is_const = true;
1283 ast::TraverseExpressions(expr, builder_.Diagnostics(),
1284 [&](const ast::Expression* e) {
1285 if (e->Is<ast::LiteralExpression>()) {
1286 return ast::TraverseAction::Descend;
1287 }
1288 if (auto* ce = e->As<ast::CallExpression>()) {
1289 auto* call = builder_.Sem().Get(ce);
1290 if (call->Target()->Is<sem::TypeConstructor>()) {
1291 return ast::TraverseAction::Descend;
1292 }
1293 }
1294
1295 is_const = false;
1296 return ast::TraverseAction::Stop;
1297 });
1298 return is_const;
1299 }
1300
GenerateTypeConstructorOrConversion(const sem::Call * call,const ast::Variable * var)1301 uint32_t Builder::GenerateTypeConstructorOrConversion(
1302 const sem::Call* call,
1303 const ast::Variable* var) {
1304 auto& args = call->Arguments();
1305 auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
1306 auto* result_type = call->Type();
1307
1308 // Generate the zero initializer if there are no values provided.
1309 if (args.empty()) {
1310 if (global_var && global_var->IsOverridable()) {
1311 auto constant_id = global_var->ConstantId();
1312 if (result_type->Is<sem::I32>()) {
1313 return GenerateConstantIfNeeded(
1314 ScalarConstant::I32(0).AsSpecOp(constant_id));
1315 }
1316 if (result_type->Is<sem::U32>()) {
1317 return GenerateConstantIfNeeded(
1318 ScalarConstant::U32(0).AsSpecOp(constant_id));
1319 }
1320 if (result_type->Is<sem::F32>()) {
1321 return GenerateConstantIfNeeded(
1322 ScalarConstant::F32(0).AsSpecOp(constant_id));
1323 }
1324 if (result_type->Is<sem::Bool>()) {
1325 return GenerateConstantIfNeeded(
1326 ScalarConstant::Bool(false).AsSpecOp(constant_id));
1327 }
1328 }
1329 return GenerateConstantNullIfNeeded(result_type->UnwrapRef());
1330 }
1331
1332 std::ostringstream out;
1333 out << "__const_" << result_type->FriendlyName(builder_.Symbols()) << "_";
1334
1335 result_type = result_type->UnwrapRef();
1336 bool constructor_is_const = IsConstructorConst(call->Declaration());
1337 if (has_error()) {
1338 return 0;
1339 }
1340
1341 bool can_cast_or_copy = result_type->is_scalar();
1342
1343 if (auto* res_vec = result_type->As<sem::Vector>()) {
1344 if (res_vec->type()->is_scalar()) {
1345 auto* value_type = args[0]->Type()->UnwrapRef();
1346 if (auto* val_vec = value_type->As<sem::Vector>()) {
1347 if (val_vec->type()->is_scalar()) {
1348 can_cast_or_copy = res_vec->Width() == val_vec->Width();
1349 }
1350 }
1351 }
1352 }
1353
1354 if (can_cast_or_copy) {
1355 return GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
1356 global_var);
1357 }
1358
1359 auto type_id = GenerateTypeIfNeeded(result_type);
1360 if (type_id == 0) {
1361 return 0;
1362 }
1363
1364 bool result_is_constant_composite = constructor_is_const;
1365 bool result_is_spec_composite = false;
1366
1367 if (auto* vec = result_type->As<sem::Vector>()) {
1368 result_type = vec->type();
1369 }
1370
1371 OperandList ops;
1372 for (auto* e : args) {
1373 uint32_t id = 0;
1374 id = GenerateExpression(e->Declaration());
1375 if (id == 0) {
1376 return 0;
1377 }
1378 id = GenerateLoadIfNeeded(e->Type(), id);
1379 if (id == 0) {
1380 return 0;
1381 }
1382
1383 auto* value_type = e->Type()->UnwrapRef();
1384 // If the result and value types are the same we can just use the object.
1385 // If the result is not a vector then we should have validated that the
1386 // value type is a correctly sized vector so we can just use it directly.
1387 if (result_type == value_type || result_type->Is<sem::Matrix>() ||
1388 result_type->Is<sem::Array>() || result_type->Is<sem::Struct>()) {
1389 out << "_" << id;
1390
1391 ops.push_back(Operand::Int(id));
1392 continue;
1393 }
1394
1395 // Both scalars, but not the same type so we need to generate a conversion
1396 // of the value.
1397 if (value_type->is_scalar() && result_type->is_scalar()) {
1398 id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
1399 global_var);
1400 out << "_" << id;
1401 ops.push_back(Operand::Int(id));
1402 continue;
1403 }
1404
1405 // When handling vectors as the values there a few cases to take into
1406 // consideration:
1407 // 1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpSpecConstantOp
1408 // 2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) -> OpCompositeExtract
1409 // 3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3)) -> use the ID.
1410 // -> handled above
1411 //
1412 // For cases 1 and 2, if the type is different we also may need to insert
1413 // a type cast.
1414 if (auto* vec = value_type->As<sem::Vector>()) {
1415 auto* vec_type = vec->type();
1416
1417 auto value_type_id = GenerateTypeIfNeeded(vec_type);
1418 if (value_type_id == 0) {
1419 return 0;
1420 }
1421
1422 for (uint32_t i = 0; i < vec->Width(); ++i) {
1423 auto extract = result_op();
1424 auto extract_id = extract.to_i();
1425
1426 if (!global_var) {
1427 // A non-global initializer. Case 2.
1428 if (!push_function_inst(spv::Op::OpCompositeExtract,
1429 {Operand::Int(value_type_id), extract,
1430 Operand::Int(id), Operand::Int(i)})) {
1431 return false;
1432 }
1433
1434 // We no longer have a constant composite, but have to do a
1435 // composite construction as these calls are inside a function.
1436 result_is_constant_composite = false;
1437 } else {
1438 // A global initializer, must use OpSpecConstantOp. Case 1.
1439 auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
1440 if (idx_id == 0) {
1441 return 0;
1442 }
1443 push_type(spv::Op::OpSpecConstantOp,
1444 {Operand::Int(value_type_id), extract,
1445 Operand::Int(SpvOpCompositeExtract), Operand::Int(id),
1446 Operand::Int(idx_id)});
1447
1448 result_is_spec_composite = true;
1449 }
1450
1451 out << "_" << extract_id;
1452 ops.push_back(Operand::Int(extract_id));
1453 }
1454 } else {
1455 error_ = "Unhandled type cast value type";
1456 return 0;
1457 }
1458 }
1459
1460 // For a single-value vector initializer, splat the initializer value.
1461 auto* const init_result_type = call->Type()->UnwrapRef();
1462 if (args.size() == 1 && init_result_type->is_scalar_vector() &&
1463 args[0]->Type()->UnwrapRef()->is_scalar()) {
1464 size_t vec_size = init_result_type->As<sem::Vector>()->Width();
1465 for (size_t i = 0; i < (vec_size - 1); ++i) {
1466 ops.push_back(ops[0]);
1467 }
1468 }
1469
1470 auto str = out.str();
1471 auto val = type_constructor_to_id_.find(str);
1472 if (val != type_constructor_to_id_.end()) {
1473 return val->second;
1474 }
1475
1476 auto result = result_op();
1477 ops.insert(ops.begin(), result);
1478 ops.insert(ops.begin(), Operand::Int(type_id));
1479
1480 type_constructor_to_id_[str] = result.to_i();
1481
1482 if (result_is_spec_composite) {
1483 push_type(spv::Op::OpSpecConstantComposite, ops);
1484 } else if (result_is_constant_composite) {
1485 push_type(spv::Op::OpConstantComposite, ops);
1486 } else {
1487 if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1488 return 0;
1489 }
1490 }
1491
1492 return result.to_i();
1493 }
1494
GenerateCastOrCopyOrPassthrough(const sem::Type * to_type,const ast::Expression * from_expr,bool is_global_init)1495 uint32_t Builder::GenerateCastOrCopyOrPassthrough(
1496 const sem::Type* to_type,
1497 const ast::Expression* from_expr,
1498 bool is_global_init) {
1499 // This should not happen as we rely on constant folding to obviate
1500 // casts/conversions for module-scope variables
1501 if (is_global_init) {
1502 TINT_ICE(Writer, builder_.Diagnostics())
1503 << "Module-level conversions are not supported. Conversions should "
1504 "have already been constant-folded by the FoldConstants transform.";
1505 return 0;
1506 }
1507
1508 auto elem_type_of = [](const sem::Type* t) -> const sem::Type* {
1509 if (t->is_scalar()) {
1510 return t;
1511 }
1512 if (auto* v = t->As<sem::Vector>()) {
1513 return v->type();
1514 }
1515 return nullptr;
1516 };
1517
1518 auto result = result_op();
1519 auto result_id = result.to_i();
1520
1521 auto result_type_id = GenerateTypeIfNeeded(to_type);
1522 if (result_type_id == 0) {
1523 return 0;
1524 }
1525
1526 auto val_id = GenerateExpression(from_expr);
1527 if (val_id == 0) {
1528 return 0;
1529 }
1530 val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
1531
1532 auto* from_type = TypeOf(from_expr)->UnwrapRef();
1533
1534 spv::Op op = spv::Op::OpNop;
1535 if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) ||
1536 (from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
1537 op = spv::Op::OpConvertSToF;
1538 } else if ((from_type->Is<sem::U32>() && to_type->Is<sem::F32>()) ||
1539 (from_type->is_unsigned_integer_vector() &&
1540 to_type->is_float_vector())) {
1541 op = spv::Op::OpConvertUToF;
1542 } else if ((from_type->Is<sem::F32>() && to_type->Is<sem::I32>()) ||
1543 (from_type->is_float_vector() &&
1544 to_type->is_signed_integer_vector())) {
1545 op = spv::Op::OpConvertFToS;
1546 } else if ((from_type->Is<sem::F32>() && to_type->Is<sem::U32>()) ||
1547 (from_type->is_float_vector() &&
1548 to_type->is_unsigned_integer_vector())) {
1549 op = spv::Op::OpConvertFToU;
1550 } else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
1551 (from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
1552 (from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
1553 (from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
1554 (from_type->Is<sem::Vector>() && (from_type == to_type))) {
1555 return val_id;
1556 } else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
1557 (from_type->Is<sem::U32>() && to_type->Is<sem::I32>()) ||
1558 (from_type->is_signed_integer_vector() &&
1559 to_type->is_unsigned_integer_vector()) ||
1560 (from_type->is_unsigned_integer_vector() &&
1561 to_type->is_integer_scalar_or_vector())) {
1562 op = spv::Op::OpBitcast;
1563 } else if ((from_type->is_numeric_scalar() && to_type->Is<sem::Bool>()) ||
1564 (from_type->is_numeric_vector() && to_type->is_bool_vector())) {
1565 // Convert scalar (vector) to bool (vector)
1566
1567 // Return the result of comparing from_expr with zero
1568 uint32_t zero = GenerateConstantNullIfNeeded(from_type);
1569 const auto* from_elem_type = elem_type_of(from_type);
1570 op = from_elem_type->is_integer_scalar() ? spv::Op::OpINotEqual
1571 : spv::Op::OpFUnordNotEqual;
1572 if (!push_function_inst(
1573 op, {Operand::Int(result_type_id), Operand::Int(result_id),
1574 Operand::Int(val_id), Operand::Int(zero)})) {
1575 return 0;
1576 }
1577
1578 return result_id;
1579 } else if (from_type->is_bool_scalar_or_vector() &&
1580 to_type->is_numeric_scalar_or_vector()) {
1581 // Convert bool scalar/vector to numeric scalar/vector.
1582 // Use the bool to select between 1 (if true) and 0 (if false).
1583
1584 const auto* to_elem_type = elem_type_of(to_type);
1585 uint32_t one_id;
1586 uint32_t zero_id;
1587 if (to_elem_type->Is<sem::F32>()) {
1588 ast::FloatLiteralExpression one(ProgramID(), Source{}, 1.0f);
1589 ast::FloatLiteralExpression zero(ProgramID(), Source{}, 0.0f);
1590 one_id = GenerateLiteralIfNeeded(nullptr, &one);
1591 zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1592 } else if (to_elem_type->Is<sem::U32>()) {
1593 ast::UintLiteralExpression one(ProgramID(), Source{}, 1);
1594 ast::UintLiteralExpression zero(ProgramID(), Source{}, 0);
1595 one_id = GenerateLiteralIfNeeded(nullptr, &one);
1596 zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1597 } else if (to_elem_type->Is<sem::I32>()) {
1598 ast::SintLiteralExpression one(ProgramID(), Source{}, 1);
1599 ast::SintLiteralExpression zero(ProgramID(), Source{}, 0);
1600 one_id = GenerateLiteralIfNeeded(nullptr, &one);
1601 zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1602 } else {
1603 error_ = "invalid destination type for bool conversion";
1604 return false;
1605 }
1606 if (auto* to_vec = to_type->As<sem::Vector>()) {
1607 // Splat the scalars into vectors.
1608 one_id = GenerateConstantVectorSplatIfNeeded(to_vec, one_id);
1609 zero_id = GenerateConstantVectorSplatIfNeeded(to_vec, zero_id);
1610 }
1611 if (!one_id || !zero_id) {
1612 return false;
1613 }
1614
1615 op = spv::Op::OpSelect;
1616 if (!push_function_inst(
1617 op, {Operand::Int(result_type_id), Operand::Int(result_id),
1618 Operand::Int(val_id), Operand::Int(one_id),
1619 Operand::Int(zero_id)})) {
1620 return 0;
1621 }
1622
1623 return result_id;
1624 } else {
1625 TINT_ICE(Writer, builder_.Diagnostics()) << "Invalid from_type";
1626 }
1627
1628 if (op == spv::Op::OpNop) {
1629 error_ = "unable to determine conversion type for cast, from: " +
1630 from_type->type_name() + " to: " + to_type->type_name();
1631 return 0;
1632 }
1633
1634 if (!push_function_inst(
1635 op, {Operand::Int(result_type_id), result, Operand::Int(val_id)})) {
1636 return 0;
1637 }
1638
1639 return result_id;
1640 }
1641
GenerateLiteralIfNeeded(const ast::Variable * var,const ast::LiteralExpression * lit)1642 uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
1643 const ast::LiteralExpression* lit) {
1644 ScalarConstant constant;
1645
1646 auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
1647 if (global && global->IsOverridable()) {
1648 constant.is_spec_op = true;
1649 constant.constant_id = global->ConstantId();
1650 }
1651
1652 if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
1653 constant.kind = ScalarConstant::Kind::kBool;
1654 constant.value.b = l->value;
1655 } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
1656 constant.kind = ScalarConstant::Kind::kI32;
1657 constant.value.i32 = sl->value;
1658 } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
1659 constant.kind = ScalarConstant::Kind::kU32;
1660 constant.value.u32 = ul->value;
1661 } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
1662 constant.kind = ScalarConstant::Kind::kF32;
1663 constant.value.f32 = fl->value;
1664 } else {
1665 error_ = "unknown literal type";
1666 return 0;
1667 }
1668
1669 return GenerateConstantIfNeeded(constant);
1670 }
1671
GenerateConstantIfNeeded(const ScalarConstant & constant)1672 uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
1673 auto it = const_to_id_.find(constant);
1674 if (it != const_to_id_.end()) {
1675 return it->second;
1676 }
1677
1678 uint32_t type_id = 0;
1679
1680 switch (constant.kind) {
1681 case ScalarConstant::Kind::kU32: {
1682 type_id = GenerateTypeIfNeeded(builder_.create<sem::U32>());
1683 break;
1684 }
1685 case ScalarConstant::Kind::kI32: {
1686 type_id = GenerateTypeIfNeeded(builder_.create<sem::I32>());
1687 break;
1688 }
1689 case ScalarConstant::Kind::kF32: {
1690 type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
1691 break;
1692 }
1693 case ScalarConstant::Kind::kBool: {
1694 type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>());
1695 break;
1696 }
1697 }
1698
1699 if (type_id == 0) {
1700 return 0;
1701 }
1702
1703 auto result = result_op();
1704 auto result_id = result.to_i();
1705
1706 if (constant.is_spec_op) {
1707 push_annot(spv::Op::OpDecorate,
1708 {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
1709 Operand::Int(constant.constant_id)});
1710 }
1711
1712 switch (constant.kind) {
1713 case ScalarConstant::Kind::kU32: {
1714 push_type(
1715 constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1716 {Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
1717 break;
1718 }
1719 case ScalarConstant::Kind::kI32: {
1720 push_type(
1721 constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1722 {Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
1723 break;
1724 }
1725 case ScalarConstant::Kind::kF32: {
1726 push_type(
1727 constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1728 {Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
1729 break;
1730 }
1731 case ScalarConstant::Kind::kBool: {
1732 if (constant.value.b) {
1733 push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
1734 : spv::Op::OpConstantTrue,
1735 {Operand::Int(type_id), result});
1736 } else {
1737 push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
1738 : spv::Op::OpConstantFalse,
1739 {Operand::Int(type_id), result});
1740 }
1741 break;
1742 }
1743 }
1744
1745 const_to_id_[constant] = result_id;
1746 return result_id;
1747 }
1748
GenerateConstantNullIfNeeded(const sem::Type * type)1749 uint32_t Builder::GenerateConstantNullIfNeeded(const sem::Type* type) {
1750 auto type_id = GenerateTypeIfNeeded(type);
1751 if (type_id == 0) {
1752 return 0;
1753 }
1754
1755 auto name = type->type_name();
1756
1757 auto it = const_null_to_id_.find(name);
1758 if (it != const_null_to_id_.end()) {
1759 return it->second;
1760 }
1761
1762 auto result = result_op();
1763 auto result_id = result.to_i();
1764
1765 push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
1766
1767 const_null_to_id_[name] = result_id;
1768 return result_id;
1769 }
1770
GenerateConstantVectorSplatIfNeeded(const sem::Vector * type,uint32_t value_id)1771 uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const sem::Vector* type,
1772 uint32_t value_id) {
1773 auto type_id = GenerateTypeIfNeeded(type);
1774 if (type_id == 0 || value_id == 0) {
1775 return 0;
1776 }
1777
1778 uint64_t key = (static_cast<uint64_t>(type->Width()) << 32) + value_id;
1779 return utils::GetOrCreate(const_splat_to_id_, key, [&] {
1780 auto result = result_op();
1781 auto result_id = result.to_i();
1782
1783 OperandList ops;
1784 ops.push_back(Operand::Int(type_id));
1785 ops.push_back(result);
1786 for (uint32_t i = 0; i < type->Width(); i++) {
1787 ops.push_back(Operand::Int(value_id));
1788 }
1789 push_type(spv::Op::OpConstantComposite, ops);
1790
1791 const_splat_to_id_[key] = result_id;
1792 return result_id;
1793 });
1794 }
1795
GenerateShortCircuitBinaryExpression(const ast::BinaryExpression * expr)1796 uint32_t Builder::GenerateShortCircuitBinaryExpression(
1797 const ast::BinaryExpression* expr) {
1798 auto lhs_id = GenerateExpression(expr->lhs);
1799 if (lhs_id == 0) {
1800 return false;
1801 }
1802 lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
1803
1804 // Get the ID of the basic block where control flow will diverge. It's the
1805 // last basic block generated for the left-hand-side of the operator.
1806 auto original_label_id = current_label_id_;
1807
1808 auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1809 if (type_id == 0) {
1810 return 0;
1811 }
1812
1813 auto merge_block = result_op();
1814 auto merge_block_id = merge_block.to_i();
1815
1816 auto block = result_op();
1817 auto block_id = block.to_i();
1818
1819 auto true_block_id = block_id;
1820 auto false_block_id = merge_block_id;
1821
1822 // For a logical or we want to only check the RHS if the LHS is failed.
1823 if (expr->IsLogicalOr()) {
1824 std::swap(true_block_id, false_block_id);
1825 }
1826
1827 if (!push_function_inst(spv::Op::OpSelectionMerge,
1828 {Operand::Int(merge_block_id),
1829 Operand::Int(SpvSelectionControlMaskNone)})) {
1830 return 0;
1831 }
1832 if (!push_function_inst(spv::Op::OpBranchConditional,
1833 {Operand::Int(lhs_id), Operand::Int(true_block_id),
1834 Operand::Int(false_block_id)})) {
1835 return 0;
1836 }
1837
1838 // Output block to check the RHS
1839 if (!GenerateLabel(block_id)) {
1840 return 0;
1841 }
1842 auto rhs_id = GenerateExpression(expr->rhs);
1843 if (rhs_id == 0) {
1844 return 0;
1845 }
1846 rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
1847
1848 // Get the block ID of the last basic block generated for the right-hand-side
1849 // expression. That block will be an immediate predecessor to the merge block.
1850 auto rhs_block_id = current_label_id_;
1851 if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) {
1852 return 0;
1853 }
1854
1855 // Output the merge block
1856 if (!GenerateLabel(merge_block_id)) {
1857 return 0;
1858 }
1859
1860 auto result = result_op();
1861 auto result_id = result.to_i();
1862
1863 if (!push_function_inst(spv::Op::OpPhi,
1864 {Operand::Int(type_id), result, Operand::Int(lhs_id),
1865 Operand::Int(original_label_id),
1866 Operand::Int(rhs_id), Operand::Int(rhs_block_id)})) {
1867 return 0;
1868 }
1869
1870 return result_id;
1871 }
1872
GenerateSplat(uint32_t scalar_id,const sem::Type * vec_type)1873 uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
1874 // Create a new vector to splat scalar into
1875 auto splat_vector = result_op();
1876 auto* splat_vector_type = builder_.create<sem::Pointer>(
1877 vec_type, ast::StorageClass::kFunction, ast::Access::kReadWrite);
1878 push_function_var(
1879 {Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
1880 Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
1881 Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
1882
1883 // Splat scalar into vector
1884 auto splat_result = result_op();
1885 OperandList ops;
1886 ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
1887 ops.push_back(splat_result);
1888 for (size_t i = 0; i < vec_type->As<sem::Vector>()->Width(); ++i) {
1889 ops.push_back(Operand::Int(scalar_id));
1890 }
1891 if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1892 return 0;
1893 }
1894
1895 return splat_result.to_i();
1896 }
1897
GenerateMatrixAddOrSub(uint32_t lhs_id,uint32_t rhs_id,const sem::Matrix * type,spv::Op op)1898 uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
1899 uint32_t rhs_id,
1900 const sem::Matrix* type,
1901 spv::Op op) {
1902 // Example addition of two matrices:
1903 // %31 = OpLoad %mat3v4float %m34
1904 // %32 = OpLoad %mat3v4float %m34
1905 // %33 = OpCompositeExtract %v4float %31 0
1906 // %34 = OpCompositeExtract %v4float %32 0
1907 // %35 = OpFAdd %v4float %33 %34
1908 // %36 = OpCompositeExtract %v4float %31 1
1909 // %37 = OpCompositeExtract %v4float %32 1
1910 // %38 = OpFAdd %v4float %36 %37
1911 // %39 = OpCompositeExtract %v4float %31 2
1912 // %40 = OpCompositeExtract %v4float %32 2
1913 // %41 = OpFAdd %v4float %39 %40
1914 // %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
1915
1916 auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
1917 auto column_type_id = GenerateTypeIfNeeded(column_type);
1918
1919 OperandList ops;
1920
1921 for (uint32_t i = 0; i < type->columns(); ++i) {
1922 // Extract column `i` from lhs mat
1923 auto lhs_column_id = result_op();
1924 if (!push_function_inst(spv::Op::OpCompositeExtract,
1925 {Operand::Int(column_type_id), lhs_column_id,
1926 Operand::Int(lhs_id), Operand::Int(i)})) {
1927 return 0;
1928 }
1929
1930 // Extract column `i` from rhs mat
1931 auto rhs_column_id = result_op();
1932 if (!push_function_inst(spv::Op::OpCompositeExtract,
1933 {Operand::Int(column_type_id), rhs_column_id,
1934 Operand::Int(rhs_id), Operand::Int(i)})) {
1935 return 0;
1936 }
1937
1938 // Add or subtract the two columns
1939 auto result = result_op();
1940 if (!push_function_inst(op, {Operand::Int(column_type_id), result,
1941 lhs_column_id, rhs_column_id})) {
1942 return 0;
1943 }
1944
1945 ops.push_back(result);
1946 }
1947
1948 // Create the result matrix from the added/subtracted column vectors
1949 auto result_mat_id = result_op();
1950 ops.insert(ops.begin(), result_mat_id);
1951 ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
1952 if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1953 return 0;
1954 }
1955
1956 return result_mat_id.to_i();
1957 }
1958
GenerateBinaryExpression(const ast::BinaryExpression * expr)1959 uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
1960 // There is special logic for short circuiting operators.
1961 if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
1962 return GenerateShortCircuitBinaryExpression(expr);
1963 }
1964
1965 auto lhs_id = GenerateExpression(expr->lhs);
1966 if (lhs_id == 0) {
1967 return 0;
1968 }
1969 lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
1970
1971 auto rhs_id = GenerateExpression(expr->rhs);
1972 if (rhs_id == 0) {
1973 return 0;
1974 }
1975 rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
1976
1977 auto result = result_op();
1978 auto result_id = result.to_i();
1979
1980 auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1981 if (type_id == 0) {
1982 return 0;
1983 }
1984
1985 // Handle int and float and the vectors of those types. Other types
1986 // should have been rejected by validation.
1987 auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
1988 auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
1989
1990 // Handle matrix-matrix addition and subtraction
1991 if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
1992 rhs_type->is_float_matrix()) {
1993 auto* lhs_mat = lhs_type->As<sem::Matrix>();
1994 auto* rhs_mat = rhs_type->As<sem::Matrix>();
1995
1996 // This should already have been validated by resolver
1997 if (lhs_mat->rows() != rhs_mat->rows() ||
1998 lhs_mat->columns() != rhs_mat->columns()) {
1999 error_ = "matrices must have same dimensionality for add or subtract";
2000 return 0;
2001 }
2002
2003 return GenerateMatrixAddOrSub(
2004 lhs_id, rhs_id, lhs_mat,
2005 expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
2006 }
2007
2008 // For vector-scalar arithmetic operations, splat scalar into a vector. We
2009 // skip this for multiply as we can use OpVectorTimesScalar.
2010 const bool is_float_scalar_vector_multiply =
2011 expr->IsMultiply() &&
2012 ((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
2013 (lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
2014
2015 if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
2016 if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
2017 uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
2018 if (splat_vector_id == 0) {
2019 return 0;
2020 }
2021 rhs_id = splat_vector_id;
2022 rhs_type = lhs_type;
2023
2024 } else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
2025 uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
2026 if (splat_vector_id == 0) {
2027 return 0;
2028 }
2029 lhs_id = splat_vector_id;
2030 lhs_type = rhs_type;
2031 }
2032 }
2033
2034 bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
2035 bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
2036 bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
2037 bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
2038
2039 spv::Op op = spv::Op::OpNop;
2040 if (expr->IsAnd()) {
2041 if (lhs_is_integer_or_vec) {
2042 op = spv::Op::OpBitwiseAnd;
2043 } else if (lhs_is_bool_or_vec) {
2044 op = spv::Op::OpLogicalAnd;
2045 } else {
2046 error_ = "invalid and expression";
2047 return 0;
2048 }
2049 } else if (expr->IsAdd()) {
2050 op = lhs_is_float_or_vec ? spv::Op::OpFAdd : spv::Op::OpIAdd;
2051 } else if (expr->IsDivide()) {
2052 if (lhs_is_float_or_vec) {
2053 op = spv::Op::OpFDiv;
2054 } else if (lhs_is_unsigned) {
2055 op = spv::Op::OpUDiv;
2056 } else {
2057 op = spv::Op::OpSDiv;
2058 }
2059 } else if (expr->IsEqual()) {
2060 if (lhs_is_float_or_vec) {
2061 op = spv::Op::OpFOrdEqual;
2062 } else if (lhs_is_bool_or_vec) {
2063 op = spv::Op::OpLogicalEqual;
2064 } else if (lhs_is_integer_or_vec) {
2065 op = spv::Op::OpIEqual;
2066 } else {
2067 error_ = "invalid equal expression";
2068 return 0;
2069 }
2070 } else if (expr->IsGreaterThan()) {
2071 if (lhs_is_float_or_vec) {
2072 op = spv::Op::OpFOrdGreaterThan;
2073 } else if (lhs_is_unsigned) {
2074 op = spv::Op::OpUGreaterThan;
2075 } else {
2076 op = spv::Op::OpSGreaterThan;
2077 }
2078 } else if (expr->IsGreaterThanEqual()) {
2079 if (lhs_is_float_or_vec) {
2080 op = spv::Op::OpFOrdGreaterThanEqual;
2081 } else if (lhs_is_unsigned) {
2082 op = spv::Op::OpUGreaterThanEqual;
2083 } else {
2084 op = spv::Op::OpSGreaterThanEqual;
2085 }
2086 } else if (expr->IsLessThan()) {
2087 if (lhs_is_float_or_vec) {
2088 op = spv::Op::OpFOrdLessThan;
2089 } else if (lhs_is_unsigned) {
2090 op = spv::Op::OpULessThan;
2091 } else {
2092 op = spv::Op::OpSLessThan;
2093 }
2094 } else if (expr->IsLessThanEqual()) {
2095 if (lhs_is_float_or_vec) {
2096 op = spv::Op::OpFOrdLessThanEqual;
2097 } else if (lhs_is_unsigned) {
2098 op = spv::Op::OpULessThanEqual;
2099 } else {
2100 op = spv::Op::OpSLessThanEqual;
2101 }
2102 } else if (expr->IsModulo()) {
2103 if (lhs_is_float_or_vec) {
2104 op = spv::Op::OpFRem;
2105 } else if (lhs_is_unsigned) {
2106 op = spv::Op::OpUMod;
2107 } else {
2108 op = spv::Op::OpSMod;
2109 }
2110 } else if (expr->IsMultiply()) {
2111 if (lhs_type->is_integer_scalar_or_vector()) {
2112 // If the left hand side is an integer then this _has_ to be OpIMul as
2113 // there there is no other integer multiplication.
2114 op = spv::Op::OpIMul;
2115 } else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
2116 // Float scalars multiply with OpFMul
2117 op = spv::Op::OpFMul;
2118 } else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
2119 // Float vectors must be validated to be the same size and then use OpFMul
2120 op = spv::Op::OpFMul;
2121 } else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
2122 // Scalar * Vector we need to flip lhs and rhs types
2123 // because OpVectorTimesScalar expects <vector>, <scalar>
2124 std::swap(lhs_id, rhs_id);
2125 op = spv::Op::OpVectorTimesScalar;
2126 } else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
2127 // float vector * scalar
2128 op = spv::Op::OpVectorTimesScalar;
2129 } else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
2130 // Scalar * Matrix we need to flip lhs and rhs types because
2131 // OpMatrixTimesScalar expects <matrix>, <scalar>
2132 std::swap(lhs_id, rhs_id);
2133 op = spv::Op::OpMatrixTimesScalar;
2134 } else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
2135 // float matrix * scalar
2136 op = spv::Op::OpMatrixTimesScalar;
2137 } else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
2138 // float vector * matrix
2139 op = spv::Op::OpVectorTimesMatrix;
2140 } else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
2141 // float matrix * vector
2142 op = spv::Op::OpMatrixTimesVector;
2143 } else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
2144 // float matrix * matrix
2145 op = spv::Op::OpMatrixTimesMatrix;
2146 } else {
2147 error_ = "invalid multiply expression";
2148 return 0;
2149 }
2150 } else if (expr->IsNotEqual()) {
2151 if (lhs_is_float_or_vec) {
2152 op = spv::Op::OpFOrdNotEqual;
2153 } else if (lhs_is_bool_or_vec) {
2154 op = spv::Op::OpLogicalNotEqual;
2155 } else if (lhs_is_integer_or_vec) {
2156 op = spv::Op::OpINotEqual;
2157 } else {
2158 error_ = "invalid not-equal expression";
2159 return 0;
2160 }
2161 } else if (expr->IsOr()) {
2162 if (lhs_is_integer_or_vec) {
2163 op = spv::Op::OpBitwiseOr;
2164 } else if (lhs_is_bool_or_vec) {
2165 op = spv::Op::OpLogicalOr;
2166 } else {
2167 error_ = "invalid and expression";
2168 return 0;
2169 }
2170 } else if (expr->IsShiftLeft()) {
2171 op = spv::Op::OpShiftLeftLogical;
2172 } else if (expr->IsShiftRight() && lhs_type->is_signed_scalar_or_vector()) {
2173 // A shift right with a signed LHS is an arithmetic shift.
2174 op = spv::Op::OpShiftRightArithmetic;
2175 } else if (expr->IsShiftRight()) {
2176 op = spv::Op::OpShiftRightLogical;
2177 } else if (expr->IsSubtract()) {
2178 op = lhs_is_float_or_vec ? spv::Op::OpFSub : spv::Op::OpISub;
2179 } else if (expr->IsXor()) {
2180 op = spv::Op::OpBitwiseXor;
2181 } else {
2182 error_ = "unknown binary expression";
2183 return 0;
2184 }
2185
2186 if (!push_function_inst(op, {Operand::Int(type_id), result,
2187 Operand::Int(lhs_id), Operand::Int(rhs_id)})) {
2188 return 0;
2189 }
2190 return result_id;
2191 }
2192
GenerateBlockStatement(const ast::BlockStatement * stmt)2193 bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
2194 scope_stack_.Push();
2195 TINT_DEFER(scope_stack_.Pop());
2196 return GenerateBlockStatementWithoutScoping(stmt);
2197 }
2198
GenerateBlockStatementWithoutScoping(const ast::BlockStatement * stmt)2199 bool Builder::GenerateBlockStatementWithoutScoping(
2200 const ast::BlockStatement* stmt) {
2201 for (auto* block_stmt : stmt->statements) {
2202 if (!GenerateStatement(block_stmt)) {
2203 return false;
2204 }
2205 }
2206 return true;
2207 }
2208
GenerateCallExpression(const ast::CallExpression * expr)2209 uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
2210 auto* call = builder_.Sem().Get(expr);
2211 auto* target = call->Target();
2212
2213 if (auto* func = target->As<sem::Function>()) {
2214 return GenerateFunctionCall(call, func);
2215 }
2216 if (auto* intrinsic = target->As<sem::Intrinsic>()) {
2217 return GenerateIntrinsicCall(call, intrinsic);
2218 }
2219 if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
2220 return GenerateTypeConstructorOrConversion(call, nullptr);
2221 }
2222 TINT_ICE(Writer, builder_.Diagnostics())
2223 << "unhandled call target: " << target->TypeInfo().name;
2224 return false;
2225 }
2226
GenerateFunctionCall(const sem::Call * call,const sem::Function *)2227 uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
2228 const sem::Function*) {
2229 auto* expr = call->Declaration();
2230 auto* ident = expr->target.name;
2231
2232 auto type_id = GenerateTypeIfNeeded(call->Type());
2233 if (type_id == 0) {
2234 return 0;
2235 }
2236
2237 auto result = result_op();
2238 auto result_id = result.to_i();
2239
2240 OperandList ops = {Operand::Int(type_id), result};
2241
2242 auto func_id = func_symbol_to_id_[ident->symbol];
2243 if (func_id == 0) {
2244 error_ = "unable to find called function: " +
2245 builder_.Symbols().NameFor(ident->symbol);
2246 return 0;
2247 }
2248 ops.push_back(Operand::Int(func_id));
2249
2250 size_t arg_idx = 0;
2251 for (auto* arg : expr->args) {
2252 auto id = GenerateExpression(arg);
2253 if (id == 0) {
2254 return 0;
2255 }
2256 id = GenerateLoadIfNeeded(TypeOf(arg), id);
2257 if (id == 0) {
2258 return 0;
2259 }
2260 ops.push_back(Operand::Int(id));
2261 arg_idx++;
2262 }
2263
2264 if (!push_function_inst(spv::Op::OpFunctionCall, std::move(ops))) {
2265 return 0;
2266 }
2267
2268 return result_id;
2269 }
2270
GenerateIntrinsicCall(const sem::Call * call,const sem::Intrinsic * intrinsic)2271 uint32_t Builder::GenerateIntrinsicCall(const sem::Call* call,
2272 const sem::Intrinsic* intrinsic) {
2273 auto result = result_op();
2274 auto result_id = result.to_i();
2275
2276 auto result_type_id = GenerateTypeIfNeeded(intrinsic->ReturnType());
2277 if (result_type_id == 0) {
2278 return 0;
2279 }
2280
2281 if (intrinsic->IsFineDerivative() || intrinsic->IsCoarseDerivative()) {
2282 push_capability(SpvCapabilityDerivativeControl);
2283 }
2284
2285 if (intrinsic->IsImageQuery()) {
2286 push_capability(SpvCapabilityImageQuery);
2287 }
2288
2289 if (intrinsic->IsTexture()) {
2290 if (!GenerateTextureIntrinsic(call, intrinsic, Operand::Int(result_type_id),
2291 result)) {
2292 return 0;
2293 }
2294 return result_id;
2295 }
2296
2297 if (intrinsic->IsBarrier()) {
2298 if (!GenerateControlBarrierIntrinsic(intrinsic)) {
2299 return 0;
2300 }
2301 return result_id;
2302 }
2303
2304 if (intrinsic->IsAtomic()) {
2305 if (!GenerateAtomicIntrinsic(call, intrinsic, Operand::Int(result_type_id),
2306 result)) {
2307 return 0;
2308 }
2309 return result_id;
2310 }
2311
2312 // Generates the SPIR-V ID for the expression for the indexed call argument,
2313 // and loads it if necessary. Returns 0 on error.
2314 auto get_arg_as_value_id = [&](size_t i,
2315 bool generate_load = true) -> uint32_t {
2316 auto* arg = call->Arguments()[i];
2317 auto* param = intrinsic->Parameters()[i];
2318 auto val_id = GenerateExpression(arg->Declaration());
2319 if (val_id == 0) {
2320 return 0;
2321 }
2322
2323 if (generate_load && !param->Type()->Is<sem::Pointer>()) {
2324 val_id = GenerateLoadIfNeeded(arg->Type(), val_id);
2325 }
2326 return val_id;
2327 };
2328
2329 OperandList params = {Operand::Int(result_type_id), result};
2330 spv::Op op = spv::Op::OpNop;
2331
2332 // Pushes the arguments for a GlslStd450 extended instruction, and sets op
2333 // to OpExtInst.
2334 auto glsl_std450 = [&](uint32_t inst_id) {
2335 auto set_id = GetGLSLstd450Import();
2336 params.push_back(Operand::Int(set_id));
2337 params.push_back(Operand::Int(inst_id));
2338 op = spv::Op::OpExtInst;
2339 };
2340
2341 switch (intrinsic->Type()) {
2342 case IntrinsicType::kAny:
2343 if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
2344 // any(v: bool) just resolves to v.
2345 return get_arg_as_value_id(0);
2346 }
2347 op = spv::Op::OpAny;
2348 break;
2349 case IntrinsicType::kAll:
2350 if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
2351 // all(v: bool) just resolves to v.
2352 return get_arg_as_value_id(0);
2353 }
2354 op = spv::Op::OpAll;
2355 break;
2356 case IntrinsicType::kArrayLength: {
2357 auto* address_of =
2358 call->Arguments()[0]->Declaration()->As<ast::UnaryOpExpression>();
2359 if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
2360 error_ = "arrayLength() expected pointer to member access, got " +
2361 std::string(address_of->TypeInfo().name);
2362 return 0;
2363 }
2364 auto* array_expr = address_of->expr;
2365
2366 auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
2367 if (!accessor) {
2368 error_ =
2369 "arrayLength() expected pointer to member access, got pointer to " +
2370 std::string(array_expr->TypeInfo().name);
2371 return 0;
2372 }
2373
2374 auto struct_id = GenerateExpression(accessor->structure);
2375 if (struct_id == 0) {
2376 return 0;
2377 }
2378 params.push_back(Operand::Int(struct_id));
2379
2380 auto* type = TypeOf(accessor->structure)->UnwrapRef();
2381 if (!type->Is<sem::Struct>()) {
2382 error_ =
2383 "invalid type (" + type->type_name() + ") for runtime array length";
2384 return 0;
2385 }
2386 // Runtime array must be the last member in the structure
2387 params.push_back(Operand::Int(uint32_t(
2388 type->As<sem::Struct>()->Declaration()->members.size() - 1)));
2389
2390 if (!push_function_inst(spv::Op::OpArrayLength, params)) {
2391 return 0;
2392 }
2393 return result_id;
2394 }
2395 case IntrinsicType::kCountOneBits:
2396 op = spv::Op::OpBitCount;
2397 break;
2398 case IntrinsicType::kDot: {
2399 op = spv::Op::OpDot;
2400 auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
2401 if (vec_ty->type()->is_integer_scalar()) {
2402 // TODO(crbug.com/tint/1267): OpDot requires floating-point types, but
2403 // WGSL also supports integer types. SPV_KHR_integer_dot_product adds
2404 // support for integer vectors. Use it if it is available.
2405 auto el_ty = Operand::Int(GenerateTypeIfNeeded(vec_ty->type()));
2406 auto vec_a = Operand::Int(get_arg_as_value_id(0));
2407 auto vec_b = Operand::Int(get_arg_as_value_id(1));
2408 if (vec_a.to_i() == 0 || vec_b.to_i() == 0) {
2409 return 0;
2410 }
2411
2412 auto sum = Operand::Int(0);
2413 for (uint32_t i = 0; i < vec_ty->Width(); i++) {
2414 auto a = result_op();
2415 auto b = result_op();
2416 auto mul = result_op();
2417 if (!push_function_inst(spv::Op::OpCompositeExtract,
2418 {el_ty, a, vec_a, Operand::Int(i)}) ||
2419 !push_function_inst(spv::Op::OpCompositeExtract,
2420 {el_ty, b, vec_b, Operand::Int(i)}) ||
2421 !push_function_inst(spv::Op::OpIMul, {el_ty, mul, a, b})) {
2422 return 0;
2423 }
2424 if (i == 0) {
2425 sum = mul;
2426 } else {
2427 auto prev_sum = sum;
2428 auto is_last_el = i == (vec_ty->Width() - 1);
2429 sum = is_last_el ? Operand::Int(result_id) : result_op();
2430 if (!push_function_inst(spv::Op::OpIAdd,
2431 {el_ty, sum, prev_sum, mul})) {
2432 return 0;
2433 }
2434 }
2435 }
2436 return result_id;
2437 }
2438 break;
2439 }
2440 case IntrinsicType::kDpdx:
2441 op = spv::Op::OpDPdx;
2442 break;
2443 case IntrinsicType::kDpdxCoarse:
2444 op = spv::Op::OpDPdxCoarse;
2445 break;
2446 case IntrinsicType::kDpdxFine:
2447 op = spv::Op::OpDPdxFine;
2448 break;
2449 case IntrinsicType::kDpdy:
2450 op = spv::Op::OpDPdy;
2451 break;
2452 case IntrinsicType::kDpdyCoarse:
2453 op = spv::Op::OpDPdyCoarse;
2454 break;
2455 case IntrinsicType::kDpdyFine:
2456 op = spv::Op::OpDPdyFine;
2457 break;
2458 case IntrinsicType::kFwidth:
2459 op = spv::Op::OpFwidth;
2460 break;
2461 case IntrinsicType::kFwidthCoarse:
2462 op = spv::Op::OpFwidthCoarse;
2463 break;
2464 case IntrinsicType::kFwidthFine:
2465 op = spv::Op::OpFwidthFine;
2466 break;
2467 case IntrinsicType::kIgnore: // [DEPRECATED]
2468 // Evaluate the single argument, return the non-zero result_id which isn't
2469 // associated with any op (ignore returns void, so this cannot be used in
2470 // an expression).
2471 if (!get_arg_as_value_id(0, false)) {
2472 return 0;
2473 }
2474 return result_id;
2475 case IntrinsicType::kIsInf:
2476 op = spv::Op::OpIsInf;
2477 break;
2478 case IntrinsicType::kIsNan:
2479 op = spv::Op::OpIsNan;
2480 break;
2481 case IntrinsicType::kIsFinite: {
2482 // Implemented as: not(IsInf or IsNan)
2483 auto val_id = get_arg_as_value_id(0);
2484 if (!val_id) {
2485 return 0;
2486 }
2487 auto inf_result = result_op();
2488 auto nan_result = result_op();
2489 auto or_result = result_op();
2490 if (push_function_inst(spv::Op::OpIsInf,
2491 {Operand::Int(result_type_id), inf_result,
2492 Operand::Int(val_id)}) &&
2493 push_function_inst(spv::Op::OpIsNan,
2494 {Operand::Int(result_type_id), nan_result,
2495 Operand::Int(val_id)}) &&
2496 push_function_inst(spv::Op::OpLogicalOr,
2497 {Operand::Int(result_type_id), or_result,
2498 Operand::Int(inf_result.to_i()),
2499 Operand::Int(nan_result.to_i())}) &&
2500 push_function_inst(spv::Op::OpLogicalNot,
2501 {Operand::Int(result_type_id), result,
2502 Operand::Int(or_result.to_i())})) {
2503 return result_id;
2504 }
2505 return 0;
2506 }
2507 case IntrinsicType::kIsNormal: {
2508 // A normal number is finite, non-zero, and not subnormal.
2509 // Its exponent is neither of the extreme possible values.
2510 // Implemented as:
2511 // exponent_bits = bitcast<u32>(f);
2512 // clamped = uclamp(1,254,exponent_bits);
2513 // result = (clamped == exponent_bits);
2514 //
2515 auto val_id = get_arg_as_value_id(0);
2516 if (!val_id) {
2517 return 0;
2518 }
2519
2520 // These parameters are valid for IEEE 754 binary32
2521 const uint32_t kExponentMask = 0x7f80000;
2522 const uint32_t kMinNormalExponent = 0x0080000;
2523 const uint32_t kMaxNormalExponent = 0x7f00000;
2524
2525 auto set_id = GetGLSLstd450Import();
2526 auto* u32 = builder_.create<sem::U32>();
2527
2528 auto unsigned_id = GenerateTypeIfNeeded(u32);
2529 auto exponent_mask_id =
2530 GenerateConstantIfNeeded(ScalarConstant::U32(kExponentMask));
2531 auto min_exponent_id =
2532 GenerateConstantIfNeeded(ScalarConstant::U32(kMinNormalExponent));
2533 auto max_exponent_id =
2534 GenerateConstantIfNeeded(ScalarConstant::U32(kMaxNormalExponent));
2535 if (auto* fvec_ty = intrinsic->ReturnType()->As<sem::Vector>()) {
2536 // In the vector case, update the unsigned type to a vector type of the
2537 // same size, and create vector constants by replicating the scalars.
2538 // I expect backend compilers to fold these into unique constants, so
2539 // there is no loss of efficiency.
2540 auto* uvec_ty = builder_.create<sem::Vector>(u32, fvec_ty->Width());
2541 unsigned_id = GenerateTypeIfNeeded(uvec_ty);
2542 auto splat = [&](uint32_t scalar_id) -> uint32_t {
2543 auto splat_result = result_op();
2544 OperandList splat_params{Operand::Int(unsigned_id), splat_result};
2545 for (size_t i = 0; i < fvec_ty->Width(); i++) {
2546 splat_params.emplace_back(Operand::Int(scalar_id));
2547 }
2548 if (!push_function_inst(spv::Op::OpCompositeConstruct,
2549 std::move(splat_params))) {
2550 return 0;
2551 }
2552 return splat_result.to_i();
2553 };
2554 exponent_mask_id = splat(exponent_mask_id);
2555 min_exponent_id = splat(min_exponent_id);
2556 max_exponent_id = splat(max_exponent_id);
2557 }
2558 auto cast_result = result_op();
2559 auto exponent_bits_result = result_op();
2560 auto clamp_result = result_op();
2561
2562 if (set_id && unsigned_id && exponent_mask_id && min_exponent_id &&
2563 max_exponent_id &&
2564 push_function_inst(
2565 spv::Op::OpBitcast,
2566 {Operand::Int(unsigned_id), cast_result, Operand::Int(val_id)}) &&
2567 push_function_inst(spv::Op::OpBitwiseAnd,
2568 {Operand::Int(unsigned_id), exponent_bits_result,
2569 Operand::Int(cast_result.to_i()),
2570 Operand::Int(exponent_mask_id)}) &&
2571 push_function_inst(
2572 spv::Op::OpExtInst,
2573 {Operand::Int(unsigned_id), clamp_result, Operand::Int(set_id),
2574 Operand::Int(GLSLstd450UClamp),
2575 Operand::Int(exponent_bits_result.to_i()),
2576 Operand::Int(min_exponent_id), Operand::Int(max_exponent_id)}) &&
2577 push_function_inst(spv::Op::OpIEqual,
2578 {Operand::Int(result_type_id), result,
2579 Operand::Int(exponent_bits_result.to_i()),
2580 Operand::Int(clamp_result.to_i())})) {
2581 return result_id;
2582 }
2583 return 0;
2584 }
2585 case IntrinsicType::kMix: {
2586 auto std450 = Operand::Int(GetGLSLstd450Import());
2587
2588 auto a_id = get_arg_as_value_id(0);
2589 auto b_id = get_arg_as_value_id(1);
2590 auto f_id = get_arg_as_value_id(2);
2591 if (!a_id || !b_id || !f_id) {
2592 return 0;
2593 }
2594
2595 // If the interpolant is scalar but the objects are vectors, we need to
2596 // splat the interpolant into a vector of the same size.
2597 auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
2598 if (result_vector_type &&
2599 intrinsic->Parameters()[2]->Type()->is_scalar()) {
2600 f_id = GenerateSplat(f_id, intrinsic->Parameters()[0]->Type());
2601 if (f_id == 0) {
2602 return 0;
2603 }
2604 }
2605
2606 if (!push_function_inst(spv::Op::OpExtInst,
2607 {Operand::Int(result_type_id), result, std450,
2608 Operand::Int(GLSLstd450FMix), Operand::Int(a_id),
2609 Operand::Int(b_id), Operand::Int(f_id)})) {
2610 return 0;
2611 }
2612 return result_id;
2613 }
2614 case IntrinsicType::kReverseBits:
2615 op = spv::Op::OpBitReverse;
2616 break;
2617 case IntrinsicType::kSelect: {
2618 // Note: Argument order is different in WGSL and SPIR-V
2619 auto cond_id = get_arg_as_value_id(2);
2620 auto true_id = get_arg_as_value_id(1);
2621 auto false_id = get_arg_as_value_id(0);
2622 if (!cond_id || !true_id || !false_id) {
2623 return 0;
2624 }
2625
2626 // If the condition is scalar but the objects are vectors, we need to
2627 // splat the condition into a vector of the same size.
2628 // TODO(jrprice): If we're targeting SPIR-V 1.4, we don't need to do this.
2629 auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
2630 if (result_vector_type &&
2631 intrinsic->Parameters()[2]->Type()->is_scalar()) {
2632 auto* bool_vec_ty = builder_.create<sem::Vector>(
2633 builder_.create<sem::Bool>(), result_vector_type->Width());
2634 if (!GenerateTypeIfNeeded(bool_vec_ty)) {
2635 return 0;
2636 }
2637 cond_id = GenerateSplat(cond_id, bool_vec_ty);
2638 if (cond_id == 0) {
2639 return 0;
2640 }
2641 }
2642
2643 if (!push_function_inst(
2644 spv::Op::OpSelect,
2645 {Operand::Int(result_type_id), result, Operand::Int(cond_id),
2646 Operand::Int(true_id), Operand::Int(false_id)})) {
2647 return 0;
2648 }
2649 return result_id;
2650 }
2651 case IntrinsicType::kTranspose:
2652 op = spv::Op::OpTranspose;
2653 break;
2654 case IntrinsicType::kAbs:
2655 if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
2656 // abs() only operates on *signed* integers.
2657 // This is a no-op for unsigned integers.
2658 return get_arg_as_value_id(0);
2659 }
2660 if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
2661 glsl_std450(GLSLstd450FAbs);
2662 } else {
2663 glsl_std450(GLSLstd450SAbs);
2664 }
2665 break;
2666 default: {
2667 auto inst_id = intrinsic_to_glsl_method(intrinsic);
2668 if (inst_id == 0) {
2669 error_ = "unknown method " + std::string(intrinsic->str());
2670 return 0;
2671 }
2672 glsl_std450(inst_id);
2673 break;
2674 }
2675 }
2676
2677 if (op == spv::Op::OpNop) {
2678 error_ =
2679 "unable to determine operator for: " + std::string(intrinsic->str());
2680 return 0;
2681 }
2682
2683 for (size_t i = 0; i < call->Arguments().size(); i++) {
2684 if (auto val_id = get_arg_as_value_id(i)) {
2685 params.emplace_back(Operand::Int(val_id));
2686 } else {
2687 return 0;
2688 }
2689 }
2690
2691 if (!push_function_inst(op, params)) {
2692 return 0;
2693 }
2694
2695 return result_id;
2696 }
2697
GenerateTextureIntrinsic(const sem::Call * call,const sem::Intrinsic * intrinsic,Operand result_type,Operand result_id)2698 bool Builder::GenerateTextureIntrinsic(const sem::Call* call,
2699 const sem::Intrinsic* intrinsic,
2700 Operand result_type,
2701 Operand result_id) {
2702 using Usage = sem::ParameterUsage;
2703
2704 auto& signature = intrinsic->Signature();
2705 auto& arguments = call->Arguments();
2706
2707 // Generates the given expression, returning the operand ID
2708 auto gen = [&](const sem::Expression* expr) {
2709 auto val_id = GenerateExpression(expr->Declaration());
2710 if (val_id == 0) {
2711 return Operand::Int(0);
2712 }
2713 val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
2714
2715 return Operand::Int(val_id);
2716 };
2717
2718 // Returns the argument with the given usage
2719 auto arg = [&](Usage usage) {
2720 int idx = signature.IndexOf(usage);
2721 return (idx >= 0) ? arguments[idx] : nullptr;
2722 };
2723
2724 // Generates the argument with the given usage, returning the operand ID
2725 auto gen_arg = [&](Usage usage) {
2726 auto* argument = arg(usage);
2727 if (!argument) {
2728 TINT_ICE(Writer, builder_.Diagnostics())
2729 << "missing argument " << static_cast<int>(usage);
2730 }
2731 return gen(argument);
2732 };
2733
2734 auto* texture = arg(Usage::kTexture);
2735 if (!texture) {
2736 TINT_ICE(Writer, builder_.Diagnostics()) << "missing texture argument";
2737 }
2738
2739 auto* texture_type = texture->Type()->UnwrapRef()->As<sem::Texture>();
2740
2741 auto op = spv::Op::OpNop;
2742
2743 // Custom function to call after the texture-intrinsic op has been generated.
2744 std::function<bool()> post_emission = [] { return true; };
2745
2746 // Populate the spirv_params with common parameters
2747 OperandList spirv_params;
2748 spirv_params.reserve(8); // Enough to fit most parameter lists
2749
2750 // Extra image operands, appended to spirv_params.
2751 struct ImageOperand {
2752 SpvImageOperandsMask mask;
2753 Operand operand;
2754 };
2755 std::vector<ImageOperand> image_operands;
2756 image_operands.reserve(4); // Enough to fit most parameter lists
2757
2758 // Appends `result_type` and `result_id` to `spirv_params`
2759 auto append_result_type_and_id_to_spirv_params = [&]() {
2760 spirv_params.emplace_back(std::move(result_type));
2761 spirv_params.emplace_back(std::move(result_id));
2762 };
2763
2764 // Appends a result type and id to `spirv_params`, possibly adding a
2765 // post_emission step.
2766 //
2767 // If the texture is a depth texture, then this function wraps the result of
2768 // the op with a OpCompositeExtract to evaluate to the first element of the
2769 // returned vector. This is done as the WGSL texture reading functions for
2770 // depths return a single float scalar instead of a vector.
2771 //
2772 // If the texture is not a depth texture, then this function simply delegates
2773 // to calling append_result_type_and_id_to_spirv_params().
2774 auto append_result_type_and_id_to_spirv_params_for_read = [&]() {
2775 if (texture_type
2776 ->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
2777 auto* f32 = builder_.create<sem::F32>();
2778 auto* spirv_result_type = builder_.create<sem::Vector>(f32, 4);
2779 auto spirv_result = result_op();
2780 post_emission = [=] {
2781 return push_function_inst(
2782 spv::Op::OpCompositeExtract,
2783 {result_type, result_id, spirv_result, Operand::Int(0)});
2784 };
2785 auto spirv_result_type_id = GenerateTypeIfNeeded(spirv_result_type);
2786 if (spirv_result_type_id == 0) {
2787 return false;
2788 }
2789 spirv_params.emplace_back(Operand::Int(spirv_result_type_id));
2790 spirv_params.emplace_back(spirv_result);
2791 return true;
2792 }
2793
2794 append_result_type_and_id_to_spirv_params();
2795 return true;
2796 };
2797
2798 // Appends a result type and id to `spirv_params`, by first swizzling the
2799 // result of the op with `swizzle`.
2800 auto append_result_type_and_id_to_spirv_params_swizzled =
2801 [&](uint32_t spirv_result_width, std::vector<uint32_t> swizzle) {
2802 if (swizzle.empty()) {
2803 append_result_type_and_id_to_spirv_params();
2804 } else {
2805 // Assign post_emission to swizzle the result of the call to
2806 // OpImageQuerySize[Lod].
2807 auto* element_type = ElementTypeOf(call->Type());
2808 auto spirv_result = result_op();
2809 auto* spirv_result_type =
2810 builder_.create<sem::Vector>(element_type, spirv_result_width);
2811 if (swizzle.size() > 1) {
2812 post_emission = [=] {
2813 OperandList operands{
2814 result_type,
2815 result_id,
2816 spirv_result,
2817 spirv_result,
2818 };
2819 for (auto idx : swizzle) {
2820 operands.emplace_back(Operand::Int(idx));
2821 }
2822 return push_function_inst(spv::Op::OpVectorShuffle, operands);
2823 };
2824 } else {
2825 post_emission = [=] {
2826 return push_function_inst(spv::Op::OpCompositeExtract,
2827 {result_type, result_id, spirv_result,
2828 Operand::Int(swizzle[0])});
2829 };
2830 }
2831 auto spirv_result_type_id = GenerateTypeIfNeeded(spirv_result_type);
2832 if (spirv_result_type_id == 0) {
2833 return false;
2834 }
2835 spirv_params.emplace_back(Operand::Int(spirv_result_type_id));
2836 spirv_params.emplace_back(spirv_result);
2837 }
2838 return true;
2839 };
2840
2841 auto append_coords_to_spirv_params = [&]() -> bool {
2842 if (auto* array_index = arg(Usage::kArrayIndex)) {
2843 // Array index needs to be appended to the coordinates.
2844 auto* packed = AppendVector(&builder_, arg(Usage::kCoords)->Declaration(),
2845 array_index->Declaration());
2846 auto param = GenerateExpression(packed->Declaration());
2847 if (param == 0) {
2848 return false;
2849 }
2850 spirv_params.emplace_back(Operand::Int(param));
2851 } else {
2852 spirv_params.emplace_back(gen_arg(Usage::kCoords)); // coordinates
2853 }
2854 return true;
2855 };
2856
2857 auto append_image_and_coords_to_spirv_params = [&]() -> bool {
2858 auto sampler_param = gen_arg(Usage::kSampler);
2859 auto texture_param = gen_arg(Usage::kTexture);
2860 auto sampled_image =
2861 GenerateSampledImage(texture_type, texture_param, sampler_param);
2862
2863 // Populate the spirv_params with the common parameters
2864 spirv_params.emplace_back(Operand::Int(sampled_image)); // sampled image
2865 return append_coords_to_spirv_params();
2866 };
2867
2868 switch (intrinsic->Type()) {
2869 case IntrinsicType::kTextureDimensions: {
2870 // Number of returned elements from OpImageQuerySize[Lod] may not match
2871 // those of textureDimensions().
2872 // This might be due to an extra vector scalar describing the number of
2873 // array elements or textureDimensions() returning a vec3 for cubes
2874 // when only width / height is returned by OpImageQuerySize[Lod]
2875 // (see https://github.com/gpuweb/gpuweb/issues/1345).
2876 // Handle these mismatches by swizzling the returned vector.
2877 std::vector<uint32_t> swizzle;
2878 uint32_t spirv_dims = 0;
2879 switch (texture_type->dim()) {
2880 case ast::TextureDimension::kNone:
2881 error_ = "texture dimension is kNone";
2882 return false;
2883 case ast::TextureDimension::k1d:
2884 case ast::TextureDimension::k2d:
2885 case ast::TextureDimension::k3d:
2886 case ast::TextureDimension::kCube:
2887 break; // No swizzle needed
2888 case ast::TextureDimension::kCubeArray:
2889 case ast::TextureDimension::k2dArray:
2890 swizzle = {0, 1}; // Strip array index
2891 spirv_dims = 3; // [width, height, array_count]
2892 break;
2893 }
2894
2895 if (!append_result_type_and_id_to_spirv_params_swizzled(spirv_dims,
2896 swizzle)) {
2897 return false;
2898 }
2899
2900 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2901 if (texture_type->IsAnyOf<sem::MultisampledTexture, //
2902 sem::DepthMultisampledTexture, //
2903 sem::StorageTexture>()) {
2904 op = spv::Op::OpImageQuerySize;
2905 } else if (auto* level = arg(Usage::kLevel)) {
2906 op = spv::Op::OpImageQuerySizeLod;
2907 spirv_params.emplace_back(gen(level));
2908 } else {
2909 ast::SintLiteralExpression i32_0(ProgramID(), Source{}, 0);
2910 op = spv::Op::OpImageQuerySizeLod;
2911 spirv_params.emplace_back(
2912 Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0)));
2913 }
2914 break;
2915 }
2916 case IntrinsicType::kTextureNumLayers: {
2917 uint32_t spirv_dims = 0;
2918 switch (texture_type->dim()) {
2919 default:
2920 error_ = "texture is not arrayed";
2921 return false;
2922 case ast::TextureDimension::k2dArray:
2923 case ast::TextureDimension::kCubeArray:
2924 spirv_dims = 3;
2925 break;
2926 }
2927
2928 // OpImageQuerySize[Lod] packs the array count as the last element of the
2929 // returned vector. Extract this.
2930 if (!append_result_type_and_id_to_spirv_params_swizzled(
2931 spirv_dims, {spirv_dims - 1})) {
2932 return false;
2933 }
2934
2935 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2936
2937 if (texture_type->Is<sem::MultisampledTexture>() ||
2938 texture_type->Is<sem::StorageTexture>()) {
2939 op = spv::Op::OpImageQuerySize;
2940 } else {
2941 ast::SintLiteralExpression i32_0(ProgramID(), Source{}, 0);
2942 op = spv::Op::OpImageQuerySizeLod;
2943 spirv_params.emplace_back(
2944 Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0)));
2945 }
2946 break;
2947 }
2948 case IntrinsicType::kTextureNumLevels: {
2949 op = spv::Op::OpImageQueryLevels;
2950 append_result_type_and_id_to_spirv_params();
2951 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2952 break;
2953 }
2954 case IntrinsicType::kTextureNumSamples: {
2955 op = spv::Op::OpImageQuerySamples;
2956 append_result_type_and_id_to_spirv_params();
2957 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2958 break;
2959 }
2960 case IntrinsicType::kTextureLoad: {
2961 op = texture_type->Is<sem::StorageTexture>() ? spv::Op::OpImageRead
2962 : spv::Op::OpImageFetch;
2963 append_result_type_and_id_to_spirv_params_for_read();
2964 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2965 if (!append_coords_to_spirv_params()) {
2966 return false;
2967 }
2968
2969 if (auto* level = arg(Usage::kLevel)) {
2970 image_operands.emplace_back(
2971 ImageOperand{SpvImageOperandsLodMask, gen(level)});
2972 }
2973
2974 if (auto* sample_index = arg(Usage::kSampleIndex)) {
2975 image_operands.emplace_back(
2976 ImageOperand{SpvImageOperandsSampleMask, gen(sample_index)});
2977 }
2978
2979 break;
2980 }
2981 case IntrinsicType::kTextureStore: {
2982 op = spv::Op::OpImageWrite;
2983 spirv_params.emplace_back(gen_arg(Usage::kTexture));
2984 if (!append_coords_to_spirv_params()) {
2985 return false;
2986 }
2987 spirv_params.emplace_back(gen_arg(Usage::kValue));
2988 break;
2989 }
2990 case IntrinsicType::kTextureGather: {
2991 op = spv::Op::OpImageGather;
2992 append_result_type_and_id_to_spirv_params();
2993 if (!append_image_and_coords_to_spirv_params()) {
2994 return false;
2995 }
2996 if (signature.IndexOf(Usage::kComponent) < 0) {
2997 spirv_params.emplace_back(
2998 Operand::Int(GenerateConstantIfNeeded(ScalarConstant::I32(0))));
2999 } else {
3000 spirv_params.emplace_back(gen_arg(Usage::kComponent));
3001 }
3002 break;
3003 }
3004 case IntrinsicType::kTextureGatherCompare: {
3005 op = spv::Op::OpImageDrefGather;
3006 append_result_type_and_id_to_spirv_params();
3007 if (!append_image_and_coords_to_spirv_params()) {
3008 return false;
3009 }
3010 spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3011 break;
3012 }
3013 case IntrinsicType::kTextureSample: {
3014 op = spv::Op::OpImageSampleImplicitLod;
3015 append_result_type_and_id_to_spirv_params_for_read();
3016 if (!append_image_and_coords_to_spirv_params()) {
3017 return false;
3018 }
3019 break;
3020 }
3021 case IntrinsicType::kTextureSampleBias: {
3022 op = spv::Op::OpImageSampleImplicitLod;
3023 append_result_type_and_id_to_spirv_params_for_read();
3024 if (!append_image_and_coords_to_spirv_params()) {
3025 return false;
3026 }
3027 image_operands.emplace_back(
3028 ImageOperand{SpvImageOperandsBiasMask, gen_arg(Usage::kBias)});
3029 break;
3030 }
3031 case IntrinsicType::kTextureSampleLevel: {
3032 op = spv::Op::OpImageSampleExplicitLod;
3033 append_result_type_and_id_to_spirv_params_for_read();
3034 if (!append_image_and_coords_to_spirv_params()) {
3035 return false;
3036 }
3037 auto level = Operand::Int(0);
3038 if (arg(Usage::kLevel)->Type()->UnwrapRef()->Is<sem::I32>()) {
3039 // Depth textures have i32 parameters for the level, but SPIR-V expects
3040 // F32. Cast.
3041 auto f32_type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
3042 if (f32_type_id == 0) {
3043 return 0;
3044 }
3045 level = result_op();
3046 if (!push_function_inst(
3047 spv::Op::OpConvertSToF,
3048 {Operand::Int(f32_type_id), level, gen_arg(Usage::kLevel)})) {
3049 return 0;
3050 }
3051 } else {
3052 level = gen_arg(Usage::kLevel);
3053 }
3054 image_operands.emplace_back(ImageOperand{SpvImageOperandsLodMask, level});
3055 break;
3056 }
3057 case IntrinsicType::kTextureSampleGrad: {
3058 op = spv::Op::OpImageSampleExplicitLod;
3059 append_result_type_and_id_to_spirv_params_for_read();
3060 if (!append_image_and_coords_to_spirv_params()) {
3061 return false;
3062 }
3063 image_operands.emplace_back(
3064 ImageOperand{SpvImageOperandsGradMask, gen_arg(Usage::kDdx)});
3065 image_operands.emplace_back(
3066 ImageOperand{SpvImageOperandsGradMask, gen_arg(Usage::kDdy)});
3067 break;
3068 }
3069 case IntrinsicType::kTextureSampleCompare: {
3070 op = spv::Op::OpImageSampleDrefImplicitLod;
3071 append_result_type_and_id_to_spirv_params();
3072 if (!append_image_and_coords_to_spirv_params()) {
3073 return false;
3074 }
3075 spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3076 break;
3077 }
3078 case IntrinsicType::kTextureSampleCompareLevel: {
3079 op = spv::Op::OpImageSampleDrefExplicitLod;
3080 append_result_type_and_id_to_spirv_params();
3081 if (!append_image_and_coords_to_spirv_params()) {
3082 return false;
3083 }
3084 spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3085
3086 ast::FloatLiteralExpression float_0(ProgramID(), Source{}, 0.0);
3087 image_operands.emplace_back(ImageOperand{
3088 SpvImageOperandsLodMask,
3089 Operand::Int(GenerateLiteralIfNeeded(nullptr, &float_0))});
3090 break;
3091 }
3092 default:
3093 TINT_UNREACHABLE(Writer, builder_.Diagnostics());
3094 return false;
3095 }
3096
3097 if (auto* offset = arg(Usage::kOffset)) {
3098 image_operands.emplace_back(
3099 ImageOperand{SpvImageOperandsConstOffsetMask, gen(offset)});
3100 }
3101
3102 if (!image_operands.empty()) {
3103 std::sort(image_operands.begin(), image_operands.end(),
3104 [](auto& a, auto& b) { return a.mask < b.mask; });
3105 uint32_t mask = 0;
3106 for (auto& image_operand : image_operands) {
3107 mask |= image_operand.mask;
3108 }
3109 spirv_params.emplace_back(Operand::Int(mask));
3110 for (auto& image_operand : image_operands) {
3111 spirv_params.emplace_back(image_operand.operand);
3112 }
3113 }
3114
3115 if (op == spv::Op::OpNop) {
3116 error_ =
3117 "unable to determine operator for: " + std::string(intrinsic->str());
3118 return false;
3119 }
3120
3121 if (!push_function_inst(op, spirv_params)) {
3122 return false;
3123 }
3124
3125 return post_emission();
3126 }
3127
GenerateControlBarrierIntrinsic(const sem::Intrinsic * intrinsic)3128 bool Builder::GenerateControlBarrierIntrinsic(const sem::Intrinsic* intrinsic) {
3129 auto const op = spv::Op::OpControlBarrier;
3130 uint32_t execution = 0;
3131 uint32_t memory = 0;
3132 uint32_t semantics = 0;
3133
3134 // TODO(crbug.com/tint/661): Combine sequential barriers to a single
3135 // instruction.
3136 if (intrinsic->Type() == sem::IntrinsicType::kWorkgroupBarrier) {
3137 execution = static_cast<uint32_t>(spv::Scope::Workgroup);
3138 memory = static_cast<uint32_t>(spv::Scope::Workgroup);
3139 semantics =
3140 static_cast<uint32_t>(spv::MemorySemanticsMask::AcquireRelease) |
3141 static_cast<uint32_t>(spv::MemorySemanticsMask::WorkgroupMemory);
3142 } else if (intrinsic->Type() == sem::IntrinsicType::kStorageBarrier) {
3143 execution = static_cast<uint32_t>(spv::Scope::Workgroup);
3144 memory = static_cast<uint32_t>(spv::Scope::Workgroup);
3145 semantics =
3146 static_cast<uint32_t>(spv::MemorySemanticsMask::AcquireRelease) |
3147 static_cast<uint32_t>(spv::MemorySemanticsMask::UniformMemory);
3148 } else {
3149 error_ = "unexpected barrier intrinsic type ";
3150 error_ += sem::str(intrinsic->Type());
3151 return false;
3152 }
3153
3154 auto execution_id = GenerateConstantIfNeeded(ScalarConstant::U32(execution));
3155 auto memory_id = GenerateConstantIfNeeded(ScalarConstant::U32(memory));
3156 auto semantics_id = GenerateConstantIfNeeded(ScalarConstant::U32(semantics));
3157 if (execution_id == 0 || memory_id == 0 || semantics_id == 0) {
3158 return false;
3159 }
3160
3161 return push_function_inst(op, {
3162 Operand::Int(execution_id),
3163 Operand::Int(memory_id),
3164 Operand::Int(semantics_id),
3165 });
3166 }
3167
GenerateAtomicIntrinsic(const sem::Call * call,const sem::Intrinsic * intrinsic,Operand result_type,Operand result_id)3168 bool Builder::GenerateAtomicIntrinsic(const sem::Call* call,
3169 const sem::Intrinsic* intrinsic,
3170 Operand result_type,
3171 Operand result_id) {
3172 auto is_value_signed = [&] {
3173 return intrinsic->Parameters()[1]->Type()->Is<sem::I32>();
3174 };
3175
3176 auto storage_class =
3177 intrinsic->Parameters()[0]->Type()->As<sem::Pointer>()->StorageClass();
3178
3179 uint32_t memory_id = 0;
3180 switch (
3181 intrinsic->Parameters()[0]->Type()->As<sem::Pointer>()->StorageClass()) {
3182 case ast::StorageClass::kWorkgroup:
3183 memory_id = GenerateConstantIfNeeded(
3184 ScalarConstant::U32(static_cast<uint32_t>(spv::Scope::Workgroup)));
3185 break;
3186 case ast::StorageClass::kStorage:
3187 memory_id = GenerateConstantIfNeeded(
3188 ScalarConstant::U32(static_cast<uint32_t>(spv::Scope::Device)));
3189 break;
3190 default:
3191 TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3192 << "unhandled atomic storage class " << storage_class;
3193 return false;
3194 }
3195 if (memory_id == 0) {
3196 return false;
3197 }
3198
3199 uint32_t semantics_id = GenerateConstantIfNeeded(ScalarConstant::U32(
3200 static_cast<uint32_t>(spv::MemorySemanticsMask::MaskNone)));
3201 if (semantics_id == 0) {
3202 return false;
3203 }
3204
3205 uint32_t pointer_id = GenerateExpression(call->Arguments()[0]->Declaration());
3206 if (pointer_id == 0) {
3207 return false;
3208 }
3209
3210 uint32_t value_id = 0;
3211 if (call->Arguments().size() > 1) {
3212 value_id = GenerateExpression(call->Arguments().back()->Declaration());
3213 if (value_id == 0) {
3214 return false;
3215 }
3216 value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
3217 if (value_id == 0) {
3218 return false;
3219 }
3220 }
3221
3222 Operand pointer = Operand::Int(pointer_id);
3223 Operand value = Operand::Int(value_id);
3224 Operand memory = Operand::Int(memory_id);
3225 Operand semantics = Operand::Int(semantics_id);
3226
3227 switch (intrinsic->Type()) {
3228 case sem::IntrinsicType::kAtomicLoad:
3229 return push_function_inst(spv::Op::OpAtomicLoad, {
3230 result_type,
3231 result_id,
3232 pointer,
3233 memory,
3234 semantics,
3235 });
3236 case sem::IntrinsicType::kAtomicStore:
3237 return push_function_inst(spv::Op::OpAtomicStore, {
3238 pointer,
3239 memory,
3240 semantics,
3241 value,
3242 });
3243 case sem::IntrinsicType::kAtomicAdd:
3244 return push_function_inst(spv::Op::OpAtomicIAdd, {
3245 result_type,
3246 result_id,
3247 pointer,
3248 memory,
3249 semantics,
3250 value,
3251 });
3252 case sem::IntrinsicType::kAtomicSub:
3253 return push_function_inst(spv::Op::OpAtomicISub, {
3254 result_type,
3255 result_id,
3256 pointer,
3257 memory,
3258 semantics,
3259 value,
3260 });
3261 case sem::IntrinsicType::kAtomicMax:
3262 return push_function_inst(
3263 is_value_signed() ? spv::Op::OpAtomicSMax : spv::Op::OpAtomicUMax,
3264 {
3265 result_type,
3266 result_id,
3267 pointer,
3268 memory,
3269 semantics,
3270 value,
3271 });
3272 case sem::IntrinsicType::kAtomicMin:
3273 return push_function_inst(
3274 is_value_signed() ? spv::Op::OpAtomicSMin : spv::Op::OpAtomicUMin,
3275 {
3276 result_type,
3277 result_id,
3278 pointer,
3279 memory,
3280 semantics,
3281 value,
3282 });
3283 case sem::IntrinsicType::kAtomicAnd:
3284 return push_function_inst(spv::Op::OpAtomicAnd, {
3285 result_type,
3286 result_id,
3287 pointer,
3288 memory,
3289 semantics,
3290 value,
3291 });
3292 case sem::IntrinsicType::kAtomicOr:
3293 return push_function_inst(spv::Op::OpAtomicOr, {
3294 result_type,
3295 result_id,
3296 pointer,
3297 memory,
3298 semantics,
3299 value,
3300 });
3301 case sem::IntrinsicType::kAtomicXor:
3302 return push_function_inst(spv::Op::OpAtomicXor, {
3303 result_type,
3304 result_id,
3305 pointer,
3306 memory,
3307 semantics,
3308 value,
3309 });
3310 case sem::IntrinsicType::kAtomicExchange:
3311 return push_function_inst(spv::Op::OpAtomicExchange, {
3312 result_type,
3313 result_id,
3314 pointer,
3315 memory,
3316 semantics,
3317 value,
3318 });
3319 case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
3320 auto comparator = GenerateExpression(call->Arguments()[1]->Declaration());
3321 if (comparator == 0) {
3322 return false;
3323 }
3324
3325 auto* value_sem_type = TypeOf(call->Arguments()[2]->Declaration());
3326
3327 auto value_type = GenerateTypeIfNeeded(value_sem_type);
3328 if (value_type == 0) {
3329 return false;
3330 }
3331
3332 auto* bool_sem_ty = builder_.create<sem::Bool>();
3333 auto bool_type = GenerateTypeIfNeeded(bool_sem_ty);
3334 if (bool_type == 0) {
3335 return false;
3336 }
3337
3338 // original_value := OpAtomicCompareExchange(pointer, memory, semantics,
3339 // semantics, value, comparator)
3340 auto original_value = result_op();
3341 if (!push_function_inst(spv::Op::OpAtomicCompareExchange,
3342 {
3343 Operand::Int(value_type),
3344 original_value,
3345 pointer,
3346 memory,
3347 semantics,
3348 semantics,
3349 value,
3350 Operand::Int(comparator),
3351 })) {
3352 return false;
3353 }
3354
3355 // values_equal := original_value == value
3356 auto values_equal = result_op();
3357 if (!push_function_inst(spv::Op::OpIEqual, {
3358 Operand::Int(bool_type),
3359 values_equal,
3360 original_value,
3361 value,
3362 })) {
3363 return false;
3364 }
3365
3366 // zero := T(0)
3367 // one := T(1)
3368 uint32_t zero = 0;
3369 uint32_t one = 0;
3370 if (value_sem_type->Is<sem::I32>()) {
3371 zero = GenerateConstantIfNeeded(ScalarConstant::I32(0u));
3372 one = GenerateConstantIfNeeded(ScalarConstant::I32(1u));
3373 } else if (value_sem_type->Is<sem::U32>()) {
3374 zero = GenerateConstantIfNeeded(ScalarConstant::U32(0u));
3375 one = GenerateConstantIfNeeded(ScalarConstant::U32(1u));
3376 } else {
3377 TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3378 << "unsupported atomic type " << value_sem_type->TypeInfo().name;
3379 }
3380 if (zero == 0 || one == 0) {
3381 return false;
3382 }
3383
3384 // xchg_success := values_equal ? one : zero
3385 auto xchg_success = result_op();
3386 if (!push_function_inst(spv::Op::OpSelect, {
3387 Operand::Int(value_type),
3388 xchg_success,
3389 values_equal,
3390 Operand::Int(one),
3391 Operand::Int(zero),
3392 })) {
3393 return false;
3394 }
3395
3396 // result := vec2<T>(original_value, xchg_success)
3397 return push_function_inst(spv::Op::OpCompositeConstruct,
3398 {
3399 result_type,
3400 result_id,
3401 original_value,
3402 xchg_success,
3403 });
3404 }
3405 default:
3406 TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3407 << "unhandled atomic intrinsic " << intrinsic->Type();
3408 return false;
3409 }
3410 }
3411
GenerateSampledImage(const sem::Type * texture_type,Operand texture_operand,Operand sampler_operand)3412 uint32_t Builder::GenerateSampledImage(const sem::Type* texture_type,
3413 Operand texture_operand,
3414 Operand sampler_operand) {
3415 uint32_t sampled_image_type_id = 0;
3416 auto val = texture_type_name_to_sampled_image_type_id_.find(
3417 texture_type->type_name());
3418 if (val != texture_type_name_to_sampled_image_type_id_.end()) {
3419 // The sampled image type is already created.
3420 sampled_image_type_id = val->second;
3421 } else {
3422 // We need to create the sampled image type and cache the result.
3423 auto sampled_image_type = result_op();
3424 sampled_image_type_id = sampled_image_type.to_i();
3425 auto texture_type_id = GenerateTypeIfNeeded(texture_type);
3426 push_type(spv::Op::OpTypeSampledImage,
3427 {sampled_image_type, Operand::Int(texture_type_id)});
3428 texture_type_name_to_sampled_image_type_id_[texture_type->type_name()] =
3429 sampled_image_type_id;
3430 }
3431
3432 auto sampled_image = result_op();
3433 if (!push_function_inst(spv::Op::OpSampledImage,
3434 {Operand::Int(sampled_image_type_id), sampled_image,
3435 texture_operand, sampler_operand})) {
3436 return 0;
3437 }
3438
3439 return sampled_image.to_i();
3440 }
3441
GenerateBitcastExpression(const ast::BitcastExpression * expr)3442 uint32_t Builder::GenerateBitcastExpression(
3443 const ast::BitcastExpression* expr) {
3444 auto result = result_op();
3445 auto result_id = result.to_i();
3446
3447 auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
3448 if (result_type_id == 0) {
3449 return 0;
3450 }
3451
3452 auto val_id = GenerateExpression(expr->expr);
3453 if (val_id == 0) {
3454 return 0;
3455 }
3456 val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
3457
3458 // Bitcast does not allow same types, just emit a CopyObject
3459 auto* to_type = TypeOf(expr)->UnwrapRef();
3460 auto* from_type = TypeOf(expr->expr)->UnwrapRef();
3461 if (to_type->type_name() == from_type->type_name()) {
3462 if (!push_function_inst(
3463 spv::Op::OpCopyObject,
3464 {Operand::Int(result_type_id), result, Operand::Int(val_id)})) {
3465 return 0;
3466 }
3467 return result_id;
3468 }
3469
3470 if (!push_function_inst(spv::Op::OpBitcast, {Operand::Int(result_type_id),
3471 result, Operand::Int(val_id)})) {
3472 return 0;
3473 }
3474
3475 return result_id;
3476 }
3477
GenerateConditionalBlock(const ast::Expression * cond,const ast::BlockStatement * true_body,size_t cur_else_idx,const ast::ElseStatementList & else_stmts)3478 bool Builder::GenerateConditionalBlock(
3479 const ast::Expression* cond,
3480 const ast::BlockStatement* true_body,
3481 size_t cur_else_idx,
3482 const ast::ElseStatementList& else_stmts) {
3483 auto cond_id = GenerateExpression(cond);
3484 if (cond_id == 0) {
3485 return false;
3486 }
3487 cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
3488
3489 auto merge_block = result_op();
3490 auto merge_block_id = merge_block.to_i();
3491
3492 if (!push_function_inst(spv::Op::OpSelectionMerge,
3493 {Operand::Int(merge_block_id),
3494 Operand::Int(SpvSelectionControlMaskNone)})) {
3495 return false;
3496 }
3497
3498 auto true_block = result_op();
3499 auto true_block_id = true_block.to_i();
3500
3501 // if there are no more else statements we branch on false to the merge
3502 // block otherwise we branch to the false block
3503 auto false_block_id =
3504 cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
3505
3506 if (!push_function_inst(spv::Op::OpBranchConditional,
3507 {Operand::Int(cond_id), Operand::Int(true_block_id),
3508 Operand::Int(false_block_id)})) {
3509 return false;
3510 }
3511
3512 // Output true block
3513 if (!GenerateLabel(true_block_id)) {
3514 return false;
3515 }
3516 if (!GenerateBlockStatement(true_body)) {
3517 return false;
3518 }
3519 // We only branch if the last element of the body didn't already branch.
3520 if (!LastIsTerminator(true_body)) {
3521 if (!push_function_inst(spv::Op::OpBranch,
3522 {Operand::Int(merge_block_id)})) {
3523 return false;
3524 }
3525 }
3526
3527 // Start the false block if needed
3528 if (false_block_id != merge_block_id) {
3529 if (!GenerateLabel(false_block_id)) {
3530 return false;
3531 }
3532
3533 auto* else_stmt = else_stmts[cur_else_idx];
3534 // Handle the else case by just outputting the statements.
3535 if (!else_stmt->condition) {
3536 if (!GenerateBlockStatement(else_stmt->body)) {
3537 return false;
3538 }
3539 } else {
3540 if (!GenerateConditionalBlock(else_stmt->condition, else_stmt->body,
3541 cur_else_idx + 1, else_stmts)) {
3542 return false;
3543 }
3544 }
3545 if (!LastIsTerminator(else_stmt->body)) {
3546 if (!push_function_inst(spv::Op::OpBranch,
3547 {Operand::Int(merge_block_id)})) {
3548 return false;
3549 }
3550 }
3551 }
3552
3553 // Output the merge block
3554 return GenerateLabel(merge_block_id);
3555 }
3556
GenerateIfStatement(const ast::IfStatement * stmt)3557 bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) {
3558 if (!continuing_stack_.empty() &&
3559 stmt == continuing_stack_.back().last_statement->As<ast::IfStatement>()) {
3560 const ContinuingInfo& ci = continuing_stack_.back();
3561 // Match one of two patterns: the break-if and break-unless patterns.
3562 //
3563 // The break-if pattern:
3564 // continuing { ...
3565 // if (cond) { break; }
3566 // }
3567 //
3568 // The break-unless pattern:
3569 // continuing { ...
3570 // if (cond) {} else {break;}
3571 // }
3572 auto is_just_a_break = [](const ast::BlockStatement* block) {
3573 return block && (block->statements.size() == 1) &&
3574 block->Last()->Is<ast::BreakStatement>();
3575 };
3576 if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
3577 // It's a break-if.
3578 TINT_ASSERT(Writer, !backedge_stack_.empty());
3579 const auto cond_id = GenerateExpression(stmt->condition);
3580 backedge_stack_.back() =
3581 Backedge(spv::Op::OpBranchConditional,
3582 {Operand::Int(cond_id), Operand::Int(ci.break_target_id),
3583 Operand::Int(ci.loop_header_id)});
3584 return true;
3585 } else if (stmt->body->Empty()) {
3586 const auto& es = stmt->else_statements;
3587 if (es.size() == 1 && !es.back()->condition &&
3588 is_just_a_break(es.back()->body)) {
3589 // It's a break-unless.
3590 TINT_ASSERT(Writer, !backedge_stack_.empty());
3591 const auto cond_id = GenerateExpression(stmt->condition);
3592 backedge_stack_.back() =
3593 Backedge(spv::Op::OpBranchConditional,
3594 {Operand::Int(cond_id), Operand::Int(ci.loop_header_id),
3595 Operand::Int(ci.break_target_id)});
3596 return true;
3597 }
3598 }
3599 }
3600
3601 if (!GenerateConditionalBlock(stmt->condition, stmt->body, 0,
3602 stmt->else_statements)) {
3603 return false;
3604 }
3605 return true;
3606 }
3607
GenerateSwitchStatement(const ast::SwitchStatement * stmt)3608 bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
3609 auto merge_block = result_op();
3610 auto merge_block_id = merge_block.to_i();
3611
3612 merge_stack_.push_back(merge_block_id);
3613
3614 auto cond_id = GenerateExpression(stmt->condition);
3615 if (cond_id == 0) {
3616 return false;
3617 }
3618 cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id);
3619
3620 auto default_block = result_op();
3621 auto default_block_id = default_block.to_i();
3622
3623 OperandList params = {Operand::Int(cond_id), Operand::Int(default_block_id)};
3624
3625 std::vector<uint32_t> case_ids;
3626 for (const auto* item : stmt->body) {
3627 if (item->IsDefault()) {
3628 case_ids.push_back(default_block_id);
3629 continue;
3630 }
3631
3632 auto block = result_op();
3633 auto block_id = block.to_i();
3634
3635 case_ids.push_back(block_id);
3636 for (auto* selector : item->selectors) {
3637 auto* int_literal = selector->As<ast::IntLiteralExpression>();
3638 if (!int_literal) {
3639 error_ = "expected integer literal for switch case label";
3640 return false;
3641 }
3642
3643 params.push_back(Operand::Int(int_literal->ValueAsU32()));
3644 params.push_back(Operand::Int(block_id));
3645 }
3646 }
3647
3648 if (!push_function_inst(spv::Op::OpSelectionMerge,
3649 {Operand::Int(merge_block_id),
3650 Operand::Int(SpvSelectionControlMaskNone)})) {
3651 return false;
3652 }
3653 if (!push_function_inst(spv::Op::OpSwitch, params)) {
3654 return false;
3655 }
3656
3657 bool generated_default = false;
3658 auto& body = stmt->body;
3659 // We output the case statements in order they were entered in the original
3660 // source. Each fallthrough goes to the next case entry, so is a forward
3661 // branch, otherwise the branch is to the merge block which comes after
3662 // the switch statement.
3663 for (uint32_t i = 0; i < body.size(); i++) {
3664 auto* item = body[i];
3665
3666 if (item->IsDefault()) {
3667 generated_default = true;
3668 }
3669
3670 if (!GenerateLabel(case_ids[i])) {
3671 return false;
3672 }
3673 if (!GenerateBlockStatement(item->body)) {
3674 return false;
3675 }
3676
3677 if (LastIsFallthrough(item->body)) {
3678 if (i == (body.size() - 1)) {
3679 // This case is caught by Resolver validation
3680 TINT_UNREACHABLE(Writer, builder_.Diagnostics());
3681 return false;
3682 }
3683 if (!push_function_inst(spv::Op::OpBranch,
3684 {Operand::Int(case_ids[i + 1])})) {
3685 return false;
3686 }
3687 } else if (!LastIsTerminator(item->body)) {
3688 if (!push_function_inst(spv::Op::OpBranch,
3689 {Operand::Int(merge_block_id)})) {
3690 return false;
3691 }
3692 }
3693 }
3694
3695 if (!generated_default) {
3696 if (!GenerateLabel(default_block_id)) {
3697 return false;
3698 }
3699 if (!push_function_inst(spv::Op::OpBranch,
3700 {Operand::Int(merge_block_id)})) {
3701 return false;
3702 }
3703 }
3704
3705 merge_stack_.pop_back();
3706
3707 return GenerateLabel(merge_block_id);
3708 }
3709
GenerateReturnStatement(const ast::ReturnStatement * stmt)3710 bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) {
3711 if (stmt->value) {
3712 auto val_id = GenerateExpression(stmt->value);
3713 if (val_id == 0) {
3714 return false;
3715 }
3716 val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id);
3717 if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
3718 return false;
3719 }
3720 } else {
3721 if (!push_function_inst(spv::Op::OpReturn, {})) {
3722 return false;
3723 }
3724 }
3725
3726 return true;
3727 }
3728
GenerateLoopStatement(const ast::LoopStatement * stmt)3729 bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
3730 auto loop_header = result_op();
3731 auto loop_header_id = loop_header.to_i();
3732 if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)})) {
3733 return false;
3734 }
3735 if (!GenerateLabel(loop_header_id)) {
3736 return false;
3737 }
3738
3739 auto merge_block = result_op();
3740 auto merge_block_id = merge_block.to_i();
3741 auto continue_block = result_op();
3742 auto continue_block_id = continue_block.to_i();
3743
3744 auto body_block = result_op();
3745 auto body_block_id = body_block.to_i();
3746
3747 if (!push_function_inst(
3748 spv::Op::OpLoopMerge,
3749 {Operand::Int(merge_block_id), Operand::Int(continue_block_id),
3750 Operand::Int(SpvLoopControlMaskNone)})) {
3751 return false;
3752 }
3753
3754 continue_stack_.push_back(continue_block_id);
3755 merge_stack_.push_back(merge_block_id);
3756
3757 // Usually, the backedge is a simple branch. This will be modified if the
3758 // backedge block in the continuing construct has an exiting edge.
3759 backedge_stack_.emplace_back(spv::Op::OpBranch,
3760 OperandList{Operand::Int(loop_header_id)});
3761
3762 if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)})) {
3763 return false;
3764 }
3765 if (!GenerateLabel(body_block_id)) {
3766 return false;
3767 }
3768
3769 // We need variables from the body to be visible in the continuing block, so
3770 // manage scope outside of GenerateBlockStatement.
3771 {
3772 scope_stack_.Push();
3773 TINT_DEFER(scope_stack_.Pop());
3774
3775 if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
3776 return false;
3777 }
3778
3779 // We only branch if the last element of the body didn't already branch.
3780 if (!LastIsTerminator(stmt->body)) {
3781 if (!push_function_inst(spv::Op::OpBranch,
3782 {Operand::Int(continue_block_id)})) {
3783 return false;
3784 }
3785 }
3786
3787 if (!GenerateLabel(continue_block_id)) {
3788 return false;
3789 }
3790 if (stmt->continuing && !stmt->continuing->Empty()) {
3791 continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
3792 merge_block_id);
3793 if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
3794 return false;
3795 }
3796 continuing_stack_.pop_back();
3797 }
3798 }
3799
3800 // Generate the backedge.
3801 TINT_ASSERT(Writer, !backedge_stack_.empty());
3802 const Backedge& backedge = backedge_stack_.back();
3803 if (!push_function_inst(backedge.opcode, backedge.operands)) {
3804 return false;
3805 }
3806 backedge_stack_.pop_back();
3807
3808 merge_stack_.pop_back();
3809 continue_stack_.pop_back();
3810
3811 return GenerateLabel(merge_block_id);
3812 }
3813
GenerateStatement(const ast::Statement * stmt)3814 bool Builder::GenerateStatement(const ast::Statement* stmt) {
3815 if (auto* a = stmt->As<ast::AssignmentStatement>()) {
3816 return GenerateAssignStatement(a);
3817 }
3818 if (auto* b = stmt->As<ast::BlockStatement>()) {
3819 return GenerateBlockStatement(b);
3820 }
3821 if (auto* b = stmt->As<ast::BreakStatement>()) {
3822 return GenerateBreakStatement(b);
3823 }
3824 if (auto* c = stmt->As<ast::CallStatement>()) {
3825 return GenerateCallExpression(c->expr) != 0;
3826 }
3827 if (auto* c = stmt->As<ast::ContinueStatement>()) {
3828 return GenerateContinueStatement(c);
3829 }
3830 if (auto* d = stmt->As<ast::DiscardStatement>()) {
3831 return GenerateDiscardStatement(d);
3832 }
3833 if (stmt->Is<ast::FallthroughStatement>()) {
3834 // Do nothing here, the fallthrough gets handled by the switch code.
3835 return true;
3836 }
3837 if (auto* i = stmt->As<ast::IfStatement>()) {
3838 return GenerateIfStatement(i);
3839 }
3840 if (auto* l = stmt->As<ast::LoopStatement>()) {
3841 return GenerateLoopStatement(l);
3842 }
3843 if (auto* r = stmt->As<ast::ReturnStatement>()) {
3844 return GenerateReturnStatement(r);
3845 }
3846 if (auto* s = stmt->As<ast::SwitchStatement>()) {
3847 return GenerateSwitchStatement(s);
3848 }
3849 if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
3850 return GenerateVariableDeclStatement(v);
3851 }
3852
3853 error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
3854 return false;
3855 }
3856
GenerateVariableDeclStatement(const ast::VariableDeclStatement * stmt)3857 bool Builder::GenerateVariableDeclStatement(
3858 const ast::VariableDeclStatement* stmt) {
3859 return GenerateFunctionVariable(stmt->variable);
3860 }
3861
GenerateTypeIfNeeded(const sem::Type * type)3862 uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
3863 if (type == nullptr) {
3864 error_ = "attempting to generate type from null type";
3865 return 0;
3866 }
3867
3868 // Atomics are a type in WGSL, but aren't a distinct type in SPIR-V.
3869 // Just emit the type inside the atomic.
3870 if (auto* atomic = type->As<sem::Atomic>()) {
3871 return GenerateTypeIfNeeded(atomic->Type());
3872 }
3873
3874 // Pointers and references with differing accesses should not result in a
3875 // different SPIR-V types, so we explicitly ignore the access.
3876 // Pointers and References both map to a SPIR-V pointer type.
3877 // Transform a Reference to a Pointer to prevent these having duplicated
3878 // definitions in the generated SPIR-V. Note that nested pointers and
3879 // references are not legal in WGSL, so only considering the top-level type is
3880 // fine.
3881 std::string type_name;
3882 if (auto* ptr = type->As<sem::Pointer>()) {
3883 type_name =
3884 sem::Pointer(ptr->StoreType(), ptr->StorageClass(), ast::kReadWrite)
3885 .type_name();
3886 } else if (auto* ref = type->As<sem::Reference>()) {
3887 type_name =
3888 sem::Pointer(ref->StoreType(), ref->StorageClass(), ast::kReadWrite)
3889 .type_name();
3890 } else {
3891 type_name = type->type_name();
3892 }
3893
3894 return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
3895 auto result = result_op();
3896 auto id = result.to_i();
3897 if (auto* arr = type->As<sem::Array>()) {
3898 if (!GenerateArrayType(arr, result)) {
3899 return 0;
3900 }
3901 } else if (type->Is<sem::Bool>()) {
3902 push_type(spv::Op::OpTypeBool, {result});
3903 } else if (type->Is<sem::F32>()) {
3904 push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
3905 } else if (type->Is<sem::I32>()) {
3906 push_type(spv::Op::OpTypeInt,
3907 {result, Operand::Int(32), Operand::Int(1)});
3908 } else if (auto* mat = type->As<sem::Matrix>()) {
3909 if (!GenerateMatrixType(mat, result)) {
3910 return 0;
3911 }
3912 } else if (auto* ptr = type->As<sem::Pointer>()) {
3913 if (!GeneratePointerType(ptr, result)) {
3914 return 0;
3915 }
3916 } else if (auto* ref = type->As<sem::Reference>()) {
3917 if (!GenerateReferenceType(ref, result)) {
3918 return 0;
3919 }
3920 } else if (auto* str = type->As<sem::Struct>()) {
3921 if (!GenerateStructType(str, result)) {
3922 return 0;
3923 }
3924 } else if (type->Is<sem::U32>()) {
3925 push_type(spv::Op::OpTypeInt,
3926 {result, Operand::Int(32), Operand::Int(0)});
3927 } else if (auto* vec = type->As<sem::Vector>()) {
3928 if (!GenerateVectorType(vec, result)) {
3929 return 0;
3930 }
3931 } else if (type->Is<sem::Void>()) {
3932 push_type(spv::Op::OpTypeVoid, {result});
3933 } else if (auto* tex = type->As<sem::Texture>()) {
3934 if (!GenerateTextureType(tex, result)) {
3935 return 0;
3936 }
3937
3938 if (auto* st = tex->As<sem::StorageTexture>()) {
3939 // Register all three access types of StorageTexture names. In SPIR-V,
3940 // we must output a single type, while the variable is annotated with
3941 // the access type. Doing this ensures we de-dupe.
3942 type_name_to_id_[builder_
3943 .create<sem::StorageTexture>(
3944 st->dim(), st->image_format(),
3945 ast::Access::kRead, st->type())
3946 ->type_name()] = id;
3947 type_name_to_id_[builder_
3948 .create<sem::StorageTexture>(
3949 st->dim(), st->image_format(),
3950 ast::Access::kWrite, st->type())
3951 ->type_name()] = id;
3952 type_name_to_id_[builder_
3953 .create<sem::StorageTexture>(
3954 st->dim(), st->image_format(),
3955 ast::Access::kReadWrite, st->type())
3956 ->type_name()] = id;
3957 }
3958
3959 } else if (type->Is<sem::Sampler>()) {
3960 push_type(spv::Op::OpTypeSampler, {result});
3961
3962 // Register both of the sampler type names. In SPIR-V they're the same
3963 // sampler type, so we need to match that when we do the dedup check.
3964 type_name_to_id_["__sampler_sampler"] = id;
3965 type_name_to_id_["__sampler_comparison"] = id;
3966
3967 } else {
3968 error_ = "unable to convert type: " + type->type_name();
3969 return 0;
3970 }
3971
3972 return id;
3973 });
3974 }
3975
GenerateTextureType(const sem::Texture * texture,const Operand & result)3976 bool Builder::GenerateTextureType(const sem::Texture* texture,
3977 const Operand& result) {
3978 uint32_t array_literal = 0u;
3979 const auto dim = texture->dim();
3980 if (dim == ast::TextureDimension::k2dArray ||
3981 dim == ast::TextureDimension::kCubeArray) {
3982 array_literal = 1u;
3983 }
3984
3985 uint32_t dim_literal = SpvDim2D;
3986 if (dim == ast::TextureDimension::k1d) {
3987 dim_literal = SpvDim1D;
3988 if (texture->Is<sem::SampledTexture>()) {
3989 push_capability(SpvCapabilitySampled1D);
3990 } else if (texture->Is<sem::StorageTexture>()) {
3991 push_capability(SpvCapabilityImage1D);
3992 }
3993 }
3994 if (dim == ast::TextureDimension::k3d) {
3995 dim_literal = SpvDim3D;
3996 }
3997 if (dim == ast::TextureDimension::kCube ||
3998 dim == ast::TextureDimension::kCubeArray) {
3999 dim_literal = SpvDimCube;
4000 }
4001
4002 uint32_t ms_literal = 0u;
4003 if (texture->IsAnyOf<sem::MultisampledTexture,
4004 sem::DepthMultisampledTexture>()) {
4005 ms_literal = 1u;
4006 }
4007
4008 uint32_t depth_literal = 0u;
4009 if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4010 depth_literal = 1u;
4011 }
4012
4013 uint32_t sampled_literal = 2u;
4014 if (texture->IsAnyOf<sem::MultisampledTexture, sem::SampledTexture,
4015 sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4016 sampled_literal = 1u;
4017 }
4018
4019 if (dim == ast::TextureDimension::kCubeArray) {
4020 if (texture->Is<sem::SampledTexture>() ||
4021 texture->Is<sem::DepthTexture>()) {
4022 push_capability(SpvCapabilitySampledCubeArray);
4023 }
4024 }
4025
4026 uint32_t type_id = 0u;
4027 if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4028 type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
4029 } else if (auto* s = texture->As<sem::SampledTexture>()) {
4030 type_id = GenerateTypeIfNeeded(s->type());
4031 } else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
4032 type_id = GenerateTypeIfNeeded(ms->type());
4033 } else if (auto* st = texture->As<sem::StorageTexture>()) {
4034 type_id = GenerateTypeIfNeeded(st->type());
4035 }
4036 if (type_id == 0u) {
4037 return false;
4038 }
4039
4040 uint32_t format_literal = SpvImageFormat_::SpvImageFormatUnknown;
4041 if (auto* t = texture->As<sem::StorageTexture>()) {
4042 format_literal = convert_image_format_to_spv(t->image_format());
4043 }
4044
4045 push_type(spv::Op::OpTypeImage,
4046 {result, Operand::Int(type_id), Operand::Int(dim_literal),
4047 Operand::Int(depth_literal), Operand::Int(array_literal),
4048 Operand::Int(ms_literal), Operand::Int(sampled_literal),
4049 Operand::Int(format_literal)});
4050
4051 return true;
4052 }
4053
GenerateArrayType(const sem::Array * ary,const Operand & result)4054 bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
4055 auto elem_type = GenerateTypeIfNeeded(ary->ElemType());
4056 if (elem_type == 0) {
4057 return false;
4058 }
4059
4060 auto result_id = result.to_i();
4061 if (ary->IsRuntimeSized()) {
4062 push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
4063 } else {
4064 auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->Count()));
4065 if (len_id == 0) {
4066 return false;
4067 }
4068
4069 push_type(spv::Op::OpTypeArray,
4070 {result, Operand::Int(elem_type), Operand::Int(len_id)});
4071 }
4072
4073 push_annot(spv::Op::OpDecorate,
4074 {Operand::Int(result_id), Operand::Int(SpvDecorationArrayStride),
4075 Operand::Int(ary->Stride())});
4076 return true;
4077 }
4078
GenerateMatrixType(const sem::Matrix * mat,const Operand & result)4079 bool Builder::GenerateMatrixType(const sem::Matrix* mat,
4080 const Operand& result) {
4081 auto* col_type = builder_.create<sem::Vector>(mat->type(), mat->rows());
4082 auto col_type_id = GenerateTypeIfNeeded(col_type);
4083 if (has_error()) {
4084 return false;
4085 }
4086
4087 push_type(spv::Op::OpTypeMatrix,
4088 {result, Operand::Int(col_type_id), Operand::Int(mat->columns())});
4089 return true;
4090 }
4091
GeneratePointerType(const sem::Pointer * ptr,const Operand & result)4092 bool Builder::GeneratePointerType(const sem::Pointer* ptr,
4093 const Operand& result) {
4094 auto subtype_id = GenerateTypeIfNeeded(ptr->StoreType());
4095 if (subtype_id == 0) {
4096 return false;
4097 }
4098
4099 auto stg_class = ConvertStorageClass(ptr->StorageClass());
4100 if (stg_class == SpvStorageClassMax) {
4101 error_ = "invalid storage class for pointer";
4102 return false;
4103 }
4104
4105 push_type(spv::Op::OpTypePointer,
4106 {result, Operand::Int(stg_class), Operand::Int(subtype_id)});
4107
4108 return true;
4109 }
4110
GenerateReferenceType(const sem::Reference * ref,const Operand & result)4111 bool Builder::GenerateReferenceType(const sem::Reference* ref,
4112 const Operand& result) {
4113 auto subtype_id = GenerateTypeIfNeeded(ref->StoreType());
4114 if (subtype_id == 0) {
4115 return false;
4116 }
4117
4118 auto stg_class = ConvertStorageClass(ref->StorageClass());
4119 if (stg_class == SpvStorageClassMax) {
4120 error_ = "invalid storage class for reference";
4121 return false;
4122 }
4123
4124 push_type(spv::Op::OpTypePointer,
4125 {result, Operand::Int(stg_class), Operand::Int(subtype_id)});
4126
4127 return true;
4128 }
4129
GenerateStructType(const sem::Struct * struct_type,const Operand & result)4130 bool Builder::GenerateStructType(const sem::Struct* struct_type,
4131 const Operand& result) {
4132 auto struct_id = result.to_i();
4133
4134 if (struct_type->Name().IsValid()) {
4135 push_debug(
4136 spv::Op::OpName,
4137 {Operand::Int(struct_id),
4138 Operand::String(builder_.Symbols().NameFor(struct_type->Name()))});
4139 }
4140
4141 OperandList ops;
4142 ops.push_back(result);
4143
4144 auto* decl = struct_type->Declaration();
4145 if (decl && decl->IsBlockDecorated()) {
4146 push_annot(spv::Op::OpDecorate,
4147 {Operand::Int(struct_id), Operand::Int(SpvDecorationBlock)});
4148 }
4149
4150 for (uint32_t i = 0; i < struct_type->Members().size(); ++i) {
4151 auto mem_id = GenerateStructMember(struct_id, i, struct_type->Members()[i]);
4152 if (mem_id == 0) {
4153 return false;
4154 }
4155
4156 ops.push_back(Operand::Int(mem_id));
4157 }
4158
4159 push_type(spv::Op::OpTypeStruct, std::move(ops));
4160 return true;
4161 }
4162
GenerateStructMember(uint32_t struct_id,uint32_t idx,const sem::StructMember * member)4163 uint32_t Builder::GenerateStructMember(uint32_t struct_id,
4164 uint32_t idx,
4165 const sem::StructMember* member) {
4166 push_debug(spv::Op::OpMemberName,
4167 {Operand::Int(struct_id), Operand::Int(idx),
4168 Operand::String(builder_.Symbols().NameFor(member->Name()))});
4169
4170 // Note: This will generate layout annotations for *all* structs, whether or
4171 // not they are used in host-shareable variables. This is officially ok in
4172 // SPIR-V 1.0 through 1.3. If / when we migrate to using SPIR-V 1.4 we'll have
4173 // to only generate the layout info for structs used for certain storage
4174 // classes.
4175
4176 push_annot(
4177 spv::Op::OpMemberDecorate,
4178 {Operand::Int(struct_id), Operand::Int(idx),
4179 Operand::Int(SpvDecorationOffset), Operand::Int(member->Offset())});
4180
4181 // Infer and emit matrix layout.
4182 auto* matrix_type = GetNestedMatrixType(member->Type());
4183 if (matrix_type) {
4184 push_annot(spv::Op::OpMemberDecorate,
4185 {Operand::Int(struct_id), Operand::Int(idx),
4186 Operand::Int(SpvDecorationColMajor)});
4187 if (!matrix_type->type()->Is<sem::F32>()) {
4188 error_ = "matrix scalar element type must be f32";
4189 return 0;
4190 }
4191 const auto scalar_elem_size = 4;
4192 const auto effective_row_count = (matrix_type->rows() == 2) ? 2 : 4;
4193 push_annot(spv::Op::OpMemberDecorate,
4194 {Operand::Int(struct_id), Operand::Int(idx),
4195 Operand::Int(SpvDecorationMatrixStride),
4196 Operand::Int(effective_row_count * scalar_elem_size)});
4197 }
4198
4199 return GenerateTypeIfNeeded(member->Type());
4200 }
4201
GenerateVectorType(const sem::Vector * vec,const Operand & result)4202 bool Builder::GenerateVectorType(const sem::Vector* vec,
4203 const Operand& result) {
4204 auto type_id = GenerateTypeIfNeeded(vec->type());
4205 if (has_error()) {
4206 return false;
4207 }
4208
4209 push_type(spv::Op::OpTypeVector,
4210 {result, Operand::Int(type_id), Operand::Int(vec->Width())});
4211 return true;
4212 }
4213
ConvertStorageClass(ast::StorageClass klass) const4214 SpvStorageClass Builder::ConvertStorageClass(ast::StorageClass klass) const {
4215 switch (klass) {
4216 case ast::StorageClass::kInvalid:
4217 return SpvStorageClassMax;
4218 case ast::StorageClass::kInput:
4219 return SpvStorageClassInput;
4220 case ast::StorageClass::kOutput:
4221 return SpvStorageClassOutput;
4222 case ast::StorageClass::kUniform:
4223 return SpvStorageClassUniform;
4224 case ast::StorageClass::kWorkgroup:
4225 return SpvStorageClassWorkgroup;
4226 case ast::StorageClass::kUniformConstant:
4227 return SpvStorageClassUniformConstant;
4228 case ast::StorageClass::kStorage:
4229 return SpvStorageClassStorageBuffer;
4230 case ast::StorageClass::kImage:
4231 return SpvStorageClassImage;
4232 case ast::StorageClass::kPrivate:
4233 return SpvStorageClassPrivate;
4234 case ast::StorageClass::kFunction:
4235 return SpvStorageClassFunction;
4236 case ast::StorageClass::kNone:
4237 break;
4238 }
4239 return SpvStorageClassMax;
4240 }
4241
ConvertBuiltin(ast::Builtin builtin,ast::StorageClass storage)4242 SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin,
4243 ast::StorageClass storage) {
4244 switch (builtin) {
4245 case ast::Builtin::kPosition:
4246 if (storage == ast::StorageClass::kInput) {
4247 return SpvBuiltInFragCoord;
4248 } else if (storage == ast::StorageClass::kOutput) {
4249 return SpvBuiltInPosition;
4250 } else {
4251 TINT_ICE(Writer, builder_.Diagnostics())
4252 << "invalid storage class for builtin";
4253 break;
4254 }
4255 case ast::Builtin::kVertexIndex:
4256 return SpvBuiltInVertexIndex;
4257 case ast::Builtin::kInstanceIndex:
4258 return SpvBuiltInInstanceIndex;
4259 case ast::Builtin::kFrontFacing:
4260 return SpvBuiltInFrontFacing;
4261 case ast::Builtin::kFragDepth:
4262 return SpvBuiltInFragDepth;
4263 case ast::Builtin::kLocalInvocationId:
4264 return SpvBuiltInLocalInvocationId;
4265 case ast::Builtin::kLocalInvocationIndex:
4266 return SpvBuiltInLocalInvocationIndex;
4267 case ast::Builtin::kGlobalInvocationId:
4268 return SpvBuiltInGlobalInvocationId;
4269 case ast::Builtin::kPointSize:
4270 return SpvBuiltInPointSize;
4271 case ast::Builtin::kWorkgroupId:
4272 return SpvBuiltInWorkgroupId;
4273 case ast::Builtin::kNumWorkgroups:
4274 return SpvBuiltInNumWorkgroups;
4275 case ast::Builtin::kSampleIndex:
4276 push_capability(SpvCapabilitySampleRateShading);
4277 return SpvBuiltInSampleId;
4278 case ast::Builtin::kSampleMask:
4279 return SpvBuiltInSampleMask;
4280 case ast::Builtin::kNone:
4281 break;
4282 }
4283 return SpvBuiltInMax;
4284 }
4285
AddInterpolationDecorations(uint32_t id,ast::InterpolationType type,ast::InterpolationSampling sampling)4286 void Builder::AddInterpolationDecorations(uint32_t id,
4287 ast::InterpolationType type,
4288 ast::InterpolationSampling sampling) {
4289 switch (type) {
4290 case ast::InterpolationType::kLinear:
4291 push_annot(spv::Op::OpDecorate,
4292 {Operand::Int(id), Operand::Int(SpvDecorationNoPerspective)});
4293 break;
4294 case ast::InterpolationType::kFlat:
4295 push_annot(spv::Op::OpDecorate,
4296 {Operand::Int(id), Operand::Int(SpvDecorationFlat)});
4297 break;
4298 case ast::InterpolationType::kPerspective:
4299 break;
4300 }
4301 switch (sampling) {
4302 case ast::InterpolationSampling::kCentroid:
4303 push_annot(spv::Op::OpDecorate,
4304 {Operand::Int(id), Operand::Int(SpvDecorationCentroid)});
4305 break;
4306 case ast::InterpolationSampling::kSample:
4307 push_capability(SpvCapabilitySampleRateShading);
4308 push_annot(spv::Op::OpDecorate,
4309 {Operand::Int(id), Operand::Int(SpvDecorationSample)});
4310 break;
4311 case ast::InterpolationSampling::kCenter:
4312 case ast::InterpolationSampling::kNone:
4313 break;
4314 }
4315 }
4316
convert_image_format_to_spv(const ast::ImageFormat format)4317 SpvImageFormat Builder::convert_image_format_to_spv(
4318 const ast::ImageFormat format) {
4319 switch (format) {
4320 case ast::ImageFormat::kR8Unorm:
4321 push_capability(SpvCapabilityStorageImageExtendedFormats);
4322 return SpvImageFormatR8;
4323 case ast::ImageFormat::kR8Snorm:
4324 push_capability(SpvCapabilityStorageImageExtendedFormats);
4325 return SpvImageFormatR8Snorm;
4326 case ast::ImageFormat::kR8Uint:
4327 push_capability(SpvCapabilityStorageImageExtendedFormats);
4328 return SpvImageFormatR8ui;
4329 case ast::ImageFormat::kR8Sint:
4330 push_capability(SpvCapabilityStorageImageExtendedFormats);
4331 return SpvImageFormatR8i;
4332 case ast::ImageFormat::kR16Uint:
4333 push_capability(SpvCapabilityStorageImageExtendedFormats);
4334 return SpvImageFormatR16ui;
4335 case ast::ImageFormat::kR16Sint:
4336 push_capability(SpvCapabilityStorageImageExtendedFormats);
4337 return SpvImageFormatR16i;
4338 case ast::ImageFormat::kR16Float:
4339 push_capability(SpvCapabilityStorageImageExtendedFormats);
4340 return SpvImageFormatR16f;
4341 case ast::ImageFormat::kRg8Unorm:
4342 push_capability(SpvCapabilityStorageImageExtendedFormats);
4343 return SpvImageFormatRg8;
4344 case ast::ImageFormat::kRg8Snorm:
4345 push_capability(SpvCapabilityStorageImageExtendedFormats);
4346 return SpvImageFormatRg8Snorm;
4347 case ast::ImageFormat::kRg8Uint:
4348 push_capability(SpvCapabilityStorageImageExtendedFormats);
4349 return SpvImageFormatRg8ui;
4350 case ast::ImageFormat::kRg8Sint:
4351 push_capability(SpvCapabilityStorageImageExtendedFormats);
4352 return SpvImageFormatRg8i;
4353 case ast::ImageFormat::kR32Uint:
4354 return SpvImageFormatR32ui;
4355 case ast::ImageFormat::kR32Sint:
4356 return SpvImageFormatR32i;
4357 case ast::ImageFormat::kR32Float:
4358 return SpvImageFormatR32f;
4359 case ast::ImageFormat::kRg16Uint:
4360 push_capability(SpvCapabilityStorageImageExtendedFormats);
4361 return SpvImageFormatRg16ui;
4362 case ast::ImageFormat::kRg16Sint:
4363 push_capability(SpvCapabilityStorageImageExtendedFormats);
4364 return SpvImageFormatRg16i;
4365 case ast::ImageFormat::kRg16Float:
4366 push_capability(SpvCapabilityStorageImageExtendedFormats);
4367 return SpvImageFormatRg16f;
4368 case ast::ImageFormat::kRgba8Unorm:
4369 return SpvImageFormatRgba8;
4370 case ast::ImageFormat::kRgba8UnormSrgb:
4371 return SpvImageFormatUnknown;
4372 case ast::ImageFormat::kRgba8Snorm:
4373 return SpvImageFormatRgba8Snorm;
4374 case ast::ImageFormat::kRgba8Uint:
4375 return SpvImageFormatRgba8ui;
4376 case ast::ImageFormat::kRgba8Sint:
4377 return SpvImageFormatRgba8i;
4378 case ast::ImageFormat::kBgra8Unorm:
4379 return SpvImageFormatUnknown;
4380 case ast::ImageFormat::kBgra8UnormSrgb:
4381 return SpvImageFormatUnknown;
4382 case ast::ImageFormat::kRgb10A2Unorm:
4383 push_capability(SpvCapabilityStorageImageExtendedFormats);
4384 return SpvImageFormatRgb10A2;
4385 case ast::ImageFormat::kRg11B10Float:
4386 push_capability(SpvCapabilityStorageImageExtendedFormats);
4387 return SpvImageFormatR11fG11fB10f;
4388 case ast::ImageFormat::kRg32Uint:
4389 push_capability(SpvCapabilityStorageImageExtendedFormats);
4390 return SpvImageFormatRg32ui;
4391 case ast::ImageFormat::kRg32Sint:
4392 push_capability(SpvCapabilityStorageImageExtendedFormats);
4393 return SpvImageFormatRg32i;
4394 case ast::ImageFormat::kRg32Float:
4395 push_capability(SpvCapabilityStorageImageExtendedFormats);
4396 return SpvImageFormatRg32f;
4397 case ast::ImageFormat::kRgba16Uint:
4398 return SpvImageFormatRgba16ui;
4399 case ast::ImageFormat::kRgba16Sint:
4400 return SpvImageFormatRgba16i;
4401 case ast::ImageFormat::kRgba16Float:
4402 return SpvImageFormatRgba16f;
4403 case ast::ImageFormat::kRgba32Uint:
4404 return SpvImageFormatRgba32ui;
4405 case ast::ImageFormat::kRgba32Sint:
4406 return SpvImageFormatRgba32i;
4407 case ast::ImageFormat::kRgba32Float:
4408 return SpvImageFormatRgba32f;
4409 case ast::ImageFormat::kNone:
4410 return SpvImageFormatUnknown;
4411 }
4412 return SpvImageFormatUnknown;
4413 }
4414
push_function_inst(spv::Op op,const OperandList & operands)4415 bool Builder::push_function_inst(spv::Op op, const OperandList& operands) {
4416 if (functions_.empty()) {
4417 std::ostringstream ss;
4418 ss << "Internal error: trying to add SPIR-V instruction " << int(op)
4419 << " outside a function";
4420 error_ = ss.str();
4421 return false;
4422 }
4423 functions_.back().push_inst(op, operands);
4424 return true;
4425 }
4426
ContinuingInfo(const ast::Statement * the_last_statement,uint32_t loop_id,uint32_t break_id)4427 Builder::ContinuingInfo::ContinuingInfo(
4428 const ast::Statement* the_last_statement,
4429 uint32_t loop_id,
4430 uint32_t break_id)
4431 : last_statement(the_last_statement),
4432 loop_header_id(loop_id),
4433 break_target_id(break_id) {
4434 TINT_ASSERT(Writer, last_statement != nullptr);
4435 TINT_ASSERT(Writer, loop_header_id != 0u);
4436 TINT_ASSERT(Writer, break_target_id != 0u);
4437 }
4438
Backedge(spv::Op the_opcode,OperandList the_operands)4439 Builder::Backedge::Backedge(spv::Op the_opcode, OperandList the_operands)
4440 : opcode(the_opcode), operands(the_operands) {}
4441
4442 Builder::Backedge::Backedge(const Builder::Backedge& other) = default;
4443 Builder::Backedge& Builder::Backedge::operator=(
4444 const Builder::Backedge& other) = default;
4445 Builder::Backedge::~Backedge() = default;
4446
4447 } // namespace spirv
4448 } // namespace writer
4449 } // namespace tint
4450