• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.  //
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include "src/writer/spirv/builder.h"
15 
16 #include <algorithm>
17 #include <limits>
18 #include <utility>
19 
20 #include "spirv/unified1/GLSL.std.450.h"
21 #include "src/ast/call_statement.h"
22 #include "src/ast/fallthrough_statement.h"
23 #include "src/ast/internal_decoration.h"
24 #include "src/ast/override_decoration.h"
25 #include "src/ast/traverse_expressions.h"
26 #include "src/sem/array.h"
27 #include "src/sem/atomic_type.h"
28 #include "src/sem/call.h"
29 #include "src/sem/depth_multisampled_texture_type.h"
30 #include "src/sem/depth_texture_type.h"
31 #include "src/sem/function.h"
32 #include "src/sem/intrinsic.h"
33 #include "src/sem/member_accessor_expression.h"
34 #include "src/sem/multisampled_texture_type.h"
35 #include "src/sem/reference_type.h"
36 #include "src/sem/sampled_texture_type.h"
37 #include "src/sem/statement.h"
38 #include "src/sem/struct.h"
39 #include "src/sem/type_constructor.h"
40 #include "src/sem/type_conversion.h"
41 #include "src/sem/variable.h"
42 #include "src/sem/vector_type.h"
43 #include "src/transform/add_empty_entry_point.h"
44 #include "src/transform/canonicalize_entry_point_io.h"
45 #include "src/transform/external_texture_transform.h"
46 #include "src/transform/fold_constants.h"
47 #include "src/transform/for_loop_to_loop.h"
48 #include "src/transform/manager.h"
49 #include "src/transform/simplify_pointers.h"
50 #include "src/transform/unshadow.h"
51 #include "src/transform/vectorize_scalar_matrix_constructors.h"
52 #include "src/transform/zero_init_workgroup_memory.h"
53 #include "src/utils/defer.h"
54 #include "src/utils/map.h"
55 #include "src/writer/append_vector.h"
56 
57 namespace tint {
58 namespace writer {
59 namespace spirv {
60 namespace {
61 
62 using IntrinsicType = sem::IntrinsicType;
63 
64 const char kGLSLstd450[] = "GLSL.std.450";
65 
size_of(const InstructionList & instructions)66 uint32_t size_of(const InstructionList& instructions) {
67   uint32_t size = 0;
68   for (const auto& inst : instructions)
69     size += inst.word_length();
70 
71   return size;
72 }
73 
pipeline_stage_to_execution_model(ast::PipelineStage stage)74 uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
75   SpvExecutionModel model = SpvExecutionModelVertex;
76 
77   switch (stage) {
78     case ast::PipelineStage::kFragment:
79       model = SpvExecutionModelFragment;
80       break;
81     case ast::PipelineStage::kVertex:
82       model = SpvExecutionModelVertex;
83       break;
84     case ast::PipelineStage::kCompute:
85       model = SpvExecutionModelGLCompute;
86       break;
87     case ast::PipelineStage::kNone:
88       model = SpvExecutionModelMax;
89       break;
90   }
91   return model;
92 }
93 
LastIsFallthrough(const ast::BlockStatement * stmts)94 bool LastIsFallthrough(const ast::BlockStatement* stmts) {
95   return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
96 }
97 
98 // A terminator is anything which will cause a SPIR-V terminator to be emitted.
99 // This means things like breaks, fallthroughs and continues which all emit an
100 // OpBranch or return for the OpReturn emission.
LastIsTerminator(const ast::BlockStatement * stmts)101 bool LastIsTerminator(const ast::BlockStatement* stmts) {
102   if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
103               ast::DiscardStatement, ast::ReturnStatement,
104               ast::FallthroughStatement>(stmts->Last())) {
105     return true;
106   }
107 
108   if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
109     return LastIsTerminator(block);
110   }
111 
112   return false;
113 }
114 
115 /// Returns the matrix type that is `type` or that is wrapped by
116 /// one or more levels of an arrays inside of `type`.
117 /// @param type the given type, which must not be null
118 /// @returns the nested matrix type, or nullptr if none
GetNestedMatrixType(const sem::Type * type)119 const sem::Matrix* GetNestedMatrixType(const sem::Type* type) {
120   while (auto* arr = type->As<sem::Array>()) {
121     type = arr->ElemType();
122   }
123   return type->As<sem::Matrix>();
124 }
125 
intrinsic_to_glsl_method(const sem::Intrinsic * intrinsic)126 uint32_t intrinsic_to_glsl_method(const sem::Intrinsic* intrinsic) {
127   switch (intrinsic->Type()) {
128     case IntrinsicType::kAcos:
129       return GLSLstd450Acos;
130     case IntrinsicType::kAsin:
131       return GLSLstd450Asin;
132     case IntrinsicType::kAtan:
133       return GLSLstd450Atan;
134     case IntrinsicType::kAtan2:
135       return GLSLstd450Atan2;
136     case IntrinsicType::kCeil:
137       return GLSLstd450Ceil;
138     case IntrinsicType::kClamp:
139       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
140         return GLSLstd450NClamp;
141       } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
142         return GLSLstd450UClamp;
143       } else {
144         return GLSLstd450SClamp;
145       }
146     case IntrinsicType::kCos:
147       return GLSLstd450Cos;
148     case IntrinsicType::kCosh:
149       return GLSLstd450Cosh;
150     case IntrinsicType::kCross:
151       return GLSLstd450Cross;
152     case IntrinsicType::kDeterminant:
153       return GLSLstd450Determinant;
154     case IntrinsicType::kDistance:
155       return GLSLstd450Distance;
156     case IntrinsicType::kExp:
157       return GLSLstd450Exp;
158     case IntrinsicType::kExp2:
159       return GLSLstd450Exp2;
160     case IntrinsicType::kFaceForward:
161       return GLSLstd450FaceForward;
162     case IntrinsicType::kFloor:
163       return GLSLstd450Floor;
164     case IntrinsicType::kFma:
165       return GLSLstd450Fma;
166     case IntrinsicType::kFract:
167       return GLSLstd450Fract;
168     case IntrinsicType::kFrexp:
169       return GLSLstd450FrexpStruct;
170     case IntrinsicType::kInverseSqrt:
171       return GLSLstd450InverseSqrt;
172     case IntrinsicType::kLdexp:
173       return GLSLstd450Ldexp;
174     case IntrinsicType::kLength:
175       return GLSLstd450Length;
176     case IntrinsicType::kLog:
177       return GLSLstd450Log;
178     case IntrinsicType::kLog2:
179       return GLSLstd450Log2;
180     case IntrinsicType::kMax:
181       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
182         return GLSLstd450NMax;
183       } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
184         return GLSLstd450UMax;
185       } else {
186         return GLSLstd450SMax;
187       }
188     case IntrinsicType::kMin:
189       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
190         return GLSLstd450NMin;
191       } else if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
192         return GLSLstd450UMin;
193       } else {
194         return GLSLstd450SMin;
195       }
196     case IntrinsicType::kMix:
197       return GLSLstd450FMix;
198     case IntrinsicType::kModf:
199       return GLSLstd450ModfStruct;
200     case IntrinsicType::kNormalize:
201       return GLSLstd450Normalize;
202     case IntrinsicType::kPack4x8snorm:
203       return GLSLstd450PackSnorm4x8;
204     case IntrinsicType::kPack4x8unorm:
205       return GLSLstd450PackUnorm4x8;
206     case IntrinsicType::kPack2x16snorm:
207       return GLSLstd450PackSnorm2x16;
208     case IntrinsicType::kPack2x16unorm:
209       return GLSLstd450PackUnorm2x16;
210     case IntrinsicType::kPack2x16float:
211       return GLSLstd450PackHalf2x16;
212     case IntrinsicType::kPow:
213       return GLSLstd450Pow;
214     case IntrinsicType::kReflect:
215       return GLSLstd450Reflect;
216     case IntrinsicType::kRefract:
217       return GLSLstd450Refract;
218     case IntrinsicType::kRound:
219       return GLSLstd450RoundEven;
220     case IntrinsicType::kSign:
221       return GLSLstd450FSign;
222     case IntrinsicType::kSin:
223       return GLSLstd450Sin;
224     case IntrinsicType::kSinh:
225       return GLSLstd450Sinh;
226     case IntrinsicType::kSmoothStep:
227       return GLSLstd450SmoothStep;
228     case IntrinsicType::kSqrt:
229       return GLSLstd450Sqrt;
230     case IntrinsicType::kStep:
231       return GLSLstd450Step;
232     case IntrinsicType::kTan:
233       return GLSLstd450Tan;
234     case IntrinsicType::kTanh:
235       return GLSLstd450Tanh;
236     case IntrinsicType::kTrunc:
237       return GLSLstd450Trunc;
238     case IntrinsicType::kUnpack4x8snorm:
239       return GLSLstd450UnpackSnorm4x8;
240     case IntrinsicType::kUnpack4x8unorm:
241       return GLSLstd450UnpackUnorm4x8;
242     case IntrinsicType::kUnpack2x16snorm:
243       return GLSLstd450UnpackSnorm2x16;
244     case IntrinsicType::kUnpack2x16unorm:
245       return GLSLstd450UnpackUnorm2x16;
246     case IntrinsicType::kUnpack2x16float:
247       return GLSLstd450UnpackHalf2x16;
248     default:
249       break;
250   }
251   return 0;
252 }
253 
254 /// @return the vector element type if ty is a vector, otherwise return ty.
ElementTypeOf(const sem::Type * ty)255 const sem::Type* ElementTypeOf(const sem::Type* ty) {
256   if (auto* v = ty->As<sem::Vector>()) {
257     return v->type();
258   }
259   return ty;
260 }
261 
262 }  // namespace
263 
Sanitize(const Program * in,bool emit_vertex_point_size,bool disable_workgroup_init)264 SanitizedResult Sanitize(const Program* in,
265                          bool emit_vertex_point_size,
266                          bool disable_workgroup_init) {
267   transform::Manager manager;
268   transform::DataMap data;
269 
270   manager.Add<transform::Unshadow>();
271   if (!disable_workgroup_init) {
272     manager.Add<transform::ZeroInitWorkgroupMemory>();
273   }
274   manager.Add<transform::SimplifyPointers>();  // Required for arrayLength()
275   manager.Add<transform::FoldConstants>();
276   manager.Add<transform::ExternalTextureTransform>();
277   manager.Add<transform::VectorizeScalarMatrixConstructors>();
278   manager.Add<transform::ForLoopToLoop>();  // Must come after
279                                             // ZeroInitWorkgroupMemory
280   manager.Add<transform::CanonicalizeEntryPointIO>();
281   manager.Add<transform::AddEmptyEntryPoint>();
282 
283   data.Add<transform::CanonicalizeEntryPointIO::Config>(
284       transform::CanonicalizeEntryPointIO::Config(
285           transform::CanonicalizeEntryPointIO::ShaderStyle::kSpirv, 0xFFFFFFFF,
286           emit_vertex_point_size));
287 
288   SanitizedResult result;
289   result.program = std::move(manager.Run(in, data).program);
290   return result;
291 }
292 
AccessorInfo()293 Builder::AccessorInfo::AccessorInfo() : source_id(0), source_type(nullptr) {}
294 
~AccessorInfo()295 Builder::AccessorInfo::~AccessorInfo() {}
296 
Builder(const Program * program)297 Builder::Builder(const Program* program)
298     : builder_(ProgramBuilder::Wrap(program)), scope_stack_({}) {}
299 
300 Builder::~Builder() = default;
301 
Build()302 bool Builder::Build() {
303   push_capability(SpvCapabilityShader);
304 
305   push_memory_model(spv::Op::OpMemoryModel,
306                     {Operand::Int(SpvAddressingModelLogical),
307                      Operand::Int(SpvMemoryModelGLSL450)});
308 
309   for (auto* var : builder_.AST().GlobalVariables()) {
310     if (!GenerateGlobalVariable(var)) {
311       return false;
312     }
313   }
314 
315   for (auto* func : builder_.AST().Functions()) {
316     if (!GenerateFunction(func)) {
317       return false;
318     }
319   }
320 
321   return true;
322 }
323 
result_op()324 Operand Builder::result_op() {
325   return Operand::Int(next_id());
326 }
327 
total_size() const328 uint32_t Builder::total_size() const {
329   // The 5 covers the magic, version, generator, id bound and reserved.
330   uint32_t size = 5;
331 
332   size += size_of(capabilities_);
333   size += size_of(extensions_);
334   size += size_of(ext_imports_);
335   size += size_of(memory_model_);
336   size += size_of(entry_points_);
337   size += size_of(execution_modes_);
338   size += size_of(debug_);
339   size += size_of(annotations_);
340   size += size_of(types_);
341   for (const auto& func : functions_) {
342     size += func.word_length();
343   }
344 
345   return size;
346 }
347 
iterate(std::function<void (const Instruction &)> cb) const348 void Builder::iterate(std::function<void(const Instruction&)> cb) const {
349   for (const auto& inst : capabilities_) {
350     cb(inst);
351   }
352   for (const auto& inst : extensions_) {
353     cb(inst);
354   }
355   for (const auto& inst : ext_imports_) {
356     cb(inst);
357   }
358   for (const auto& inst : memory_model_) {
359     cb(inst);
360   }
361   for (const auto& inst : entry_points_) {
362     cb(inst);
363   }
364   for (const auto& inst : execution_modes_) {
365     cb(inst);
366   }
367   for (const auto& inst : debug_) {
368     cb(inst);
369   }
370   for (const auto& inst : annotations_) {
371     cb(inst);
372   }
373   for (const auto& inst : types_) {
374     cb(inst);
375   }
376   for (const auto& func : functions_) {
377     func.iterate(cb);
378   }
379 }
380 
push_capability(uint32_t cap)381 void Builder::push_capability(uint32_t cap) {
382   if (capability_set_.count(cap) == 0) {
383     capability_set_.insert(cap);
384     capabilities_.push_back(
385         Instruction{spv::Op::OpCapability, {Operand::Int(cap)}});
386   }
387 }
388 
GenerateLabel(uint32_t id)389 bool Builder::GenerateLabel(uint32_t id) {
390   if (!push_function_inst(spv::Op::OpLabel, {Operand::Int(id)})) {
391     return false;
392   }
393   current_label_id_ = id;
394   return true;
395 }
396 
GenerateAssignStatement(const ast::AssignmentStatement * assign)397 bool Builder::GenerateAssignStatement(const ast::AssignmentStatement* assign) {
398   if (assign->lhs->Is<ast::PhonyExpression>()) {
399     auto rhs_id = GenerateExpression(assign->rhs);
400     if (rhs_id == 0) {
401       return false;
402     }
403     return true;
404   } else {
405     auto lhs_id = GenerateExpression(assign->lhs);
406     if (lhs_id == 0) {
407       return false;
408     }
409     auto rhs_id = GenerateExpression(assign->rhs);
410     if (rhs_id == 0) {
411       return false;
412     }
413 
414     // If the thing we're assigning is a reference then we must load it first.
415     auto* type = TypeOf(assign->rhs);
416     rhs_id = GenerateLoadIfNeeded(type, rhs_id);
417 
418     return GenerateStore(lhs_id, rhs_id);
419   }
420 }
421 
GenerateBreakStatement(const ast::BreakStatement *)422 bool Builder::GenerateBreakStatement(const ast::BreakStatement*) {
423   if (merge_stack_.empty()) {
424     error_ = "Attempted to break without a merge block";
425     return false;
426   }
427   if (!push_function_inst(spv::Op::OpBranch,
428                           {Operand::Int(merge_stack_.back())})) {
429     return false;
430   }
431   return true;
432 }
433 
GenerateContinueStatement(const ast::ContinueStatement *)434 bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
435   if (continue_stack_.empty()) {
436     error_ = "Attempted to continue without a continue block";
437     return false;
438   }
439   if (!push_function_inst(spv::Op::OpBranch,
440                           {Operand::Int(continue_stack_.back())})) {
441     return false;
442   }
443   return true;
444 }
445 
446 // TODO(dsinclair): This is generating an OpKill but the semantics of kill
447 // haven't been defined for WGSL yet. So, this may need to change.
448 // https://github.com/gpuweb/gpuweb/issues/676
GenerateDiscardStatement(const ast::DiscardStatement *)449 bool Builder::GenerateDiscardStatement(const ast::DiscardStatement*) {
450   if (!push_function_inst(spv::Op::OpKill, {})) {
451     return false;
452   }
453   return true;
454 }
455 
GenerateEntryPoint(const ast::Function * func,uint32_t id)456 bool Builder::GenerateEntryPoint(const ast::Function* func, uint32_t id) {
457   auto stage = pipeline_stage_to_execution_model(func->PipelineStage());
458   if (stage == SpvExecutionModelMax) {
459     error_ = "Unknown pipeline stage provided";
460     return false;
461   }
462 
463   OperandList operands = {
464       Operand::Int(stage), Operand::Int(id),
465       Operand::String(builder_.Symbols().NameFor(func->symbol))};
466 
467   auto* func_sem = builder_.Sem().Get(func);
468   for (const auto* var : func_sem->TransitivelyReferencedGlobals()) {
469     // For SPIR-V 1.3 we only output Input/output variables. If we update to
470     // SPIR-V 1.4 or later this should be all variables.
471     if (var->StorageClass() != ast::StorageClass::kInput &&
472         var->StorageClass() != ast::StorageClass::kOutput) {
473       continue;
474     }
475 
476     uint32_t var_id = scope_stack_.Get(var->Declaration()->symbol);
477     if (var_id == 0) {
478       error_ = "unable to find ID for global variable: " +
479                builder_.Symbols().NameFor(var->Declaration()->symbol);
480       return false;
481     }
482 
483     operands.push_back(Operand::Int(var_id));
484   }
485   push_entry_point(spv::Op::OpEntryPoint, operands);
486 
487   return true;
488 }
489 
GenerateExecutionModes(const ast::Function * func,uint32_t id)490 bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
491   auto* func_sem = builder_.Sem().Get(func);
492 
493   // WGSL fragment shader origin is upper left
494   if (func->PipelineStage() == ast::PipelineStage::kFragment) {
495     push_execution_mode(
496         spv::Op::OpExecutionMode,
497         {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
498   } else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
499     auto& wgsize = func_sem->WorkgroupSize();
500 
501     // Check if the workgroup_size uses pipeline-overridable constants.
502     if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
503         wgsize[2].overridable_const) {
504       if (has_overridable_workgroup_size_) {
505         // Only one stage can have a pipeline-overridable workgroup size.
506         // TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario.
507         TINT_ICE(Writer, builder_.Diagnostics())
508             << "multiple stages using pipeline-overridable workgroup sizes";
509       }
510       has_overridable_workgroup_size_ = true;
511 
512       auto* vec3_u32 =
513           builder_.create<sem::Vector>(builder_.create<sem::U32>(), 3);
514       uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(vec3_u32);
515       if (vec3_u32_type_id == 0) {
516         return 0;
517       }
518 
519       OperandList wgsize_ops;
520       auto wgsize_result = result_op();
521       wgsize_ops.push_back(Operand::Int(vec3_u32_type_id));
522       wgsize_ops.push_back(wgsize_result);
523 
524       // Generate OpConstant instructions for each dimension.
525       for (int i = 0; i < 3; i++) {
526         auto constant = ScalarConstant::U32(wgsize[i].value);
527         if (wgsize[i].overridable_const) {
528           // Make the constant specializable.
529           auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
530               wgsize[i].overridable_const);
531           if (!sem_const->IsOverridable()) {
532             TINT_ICE(Writer, builder_.Diagnostics())
533                 << "expected a pipeline-overridable constant";
534           }
535           constant.is_spec_op = true;
536           constant.constant_id = sem_const->ConstantId();
537         }
538 
539         auto result = GenerateConstantIfNeeded(constant);
540         wgsize_ops.push_back(Operand::Int(result));
541       }
542 
543       // Generate the WorkgroupSize builtin.
544       push_type(spv::Op::OpSpecConstantComposite, wgsize_ops);
545       push_annot(spv::Op::OpDecorate,
546                  {wgsize_result, Operand::Int(SpvDecorationBuiltIn),
547                   Operand::Int(SpvBuiltInWorkgroupSize)});
548     } else {
549       // Not overridable, so just use OpExecutionMode LocalSize.
550       uint32_t x = wgsize[0].value;
551       uint32_t y = wgsize[1].value;
552       uint32_t z = wgsize[2].value;
553       push_execution_mode(
554           spv::Op::OpExecutionMode,
555           {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
556            Operand::Int(x), Operand::Int(y), Operand::Int(z)});
557     }
558   }
559 
560   for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {
561     if (builtin.second->builtin == ast::Builtin::kFragDepth) {
562       push_execution_mode(
563           spv::Op::OpExecutionMode,
564           {Operand::Int(id), Operand::Int(SpvExecutionModeDepthReplacing)});
565     }
566   }
567 
568   return true;
569 }
570 
GenerateExpression(const ast::Expression * expr)571 uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
572   if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
573     return GenerateAccessorExpression(a);
574   }
575   if (auto* b = expr->As<ast::BinaryExpression>()) {
576     return GenerateBinaryExpression(b);
577   }
578   if (auto* b = expr->As<ast::BitcastExpression>()) {
579     return GenerateBitcastExpression(b);
580   }
581   if (auto* c = expr->As<ast::CallExpression>()) {
582     return GenerateCallExpression(c);
583   }
584   if (auto* i = expr->As<ast::IdentifierExpression>()) {
585     return GenerateIdentifierExpression(i);
586   }
587   if (auto* l = expr->As<ast::LiteralExpression>()) {
588     return GenerateLiteralIfNeeded(nullptr, l);
589   }
590   if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
591     return GenerateAccessorExpression(m);
592   }
593   if (auto* u = expr->As<ast::UnaryOpExpression>()) {
594     return GenerateUnaryOpExpression(u);
595   }
596 
597   error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
598   return 0;
599 }
600 
GenerateFunction(const ast::Function * func_ast)601 bool Builder::GenerateFunction(const ast::Function* func_ast) {
602   auto* func = builder_.Sem().Get(func_ast);
603 
604   uint32_t func_type_id = GenerateFunctionTypeIfNeeded(func);
605   if (func_type_id == 0) {
606     return false;
607   }
608 
609   auto func_op = result_op();
610   auto func_id = func_op.to_i();
611 
612   push_debug(spv::Op::OpName,
613              {Operand::Int(func_id),
614               Operand::String(builder_.Symbols().NameFor(func_ast->symbol))});
615 
616   auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
617   if (ret_id == 0) {
618     return false;
619   }
620 
621   scope_stack_.Push();
622   TINT_DEFER(scope_stack_.Pop());
623 
624   auto definition_inst = Instruction{
625       spv::Op::OpFunction,
626       {Operand::Int(ret_id), func_op, Operand::Int(SpvFunctionControlMaskNone),
627        Operand::Int(func_type_id)}};
628 
629   InstructionList params;
630   for (auto* param : func->Parameters()) {
631     auto param_op = result_op();
632     auto param_id = param_op.to_i();
633 
634     auto param_type_id = GenerateTypeIfNeeded(param->Type());
635     if (param_type_id == 0) {
636       return false;
637     }
638 
639     push_debug(spv::Op::OpName, {Operand::Int(param_id),
640                                  Operand::String(builder_.Symbols().NameFor(
641                                      param->Declaration()->symbol))});
642     params.push_back(Instruction{spv::Op::OpFunctionParameter,
643                                  {Operand::Int(param_type_id), param_op}});
644 
645     scope_stack_.Set(param->Declaration()->symbol, param_id);
646   }
647 
648   push_function(Function{definition_inst, result_op(), std::move(params)});
649 
650   for (auto* stmt : func_ast->body->statements) {
651     if (!GenerateStatement(stmt)) {
652       return false;
653     }
654   }
655 
656   if (!LastIsTerminator(func_ast->body)) {
657     if (func->ReturnType()->Is<sem::Void>()) {
658       push_function_inst(spv::Op::OpReturn, {});
659     } else {
660       auto zero = GenerateConstantNullIfNeeded(func->ReturnType());
661       push_function_inst(spv::Op::OpReturnValue, {Operand::Int(zero)});
662     }
663   }
664 
665   if (func_ast->IsEntryPoint()) {
666     if (!GenerateEntryPoint(func_ast, func_id)) {
667       return false;
668     }
669     if (!GenerateExecutionModes(func_ast, func_id)) {
670       return false;
671     }
672   }
673 
674   func_symbol_to_id_[func_ast->symbol] = func_id;
675 
676   return true;
677 }
678 
GenerateFunctionTypeIfNeeded(const sem::Function * func)679 uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
680   return utils::GetOrCreate(
681       func_sig_to_id_, func->Signature(), [&]() -> uint32_t {
682         auto func_op = result_op();
683         auto func_type_id = func_op.to_i();
684 
685         auto ret_id = GenerateTypeIfNeeded(func->ReturnType());
686         if (ret_id == 0) {
687           return 0;
688         }
689 
690         OperandList ops = {func_op, Operand::Int(ret_id)};
691         for (auto* param : func->Parameters()) {
692           auto param_type_id = GenerateTypeIfNeeded(param->Type());
693           if (param_type_id == 0) {
694             return 0;
695           }
696           ops.push_back(Operand::Int(param_type_id));
697         }
698 
699         push_type(spv::Op::OpTypeFunction, std::move(ops));
700         return func_type_id;
701       });
702 }
703 
GenerateFunctionVariable(const ast::Variable * var)704 bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
705   uint32_t init_id = 0;
706   if (var->constructor) {
707     init_id = GenerateExpression(var->constructor);
708     if (init_id == 0) {
709       return false;
710     }
711     auto* type = TypeOf(var->constructor);
712     if (type->Is<sem::Reference>()) {
713       init_id = GenerateLoadIfNeeded(type, init_id);
714     }
715   }
716 
717   if (var->is_const) {
718     if (!var->constructor) {
719       error_ = "missing constructor for constant";
720       return false;
721     }
722     scope_stack_.Set(var->symbol, init_id);
723     spirv_id_to_variable_[init_id] = var;
724     return true;
725   }
726 
727   auto result = result_op();
728   auto var_id = result.to_i();
729   auto sc = ast::StorageClass::kFunction;
730   auto* type = builder_.Sem().Get(var)->Type();
731   auto type_id = GenerateTypeIfNeeded(type);
732   if (type_id == 0) {
733     return false;
734   }
735 
736   push_debug(spv::Op::OpName,
737              {Operand::Int(var_id),
738               Operand::String(builder_.Symbols().NameFor(var->symbol))});
739 
740   // TODO(dsinclair) We could detect if the constructor is fully const and emit
741   // an initializer value for the variable instead of doing the OpLoad.
742   auto null_id = GenerateConstantNullIfNeeded(type->UnwrapRef());
743   if (null_id == 0) {
744     return 0;
745   }
746   push_function_var({Operand::Int(type_id), result,
747                      Operand::Int(ConvertStorageClass(sc)),
748                      Operand::Int(null_id)});
749 
750   if (var->constructor) {
751     if (!GenerateStore(var_id, init_id)) {
752       return false;
753     }
754   }
755 
756   scope_stack_.Set(var->symbol, var_id);
757   spirv_id_to_variable_[var_id] = var;
758 
759   return true;
760 }
761 
GenerateStore(uint32_t to,uint32_t from)762 bool Builder::GenerateStore(uint32_t to, uint32_t from) {
763   return push_function_inst(spv::Op::OpStore,
764                             {Operand::Int(to), Operand::Int(from)});
765 }
766 
GenerateGlobalVariable(const ast::Variable * var)767 bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
768   auto* sem = builder_.Sem().Get(var);
769   auto* type = sem->Type()->UnwrapRef();
770 
771   uint32_t init_id = 0;
772   if (var->constructor) {
773     init_id = GenerateConstructorExpression(var, var->constructor);
774     if (init_id == 0) {
775       return false;
776     }
777   }
778 
779   if (var->is_const) {
780     if (!var->constructor) {
781       // Constants must have an initializer unless they have an override
782       // decoration.
783       if (!ast::HasDecoration<ast::OverrideDecoration>(var->decorations)) {
784         error_ = "missing constructor for constant";
785         return false;
786       }
787 
788       // SPIR-V requires specialization constants to have initializers.
789       if (type->Is<sem::F32>()) {
790         ast::FloatLiteralExpression l(ProgramID(), Source{}, 0.0f);
791         init_id = GenerateLiteralIfNeeded(var, &l);
792       } else if (type->Is<sem::U32>()) {
793         ast::UintLiteralExpression l(ProgramID(), Source{}, 0);
794         init_id = GenerateLiteralIfNeeded(var, &l);
795       } else if (type->Is<sem::I32>()) {
796         ast::SintLiteralExpression l(ProgramID(), Source{}, 0);
797         init_id = GenerateLiteralIfNeeded(var, &l);
798       } else if (type->Is<sem::Bool>()) {
799         ast::BoolLiteralExpression l(ProgramID(), Source{}, false);
800         init_id = GenerateLiteralIfNeeded(var, &l);
801       } else {
802         error_ = "invalid type for pipeline constant ID, must be scalar";
803         return false;
804       }
805       if (init_id == 0) {
806         return 0;
807       }
808     }
809     push_debug(spv::Op::OpName,
810                {Operand::Int(init_id),
811                 Operand::String(builder_.Symbols().NameFor(var->symbol))});
812 
813     scope_stack_.Set(var->symbol, init_id);
814     spirv_id_to_variable_[init_id] = var;
815     return true;
816   }
817 
818   auto result = result_op();
819   auto var_id = result.to_i();
820 
821   auto sc = sem->StorageClass() == ast::StorageClass::kNone
822                 ? ast::StorageClass::kPrivate
823                 : sem->StorageClass();
824 
825   auto type_id = GenerateTypeIfNeeded(sem->Type());
826   if (type_id == 0) {
827     return false;
828   }
829 
830   push_debug(spv::Op::OpName,
831              {Operand::Int(var_id),
832               Operand::String(builder_.Symbols().NameFor(var->symbol))});
833 
834   OperandList ops = {Operand::Int(type_id), result,
835                      Operand::Int(ConvertStorageClass(sc))};
836 
837   if (var->constructor) {
838     ops.push_back(Operand::Int(init_id));
839   } else {
840     auto* st = type->As<sem::StorageTexture>();
841     if (st || type->Is<sem::Struct>()) {
842       // type is a sem::Struct or a sem::StorageTexture
843       auto access = st ? st->access() : sem->Access();
844       switch (access) {
845         case ast::Access::kWrite:
846           push_annot(
847               spv::Op::OpDecorate,
848               {Operand::Int(var_id), Operand::Int(SpvDecorationNonReadable)});
849           break;
850         case ast::Access::kRead:
851           push_annot(
852               spv::Op::OpDecorate,
853               {Operand::Int(var_id), Operand::Int(SpvDecorationNonWritable)});
854           break;
855         case ast::Access::kUndefined:
856         case ast::Access::kReadWrite:
857           break;
858       }
859     }
860     if (!type->Is<sem::Sampler>()) {
861       // If we don't have a constructor and we're an Output or Private
862       // variable, then WGSL requires that we zero-initialize.
863       if (sem->StorageClass() == ast::StorageClass::kPrivate ||
864           sem->StorageClass() == ast::StorageClass::kOutput) {
865         init_id = GenerateConstantNullIfNeeded(type);
866         if (init_id == 0) {
867           return 0;
868         }
869         ops.push_back(Operand::Int(init_id));
870       }
871     }
872   }
873 
874   push_type(spv::Op::OpVariable, std::move(ops));
875 
876   for (auto* deco : var->decorations) {
877     if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
878       push_annot(spv::Op::OpDecorate,
879                  {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
880                   Operand::Int(
881                       ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
882     } else if (auto* location = deco->As<ast::LocationDecoration>()) {
883       push_annot(spv::Op::OpDecorate,
884                  {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
885                   Operand::Int(location->value)});
886     } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
887       AddInterpolationDecorations(var_id, interpolate->type,
888                                   interpolate->sampling);
889     } else if (deco->Is<ast::InvariantDecoration>()) {
890       push_annot(spv::Op::OpDecorate,
891                  {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
892     } else if (auto* binding = deco->As<ast::BindingDecoration>()) {
893       push_annot(spv::Op::OpDecorate,
894                  {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
895                   Operand::Int(binding->value)});
896     } else if (auto* group = deco->As<ast::GroupDecoration>()) {
897       push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
898                                        Operand::Int(SpvDecorationDescriptorSet),
899                                        Operand::Int(group->value)});
900     } else if (deco->Is<ast::OverrideDecoration>()) {
901       // Spec constants are handled elsewhere
902     } else if (!deco->Is<ast::InternalDecoration>()) {
903       error_ = "unknown decoration";
904       return false;
905     }
906   }
907 
908   scope_stack_.Set(var->symbol, var_id);
909   spirv_id_to_variable_[var_id] = var;
910   return true;
911 }
912 
GenerateIndexAccessor(const ast::IndexAccessorExpression * expr,AccessorInfo * info)913 bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr,
914                                     AccessorInfo* info) {
915   auto idx_id = GenerateExpression(expr->index);
916   if (idx_id == 0) {
917     return 0;
918   }
919   auto* type = TypeOf(expr->index);
920   idx_id = GenerateLoadIfNeeded(type, idx_id);
921 
922   // If the source is a reference, we access chain into it.
923   // In the future, pointers may support access-chaining.
924   // See https://github.com/gpuweb/gpuweb/pull/1580
925   if (info->source_type->Is<sem::Reference>()) {
926     info->access_chain_indices.push_back(idx_id);
927     info->source_type = TypeOf(expr);
928     return true;
929   }
930 
931   auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
932   if (result_type_id == 0) {
933     return false;
934   }
935 
936   // We don't have a pointer, so we can just directly extract the value.
937   auto extract = result_op();
938   auto extract_id = extract.to_i();
939 
940   // If the index is a literal, we use OpCompositeExtract.
941   if (auto* literal = expr->index->As<ast::IntLiteralExpression>()) {
942     if (!push_function_inst(spv::Op::OpCompositeExtract,
943                             {Operand::Int(result_type_id), extract,
944                              Operand::Int(info->source_id),
945                              Operand::Int(literal->ValueAsU32())})) {
946       return false;
947     }
948 
949     info->source_id = extract_id;
950     info->source_type = TypeOf(expr);
951 
952     return true;
953   }
954 
955   // If the source is a vector, we use OpVectorExtractDynamic.
956   if (info->source_type->Is<sem::Vector>()) {
957     if (!push_function_inst(
958             spv::Op::OpVectorExtractDynamic,
959             {Operand::Int(result_type_id), extract,
960              Operand::Int(info->source_id), Operand::Int(idx_id)})) {
961       return false;
962     }
963 
964     info->source_id = extract_id;
965     info->source_type = TypeOf(expr);
966 
967     return true;
968   }
969 
970   TINT_ICE(Writer, builder_.Diagnostics())
971       << "unsupported index accessor expression";
972   return false;
973 }
974 
GenerateMemberAccessor(const ast::MemberAccessorExpression * expr,AccessorInfo * info)975 bool Builder::GenerateMemberAccessor(const ast::MemberAccessorExpression* expr,
976                                      AccessorInfo* info) {
977   auto* expr_sem = builder_.Sem().Get(expr);
978   auto* expr_type = expr_sem->Type();
979 
980   if (auto* access = expr_sem->As<sem::StructMemberAccess>()) {
981     uint32_t idx = access->Member()->Index();
982 
983     if (info->source_type->Is<sem::Reference>()) {
984       auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(idx));
985       if (idx_id == 0) {
986         return 0;
987       }
988       info->access_chain_indices.push_back(idx_id);
989       info->source_type = expr_type;
990     } else {
991       auto result_type_id = GenerateTypeIfNeeded(expr_type);
992       if (result_type_id == 0) {
993         return false;
994       }
995 
996       auto extract = result_op();
997       auto extract_id = extract.to_i();
998       if (!push_function_inst(
999               spv::Op::OpCompositeExtract,
1000               {Operand::Int(result_type_id), extract,
1001                Operand::Int(info->source_id), Operand::Int(idx)})) {
1002         return false;
1003       }
1004 
1005       info->source_id = extract_id;
1006       info->source_type = expr_type;
1007     }
1008 
1009     return true;
1010   }
1011 
1012   if (auto* swizzle = expr_sem->As<sem::Swizzle>()) {
1013     // Single element swizzle is either an access chain or a composite extract
1014     auto& indices = swizzle->Indices();
1015     if (indices.size() == 1) {
1016       if (info->source_type->Is<sem::Reference>()) {
1017         auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(indices[0]));
1018         if (idx_id == 0) {
1019           return 0;
1020         }
1021         info->access_chain_indices.push_back(idx_id);
1022       } else {
1023         auto result_type_id = GenerateTypeIfNeeded(expr_type);
1024         if (result_type_id == 0) {
1025           return 0;
1026         }
1027 
1028         auto extract = result_op();
1029         auto extract_id = extract.to_i();
1030         if (!push_function_inst(
1031                 spv::Op::OpCompositeExtract,
1032                 {Operand::Int(result_type_id), extract,
1033                  Operand::Int(info->source_id), Operand::Int(indices[0])})) {
1034           return false;
1035         }
1036 
1037         info->source_id = extract_id;
1038         info->source_type = expr_type;
1039       }
1040       return true;
1041     }
1042 
1043     // Store the type away as it may change if we run the access chain
1044     auto* incoming_type = info->source_type;
1045 
1046     // Multi-item extract is a VectorShuffle. We have to emit any existing
1047     // access chain data, then load the access chain and shuffle that.
1048     if (!info->access_chain_indices.empty()) {
1049       auto result_type_id = GenerateTypeIfNeeded(info->source_type);
1050       if (result_type_id == 0) {
1051         return 0;
1052       }
1053       auto extract = result_op();
1054       auto extract_id = extract.to_i();
1055 
1056       OperandList ops = {Operand::Int(result_type_id), extract,
1057                          Operand::Int(info->source_id)};
1058       for (auto id : info->access_chain_indices) {
1059         ops.push_back(Operand::Int(id));
1060       }
1061 
1062       if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
1063         return false;
1064       }
1065 
1066       info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
1067       info->source_type = expr_type->UnwrapRef();
1068       info->access_chain_indices.clear();
1069     }
1070 
1071     auto result_type_id = GenerateTypeIfNeeded(expr_type);
1072     if (result_type_id == 0) {
1073       return false;
1074     }
1075 
1076     auto vec_id = GenerateLoadIfNeeded(incoming_type, info->source_id);
1077 
1078     auto result = result_op();
1079     auto result_id = result.to_i();
1080 
1081     OperandList ops = {Operand::Int(result_type_id), result,
1082                        Operand::Int(vec_id), Operand::Int(vec_id)};
1083 
1084     for (auto idx : indices) {
1085       ops.push_back(Operand::Int(idx));
1086     }
1087 
1088     if (!push_function_inst(spv::Op::OpVectorShuffle, ops)) {
1089       return false;
1090     }
1091     info->source_id = result_id;
1092     info->source_type = expr_type;
1093     return true;
1094   }
1095 
1096   TINT_ICE(Writer, builder_.Diagnostics())
1097       << "unhandled member index type: " << expr_sem->TypeInfo().name;
1098   return false;
1099 }
1100 
GenerateAccessorExpression(const ast::Expression * expr)1101 uint32_t Builder::GenerateAccessorExpression(const ast::Expression* expr) {
1102   if (!expr->IsAnyOf<ast::IndexAccessorExpression,
1103                      ast::MemberAccessorExpression>()) {
1104     TINT_ICE(Writer, builder_.Diagnostics()) << "expression is not an accessor";
1105     return 0;
1106   }
1107 
1108   // Gather a list of all the member and index accessors that are in this chain.
1109   // The list is built in reverse order as that's the order we need to access
1110   // the chain.
1111   std::vector<const ast::Expression*> accessors;
1112   const ast::Expression* source = expr;
1113   while (true) {
1114     if (auto* array = source->As<ast::IndexAccessorExpression>()) {
1115       accessors.insert(accessors.begin(), source);
1116       source = array->object;
1117     } else if (auto* member = source->As<ast::MemberAccessorExpression>()) {
1118       accessors.insert(accessors.begin(), source);
1119       source = member->structure;
1120     } else {
1121       break;
1122     }
1123   }
1124 
1125   AccessorInfo info;
1126   info.source_id = GenerateExpression(source);
1127   if (info.source_id == 0) {
1128     return 0;
1129   }
1130   info.source_type = TypeOf(source);
1131 
1132   for (auto* accessor : accessors) {
1133     if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
1134       if (!GenerateIndexAccessor(array, &info)) {
1135         return 0;
1136       }
1137     } else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
1138       if (!GenerateMemberAccessor(member, &info)) {
1139         return 0;
1140       }
1141 
1142     } else {
1143       error_ =
1144           "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
1145       return 0;
1146     }
1147   }
1148 
1149   if (!info.access_chain_indices.empty()) {
1150     auto* type = TypeOf(expr);
1151     auto result_type_id = GenerateTypeIfNeeded(type);
1152     if (result_type_id == 0) {
1153       return 0;
1154     }
1155 
1156     auto result = result_op();
1157     auto result_id = result.to_i();
1158 
1159     OperandList ops = {Operand::Int(result_type_id), result,
1160                        Operand::Int(info.source_id)};
1161     for (auto id : info.access_chain_indices) {
1162       ops.push_back(Operand::Int(id));
1163     }
1164 
1165     if (!push_function_inst(spv::Op::OpAccessChain, ops)) {
1166       return false;
1167     }
1168     info.source_id = result_id;
1169   }
1170 
1171   return info.source_id;
1172 }
1173 
GenerateIdentifierExpression(const ast::IdentifierExpression * expr)1174 uint32_t Builder::GenerateIdentifierExpression(
1175     const ast::IdentifierExpression* expr) {
1176   uint32_t val = scope_stack_.Get(expr->symbol);
1177   if (val == 0) {
1178     error_ = "unable to find variable with identifier: " +
1179              builder_.Symbols().NameFor(expr->symbol);
1180   }
1181   return val;
1182 }
1183 
GenerateLoadIfNeeded(const sem::Type * type,uint32_t id)1184 uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
1185   if (auto* ref = type->As<sem::Reference>()) {
1186     type = ref->StoreType();
1187   } else {
1188     return id;
1189   }
1190 
1191   auto type_id = GenerateTypeIfNeeded(type);
1192   auto result = result_op();
1193   auto result_id = result.to_i();
1194   if (!push_function_inst(spv::Op::OpLoad,
1195                           {Operand::Int(type_id), result, Operand::Int(id)})) {
1196     return 0;
1197   }
1198   return result_id;
1199 }
1200 
GenerateUnaryOpExpression(const ast::UnaryOpExpression * expr)1201 uint32_t Builder::GenerateUnaryOpExpression(
1202     const ast::UnaryOpExpression* expr) {
1203   auto result = result_op();
1204   auto result_id = result.to_i();
1205 
1206   auto val_id = GenerateExpression(expr->expr);
1207   if (val_id == 0) {
1208     return 0;
1209   }
1210 
1211   spv::Op op = spv::Op::OpNop;
1212   switch (expr->op) {
1213     case ast::UnaryOp::kComplement:
1214       op = spv::Op::OpNot;
1215       break;
1216     case ast::UnaryOp::kNegation:
1217       if (TypeOf(expr)->is_float_scalar_or_vector()) {
1218         op = spv::Op::OpFNegate;
1219       } else {
1220         op = spv::Op::OpSNegate;
1221       }
1222       break;
1223     case ast::UnaryOp::kNot:
1224       op = spv::Op::OpLogicalNot;
1225       break;
1226     case ast::UnaryOp::kAddressOf:
1227     case ast::UnaryOp::kIndirection:
1228       // Address-of converts a reference to a pointer, and dereference converts
1229       // a pointer to a reference. These are the same thing in SPIR-V, so this
1230       // is a no-op.
1231       return val_id;
1232   }
1233 
1234   val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
1235 
1236   auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1237   if (type_id == 0) {
1238     return 0;
1239   }
1240 
1241   if (!push_function_inst(
1242           op, {Operand::Int(type_id), result, Operand::Int(val_id)})) {
1243     return false;
1244   }
1245 
1246   return result_id;
1247 }
1248 
GetGLSLstd450Import()1249 uint32_t Builder::GetGLSLstd450Import() {
1250   auto where = import_name_to_id_.find(kGLSLstd450);
1251   if (where != import_name_to_id_.end()) {
1252     return where->second;
1253   }
1254 
1255   // It doesn't exist yet. Generate it.
1256   auto result = result_op();
1257   auto id = result.to_i();
1258 
1259   push_ext_import(spv::Op::OpExtInstImport,
1260                   {result, Operand::String(kGLSLstd450)});
1261 
1262   // Remember it for later.
1263   import_name_to_id_[kGLSLstd450] = id;
1264   return id;
1265 }
1266 
GenerateConstructorExpression(const ast::Variable * var,const ast::Expression * expr)1267 uint32_t Builder::GenerateConstructorExpression(const ast::Variable* var,
1268                                                 const ast::Expression* expr) {
1269   if (auto* literal = expr->As<ast::LiteralExpression>()) {
1270     return GenerateLiteralIfNeeded(var, literal);
1271   }
1272   if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
1273     if (call->Target()->IsAnyOf<sem::TypeConstructor, sem::TypeConversion>()) {
1274       return GenerateTypeConstructorOrConversion(call, var);
1275     }
1276   }
1277   error_ = "unknown constructor expression";
1278   return 0;
1279 }
1280 
IsConstructorConst(const ast::Expression * expr)1281 bool Builder::IsConstructorConst(const ast::Expression* expr) {
1282   bool is_const = true;
1283   ast::TraverseExpressions(expr, builder_.Diagnostics(),
1284                            [&](const ast::Expression* e) {
1285                              if (e->Is<ast::LiteralExpression>()) {
1286                                return ast::TraverseAction::Descend;
1287                              }
1288                              if (auto* ce = e->As<ast::CallExpression>()) {
1289                                auto* call = builder_.Sem().Get(ce);
1290                                if (call->Target()->Is<sem::TypeConstructor>()) {
1291                                  return ast::TraverseAction::Descend;
1292                                }
1293                              }
1294 
1295                              is_const = false;
1296                              return ast::TraverseAction::Stop;
1297                            });
1298   return is_const;
1299 }
1300 
GenerateTypeConstructorOrConversion(const sem::Call * call,const ast::Variable * var)1301 uint32_t Builder::GenerateTypeConstructorOrConversion(
1302     const sem::Call* call,
1303     const ast::Variable* var) {
1304   auto& args = call->Arguments();
1305   auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
1306   auto* result_type = call->Type();
1307 
1308   // Generate the zero initializer if there are no values provided.
1309   if (args.empty()) {
1310     if (global_var && global_var->IsOverridable()) {
1311       auto constant_id = global_var->ConstantId();
1312       if (result_type->Is<sem::I32>()) {
1313         return GenerateConstantIfNeeded(
1314             ScalarConstant::I32(0).AsSpecOp(constant_id));
1315       }
1316       if (result_type->Is<sem::U32>()) {
1317         return GenerateConstantIfNeeded(
1318             ScalarConstant::U32(0).AsSpecOp(constant_id));
1319       }
1320       if (result_type->Is<sem::F32>()) {
1321         return GenerateConstantIfNeeded(
1322             ScalarConstant::F32(0).AsSpecOp(constant_id));
1323       }
1324       if (result_type->Is<sem::Bool>()) {
1325         return GenerateConstantIfNeeded(
1326             ScalarConstant::Bool(false).AsSpecOp(constant_id));
1327       }
1328     }
1329     return GenerateConstantNullIfNeeded(result_type->UnwrapRef());
1330   }
1331 
1332   std::ostringstream out;
1333   out << "__const_" << result_type->FriendlyName(builder_.Symbols()) << "_";
1334 
1335   result_type = result_type->UnwrapRef();
1336   bool constructor_is_const = IsConstructorConst(call->Declaration());
1337   if (has_error()) {
1338     return 0;
1339   }
1340 
1341   bool can_cast_or_copy = result_type->is_scalar();
1342 
1343   if (auto* res_vec = result_type->As<sem::Vector>()) {
1344     if (res_vec->type()->is_scalar()) {
1345       auto* value_type = args[0]->Type()->UnwrapRef();
1346       if (auto* val_vec = value_type->As<sem::Vector>()) {
1347         if (val_vec->type()->is_scalar()) {
1348           can_cast_or_copy = res_vec->Width() == val_vec->Width();
1349         }
1350       }
1351     }
1352   }
1353 
1354   if (can_cast_or_copy) {
1355     return GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
1356                                            global_var);
1357   }
1358 
1359   auto type_id = GenerateTypeIfNeeded(result_type);
1360   if (type_id == 0) {
1361     return 0;
1362   }
1363 
1364   bool result_is_constant_composite = constructor_is_const;
1365   bool result_is_spec_composite = false;
1366 
1367   if (auto* vec = result_type->As<sem::Vector>()) {
1368     result_type = vec->type();
1369   }
1370 
1371   OperandList ops;
1372   for (auto* e : args) {
1373     uint32_t id = 0;
1374     id = GenerateExpression(e->Declaration());
1375     if (id == 0) {
1376       return 0;
1377     }
1378     id = GenerateLoadIfNeeded(e->Type(), id);
1379     if (id == 0) {
1380       return 0;
1381     }
1382 
1383     auto* value_type = e->Type()->UnwrapRef();
1384     // If the result and value types are the same we can just use the object.
1385     // If the result is not a vector then we should have validated that the
1386     // value type is a correctly sized vector so we can just use it directly.
1387     if (result_type == value_type || result_type->Is<sem::Matrix>() ||
1388         result_type->Is<sem::Array>() || result_type->Is<sem::Struct>()) {
1389       out << "_" << id;
1390 
1391       ops.push_back(Operand::Int(id));
1392       continue;
1393     }
1394 
1395     // Both scalars, but not the same type so we need to generate a conversion
1396     // of the value.
1397     if (value_type->is_scalar() && result_type->is_scalar()) {
1398       id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
1399                                            global_var);
1400       out << "_" << id;
1401       ops.push_back(Operand::Int(id));
1402       continue;
1403     }
1404 
1405     // When handling vectors as the values there a few cases to take into
1406     // consideration:
1407     //  1. Module scoped vec3<f32>(vec2<f32>(1, 2), 3)  -> OpSpecConstantOp
1408     //  2. Function scoped vec3<f32>(vec2<f32>(1, 2), 3) ->  OpCompositeExtract
1409     //  3. Either array<vec3<f32>, 1>(vec3<f32>(1, 2, 3))  -> use the ID.
1410     //       -> handled above
1411     //
1412     // For cases 1 and 2, if the type is different we also may need to insert
1413     // a type cast.
1414     if (auto* vec = value_type->As<sem::Vector>()) {
1415       auto* vec_type = vec->type();
1416 
1417       auto value_type_id = GenerateTypeIfNeeded(vec_type);
1418       if (value_type_id == 0) {
1419         return 0;
1420       }
1421 
1422       for (uint32_t i = 0; i < vec->Width(); ++i) {
1423         auto extract = result_op();
1424         auto extract_id = extract.to_i();
1425 
1426         if (!global_var) {
1427           // A non-global initializer. Case 2.
1428           if (!push_function_inst(spv::Op::OpCompositeExtract,
1429                                   {Operand::Int(value_type_id), extract,
1430                                    Operand::Int(id), Operand::Int(i)})) {
1431             return false;
1432           }
1433 
1434           // We no longer have a constant composite, but have to do a
1435           // composite construction as these calls are inside a function.
1436           result_is_constant_composite = false;
1437         } else {
1438           // A global initializer, must use OpSpecConstantOp. Case 1.
1439           auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
1440           if (idx_id == 0) {
1441             return 0;
1442           }
1443           push_type(spv::Op::OpSpecConstantOp,
1444                     {Operand::Int(value_type_id), extract,
1445                      Operand::Int(SpvOpCompositeExtract), Operand::Int(id),
1446                      Operand::Int(idx_id)});
1447 
1448           result_is_spec_composite = true;
1449         }
1450 
1451         out << "_" << extract_id;
1452         ops.push_back(Operand::Int(extract_id));
1453       }
1454     } else {
1455       error_ = "Unhandled type cast value type";
1456       return 0;
1457     }
1458   }
1459 
1460   // For a single-value vector initializer, splat the initializer value.
1461   auto* const init_result_type = call->Type()->UnwrapRef();
1462   if (args.size() == 1 && init_result_type->is_scalar_vector() &&
1463       args[0]->Type()->UnwrapRef()->is_scalar()) {
1464     size_t vec_size = init_result_type->As<sem::Vector>()->Width();
1465     for (size_t i = 0; i < (vec_size - 1); ++i) {
1466       ops.push_back(ops[0]);
1467     }
1468   }
1469 
1470   auto str = out.str();
1471   auto val = type_constructor_to_id_.find(str);
1472   if (val != type_constructor_to_id_.end()) {
1473     return val->second;
1474   }
1475 
1476   auto result = result_op();
1477   ops.insert(ops.begin(), result);
1478   ops.insert(ops.begin(), Operand::Int(type_id));
1479 
1480   type_constructor_to_id_[str] = result.to_i();
1481 
1482   if (result_is_spec_composite) {
1483     push_type(spv::Op::OpSpecConstantComposite, ops);
1484   } else if (result_is_constant_composite) {
1485     push_type(spv::Op::OpConstantComposite, ops);
1486   } else {
1487     if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1488       return 0;
1489     }
1490   }
1491 
1492   return result.to_i();
1493 }
1494 
GenerateCastOrCopyOrPassthrough(const sem::Type * to_type,const ast::Expression * from_expr,bool is_global_init)1495 uint32_t Builder::GenerateCastOrCopyOrPassthrough(
1496     const sem::Type* to_type,
1497     const ast::Expression* from_expr,
1498     bool is_global_init) {
1499   // This should not happen as we rely on constant folding to obviate
1500   // casts/conversions for module-scope variables
1501   if (is_global_init) {
1502     TINT_ICE(Writer, builder_.Diagnostics())
1503         << "Module-level conversions are not supported. Conversions should "
1504            "have already been constant-folded by the FoldConstants transform.";
1505     return 0;
1506   }
1507 
1508   auto elem_type_of = [](const sem::Type* t) -> const sem::Type* {
1509     if (t->is_scalar()) {
1510       return t;
1511     }
1512     if (auto* v = t->As<sem::Vector>()) {
1513       return v->type();
1514     }
1515     return nullptr;
1516   };
1517 
1518   auto result = result_op();
1519   auto result_id = result.to_i();
1520 
1521   auto result_type_id = GenerateTypeIfNeeded(to_type);
1522   if (result_type_id == 0) {
1523     return 0;
1524   }
1525 
1526   auto val_id = GenerateExpression(from_expr);
1527   if (val_id == 0) {
1528     return 0;
1529   }
1530   val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
1531 
1532   auto* from_type = TypeOf(from_expr)->UnwrapRef();
1533 
1534   spv::Op op = spv::Op::OpNop;
1535   if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) ||
1536       (from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
1537     op = spv::Op::OpConvertSToF;
1538   } else if ((from_type->Is<sem::U32>() && to_type->Is<sem::F32>()) ||
1539              (from_type->is_unsigned_integer_vector() &&
1540               to_type->is_float_vector())) {
1541     op = spv::Op::OpConvertUToF;
1542   } else if ((from_type->Is<sem::F32>() && to_type->Is<sem::I32>()) ||
1543              (from_type->is_float_vector() &&
1544               to_type->is_signed_integer_vector())) {
1545     op = spv::Op::OpConvertFToS;
1546   } else if ((from_type->Is<sem::F32>() && to_type->Is<sem::U32>()) ||
1547              (from_type->is_float_vector() &&
1548               to_type->is_unsigned_integer_vector())) {
1549     op = spv::Op::OpConvertFToU;
1550   } else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
1551              (from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
1552              (from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
1553              (from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
1554              (from_type->Is<sem::Vector>() && (from_type == to_type))) {
1555     return val_id;
1556   } else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
1557              (from_type->Is<sem::U32>() && to_type->Is<sem::I32>()) ||
1558              (from_type->is_signed_integer_vector() &&
1559               to_type->is_unsigned_integer_vector()) ||
1560              (from_type->is_unsigned_integer_vector() &&
1561               to_type->is_integer_scalar_or_vector())) {
1562     op = spv::Op::OpBitcast;
1563   } else if ((from_type->is_numeric_scalar() && to_type->Is<sem::Bool>()) ||
1564              (from_type->is_numeric_vector() && to_type->is_bool_vector())) {
1565     // Convert scalar (vector) to bool (vector)
1566 
1567     // Return the result of comparing from_expr with zero
1568     uint32_t zero = GenerateConstantNullIfNeeded(from_type);
1569     const auto* from_elem_type = elem_type_of(from_type);
1570     op = from_elem_type->is_integer_scalar() ? spv::Op::OpINotEqual
1571                                              : spv::Op::OpFUnordNotEqual;
1572     if (!push_function_inst(
1573             op, {Operand::Int(result_type_id), Operand::Int(result_id),
1574                  Operand::Int(val_id), Operand::Int(zero)})) {
1575       return 0;
1576     }
1577 
1578     return result_id;
1579   } else if (from_type->is_bool_scalar_or_vector() &&
1580              to_type->is_numeric_scalar_or_vector()) {
1581     // Convert bool scalar/vector to numeric scalar/vector.
1582     // Use the bool to select between 1 (if true) and 0 (if false).
1583 
1584     const auto* to_elem_type = elem_type_of(to_type);
1585     uint32_t one_id;
1586     uint32_t zero_id;
1587     if (to_elem_type->Is<sem::F32>()) {
1588       ast::FloatLiteralExpression one(ProgramID(), Source{}, 1.0f);
1589       ast::FloatLiteralExpression zero(ProgramID(), Source{}, 0.0f);
1590       one_id = GenerateLiteralIfNeeded(nullptr, &one);
1591       zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1592     } else if (to_elem_type->Is<sem::U32>()) {
1593       ast::UintLiteralExpression one(ProgramID(), Source{}, 1);
1594       ast::UintLiteralExpression zero(ProgramID(), Source{}, 0);
1595       one_id = GenerateLiteralIfNeeded(nullptr, &one);
1596       zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1597     } else if (to_elem_type->Is<sem::I32>()) {
1598       ast::SintLiteralExpression one(ProgramID(), Source{}, 1);
1599       ast::SintLiteralExpression zero(ProgramID(), Source{}, 0);
1600       one_id = GenerateLiteralIfNeeded(nullptr, &one);
1601       zero_id = GenerateLiteralIfNeeded(nullptr, &zero);
1602     } else {
1603       error_ = "invalid destination type for bool conversion";
1604       return false;
1605     }
1606     if (auto* to_vec = to_type->As<sem::Vector>()) {
1607       // Splat the scalars into vectors.
1608       one_id = GenerateConstantVectorSplatIfNeeded(to_vec, one_id);
1609       zero_id = GenerateConstantVectorSplatIfNeeded(to_vec, zero_id);
1610     }
1611     if (!one_id || !zero_id) {
1612       return false;
1613     }
1614 
1615     op = spv::Op::OpSelect;
1616     if (!push_function_inst(
1617             op, {Operand::Int(result_type_id), Operand::Int(result_id),
1618                  Operand::Int(val_id), Operand::Int(one_id),
1619                  Operand::Int(zero_id)})) {
1620       return 0;
1621     }
1622 
1623     return result_id;
1624   } else {
1625     TINT_ICE(Writer, builder_.Diagnostics()) << "Invalid from_type";
1626   }
1627 
1628   if (op == spv::Op::OpNop) {
1629     error_ = "unable to determine conversion type for cast, from: " +
1630              from_type->type_name() + " to: " + to_type->type_name();
1631     return 0;
1632   }
1633 
1634   if (!push_function_inst(
1635           op, {Operand::Int(result_type_id), result, Operand::Int(val_id)})) {
1636     return 0;
1637   }
1638 
1639   return result_id;
1640 }
1641 
GenerateLiteralIfNeeded(const ast::Variable * var,const ast::LiteralExpression * lit)1642 uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
1643                                           const ast::LiteralExpression* lit) {
1644   ScalarConstant constant;
1645 
1646   auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
1647   if (global && global->IsOverridable()) {
1648     constant.is_spec_op = true;
1649     constant.constant_id = global->ConstantId();
1650   }
1651 
1652   if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
1653     constant.kind = ScalarConstant::Kind::kBool;
1654     constant.value.b = l->value;
1655   } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
1656     constant.kind = ScalarConstant::Kind::kI32;
1657     constant.value.i32 = sl->value;
1658   } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
1659     constant.kind = ScalarConstant::Kind::kU32;
1660     constant.value.u32 = ul->value;
1661   } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
1662     constant.kind = ScalarConstant::Kind::kF32;
1663     constant.value.f32 = fl->value;
1664   } else {
1665     error_ = "unknown literal type";
1666     return 0;
1667   }
1668 
1669   return GenerateConstantIfNeeded(constant);
1670 }
1671 
GenerateConstantIfNeeded(const ScalarConstant & constant)1672 uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
1673   auto it = const_to_id_.find(constant);
1674   if (it != const_to_id_.end()) {
1675     return it->second;
1676   }
1677 
1678   uint32_t type_id = 0;
1679 
1680   switch (constant.kind) {
1681     case ScalarConstant::Kind::kU32: {
1682       type_id = GenerateTypeIfNeeded(builder_.create<sem::U32>());
1683       break;
1684     }
1685     case ScalarConstant::Kind::kI32: {
1686       type_id = GenerateTypeIfNeeded(builder_.create<sem::I32>());
1687       break;
1688     }
1689     case ScalarConstant::Kind::kF32: {
1690       type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
1691       break;
1692     }
1693     case ScalarConstant::Kind::kBool: {
1694       type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>());
1695       break;
1696     }
1697   }
1698 
1699   if (type_id == 0) {
1700     return 0;
1701   }
1702 
1703   auto result = result_op();
1704   auto result_id = result.to_i();
1705 
1706   if (constant.is_spec_op) {
1707     push_annot(spv::Op::OpDecorate,
1708                {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
1709                 Operand::Int(constant.constant_id)});
1710   }
1711 
1712   switch (constant.kind) {
1713     case ScalarConstant::Kind::kU32: {
1714       push_type(
1715           constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1716           {Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
1717       break;
1718     }
1719     case ScalarConstant::Kind::kI32: {
1720       push_type(
1721           constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1722           {Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
1723       break;
1724     }
1725     case ScalarConstant::Kind::kF32: {
1726       push_type(
1727           constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
1728           {Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
1729       break;
1730     }
1731     case ScalarConstant::Kind::kBool: {
1732       if (constant.value.b) {
1733         push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
1734                                       : spv::Op::OpConstantTrue,
1735                   {Operand::Int(type_id), result});
1736       } else {
1737         push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
1738                                       : spv::Op::OpConstantFalse,
1739                   {Operand::Int(type_id), result});
1740       }
1741       break;
1742     }
1743   }
1744 
1745   const_to_id_[constant] = result_id;
1746   return result_id;
1747 }
1748 
GenerateConstantNullIfNeeded(const sem::Type * type)1749 uint32_t Builder::GenerateConstantNullIfNeeded(const sem::Type* type) {
1750   auto type_id = GenerateTypeIfNeeded(type);
1751   if (type_id == 0) {
1752     return 0;
1753   }
1754 
1755   auto name = type->type_name();
1756 
1757   auto it = const_null_to_id_.find(name);
1758   if (it != const_null_to_id_.end()) {
1759     return it->second;
1760   }
1761 
1762   auto result = result_op();
1763   auto result_id = result.to_i();
1764 
1765   push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
1766 
1767   const_null_to_id_[name] = result_id;
1768   return result_id;
1769 }
1770 
GenerateConstantVectorSplatIfNeeded(const sem::Vector * type,uint32_t value_id)1771 uint32_t Builder::GenerateConstantVectorSplatIfNeeded(const sem::Vector* type,
1772                                                       uint32_t value_id) {
1773   auto type_id = GenerateTypeIfNeeded(type);
1774   if (type_id == 0 || value_id == 0) {
1775     return 0;
1776   }
1777 
1778   uint64_t key = (static_cast<uint64_t>(type->Width()) << 32) + value_id;
1779   return utils::GetOrCreate(const_splat_to_id_, key, [&] {
1780     auto result = result_op();
1781     auto result_id = result.to_i();
1782 
1783     OperandList ops;
1784     ops.push_back(Operand::Int(type_id));
1785     ops.push_back(result);
1786     for (uint32_t i = 0; i < type->Width(); i++) {
1787       ops.push_back(Operand::Int(value_id));
1788     }
1789     push_type(spv::Op::OpConstantComposite, ops);
1790 
1791     const_splat_to_id_[key] = result_id;
1792     return result_id;
1793   });
1794 }
1795 
GenerateShortCircuitBinaryExpression(const ast::BinaryExpression * expr)1796 uint32_t Builder::GenerateShortCircuitBinaryExpression(
1797     const ast::BinaryExpression* expr) {
1798   auto lhs_id = GenerateExpression(expr->lhs);
1799   if (lhs_id == 0) {
1800     return false;
1801   }
1802   lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
1803 
1804   // Get the ID of the basic block where control flow will diverge. It's the
1805   // last basic block generated for the left-hand-side of the operator.
1806   auto original_label_id = current_label_id_;
1807 
1808   auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1809   if (type_id == 0) {
1810     return 0;
1811   }
1812 
1813   auto merge_block = result_op();
1814   auto merge_block_id = merge_block.to_i();
1815 
1816   auto block = result_op();
1817   auto block_id = block.to_i();
1818 
1819   auto true_block_id = block_id;
1820   auto false_block_id = merge_block_id;
1821 
1822   // For a logical or we want to only check the RHS if the LHS is failed.
1823   if (expr->IsLogicalOr()) {
1824     std::swap(true_block_id, false_block_id);
1825   }
1826 
1827   if (!push_function_inst(spv::Op::OpSelectionMerge,
1828                           {Operand::Int(merge_block_id),
1829                            Operand::Int(SpvSelectionControlMaskNone)})) {
1830     return 0;
1831   }
1832   if (!push_function_inst(spv::Op::OpBranchConditional,
1833                           {Operand::Int(lhs_id), Operand::Int(true_block_id),
1834                            Operand::Int(false_block_id)})) {
1835     return 0;
1836   }
1837 
1838   // Output block to check the RHS
1839   if (!GenerateLabel(block_id)) {
1840     return 0;
1841   }
1842   auto rhs_id = GenerateExpression(expr->rhs);
1843   if (rhs_id == 0) {
1844     return 0;
1845   }
1846   rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
1847 
1848   // Get the block ID of the last basic block generated for the right-hand-side
1849   // expression. That block will be an immediate predecessor to the merge block.
1850   auto rhs_block_id = current_label_id_;
1851   if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) {
1852     return 0;
1853   }
1854 
1855   // Output the merge block
1856   if (!GenerateLabel(merge_block_id)) {
1857     return 0;
1858   }
1859 
1860   auto result = result_op();
1861   auto result_id = result.to_i();
1862 
1863   if (!push_function_inst(spv::Op::OpPhi,
1864                           {Operand::Int(type_id), result, Operand::Int(lhs_id),
1865                            Operand::Int(original_label_id),
1866                            Operand::Int(rhs_id), Operand::Int(rhs_block_id)})) {
1867     return 0;
1868   }
1869 
1870   return result_id;
1871 }
1872 
GenerateSplat(uint32_t scalar_id,const sem::Type * vec_type)1873 uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
1874   // Create a new vector to splat scalar into
1875   auto splat_vector = result_op();
1876   auto* splat_vector_type = builder_.create<sem::Pointer>(
1877       vec_type, ast::StorageClass::kFunction, ast::Access::kReadWrite);
1878   push_function_var(
1879       {Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
1880        Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
1881        Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
1882 
1883   // Splat scalar into vector
1884   auto splat_result = result_op();
1885   OperandList ops;
1886   ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
1887   ops.push_back(splat_result);
1888   for (size_t i = 0; i < vec_type->As<sem::Vector>()->Width(); ++i) {
1889     ops.push_back(Operand::Int(scalar_id));
1890   }
1891   if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1892     return 0;
1893   }
1894 
1895   return splat_result.to_i();
1896 }
1897 
GenerateMatrixAddOrSub(uint32_t lhs_id,uint32_t rhs_id,const sem::Matrix * type,spv::Op op)1898 uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
1899                                          uint32_t rhs_id,
1900                                          const sem::Matrix* type,
1901                                          spv::Op op) {
1902   // Example addition of two matrices:
1903   // %31 = OpLoad %mat3v4float %m34
1904   // %32 = OpLoad %mat3v4float %m34
1905   // %33 = OpCompositeExtract %v4float %31 0
1906   // %34 = OpCompositeExtract %v4float %32 0
1907   // %35 = OpFAdd %v4float %33 %34
1908   // %36 = OpCompositeExtract %v4float %31 1
1909   // %37 = OpCompositeExtract %v4float %32 1
1910   // %38 = OpFAdd %v4float %36 %37
1911   // %39 = OpCompositeExtract %v4float %31 2
1912   // %40 = OpCompositeExtract %v4float %32 2
1913   // %41 = OpFAdd %v4float %39 %40
1914   // %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
1915 
1916   auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
1917   auto column_type_id = GenerateTypeIfNeeded(column_type);
1918 
1919   OperandList ops;
1920 
1921   for (uint32_t i = 0; i < type->columns(); ++i) {
1922     // Extract column `i` from lhs mat
1923     auto lhs_column_id = result_op();
1924     if (!push_function_inst(spv::Op::OpCompositeExtract,
1925                             {Operand::Int(column_type_id), lhs_column_id,
1926                              Operand::Int(lhs_id), Operand::Int(i)})) {
1927       return 0;
1928     }
1929 
1930     // Extract column `i` from rhs mat
1931     auto rhs_column_id = result_op();
1932     if (!push_function_inst(spv::Op::OpCompositeExtract,
1933                             {Operand::Int(column_type_id), rhs_column_id,
1934                              Operand::Int(rhs_id), Operand::Int(i)})) {
1935       return 0;
1936     }
1937 
1938     // Add or subtract the two columns
1939     auto result = result_op();
1940     if (!push_function_inst(op, {Operand::Int(column_type_id), result,
1941                                  lhs_column_id, rhs_column_id})) {
1942       return 0;
1943     }
1944 
1945     ops.push_back(result);
1946   }
1947 
1948   // Create the result matrix from the added/subtracted column vectors
1949   auto result_mat_id = result_op();
1950   ops.insert(ops.begin(), result_mat_id);
1951   ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
1952   if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
1953     return 0;
1954   }
1955 
1956   return result_mat_id.to_i();
1957 }
1958 
GenerateBinaryExpression(const ast::BinaryExpression * expr)1959 uint32_t Builder::GenerateBinaryExpression(const ast::BinaryExpression* expr) {
1960   // There is special logic for short circuiting operators.
1961   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
1962     return GenerateShortCircuitBinaryExpression(expr);
1963   }
1964 
1965   auto lhs_id = GenerateExpression(expr->lhs);
1966   if (lhs_id == 0) {
1967     return 0;
1968   }
1969   lhs_id = GenerateLoadIfNeeded(TypeOf(expr->lhs), lhs_id);
1970 
1971   auto rhs_id = GenerateExpression(expr->rhs);
1972   if (rhs_id == 0) {
1973     return 0;
1974   }
1975   rhs_id = GenerateLoadIfNeeded(TypeOf(expr->rhs), rhs_id);
1976 
1977   auto result = result_op();
1978   auto result_id = result.to_i();
1979 
1980   auto type_id = GenerateTypeIfNeeded(TypeOf(expr));
1981   if (type_id == 0) {
1982     return 0;
1983   }
1984 
1985   // Handle int and float and the vectors of those types. Other types
1986   // should have been rejected by validation.
1987   auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
1988   auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
1989 
1990   // Handle matrix-matrix addition and subtraction
1991   if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
1992       rhs_type->is_float_matrix()) {
1993     auto* lhs_mat = lhs_type->As<sem::Matrix>();
1994     auto* rhs_mat = rhs_type->As<sem::Matrix>();
1995 
1996     // This should already have been validated by resolver
1997     if (lhs_mat->rows() != rhs_mat->rows() ||
1998         lhs_mat->columns() != rhs_mat->columns()) {
1999       error_ = "matrices must have same dimensionality for add or subtract";
2000       return 0;
2001     }
2002 
2003     return GenerateMatrixAddOrSub(
2004         lhs_id, rhs_id, lhs_mat,
2005         expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
2006   }
2007 
2008   // For vector-scalar arithmetic operations, splat scalar into a vector. We
2009   // skip this for multiply as we can use OpVectorTimesScalar.
2010   const bool is_float_scalar_vector_multiply =
2011       expr->IsMultiply() &&
2012       ((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
2013        (lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
2014 
2015   if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
2016     if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
2017       uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
2018       if (splat_vector_id == 0) {
2019         return 0;
2020       }
2021       rhs_id = splat_vector_id;
2022       rhs_type = lhs_type;
2023 
2024     } else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
2025       uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
2026       if (splat_vector_id == 0) {
2027         return 0;
2028       }
2029       lhs_id = splat_vector_id;
2030       lhs_type = rhs_type;
2031     }
2032   }
2033 
2034   bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
2035   bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
2036   bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
2037   bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
2038 
2039   spv::Op op = spv::Op::OpNop;
2040   if (expr->IsAnd()) {
2041     if (lhs_is_integer_or_vec) {
2042       op = spv::Op::OpBitwiseAnd;
2043     } else if (lhs_is_bool_or_vec) {
2044       op = spv::Op::OpLogicalAnd;
2045     } else {
2046       error_ = "invalid and expression";
2047       return 0;
2048     }
2049   } else if (expr->IsAdd()) {
2050     op = lhs_is_float_or_vec ? spv::Op::OpFAdd : spv::Op::OpIAdd;
2051   } else if (expr->IsDivide()) {
2052     if (lhs_is_float_or_vec) {
2053       op = spv::Op::OpFDiv;
2054     } else if (lhs_is_unsigned) {
2055       op = spv::Op::OpUDiv;
2056     } else {
2057       op = spv::Op::OpSDiv;
2058     }
2059   } else if (expr->IsEqual()) {
2060     if (lhs_is_float_or_vec) {
2061       op = spv::Op::OpFOrdEqual;
2062     } else if (lhs_is_bool_or_vec) {
2063       op = spv::Op::OpLogicalEqual;
2064     } else if (lhs_is_integer_or_vec) {
2065       op = spv::Op::OpIEqual;
2066     } else {
2067       error_ = "invalid equal expression";
2068       return 0;
2069     }
2070   } else if (expr->IsGreaterThan()) {
2071     if (lhs_is_float_or_vec) {
2072       op = spv::Op::OpFOrdGreaterThan;
2073     } else if (lhs_is_unsigned) {
2074       op = spv::Op::OpUGreaterThan;
2075     } else {
2076       op = spv::Op::OpSGreaterThan;
2077     }
2078   } else if (expr->IsGreaterThanEqual()) {
2079     if (lhs_is_float_or_vec) {
2080       op = spv::Op::OpFOrdGreaterThanEqual;
2081     } else if (lhs_is_unsigned) {
2082       op = spv::Op::OpUGreaterThanEqual;
2083     } else {
2084       op = spv::Op::OpSGreaterThanEqual;
2085     }
2086   } else if (expr->IsLessThan()) {
2087     if (lhs_is_float_or_vec) {
2088       op = spv::Op::OpFOrdLessThan;
2089     } else if (lhs_is_unsigned) {
2090       op = spv::Op::OpULessThan;
2091     } else {
2092       op = spv::Op::OpSLessThan;
2093     }
2094   } else if (expr->IsLessThanEqual()) {
2095     if (lhs_is_float_or_vec) {
2096       op = spv::Op::OpFOrdLessThanEqual;
2097     } else if (lhs_is_unsigned) {
2098       op = spv::Op::OpULessThanEqual;
2099     } else {
2100       op = spv::Op::OpSLessThanEqual;
2101     }
2102   } else if (expr->IsModulo()) {
2103     if (lhs_is_float_or_vec) {
2104       op = spv::Op::OpFRem;
2105     } else if (lhs_is_unsigned) {
2106       op = spv::Op::OpUMod;
2107     } else {
2108       op = spv::Op::OpSMod;
2109     }
2110   } else if (expr->IsMultiply()) {
2111     if (lhs_type->is_integer_scalar_or_vector()) {
2112       // If the left hand side is an integer then this _has_ to be OpIMul as
2113       // there there is no other integer multiplication.
2114       op = spv::Op::OpIMul;
2115     } else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
2116       // Float scalars multiply with OpFMul
2117       op = spv::Op::OpFMul;
2118     } else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
2119       // Float vectors must be validated to be the same size and then use OpFMul
2120       op = spv::Op::OpFMul;
2121     } else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
2122       // Scalar * Vector we need to flip lhs and rhs types
2123       // because OpVectorTimesScalar expects <vector>, <scalar>
2124       std::swap(lhs_id, rhs_id);
2125       op = spv::Op::OpVectorTimesScalar;
2126     } else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
2127       // float vector * scalar
2128       op = spv::Op::OpVectorTimesScalar;
2129     } else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
2130       // Scalar * Matrix we need to flip lhs and rhs types because
2131       // OpMatrixTimesScalar expects <matrix>, <scalar>
2132       std::swap(lhs_id, rhs_id);
2133       op = spv::Op::OpMatrixTimesScalar;
2134     } else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
2135       // float matrix * scalar
2136       op = spv::Op::OpMatrixTimesScalar;
2137     } else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
2138       // float vector * matrix
2139       op = spv::Op::OpVectorTimesMatrix;
2140     } else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
2141       // float matrix * vector
2142       op = spv::Op::OpMatrixTimesVector;
2143     } else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
2144       // float matrix * matrix
2145       op = spv::Op::OpMatrixTimesMatrix;
2146     } else {
2147       error_ = "invalid multiply expression";
2148       return 0;
2149     }
2150   } else if (expr->IsNotEqual()) {
2151     if (lhs_is_float_or_vec) {
2152       op = spv::Op::OpFOrdNotEqual;
2153     } else if (lhs_is_bool_or_vec) {
2154       op = spv::Op::OpLogicalNotEqual;
2155     } else if (lhs_is_integer_or_vec) {
2156       op = spv::Op::OpINotEqual;
2157     } else {
2158       error_ = "invalid not-equal expression";
2159       return 0;
2160     }
2161   } else if (expr->IsOr()) {
2162     if (lhs_is_integer_or_vec) {
2163       op = spv::Op::OpBitwiseOr;
2164     } else if (lhs_is_bool_or_vec) {
2165       op = spv::Op::OpLogicalOr;
2166     } else {
2167       error_ = "invalid and expression";
2168       return 0;
2169     }
2170   } else if (expr->IsShiftLeft()) {
2171     op = spv::Op::OpShiftLeftLogical;
2172   } else if (expr->IsShiftRight() && lhs_type->is_signed_scalar_or_vector()) {
2173     // A shift right with a signed LHS is an arithmetic shift.
2174     op = spv::Op::OpShiftRightArithmetic;
2175   } else if (expr->IsShiftRight()) {
2176     op = spv::Op::OpShiftRightLogical;
2177   } else if (expr->IsSubtract()) {
2178     op = lhs_is_float_or_vec ? spv::Op::OpFSub : spv::Op::OpISub;
2179   } else if (expr->IsXor()) {
2180     op = spv::Op::OpBitwiseXor;
2181   } else {
2182     error_ = "unknown binary expression";
2183     return 0;
2184   }
2185 
2186   if (!push_function_inst(op, {Operand::Int(type_id), result,
2187                                Operand::Int(lhs_id), Operand::Int(rhs_id)})) {
2188     return 0;
2189   }
2190   return result_id;
2191 }
2192 
GenerateBlockStatement(const ast::BlockStatement * stmt)2193 bool Builder::GenerateBlockStatement(const ast::BlockStatement* stmt) {
2194   scope_stack_.Push();
2195   TINT_DEFER(scope_stack_.Pop());
2196   return GenerateBlockStatementWithoutScoping(stmt);
2197 }
2198 
GenerateBlockStatementWithoutScoping(const ast::BlockStatement * stmt)2199 bool Builder::GenerateBlockStatementWithoutScoping(
2200     const ast::BlockStatement* stmt) {
2201   for (auto* block_stmt : stmt->statements) {
2202     if (!GenerateStatement(block_stmt)) {
2203       return false;
2204     }
2205   }
2206   return true;
2207 }
2208 
GenerateCallExpression(const ast::CallExpression * expr)2209 uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
2210   auto* call = builder_.Sem().Get(expr);
2211   auto* target = call->Target();
2212 
2213   if (auto* func = target->As<sem::Function>()) {
2214     return GenerateFunctionCall(call, func);
2215   }
2216   if (auto* intrinsic = target->As<sem::Intrinsic>()) {
2217     return GenerateIntrinsicCall(call, intrinsic);
2218   }
2219   if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
2220     return GenerateTypeConstructorOrConversion(call, nullptr);
2221   }
2222   TINT_ICE(Writer, builder_.Diagnostics())
2223       << "unhandled call target: " << target->TypeInfo().name;
2224   return false;
2225 }
2226 
GenerateFunctionCall(const sem::Call * call,const sem::Function *)2227 uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
2228                                        const sem::Function*) {
2229   auto* expr = call->Declaration();
2230   auto* ident = expr->target.name;
2231 
2232   auto type_id = GenerateTypeIfNeeded(call->Type());
2233   if (type_id == 0) {
2234     return 0;
2235   }
2236 
2237   auto result = result_op();
2238   auto result_id = result.to_i();
2239 
2240   OperandList ops = {Operand::Int(type_id), result};
2241 
2242   auto func_id = func_symbol_to_id_[ident->symbol];
2243   if (func_id == 0) {
2244     error_ = "unable to find called function: " +
2245              builder_.Symbols().NameFor(ident->symbol);
2246     return 0;
2247   }
2248   ops.push_back(Operand::Int(func_id));
2249 
2250   size_t arg_idx = 0;
2251   for (auto* arg : expr->args) {
2252     auto id = GenerateExpression(arg);
2253     if (id == 0) {
2254       return 0;
2255     }
2256     id = GenerateLoadIfNeeded(TypeOf(arg), id);
2257     if (id == 0) {
2258       return 0;
2259     }
2260     ops.push_back(Operand::Int(id));
2261     arg_idx++;
2262   }
2263 
2264   if (!push_function_inst(spv::Op::OpFunctionCall, std::move(ops))) {
2265     return 0;
2266   }
2267 
2268   return result_id;
2269 }
2270 
GenerateIntrinsicCall(const sem::Call * call,const sem::Intrinsic * intrinsic)2271 uint32_t Builder::GenerateIntrinsicCall(const sem::Call* call,
2272                                         const sem::Intrinsic* intrinsic) {
2273   auto result = result_op();
2274   auto result_id = result.to_i();
2275 
2276   auto result_type_id = GenerateTypeIfNeeded(intrinsic->ReturnType());
2277   if (result_type_id == 0) {
2278     return 0;
2279   }
2280 
2281   if (intrinsic->IsFineDerivative() || intrinsic->IsCoarseDerivative()) {
2282     push_capability(SpvCapabilityDerivativeControl);
2283   }
2284 
2285   if (intrinsic->IsImageQuery()) {
2286     push_capability(SpvCapabilityImageQuery);
2287   }
2288 
2289   if (intrinsic->IsTexture()) {
2290     if (!GenerateTextureIntrinsic(call, intrinsic, Operand::Int(result_type_id),
2291                                   result)) {
2292       return 0;
2293     }
2294     return result_id;
2295   }
2296 
2297   if (intrinsic->IsBarrier()) {
2298     if (!GenerateControlBarrierIntrinsic(intrinsic)) {
2299       return 0;
2300     }
2301     return result_id;
2302   }
2303 
2304   if (intrinsic->IsAtomic()) {
2305     if (!GenerateAtomicIntrinsic(call, intrinsic, Operand::Int(result_type_id),
2306                                  result)) {
2307       return 0;
2308     }
2309     return result_id;
2310   }
2311 
2312   // Generates the SPIR-V ID for the expression for the indexed call argument,
2313   // and loads it if necessary. Returns 0 on error.
2314   auto get_arg_as_value_id = [&](size_t i,
2315                                  bool generate_load = true) -> uint32_t {
2316     auto* arg = call->Arguments()[i];
2317     auto* param = intrinsic->Parameters()[i];
2318     auto val_id = GenerateExpression(arg->Declaration());
2319     if (val_id == 0) {
2320       return 0;
2321     }
2322 
2323     if (generate_load && !param->Type()->Is<sem::Pointer>()) {
2324       val_id = GenerateLoadIfNeeded(arg->Type(), val_id);
2325     }
2326     return val_id;
2327   };
2328 
2329   OperandList params = {Operand::Int(result_type_id), result};
2330   spv::Op op = spv::Op::OpNop;
2331 
2332   // Pushes the arguments for a GlslStd450 extended instruction, and sets op
2333   // to OpExtInst.
2334   auto glsl_std450 = [&](uint32_t inst_id) {
2335     auto set_id = GetGLSLstd450Import();
2336     params.push_back(Operand::Int(set_id));
2337     params.push_back(Operand::Int(inst_id));
2338     op = spv::Op::OpExtInst;
2339   };
2340 
2341   switch (intrinsic->Type()) {
2342     case IntrinsicType::kAny:
2343       if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
2344         // any(v: bool) just resolves to v.
2345         return get_arg_as_value_id(0);
2346       }
2347       op = spv::Op::OpAny;
2348       break;
2349     case IntrinsicType::kAll:
2350       if (intrinsic->Parameters()[0]->Type()->Is<sem::Bool>()) {
2351         // all(v: bool) just resolves to v.
2352         return get_arg_as_value_id(0);
2353       }
2354       op = spv::Op::OpAll;
2355       break;
2356     case IntrinsicType::kArrayLength: {
2357       auto* address_of =
2358           call->Arguments()[0]->Declaration()->As<ast::UnaryOpExpression>();
2359       if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
2360         error_ = "arrayLength() expected pointer to member access, got " +
2361                  std::string(address_of->TypeInfo().name);
2362         return 0;
2363       }
2364       auto* array_expr = address_of->expr;
2365 
2366       auto* accessor = array_expr->As<ast::MemberAccessorExpression>();
2367       if (!accessor) {
2368         error_ =
2369             "arrayLength() expected pointer to member access, got pointer to " +
2370             std::string(array_expr->TypeInfo().name);
2371         return 0;
2372       }
2373 
2374       auto struct_id = GenerateExpression(accessor->structure);
2375       if (struct_id == 0) {
2376         return 0;
2377       }
2378       params.push_back(Operand::Int(struct_id));
2379 
2380       auto* type = TypeOf(accessor->structure)->UnwrapRef();
2381       if (!type->Is<sem::Struct>()) {
2382         error_ =
2383             "invalid type (" + type->type_name() + ") for runtime array length";
2384         return 0;
2385       }
2386       // Runtime array must be the last member in the structure
2387       params.push_back(Operand::Int(uint32_t(
2388           type->As<sem::Struct>()->Declaration()->members.size() - 1)));
2389 
2390       if (!push_function_inst(spv::Op::OpArrayLength, params)) {
2391         return 0;
2392       }
2393       return result_id;
2394     }
2395     case IntrinsicType::kCountOneBits:
2396       op = spv::Op::OpBitCount;
2397       break;
2398     case IntrinsicType::kDot: {
2399       op = spv::Op::OpDot;
2400       auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
2401       if (vec_ty->type()->is_integer_scalar()) {
2402         // TODO(crbug.com/tint/1267): OpDot requires floating-point types, but
2403         // WGSL also supports integer types. SPV_KHR_integer_dot_product adds
2404         // support for integer vectors. Use it if it is available.
2405         auto el_ty = Operand::Int(GenerateTypeIfNeeded(vec_ty->type()));
2406         auto vec_a = Operand::Int(get_arg_as_value_id(0));
2407         auto vec_b = Operand::Int(get_arg_as_value_id(1));
2408         if (vec_a.to_i() == 0 || vec_b.to_i() == 0) {
2409           return 0;
2410         }
2411 
2412         auto sum = Operand::Int(0);
2413         for (uint32_t i = 0; i < vec_ty->Width(); i++) {
2414           auto a = result_op();
2415           auto b = result_op();
2416           auto mul = result_op();
2417           if (!push_function_inst(spv::Op::OpCompositeExtract,
2418                                   {el_ty, a, vec_a, Operand::Int(i)}) ||
2419               !push_function_inst(spv::Op::OpCompositeExtract,
2420                                   {el_ty, b, vec_b, Operand::Int(i)}) ||
2421               !push_function_inst(spv::Op::OpIMul, {el_ty, mul, a, b})) {
2422             return 0;
2423           }
2424           if (i == 0) {
2425             sum = mul;
2426           } else {
2427             auto prev_sum = sum;
2428             auto is_last_el = i == (vec_ty->Width() - 1);
2429             sum = is_last_el ? Operand::Int(result_id) : result_op();
2430             if (!push_function_inst(spv::Op::OpIAdd,
2431                                     {el_ty, sum, prev_sum, mul})) {
2432               return 0;
2433             }
2434           }
2435         }
2436         return result_id;
2437       }
2438       break;
2439     }
2440     case IntrinsicType::kDpdx:
2441       op = spv::Op::OpDPdx;
2442       break;
2443     case IntrinsicType::kDpdxCoarse:
2444       op = spv::Op::OpDPdxCoarse;
2445       break;
2446     case IntrinsicType::kDpdxFine:
2447       op = spv::Op::OpDPdxFine;
2448       break;
2449     case IntrinsicType::kDpdy:
2450       op = spv::Op::OpDPdy;
2451       break;
2452     case IntrinsicType::kDpdyCoarse:
2453       op = spv::Op::OpDPdyCoarse;
2454       break;
2455     case IntrinsicType::kDpdyFine:
2456       op = spv::Op::OpDPdyFine;
2457       break;
2458     case IntrinsicType::kFwidth:
2459       op = spv::Op::OpFwidth;
2460       break;
2461     case IntrinsicType::kFwidthCoarse:
2462       op = spv::Op::OpFwidthCoarse;
2463       break;
2464     case IntrinsicType::kFwidthFine:
2465       op = spv::Op::OpFwidthFine;
2466       break;
2467     case IntrinsicType::kIgnore:  // [DEPRECATED]
2468       // Evaluate the single argument, return the non-zero result_id which isn't
2469       // associated with any op (ignore returns void, so this cannot be used in
2470       // an expression).
2471       if (!get_arg_as_value_id(0, false)) {
2472         return 0;
2473       }
2474       return result_id;
2475     case IntrinsicType::kIsInf:
2476       op = spv::Op::OpIsInf;
2477       break;
2478     case IntrinsicType::kIsNan:
2479       op = spv::Op::OpIsNan;
2480       break;
2481     case IntrinsicType::kIsFinite: {
2482       // Implemented as:   not(IsInf or IsNan)
2483       auto val_id = get_arg_as_value_id(0);
2484       if (!val_id) {
2485         return 0;
2486       }
2487       auto inf_result = result_op();
2488       auto nan_result = result_op();
2489       auto or_result = result_op();
2490       if (push_function_inst(spv::Op::OpIsInf,
2491                              {Operand::Int(result_type_id), inf_result,
2492                               Operand::Int(val_id)}) &&
2493           push_function_inst(spv::Op::OpIsNan,
2494                              {Operand::Int(result_type_id), nan_result,
2495                               Operand::Int(val_id)}) &&
2496           push_function_inst(spv::Op::OpLogicalOr,
2497                              {Operand::Int(result_type_id), or_result,
2498                               Operand::Int(inf_result.to_i()),
2499                               Operand::Int(nan_result.to_i())}) &&
2500           push_function_inst(spv::Op::OpLogicalNot,
2501                              {Operand::Int(result_type_id), result,
2502                               Operand::Int(or_result.to_i())})) {
2503         return result_id;
2504       }
2505       return 0;
2506     }
2507     case IntrinsicType::kIsNormal: {
2508       // A normal number is finite, non-zero, and not subnormal.
2509       // Its exponent is neither of the extreme possible values.
2510       // Implemented as:
2511       //   exponent_bits = bitcast<u32>(f);
2512       //   clamped = uclamp(1,254,exponent_bits);
2513       //   result = (clamped == exponent_bits);
2514       //
2515       auto val_id = get_arg_as_value_id(0);
2516       if (!val_id) {
2517         return 0;
2518       }
2519 
2520       // These parameters are valid for IEEE 754 binary32
2521       const uint32_t kExponentMask = 0x7f80000;
2522       const uint32_t kMinNormalExponent = 0x0080000;
2523       const uint32_t kMaxNormalExponent = 0x7f00000;
2524 
2525       auto set_id = GetGLSLstd450Import();
2526       auto* u32 = builder_.create<sem::U32>();
2527 
2528       auto unsigned_id = GenerateTypeIfNeeded(u32);
2529       auto exponent_mask_id =
2530           GenerateConstantIfNeeded(ScalarConstant::U32(kExponentMask));
2531       auto min_exponent_id =
2532           GenerateConstantIfNeeded(ScalarConstant::U32(kMinNormalExponent));
2533       auto max_exponent_id =
2534           GenerateConstantIfNeeded(ScalarConstant::U32(kMaxNormalExponent));
2535       if (auto* fvec_ty = intrinsic->ReturnType()->As<sem::Vector>()) {
2536         // In the vector case, update the unsigned type to a vector type of the
2537         // same size, and create vector constants by replicating the scalars.
2538         // I expect backend compilers to fold these into unique constants, so
2539         // there is no loss of efficiency.
2540         auto* uvec_ty = builder_.create<sem::Vector>(u32, fvec_ty->Width());
2541         unsigned_id = GenerateTypeIfNeeded(uvec_ty);
2542         auto splat = [&](uint32_t scalar_id) -> uint32_t {
2543           auto splat_result = result_op();
2544           OperandList splat_params{Operand::Int(unsigned_id), splat_result};
2545           for (size_t i = 0; i < fvec_ty->Width(); i++) {
2546             splat_params.emplace_back(Operand::Int(scalar_id));
2547           }
2548           if (!push_function_inst(spv::Op::OpCompositeConstruct,
2549                                   std::move(splat_params))) {
2550             return 0;
2551           }
2552           return splat_result.to_i();
2553         };
2554         exponent_mask_id = splat(exponent_mask_id);
2555         min_exponent_id = splat(min_exponent_id);
2556         max_exponent_id = splat(max_exponent_id);
2557       }
2558       auto cast_result = result_op();
2559       auto exponent_bits_result = result_op();
2560       auto clamp_result = result_op();
2561 
2562       if (set_id && unsigned_id && exponent_mask_id && min_exponent_id &&
2563           max_exponent_id &&
2564           push_function_inst(
2565               spv::Op::OpBitcast,
2566               {Operand::Int(unsigned_id), cast_result, Operand::Int(val_id)}) &&
2567           push_function_inst(spv::Op::OpBitwiseAnd,
2568                              {Operand::Int(unsigned_id), exponent_bits_result,
2569                               Operand::Int(cast_result.to_i()),
2570                               Operand::Int(exponent_mask_id)}) &&
2571           push_function_inst(
2572               spv::Op::OpExtInst,
2573               {Operand::Int(unsigned_id), clamp_result, Operand::Int(set_id),
2574                Operand::Int(GLSLstd450UClamp),
2575                Operand::Int(exponent_bits_result.to_i()),
2576                Operand::Int(min_exponent_id), Operand::Int(max_exponent_id)}) &&
2577           push_function_inst(spv::Op::OpIEqual,
2578                              {Operand::Int(result_type_id), result,
2579                               Operand::Int(exponent_bits_result.to_i()),
2580                               Operand::Int(clamp_result.to_i())})) {
2581         return result_id;
2582       }
2583       return 0;
2584     }
2585     case IntrinsicType::kMix: {
2586       auto std450 = Operand::Int(GetGLSLstd450Import());
2587 
2588       auto a_id = get_arg_as_value_id(0);
2589       auto b_id = get_arg_as_value_id(1);
2590       auto f_id = get_arg_as_value_id(2);
2591       if (!a_id || !b_id || !f_id) {
2592         return 0;
2593       }
2594 
2595       // If the interpolant is scalar but the objects are vectors, we need to
2596       // splat the interpolant into a vector of the same size.
2597       auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
2598       if (result_vector_type &&
2599           intrinsic->Parameters()[2]->Type()->is_scalar()) {
2600         f_id = GenerateSplat(f_id, intrinsic->Parameters()[0]->Type());
2601         if (f_id == 0) {
2602           return 0;
2603         }
2604       }
2605 
2606       if (!push_function_inst(spv::Op::OpExtInst,
2607                               {Operand::Int(result_type_id), result, std450,
2608                                Operand::Int(GLSLstd450FMix), Operand::Int(a_id),
2609                                Operand::Int(b_id), Operand::Int(f_id)})) {
2610         return 0;
2611       }
2612       return result_id;
2613     }
2614     case IntrinsicType::kReverseBits:
2615       op = spv::Op::OpBitReverse;
2616       break;
2617     case IntrinsicType::kSelect: {
2618       // Note: Argument order is different in WGSL and SPIR-V
2619       auto cond_id = get_arg_as_value_id(2);
2620       auto true_id = get_arg_as_value_id(1);
2621       auto false_id = get_arg_as_value_id(0);
2622       if (!cond_id || !true_id || !false_id) {
2623         return 0;
2624       }
2625 
2626       // If the condition is scalar but the objects are vectors, we need to
2627       // splat the condition into a vector of the same size.
2628       // TODO(jrprice): If we're targeting SPIR-V 1.4, we don't need to do this.
2629       auto* result_vector_type = intrinsic->ReturnType()->As<sem::Vector>();
2630       if (result_vector_type &&
2631           intrinsic->Parameters()[2]->Type()->is_scalar()) {
2632         auto* bool_vec_ty = builder_.create<sem::Vector>(
2633             builder_.create<sem::Bool>(), result_vector_type->Width());
2634         if (!GenerateTypeIfNeeded(bool_vec_ty)) {
2635           return 0;
2636         }
2637         cond_id = GenerateSplat(cond_id, bool_vec_ty);
2638         if (cond_id == 0) {
2639           return 0;
2640         }
2641       }
2642 
2643       if (!push_function_inst(
2644               spv::Op::OpSelect,
2645               {Operand::Int(result_type_id), result, Operand::Int(cond_id),
2646                Operand::Int(true_id), Operand::Int(false_id)})) {
2647         return 0;
2648       }
2649       return result_id;
2650     }
2651     case IntrinsicType::kTranspose:
2652       op = spv::Op::OpTranspose;
2653       break;
2654     case IntrinsicType::kAbs:
2655       if (intrinsic->ReturnType()->is_unsigned_scalar_or_vector()) {
2656         // abs() only operates on *signed* integers.
2657         // This is a no-op for unsigned integers.
2658         return get_arg_as_value_id(0);
2659       }
2660       if (intrinsic->ReturnType()->is_float_scalar_or_vector()) {
2661         glsl_std450(GLSLstd450FAbs);
2662       } else {
2663         glsl_std450(GLSLstd450SAbs);
2664       }
2665       break;
2666     default: {
2667       auto inst_id = intrinsic_to_glsl_method(intrinsic);
2668       if (inst_id == 0) {
2669         error_ = "unknown method " + std::string(intrinsic->str());
2670         return 0;
2671       }
2672       glsl_std450(inst_id);
2673       break;
2674     }
2675   }
2676 
2677   if (op == spv::Op::OpNop) {
2678     error_ =
2679         "unable to determine operator for: " + std::string(intrinsic->str());
2680     return 0;
2681   }
2682 
2683   for (size_t i = 0; i < call->Arguments().size(); i++) {
2684     if (auto val_id = get_arg_as_value_id(i)) {
2685       params.emplace_back(Operand::Int(val_id));
2686     } else {
2687       return 0;
2688     }
2689   }
2690 
2691   if (!push_function_inst(op, params)) {
2692     return 0;
2693   }
2694 
2695   return result_id;
2696 }
2697 
GenerateTextureIntrinsic(const sem::Call * call,const sem::Intrinsic * intrinsic,Operand result_type,Operand result_id)2698 bool Builder::GenerateTextureIntrinsic(const sem::Call* call,
2699                                        const sem::Intrinsic* intrinsic,
2700                                        Operand result_type,
2701                                        Operand result_id) {
2702   using Usage = sem::ParameterUsage;
2703 
2704   auto& signature = intrinsic->Signature();
2705   auto& arguments = call->Arguments();
2706 
2707   // Generates the given expression, returning the operand ID
2708   auto gen = [&](const sem::Expression* expr) {
2709     auto val_id = GenerateExpression(expr->Declaration());
2710     if (val_id == 0) {
2711       return Operand::Int(0);
2712     }
2713     val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
2714 
2715     return Operand::Int(val_id);
2716   };
2717 
2718   // Returns the argument with the given usage
2719   auto arg = [&](Usage usage) {
2720     int idx = signature.IndexOf(usage);
2721     return (idx >= 0) ? arguments[idx] : nullptr;
2722   };
2723 
2724   // Generates the argument with the given usage, returning the operand ID
2725   auto gen_arg = [&](Usage usage) {
2726     auto* argument = arg(usage);
2727     if (!argument) {
2728       TINT_ICE(Writer, builder_.Diagnostics())
2729           << "missing argument " << static_cast<int>(usage);
2730     }
2731     return gen(argument);
2732   };
2733 
2734   auto* texture = arg(Usage::kTexture);
2735   if (!texture) {
2736     TINT_ICE(Writer, builder_.Diagnostics()) << "missing texture argument";
2737   }
2738 
2739   auto* texture_type = texture->Type()->UnwrapRef()->As<sem::Texture>();
2740 
2741   auto op = spv::Op::OpNop;
2742 
2743   // Custom function to call after the texture-intrinsic op has been generated.
2744   std::function<bool()> post_emission = [] { return true; };
2745 
2746   // Populate the spirv_params with common parameters
2747   OperandList spirv_params;
2748   spirv_params.reserve(8);  // Enough to fit most parameter lists
2749 
2750   // Extra image operands, appended to spirv_params.
2751   struct ImageOperand {
2752     SpvImageOperandsMask mask;
2753     Operand operand;
2754   };
2755   std::vector<ImageOperand> image_operands;
2756   image_operands.reserve(4);  // Enough to fit most parameter lists
2757 
2758   // Appends `result_type` and `result_id` to `spirv_params`
2759   auto append_result_type_and_id_to_spirv_params = [&]() {
2760     spirv_params.emplace_back(std::move(result_type));
2761     spirv_params.emplace_back(std::move(result_id));
2762   };
2763 
2764   // Appends a result type and id to `spirv_params`, possibly adding a
2765   // post_emission step.
2766   //
2767   // If the texture is a depth texture, then this function wraps the result of
2768   // the op with a OpCompositeExtract to evaluate to the first element of the
2769   // returned vector. This is done as the WGSL texture reading functions for
2770   // depths return a single float scalar instead of a vector.
2771   //
2772   // If the texture is not a depth texture, then this function simply delegates
2773   // to calling append_result_type_and_id_to_spirv_params().
2774   auto append_result_type_and_id_to_spirv_params_for_read = [&]() {
2775     if (texture_type
2776             ->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
2777       auto* f32 = builder_.create<sem::F32>();
2778       auto* spirv_result_type = builder_.create<sem::Vector>(f32, 4);
2779       auto spirv_result = result_op();
2780       post_emission = [=] {
2781         return push_function_inst(
2782             spv::Op::OpCompositeExtract,
2783             {result_type, result_id, spirv_result, Operand::Int(0)});
2784       };
2785       auto spirv_result_type_id = GenerateTypeIfNeeded(spirv_result_type);
2786       if (spirv_result_type_id == 0) {
2787         return false;
2788       }
2789       spirv_params.emplace_back(Operand::Int(spirv_result_type_id));
2790       spirv_params.emplace_back(spirv_result);
2791       return true;
2792     }
2793 
2794     append_result_type_and_id_to_spirv_params();
2795     return true;
2796   };
2797 
2798   // Appends a result type and id to `spirv_params`, by first swizzling the
2799   // result of the op with `swizzle`.
2800   auto append_result_type_and_id_to_spirv_params_swizzled =
2801       [&](uint32_t spirv_result_width, std::vector<uint32_t> swizzle) {
2802         if (swizzle.empty()) {
2803           append_result_type_and_id_to_spirv_params();
2804         } else {
2805           // Assign post_emission to swizzle the result of the call to
2806           // OpImageQuerySize[Lod].
2807           auto* element_type = ElementTypeOf(call->Type());
2808           auto spirv_result = result_op();
2809           auto* spirv_result_type =
2810               builder_.create<sem::Vector>(element_type, spirv_result_width);
2811           if (swizzle.size() > 1) {
2812             post_emission = [=] {
2813               OperandList operands{
2814                   result_type,
2815                   result_id,
2816                   spirv_result,
2817                   spirv_result,
2818               };
2819               for (auto idx : swizzle) {
2820                 operands.emplace_back(Operand::Int(idx));
2821               }
2822               return push_function_inst(spv::Op::OpVectorShuffle, operands);
2823             };
2824           } else {
2825             post_emission = [=] {
2826               return push_function_inst(spv::Op::OpCompositeExtract,
2827                                         {result_type, result_id, spirv_result,
2828                                          Operand::Int(swizzle[0])});
2829             };
2830           }
2831           auto spirv_result_type_id = GenerateTypeIfNeeded(spirv_result_type);
2832           if (spirv_result_type_id == 0) {
2833             return false;
2834           }
2835           spirv_params.emplace_back(Operand::Int(spirv_result_type_id));
2836           spirv_params.emplace_back(spirv_result);
2837         }
2838         return true;
2839       };
2840 
2841   auto append_coords_to_spirv_params = [&]() -> bool {
2842     if (auto* array_index = arg(Usage::kArrayIndex)) {
2843       // Array index needs to be appended to the coordinates.
2844       auto* packed = AppendVector(&builder_, arg(Usage::kCoords)->Declaration(),
2845                                   array_index->Declaration());
2846       auto param = GenerateExpression(packed->Declaration());
2847       if (param == 0) {
2848         return false;
2849       }
2850       spirv_params.emplace_back(Operand::Int(param));
2851     } else {
2852       spirv_params.emplace_back(gen_arg(Usage::kCoords));  // coordinates
2853     }
2854     return true;
2855   };
2856 
2857   auto append_image_and_coords_to_spirv_params = [&]() -> bool {
2858     auto sampler_param = gen_arg(Usage::kSampler);
2859     auto texture_param = gen_arg(Usage::kTexture);
2860     auto sampled_image =
2861         GenerateSampledImage(texture_type, texture_param, sampler_param);
2862 
2863     // Populate the spirv_params with the common parameters
2864     spirv_params.emplace_back(Operand::Int(sampled_image));  // sampled image
2865     return append_coords_to_spirv_params();
2866   };
2867 
2868   switch (intrinsic->Type()) {
2869     case IntrinsicType::kTextureDimensions: {
2870       // Number of returned elements from OpImageQuerySize[Lod] may not match
2871       // those of textureDimensions().
2872       // This might be due to an extra vector scalar describing the number of
2873       // array elements or textureDimensions() returning a vec3 for cubes
2874       // when only width / height is returned by OpImageQuerySize[Lod]
2875       // (see https://github.com/gpuweb/gpuweb/issues/1345).
2876       // Handle these mismatches by swizzling the returned vector.
2877       std::vector<uint32_t> swizzle;
2878       uint32_t spirv_dims = 0;
2879       switch (texture_type->dim()) {
2880         case ast::TextureDimension::kNone:
2881           error_ = "texture dimension is kNone";
2882           return false;
2883         case ast::TextureDimension::k1d:
2884         case ast::TextureDimension::k2d:
2885         case ast::TextureDimension::k3d:
2886         case ast::TextureDimension::kCube:
2887           break;  // No swizzle needed
2888         case ast::TextureDimension::kCubeArray:
2889         case ast::TextureDimension::k2dArray:
2890           swizzle = {0, 1};  // Strip array index
2891           spirv_dims = 3;    // [width, height, array_count]
2892           break;
2893       }
2894 
2895       if (!append_result_type_and_id_to_spirv_params_swizzled(spirv_dims,
2896                                                               swizzle)) {
2897         return false;
2898       }
2899 
2900       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2901       if (texture_type->IsAnyOf<sem::MultisampledTexture,       //
2902                                 sem::DepthMultisampledTexture,  //
2903                                 sem::StorageTexture>()) {
2904         op = spv::Op::OpImageQuerySize;
2905       } else if (auto* level = arg(Usage::kLevel)) {
2906         op = spv::Op::OpImageQuerySizeLod;
2907         spirv_params.emplace_back(gen(level));
2908       } else {
2909         ast::SintLiteralExpression i32_0(ProgramID(), Source{}, 0);
2910         op = spv::Op::OpImageQuerySizeLod;
2911         spirv_params.emplace_back(
2912             Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0)));
2913       }
2914       break;
2915     }
2916     case IntrinsicType::kTextureNumLayers: {
2917       uint32_t spirv_dims = 0;
2918       switch (texture_type->dim()) {
2919         default:
2920           error_ = "texture is not arrayed";
2921           return false;
2922         case ast::TextureDimension::k2dArray:
2923         case ast::TextureDimension::kCubeArray:
2924           spirv_dims = 3;
2925           break;
2926       }
2927 
2928       // OpImageQuerySize[Lod] packs the array count as the last element of the
2929       // returned vector. Extract this.
2930       if (!append_result_type_and_id_to_spirv_params_swizzled(
2931               spirv_dims, {spirv_dims - 1})) {
2932         return false;
2933       }
2934 
2935       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2936 
2937       if (texture_type->Is<sem::MultisampledTexture>() ||
2938           texture_type->Is<sem::StorageTexture>()) {
2939         op = spv::Op::OpImageQuerySize;
2940       } else {
2941         ast::SintLiteralExpression i32_0(ProgramID(), Source{}, 0);
2942         op = spv::Op::OpImageQuerySizeLod;
2943         spirv_params.emplace_back(
2944             Operand::Int(GenerateLiteralIfNeeded(nullptr, &i32_0)));
2945       }
2946       break;
2947     }
2948     case IntrinsicType::kTextureNumLevels: {
2949       op = spv::Op::OpImageQueryLevels;
2950       append_result_type_and_id_to_spirv_params();
2951       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2952       break;
2953     }
2954     case IntrinsicType::kTextureNumSamples: {
2955       op = spv::Op::OpImageQuerySamples;
2956       append_result_type_and_id_to_spirv_params();
2957       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2958       break;
2959     }
2960     case IntrinsicType::kTextureLoad: {
2961       op = texture_type->Is<sem::StorageTexture>() ? spv::Op::OpImageRead
2962                                                    : spv::Op::OpImageFetch;
2963       append_result_type_and_id_to_spirv_params_for_read();
2964       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2965       if (!append_coords_to_spirv_params()) {
2966         return false;
2967       }
2968 
2969       if (auto* level = arg(Usage::kLevel)) {
2970         image_operands.emplace_back(
2971             ImageOperand{SpvImageOperandsLodMask, gen(level)});
2972       }
2973 
2974       if (auto* sample_index = arg(Usage::kSampleIndex)) {
2975         image_operands.emplace_back(
2976             ImageOperand{SpvImageOperandsSampleMask, gen(sample_index)});
2977       }
2978 
2979       break;
2980     }
2981     case IntrinsicType::kTextureStore: {
2982       op = spv::Op::OpImageWrite;
2983       spirv_params.emplace_back(gen_arg(Usage::kTexture));
2984       if (!append_coords_to_spirv_params()) {
2985         return false;
2986       }
2987       spirv_params.emplace_back(gen_arg(Usage::kValue));
2988       break;
2989     }
2990     case IntrinsicType::kTextureGather: {
2991       op = spv::Op::OpImageGather;
2992       append_result_type_and_id_to_spirv_params();
2993       if (!append_image_and_coords_to_spirv_params()) {
2994         return false;
2995       }
2996       if (signature.IndexOf(Usage::kComponent) < 0) {
2997         spirv_params.emplace_back(
2998             Operand::Int(GenerateConstantIfNeeded(ScalarConstant::I32(0))));
2999       } else {
3000         spirv_params.emplace_back(gen_arg(Usage::kComponent));
3001       }
3002       break;
3003     }
3004     case IntrinsicType::kTextureGatherCompare: {
3005       op = spv::Op::OpImageDrefGather;
3006       append_result_type_and_id_to_spirv_params();
3007       if (!append_image_and_coords_to_spirv_params()) {
3008         return false;
3009       }
3010       spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3011       break;
3012     }
3013     case IntrinsicType::kTextureSample: {
3014       op = spv::Op::OpImageSampleImplicitLod;
3015       append_result_type_and_id_to_spirv_params_for_read();
3016       if (!append_image_and_coords_to_spirv_params()) {
3017         return false;
3018       }
3019       break;
3020     }
3021     case IntrinsicType::kTextureSampleBias: {
3022       op = spv::Op::OpImageSampleImplicitLod;
3023       append_result_type_and_id_to_spirv_params_for_read();
3024       if (!append_image_and_coords_to_spirv_params()) {
3025         return false;
3026       }
3027       image_operands.emplace_back(
3028           ImageOperand{SpvImageOperandsBiasMask, gen_arg(Usage::kBias)});
3029       break;
3030     }
3031     case IntrinsicType::kTextureSampleLevel: {
3032       op = spv::Op::OpImageSampleExplicitLod;
3033       append_result_type_and_id_to_spirv_params_for_read();
3034       if (!append_image_and_coords_to_spirv_params()) {
3035         return false;
3036       }
3037       auto level = Operand::Int(0);
3038       if (arg(Usage::kLevel)->Type()->UnwrapRef()->Is<sem::I32>()) {
3039         // Depth textures have i32 parameters for the level, but SPIR-V expects
3040         // F32. Cast.
3041         auto f32_type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
3042         if (f32_type_id == 0) {
3043           return 0;
3044         }
3045         level = result_op();
3046         if (!push_function_inst(
3047                 spv::Op::OpConvertSToF,
3048                 {Operand::Int(f32_type_id), level, gen_arg(Usage::kLevel)})) {
3049           return 0;
3050         }
3051       } else {
3052         level = gen_arg(Usage::kLevel);
3053       }
3054       image_operands.emplace_back(ImageOperand{SpvImageOperandsLodMask, level});
3055       break;
3056     }
3057     case IntrinsicType::kTextureSampleGrad: {
3058       op = spv::Op::OpImageSampleExplicitLod;
3059       append_result_type_and_id_to_spirv_params_for_read();
3060       if (!append_image_and_coords_to_spirv_params()) {
3061         return false;
3062       }
3063       image_operands.emplace_back(
3064           ImageOperand{SpvImageOperandsGradMask, gen_arg(Usage::kDdx)});
3065       image_operands.emplace_back(
3066           ImageOperand{SpvImageOperandsGradMask, gen_arg(Usage::kDdy)});
3067       break;
3068     }
3069     case IntrinsicType::kTextureSampleCompare: {
3070       op = spv::Op::OpImageSampleDrefImplicitLod;
3071       append_result_type_and_id_to_spirv_params();
3072       if (!append_image_and_coords_to_spirv_params()) {
3073         return false;
3074       }
3075       spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3076       break;
3077     }
3078     case IntrinsicType::kTextureSampleCompareLevel: {
3079       op = spv::Op::OpImageSampleDrefExplicitLod;
3080       append_result_type_and_id_to_spirv_params();
3081       if (!append_image_and_coords_to_spirv_params()) {
3082         return false;
3083       }
3084       spirv_params.emplace_back(gen_arg(Usage::kDepthRef));
3085 
3086       ast::FloatLiteralExpression float_0(ProgramID(), Source{}, 0.0);
3087       image_operands.emplace_back(ImageOperand{
3088           SpvImageOperandsLodMask,
3089           Operand::Int(GenerateLiteralIfNeeded(nullptr, &float_0))});
3090       break;
3091     }
3092     default:
3093       TINT_UNREACHABLE(Writer, builder_.Diagnostics());
3094       return false;
3095   }
3096 
3097   if (auto* offset = arg(Usage::kOffset)) {
3098     image_operands.emplace_back(
3099         ImageOperand{SpvImageOperandsConstOffsetMask, gen(offset)});
3100   }
3101 
3102   if (!image_operands.empty()) {
3103     std::sort(image_operands.begin(), image_operands.end(),
3104               [](auto& a, auto& b) { return a.mask < b.mask; });
3105     uint32_t mask = 0;
3106     for (auto& image_operand : image_operands) {
3107       mask |= image_operand.mask;
3108     }
3109     spirv_params.emplace_back(Operand::Int(mask));
3110     for (auto& image_operand : image_operands) {
3111       spirv_params.emplace_back(image_operand.operand);
3112     }
3113   }
3114 
3115   if (op == spv::Op::OpNop) {
3116     error_ =
3117         "unable to determine operator for: " + std::string(intrinsic->str());
3118     return false;
3119   }
3120 
3121   if (!push_function_inst(op, spirv_params)) {
3122     return false;
3123   }
3124 
3125   return post_emission();
3126 }
3127 
GenerateControlBarrierIntrinsic(const sem::Intrinsic * intrinsic)3128 bool Builder::GenerateControlBarrierIntrinsic(const sem::Intrinsic* intrinsic) {
3129   auto const op = spv::Op::OpControlBarrier;
3130   uint32_t execution = 0;
3131   uint32_t memory = 0;
3132   uint32_t semantics = 0;
3133 
3134   // TODO(crbug.com/tint/661): Combine sequential barriers to a single
3135   // instruction.
3136   if (intrinsic->Type() == sem::IntrinsicType::kWorkgroupBarrier) {
3137     execution = static_cast<uint32_t>(spv::Scope::Workgroup);
3138     memory = static_cast<uint32_t>(spv::Scope::Workgroup);
3139     semantics =
3140         static_cast<uint32_t>(spv::MemorySemanticsMask::AcquireRelease) |
3141         static_cast<uint32_t>(spv::MemorySemanticsMask::WorkgroupMemory);
3142   } else if (intrinsic->Type() == sem::IntrinsicType::kStorageBarrier) {
3143     execution = static_cast<uint32_t>(spv::Scope::Workgroup);
3144     memory = static_cast<uint32_t>(spv::Scope::Workgroup);
3145     semantics =
3146         static_cast<uint32_t>(spv::MemorySemanticsMask::AcquireRelease) |
3147         static_cast<uint32_t>(spv::MemorySemanticsMask::UniformMemory);
3148   } else {
3149     error_ = "unexpected barrier intrinsic type ";
3150     error_ += sem::str(intrinsic->Type());
3151     return false;
3152   }
3153 
3154   auto execution_id = GenerateConstantIfNeeded(ScalarConstant::U32(execution));
3155   auto memory_id = GenerateConstantIfNeeded(ScalarConstant::U32(memory));
3156   auto semantics_id = GenerateConstantIfNeeded(ScalarConstant::U32(semantics));
3157   if (execution_id == 0 || memory_id == 0 || semantics_id == 0) {
3158     return false;
3159   }
3160 
3161   return push_function_inst(op, {
3162                                     Operand::Int(execution_id),
3163                                     Operand::Int(memory_id),
3164                                     Operand::Int(semantics_id),
3165                                 });
3166 }
3167 
GenerateAtomicIntrinsic(const sem::Call * call,const sem::Intrinsic * intrinsic,Operand result_type,Operand result_id)3168 bool Builder::GenerateAtomicIntrinsic(const sem::Call* call,
3169                                       const sem::Intrinsic* intrinsic,
3170                                       Operand result_type,
3171                                       Operand result_id) {
3172   auto is_value_signed = [&] {
3173     return intrinsic->Parameters()[1]->Type()->Is<sem::I32>();
3174   };
3175 
3176   auto storage_class =
3177       intrinsic->Parameters()[0]->Type()->As<sem::Pointer>()->StorageClass();
3178 
3179   uint32_t memory_id = 0;
3180   switch (
3181       intrinsic->Parameters()[0]->Type()->As<sem::Pointer>()->StorageClass()) {
3182     case ast::StorageClass::kWorkgroup:
3183       memory_id = GenerateConstantIfNeeded(
3184           ScalarConstant::U32(static_cast<uint32_t>(spv::Scope::Workgroup)));
3185       break;
3186     case ast::StorageClass::kStorage:
3187       memory_id = GenerateConstantIfNeeded(
3188           ScalarConstant::U32(static_cast<uint32_t>(spv::Scope::Device)));
3189       break;
3190     default:
3191       TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3192           << "unhandled atomic storage class " << storage_class;
3193       return false;
3194   }
3195   if (memory_id == 0) {
3196     return false;
3197   }
3198 
3199   uint32_t semantics_id = GenerateConstantIfNeeded(ScalarConstant::U32(
3200       static_cast<uint32_t>(spv::MemorySemanticsMask::MaskNone)));
3201   if (semantics_id == 0) {
3202     return false;
3203   }
3204 
3205   uint32_t pointer_id = GenerateExpression(call->Arguments()[0]->Declaration());
3206   if (pointer_id == 0) {
3207     return false;
3208   }
3209 
3210   uint32_t value_id = 0;
3211   if (call->Arguments().size() > 1) {
3212     value_id = GenerateExpression(call->Arguments().back()->Declaration());
3213     if (value_id == 0) {
3214       return false;
3215     }
3216     value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
3217     if (value_id == 0) {
3218       return false;
3219     }
3220   }
3221 
3222   Operand pointer = Operand::Int(pointer_id);
3223   Operand value = Operand::Int(value_id);
3224   Operand memory = Operand::Int(memory_id);
3225   Operand semantics = Operand::Int(semantics_id);
3226 
3227   switch (intrinsic->Type()) {
3228     case sem::IntrinsicType::kAtomicLoad:
3229       return push_function_inst(spv::Op::OpAtomicLoad, {
3230                                                            result_type,
3231                                                            result_id,
3232                                                            pointer,
3233                                                            memory,
3234                                                            semantics,
3235                                                        });
3236     case sem::IntrinsicType::kAtomicStore:
3237       return push_function_inst(spv::Op::OpAtomicStore, {
3238                                                             pointer,
3239                                                             memory,
3240                                                             semantics,
3241                                                             value,
3242                                                         });
3243     case sem::IntrinsicType::kAtomicAdd:
3244       return push_function_inst(spv::Op::OpAtomicIAdd, {
3245                                                            result_type,
3246                                                            result_id,
3247                                                            pointer,
3248                                                            memory,
3249                                                            semantics,
3250                                                            value,
3251                                                        });
3252     case sem::IntrinsicType::kAtomicSub:
3253       return push_function_inst(spv::Op::OpAtomicISub, {
3254                                                            result_type,
3255                                                            result_id,
3256                                                            pointer,
3257                                                            memory,
3258                                                            semantics,
3259                                                            value,
3260                                                        });
3261     case sem::IntrinsicType::kAtomicMax:
3262       return push_function_inst(
3263           is_value_signed() ? spv::Op::OpAtomicSMax : spv::Op::OpAtomicUMax,
3264           {
3265               result_type,
3266               result_id,
3267               pointer,
3268               memory,
3269               semantics,
3270               value,
3271           });
3272     case sem::IntrinsicType::kAtomicMin:
3273       return push_function_inst(
3274           is_value_signed() ? spv::Op::OpAtomicSMin : spv::Op::OpAtomicUMin,
3275           {
3276               result_type,
3277               result_id,
3278               pointer,
3279               memory,
3280               semantics,
3281               value,
3282           });
3283     case sem::IntrinsicType::kAtomicAnd:
3284       return push_function_inst(spv::Op::OpAtomicAnd, {
3285                                                           result_type,
3286                                                           result_id,
3287                                                           pointer,
3288                                                           memory,
3289                                                           semantics,
3290                                                           value,
3291                                                       });
3292     case sem::IntrinsicType::kAtomicOr:
3293       return push_function_inst(spv::Op::OpAtomicOr, {
3294                                                          result_type,
3295                                                          result_id,
3296                                                          pointer,
3297                                                          memory,
3298                                                          semantics,
3299                                                          value,
3300                                                      });
3301     case sem::IntrinsicType::kAtomicXor:
3302       return push_function_inst(spv::Op::OpAtomicXor, {
3303                                                           result_type,
3304                                                           result_id,
3305                                                           pointer,
3306                                                           memory,
3307                                                           semantics,
3308                                                           value,
3309                                                       });
3310     case sem::IntrinsicType::kAtomicExchange:
3311       return push_function_inst(spv::Op::OpAtomicExchange, {
3312                                                                result_type,
3313                                                                result_id,
3314                                                                pointer,
3315                                                                memory,
3316                                                                semantics,
3317                                                                value,
3318                                                            });
3319     case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
3320       auto comparator = GenerateExpression(call->Arguments()[1]->Declaration());
3321       if (comparator == 0) {
3322         return false;
3323       }
3324 
3325       auto* value_sem_type = TypeOf(call->Arguments()[2]->Declaration());
3326 
3327       auto value_type = GenerateTypeIfNeeded(value_sem_type);
3328       if (value_type == 0) {
3329         return false;
3330       }
3331 
3332       auto* bool_sem_ty = builder_.create<sem::Bool>();
3333       auto bool_type = GenerateTypeIfNeeded(bool_sem_ty);
3334       if (bool_type == 0) {
3335         return false;
3336       }
3337 
3338       // original_value := OpAtomicCompareExchange(pointer, memory, semantics,
3339       //                                           semantics, value, comparator)
3340       auto original_value = result_op();
3341       if (!push_function_inst(spv::Op::OpAtomicCompareExchange,
3342                               {
3343                                   Operand::Int(value_type),
3344                                   original_value,
3345                                   pointer,
3346                                   memory,
3347                                   semantics,
3348                                   semantics,
3349                                   value,
3350                                   Operand::Int(comparator),
3351                               })) {
3352         return false;
3353       }
3354 
3355       // values_equal := original_value == value
3356       auto values_equal = result_op();
3357       if (!push_function_inst(spv::Op::OpIEqual, {
3358                                                      Operand::Int(bool_type),
3359                                                      values_equal,
3360                                                      original_value,
3361                                                      value,
3362                                                  })) {
3363         return false;
3364       }
3365 
3366       // zero := T(0)
3367       // one := T(1)
3368       uint32_t zero = 0;
3369       uint32_t one = 0;
3370       if (value_sem_type->Is<sem::I32>()) {
3371         zero = GenerateConstantIfNeeded(ScalarConstant::I32(0u));
3372         one = GenerateConstantIfNeeded(ScalarConstant::I32(1u));
3373       } else if (value_sem_type->Is<sem::U32>()) {
3374         zero = GenerateConstantIfNeeded(ScalarConstant::U32(0u));
3375         one = GenerateConstantIfNeeded(ScalarConstant::U32(1u));
3376       } else {
3377         TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3378             << "unsupported atomic type " << value_sem_type->TypeInfo().name;
3379       }
3380       if (zero == 0 || one == 0) {
3381         return false;
3382       }
3383 
3384       // xchg_success := values_equal ? one : zero
3385       auto xchg_success = result_op();
3386       if (!push_function_inst(spv::Op::OpSelect, {
3387                                                      Operand::Int(value_type),
3388                                                      xchg_success,
3389                                                      values_equal,
3390                                                      Operand::Int(one),
3391                                                      Operand::Int(zero),
3392                                                  })) {
3393         return false;
3394       }
3395 
3396       // result := vec2<T>(original_value, xchg_success)
3397       return push_function_inst(spv::Op::OpCompositeConstruct,
3398                                 {
3399                                     result_type,
3400                                     result_id,
3401                                     original_value,
3402                                     xchg_success,
3403                                 });
3404     }
3405     default:
3406       TINT_UNREACHABLE(Writer, builder_.Diagnostics())
3407           << "unhandled atomic intrinsic " << intrinsic->Type();
3408       return false;
3409   }
3410 }
3411 
GenerateSampledImage(const sem::Type * texture_type,Operand texture_operand,Operand sampler_operand)3412 uint32_t Builder::GenerateSampledImage(const sem::Type* texture_type,
3413                                        Operand texture_operand,
3414                                        Operand sampler_operand) {
3415   uint32_t sampled_image_type_id = 0;
3416   auto val = texture_type_name_to_sampled_image_type_id_.find(
3417       texture_type->type_name());
3418   if (val != texture_type_name_to_sampled_image_type_id_.end()) {
3419     // The sampled image type is already created.
3420     sampled_image_type_id = val->second;
3421   } else {
3422     // We need to create the sampled image type and cache the result.
3423     auto sampled_image_type = result_op();
3424     sampled_image_type_id = sampled_image_type.to_i();
3425     auto texture_type_id = GenerateTypeIfNeeded(texture_type);
3426     push_type(spv::Op::OpTypeSampledImage,
3427               {sampled_image_type, Operand::Int(texture_type_id)});
3428     texture_type_name_to_sampled_image_type_id_[texture_type->type_name()] =
3429         sampled_image_type_id;
3430   }
3431 
3432   auto sampled_image = result_op();
3433   if (!push_function_inst(spv::Op::OpSampledImage,
3434                           {Operand::Int(sampled_image_type_id), sampled_image,
3435                            texture_operand, sampler_operand})) {
3436     return 0;
3437   }
3438 
3439   return sampled_image.to_i();
3440 }
3441 
GenerateBitcastExpression(const ast::BitcastExpression * expr)3442 uint32_t Builder::GenerateBitcastExpression(
3443     const ast::BitcastExpression* expr) {
3444   auto result = result_op();
3445   auto result_id = result.to_i();
3446 
3447   auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
3448   if (result_type_id == 0) {
3449     return 0;
3450   }
3451 
3452   auto val_id = GenerateExpression(expr->expr);
3453   if (val_id == 0) {
3454     return 0;
3455   }
3456   val_id = GenerateLoadIfNeeded(TypeOf(expr->expr), val_id);
3457 
3458   // Bitcast does not allow same types, just emit a CopyObject
3459   auto* to_type = TypeOf(expr)->UnwrapRef();
3460   auto* from_type = TypeOf(expr->expr)->UnwrapRef();
3461   if (to_type->type_name() == from_type->type_name()) {
3462     if (!push_function_inst(
3463             spv::Op::OpCopyObject,
3464             {Operand::Int(result_type_id), result, Operand::Int(val_id)})) {
3465       return 0;
3466     }
3467     return result_id;
3468   }
3469 
3470   if (!push_function_inst(spv::Op::OpBitcast, {Operand::Int(result_type_id),
3471                                                result, Operand::Int(val_id)})) {
3472     return 0;
3473   }
3474 
3475   return result_id;
3476 }
3477 
GenerateConditionalBlock(const ast::Expression * cond,const ast::BlockStatement * true_body,size_t cur_else_idx,const ast::ElseStatementList & else_stmts)3478 bool Builder::GenerateConditionalBlock(
3479     const ast::Expression* cond,
3480     const ast::BlockStatement* true_body,
3481     size_t cur_else_idx,
3482     const ast::ElseStatementList& else_stmts) {
3483   auto cond_id = GenerateExpression(cond);
3484   if (cond_id == 0) {
3485     return false;
3486   }
3487   cond_id = GenerateLoadIfNeeded(TypeOf(cond), cond_id);
3488 
3489   auto merge_block = result_op();
3490   auto merge_block_id = merge_block.to_i();
3491 
3492   if (!push_function_inst(spv::Op::OpSelectionMerge,
3493                           {Operand::Int(merge_block_id),
3494                            Operand::Int(SpvSelectionControlMaskNone)})) {
3495     return false;
3496   }
3497 
3498   auto true_block = result_op();
3499   auto true_block_id = true_block.to_i();
3500 
3501   // if there are no more else statements we branch on false to the merge
3502   // block otherwise we branch to the false block
3503   auto false_block_id =
3504       cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
3505 
3506   if (!push_function_inst(spv::Op::OpBranchConditional,
3507                           {Operand::Int(cond_id), Operand::Int(true_block_id),
3508                            Operand::Int(false_block_id)})) {
3509     return false;
3510   }
3511 
3512   // Output true block
3513   if (!GenerateLabel(true_block_id)) {
3514     return false;
3515   }
3516   if (!GenerateBlockStatement(true_body)) {
3517     return false;
3518   }
3519   // We only branch if the last element of the body didn't already branch.
3520   if (!LastIsTerminator(true_body)) {
3521     if (!push_function_inst(spv::Op::OpBranch,
3522                             {Operand::Int(merge_block_id)})) {
3523       return false;
3524     }
3525   }
3526 
3527   // Start the false block if needed
3528   if (false_block_id != merge_block_id) {
3529     if (!GenerateLabel(false_block_id)) {
3530       return false;
3531     }
3532 
3533     auto* else_stmt = else_stmts[cur_else_idx];
3534     // Handle the else case by just outputting the statements.
3535     if (!else_stmt->condition) {
3536       if (!GenerateBlockStatement(else_stmt->body)) {
3537         return false;
3538       }
3539     } else {
3540       if (!GenerateConditionalBlock(else_stmt->condition, else_stmt->body,
3541                                     cur_else_idx + 1, else_stmts)) {
3542         return false;
3543       }
3544     }
3545     if (!LastIsTerminator(else_stmt->body)) {
3546       if (!push_function_inst(spv::Op::OpBranch,
3547                               {Operand::Int(merge_block_id)})) {
3548         return false;
3549       }
3550     }
3551   }
3552 
3553   // Output the merge block
3554   return GenerateLabel(merge_block_id);
3555 }
3556 
GenerateIfStatement(const ast::IfStatement * stmt)3557 bool Builder::GenerateIfStatement(const ast::IfStatement* stmt) {
3558   if (!continuing_stack_.empty() &&
3559       stmt == continuing_stack_.back().last_statement->As<ast::IfStatement>()) {
3560     const ContinuingInfo& ci = continuing_stack_.back();
3561     // Match one of two patterns: the break-if and break-unless patterns.
3562     //
3563     // The break-if pattern:
3564     //  continuing { ...
3565     //    if (cond) { break; }
3566     //  }
3567     //
3568     // The break-unless pattern:
3569     //  continuing { ...
3570     //    if (cond) {} else {break;}
3571     //  }
3572     auto is_just_a_break = [](const ast::BlockStatement* block) {
3573       return block && (block->statements.size() == 1) &&
3574              block->Last()->Is<ast::BreakStatement>();
3575     };
3576     if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
3577       // It's a break-if.
3578       TINT_ASSERT(Writer, !backedge_stack_.empty());
3579       const auto cond_id = GenerateExpression(stmt->condition);
3580       backedge_stack_.back() =
3581           Backedge(spv::Op::OpBranchConditional,
3582                    {Operand::Int(cond_id), Operand::Int(ci.break_target_id),
3583                     Operand::Int(ci.loop_header_id)});
3584       return true;
3585     } else if (stmt->body->Empty()) {
3586       const auto& es = stmt->else_statements;
3587       if (es.size() == 1 && !es.back()->condition &&
3588           is_just_a_break(es.back()->body)) {
3589         // It's a break-unless.
3590         TINT_ASSERT(Writer, !backedge_stack_.empty());
3591         const auto cond_id = GenerateExpression(stmt->condition);
3592         backedge_stack_.back() =
3593             Backedge(spv::Op::OpBranchConditional,
3594                      {Operand::Int(cond_id), Operand::Int(ci.loop_header_id),
3595                       Operand::Int(ci.break_target_id)});
3596         return true;
3597       }
3598     }
3599   }
3600 
3601   if (!GenerateConditionalBlock(stmt->condition, stmt->body, 0,
3602                                 stmt->else_statements)) {
3603     return false;
3604   }
3605   return true;
3606 }
3607 
GenerateSwitchStatement(const ast::SwitchStatement * stmt)3608 bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
3609   auto merge_block = result_op();
3610   auto merge_block_id = merge_block.to_i();
3611 
3612   merge_stack_.push_back(merge_block_id);
3613 
3614   auto cond_id = GenerateExpression(stmt->condition);
3615   if (cond_id == 0) {
3616     return false;
3617   }
3618   cond_id = GenerateLoadIfNeeded(TypeOf(stmt->condition), cond_id);
3619 
3620   auto default_block = result_op();
3621   auto default_block_id = default_block.to_i();
3622 
3623   OperandList params = {Operand::Int(cond_id), Operand::Int(default_block_id)};
3624 
3625   std::vector<uint32_t> case_ids;
3626   for (const auto* item : stmt->body) {
3627     if (item->IsDefault()) {
3628       case_ids.push_back(default_block_id);
3629       continue;
3630     }
3631 
3632     auto block = result_op();
3633     auto block_id = block.to_i();
3634 
3635     case_ids.push_back(block_id);
3636     for (auto* selector : item->selectors) {
3637       auto* int_literal = selector->As<ast::IntLiteralExpression>();
3638       if (!int_literal) {
3639         error_ = "expected integer literal for switch case label";
3640         return false;
3641       }
3642 
3643       params.push_back(Operand::Int(int_literal->ValueAsU32()));
3644       params.push_back(Operand::Int(block_id));
3645     }
3646   }
3647 
3648   if (!push_function_inst(spv::Op::OpSelectionMerge,
3649                           {Operand::Int(merge_block_id),
3650                            Operand::Int(SpvSelectionControlMaskNone)})) {
3651     return false;
3652   }
3653   if (!push_function_inst(spv::Op::OpSwitch, params)) {
3654     return false;
3655   }
3656 
3657   bool generated_default = false;
3658   auto& body = stmt->body;
3659   // We output the case statements in order they were entered in the original
3660   // source. Each fallthrough goes to the next case entry, so is a forward
3661   // branch, otherwise the branch is to the merge block which comes after
3662   // the switch statement.
3663   for (uint32_t i = 0; i < body.size(); i++) {
3664     auto* item = body[i];
3665 
3666     if (item->IsDefault()) {
3667       generated_default = true;
3668     }
3669 
3670     if (!GenerateLabel(case_ids[i])) {
3671       return false;
3672     }
3673     if (!GenerateBlockStatement(item->body)) {
3674       return false;
3675     }
3676 
3677     if (LastIsFallthrough(item->body)) {
3678       if (i == (body.size() - 1)) {
3679         // This case is caught by Resolver validation
3680         TINT_UNREACHABLE(Writer, builder_.Diagnostics());
3681         return false;
3682       }
3683       if (!push_function_inst(spv::Op::OpBranch,
3684                               {Operand::Int(case_ids[i + 1])})) {
3685         return false;
3686       }
3687     } else if (!LastIsTerminator(item->body)) {
3688       if (!push_function_inst(spv::Op::OpBranch,
3689                               {Operand::Int(merge_block_id)})) {
3690         return false;
3691       }
3692     }
3693   }
3694 
3695   if (!generated_default) {
3696     if (!GenerateLabel(default_block_id)) {
3697       return false;
3698     }
3699     if (!push_function_inst(spv::Op::OpBranch,
3700                             {Operand::Int(merge_block_id)})) {
3701       return false;
3702     }
3703   }
3704 
3705   merge_stack_.pop_back();
3706 
3707   return GenerateLabel(merge_block_id);
3708 }
3709 
GenerateReturnStatement(const ast::ReturnStatement * stmt)3710 bool Builder::GenerateReturnStatement(const ast::ReturnStatement* stmt) {
3711   if (stmt->value) {
3712     auto val_id = GenerateExpression(stmt->value);
3713     if (val_id == 0) {
3714       return false;
3715     }
3716     val_id = GenerateLoadIfNeeded(TypeOf(stmt->value), val_id);
3717     if (!push_function_inst(spv::Op::OpReturnValue, {Operand::Int(val_id)})) {
3718       return false;
3719     }
3720   } else {
3721     if (!push_function_inst(spv::Op::OpReturn, {})) {
3722       return false;
3723     }
3724   }
3725 
3726   return true;
3727 }
3728 
GenerateLoopStatement(const ast::LoopStatement * stmt)3729 bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
3730   auto loop_header = result_op();
3731   auto loop_header_id = loop_header.to_i();
3732   if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)})) {
3733     return false;
3734   }
3735   if (!GenerateLabel(loop_header_id)) {
3736     return false;
3737   }
3738 
3739   auto merge_block = result_op();
3740   auto merge_block_id = merge_block.to_i();
3741   auto continue_block = result_op();
3742   auto continue_block_id = continue_block.to_i();
3743 
3744   auto body_block = result_op();
3745   auto body_block_id = body_block.to_i();
3746 
3747   if (!push_function_inst(
3748           spv::Op::OpLoopMerge,
3749           {Operand::Int(merge_block_id), Operand::Int(continue_block_id),
3750            Operand::Int(SpvLoopControlMaskNone)})) {
3751     return false;
3752   }
3753 
3754   continue_stack_.push_back(continue_block_id);
3755   merge_stack_.push_back(merge_block_id);
3756 
3757   // Usually, the backedge is a simple branch.  This will be modified if the
3758   // backedge block in the continuing construct has an exiting edge.
3759   backedge_stack_.emplace_back(spv::Op::OpBranch,
3760                                OperandList{Operand::Int(loop_header_id)});
3761 
3762   if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)})) {
3763     return false;
3764   }
3765   if (!GenerateLabel(body_block_id)) {
3766     return false;
3767   }
3768 
3769   // We need variables from the body to be visible in the continuing block, so
3770   // manage scope outside of GenerateBlockStatement.
3771   {
3772     scope_stack_.Push();
3773     TINT_DEFER(scope_stack_.Pop());
3774 
3775     if (!GenerateBlockStatementWithoutScoping(stmt->body)) {
3776       return false;
3777     }
3778 
3779     // We only branch if the last element of the body didn't already branch.
3780     if (!LastIsTerminator(stmt->body)) {
3781       if (!push_function_inst(spv::Op::OpBranch,
3782                               {Operand::Int(continue_block_id)})) {
3783         return false;
3784       }
3785     }
3786 
3787     if (!GenerateLabel(continue_block_id)) {
3788       return false;
3789     }
3790     if (stmt->continuing && !stmt->continuing->Empty()) {
3791       continuing_stack_.emplace_back(stmt->continuing->Last(), loop_header_id,
3792                                      merge_block_id);
3793       if (!GenerateBlockStatementWithoutScoping(stmt->continuing)) {
3794         return false;
3795       }
3796       continuing_stack_.pop_back();
3797     }
3798   }
3799 
3800   // Generate the backedge.
3801   TINT_ASSERT(Writer, !backedge_stack_.empty());
3802   const Backedge& backedge = backedge_stack_.back();
3803   if (!push_function_inst(backedge.opcode, backedge.operands)) {
3804     return false;
3805   }
3806   backedge_stack_.pop_back();
3807 
3808   merge_stack_.pop_back();
3809   continue_stack_.pop_back();
3810 
3811   return GenerateLabel(merge_block_id);
3812 }
3813 
GenerateStatement(const ast::Statement * stmt)3814 bool Builder::GenerateStatement(const ast::Statement* stmt) {
3815   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
3816     return GenerateAssignStatement(a);
3817   }
3818   if (auto* b = stmt->As<ast::BlockStatement>()) {
3819     return GenerateBlockStatement(b);
3820   }
3821   if (auto* b = stmt->As<ast::BreakStatement>()) {
3822     return GenerateBreakStatement(b);
3823   }
3824   if (auto* c = stmt->As<ast::CallStatement>()) {
3825     return GenerateCallExpression(c->expr) != 0;
3826   }
3827   if (auto* c = stmt->As<ast::ContinueStatement>()) {
3828     return GenerateContinueStatement(c);
3829   }
3830   if (auto* d = stmt->As<ast::DiscardStatement>()) {
3831     return GenerateDiscardStatement(d);
3832   }
3833   if (stmt->Is<ast::FallthroughStatement>()) {
3834     // Do nothing here, the fallthrough gets handled by the switch code.
3835     return true;
3836   }
3837   if (auto* i = stmt->As<ast::IfStatement>()) {
3838     return GenerateIfStatement(i);
3839   }
3840   if (auto* l = stmt->As<ast::LoopStatement>()) {
3841     return GenerateLoopStatement(l);
3842   }
3843   if (auto* r = stmt->As<ast::ReturnStatement>()) {
3844     return GenerateReturnStatement(r);
3845   }
3846   if (auto* s = stmt->As<ast::SwitchStatement>()) {
3847     return GenerateSwitchStatement(s);
3848   }
3849   if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
3850     return GenerateVariableDeclStatement(v);
3851   }
3852 
3853   error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
3854   return false;
3855 }
3856 
GenerateVariableDeclStatement(const ast::VariableDeclStatement * stmt)3857 bool Builder::GenerateVariableDeclStatement(
3858     const ast::VariableDeclStatement* stmt) {
3859   return GenerateFunctionVariable(stmt->variable);
3860 }
3861 
GenerateTypeIfNeeded(const sem::Type * type)3862 uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
3863   if (type == nullptr) {
3864     error_ = "attempting to generate type from null type";
3865     return 0;
3866   }
3867 
3868   // Atomics are a type in WGSL, but aren't a distinct type in SPIR-V.
3869   // Just emit the type inside the atomic.
3870   if (auto* atomic = type->As<sem::Atomic>()) {
3871     return GenerateTypeIfNeeded(atomic->Type());
3872   }
3873 
3874   // Pointers and references with differing accesses should not result in a
3875   // different SPIR-V types, so we explicitly ignore the access.
3876   // Pointers and References both map to a SPIR-V pointer type.
3877   // Transform a Reference to a Pointer to prevent these having duplicated
3878   // definitions in the generated SPIR-V. Note that nested pointers and
3879   // references are not legal in WGSL, so only considering the top-level type is
3880   // fine.
3881   std::string type_name;
3882   if (auto* ptr = type->As<sem::Pointer>()) {
3883     type_name =
3884         sem::Pointer(ptr->StoreType(), ptr->StorageClass(), ast::kReadWrite)
3885             .type_name();
3886   } else if (auto* ref = type->As<sem::Reference>()) {
3887     type_name =
3888         sem::Pointer(ref->StoreType(), ref->StorageClass(), ast::kReadWrite)
3889             .type_name();
3890   } else {
3891     type_name = type->type_name();
3892   }
3893 
3894   return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
3895     auto result = result_op();
3896     auto id = result.to_i();
3897     if (auto* arr = type->As<sem::Array>()) {
3898       if (!GenerateArrayType(arr, result)) {
3899         return 0;
3900       }
3901     } else if (type->Is<sem::Bool>()) {
3902       push_type(spv::Op::OpTypeBool, {result});
3903     } else if (type->Is<sem::F32>()) {
3904       push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
3905     } else if (type->Is<sem::I32>()) {
3906       push_type(spv::Op::OpTypeInt,
3907                 {result, Operand::Int(32), Operand::Int(1)});
3908     } else if (auto* mat = type->As<sem::Matrix>()) {
3909       if (!GenerateMatrixType(mat, result)) {
3910         return 0;
3911       }
3912     } else if (auto* ptr = type->As<sem::Pointer>()) {
3913       if (!GeneratePointerType(ptr, result)) {
3914         return 0;
3915       }
3916     } else if (auto* ref = type->As<sem::Reference>()) {
3917       if (!GenerateReferenceType(ref, result)) {
3918         return 0;
3919       }
3920     } else if (auto* str = type->As<sem::Struct>()) {
3921       if (!GenerateStructType(str, result)) {
3922         return 0;
3923       }
3924     } else if (type->Is<sem::U32>()) {
3925       push_type(spv::Op::OpTypeInt,
3926                 {result, Operand::Int(32), Operand::Int(0)});
3927     } else if (auto* vec = type->As<sem::Vector>()) {
3928       if (!GenerateVectorType(vec, result)) {
3929         return 0;
3930       }
3931     } else if (type->Is<sem::Void>()) {
3932       push_type(spv::Op::OpTypeVoid, {result});
3933     } else if (auto* tex = type->As<sem::Texture>()) {
3934       if (!GenerateTextureType(tex, result)) {
3935         return 0;
3936       }
3937 
3938       if (auto* st = tex->As<sem::StorageTexture>()) {
3939         // Register all three access types of StorageTexture names. In SPIR-V,
3940         // we must output a single type, while the variable is annotated with
3941         // the access type. Doing this ensures we de-dupe.
3942         type_name_to_id_[builder_
3943                              .create<sem::StorageTexture>(
3944                                  st->dim(), st->image_format(),
3945                                  ast::Access::kRead, st->type())
3946                              ->type_name()] = id;
3947         type_name_to_id_[builder_
3948                              .create<sem::StorageTexture>(
3949                                  st->dim(), st->image_format(),
3950                                  ast::Access::kWrite, st->type())
3951                              ->type_name()] = id;
3952         type_name_to_id_[builder_
3953                              .create<sem::StorageTexture>(
3954                                  st->dim(), st->image_format(),
3955                                  ast::Access::kReadWrite, st->type())
3956                              ->type_name()] = id;
3957       }
3958 
3959     } else if (type->Is<sem::Sampler>()) {
3960       push_type(spv::Op::OpTypeSampler, {result});
3961 
3962       // Register both of the sampler type names. In SPIR-V they're the same
3963       // sampler type, so we need to match that when we do the dedup check.
3964       type_name_to_id_["__sampler_sampler"] = id;
3965       type_name_to_id_["__sampler_comparison"] = id;
3966 
3967     } else {
3968       error_ = "unable to convert type: " + type->type_name();
3969       return 0;
3970     }
3971 
3972     return id;
3973   });
3974 }
3975 
GenerateTextureType(const sem::Texture * texture,const Operand & result)3976 bool Builder::GenerateTextureType(const sem::Texture* texture,
3977                                   const Operand& result) {
3978   uint32_t array_literal = 0u;
3979   const auto dim = texture->dim();
3980   if (dim == ast::TextureDimension::k2dArray ||
3981       dim == ast::TextureDimension::kCubeArray) {
3982     array_literal = 1u;
3983   }
3984 
3985   uint32_t dim_literal = SpvDim2D;
3986   if (dim == ast::TextureDimension::k1d) {
3987     dim_literal = SpvDim1D;
3988     if (texture->Is<sem::SampledTexture>()) {
3989       push_capability(SpvCapabilitySampled1D);
3990     } else if (texture->Is<sem::StorageTexture>()) {
3991       push_capability(SpvCapabilityImage1D);
3992     }
3993   }
3994   if (dim == ast::TextureDimension::k3d) {
3995     dim_literal = SpvDim3D;
3996   }
3997   if (dim == ast::TextureDimension::kCube ||
3998       dim == ast::TextureDimension::kCubeArray) {
3999     dim_literal = SpvDimCube;
4000   }
4001 
4002   uint32_t ms_literal = 0u;
4003   if (texture->IsAnyOf<sem::MultisampledTexture,
4004                        sem::DepthMultisampledTexture>()) {
4005     ms_literal = 1u;
4006   }
4007 
4008   uint32_t depth_literal = 0u;
4009   if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4010     depth_literal = 1u;
4011   }
4012 
4013   uint32_t sampled_literal = 2u;
4014   if (texture->IsAnyOf<sem::MultisampledTexture, sem::SampledTexture,
4015                        sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4016     sampled_literal = 1u;
4017   }
4018 
4019   if (dim == ast::TextureDimension::kCubeArray) {
4020     if (texture->Is<sem::SampledTexture>() ||
4021         texture->Is<sem::DepthTexture>()) {
4022       push_capability(SpvCapabilitySampledCubeArray);
4023     }
4024   }
4025 
4026   uint32_t type_id = 0u;
4027   if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
4028     type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
4029   } else if (auto* s = texture->As<sem::SampledTexture>()) {
4030     type_id = GenerateTypeIfNeeded(s->type());
4031   } else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
4032     type_id = GenerateTypeIfNeeded(ms->type());
4033   } else if (auto* st = texture->As<sem::StorageTexture>()) {
4034     type_id = GenerateTypeIfNeeded(st->type());
4035   }
4036   if (type_id == 0u) {
4037     return false;
4038   }
4039 
4040   uint32_t format_literal = SpvImageFormat_::SpvImageFormatUnknown;
4041   if (auto* t = texture->As<sem::StorageTexture>()) {
4042     format_literal = convert_image_format_to_spv(t->image_format());
4043   }
4044 
4045   push_type(spv::Op::OpTypeImage,
4046             {result, Operand::Int(type_id), Operand::Int(dim_literal),
4047              Operand::Int(depth_literal), Operand::Int(array_literal),
4048              Operand::Int(ms_literal), Operand::Int(sampled_literal),
4049              Operand::Int(format_literal)});
4050 
4051   return true;
4052 }
4053 
GenerateArrayType(const sem::Array * ary,const Operand & result)4054 bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
4055   auto elem_type = GenerateTypeIfNeeded(ary->ElemType());
4056   if (elem_type == 0) {
4057     return false;
4058   }
4059 
4060   auto result_id = result.to_i();
4061   if (ary->IsRuntimeSized()) {
4062     push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
4063   } else {
4064     auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->Count()));
4065     if (len_id == 0) {
4066       return false;
4067     }
4068 
4069     push_type(spv::Op::OpTypeArray,
4070               {result, Operand::Int(elem_type), Operand::Int(len_id)});
4071   }
4072 
4073   push_annot(spv::Op::OpDecorate,
4074              {Operand::Int(result_id), Operand::Int(SpvDecorationArrayStride),
4075               Operand::Int(ary->Stride())});
4076   return true;
4077 }
4078 
GenerateMatrixType(const sem::Matrix * mat,const Operand & result)4079 bool Builder::GenerateMatrixType(const sem::Matrix* mat,
4080                                  const Operand& result) {
4081   auto* col_type = builder_.create<sem::Vector>(mat->type(), mat->rows());
4082   auto col_type_id = GenerateTypeIfNeeded(col_type);
4083   if (has_error()) {
4084     return false;
4085   }
4086 
4087   push_type(spv::Op::OpTypeMatrix,
4088             {result, Operand::Int(col_type_id), Operand::Int(mat->columns())});
4089   return true;
4090 }
4091 
GeneratePointerType(const sem::Pointer * ptr,const Operand & result)4092 bool Builder::GeneratePointerType(const sem::Pointer* ptr,
4093                                   const Operand& result) {
4094   auto subtype_id = GenerateTypeIfNeeded(ptr->StoreType());
4095   if (subtype_id == 0) {
4096     return false;
4097   }
4098 
4099   auto stg_class = ConvertStorageClass(ptr->StorageClass());
4100   if (stg_class == SpvStorageClassMax) {
4101     error_ = "invalid storage class for pointer";
4102     return false;
4103   }
4104 
4105   push_type(spv::Op::OpTypePointer,
4106             {result, Operand::Int(stg_class), Operand::Int(subtype_id)});
4107 
4108   return true;
4109 }
4110 
GenerateReferenceType(const sem::Reference * ref,const Operand & result)4111 bool Builder::GenerateReferenceType(const sem::Reference* ref,
4112                                     const Operand& result) {
4113   auto subtype_id = GenerateTypeIfNeeded(ref->StoreType());
4114   if (subtype_id == 0) {
4115     return false;
4116   }
4117 
4118   auto stg_class = ConvertStorageClass(ref->StorageClass());
4119   if (stg_class == SpvStorageClassMax) {
4120     error_ = "invalid storage class for reference";
4121     return false;
4122   }
4123 
4124   push_type(spv::Op::OpTypePointer,
4125             {result, Operand::Int(stg_class), Operand::Int(subtype_id)});
4126 
4127   return true;
4128 }
4129 
GenerateStructType(const sem::Struct * struct_type,const Operand & result)4130 bool Builder::GenerateStructType(const sem::Struct* struct_type,
4131                                  const Operand& result) {
4132   auto struct_id = result.to_i();
4133 
4134   if (struct_type->Name().IsValid()) {
4135     push_debug(
4136         spv::Op::OpName,
4137         {Operand::Int(struct_id),
4138          Operand::String(builder_.Symbols().NameFor(struct_type->Name()))});
4139   }
4140 
4141   OperandList ops;
4142   ops.push_back(result);
4143 
4144   auto* decl = struct_type->Declaration();
4145   if (decl && decl->IsBlockDecorated()) {
4146     push_annot(spv::Op::OpDecorate,
4147                {Operand::Int(struct_id), Operand::Int(SpvDecorationBlock)});
4148   }
4149 
4150   for (uint32_t i = 0; i < struct_type->Members().size(); ++i) {
4151     auto mem_id = GenerateStructMember(struct_id, i, struct_type->Members()[i]);
4152     if (mem_id == 0) {
4153       return false;
4154     }
4155 
4156     ops.push_back(Operand::Int(mem_id));
4157   }
4158 
4159   push_type(spv::Op::OpTypeStruct, std::move(ops));
4160   return true;
4161 }
4162 
GenerateStructMember(uint32_t struct_id,uint32_t idx,const sem::StructMember * member)4163 uint32_t Builder::GenerateStructMember(uint32_t struct_id,
4164                                        uint32_t idx,
4165                                        const sem::StructMember* member) {
4166   push_debug(spv::Op::OpMemberName,
4167              {Operand::Int(struct_id), Operand::Int(idx),
4168               Operand::String(builder_.Symbols().NameFor(member->Name()))});
4169 
4170   // Note: This will generate layout annotations for *all* structs, whether or
4171   // not they are used in host-shareable variables. This is officially ok in
4172   // SPIR-V 1.0 through 1.3. If / when we migrate to using SPIR-V 1.4 we'll have
4173   // to only generate the layout info for structs used for certain storage
4174   // classes.
4175 
4176   push_annot(
4177       spv::Op::OpMemberDecorate,
4178       {Operand::Int(struct_id), Operand::Int(idx),
4179        Operand::Int(SpvDecorationOffset), Operand::Int(member->Offset())});
4180 
4181   // Infer and emit matrix layout.
4182   auto* matrix_type = GetNestedMatrixType(member->Type());
4183   if (matrix_type) {
4184     push_annot(spv::Op::OpMemberDecorate,
4185                {Operand::Int(struct_id), Operand::Int(idx),
4186                 Operand::Int(SpvDecorationColMajor)});
4187     if (!matrix_type->type()->Is<sem::F32>()) {
4188       error_ = "matrix scalar element type must be f32";
4189       return 0;
4190     }
4191     const auto scalar_elem_size = 4;
4192     const auto effective_row_count = (matrix_type->rows() == 2) ? 2 : 4;
4193     push_annot(spv::Op::OpMemberDecorate,
4194                {Operand::Int(struct_id), Operand::Int(idx),
4195                 Operand::Int(SpvDecorationMatrixStride),
4196                 Operand::Int(effective_row_count * scalar_elem_size)});
4197   }
4198 
4199   return GenerateTypeIfNeeded(member->Type());
4200 }
4201 
GenerateVectorType(const sem::Vector * vec,const Operand & result)4202 bool Builder::GenerateVectorType(const sem::Vector* vec,
4203                                  const Operand& result) {
4204   auto type_id = GenerateTypeIfNeeded(vec->type());
4205   if (has_error()) {
4206     return false;
4207   }
4208 
4209   push_type(spv::Op::OpTypeVector,
4210             {result, Operand::Int(type_id), Operand::Int(vec->Width())});
4211   return true;
4212 }
4213 
ConvertStorageClass(ast::StorageClass klass) const4214 SpvStorageClass Builder::ConvertStorageClass(ast::StorageClass klass) const {
4215   switch (klass) {
4216     case ast::StorageClass::kInvalid:
4217       return SpvStorageClassMax;
4218     case ast::StorageClass::kInput:
4219       return SpvStorageClassInput;
4220     case ast::StorageClass::kOutput:
4221       return SpvStorageClassOutput;
4222     case ast::StorageClass::kUniform:
4223       return SpvStorageClassUniform;
4224     case ast::StorageClass::kWorkgroup:
4225       return SpvStorageClassWorkgroup;
4226     case ast::StorageClass::kUniformConstant:
4227       return SpvStorageClassUniformConstant;
4228     case ast::StorageClass::kStorage:
4229       return SpvStorageClassStorageBuffer;
4230     case ast::StorageClass::kImage:
4231       return SpvStorageClassImage;
4232     case ast::StorageClass::kPrivate:
4233       return SpvStorageClassPrivate;
4234     case ast::StorageClass::kFunction:
4235       return SpvStorageClassFunction;
4236     case ast::StorageClass::kNone:
4237       break;
4238   }
4239   return SpvStorageClassMax;
4240 }
4241 
ConvertBuiltin(ast::Builtin builtin,ast::StorageClass storage)4242 SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin,
4243                                    ast::StorageClass storage) {
4244   switch (builtin) {
4245     case ast::Builtin::kPosition:
4246       if (storage == ast::StorageClass::kInput) {
4247         return SpvBuiltInFragCoord;
4248       } else if (storage == ast::StorageClass::kOutput) {
4249         return SpvBuiltInPosition;
4250       } else {
4251         TINT_ICE(Writer, builder_.Diagnostics())
4252             << "invalid storage class for builtin";
4253         break;
4254       }
4255     case ast::Builtin::kVertexIndex:
4256       return SpvBuiltInVertexIndex;
4257     case ast::Builtin::kInstanceIndex:
4258       return SpvBuiltInInstanceIndex;
4259     case ast::Builtin::kFrontFacing:
4260       return SpvBuiltInFrontFacing;
4261     case ast::Builtin::kFragDepth:
4262       return SpvBuiltInFragDepth;
4263     case ast::Builtin::kLocalInvocationId:
4264       return SpvBuiltInLocalInvocationId;
4265     case ast::Builtin::kLocalInvocationIndex:
4266       return SpvBuiltInLocalInvocationIndex;
4267     case ast::Builtin::kGlobalInvocationId:
4268       return SpvBuiltInGlobalInvocationId;
4269     case ast::Builtin::kPointSize:
4270       return SpvBuiltInPointSize;
4271     case ast::Builtin::kWorkgroupId:
4272       return SpvBuiltInWorkgroupId;
4273     case ast::Builtin::kNumWorkgroups:
4274       return SpvBuiltInNumWorkgroups;
4275     case ast::Builtin::kSampleIndex:
4276       push_capability(SpvCapabilitySampleRateShading);
4277       return SpvBuiltInSampleId;
4278     case ast::Builtin::kSampleMask:
4279       return SpvBuiltInSampleMask;
4280     case ast::Builtin::kNone:
4281       break;
4282   }
4283   return SpvBuiltInMax;
4284 }
4285 
AddInterpolationDecorations(uint32_t id,ast::InterpolationType type,ast::InterpolationSampling sampling)4286 void Builder::AddInterpolationDecorations(uint32_t id,
4287                                           ast::InterpolationType type,
4288                                           ast::InterpolationSampling sampling) {
4289   switch (type) {
4290     case ast::InterpolationType::kLinear:
4291       push_annot(spv::Op::OpDecorate,
4292                  {Operand::Int(id), Operand::Int(SpvDecorationNoPerspective)});
4293       break;
4294     case ast::InterpolationType::kFlat:
4295       push_annot(spv::Op::OpDecorate,
4296                  {Operand::Int(id), Operand::Int(SpvDecorationFlat)});
4297       break;
4298     case ast::InterpolationType::kPerspective:
4299       break;
4300   }
4301   switch (sampling) {
4302     case ast::InterpolationSampling::kCentroid:
4303       push_annot(spv::Op::OpDecorate,
4304                  {Operand::Int(id), Operand::Int(SpvDecorationCentroid)});
4305       break;
4306     case ast::InterpolationSampling::kSample:
4307       push_capability(SpvCapabilitySampleRateShading);
4308       push_annot(spv::Op::OpDecorate,
4309                  {Operand::Int(id), Operand::Int(SpvDecorationSample)});
4310       break;
4311     case ast::InterpolationSampling::kCenter:
4312     case ast::InterpolationSampling::kNone:
4313       break;
4314   }
4315 }
4316 
convert_image_format_to_spv(const ast::ImageFormat format)4317 SpvImageFormat Builder::convert_image_format_to_spv(
4318     const ast::ImageFormat format) {
4319   switch (format) {
4320     case ast::ImageFormat::kR8Unorm:
4321       push_capability(SpvCapabilityStorageImageExtendedFormats);
4322       return SpvImageFormatR8;
4323     case ast::ImageFormat::kR8Snorm:
4324       push_capability(SpvCapabilityStorageImageExtendedFormats);
4325       return SpvImageFormatR8Snorm;
4326     case ast::ImageFormat::kR8Uint:
4327       push_capability(SpvCapabilityStorageImageExtendedFormats);
4328       return SpvImageFormatR8ui;
4329     case ast::ImageFormat::kR8Sint:
4330       push_capability(SpvCapabilityStorageImageExtendedFormats);
4331       return SpvImageFormatR8i;
4332     case ast::ImageFormat::kR16Uint:
4333       push_capability(SpvCapabilityStorageImageExtendedFormats);
4334       return SpvImageFormatR16ui;
4335     case ast::ImageFormat::kR16Sint:
4336       push_capability(SpvCapabilityStorageImageExtendedFormats);
4337       return SpvImageFormatR16i;
4338     case ast::ImageFormat::kR16Float:
4339       push_capability(SpvCapabilityStorageImageExtendedFormats);
4340       return SpvImageFormatR16f;
4341     case ast::ImageFormat::kRg8Unorm:
4342       push_capability(SpvCapabilityStorageImageExtendedFormats);
4343       return SpvImageFormatRg8;
4344     case ast::ImageFormat::kRg8Snorm:
4345       push_capability(SpvCapabilityStorageImageExtendedFormats);
4346       return SpvImageFormatRg8Snorm;
4347     case ast::ImageFormat::kRg8Uint:
4348       push_capability(SpvCapabilityStorageImageExtendedFormats);
4349       return SpvImageFormatRg8ui;
4350     case ast::ImageFormat::kRg8Sint:
4351       push_capability(SpvCapabilityStorageImageExtendedFormats);
4352       return SpvImageFormatRg8i;
4353     case ast::ImageFormat::kR32Uint:
4354       return SpvImageFormatR32ui;
4355     case ast::ImageFormat::kR32Sint:
4356       return SpvImageFormatR32i;
4357     case ast::ImageFormat::kR32Float:
4358       return SpvImageFormatR32f;
4359     case ast::ImageFormat::kRg16Uint:
4360       push_capability(SpvCapabilityStorageImageExtendedFormats);
4361       return SpvImageFormatRg16ui;
4362     case ast::ImageFormat::kRg16Sint:
4363       push_capability(SpvCapabilityStorageImageExtendedFormats);
4364       return SpvImageFormatRg16i;
4365     case ast::ImageFormat::kRg16Float:
4366       push_capability(SpvCapabilityStorageImageExtendedFormats);
4367       return SpvImageFormatRg16f;
4368     case ast::ImageFormat::kRgba8Unorm:
4369       return SpvImageFormatRgba8;
4370     case ast::ImageFormat::kRgba8UnormSrgb:
4371       return SpvImageFormatUnknown;
4372     case ast::ImageFormat::kRgba8Snorm:
4373       return SpvImageFormatRgba8Snorm;
4374     case ast::ImageFormat::kRgba8Uint:
4375       return SpvImageFormatRgba8ui;
4376     case ast::ImageFormat::kRgba8Sint:
4377       return SpvImageFormatRgba8i;
4378     case ast::ImageFormat::kBgra8Unorm:
4379       return SpvImageFormatUnknown;
4380     case ast::ImageFormat::kBgra8UnormSrgb:
4381       return SpvImageFormatUnknown;
4382     case ast::ImageFormat::kRgb10A2Unorm:
4383       push_capability(SpvCapabilityStorageImageExtendedFormats);
4384       return SpvImageFormatRgb10A2;
4385     case ast::ImageFormat::kRg11B10Float:
4386       push_capability(SpvCapabilityStorageImageExtendedFormats);
4387       return SpvImageFormatR11fG11fB10f;
4388     case ast::ImageFormat::kRg32Uint:
4389       push_capability(SpvCapabilityStorageImageExtendedFormats);
4390       return SpvImageFormatRg32ui;
4391     case ast::ImageFormat::kRg32Sint:
4392       push_capability(SpvCapabilityStorageImageExtendedFormats);
4393       return SpvImageFormatRg32i;
4394     case ast::ImageFormat::kRg32Float:
4395       push_capability(SpvCapabilityStorageImageExtendedFormats);
4396       return SpvImageFormatRg32f;
4397     case ast::ImageFormat::kRgba16Uint:
4398       return SpvImageFormatRgba16ui;
4399     case ast::ImageFormat::kRgba16Sint:
4400       return SpvImageFormatRgba16i;
4401     case ast::ImageFormat::kRgba16Float:
4402       return SpvImageFormatRgba16f;
4403     case ast::ImageFormat::kRgba32Uint:
4404       return SpvImageFormatRgba32ui;
4405     case ast::ImageFormat::kRgba32Sint:
4406       return SpvImageFormatRgba32i;
4407     case ast::ImageFormat::kRgba32Float:
4408       return SpvImageFormatRgba32f;
4409     case ast::ImageFormat::kNone:
4410       return SpvImageFormatUnknown;
4411   }
4412   return SpvImageFormatUnknown;
4413 }
4414 
push_function_inst(spv::Op op,const OperandList & operands)4415 bool Builder::push_function_inst(spv::Op op, const OperandList& operands) {
4416   if (functions_.empty()) {
4417     std::ostringstream ss;
4418     ss << "Internal error: trying to add SPIR-V instruction " << int(op)
4419        << " outside a function";
4420     error_ = ss.str();
4421     return false;
4422   }
4423   functions_.back().push_inst(op, operands);
4424   return true;
4425 }
4426 
ContinuingInfo(const ast::Statement * the_last_statement,uint32_t loop_id,uint32_t break_id)4427 Builder::ContinuingInfo::ContinuingInfo(
4428     const ast::Statement* the_last_statement,
4429     uint32_t loop_id,
4430     uint32_t break_id)
4431     : last_statement(the_last_statement),
4432       loop_header_id(loop_id),
4433       break_target_id(break_id) {
4434   TINT_ASSERT(Writer, last_statement != nullptr);
4435   TINT_ASSERT(Writer, loop_header_id != 0u);
4436   TINT_ASSERT(Writer, break_target_id != 0u);
4437 }
4438 
Backedge(spv::Op the_opcode,OperandList the_operands)4439 Builder::Backedge::Backedge(spv::Op the_opcode, OperandList the_operands)
4440     : opcode(the_opcode), operands(the_operands) {}
4441 
4442 Builder::Backedge::Backedge(const Builder::Backedge& other) = default;
4443 Builder::Backedge& Builder::Backedge::operator=(
4444     const Builder::Backedge& other) = default;
4445 Builder::Backedge::~Backedge() = default;
4446 
4447 }  // namespace spirv
4448 }  // namespace writer
4449 }  // namespace tint
4450