• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/reader/spirv/parser_impl.h"
16 
17 #include <algorithm>
18 #include <limits>
19 #include <locale>
20 #include <utility>
21 
22 #include "source/opt/build_module.h"
23 #include "src/ast/bitcast_expression.h"
24 #include "src/ast/disable_validation_decoration.h"
25 #include "src/ast/interpolate_decoration.h"
26 #include "src/ast/override_decoration.h"
27 #include "src/ast/struct_block_decoration.h"
28 #include "src/ast/type_name.h"
29 #include "src/ast/unary_op_expression.h"
30 #include "src/reader/spirv/function.h"
31 #include "src/sem/depth_texture_type.h"
32 #include "src/sem/multisampled_texture_type.h"
33 #include "src/sem/sampled_texture_type.h"
34 #include "src/utils/unique_vector.h"
35 
36 namespace tint {
37 namespace reader {
38 namespace spirv {
39 
40 namespace {
41 
42 // Input SPIR-V needs only to conform to Vulkan 1.1 requirements.
43 // The combination of the SPIR-V reader and the semantics of WGSL
44 // tighten up the code so that the output of the SPIR-V *writer*
45 // will satisfy SPV_ENV_WEBGPU_0 validation.
46 const spv_target_env kInputEnv = SPV_ENV_VULKAN_1_1;
47 
48 // A FunctionTraverser is used to compute an ordering of functions in the
49 // module such that callees precede callers.
50 class FunctionTraverser {
51  public:
FunctionTraverser(const spvtools::opt::Module & module)52   explicit FunctionTraverser(const spvtools::opt::Module& module)
53       : module_(module) {}
54 
55   // @returns the functions in the modules such that callees precede callers.
TopologicallyOrderedFunctions()56   std::vector<const spvtools::opt::Function*> TopologicallyOrderedFunctions() {
57     visited_.clear();
58     ordered_.clear();
59     id_to_func_.clear();
60     for (const auto& f : module_) {
61       id_to_func_[f.result_id()] = &f;
62     }
63     for (const auto& f : module_) {
64       Visit(f);
65     }
66     return ordered_;
67   }
68 
69  private:
Visit(const spvtools::opt::Function & f)70   void Visit(const spvtools::opt::Function& f) {
71     if (visited_.count(&f)) {
72       return;
73     }
74     visited_.insert(&f);
75     for (const auto& bb : f) {
76       for (const auto& inst : bb) {
77         if (inst.opcode() != SpvOpFunctionCall) {
78           continue;
79         }
80         const auto* callee = id_to_func_[inst.GetSingleWordInOperand(0)];
81         if (callee) {
82           Visit(*callee);
83         }
84       }
85     }
86     ordered_.push_back(&f);
87   }
88 
89   const spvtools::opt::Module& module_;
90   std::unordered_set<const spvtools::opt::Function*> visited_;
91   std::unordered_map<uint32_t, const spvtools::opt::Function*> id_to_func_;
92   std::vector<const spvtools::opt::Function*> ordered_;
93 };
94 
95 // Returns true if the opcode operates as if its operands are signed integral.
AssumesSignedOperands(SpvOp opcode)96 bool AssumesSignedOperands(SpvOp opcode) {
97   switch (opcode) {
98     case SpvOpSNegate:
99     case SpvOpSDiv:
100     case SpvOpSRem:
101     case SpvOpSMod:
102     case SpvOpSLessThan:
103     case SpvOpSLessThanEqual:
104     case SpvOpSGreaterThan:
105     case SpvOpSGreaterThanEqual:
106     case SpvOpConvertSToF:
107       return true;
108     default:
109       break;
110   }
111   return false;
112 }
113 
114 // Returns true if the GLSL extended instruction expects operands to be signed.
115 // @param extended_opcode GLSL.std.450 opcode
116 // @returns true if all operands must be signed integral type
AssumesSignedOperands(GLSLstd450 extended_opcode)117 bool AssumesSignedOperands(GLSLstd450 extended_opcode) {
118   switch (extended_opcode) {
119     case GLSLstd450SAbs:
120     case GLSLstd450SSign:
121     case GLSLstd450SMin:
122     case GLSLstd450SMax:
123     case GLSLstd450SClamp:
124       return true;
125     default:
126       break;
127   }
128   return false;
129 }
130 
131 // Returns true if the opcode operates as if its operands are unsigned integral.
AssumesUnsignedOperands(SpvOp opcode)132 bool AssumesUnsignedOperands(SpvOp opcode) {
133   switch (opcode) {
134     case SpvOpUDiv:
135     case SpvOpUMod:
136     case SpvOpULessThan:
137     case SpvOpULessThanEqual:
138     case SpvOpUGreaterThan:
139     case SpvOpUGreaterThanEqual:
140     case SpvOpConvertUToF:
141       return true;
142     default:
143       break;
144   }
145   return false;
146 }
147 
148 // Returns true if the GLSL extended instruction expects operands to be
149 // unsigned.
150 // @param extended_opcode GLSL.std.450 opcode
151 // @returns true if all operands must be unsigned integral type
AssumesUnsignedOperands(GLSLstd450 extended_opcode)152 bool AssumesUnsignedOperands(GLSLstd450 extended_opcode) {
153   switch (extended_opcode) {
154     case GLSLstd450UMin:
155     case GLSLstd450UMax:
156     case GLSLstd450UClamp:
157       return true;
158     default:
159       break;
160   }
161   return false;
162 }
163 
164 // Returns true if the corresponding WGSL operation requires
165 // the signedness of the second operand to match the signedness of the
166 // first operand, and it's not one of the OpU* or OpS* instructions.
167 // (Those are handled via MakeOperand.)
AssumesSecondOperandSignednessMatchesFirstOperand(SpvOp opcode)168 bool AssumesSecondOperandSignednessMatchesFirstOperand(SpvOp opcode) {
169   switch (opcode) {
170     // All the OpI* integer binary operations.
171     case SpvOpIAdd:
172     case SpvOpISub:
173     case SpvOpIMul:
174     case SpvOpIEqual:
175     case SpvOpINotEqual:
176     // All the bitwise integer binary operations.
177     case SpvOpBitwiseAnd:
178     case SpvOpBitwiseOr:
179     case SpvOpBitwiseXor:
180       return true;
181     default:
182       break;
183   }
184   return false;
185 }
186 
187 // Returns true if the corresponding WGSL operation requires
188 // the signedness of the result to match the signedness of the first operand.
AssumesResultSignednessMatchesFirstOperand(SpvOp opcode)189 bool AssumesResultSignednessMatchesFirstOperand(SpvOp opcode) {
190   switch (opcode) {
191     case SpvOpNot:
192     case SpvOpSNegate:
193     case SpvOpBitCount:
194     case SpvOpBitReverse:
195     case SpvOpSDiv:
196     case SpvOpSMod:
197     case SpvOpSRem:
198     case SpvOpIAdd:
199     case SpvOpISub:
200     case SpvOpIMul:
201     case SpvOpBitwiseAnd:
202     case SpvOpBitwiseOr:
203     case SpvOpBitwiseXor:
204     case SpvOpShiftLeftLogical:
205     case SpvOpShiftRightLogical:
206     case SpvOpShiftRightArithmetic:
207       return true;
208     default:
209       break;
210   }
211   return false;
212 }
213 
214 // Returns true if the extended instruction requires the signedness of the
215 // result to match the signedness of the first operand to the operation.
216 // @param extended_opcode GLSL.std.450 opcode
217 // @returns true if the result type must match the first operand type.
AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode)218 bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
219   switch (extended_opcode) {
220     case GLSLstd450SAbs:
221     case GLSLstd450SSign:
222     case GLSLstd450SMin:
223     case GLSLstd450SMax:
224     case GLSLstd450SClamp:
225     case GLSLstd450UMin:
226     case GLSLstd450UMax:
227     case GLSLstd450UClamp:
228       // TODO(dneto): FindSMsb?
229       // TODO(dneto): FindUMsb?
230       return true;
231     default:
232       break;
233   }
234   return false;
235 }
236 
237 // @param a SPIR-V decoration
238 // @return true when the given decoration is a pipeline decoration other than a
239 // bulitin variable.
IsPipelineDecoration(const Decoration & deco)240 bool IsPipelineDecoration(const Decoration& deco) {
241   if (deco.size() < 1) {
242     return false;
243   }
244   switch (deco[0]) {
245     case SpvDecorationLocation:
246     case SpvDecorationFlat:
247     case SpvDecorationNoPerspective:
248     case SpvDecorationCentroid:
249     case SpvDecorationSample:
250       return true;
251     default:
252       break;
253   }
254   return false;
255 }
256 
257 }  // namespace
258 
259 TypedExpression::TypedExpression() = default;
260 
261 TypedExpression::TypedExpression(const TypedExpression&) = default;
262 
263 TypedExpression& TypedExpression::operator=(const TypedExpression&) = default;
264 
TypedExpression(const Type * type_in,const ast::Expression * expr_in)265 TypedExpression::TypedExpression(const Type* type_in,
266                                  const ast::Expression* expr_in)
267     : type(type_in), expr(expr_in) {}
268 
ParserImpl(const std::vector<uint32_t> & spv_binary)269 ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
270     : Reader(),
271       spv_binary_(spv_binary),
272       fail_stream_(&success_, &errors_),
273       namer_(fail_stream_),
274       enum_converter_(fail_stream_),
275       tools_context_(kInputEnv) {
276   // Create a message consumer to propagate error messages from SPIRV-Tools
277   // out as our own failures.
278   message_consumer_ = [this](spv_message_level_t level, const char* /*source*/,
279                              const spv_position_t& position,
280                              const char* message) {
281     switch (level) {
282       // Ignore info and warning message.
283       case SPV_MSG_WARNING:
284       case SPV_MSG_INFO:
285         break;
286       // Otherwise, propagate the error.
287       default:
288         // For binary validation errors, we only have the instruction
289         // number.  It's not text, so there is no column number.
290         this->Fail() << "line:" << position.index << ": " << message;
291     }
292   };
293 }
294 
295 ParserImpl::~ParserImpl() = default;
296 
Parse()297 bool ParserImpl::Parse() {
298   // Set up use of SPIRV-Tools utilities.
299   spvtools::SpirvTools spv_tools(kInputEnv);
300 
301   // Error messages from SPIRV-Tools are forwarded as failures, including
302   // setting |success_| to false.
303   spv_tools.SetMessageConsumer(message_consumer_);
304 
305   if (!success_) {
306     return false;
307   }
308 
309   // Only consider modules valid for Vulkan 1.0.  On failure, the message
310   // consumer will set the error status.
311   if (!spv_tools.Validate(spv_binary_)) {
312     success_ = false;
313     return false;
314   }
315   if (!BuildInternalModule()) {
316     return false;
317   }
318   if (!ParseInternalModule()) {
319     return false;
320   }
321 
322   return success_;
323 }
324 
program()325 Program ParserImpl::program() {
326   // TODO(dneto): Should we clear out spv_binary_ here, to reduce
327   // memory usage?
328   return tint::Program(std::move(builder_));
329 }
330 
ConvertType(uint32_t type_id,PtrAs ptr_as)331 const Type* ParserImpl::ConvertType(uint32_t type_id, PtrAs ptr_as) {
332   if (!success_) {
333     return nullptr;
334   }
335 
336   if (type_mgr_ == nullptr) {
337     Fail() << "ConvertType called when the internal module has not been built";
338     return nullptr;
339   }
340 
341   auto* spirv_type = type_mgr_->GetType(type_id);
342   if (spirv_type == nullptr) {
343     Fail() << "ID is not a SPIR-V type: " << type_id;
344     return nullptr;
345   }
346 
347   switch (spirv_type->kind()) {
348     case spvtools::opt::analysis::Type::kVoid:
349       return ty_.Void();
350     case spvtools::opt::analysis::Type::kBool:
351       return ty_.Bool();
352     case spvtools::opt::analysis::Type::kInteger:
353       return ConvertType(spirv_type->AsInteger());
354     case spvtools::opt::analysis::Type::kFloat:
355       return ConvertType(spirv_type->AsFloat());
356     case spvtools::opt::analysis::Type::kVector:
357       return ConvertType(spirv_type->AsVector());
358     case spvtools::opt::analysis::Type::kMatrix:
359       return ConvertType(spirv_type->AsMatrix());
360     case spvtools::opt::analysis::Type::kRuntimeArray:
361       return ConvertType(type_id, spirv_type->AsRuntimeArray());
362     case spvtools::opt::analysis::Type::kArray:
363       return ConvertType(type_id, spirv_type->AsArray());
364     case spvtools::opt::analysis::Type::kStruct:
365       return ConvertType(type_id, spirv_type->AsStruct());
366     case spvtools::opt::analysis::Type::kPointer:
367       return ConvertType(type_id, ptr_as, spirv_type->AsPointer());
368     case spvtools::opt::analysis::Type::kFunction:
369       // Tint doesn't have a Function type.
370       // We need to convert the result type and parameter types.
371       // But the SPIR-V defines those before defining the function
372       // type.  No further work is required here.
373       return nullptr;
374     case spvtools::opt::analysis::Type::kSampler:
375     case spvtools::opt::analysis::Type::kSampledImage:
376     case spvtools::opt::analysis::Type::kImage:
377       // Fake it for sampler and texture types.  These are handled in an
378       // entirely different way.
379       return ty_.Void();
380     default:
381       break;
382   }
383 
384   Fail() << "unknown SPIR-V type with ID " << type_id << ": "
385          << def_use_mgr_->GetDef(type_id)->PrettyPrint();
386   return nullptr;
387 }
388 
GetDecorationsFor(uint32_t id) const389 DecorationList ParserImpl::GetDecorationsFor(uint32_t id) const {
390   DecorationList result;
391   const auto& decorations = deco_mgr_->GetDecorationsFor(id, true);
392   for (const auto* inst : decorations) {
393     if (inst->opcode() != SpvOpDecorate) {
394       continue;
395     }
396     // Example: OpDecorate %struct_id Block
397     // Example: OpDecorate %array_ty ArrayStride 16
398     std::vector<uint32_t> inst_as_words;
399     inst->ToBinaryWithoutAttachedDebugInsts(&inst_as_words);
400     Decoration d(inst_as_words.begin() + 2, inst_as_words.end());
401     result.push_back(d);
402   }
403   return result;
404 }
405 
GetDecorationsForMember(uint32_t id,uint32_t member_index) const406 DecorationList ParserImpl::GetDecorationsForMember(
407     uint32_t id,
408     uint32_t member_index) const {
409   DecorationList result;
410   const auto& decorations = deco_mgr_->GetDecorationsFor(id, true);
411   for (const auto* inst : decorations) {
412     if ((inst->opcode() != SpvOpMemberDecorate) ||
413         (inst->GetSingleWordInOperand(1) != member_index)) {
414       continue;
415     }
416     // Example: OpMemberDecorate %struct_id 2 Offset 24
417     std::vector<uint32_t> inst_as_words;
418     inst->ToBinaryWithoutAttachedDebugInsts(&inst_as_words);
419     Decoration d(inst_as_words.begin() + 3, inst_as_words.end());
420     result.push_back(d);
421   }
422   return result;
423 }
424 
ShowType(uint32_t type_id)425 std::string ParserImpl::ShowType(uint32_t type_id) {
426   if (def_use_mgr_) {
427     const auto* type_inst = def_use_mgr_->GetDef(type_id);
428     if (type_inst) {
429       return type_inst->PrettyPrint();
430     }
431   }
432   return "SPIR-V type " + std::to_string(type_id);
433 }
434 
ConvertMemberDecoration(uint32_t struct_type_id,uint32_t member_index,const Type * member_ty,const Decoration & decoration)435 ast::DecorationList ParserImpl::ConvertMemberDecoration(
436     uint32_t struct_type_id,
437     uint32_t member_index,
438     const Type* member_ty,
439     const Decoration& decoration) {
440   if (decoration.empty()) {
441     Fail() << "malformed SPIR-V decoration: it's empty";
442     return {};
443   }
444   switch (decoration[0]) {
445     case SpvDecorationOffset:
446       if (decoration.size() != 2) {
447         Fail()
448             << "malformed Offset decoration: expected 1 literal operand, has "
449             << decoration.size() - 1 << ": member " << member_index << " of "
450             << ShowType(struct_type_id);
451         return {};
452       }
453       return {
454           create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
455       };
456     case SpvDecorationNonReadable:
457       // WGSL doesn't have a member decoration for this.  Silently drop it.
458       return {};
459     case SpvDecorationNonWritable:
460       // WGSL doesn't have a member decoration for this.
461       return {};
462     case SpvDecorationColMajor:
463       // WGSL only supports column major matrices.
464       return {};
465     case SpvDecorationRelaxedPrecision:
466       // WGSL doesn't support relaxed precision.
467       return {};
468     case SpvDecorationRowMajor:
469       Fail() << "WGSL does not support row-major matrices: can't "
470                 "translate member "
471              << member_index << " of " << ShowType(struct_type_id);
472       return {};
473     case SpvDecorationMatrixStride: {
474       if (decoration.size() != 2) {
475         Fail() << "malformed MatrixStride decoration: expected 1 literal "
476                   "operand, has "
477                << decoration.size() - 1 << ": member " << member_index << " of "
478                << ShowType(struct_type_id);
479         return {};
480       }
481       uint32_t stride = decoration[1];
482       auto* ty = member_ty->UnwrapAlias();
483       while (auto* arr = ty->As<Array>()) {
484         ty = arr->type->UnwrapAlias();
485       }
486       auto* mat = ty->As<Matrix>();
487       if (!mat) {
488         Fail() << "MatrixStride cannot be applied to type " << ty->String();
489         return {};
490       }
491       uint32_t natural_stride = (mat->rows == 2) ? 8 : 16;
492       if (stride == natural_stride) {
493         return {};  // Decoration matches the natural stride for the matrix
494       }
495       if (!member_ty->Is<Matrix>()) {
496         Fail() << "custom matrix strides not currently supported on array of "
497                   "matrices";
498         return {};
499       }
500       return {
501           create<ast::StrideDecoration>(Source{}, decoration[1]),
502           builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
503               builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
504       };
505     }
506     default:
507       // TODO(dneto): Support the remaining member decorations.
508       break;
509   }
510   Fail() << "unhandled member decoration: " << decoration[0] << " on member "
511          << member_index << " of " << ShowType(struct_type_id);
512   return {};
513 }
514 
BuildInternalModule()515 bool ParserImpl::BuildInternalModule() {
516   if (!success_) {
517     return false;
518   }
519 
520   const spv_context& context = tools_context_.CContext();
521   ir_context_ = spvtools::BuildModule(context->target_env, context->consumer,
522                                       spv_binary_.data(), spv_binary_.size());
523   if (!ir_context_) {
524     return Fail() << "internal error: couldn't build the internal "
525                      "representation of the module";
526   }
527   module_ = ir_context_->module();
528   def_use_mgr_ = ir_context_->get_def_use_mgr();
529   constant_mgr_ = ir_context_->get_constant_mgr();
530   type_mgr_ = ir_context_->get_type_mgr();
531   deco_mgr_ = ir_context_->get_decoration_mgr();
532 
533   topologically_ordered_functions_ =
534       FunctionTraverser(*module_).TopologicallyOrderedFunctions();
535 
536   return success_;
537 }
538 
ResetInternalModule()539 void ParserImpl::ResetInternalModule() {
540   ir_context_.reset(nullptr);
541   module_ = nullptr;
542   def_use_mgr_ = nullptr;
543   constant_mgr_ = nullptr;
544   type_mgr_ = nullptr;
545   deco_mgr_ = nullptr;
546 
547   glsl_std_450_imports_.clear();
548 }
549 
ParseInternalModule()550 bool ParserImpl::ParseInternalModule() {
551   if (!success_) {
552     return false;
553   }
554   RegisterLineNumbers();
555   if (!ParseInternalModuleExceptFunctions()) {
556     return false;
557   }
558   if (!EmitFunctions()) {
559     return false;
560   }
561   return success_;
562 }
563 
RegisterLineNumbers()564 void ParserImpl::RegisterLineNumbers() {
565   Source::Location instruction_number{};
566 
567   // Has there been an OpLine since the last OpNoLine or start of the module?
568   bool in_op_line_scope = false;
569   // The source location provided by the most recent OpLine instruction.
570   Source::Location op_line_source{};
571   const bool run_on_debug_insts = true;
572   module_->ForEachInst(
573       [this, &in_op_line_scope, &op_line_source,
574        &instruction_number](const spvtools::opt::Instruction* inst) {
575         ++instruction_number.line;
576         switch (inst->opcode()) {
577           case SpvOpLine:
578             in_op_line_scope = true;
579             // TODO(dneto): This ignores the File ID (operand 0), since the Tint
580             // Source concept doesn't represent that.
581             op_line_source.line = inst->GetSingleWordInOperand(1);
582             op_line_source.column = inst->GetSingleWordInOperand(2);
583             break;
584           case SpvOpNoLine:
585             in_op_line_scope = false;
586             break;
587           default:
588             break;
589         }
590         this->inst_source_[inst] =
591             in_op_line_scope ? op_line_source : instruction_number;
592       },
593       run_on_debug_insts);
594 }
595 
GetSourceForResultIdForTest(uint32_t id) const596 Source ParserImpl::GetSourceForResultIdForTest(uint32_t id) const {
597   return GetSourceForInst(def_use_mgr_->GetDef(id));
598 }
599 
GetSourceForInst(const spvtools::opt::Instruction * inst) const600 Source ParserImpl::GetSourceForInst(
601     const spvtools::opt::Instruction* inst) const {
602   auto where = inst_source_.find(inst);
603   if (where == inst_source_.end()) {
604     return {};
605   }
606   return Source{where->second};
607 }
608 
ParseInternalModuleExceptFunctions()609 bool ParserImpl::ParseInternalModuleExceptFunctions() {
610   if (!success_) {
611     return false;
612   }
613   if (!RegisterExtendedInstructionImports()) {
614     return false;
615   }
616   if (!RegisterUserAndStructMemberNames()) {
617     return false;
618   }
619   if (!RegisterWorkgroupSizeBuiltin()) {
620     return false;
621   }
622   if (!RegisterEntryPoints()) {
623     return false;
624   }
625   if (!RegisterHandleUsage()) {
626     return false;
627   }
628   if (!RegisterTypes()) {
629     return false;
630   }
631   if (!RejectInvalidPointerRoots()) {
632     return false;
633   }
634   if (!EmitScalarSpecConstants()) {
635     return false;
636   }
637   if (!EmitModuleScopeVariables()) {
638     return false;
639   }
640   return success_;
641 }
642 
RegisterExtendedInstructionImports()643 bool ParserImpl::RegisterExtendedInstructionImports() {
644   for (const spvtools::opt::Instruction& import : module_->ext_inst_imports()) {
645     std::string name(
646         reinterpret_cast<const char*>(import.GetInOperand(0).words.data()));
647     // TODO(dneto): Handle other extended instruction sets when needed.
648     if (name == "GLSL.std.450") {
649       glsl_std_450_imports_.insert(import.result_id());
650     } else if (name.find("NonSemantic.") == 0) {
651       ignored_imports_.insert(import.result_id());
652     } else {
653       return Fail() << "Unrecognized extended instruction set: " << name;
654     }
655   }
656   return true;
657 }
658 
IsGlslExtendedInstruction(const spvtools::opt::Instruction & inst) const659 bool ParserImpl::IsGlslExtendedInstruction(
660     const spvtools::opt::Instruction& inst) const {
661   return (inst.opcode() == SpvOpExtInst) &&
662          (glsl_std_450_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
663 }
664 
IsIgnoredExtendedInstruction(const spvtools::opt::Instruction & inst) const665 bool ParserImpl::IsIgnoredExtendedInstruction(
666     const spvtools::opt::Instruction& inst) const {
667   return (inst.opcode() == SpvOpExtInst) &&
668          (ignored_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
669 }
670 
RegisterUserAndStructMemberNames()671 bool ParserImpl::RegisterUserAndStructMemberNames() {
672   if (!success_) {
673     return false;
674   }
675   // Register entry point names. An entry point name is the point of contact
676   // between the API and the shader. It has the highest priority for
677   // preservation, so register it first.
678   for (const spvtools::opt::Instruction& entry_point :
679        module_->entry_points()) {
680     const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
681     const std::string name = entry_point.GetInOperand(2).AsString();
682 
683     // This translator requires the entry point to be a valid WGSL identifier.
684     // Allowing otherwise leads to difficulties in that the programmer needs
685     // to get a mapping from their original entry point name to the WGSL name,
686     // and we don't have a good mechanism for that.
687     if (!IsValidIdentifier(name)) {
688       return Fail() << "entry point name is not a valid WGSL identifier: "
689                     << name;
690     }
691 
692     // SPIR-V allows a single function to be the implementation for more
693     // than one entry point.  In the common case, it's one-to-one, and we should
694     // try to name the function after the entry point.  Otherwise, give the
695     // function a name automatically derived from the entry point name.
696     namer_.SuggestSanitizedName(function_id, name);
697 
698     // There is another many-to-one relationship to take care of:  In SPIR-V
699     // the same name can be used for multiple entry points, provided they are
700     // for different shader stages. Take action now to ensure we can use the
701     // entry point name later on, and not have it taken for another identifier
702     // by an accidental collision with a derived name made for a different ID.
703     if (!namer_.IsRegistered(name)) {
704       // The entry point name is "unoccupied" becase an earlier entry point
705       // grabbed the slot for the function that implements both entry points.
706       // Register this new entry point's name, to avoid accidental collisions
707       // with a future generated ID.
708       if (!namer_.RegisterWithoutId(name)) {
709         return false;
710       }
711     }
712   }
713 
714   // Register names from OpName and OpMemberName
715   for (const auto& inst : module_->debugs2()) {
716     switch (inst.opcode()) {
717       case SpvOpName: {
718         const auto name = inst.GetInOperand(1).AsString();
719         if (!name.empty()) {
720           namer_.SuggestSanitizedName(inst.GetSingleWordInOperand(0), name);
721         }
722         break;
723       }
724       case SpvOpMemberName: {
725         const auto name = inst.GetInOperand(2).AsString();
726         if (!name.empty()) {
727           namer_.SuggestSanitizedMemberName(inst.GetSingleWordInOperand(0),
728                                             inst.GetSingleWordInOperand(1),
729                                             name);
730         }
731         break;
732       }
733       default:
734         break;
735     }
736   }
737 
738   // Fill in struct member names, and disambiguate them.
739   for (const auto* type_inst : module_->GetTypes()) {
740     if (type_inst->opcode() == SpvOpTypeStruct) {
741       namer_.ResolveMemberNamesForStruct(type_inst->result_id(),
742                                          type_inst->NumInOperands());
743     }
744   }
745 
746   return true;
747 }
748 
IsValidIdentifier(const std::string & str)749 bool ParserImpl::IsValidIdentifier(const std::string& str) {
750   if (str.empty()) {
751     return false;
752   }
753   std::locale c_locale("C");
754   if (!std::isalpha(str[0], c_locale)) {
755     return false;
756   }
757   for (const char& ch : str) {
758     if ((ch != '_') && !std::isalnum(ch, c_locale)) {
759       return false;
760     }
761   }
762   return true;
763 }
764 
RegisterWorkgroupSizeBuiltin()765 bool ParserImpl::RegisterWorkgroupSizeBuiltin() {
766   WorkgroupSizeInfo& info = workgroup_size_builtin_;
767   for (const spvtools::opt::Instruction& inst : module_->annotations()) {
768     if (inst.opcode() != SpvOpDecorate) {
769       continue;
770     }
771     if (inst.GetSingleWordInOperand(1) != SpvDecorationBuiltIn) {
772       continue;
773     }
774     if (inst.GetSingleWordInOperand(2) != SpvBuiltInWorkgroupSize) {
775       continue;
776     }
777     info.id = inst.GetSingleWordInOperand(0);
778   }
779   if (info.id == 0) {
780     return true;
781   }
782   // Gather the values.
783   const spvtools::opt::Instruction* composite_def =
784       def_use_mgr_->GetDef(info.id);
785   if (!composite_def) {
786     return Fail() << "Invalid WorkgroupSize builtin value";
787   }
788   // SPIR-V validation checks that the result is a 3-element vector of 32-bit
789   // integer scalars (signed or unsigned).  Rely on validation to check the
790   // type.  In theory the instruction could be OpConstantNull and still
791   // pass validation, but that would be non-sensical.  Be a little more
792   // stringent here and check for specific opcodes.  WGSL does not support
793   // const-expr yet, so avoid supporting OpSpecConstantOp here.
794   // TODO(dneto): See https://github.com/gpuweb/gpuweb/issues/1272 for WGSL
795   // const_expr proposals.
796   if ((composite_def->opcode() != SpvOpSpecConstantComposite &&
797        composite_def->opcode() != SpvOpConstantComposite)) {
798     return Fail() << "Invalid WorkgroupSize builtin.  Expected 3-element "
799                      "OpSpecConstantComposite or OpConstantComposite:  "
800                   << composite_def->PrettyPrint();
801   }
802   info.type_id = composite_def->type_id();
803   // Extract the component type from the vector type.
804   info.component_type_id =
805       def_use_mgr_->GetDef(info.type_id)->GetSingleWordInOperand(0);
806 
807   /// Sets the ID and value of the index'th member of the composite constant.
808   /// Returns false and emits a diagnostic on error.
809   auto set_param = [this, composite_def](uint32_t* id_ptr, uint32_t* value_ptr,
810                                          int index) -> bool {
811     const auto id = composite_def->GetSingleWordInOperand(index);
812     const auto* def = def_use_mgr_->GetDef(id);
813     if (!def ||
814         (def->opcode() != SpvOpSpecConstant &&
815          def->opcode() != SpvOpConstant) ||
816         (def->NumInOperands() != 1)) {
817       return Fail() << "invalid component " << index << " of workgroupsize "
818                     << (def ? def->PrettyPrint()
819                             : std::string("no definition"));
820     }
821     *id_ptr = id;
822     // Use the default value of a spec constant.
823     *value_ptr = def->GetSingleWordInOperand(0);
824     return true;
825   };
826 
827   return set_param(&info.x_id, &info.x_value, 0) &&
828          set_param(&info.y_id, &info.y_value, 1) &&
829          set_param(&info.z_id, &info.z_value, 2);
830 }
831 
RegisterEntryPoints()832 bool ParserImpl::RegisterEntryPoints() {
833   // Mapping from entry point ID to GridSize computed from LocalSize
834   // decorations.
835   std::unordered_map<uint32_t, GridSize> local_size;
836   for (const spvtools::opt::Instruction& inst : module_->execution_modes()) {
837     auto mode = static_cast<SpvExecutionMode>(inst.GetSingleWordInOperand(1));
838     if (mode == SpvExecutionModeLocalSize) {
839       if (inst.NumInOperands() != 5) {
840         // This won't even get past SPIR-V binary parsing.
841         return Fail() << "invalid LocalSize execution mode: "
842                       << inst.PrettyPrint();
843       }
844       uint32_t function_id = inst.GetSingleWordInOperand(0);
845       local_size[function_id] = GridSize{inst.GetSingleWordInOperand(2),
846                                          inst.GetSingleWordInOperand(3),
847                                          inst.GetSingleWordInOperand(4)};
848     }
849   }
850 
851   for (const spvtools::opt::Instruction& entry_point :
852        module_->entry_points()) {
853     const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0));
854     const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
855 
856     const std::string ep_name = entry_point.GetOperand(2).AsString();
857     if (!IsValidIdentifier(ep_name)) {
858       return Fail() << "entry point name is not a valid WGSL identifier: "
859                     << ep_name;
860     }
861 
862     bool owns_inner_implementation = false;
863     std::string inner_implementation_name;
864 
865     auto where = function_to_ep_info_.find(function_id);
866     if (where == function_to_ep_info_.end()) {
867       // If this is the first entry point to have function_id as its
868       // implementation, then this entry point is responsible for generating
869       // the inner implementation.
870       owns_inner_implementation = true;
871       inner_implementation_name = namer_.MakeDerivedName(ep_name);
872     } else {
873       // Reuse the inner implementation owned by the first entry point.
874       inner_implementation_name = where->second[0].inner_name;
875     }
876     TINT_ASSERT(Reader, !inner_implementation_name.empty());
877     TINT_ASSERT(Reader, ep_name != inner_implementation_name);
878 
879     utils::UniqueVector<uint32_t> inputs;
880     utils::UniqueVector<uint32_t> outputs;
881     for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) {
882       const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg);
883       if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) {
884         switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) {
885           case SpvStorageClassInput:
886             inputs.add(var_id);
887             break;
888           case SpvStorageClassOutput:
889             outputs.add(var_id);
890             break;
891           default:
892             break;
893         }
894       }
895     }
896     // Save the lists, in ID-sorted order.
897     std::vector<uint32_t> sorted_inputs(inputs);
898     std::sort(sorted_inputs.begin(), sorted_inputs.end());
899     std::vector<uint32_t> sorted_outputs(outputs);
900     std::sort(sorted_outputs.begin(), sorted_outputs.end());
901 
902     const auto ast_stage = enum_converter_.ToPipelineStage(stage);
903     GridSize wgsize;
904     if (ast_stage == ast::PipelineStage::kCompute) {
905       if (workgroup_size_builtin_.id) {
906         // Store the default values.
907         // WGSL allows specializing these, but this code doesn't support that
908         // yet. https://github.com/gpuweb/gpuweb/issues/1442
909         wgsize = GridSize{workgroup_size_builtin_.x_value,
910                           workgroup_size_builtin_.y_value,
911                           workgroup_size_builtin_.z_value};
912       } else {
913         // Use the LocalSize execution mode.  This is the second choice.
914         auto where_local_size = local_size.find(function_id);
915         if (where_local_size != local_size.end()) {
916           wgsize = where_local_size->second;
917         }
918       }
919     }
920     function_to_ep_info_[function_id].emplace_back(
921         ep_name, ast_stage, owns_inner_implementation,
922         inner_implementation_name, std::move(sorted_inputs),
923         std::move(sorted_outputs), wgsize);
924   }
925 
926   // The enum conversion could have failed, so return the existing status value.
927   return success_;
928 }
929 
ConvertType(const spvtools::opt::analysis::Integer * int_ty)930 const Type* ParserImpl::ConvertType(
931     const spvtools::opt::analysis::Integer* int_ty) {
932   if (int_ty->width() == 32) {
933     return int_ty->IsSigned() ? static_cast<const Type*>(ty_.I32())
934                               : static_cast<const Type*>(ty_.U32());
935   }
936   Fail() << "unhandled integer width: " << int_ty->width();
937   return nullptr;
938 }
939 
ConvertType(const spvtools::opt::analysis::Float * float_ty)940 const Type* ParserImpl::ConvertType(
941     const spvtools::opt::analysis::Float* float_ty) {
942   if (float_ty->width() == 32) {
943     return ty_.F32();
944   }
945   Fail() << "unhandled float width: " << float_ty->width();
946   return nullptr;
947 }
948 
ConvertType(const spvtools::opt::analysis::Vector * vec_ty)949 const Type* ParserImpl::ConvertType(
950     const spvtools::opt::analysis::Vector* vec_ty) {
951   const auto num_elem = vec_ty->element_count();
952   auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
953   if (ast_elem_ty == nullptr) {
954     return ast_elem_ty;
955   }
956   return ty_.Vector(ast_elem_ty, num_elem);
957 }
958 
ConvertType(const spvtools::opt::analysis::Matrix * mat_ty)959 const Type* ParserImpl::ConvertType(
960     const spvtools::opt::analysis::Matrix* mat_ty) {
961   const auto* vec_ty = mat_ty->element_type()->AsVector();
962   const auto* scalar_ty = vec_ty->element_type();
963   const auto num_rows = vec_ty->element_count();
964   const auto num_columns = mat_ty->element_count();
965   auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty));
966   if (ast_scalar_ty == nullptr) {
967     return nullptr;
968   }
969   return ty_.Matrix(ast_scalar_ty, num_columns, num_rows);
970 }
971 
ConvertType(uint32_t type_id,const spvtools::opt::analysis::RuntimeArray * rtarr_ty)972 const Type* ParserImpl::ConvertType(
973     uint32_t type_id,
974     const spvtools::opt::analysis::RuntimeArray* rtarr_ty) {
975   auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
976   if (ast_elem_ty == nullptr) {
977     return nullptr;
978   }
979   uint32_t array_stride = 0;
980   if (!ParseArrayDecorations(rtarr_ty, &array_stride)) {
981     return nullptr;
982   }
983   const Type* result = ty_.Array(ast_elem_ty, 0, array_stride);
984   return MaybeGenerateAlias(type_id, rtarr_ty, result);
985 }
986 
ConvertType(uint32_t type_id,const spvtools::opt::analysis::Array * arr_ty)987 const Type* ParserImpl::ConvertType(
988     uint32_t type_id,
989     const spvtools::opt::analysis::Array* arr_ty) {
990   // Get the element type. The SPIR-V optimizer's types representation
991   // deduplicates array types that have the same parameterization.
992   // We don't want that deduplication, so get the element type from
993   // the SPIR-V type directly.
994   const auto* inst = def_use_mgr_->GetDef(type_id);
995   const auto elem_type_id = inst->GetSingleWordInOperand(0);
996   auto* ast_elem_ty = ConvertType(elem_type_id);
997   if (ast_elem_ty == nullptr) {
998     return nullptr;
999   }
1000   // Get the length.
1001   const auto& length_info = arr_ty->length_info();
1002   if (length_info.words.empty()) {
1003     // The internal representation is invalid. The discriminant vector
1004     // is mal-formed.
1005     Fail() << "internal error: Array length info is invalid";
1006     return nullptr;
1007   }
1008   if (length_info.words[0] !=
1009       spvtools::opt::analysis::Array::LengthInfo::kConstant) {
1010     Fail() << "Array type " << type_mgr_->GetId(arr_ty)
1011            << " length is a specialization constant";
1012     return nullptr;
1013   }
1014   const auto* constant = constant_mgr_->FindDeclaredConstant(length_info.id);
1015   if (constant == nullptr) {
1016     Fail() << "Array type " << type_mgr_->GetId(arr_ty) << " length ID "
1017            << length_info.id << " does not name an OpConstant";
1018     return nullptr;
1019   }
1020   const uint64_t num_elem = constant->GetZeroExtendedValue();
1021   // For now, limit to only 32bits.
1022   if (num_elem > std::numeric_limits<uint32_t>::max()) {
1023     Fail() << "Array type " << type_mgr_->GetId(arr_ty)
1024            << " has too many elements (more than can fit in 32 bits): "
1025            << num_elem;
1026     return nullptr;
1027   }
1028   uint32_t array_stride = 0;
1029   if (!ParseArrayDecorations(arr_ty, &array_stride)) {
1030     return nullptr;
1031   }
1032   if (remap_buffer_block_type_.count(elem_type_id)) {
1033     remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
1034   }
1035   const Type* result =
1036       ty_.Array(ast_elem_ty, static_cast<uint32_t>(num_elem), array_stride);
1037   return MaybeGenerateAlias(type_id, arr_ty, result);
1038 }
1039 
ParseArrayDecorations(const spvtools::opt::analysis::Type * spv_type,uint32_t * array_stride)1040 bool ParserImpl::ParseArrayDecorations(
1041     const spvtools::opt::analysis::Type* spv_type,
1042     uint32_t* array_stride) {
1043   bool has_array_stride = false;
1044   *array_stride = 0;  // Implicit stride case.
1045   const auto type_id = type_mgr_->GetId(spv_type);
1046   for (auto& decoration : this->GetDecorationsFor(type_id)) {
1047     if (decoration.size() == 2 && decoration[0] == SpvDecorationArrayStride) {
1048       const auto stride = decoration[1];
1049       if (stride == 0) {
1050         return Fail() << "invalid array type ID " << type_id
1051                       << ": ArrayStride can't be 0";
1052       }
1053       if (has_array_stride) {
1054         return Fail() << "invalid array type ID " << type_id
1055                       << ": multiple ArrayStride decorations";
1056       }
1057       has_array_stride = true;
1058       *array_stride = stride;
1059     } else {
1060       return Fail() << "invalid array type ID " << type_id
1061                     << ": unknown decoration "
1062                     << (decoration.empty() ? "(empty)"
1063                                            : std::to_string(decoration[0]))
1064                     << " with " << decoration.size() << " total words";
1065     }
1066   }
1067   return true;
1068 }
1069 
ConvertType(uint32_t type_id,const spvtools::opt::analysis::Struct * struct_ty)1070 const Type* ParserImpl::ConvertType(
1071     uint32_t type_id,
1072     const spvtools::opt::analysis::Struct* struct_ty) {
1073   // Compute the struct decoration.
1074   auto struct_decorations = this->GetDecorationsFor(type_id);
1075   bool is_block_decorated = false;
1076   if (struct_decorations.size() == 1) {
1077     const auto decoration = struct_decorations[0][0];
1078     if (decoration == SpvDecorationBlock) {
1079       is_block_decorated = true;
1080     } else if (decoration == SpvDecorationBufferBlock) {
1081       is_block_decorated = true;
1082       remap_buffer_block_type_.insert(type_id);
1083     } else {
1084       Fail() << "struct with ID " << type_id
1085              << " has unrecognized decoration: " << int(decoration);
1086     }
1087   } else if (struct_decorations.size() > 1) {
1088     Fail() << "can't handle a struct with more than one decoration: struct "
1089            << type_id << " has " << struct_decorations.size();
1090     return nullptr;
1091   }
1092 
1093   // Compute members
1094   ast::StructMemberList ast_members;
1095   const auto members = struct_ty->element_types();
1096   if (members.empty()) {
1097     Fail() << "WGSL does not support empty structures. can't convert type: "
1098            << def_use_mgr_->GetDef(type_id)->PrettyPrint();
1099     return nullptr;
1100   }
1101   TypeList ast_member_types;
1102   unsigned num_non_writable_members = 0;
1103   for (uint32_t member_index = 0; member_index < members.size();
1104        ++member_index) {
1105     const auto member_type_id = type_mgr_->GetId(members[member_index]);
1106     auto* ast_member_ty = ConvertType(member_type_id);
1107     if (ast_member_ty == nullptr) {
1108       // Already emitted diagnostics.
1109       return nullptr;
1110     }
1111 
1112     ast_member_types.emplace_back(ast_member_ty);
1113 
1114     // Scan member for built-in decorations. Some vertex built-ins are handled
1115     // specially, and should not generate a structure member.
1116     bool create_ast_member = true;
1117     for (auto& decoration : GetDecorationsForMember(type_id, member_index)) {
1118       if (decoration.empty()) {
1119         Fail() << "malformed SPIR-V decoration: it's empty";
1120         return nullptr;
1121       }
1122       if ((decoration[0] == SpvDecorationBuiltIn) && (decoration.size() > 1)) {
1123         switch (decoration[1]) {
1124           case SpvBuiltInPosition:
1125             // Record this built-in variable specially.
1126             builtin_position_.struct_type_id = type_id;
1127             builtin_position_.position_member_index = member_index;
1128             builtin_position_.position_member_type_id = member_type_id;
1129             create_ast_member = false;  // Not part of the WGSL structure.
1130             break;
1131           case SpvBuiltInPointSize:  // not supported in WGSL, but ignore
1132             builtin_position_.pointsize_member_index = member_index;
1133             create_ast_member = false;  // Not part of the WGSL structure.
1134             break;
1135           case SpvBuiltInClipDistance:  // not supported in WGSL
1136           case SpvBuiltInCullDistance:  // not supported in WGSL
1137             create_ast_member = false;  // Not part of the WGSL structure.
1138             break;
1139           default:
1140             Fail() << "unrecognized builtin " << decoration[1];
1141             return nullptr;
1142         }
1143       }
1144     }
1145     if (!create_ast_member) {
1146       // This member is decorated as a built-in, and is handled specially.
1147       continue;
1148     }
1149 
1150     bool is_non_writable = false;
1151     ast::DecorationList ast_member_decorations;
1152     for (auto& decoration : GetDecorationsForMember(type_id, member_index)) {
1153       if (IsPipelineDecoration(decoration)) {
1154         // IO decorations are handled when emitting the entry point.
1155         continue;
1156       } else if (decoration[0] == SpvDecorationNonWritable) {
1157         // WGSL doesn't represent individual members as non-writable. Instead,
1158         // apply the ReadOnly access control to the containing struct if all
1159         // the members are non-writable.
1160         is_non_writable = true;
1161       } else {
1162         auto decos = ConvertMemberDecoration(type_id, member_index,
1163                                              ast_member_ty, decoration);
1164         for (auto* deco : decos) {
1165           ast_member_decorations.emplace_back(deco);
1166         }
1167         if (!success_) {
1168           return nullptr;
1169         }
1170       }
1171     }
1172 
1173     if (is_non_writable) {
1174       // Count a member as non-writable only once, no matter how many
1175       // NonWritable decorations are applied to it.
1176       ++num_non_writable_members;
1177     }
1178     const auto member_name = namer_.GetMemberName(type_id, member_index);
1179     auto* ast_struct_member = create<ast::StructMember>(
1180         Source{}, builder_.Symbols().Register(member_name),
1181         ast_member_ty->Build(builder_), std::move(ast_member_decorations));
1182     ast_members.push_back(ast_struct_member);
1183   }
1184 
1185   if (ast_members.empty()) {
1186     // All members were likely built-ins. Don't generate an empty AST structure.
1187     return nullptr;
1188   }
1189 
1190   namer_.SuggestSanitizedName(type_id, "S");
1191 
1192   auto name = namer_.GetName(type_id);
1193 
1194   // Now make the struct.
1195   auto sym = builder_.Symbols().Register(name);
1196   ast::DecorationList ast_struct_decorations;
1197   if (is_block_decorated && struct_types_for_buffers_.count(type_id)) {
1198     ast_struct_decorations.emplace_back(
1199         create<ast::StructBlockDecoration>(Source{}));
1200   }
1201   auto* ast_struct = create<ast::Struct>(Source{}, sym, std::move(ast_members),
1202                                          std::move(ast_struct_decorations));
1203   if (num_non_writable_members == members.size()) {
1204     read_only_struct_types_.insert(ast_struct->name);
1205   }
1206   AddTypeDecl(sym, ast_struct);
1207   const auto* result = ty_.Struct(sym, std::move(ast_member_types));
1208   struct_id_for_symbol_[sym] = type_id;
1209   return result;
1210 }
1211 
AddTypeDecl(Symbol name,const ast::TypeDecl * decl)1212 void ParserImpl::AddTypeDecl(Symbol name, const ast::TypeDecl* decl) {
1213   auto iter = declared_types_.insert(name);
1214   if (iter.second) {
1215     builder_.AST().AddTypeDecl(decl);
1216   }
1217 }
1218 
ConvertType(uint32_t type_id,PtrAs ptr_as,const spvtools::opt::analysis::Pointer *)1219 const Type* ParserImpl::ConvertType(uint32_t type_id,
1220                                     PtrAs ptr_as,
1221                                     const spvtools::opt::analysis::Pointer*) {
1222   const auto* inst = def_use_mgr_->GetDef(type_id);
1223   const auto pointee_type_id = inst->GetSingleWordInOperand(1);
1224   const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0));
1225 
1226   if (pointee_type_id == builtin_position_.struct_type_id) {
1227     builtin_position_.pointer_type_id = type_id;
1228     // Pipeline IO builtins map to private variables.
1229     builtin_position_.storage_class = SpvStorageClassPrivate;
1230     return nullptr;
1231   }
1232   auto* ast_elem_ty = ConvertType(pointee_type_id, PtrAs::Ptr);
1233   if (ast_elem_ty == nullptr) {
1234     Fail() << "SPIR-V pointer type with ID " << type_id
1235            << " has invalid pointee type " << pointee_type_id;
1236     return nullptr;
1237   }
1238 
1239   auto ast_storage_class = enum_converter_.ToStorageClass(storage_class);
1240   if (ast_storage_class == ast::StorageClass::kInvalid) {
1241     Fail() << "SPIR-V pointer type with ID " << type_id
1242            << " has invalid storage class "
1243            << static_cast<uint32_t>(storage_class);
1244     return nullptr;
1245   }
1246   if (ast_storage_class == ast::StorageClass::kUniform &&
1247       remap_buffer_block_type_.count(pointee_type_id)) {
1248     ast_storage_class = ast::StorageClass::kStorage;
1249     remap_buffer_block_type_.insert(type_id);
1250   }
1251 
1252   // Pipeline input and output variables map to private variables.
1253   if (ast_storage_class == ast::StorageClass::kInput ||
1254       ast_storage_class == ast::StorageClass::kOutput) {
1255     ast_storage_class = ast::StorageClass::kPrivate;
1256   }
1257   switch (ptr_as) {
1258     case PtrAs::Ref:
1259       return ty_.Reference(ast_elem_ty, ast_storage_class);
1260     case PtrAs::Ptr:
1261       return ty_.Pointer(ast_elem_ty, ast_storage_class);
1262   }
1263   Fail() << "invalid value for ptr_as: " << static_cast<int>(ptr_as);
1264   return nullptr;
1265 }
1266 
RegisterTypes()1267 bool ParserImpl::RegisterTypes() {
1268   if (!success_) {
1269     return false;
1270   }
1271 
1272   // First record the structure types that should have a `block` decoration
1273   // in WGSL. In particular, exclude user-defined pipeline IO in a
1274   // block-decorated struct.
1275   for (const auto& type_or_value : module_->types_values()) {
1276     if (type_or_value.opcode() != SpvOpVariable) {
1277       continue;
1278     }
1279     const auto& var = type_or_value;
1280     const auto spirv_storage_class =
1281         SpvStorageClass(var.GetSingleWordInOperand(0));
1282     if ((spirv_storage_class != SpvStorageClassStorageBuffer) &&
1283         (spirv_storage_class != SpvStorageClassUniform)) {
1284       continue;
1285     }
1286     const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
1287     if (ptr_type->opcode() != SpvOpTypePointer) {
1288       return Fail() << "OpVariable type expected to be a pointer: "
1289                     << var.PrettyPrint();
1290     }
1291     const auto* store_type =
1292         def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
1293     if (store_type->opcode() == SpvOpTypeStruct) {
1294       struct_types_for_buffers_.insert(store_type->result_id());
1295     } else {
1296       Fail() << "WGSL does not support arrays of buffers: "
1297              << var.PrettyPrint();
1298     }
1299   }
1300 
1301   // Now convert each type.
1302   for (auto& type_or_const : module_->types_values()) {
1303     const auto* type = type_mgr_->GetType(type_or_const.result_id());
1304     if (type == nullptr) {
1305       continue;
1306     }
1307     ConvertType(type_or_const.result_id());
1308   }
1309   // Manufacture a type for the gl_Position variable if we have to.
1310   if ((builtin_position_.struct_type_id != 0) &&
1311       (builtin_position_.position_member_pointer_type_id == 0)) {
1312     builtin_position_.position_member_pointer_type_id =
1313         type_mgr_->FindPointerToType(builtin_position_.position_member_type_id,
1314                                      builtin_position_.storage_class);
1315     ConvertType(builtin_position_.position_member_pointer_type_id);
1316   }
1317   return success_;
1318 }
1319 
RejectInvalidPointerRoots()1320 bool ParserImpl::RejectInvalidPointerRoots() {
1321   if (!success_) {
1322     return false;
1323   }
1324   for (auto& inst : module_->types_values()) {
1325     if (const auto* result_type = type_mgr_->GetType(inst.type_id())) {
1326       if (result_type->AsPointer()) {
1327         switch (inst.opcode()) {
1328           case SpvOpVariable:
1329             // This is the only valid case.
1330             break;
1331           case SpvOpUndef:
1332             return Fail() << "undef pointer is not valid: "
1333                           << inst.PrettyPrint();
1334           case SpvOpConstantNull:
1335             return Fail() << "null pointer is not valid: "
1336                           << inst.PrettyPrint();
1337           default:
1338             return Fail() << "module-scope pointer is not valid: "
1339                           << inst.PrettyPrint();
1340         }
1341       }
1342     }
1343   }
1344   return success();
1345 }
1346 
EmitScalarSpecConstants()1347 bool ParserImpl::EmitScalarSpecConstants() {
1348   if (!success_) {
1349     return false;
1350   }
1351   // Generate a module-scope const declaration for each instruction
1352   // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant.
1353   for (auto& inst : module_->types_values()) {
1354     // These will be populated for a valid scalar spec constant.
1355     const Type* ast_type = nullptr;
1356     ast::LiteralExpression* ast_expr = nullptr;
1357 
1358     switch (inst.opcode()) {
1359       case SpvOpSpecConstantTrue:
1360       case SpvOpSpecConstantFalse: {
1361         ast_type = ConvertType(inst.type_id());
1362         ast_expr = create<ast::BoolLiteralExpression>(
1363             Source{}, inst.opcode() == SpvOpSpecConstantTrue);
1364         break;
1365       }
1366       case SpvOpSpecConstant: {
1367         ast_type = ConvertType(inst.type_id());
1368         const uint32_t literal_value = inst.GetSingleWordInOperand(0);
1369         if (ast_type->Is<I32>()) {
1370           ast_expr = create<ast::SintLiteralExpression>(
1371               Source{}, static_cast<int32_t>(literal_value));
1372         } else if (ast_type->Is<U32>()) {
1373           ast_expr = create<ast::UintLiteralExpression>(
1374               Source{}, static_cast<uint32_t>(literal_value));
1375         } else if (ast_type->Is<F32>()) {
1376           float float_value;
1377           // Copy the bits so we can read them as a float.
1378           std::memcpy(&float_value, &literal_value, sizeof(float_value));
1379           ast_expr = create<ast::FloatLiteralExpression>(Source{}, float_value);
1380         } else {
1381           return Fail() << " invalid result type for OpSpecConstant "
1382                         << inst.PrettyPrint();
1383         }
1384         break;
1385       }
1386       default:
1387         break;
1388     }
1389     if (ast_type && ast_expr) {
1390       ast::DecorationList spec_id_decos;
1391       for (const auto& deco : GetDecorationsFor(inst.result_id())) {
1392         if ((deco.size() == 2) && (deco[0] == SpvDecorationSpecId)) {
1393           const uint32_t id = deco[1];
1394           if (id > 65535) {
1395             return Fail() << "SpecId too large. WGSL override IDs must be "
1396                              "between 0 and 65535: ID %"
1397                           << inst.result_id() << " has SpecId " << id;
1398           }
1399           auto* cid = create<ast::OverrideDecoration>(Source{}, id);
1400           spec_id_decos.push_back(cid);
1401           break;
1402         }
1403       }
1404       auto* ast_var =
1405           MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type,
1406                        true, ast_expr, std::move(spec_id_decos));
1407       if (ast_var) {
1408         builder_.AST().AddGlobalVariable(ast_var);
1409         scalar_spec_constants_.insert(inst.result_id());
1410       }
1411     }
1412   }
1413   return success_;
1414 }
1415 
MaybeGenerateAlias(uint32_t type_id,const spvtools::opt::analysis::Type * type,const Type * ast_type)1416 const Type* ParserImpl::MaybeGenerateAlias(
1417     uint32_t type_id,
1418     const spvtools::opt::analysis::Type* type,
1419     const Type* ast_type) {
1420   if (!success_) {
1421     return nullptr;
1422   }
1423 
1424   // We only care about arrays, and runtime arrays.
1425   switch (type->kind()) {
1426     case spvtools::opt::analysis::Type::kRuntimeArray:
1427       // Runtime arrays are always decorated with ArrayStride so always get a
1428       // type alias.
1429       namer_.SuggestSanitizedName(type_id, "RTArr");
1430       break;
1431     case spvtools::opt::analysis::Type::kArray:
1432       // Only make a type aliase for arrays with decorations.
1433       if (GetDecorationsFor(type_id).empty()) {
1434         return ast_type;
1435       }
1436       namer_.SuggestSanitizedName(type_id, "Arr");
1437       break;
1438     default:
1439       // Ignore constants, and any other types.
1440       return ast_type;
1441   }
1442   auto* ast_underlying_type = ast_type;
1443   if (ast_underlying_type == nullptr) {
1444     Fail() << "internal error: no type registered for SPIR-V ID: " << type_id;
1445     return nullptr;
1446   }
1447   const auto name = namer_.GetName(type_id);
1448   const auto sym = builder_.Symbols().Register(name);
1449   auto* ast_alias_type =
1450       builder_.ty.alias(sym, ast_underlying_type->Build(builder_));
1451 
1452   // Record this new alias as the AST type for this SPIR-V ID.
1453   AddTypeDecl(sym, ast_alias_type);
1454 
1455   return ty_.Alias(sym, ast_underlying_type);
1456 }
1457 
EmitModuleScopeVariables()1458 bool ParserImpl::EmitModuleScopeVariables() {
1459   if (!success_) {
1460     return false;
1461   }
1462   for (const auto& type_or_value : module_->types_values()) {
1463     if (type_or_value.opcode() != SpvOpVariable) {
1464       continue;
1465     }
1466     const auto& var = type_or_value;
1467     const auto spirv_storage_class =
1468         SpvStorageClass(var.GetSingleWordInOperand(0));
1469 
1470     uint32_t type_id = var.type_id();
1471     if ((type_id == builtin_position_.pointer_type_id) &&
1472         ((spirv_storage_class == SpvStorageClassInput) ||
1473          (spirv_storage_class == SpvStorageClassOutput))) {
1474       // Skip emitting gl_PerVertex.
1475       builtin_position_.per_vertex_var_id = var.result_id();
1476       builtin_position_.per_vertex_var_init_id =
1477           var.NumInOperands() > 1 ? var.GetSingleWordInOperand(1) : 0u;
1478       continue;
1479     }
1480     switch (enum_converter_.ToStorageClass(spirv_storage_class)) {
1481       case ast::StorageClass::kNone:
1482       case ast::StorageClass::kInput:
1483       case ast::StorageClass::kOutput:
1484       case ast::StorageClass::kUniform:
1485       case ast::StorageClass::kUniformConstant:
1486       case ast::StorageClass::kStorage:
1487       case ast::StorageClass::kImage:
1488       case ast::StorageClass::kWorkgroup:
1489       case ast::StorageClass::kPrivate:
1490         break;
1491       default:
1492         return Fail() << "invalid SPIR-V storage class "
1493                       << int(spirv_storage_class)
1494                       << " for module scope variable: " << var.PrettyPrint();
1495     }
1496     if (!success_) {
1497       return false;
1498     }
1499     const Type* ast_type = nullptr;
1500     if (spirv_storage_class == SpvStorageClassUniformConstant) {
1501       // These are opaque handles: samplers or textures
1502       ast_type = GetTypeForHandleVar(var);
1503       if (!ast_type) {
1504         return false;
1505       }
1506     } else {
1507       ast_type = ConvertType(type_id);
1508       if (ast_type == nullptr) {
1509         return Fail() << "internal error: failed to register Tint AST type for "
1510                          "SPIR-V type with ID: "
1511                       << var.type_id();
1512       }
1513       if (!ast_type->Is<Pointer>()) {
1514         return Fail() << "variable with ID " << var.result_id()
1515                       << " has non-pointer type " << var.type_id();
1516       }
1517     }
1518 
1519     auto* ast_store_type = ast_type->As<Pointer>()->type;
1520     auto ast_storage_class = ast_type->As<Pointer>()->storage_class;
1521     const ast::Expression* ast_constructor = nullptr;
1522     if (var.NumInOperands() > 1) {
1523       // SPIR-V initializers are always constants.
1524       // (OpenCL also allows the ID of an OpVariable, but we don't handle that
1525       // here.)
1526       ast_constructor =
1527           MakeConstantExpression(var.GetSingleWordInOperand(1)).expr;
1528     }
1529     auto* ast_var =
1530         MakeVariable(var.result_id(), ast_storage_class, ast_store_type, false,
1531                      ast_constructor, ast::DecorationList{});
1532     // TODO(dneto): initializers (a.k.a. constructor expression)
1533     if (ast_var) {
1534       builder_.AST().AddGlobalVariable(ast_var);
1535     }
1536   }
1537 
1538   // Emit gl_Position instead of gl_PerVertex
1539   if (builtin_position_.per_vertex_var_id) {
1540     // Make sure the variable has a name.
1541     namer_.SuggestSanitizedName(builtin_position_.per_vertex_var_id,
1542                                 "gl_Position");
1543     const ast::Expression* ast_constructor = nullptr;
1544     if (builtin_position_.per_vertex_var_init_id) {
1545       // The initializer is complex.
1546       const auto* init =
1547           def_use_mgr_->GetDef(builtin_position_.per_vertex_var_init_id);
1548       switch (init->opcode()) {
1549         case SpvOpConstantComposite:
1550         case SpvOpSpecConstantComposite:
1551           ast_constructor = MakeConstantExpression(
1552                                 init->GetSingleWordInOperand(
1553                                     builtin_position_.position_member_index))
1554                                 .expr;
1555           break;
1556         default:
1557           return Fail() << "gl_PerVertex initializer too complex. only "
1558                            "OpCompositeConstruct and OpSpecConstantComposite "
1559                            "are supported: "
1560                         << init->PrettyPrint();
1561       }
1562     }
1563     auto* ast_var = MakeVariable(
1564         builtin_position_.per_vertex_var_id,
1565         enum_converter_.ToStorageClass(builtin_position_.storage_class),
1566         ConvertType(builtin_position_.position_member_type_id), false,
1567         ast_constructor, {});
1568 
1569     builder_.AST().AddGlobalVariable(ast_var);
1570   }
1571   return success_;
1572 }
1573 
1574 // @param var_id SPIR-V id of an OpVariable, assumed to be pointer
1575 // to an array
1576 // @returns the IntConstant for the size of the array, or nullptr
GetArraySize(uint32_t var_id)1577 const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(
1578     uint32_t var_id) {
1579   auto* var = def_use_mgr_->GetDef(var_id);
1580   if (!var || var->opcode() != SpvOpVariable) {
1581     return nullptr;
1582   }
1583   auto* ptr_type = def_use_mgr_->GetDef(var->type_id());
1584   if (!ptr_type || ptr_type->opcode() != SpvOpTypePointer) {
1585     return nullptr;
1586   }
1587   auto* array_type = def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
1588   if (!array_type || array_type->opcode() != SpvOpTypeArray) {
1589     return nullptr;
1590   }
1591   auto* size = constant_mgr_->FindDeclaredConstant(
1592       array_type->GetSingleWordInOperand(1));
1593   if (!size) {
1594     return nullptr;
1595   }
1596   return size->AsIntConstant();
1597 }
1598 
MakeVariable(uint32_t id,ast::StorageClass sc,const Type * storage_type,bool is_const,const ast::Expression * constructor,ast::DecorationList decorations)1599 ast::Variable* ParserImpl::MakeVariable(uint32_t id,
1600                                         ast::StorageClass sc,
1601                                         const Type* storage_type,
1602                                         bool is_const,
1603                                         const ast::Expression* constructor,
1604                                         ast::DecorationList decorations) {
1605   if (storage_type == nullptr) {
1606     Fail() << "internal error: can't make ast::Variable for null type";
1607     return nullptr;
1608   }
1609 
1610   ast::Access access = ast::Access::kUndefined;
1611   if (sc == ast::StorageClass::kStorage) {
1612     bool read_only = false;
1613     if (auto* tn = storage_type->As<Named>()) {
1614       read_only = read_only_struct_types_.count(tn->name) > 0;
1615     }
1616 
1617     // Apply the access(read) or access(read_write) modifier.
1618     access = read_only ? ast::Access::kRead : ast::Access::kReadWrite;
1619   }
1620 
1621   // Handle variables (textures and samplers) are always in the handle
1622   // storage class, so we don't mention the storage class.
1623   if (sc == ast::StorageClass::kUniformConstant) {
1624     sc = ast::StorageClass::kNone;
1625   }
1626 
1627   if (!ConvertDecorationsForVariable(id, &storage_type, &decorations,
1628                                      sc != ast::StorageClass::kPrivate)) {
1629     return nullptr;
1630   }
1631 
1632   std::string name = namer_.Name(id);
1633 
1634   // Note: we're constructing the variable here with the *storage* type,
1635   // regardless of whether this is a `let` or `var` declaration.
1636   // `var` declarations will have a resolved type of ref<storage>, but at the
1637   // AST level both `var` and `let` are declared with the same type.
1638   return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc,
1639                                access, storage_type->Build(builder_), is_const,
1640                                constructor, decorations);
1641 }
1642 
ConvertDecorationsForVariable(uint32_t id,const Type ** store_type,ast::DecorationList * decorations,bool transfer_pipeline_io)1643 bool ParserImpl::ConvertDecorationsForVariable(uint32_t id,
1644                                                const Type** store_type,
1645                                                ast::DecorationList* decorations,
1646                                                bool transfer_pipeline_io) {
1647   DecorationList non_builtin_pipeline_decorations;
1648   for (auto& deco : GetDecorationsFor(id)) {
1649     if (deco.empty()) {
1650       return Fail() << "malformed decoration on ID " << id << ": it is empty";
1651     }
1652     if (deco[0] == SpvDecorationBuiltIn) {
1653       if (deco.size() == 1) {
1654         return Fail() << "malformed BuiltIn decoration on ID " << id
1655                       << ": has no operand";
1656       }
1657       const auto spv_builtin = static_cast<SpvBuiltIn>(deco[1]);
1658       switch (spv_builtin) {
1659         case SpvBuiltInPointSize:
1660           special_builtins_[id] = spv_builtin;
1661           return false;  // This is not an error
1662         case SpvBuiltInSampleId:
1663         case SpvBuiltInVertexIndex:
1664         case SpvBuiltInInstanceIndex:
1665         case SpvBuiltInLocalInvocationId:
1666         case SpvBuiltInLocalInvocationIndex:
1667         case SpvBuiltInGlobalInvocationId:
1668         case SpvBuiltInWorkgroupId:
1669         case SpvBuiltInNumWorkgroups:
1670           // The SPIR-V variable may signed (because GLSL requires signed for
1671           // some of these), but WGSL requires unsigned.  Handle specially
1672           // so we always perform the conversion at load and store.
1673           special_builtins_[id] = spv_builtin;
1674           if (auto* forced_type = UnsignedTypeFor(*store_type)) {
1675             // Requires conversion and special handling in code generation.
1676             if (transfer_pipeline_io) {
1677               *store_type = forced_type;
1678             }
1679           }
1680           break;
1681         case SpvBuiltInSampleMask: {
1682           // In SPIR-V this is used for both input and output variable.
1683           // The SPIR-V variable has store type of array of integer scalar,
1684           // either signed or unsigned.
1685           // WGSL requires the store type to be u32.
1686           auto* size = GetArraySize(id);
1687           if (!size || size->GetZeroExtendedValue() != 1) {
1688             Fail() << "WGSL supports a sample mask of at most 32 bits. "
1689                       "SampleMask must be an array of 1 element.";
1690           }
1691           special_builtins_[id] = spv_builtin;
1692           if (transfer_pipeline_io) {
1693             *store_type = ty_.U32();
1694           }
1695           break;
1696         }
1697         default:
1698           break;
1699       }
1700       auto ast_builtin = enum_converter_.ToBuiltin(spv_builtin);
1701       if (ast_builtin == ast::Builtin::kNone) {
1702         // A diagnostic has already been emitted.
1703         return false;
1704       }
1705       if (transfer_pipeline_io) {
1706         decorations->emplace_back(
1707             create<ast::BuiltinDecoration>(Source{}, ast_builtin));
1708       }
1709     }
1710     if (transfer_pipeline_io && IsPipelineDecoration(deco)) {
1711       non_builtin_pipeline_decorations.push_back(deco);
1712     }
1713     if (deco[0] == SpvDecorationDescriptorSet) {
1714       if (deco.size() == 1) {
1715         return Fail() << "malformed DescriptorSet decoration on ID " << id
1716                       << ": has no operand";
1717       }
1718       decorations->emplace_back(
1719           create<ast::GroupDecoration>(Source{}, deco[1]));
1720     }
1721     if (deco[0] == SpvDecorationBinding) {
1722       if (deco.size() == 1) {
1723         return Fail() << "malformed Binding decoration on ID " << id
1724                       << ": has no operand";
1725       }
1726       decorations->emplace_back(
1727           create<ast::BindingDecoration>(Source{}, deco[1]));
1728     }
1729   }
1730 
1731   if (transfer_pipeline_io) {
1732     if (!ConvertPipelineDecorations(
1733             *store_type, non_builtin_pipeline_decorations, decorations)) {
1734       return false;
1735     }
1736   }
1737 
1738   return success();
1739 }
1740 
GetMemberPipelineDecorations(const Struct & struct_type,int member_index)1741 DecorationList ParserImpl::GetMemberPipelineDecorations(
1742     const Struct& struct_type,
1743     int member_index) {
1744   // Yes, I could have used std::copy_if or std::copy_if.
1745   DecorationList result;
1746   for (const auto& deco : GetDecorationsForMember(
1747            struct_id_for_symbol_[struct_type.name], member_index)) {
1748     if (IsPipelineDecoration(deco)) {
1749       result.emplace_back(deco);
1750     }
1751   }
1752   return result;
1753 }
1754 
SetLocation(ast::DecorationList * decos,const ast::Decoration * replacement)1755 const ast::Decoration* ParserImpl::SetLocation(
1756     ast::DecorationList* decos,
1757     const ast::Decoration* replacement) {
1758   if (!replacement) {
1759     return nullptr;
1760   }
1761   for (auto*& deco : *decos) {
1762     if (deco->Is<ast::LocationDecoration>()) {
1763       // Replace this location decoration with the replacement.
1764       // The old one doesn't leak because it's kept in the builder's AST node
1765       // list.
1766       const ast::Decoration* result = nullptr;
1767       result = deco;
1768       deco = replacement;
1769       return result;  // Assume there is only one such decoration.
1770     }
1771   }
1772   // The list didn't have a location. Add it.
1773   decos->push_back(replacement);
1774   return nullptr;
1775 }
1776 
ConvertPipelineDecorations(const Type * store_type,const DecorationList & decorations,ast::DecorationList * ast_decos)1777 bool ParserImpl::ConvertPipelineDecorations(const Type* store_type,
1778                                             const DecorationList& decorations,
1779                                             ast::DecorationList* ast_decos) {
1780   // Vulkan defaults to perspective-correct interpolation.
1781   ast::InterpolationType type = ast::InterpolationType::kPerspective;
1782   ast::InterpolationSampling sampling = ast::InterpolationSampling::kNone;
1783 
1784   for (const auto& deco : decorations) {
1785     TINT_ASSERT(Reader, deco.size() > 0);
1786     switch (deco[0]) {
1787       case SpvDecorationLocation:
1788         if (deco.size() != 2) {
1789           return Fail() << "malformed Location decoration on ID requires one "
1790                            "literal operand";
1791         }
1792         SetLocation(ast_decos,
1793                     create<ast::LocationDecoration>(Source{}, deco[1]));
1794         break;
1795       case SpvDecorationFlat:
1796         type = ast::InterpolationType::kFlat;
1797         break;
1798       case SpvDecorationNoPerspective:
1799         if (store_type->IsIntegerScalarOrVector()) {
1800           // This doesn't capture the array or struct case.
1801           return Fail() << "NoPerspective is invalid on integral IO";
1802         }
1803         type = ast::InterpolationType::kLinear;
1804         break;
1805       case SpvDecorationCentroid:
1806         if (store_type->IsIntegerScalarOrVector()) {
1807           // This doesn't capture the array or struct case.
1808           return Fail()
1809                  << "Centroid interpolation sampling is invalid on integral IO";
1810         }
1811         sampling = ast::InterpolationSampling::kCentroid;
1812         break;
1813       case SpvDecorationSample:
1814         if (store_type->IsIntegerScalarOrVector()) {
1815           // This doesn't capture the array or struct case.
1816           return Fail()
1817                  << "Sample interpolation sampling is invalid on integral IO";
1818         }
1819         sampling = ast::InterpolationSampling::kSample;
1820         break;
1821       default:
1822         break;
1823     }
1824   }
1825 
1826   // Apply interpolation.
1827   if (type == ast::InterpolationType::kPerspective &&
1828       sampling == ast::InterpolationSampling::kNone) {
1829     // This is the default. Don't add a decoration.
1830   } else {
1831     ast_decos->emplace_back(create<ast::InterpolateDecoration>(type, sampling));
1832   }
1833 
1834   return success();
1835 }
1836 
CanMakeConstantExpression(uint32_t id)1837 bool ParserImpl::CanMakeConstantExpression(uint32_t id) {
1838   if ((id == workgroup_size_builtin_.id) ||
1839       (id == workgroup_size_builtin_.x_id) ||
1840       (id == workgroup_size_builtin_.y_id) ||
1841       (id == workgroup_size_builtin_.z_id)) {
1842     return true;
1843   }
1844   const auto* inst = def_use_mgr_->GetDef(id);
1845   if (!inst) {
1846     return false;
1847   }
1848   if (inst->opcode() == SpvOpUndef) {
1849     return true;
1850   }
1851   return nullptr != constant_mgr_->FindDeclaredConstant(id);
1852 }
1853 
MakeConstantExpression(uint32_t id)1854 TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
1855   if (!success_) {
1856     return {};
1857   }
1858 
1859   // Handle the special cases for workgroup sizing.
1860   if (id == workgroup_size_builtin_.id) {
1861     auto x = MakeConstantExpression(workgroup_size_builtin_.x_id);
1862     auto y = MakeConstantExpression(workgroup_size_builtin_.y_id);
1863     auto z = MakeConstantExpression(workgroup_size_builtin_.z_id);
1864     auto* ast_type = ty_.Vector(x.type, 3);
1865     return {ast_type,
1866             builder_.Construct(Source{}, ast_type->Build(builder_),
1867                                ast::ExpressionList{x.expr, y.expr, z.expr})};
1868   } else if (id == workgroup_size_builtin_.x_id) {
1869     return MakeConstantExpressionForScalarSpirvConstant(
1870         Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1871         constant_mgr_->GetConstant(
1872             type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1873             {workgroup_size_builtin_.x_value}));
1874   } else if (id == workgroup_size_builtin_.y_id) {
1875     return MakeConstantExpressionForScalarSpirvConstant(
1876         Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1877         constant_mgr_->GetConstant(
1878             type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1879             {workgroup_size_builtin_.y_value}));
1880   } else if (id == workgroup_size_builtin_.z_id) {
1881     return MakeConstantExpressionForScalarSpirvConstant(
1882         Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1883         constant_mgr_->GetConstant(
1884             type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1885             {workgroup_size_builtin_.z_value}));
1886   }
1887 
1888   // Handle the general case where a constant is already registered
1889   // with the SPIR-V optimizer's analysis framework.
1890   const auto* inst = def_use_mgr_->GetDef(id);
1891   if (inst == nullptr) {
1892     Fail() << "ID " << id << " is not a registered instruction";
1893     return {};
1894   }
1895   auto source = GetSourceForInst(inst);
1896 
1897   // TODO(dneto): Handle spec constants too?
1898 
1899   auto* original_ast_type = ConvertType(inst->type_id());
1900   if (original_ast_type == nullptr) {
1901     return {};
1902   }
1903 
1904   switch (inst->opcode()) {
1905     case SpvOpUndef:  // Remap undef to null.
1906     case SpvOpConstantNull:
1907       return {original_ast_type, MakeNullValue(original_ast_type)};
1908     case SpvOpConstantTrue:
1909     case SpvOpConstantFalse:
1910     case SpvOpConstant: {
1911       const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id);
1912       if (spirv_const == nullptr) {
1913         Fail() << "ID " << id << " is not a constant";
1914         return {};
1915       }
1916       return MakeConstantExpressionForScalarSpirvConstant(
1917           source, original_ast_type, spirv_const);
1918     }
1919     case SpvOpConstantComposite: {
1920       // Handle vector, matrix, array, and struct
1921 
1922       // Generate a composite from explicit components.
1923       ast::ExpressionList ast_components;
1924       if (!inst->WhileEachInId([&](const uint32_t* id_ref) -> bool {
1925             auto component = MakeConstantExpression(*id_ref);
1926             if (!component) {
1927               this->Fail() << "invalid constant with ID " << *id_ref;
1928               return false;
1929             }
1930             ast_components.emplace_back(component.expr);
1931             return true;
1932           })) {
1933         // We've already emitted a diagnostic.
1934         return {};
1935       }
1936       return {original_ast_type,
1937               builder_.Construct(source, original_ast_type->Build(builder_),
1938                                  std::move(ast_components))};
1939     }
1940     default:
1941       break;
1942   }
1943   Fail() << "unhandled constant instruction " << inst->PrettyPrint();
1944   return {};
1945 }
1946 
MakeConstantExpressionForScalarSpirvConstant(Source source,const Type * original_ast_type,const spvtools::opt::analysis::Constant * spirv_const)1947 TypedExpression ParserImpl::MakeConstantExpressionForScalarSpirvConstant(
1948     Source source,
1949     const Type* original_ast_type,
1950     const spvtools::opt::analysis::Constant* spirv_const) {
1951   auto* ast_type = original_ast_type->UnwrapAlias();
1952 
1953   // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
1954   // So canonicalization should map that way too.
1955   // Currently "null<type>" is missing from the WGSL parser.
1956   // See https://bugs.chromium.org/p/tint/issues/detail?id=34
1957   if (ast_type->Is<U32>()) {
1958     return {ty_.U32(),
1959             create<ast::UintLiteralExpression>(source, spirv_const->GetU32())};
1960   }
1961   if (ast_type->Is<I32>()) {
1962     return {ty_.I32(),
1963             create<ast::SintLiteralExpression>(source, spirv_const->GetS32())};
1964   }
1965   if (ast_type->Is<F32>()) {
1966     return {ty_.F32(), create<ast::FloatLiteralExpression>(
1967                            source, spirv_const->GetFloat())};
1968   }
1969   if (ast_type->Is<Bool>()) {
1970     const bool value = spirv_const->AsNullConstant()
1971                            ? false
1972                            : spirv_const->AsBoolConstant()->value();
1973     return {ty_.Bool(), create<ast::BoolLiteralExpression>(source, value)};
1974   }
1975   Fail() << "expected scalar constant";
1976   return {};
1977 }
1978 
MakeNullValue(const Type * type)1979 const ast::Expression* ParserImpl::MakeNullValue(const Type* type) {
1980   // TODO(dneto): Use the no-operands constructor syntax when it becomes
1981   // available in Tint.
1982   // https://github.com/gpuweb/gpuweb/issues/685
1983   // https://bugs.chromium.org/p/tint/issues/detail?id=34
1984 
1985   if (!type) {
1986     Fail() << "trying to create null value for a null type";
1987     return nullptr;
1988   }
1989 
1990   auto* original_type = type;
1991   type = type->UnwrapAlias();
1992 
1993   if (type->Is<Bool>()) {
1994     return create<ast::BoolLiteralExpression>(Source{}, false);
1995   }
1996   if (type->Is<U32>()) {
1997     return create<ast::UintLiteralExpression>(Source{}, 0u);
1998   }
1999   if (type->Is<I32>()) {
2000     return create<ast::SintLiteralExpression>(Source{}, 0);
2001   }
2002   if (type->Is<F32>()) {
2003     return create<ast::FloatLiteralExpression>(Source{}, 0.0f);
2004   }
2005   if (type->Is<Alias>()) {
2006     // TODO(amaiorano): No type constructor for TypeName (yet?)
2007     ast::ExpressionList ast_components;
2008     return builder_.Construct(Source{}, original_type->Build(builder_),
2009                               std::move(ast_components));
2010   }
2011   if (auto* vec_ty = type->As<Vector>()) {
2012     ast::ExpressionList ast_components;
2013     for (size_t i = 0; i < vec_ty->size; ++i) {
2014       ast_components.emplace_back(MakeNullValue(vec_ty->type));
2015     }
2016     return builder_.Construct(Source{}, type->Build(builder_),
2017                               std::move(ast_components));
2018   }
2019   if (auto* mat_ty = type->As<Matrix>()) {
2020     // Matrix components are columns
2021     auto* column_ty = ty_.Vector(mat_ty->type, mat_ty->rows);
2022     ast::ExpressionList ast_components;
2023     for (size_t i = 0; i < mat_ty->columns; ++i) {
2024       ast_components.emplace_back(MakeNullValue(column_ty));
2025     }
2026     return builder_.Construct(Source{}, type->Build(builder_),
2027                               std::move(ast_components));
2028   }
2029   if (auto* arr_ty = type->As<Array>()) {
2030     ast::ExpressionList ast_components;
2031     for (size_t i = 0; i < arr_ty->size; ++i) {
2032       ast_components.emplace_back(MakeNullValue(arr_ty->type));
2033     }
2034     return builder_.Construct(Source{}, original_type->Build(builder_),
2035                               std::move(ast_components));
2036   }
2037   if (auto* struct_ty = type->As<Struct>()) {
2038     ast::ExpressionList ast_components;
2039     for (auto* member : struct_ty->members) {
2040       ast_components.emplace_back(MakeNullValue(member));
2041     }
2042     return builder_.Construct(Source{}, original_type->Build(builder_),
2043                               std::move(ast_components));
2044   }
2045   Fail() << "can't make null value for type: " << type->TypeInfo().name;
2046   return nullptr;
2047 }
2048 
MakeNullExpression(const Type * type)2049 TypedExpression ParserImpl::MakeNullExpression(const Type* type) {
2050   return {type, MakeNullValue(type)};
2051 }
2052 
UnsignedTypeFor(const Type * type)2053 const Type* ParserImpl::UnsignedTypeFor(const Type* type) {
2054   if (type->Is<I32>()) {
2055     return ty_.U32();
2056   }
2057   if (auto* v = type->As<Vector>()) {
2058     if (v->type->Is<I32>()) {
2059       return ty_.Vector(ty_.U32(), v->size);
2060     }
2061   }
2062   return {};
2063 }
2064 
SignedTypeFor(const Type * type)2065 const Type* ParserImpl::SignedTypeFor(const Type* type) {
2066   if (type->Is<U32>()) {
2067     return ty_.I32();
2068   }
2069   if (auto* v = type->As<Vector>()) {
2070     if (v->type->Is<U32>()) {
2071       return ty_.Vector(ty_.I32(), v->size);
2072     }
2073   }
2074   return {};
2075 }
2076 
RectifyOperandSignedness(const spvtools::opt::Instruction & inst,TypedExpression && expr)2077 TypedExpression ParserImpl::RectifyOperandSignedness(
2078     const spvtools::opt::Instruction& inst,
2079     TypedExpression&& expr) {
2080   bool requires_signed = false;
2081   bool requires_unsigned = false;
2082   if (IsGlslExtendedInstruction(inst)) {
2083     const auto extended_opcode =
2084         static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
2085     requires_signed = AssumesSignedOperands(extended_opcode);
2086     requires_unsigned = AssumesUnsignedOperands(extended_opcode);
2087   } else {
2088     const auto opcode = inst.opcode();
2089     requires_signed = AssumesSignedOperands(opcode);
2090     requires_unsigned = AssumesUnsignedOperands(opcode);
2091   }
2092   if (!requires_signed && !requires_unsigned) {
2093     // No conversion is required, assuming our tables are complete.
2094     return std::move(expr);
2095   }
2096   if (!expr) {
2097     Fail() << "internal error: RectifyOperandSignedness given a null expr\n";
2098     return {};
2099   }
2100   auto* type = expr.type;
2101   if (!type) {
2102     Fail() << "internal error: unmapped type for: "
2103            << expr.expr->TypeInfo().name << "\n";
2104     return {};
2105   }
2106   if (requires_unsigned) {
2107     if (auto* unsigned_ty = UnsignedTypeFor(type)) {
2108       // Conversion is required.
2109       return {unsigned_ty,
2110               create<ast::BitcastExpression>(
2111                   Source{}, unsigned_ty->Build(builder_), expr.expr)};
2112     }
2113   } else if (requires_signed) {
2114     if (auto* signed_ty = SignedTypeFor(type)) {
2115       // Conversion is required.
2116       return {signed_ty, create<ast::BitcastExpression>(
2117                              Source{}, signed_ty->Build(builder_), expr.expr)};
2118     }
2119   }
2120   // We should not reach here.
2121   return std::move(expr);
2122 }
2123 
RectifySecondOperandSignedness(const spvtools::opt::Instruction & inst,const Type * first_operand_type,TypedExpression && second_operand_expr)2124 TypedExpression ParserImpl::RectifySecondOperandSignedness(
2125     const spvtools::opt::Instruction& inst,
2126     const Type* first_operand_type,
2127     TypedExpression&& second_operand_expr) {
2128   if ((first_operand_type != second_operand_expr.type) &&
2129       AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) {
2130     // Conversion is required.
2131     return {first_operand_type,
2132             create<ast::BitcastExpression>(Source{},
2133                                            first_operand_type->Build(builder_),
2134                                            second_operand_expr.expr)};
2135   }
2136   // No conversion necessary.
2137   return std::move(second_operand_expr);
2138 }
2139 
ForcedResultType(const spvtools::opt::Instruction & inst,const Type * first_operand_type)2140 const Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
2141                                          const Type* first_operand_type) {
2142   const auto opcode = inst.opcode();
2143   if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
2144     return first_operand_type;
2145   }
2146   if (IsGlslExtendedInstruction(inst)) {
2147     const auto extended_opcode =
2148         static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
2149     if (AssumesResultSignednessMatchesFirstOperand(extended_opcode)) {
2150       return first_operand_type;
2151     }
2152   }
2153   return nullptr;
2154 }
2155 
GetSignedIntMatchingShape(const Type * other)2156 const Type* ParserImpl::GetSignedIntMatchingShape(const Type* other) {
2157   if (other == nullptr) {
2158     Fail() << "no type provided";
2159   }
2160   if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
2161     return ty_.I32();
2162   }
2163   if (auto* vec_ty = other->As<Vector>()) {
2164     return ty_.Vector(ty_.I32(), vec_ty->size);
2165   }
2166   Fail() << "required numeric scalar or vector, but got "
2167          << other->TypeInfo().name;
2168   return nullptr;
2169 }
2170 
GetUnsignedIntMatchingShape(const Type * other)2171 const Type* ParserImpl::GetUnsignedIntMatchingShape(const Type* other) {
2172   if (other == nullptr) {
2173     Fail() << "no type provided";
2174     return nullptr;
2175   }
2176   if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
2177     return ty_.U32();
2178   }
2179   if (auto* vec_ty = other->As<Vector>()) {
2180     return ty_.Vector(ty_.U32(), vec_ty->size);
2181   }
2182   Fail() << "required numeric scalar or vector, but got "
2183          << other->TypeInfo().name;
2184   return nullptr;
2185 }
2186 
RectifyForcedResultType(TypedExpression expr,const spvtools::opt::Instruction & inst,const Type * first_operand_type)2187 TypedExpression ParserImpl::RectifyForcedResultType(
2188     TypedExpression expr,
2189     const spvtools::opt::Instruction& inst,
2190     const Type* first_operand_type) {
2191   auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
2192   if ((!forced_result_ty) || (forced_result_ty == expr.type)) {
2193     return expr;
2194   }
2195   return {expr.type, create<ast::BitcastExpression>(
2196                          Source{}, expr.type->Build(builder_), expr.expr)};
2197 }
2198 
AsUnsigned(TypedExpression expr)2199 TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
2200   if (expr.type && expr.type->IsSignedScalarOrVector()) {
2201     auto* new_type = GetUnsignedIntMatchingShape(expr.type);
2202     return {new_type, create<ast::BitcastExpression>(
2203                           Source{}, new_type->Build(builder_), expr.expr)};
2204   }
2205   return expr;
2206 }
2207 
AsSigned(TypedExpression expr)2208 TypedExpression ParserImpl::AsSigned(TypedExpression expr) {
2209   if (expr.type && expr.type->IsUnsignedScalarOrVector()) {
2210     auto* new_type = GetSignedIntMatchingShape(expr.type);
2211     return {new_type, create<ast::BitcastExpression>(
2212                           Source{}, new_type->Build(builder_), expr.expr)};
2213   }
2214   return expr;
2215 }
2216 
EmitFunctions()2217 bool ParserImpl::EmitFunctions() {
2218   if (!success_) {
2219     return false;
2220   }
2221   for (const auto* f : topologically_ordered_functions_) {
2222     if (!success_) {
2223       return false;
2224     }
2225 
2226     auto id = f->result_id();
2227     auto it = function_to_ep_info_.find(id);
2228     if (it == function_to_ep_info_.end()) {
2229       FunctionEmitter emitter(this, *f, nullptr);
2230       success_ = emitter.Emit();
2231     } else {
2232       for (const auto& ep : it->second) {
2233         FunctionEmitter emitter(this, *f, &ep);
2234         success_ = emitter.Emit();
2235         if (!success_) {
2236           return false;
2237         }
2238       }
2239     }
2240   }
2241   return success_;
2242 }
2243 
2244 const spvtools::opt::Instruction*
GetMemoryObjectDeclarationForHandle(uint32_t id,bool follow_image)2245 ParserImpl::GetMemoryObjectDeclarationForHandle(uint32_t id,
2246                                                 bool follow_image) {
2247   auto saved_id = id;
2248   auto local_fail = [this, saved_id, id,
2249                      follow_image]() -> const spvtools::opt::Instruction* {
2250     const auto* inst = def_use_mgr_->GetDef(id);
2251     Fail() << "Could not find memory object declaration for the "
2252            << (follow_image ? "image" : "sampler") << " underlying id " << id
2253            << " (from original id " << saved_id << ") "
2254            << (inst ? inst->PrettyPrint() : std::string());
2255     return nullptr;
2256   };
2257 
2258   auto& memo_table =
2259       (follow_image ? mem_obj_decl_image_ : mem_obj_decl_sampler_);
2260 
2261   // Use a visited set to defend against bad input which might have long
2262   // chains or even loops.
2263   std::unordered_set<uint32_t> visited;
2264 
2265   // Trace backward in the SSA data flow until we hit a memory object
2266   // declaration.
2267   while (true) {
2268     auto where = memo_table.find(id);
2269     if (where != memo_table.end()) {
2270       return where->second;
2271     }
2272     // Protect against loops.
2273     auto visited_iter = visited.find(id);
2274     if (visited_iter != visited.end()) {
2275       // We've hit a loop. Mark all the visited nodes
2276       // as dead ends.
2277       for (auto iter : visited) {
2278         memo_table[iter] = nullptr;
2279       }
2280       return nullptr;
2281     }
2282     visited.insert(id);
2283 
2284     const auto* inst = def_use_mgr_->GetDef(id);
2285     if (inst == nullptr) {
2286       return local_fail();
2287     }
2288     switch (inst->opcode()) {
2289       case SpvOpFunctionParameter:
2290       case SpvOpVariable:
2291         // We found the memory object declaration.
2292         // Remember it as the answer for the whole path.
2293         for (auto iter : visited) {
2294           memo_table[iter] = inst;
2295         }
2296         return inst;
2297       case SpvOpLoad:
2298         // Follow the pointer being loaded
2299         id = inst->GetSingleWordInOperand(0);
2300         break;
2301       case SpvOpCopyObject:
2302         // Follow the object being copied.
2303         id = inst->GetSingleWordInOperand(0);
2304         break;
2305       case SpvOpAccessChain:
2306       case SpvOpInBoundsAccessChain:
2307       case SpvOpPtrAccessChain:
2308       case SpvOpInBoundsPtrAccessChain:
2309         // Follow the base pointer.
2310         id = inst->GetSingleWordInOperand(0);
2311         break;
2312       case SpvOpSampledImage:
2313         // Follow the image or the sampler, depending on the follow_image
2314         // parameter.
2315         id = inst->GetSingleWordInOperand(follow_image ? 0 : 1);
2316         break;
2317       case SpvOpImage:
2318         // Follow the sampled image
2319         id = inst->GetSingleWordInOperand(0);
2320         break;
2321       default:
2322         // Can't trace further.
2323         // Remember it as the answer for the whole path.
2324         for (auto iter : visited) {
2325           memo_table[iter] = nullptr;
2326         }
2327         return nullptr;
2328     }
2329   }
2330 }
2331 
2332 const spvtools::opt::Instruction*
GetSpirvTypeForHandleMemoryObjectDeclaration(const spvtools::opt::Instruction & var)2333 ParserImpl::GetSpirvTypeForHandleMemoryObjectDeclaration(
2334     const spvtools::opt::Instruction& var) {
2335   if (!success()) {
2336     return nullptr;
2337   }
2338   // The WGSL handle type is determined by looking at information from
2339   // several sources:
2340   //    - the usage of the handle by image access instructions
2341   //    - the SPIR-V type declaration
2342   // Each source does not have enough information to completely determine
2343   // the result.
2344 
2345   // Messages are phrased in terms of images and samplers because those
2346   // are the only SPIR-V handles supported by WGSL.
2347 
2348   // Get the SPIR-V handle type.
2349   const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
2350   if (!ptr_type || (ptr_type->opcode() != SpvOpTypePointer)) {
2351     Fail() << "Invalid type for variable or function parameter "
2352            << var.PrettyPrint();
2353     return nullptr;
2354   }
2355   const auto* raw_handle_type =
2356       def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
2357   if (!raw_handle_type) {
2358     Fail() << "Invalid pointer type for variable or function parameter "
2359            << var.PrettyPrint();
2360     return nullptr;
2361   }
2362   switch (raw_handle_type->opcode()) {
2363     case SpvOpTypeSampler:
2364     case SpvOpTypeImage:
2365       // The expected cases.
2366       break;
2367     case SpvOpTypeArray:
2368     case SpvOpTypeRuntimeArray:
2369       Fail()
2370           << "arrays of textures or samplers are not supported in WGSL; can't "
2371              "translate variable or function parameter: "
2372           << var.PrettyPrint();
2373       return nullptr;
2374     case SpvOpTypeSampledImage:
2375       Fail() << "WGSL does not support combined image-samplers: "
2376              << var.PrettyPrint();
2377       return nullptr;
2378     default:
2379       Fail() << "invalid type for image or sampler variable or function "
2380                 "parameter: "
2381              << var.PrettyPrint();
2382       return nullptr;
2383   }
2384   return raw_handle_type;
2385 }
2386 
GetTypeForHandleVar(const spvtools::opt::Instruction & var)2387 const Pointer* ParserImpl::GetTypeForHandleVar(
2388     const spvtools::opt::Instruction& var) {
2389   auto where = handle_type_.find(&var);
2390   if (where != handle_type_.end()) {
2391     return where->second;
2392   }
2393 
2394   const spvtools::opt::Instruction* raw_handle_type =
2395       GetSpirvTypeForHandleMemoryObjectDeclaration(var);
2396   if (!raw_handle_type) {
2397     return nullptr;
2398   }
2399 
2400   // The variable could be a sampler or image.
2401   // Where possible, determine which one it is from the usage inferred
2402   // for the variable.
2403   Usage usage = handle_usage_[&var];
2404   if (!usage.IsValid()) {
2405     Fail() << "Invalid sampler or texture usage for variable "
2406            << var.PrettyPrint() << "\n"
2407            << usage;
2408     return nullptr;
2409   }
2410   // Infer a handle type, if usage didn't already tell us.
2411   if (!usage.IsComplete()) {
2412     // In SPIR-V you could statically reference a texture or sampler without
2413     // using it in a way that gives us a clue on how to declare it.  Look inside
2414     // the store type to infer a usage.
2415     if (raw_handle_type->opcode() == SpvOpTypeSampler) {
2416       usage.AddSampler();
2417     } else {
2418       // It's a texture.
2419       if (raw_handle_type->NumInOperands() != 7) {
2420         Fail() << "invalid SPIR-V image type: expected 7 operands: "
2421                << raw_handle_type->PrettyPrint();
2422         return nullptr;
2423       }
2424       const auto sampled_param = raw_handle_type->GetSingleWordInOperand(5);
2425       const auto format_param = raw_handle_type->GetSingleWordInOperand(6);
2426       // Only storage images have a format.
2427       if ((format_param != SpvImageFormatUnknown) ||
2428           sampled_param == 2 /* without sampler */) {
2429         // Get NonWritable and NonReadable attributes of the variable.
2430         bool is_nonwritable = false;
2431         bool is_nonreadable = false;
2432         for (const auto& deco : GetDecorationsFor(var.result_id())) {
2433           if (deco.size() != 1) {
2434             continue;
2435           }
2436           if (deco[0] == SpvDecorationNonWritable) {
2437             is_nonwritable = true;
2438           }
2439           if (deco[0] == SpvDecorationNonReadable) {
2440             is_nonreadable = true;
2441           }
2442         }
2443         if (is_nonwritable && is_nonreadable) {
2444           Fail() << "storage image variable is both NonWritable and NonReadable"
2445                  << var.PrettyPrint();
2446         }
2447         if (!is_nonwritable && !is_nonreadable) {
2448           Fail()
2449               << "storage image variable is neither NonWritable nor NonReadable"
2450               << var.PrettyPrint();
2451         }
2452         // Let's make it one of the storage textures.
2453         if (is_nonwritable) {
2454           usage.AddStorageReadTexture();
2455         } else {
2456           usage.AddStorageWriteTexture();
2457         }
2458       } else {
2459         usage.AddSampledTexture();
2460       }
2461     }
2462     if (!usage.IsComplete()) {
2463       Fail()
2464           << "internal error: should have inferred a complete handle type. got "
2465           << usage.to_str();
2466       return nullptr;
2467     }
2468   }
2469 
2470   // Construct the Tint handle type.
2471   const Type* ast_store_type = nullptr;
2472   if (usage.IsSampler()) {
2473     ast_store_type = ty_.Sampler(usage.IsComparisonSampler()
2474                                      ? ast::SamplerKind::kComparisonSampler
2475                                      : ast::SamplerKind::kSampler);
2476   } else if (usage.IsTexture()) {
2477     const spvtools::opt::analysis::Image* image_type =
2478         type_mgr_->GetType(raw_handle_type->result_id())->AsImage();
2479     if (!image_type) {
2480       Fail() << "internal error: Couldn't look up image type"
2481              << raw_handle_type->PrettyPrint();
2482       return nullptr;
2483     }
2484 
2485     if (image_type->is_arrayed()) {
2486       // Give a nicer error message here, where we have the offending variable
2487       // in hand, rather than inside the enum converter.
2488       switch (image_type->dim()) {
2489         case SpvDim2D:
2490         case SpvDimCube:
2491           break;
2492         default:
2493           Fail() << "WGSL arrayed textures must be 2d_array or cube_array: "
2494                     "invalid multisampled texture variable "
2495                  << namer_.Name(var.result_id()) << ": " << var.PrettyPrint();
2496           return nullptr;
2497       }
2498     }
2499 
2500     const ast::TextureDimension dim =
2501         enum_converter_.ToDim(image_type->dim(), image_type->is_arrayed());
2502     if (dim == ast::TextureDimension::kNone) {
2503       return nullptr;
2504     }
2505 
2506     // WGSL textures are always formatted.  Unformatted textures are always
2507     // sampled.
2508     if (usage.IsSampledTexture() || usage.IsStorageReadTexture() ||
2509         (image_type->format() == SpvImageFormatUnknown)) {
2510       // Make a sampled texture type.
2511       auto* ast_sampled_component_type =
2512           ConvertType(raw_handle_type->GetSingleWordInOperand(0));
2513 
2514       // Vulkan ignores the depth parameter on OpImage, so pay attention to the
2515       // usage as well.  That is, it's valid for a Vulkan shader to use an
2516       // OpImage variable with an OpImage*Dref* instruction.  In WGSL we must
2517       // treat that as a depth texture.
2518       if (image_type->depth() || usage.IsDepthTexture()) {
2519         if (image_type->is_multisampled()) {
2520           ast_store_type = ty_.DepthMultisampledTexture(dim);
2521         } else {
2522           ast_store_type = ty_.DepthTexture(dim);
2523         }
2524       } else if (image_type->is_multisampled()) {
2525         if (dim != ast::TextureDimension::k2d) {
2526           Fail() << "WGSL multisampled textures must be 2d and non-arrayed: "
2527                     "invalid multisampled texture variable "
2528                  << namer_.Name(var.result_id()) << ": " << var.PrettyPrint();
2529         }
2530         // Multisampled textures are never depth textures.
2531         ast_store_type =
2532             ty_.MultisampledTexture(dim, ast_sampled_component_type);
2533       } else {
2534         ast_store_type = ty_.SampledTexture(dim, ast_sampled_component_type);
2535       }
2536     } else {
2537       const auto access = ast::Access::kWrite;
2538       const auto format = enum_converter_.ToImageFormat(image_type->format());
2539       if (format == ast::ImageFormat::kNone) {
2540         return nullptr;
2541       }
2542       ast_store_type = ty_.StorageTexture(dim, format, access);
2543     }
2544   } else {
2545     Fail() << "unsupported: UniformConstant variable is not a recognized "
2546               "sampler or texture"
2547            << var.PrettyPrint();
2548     return nullptr;
2549   }
2550 
2551   // Form the pointer type.
2552   auto* result =
2553       ty_.Pointer(ast_store_type, ast::StorageClass::kUniformConstant);
2554   // Remember it for later.
2555   handle_type_[&var] = result;
2556   return result;
2557 }
2558 
GetComponentTypeForFormat(ast::ImageFormat format)2559 const Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
2560   switch (format) {
2561     case ast::ImageFormat::kR8Uint:
2562     case ast::ImageFormat::kR16Uint:
2563     case ast::ImageFormat::kRg8Uint:
2564     case ast::ImageFormat::kR32Uint:
2565     case ast::ImageFormat::kRg16Uint:
2566     case ast::ImageFormat::kRgba8Uint:
2567     case ast::ImageFormat::kRg32Uint:
2568     case ast::ImageFormat::kRgba16Uint:
2569     case ast::ImageFormat::kRgba32Uint:
2570       return ty_.U32();
2571 
2572     case ast::ImageFormat::kR8Sint:
2573     case ast::ImageFormat::kR16Sint:
2574     case ast::ImageFormat::kRg8Sint:
2575     case ast::ImageFormat::kR32Sint:
2576     case ast::ImageFormat::kRg16Sint:
2577     case ast::ImageFormat::kRgba8Sint:
2578     case ast::ImageFormat::kRg32Sint:
2579     case ast::ImageFormat::kRgba16Sint:
2580     case ast::ImageFormat::kRgba32Sint:
2581       return ty_.I32();
2582 
2583     case ast::ImageFormat::kR8Unorm:
2584     case ast::ImageFormat::kRg8Unorm:
2585     case ast::ImageFormat::kRgba8Unorm:
2586     case ast::ImageFormat::kRgba8UnormSrgb:
2587     case ast::ImageFormat::kBgra8Unorm:
2588     case ast::ImageFormat::kBgra8UnormSrgb:
2589     case ast::ImageFormat::kRgb10A2Unorm:
2590     case ast::ImageFormat::kR8Snorm:
2591     case ast::ImageFormat::kRg8Snorm:
2592     case ast::ImageFormat::kRgba8Snorm:
2593     case ast::ImageFormat::kR16Float:
2594     case ast::ImageFormat::kR32Float:
2595     case ast::ImageFormat::kRg16Float:
2596     case ast::ImageFormat::kRg11B10Float:
2597     case ast::ImageFormat::kRg32Float:
2598     case ast::ImageFormat::kRgba16Float:
2599     case ast::ImageFormat::kRgba32Float:
2600       return ty_.F32();
2601     default:
2602       break;
2603   }
2604   Fail() << "unknown format " << int(format);
2605   return nullptr;
2606 }
2607 
GetChannelCountForFormat(ast::ImageFormat format)2608 unsigned ParserImpl::GetChannelCountForFormat(ast::ImageFormat format) {
2609   switch (format) {
2610     case ast::ImageFormat::kR16Float:
2611     case ast::ImageFormat::kR16Sint:
2612     case ast::ImageFormat::kR16Uint:
2613     case ast::ImageFormat::kR32Float:
2614     case ast::ImageFormat::kR32Sint:
2615     case ast::ImageFormat::kR32Uint:
2616     case ast::ImageFormat::kR8Sint:
2617     case ast::ImageFormat::kR8Snorm:
2618     case ast::ImageFormat::kR8Uint:
2619     case ast::ImageFormat::kR8Unorm:
2620       // One channel
2621       return 1;
2622 
2623     case ast::ImageFormat::kRg11B10Float:
2624     case ast::ImageFormat::kRg16Float:
2625     case ast::ImageFormat::kRg16Sint:
2626     case ast::ImageFormat::kRg16Uint:
2627     case ast::ImageFormat::kRg32Float:
2628     case ast::ImageFormat::kRg32Sint:
2629     case ast::ImageFormat::kRg32Uint:
2630     case ast::ImageFormat::kRg8Sint:
2631     case ast::ImageFormat::kRg8Snorm:
2632     case ast::ImageFormat::kRg8Uint:
2633     case ast::ImageFormat::kRg8Unorm:
2634       // Two channels
2635       return 2;
2636 
2637     case ast::ImageFormat::kBgra8Unorm:
2638     case ast::ImageFormat::kBgra8UnormSrgb:
2639     case ast::ImageFormat::kRgb10A2Unorm:
2640     case ast::ImageFormat::kRgba16Float:
2641     case ast::ImageFormat::kRgba16Sint:
2642     case ast::ImageFormat::kRgba16Uint:
2643     case ast::ImageFormat::kRgba32Float:
2644     case ast::ImageFormat::kRgba32Sint:
2645     case ast::ImageFormat::kRgba32Uint:
2646     case ast::ImageFormat::kRgba8Sint:
2647     case ast::ImageFormat::kRgba8Snorm:
2648     case ast::ImageFormat::kRgba8Uint:
2649     case ast::ImageFormat::kRgba8Unorm:
2650     case ast::ImageFormat::kRgba8UnormSrgb:
2651       // Four channels
2652       return 4;
2653 
2654     default:
2655       break;
2656   }
2657   Fail() << "unknown format " << int(format);
2658   return 0;
2659 }
2660 
GetTexelTypeForFormat(ast::ImageFormat format)2661 const Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) {
2662   const auto* component_type = GetComponentTypeForFormat(format);
2663   if (!component_type) {
2664     return nullptr;
2665   }
2666   return ty_.Vector(component_type, 4);
2667 }
2668 
RegisterHandleUsage()2669 bool ParserImpl::RegisterHandleUsage() {
2670   if (!success_) {
2671     return false;
2672   }
2673 
2674   // Map a function ID to the list of its function parameter instructions, in
2675   // order.
2676   std::unordered_map<uint32_t, std::vector<const spvtools::opt::Instruction*>>
2677       function_params;
2678   for (const auto* f : topologically_ordered_functions_) {
2679     // Record the instructions defining this function's parameters.
2680     auto& params = function_params[f->result_id()];
2681     f->ForEachParam([&params](const spvtools::opt::Instruction* param) {
2682       params.push_back(param);
2683     });
2684   }
2685 
2686   // Returns the memory object declaration for an image underlying the first
2687   // operand of the given image instruction.
2688   auto get_image = [this](const spvtools::opt::Instruction& image_inst) {
2689     return this->GetMemoryObjectDeclarationForHandle(
2690         image_inst.GetSingleWordInOperand(0), true);
2691   };
2692   // Returns the memory object declaration for a sampler underlying the first
2693   // operand of the given image instruction.
2694   auto get_sampler = [this](const spvtools::opt::Instruction& image_inst) {
2695     return this->GetMemoryObjectDeclarationForHandle(
2696         image_inst.GetSingleWordInOperand(0), false);
2697   };
2698 
2699   // Scan the bodies of functions for image operations, recording their implied
2700   // usage properties on the memory object declarations (i.e. variables or
2701   // function parameters).  We scan the functions in an order so that callees
2702   // precede callers. That way the usage on a function parameter is already
2703   // computed before we see the call to that function.  So when we reach
2704   // a function call, we can add the usage from the callee formal parameters.
2705   for (const auto* f : topologically_ordered_functions_) {
2706     for (const auto& bb : *f) {
2707       for (const auto& inst : bb) {
2708         switch (inst.opcode()) {
2709             // Single texel reads and writes
2710 
2711           case SpvOpImageRead:
2712             handle_usage_[get_image(inst)].AddStorageReadTexture();
2713             break;
2714           case SpvOpImageWrite:
2715             handle_usage_[get_image(inst)].AddStorageWriteTexture();
2716             break;
2717           case SpvOpImageFetch:
2718             handle_usage_[get_image(inst)].AddSampledTexture();
2719             break;
2720 
2721             // Sampling and gathering from a sampled image.
2722 
2723           case SpvOpImageSampleImplicitLod:
2724           case SpvOpImageSampleExplicitLod:
2725           case SpvOpImageSampleProjImplicitLod:
2726           case SpvOpImageSampleProjExplicitLod:
2727           case SpvOpImageGather:
2728             handle_usage_[get_image(inst)].AddSampledTexture();
2729             handle_usage_[get_sampler(inst)].AddSampler();
2730             break;
2731           case SpvOpImageSampleDrefImplicitLod:
2732           case SpvOpImageSampleDrefExplicitLod:
2733           case SpvOpImageSampleProjDrefImplicitLod:
2734           case SpvOpImageSampleProjDrefExplicitLod:
2735           case SpvOpImageDrefGather:
2736             // Depth reference access implies usage as a depth texture, which
2737             // in turn is a sampled texture.
2738             handle_usage_[get_image(inst)].AddDepthTexture();
2739             handle_usage_[get_sampler(inst)].AddComparisonSampler();
2740             break;
2741 
2742             // Image queries
2743 
2744           case SpvOpImageQuerySizeLod:
2745             // Vulkan requires Sampled=1 for this. SPIR-V already requires MS=0.
2746             handle_usage_[get_image(inst)].AddSampledTexture();
2747             break;
2748           case SpvOpImageQuerySize:
2749             // Applies to either MS=1 or Sampled=0 or 2.
2750             // So we can't force it to be multisampled, or storage image.
2751             break;
2752           case SpvOpImageQueryLod:
2753             handle_usage_[get_image(inst)].AddSampledTexture();
2754             handle_usage_[get_sampler(inst)].AddSampler();
2755             break;
2756           case SpvOpImageQueryLevels:
2757             // We can't tell anything more than that it's an image.
2758             handle_usage_[get_image(inst)].AddTexture();
2759             break;
2760           case SpvOpImageQuerySamples:
2761             handle_usage_[get_image(inst)].AddMultisampledTexture();
2762             break;
2763 
2764             // Function calls
2765 
2766           case SpvOpFunctionCall: {
2767             // Propagate handle usages from callee function formal parameters to
2768             // the matching caller parameters.  This is where we rely on the
2769             // fact that callees have been processed earlier in the flow.
2770             const auto num_in_operands = inst.NumInOperands();
2771             // The first operand of the call is the function ID.
2772             // The remaining operands are the operands to the function.
2773             if (num_in_operands < 1) {
2774               return Fail() << "Call instruction must have at least one operand"
2775                             << inst.PrettyPrint();
2776             }
2777             const auto function_id = inst.GetSingleWordInOperand(0);
2778             const auto& formal_params = function_params[function_id];
2779             if (formal_params.size() != (num_in_operands - 1)) {
2780               return Fail() << "Called function has " << formal_params.size()
2781                             << " parameters, but function call has "
2782                             << (num_in_operands - 1) << " parameters"
2783                             << inst.PrettyPrint();
2784             }
2785             for (uint32_t i = 1; i < num_in_operands; ++i) {
2786               auto where = handle_usage_.find(formal_params[i - 1]);
2787               if (where == handle_usage_.end()) {
2788                 // We haven't recorded any handle usage on the formal parameter.
2789                 continue;
2790               }
2791               const Usage& formal_param_usage = where->second;
2792               const auto operand_id = inst.GetSingleWordInOperand(i);
2793               const auto* operand_as_sampler =
2794                   GetMemoryObjectDeclarationForHandle(operand_id, false);
2795               const auto* operand_as_image =
2796                   GetMemoryObjectDeclarationForHandle(operand_id, true);
2797               if (operand_as_sampler) {
2798                 handle_usage_[operand_as_sampler].Add(formal_param_usage);
2799               }
2800               if (operand_as_image &&
2801                   (operand_as_image != operand_as_sampler)) {
2802                 handle_usage_[operand_as_image].Add(formal_param_usage);
2803               }
2804             }
2805             break;
2806           }
2807 
2808           default:
2809             break;
2810         }
2811       }
2812     }
2813   }
2814   return success_;
2815 }
2816 
GetHandleUsage(uint32_t id) const2817 Usage ParserImpl::GetHandleUsage(uint32_t id) const {
2818   const auto where = handle_usage_.find(def_use_mgr_->GetDef(id));
2819   if (where != handle_usage_.end()) {
2820     return where->second;
2821   }
2822   return Usage();
2823 }
2824 
GetInstructionForTest(uint32_t id) const2825 const spvtools::opt::Instruction* ParserImpl::GetInstructionForTest(
2826     uint32_t id) const {
2827   return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr;
2828 }
2829 
GetMemberName(const Struct & struct_type,int member_index)2830 std::string ParserImpl::GetMemberName(const Struct& struct_type,
2831                                       int member_index) {
2832   auto where = struct_id_for_symbol_.find(struct_type.name);
2833   if (where == struct_id_for_symbol_.end()) {
2834     Fail() << "no structure type registered for symbol";
2835     return "";
2836   }
2837   return namer_.GetMemberName(where->second, member_index);
2838 }
2839 
2840 WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
2841 
2842 WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
2843 
2844 }  // namespace spirv
2845 }  // namespace reader
2846 }  // namespace tint
2847