1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/reader/spirv/parser_impl.h"
16
17 #include <algorithm>
18 #include <limits>
19 #include <locale>
20 #include <utility>
21
22 #include "source/opt/build_module.h"
23 #include "src/ast/bitcast_expression.h"
24 #include "src/ast/disable_validation_decoration.h"
25 #include "src/ast/interpolate_decoration.h"
26 #include "src/ast/override_decoration.h"
27 #include "src/ast/struct_block_decoration.h"
28 #include "src/ast/type_name.h"
29 #include "src/ast/unary_op_expression.h"
30 #include "src/reader/spirv/function.h"
31 #include "src/sem/depth_texture_type.h"
32 #include "src/sem/multisampled_texture_type.h"
33 #include "src/sem/sampled_texture_type.h"
34 #include "src/utils/unique_vector.h"
35
36 namespace tint {
37 namespace reader {
38 namespace spirv {
39
40 namespace {
41
42 // Input SPIR-V needs only to conform to Vulkan 1.1 requirements.
43 // The combination of the SPIR-V reader and the semantics of WGSL
44 // tighten up the code so that the output of the SPIR-V *writer*
45 // will satisfy SPV_ENV_WEBGPU_0 validation.
46 const spv_target_env kInputEnv = SPV_ENV_VULKAN_1_1;
47
48 // A FunctionTraverser is used to compute an ordering of functions in the
49 // module such that callees precede callers.
50 class FunctionTraverser {
51 public:
FunctionTraverser(const spvtools::opt::Module & module)52 explicit FunctionTraverser(const spvtools::opt::Module& module)
53 : module_(module) {}
54
55 // @returns the functions in the modules such that callees precede callers.
TopologicallyOrderedFunctions()56 std::vector<const spvtools::opt::Function*> TopologicallyOrderedFunctions() {
57 visited_.clear();
58 ordered_.clear();
59 id_to_func_.clear();
60 for (const auto& f : module_) {
61 id_to_func_[f.result_id()] = &f;
62 }
63 for (const auto& f : module_) {
64 Visit(f);
65 }
66 return ordered_;
67 }
68
69 private:
Visit(const spvtools::opt::Function & f)70 void Visit(const spvtools::opt::Function& f) {
71 if (visited_.count(&f)) {
72 return;
73 }
74 visited_.insert(&f);
75 for (const auto& bb : f) {
76 for (const auto& inst : bb) {
77 if (inst.opcode() != SpvOpFunctionCall) {
78 continue;
79 }
80 const auto* callee = id_to_func_[inst.GetSingleWordInOperand(0)];
81 if (callee) {
82 Visit(*callee);
83 }
84 }
85 }
86 ordered_.push_back(&f);
87 }
88
89 const spvtools::opt::Module& module_;
90 std::unordered_set<const spvtools::opt::Function*> visited_;
91 std::unordered_map<uint32_t, const spvtools::opt::Function*> id_to_func_;
92 std::vector<const spvtools::opt::Function*> ordered_;
93 };
94
95 // Returns true if the opcode operates as if its operands are signed integral.
AssumesSignedOperands(SpvOp opcode)96 bool AssumesSignedOperands(SpvOp opcode) {
97 switch (opcode) {
98 case SpvOpSNegate:
99 case SpvOpSDiv:
100 case SpvOpSRem:
101 case SpvOpSMod:
102 case SpvOpSLessThan:
103 case SpvOpSLessThanEqual:
104 case SpvOpSGreaterThan:
105 case SpvOpSGreaterThanEqual:
106 case SpvOpConvertSToF:
107 return true;
108 default:
109 break;
110 }
111 return false;
112 }
113
114 // Returns true if the GLSL extended instruction expects operands to be signed.
115 // @param extended_opcode GLSL.std.450 opcode
116 // @returns true if all operands must be signed integral type
AssumesSignedOperands(GLSLstd450 extended_opcode)117 bool AssumesSignedOperands(GLSLstd450 extended_opcode) {
118 switch (extended_opcode) {
119 case GLSLstd450SAbs:
120 case GLSLstd450SSign:
121 case GLSLstd450SMin:
122 case GLSLstd450SMax:
123 case GLSLstd450SClamp:
124 return true;
125 default:
126 break;
127 }
128 return false;
129 }
130
131 // Returns true if the opcode operates as if its operands are unsigned integral.
AssumesUnsignedOperands(SpvOp opcode)132 bool AssumesUnsignedOperands(SpvOp opcode) {
133 switch (opcode) {
134 case SpvOpUDiv:
135 case SpvOpUMod:
136 case SpvOpULessThan:
137 case SpvOpULessThanEqual:
138 case SpvOpUGreaterThan:
139 case SpvOpUGreaterThanEqual:
140 case SpvOpConvertUToF:
141 return true;
142 default:
143 break;
144 }
145 return false;
146 }
147
148 // Returns true if the GLSL extended instruction expects operands to be
149 // unsigned.
150 // @param extended_opcode GLSL.std.450 opcode
151 // @returns true if all operands must be unsigned integral type
AssumesUnsignedOperands(GLSLstd450 extended_opcode)152 bool AssumesUnsignedOperands(GLSLstd450 extended_opcode) {
153 switch (extended_opcode) {
154 case GLSLstd450UMin:
155 case GLSLstd450UMax:
156 case GLSLstd450UClamp:
157 return true;
158 default:
159 break;
160 }
161 return false;
162 }
163
164 // Returns true if the corresponding WGSL operation requires
165 // the signedness of the second operand to match the signedness of the
166 // first operand, and it's not one of the OpU* or OpS* instructions.
167 // (Those are handled via MakeOperand.)
AssumesSecondOperandSignednessMatchesFirstOperand(SpvOp opcode)168 bool AssumesSecondOperandSignednessMatchesFirstOperand(SpvOp opcode) {
169 switch (opcode) {
170 // All the OpI* integer binary operations.
171 case SpvOpIAdd:
172 case SpvOpISub:
173 case SpvOpIMul:
174 case SpvOpIEqual:
175 case SpvOpINotEqual:
176 // All the bitwise integer binary operations.
177 case SpvOpBitwiseAnd:
178 case SpvOpBitwiseOr:
179 case SpvOpBitwiseXor:
180 return true;
181 default:
182 break;
183 }
184 return false;
185 }
186
187 // Returns true if the corresponding WGSL operation requires
188 // the signedness of the result to match the signedness of the first operand.
AssumesResultSignednessMatchesFirstOperand(SpvOp opcode)189 bool AssumesResultSignednessMatchesFirstOperand(SpvOp opcode) {
190 switch (opcode) {
191 case SpvOpNot:
192 case SpvOpSNegate:
193 case SpvOpBitCount:
194 case SpvOpBitReverse:
195 case SpvOpSDiv:
196 case SpvOpSMod:
197 case SpvOpSRem:
198 case SpvOpIAdd:
199 case SpvOpISub:
200 case SpvOpIMul:
201 case SpvOpBitwiseAnd:
202 case SpvOpBitwiseOr:
203 case SpvOpBitwiseXor:
204 case SpvOpShiftLeftLogical:
205 case SpvOpShiftRightLogical:
206 case SpvOpShiftRightArithmetic:
207 return true;
208 default:
209 break;
210 }
211 return false;
212 }
213
214 // Returns true if the extended instruction requires the signedness of the
215 // result to match the signedness of the first operand to the operation.
216 // @param extended_opcode GLSL.std.450 opcode
217 // @returns true if the result type must match the first operand type.
AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode)218 bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
219 switch (extended_opcode) {
220 case GLSLstd450SAbs:
221 case GLSLstd450SSign:
222 case GLSLstd450SMin:
223 case GLSLstd450SMax:
224 case GLSLstd450SClamp:
225 case GLSLstd450UMin:
226 case GLSLstd450UMax:
227 case GLSLstd450UClamp:
228 // TODO(dneto): FindSMsb?
229 // TODO(dneto): FindUMsb?
230 return true;
231 default:
232 break;
233 }
234 return false;
235 }
236
237 // @param a SPIR-V decoration
238 // @return true when the given decoration is a pipeline decoration other than a
239 // bulitin variable.
IsPipelineDecoration(const Decoration & deco)240 bool IsPipelineDecoration(const Decoration& deco) {
241 if (deco.size() < 1) {
242 return false;
243 }
244 switch (deco[0]) {
245 case SpvDecorationLocation:
246 case SpvDecorationFlat:
247 case SpvDecorationNoPerspective:
248 case SpvDecorationCentroid:
249 case SpvDecorationSample:
250 return true;
251 default:
252 break;
253 }
254 return false;
255 }
256
257 } // namespace
258
259 TypedExpression::TypedExpression() = default;
260
261 TypedExpression::TypedExpression(const TypedExpression&) = default;
262
263 TypedExpression& TypedExpression::operator=(const TypedExpression&) = default;
264
TypedExpression(const Type * type_in,const ast::Expression * expr_in)265 TypedExpression::TypedExpression(const Type* type_in,
266 const ast::Expression* expr_in)
267 : type(type_in), expr(expr_in) {}
268
ParserImpl(const std::vector<uint32_t> & spv_binary)269 ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
270 : Reader(),
271 spv_binary_(spv_binary),
272 fail_stream_(&success_, &errors_),
273 namer_(fail_stream_),
274 enum_converter_(fail_stream_),
275 tools_context_(kInputEnv) {
276 // Create a message consumer to propagate error messages from SPIRV-Tools
277 // out as our own failures.
278 message_consumer_ = [this](spv_message_level_t level, const char* /*source*/,
279 const spv_position_t& position,
280 const char* message) {
281 switch (level) {
282 // Ignore info and warning message.
283 case SPV_MSG_WARNING:
284 case SPV_MSG_INFO:
285 break;
286 // Otherwise, propagate the error.
287 default:
288 // For binary validation errors, we only have the instruction
289 // number. It's not text, so there is no column number.
290 this->Fail() << "line:" << position.index << ": " << message;
291 }
292 };
293 }
294
295 ParserImpl::~ParserImpl() = default;
296
Parse()297 bool ParserImpl::Parse() {
298 // Set up use of SPIRV-Tools utilities.
299 spvtools::SpirvTools spv_tools(kInputEnv);
300
301 // Error messages from SPIRV-Tools are forwarded as failures, including
302 // setting |success_| to false.
303 spv_tools.SetMessageConsumer(message_consumer_);
304
305 if (!success_) {
306 return false;
307 }
308
309 // Only consider modules valid for Vulkan 1.0. On failure, the message
310 // consumer will set the error status.
311 if (!spv_tools.Validate(spv_binary_)) {
312 success_ = false;
313 return false;
314 }
315 if (!BuildInternalModule()) {
316 return false;
317 }
318 if (!ParseInternalModule()) {
319 return false;
320 }
321
322 return success_;
323 }
324
program()325 Program ParserImpl::program() {
326 // TODO(dneto): Should we clear out spv_binary_ here, to reduce
327 // memory usage?
328 return tint::Program(std::move(builder_));
329 }
330
ConvertType(uint32_t type_id,PtrAs ptr_as)331 const Type* ParserImpl::ConvertType(uint32_t type_id, PtrAs ptr_as) {
332 if (!success_) {
333 return nullptr;
334 }
335
336 if (type_mgr_ == nullptr) {
337 Fail() << "ConvertType called when the internal module has not been built";
338 return nullptr;
339 }
340
341 auto* spirv_type = type_mgr_->GetType(type_id);
342 if (spirv_type == nullptr) {
343 Fail() << "ID is not a SPIR-V type: " << type_id;
344 return nullptr;
345 }
346
347 switch (spirv_type->kind()) {
348 case spvtools::opt::analysis::Type::kVoid:
349 return ty_.Void();
350 case spvtools::opt::analysis::Type::kBool:
351 return ty_.Bool();
352 case spvtools::opt::analysis::Type::kInteger:
353 return ConvertType(spirv_type->AsInteger());
354 case spvtools::opt::analysis::Type::kFloat:
355 return ConvertType(spirv_type->AsFloat());
356 case spvtools::opt::analysis::Type::kVector:
357 return ConvertType(spirv_type->AsVector());
358 case spvtools::opt::analysis::Type::kMatrix:
359 return ConvertType(spirv_type->AsMatrix());
360 case spvtools::opt::analysis::Type::kRuntimeArray:
361 return ConvertType(type_id, spirv_type->AsRuntimeArray());
362 case spvtools::opt::analysis::Type::kArray:
363 return ConvertType(type_id, spirv_type->AsArray());
364 case spvtools::opt::analysis::Type::kStruct:
365 return ConvertType(type_id, spirv_type->AsStruct());
366 case spvtools::opt::analysis::Type::kPointer:
367 return ConvertType(type_id, ptr_as, spirv_type->AsPointer());
368 case spvtools::opt::analysis::Type::kFunction:
369 // Tint doesn't have a Function type.
370 // We need to convert the result type and parameter types.
371 // But the SPIR-V defines those before defining the function
372 // type. No further work is required here.
373 return nullptr;
374 case spvtools::opt::analysis::Type::kSampler:
375 case spvtools::opt::analysis::Type::kSampledImage:
376 case spvtools::opt::analysis::Type::kImage:
377 // Fake it for sampler and texture types. These are handled in an
378 // entirely different way.
379 return ty_.Void();
380 default:
381 break;
382 }
383
384 Fail() << "unknown SPIR-V type with ID " << type_id << ": "
385 << def_use_mgr_->GetDef(type_id)->PrettyPrint();
386 return nullptr;
387 }
388
GetDecorationsFor(uint32_t id) const389 DecorationList ParserImpl::GetDecorationsFor(uint32_t id) const {
390 DecorationList result;
391 const auto& decorations = deco_mgr_->GetDecorationsFor(id, true);
392 for (const auto* inst : decorations) {
393 if (inst->opcode() != SpvOpDecorate) {
394 continue;
395 }
396 // Example: OpDecorate %struct_id Block
397 // Example: OpDecorate %array_ty ArrayStride 16
398 std::vector<uint32_t> inst_as_words;
399 inst->ToBinaryWithoutAttachedDebugInsts(&inst_as_words);
400 Decoration d(inst_as_words.begin() + 2, inst_as_words.end());
401 result.push_back(d);
402 }
403 return result;
404 }
405
GetDecorationsForMember(uint32_t id,uint32_t member_index) const406 DecorationList ParserImpl::GetDecorationsForMember(
407 uint32_t id,
408 uint32_t member_index) const {
409 DecorationList result;
410 const auto& decorations = deco_mgr_->GetDecorationsFor(id, true);
411 for (const auto* inst : decorations) {
412 if ((inst->opcode() != SpvOpMemberDecorate) ||
413 (inst->GetSingleWordInOperand(1) != member_index)) {
414 continue;
415 }
416 // Example: OpMemberDecorate %struct_id 2 Offset 24
417 std::vector<uint32_t> inst_as_words;
418 inst->ToBinaryWithoutAttachedDebugInsts(&inst_as_words);
419 Decoration d(inst_as_words.begin() + 3, inst_as_words.end());
420 result.push_back(d);
421 }
422 return result;
423 }
424
ShowType(uint32_t type_id)425 std::string ParserImpl::ShowType(uint32_t type_id) {
426 if (def_use_mgr_) {
427 const auto* type_inst = def_use_mgr_->GetDef(type_id);
428 if (type_inst) {
429 return type_inst->PrettyPrint();
430 }
431 }
432 return "SPIR-V type " + std::to_string(type_id);
433 }
434
ConvertMemberDecoration(uint32_t struct_type_id,uint32_t member_index,const Type * member_ty,const Decoration & decoration)435 ast::DecorationList ParserImpl::ConvertMemberDecoration(
436 uint32_t struct_type_id,
437 uint32_t member_index,
438 const Type* member_ty,
439 const Decoration& decoration) {
440 if (decoration.empty()) {
441 Fail() << "malformed SPIR-V decoration: it's empty";
442 return {};
443 }
444 switch (decoration[0]) {
445 case SpvDecorationOffset:
446 if (decoration.size() != 2) {
447 Fail()
448 << "malformed Offset decoration: expected 1 literal operand, has "
449 << decoration.size() - 1 << ": member " << member_index << " of "
450 << ShowType(struct_type_id);
451 return {};
452 }
453 return {
454 create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
455 };
456 case SpvDecorationNonReadable:
457 // WGSL doesn't have a member decoration for this. Silently drop it.
458 return {};
459 case SpvDecorationNonWritable:
460 // WGSL doesn't have a member decoration for this.
461 return {};
462 case SpvDecorationColMajor:
463 // WGSL only supports column major matrices.
464 return {};
465 case SpvDecorationRelaxedPrecision:
466 // WGSL doesn't support relaxed precision.
467 return {};
468 case SpvDecorationRowMajor:
469 Fail() << "WGSL does not support row-major matrices: can't "
470 "translate member "
471 << member_index << " of " << ShowType(struct_type_id);
472 return {};
473 case SpvDecorationMatrixStride: {
474 if (decoration.size() != 2) {
475 Fail() << "malformed MatrixStride decoration: expected 1 literal "
476 "operand, has "
477 << decoration.size() - 1 << ": member " << member_index << " of "
478 << ShowType(struct_type_id);
479 return {};
480 }
481 uint32_t stride = decoration[1];
482 auto* ty = member_ty->UnwrapAlias();
483 while (auto* arr = ty->As<Array>()) {
484 ty = arr->type->UnwrapAlias();
485 }
486 auto* mat = ty->As<Matrix>();
487 if (!mat) {
488 Fail() << "MatrixStride cannot be applied to type " << ty->String();
489 return {};
490 }
491 uint32_t natural_stride = (mat->rows == 2) ? 8 : 16;
492 if (stride == natural_stride) {
493 return {}; // Decoration matches the natural stride for the matrix
494 }
495 if (!member_ty->Is<Matrix>()) {
496 Fail() << "custom matrix strides not currently supported on array of "
497 "matrices";
498 return {};
499 }
500 return {
501 create<ast::StrideDecoration>(Source{}, decoration[1]),
502 builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
503 builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
504 };
505 }
506 default:
507 // TODO(dneto): Support the remaining member decorations.
508 break;
509 }
510 Fail() << "unhandled member decoration: " << decoration[0] << " on member "
511 << member_index << " of " << ShowType(struct_type_id);
512 return {};
513 }
514
BuildInternalModule()515 bool ParserImpl::BuildInternalModule() {
516 if (!success_) {
517 return false;
518 }
519
520 const spv_context& context = tools_context_.CContext();
521 ir_context_ = spvtools::BuildModule(context->target_env, context->consumer,
522 spv_binary_.data(), spv_binary_.size());
523 if (!ir_context_) {
524 return Fail() << "internal error: couldn't build the internal "
525 "representation of the module";
526 }
527 module_ = ir_context_->module();
528 def_use_mgr_ = ir_context_->get_def_use_mgr();
529 constant_mgr_ = ir_context_->get_constant_mgr();
530 type_mgr_ = ir_context_->get_type_mgr();
531 deco_mgr_ = ir_context_->get_decoration_mgr();
532
533 topologically_ordered_functions_ =
534 FunctionTraverser(*module_).TopologicallyOrderedFunctions();
535
536 return success_;
537 }
538
ResetInternalModule()539 void ParserImpl::ResetInternalModule() {
540 ir_context_.reset(nullptr);
541 module_ = nullptr;
542 def_use_mgr_ = nullptr;
543 constant_mgr_ = nullptr;
544 type_mgr_ = nullptr;
545 deco_mgr_ = nullptr;
546
547 glsl_std_450_imports_.clear();
548 }
549
ParseInternalModule()550 bool ParserImpl::ParseInternalModule() {
551 if (!success_) {
552 return false;
553 }
554 RegisterLineNumbers();
555 if (!ParseInternalModuleExceptFunctions()) {
556 return false;
557 }
558 if (!EmitFunctions()) {
559 return false;
560 }
561 return success_;
562 }
563
RegisterLineNumbers()564 void ParserImpl::RegisterLineNumbers() {
565 Source::Location instruction_number{};
566
567 // Has there been an OpLine since the last OpNoLine or start of the module?
568 bool in_op_line_scope = false;
569 // The source location provided by the most recent OpLine instruction.
570 Source::Location op_line_source{};
571 const bool run_on_debug_insts = true;
572 module_->ForEachInst(
573 [this, &in_op_line_scope, &op_line_source,
574 &instruction_number](const spvtools::opt::Instruction* inst) {
575 ++instruction_number.line;
576 switch (inst->opcode()) {
577 case SpvOpLine:
578 in_op_line_scope = true;
579 // TODO(dneto): This ignores the File ID (operand 0), since the Tint
580 // Source concept doesn't represent that.
581 op_line_source.line = inst->GetSingleWordInOperand(1);
582 op_line_source.column = inst->GetSingleWordInOperand(2);
583 break;
584 case SpvOpNoLine:
585 in_op_line_scope = false;
586 break;
587 default:
588 break;
589 }
590 this->inst_source_[inst] =
591 in_op_line_scope ? op_line_source : instruction_number;
592 },
593 run_on_debug_insts);
594 }
595
GetSourceForResultIdForTest(uint32_t id) const596 Source ParserImpl::GetSourceForResultIdForTest(uint32_t id) const {
597 return GetSourceForInst(def_use_mgr_->GetDef(id));
598 }
599
GetSourceForInst(const spvtools::opt::Instruction * inst) const600 Source ParserImpl::GetSourceForInst(
601 const spvtools::opt::Instruction* inst) const {
602 auto where = inst_source_.find(inst);
603 if (where == inst_source_.end()) {
604 return {};
605 }
606 return Source{where->second};
607 }
608
ParseInternalModuleExceptFunctions()609 bool ParserImpl::ParseInternalModuleExceptFunctions() {
610 if (!success_) {
611 return false;
612 }
613 if (!RegisterExtendedInstructionImports()) {
614 return false;
615 }
616 if (!RegisterUserAndStructMemberNames()) {
617 return false;
618 }
619 if (!RegisterWorkgroupSizeBuiltin()) {
620 return false;
621 }
622 if (!RegisterEntryPoints()) {
623 return false;
624 }
625 if (!RegisterHandleUsage()) {
626 return false;
627 }
628 if (!RegisterTypes()) {
629 return false;
630 }
631 if (!RejectInvalidPointerRoots()) {
632 return false;
633 }
634 if (!EmitScalarSpecConstants()) {
635 return false;
636 }
637 if (!EmitModuleScopeVariables()) {
638 return false;
639 }
640 return success_;
641 }
642
RegisterExtendedInstructionImports()643 bool ParserImpl::RegisterExtendedInstructionImports() {
644 for (const spvtools::opt::Instruction& import : module_->ext_inst_imports()) {
645 std::string name(
646 reinterpret_cast<const char*>(import.GetInOperand(0).words.data()));
647 // TODO(dneto): Handle other extended instruction sets when needed.
648 if (name == "GLSL.std.450") {
649 glsl_std_450_imports_.insert(import.result_id());
650 } else if (name.find("NonSemantic.") == 0) {
651 ignored_imports_.insert(import.result_id());
652 } else {
653 return Fail() << "Unrecognized extended instruction set: " << name;
654 }
655 }
656 return true;
657 }
658
IsGlslExtendedInstruction(const spvtools::opt::Instruction & inst) const659 bool ParserImpl::IsGlslExtendedInstruction(
660 const spvtools::opt::Instruction& inst) const {
661 return (inst.opcode() == SpvOpExtInst) &&
662 (glsl_std_450_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
663 }
664
IsIgnoredExtendedInstruction(const spvtools::opt::Instruction & inst) const665 bool ParserImpl::IsIgnoredExtendedInstruction(
666 const spvtools::opt::Instruction& inst) const {
667 return (inst.opcode() == SpvOpExtInst) &&
668 (ignored_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
669 }
670
RegisterUserAndStructMemberNames()671 bool ParserImpl::RegisterUserAndStructMemberNames() {
672 if (!success_) {
673 return false;
674 }
675 // Register entry point names. An entry point name is the point of contact
676 // between the API and the shader. It has the highest priority for
677 // preservation, so register it first.
678 for (const spvtools::opt::Instruction& entry_point :
679 module_->entry_points()) {
680 const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
681 const std::string name = entry_point.GetInOperand(2).AsString();
682
683 // This translator requires the entry point to be a valid WGSL identifier.
684 // Allowing otherwise leads to difficulties in that the programmer needs
685 // to get a mapping from their original entry point name to the WGSL name,
686 // and we don't have a good mechanism for that.
687 if (!IsValidIdentifier(name)) {
688 return Fail() << "entry point name is not a valid WGSL identifier: "
689 << name;
690 }
691
692 // SPIR-V allows a single function to be the implementation for more
693 // than one entry point. In the common case, it's one-to-one, and we should
694 // try to name the function after the entry point. Otherwise, give the
695 // function a name automatically derived from the entry point name.
696 namer_.SuggestSanitizedName(function_id, name);
697
698 // There is another many-to-one relationship to take care of: In SPIR-V
699 // the same name can be used for multiple entry points, provided they are
700 // for different shader stages. Take action now to ensure we can use the
701 // entry point name later on, and not have it taken for another identifier
702 // by an accidental collision with a derived name made for a different ID.
703 if (!namer_.IsRegistered(name)) {
704 // The entry point name is "unoccupied" becase an earlier entry point
705 // grabbed the slot for the function that implements both entry points.
706 // Register this new entry point's name, to avoid accidental collisions
707 // with a future generated ID.
708 if (!namer_.RegisterWithoutId(name)) {
709 return false;
710 }
711 }
712 }
713
714 // Register names from OpName and OpMemberName
715 for (const auto& inst : module_->debugs2()) {
716 switch (inst.opcode()) {
717 case SpvOpName: {
718 const auto name = inst.GetInOperand(1).AsString();
719 if (!name.empty()) {
720 namer_.SuggestSanitizedName(inst.GetSingleWordInOperand(0), name);
721 }
722 break;
723 }
724 case SpvOpMemberName: {
725 const auto name = inst.GetInOperand(2).AsString();
726 if (!name.empty()) {
727 namer_.SuggestSanitizedMemberName(inst.GetSingleWordInOperand(0),
728 inst.GetSingleWordInOperand(1),
729 name);
730 }
731 break;
732 }
733 default:
734 break;
735 }
736 }
737
738 // Fill in struct member names, and disambiguate them.
739 for (const auto* type_inst : module_->GetTypes()) {
740 if (type_inst->opcode() == SpvOpTypeStruct) {
741 namer_.ResolveMemberNamesForStruct(type_inst->result_id(),
742 type_inst->NumInOperands());
743 }
744 }
745
746 return true;
747 }
748
IsValidIdentifier(const std::string & str)749 bool ParserImpl::IsValidIdentifier(const std::string& str) {
750 if (str.empty()) {
751 return false;
752 }
753 std::locale c_locale("C");
754 if (!std::isalpha(str[0], c_locale)) {
755 return false;
756 }
757 for (const char& ch : str) {
758 if ((ch != '_') && !std::isalnum(ch, c_locale)) {
759 return false;
760 }
761 }
762 return true;
763 }
764
RegisterWorkgroupSizeBuiltin()765 bool ParserImpl::RegisterWorkgroupSizeBuiltin() {
766 WorkgroupSizeInfo& info = workgroup_size_builtin_;
767 for (const spvtools::opt::Instruction& inst : module_->annotations()) {
768 if (inst.opcode() != SpvOpDecorate) {
769 continue;
770 }
771 if (inst.GetSingleWordInOperand(1) != SpvDecorationBuiltIn) {
772 continue;
773 }
774 if (inst.GetSingleWordInOperand(2) != SpvBuiltInWorkgroupSize) {
775 continue;
776 }
777 info.id = inst.GetSingleWordInOperand(0);
778 }
779 if (info.id == 0) {
780 return true;
781 }
782 // Gather the values.
783 const spvtools::opt::Instruction* composite_def =
784 def_use_mgr_->GetDef(info.id);
785 if (!composite_def) {
786 return Fail() << "Invalid WorkgroupSize builtin value";
787 }
788 // SPIR-V validation checks that the result is a 3-element vector of 32-bit
789 // integer scalars (signed or unsigned). Rely on validation to check the
790 // type. In theory the instruction could be OpConstantNull and still
791 // pass validation, but that would be non-sensical. Be a little more
792 // stringent here and check for specific opcodes. WGSL does not support
793 // const-expr yet, so avoid supporting OpSpecConstantOp here.
794 // TODO(dneto): See https://github.com/gpuweb/gpuweb/issues/1272 for WGSL
795 // const_expr proposals.
796 if ((composite_def->opcode() != SpvOpSpecConstantComposite &&
797 composite_def->opcode() != SpvOpConstantComposite)) {
798 return Fail() << "Invalid WorkgroupSize builtin. Expected 3-element "
799 "OpSpecConstantComposite or OpConstantComposite: "
800 << composite_def->PrettyPrint();
801 }
802 info.type_id = composite_def->type_id();
803 // Extract the component type from the vector type.
804 info.component_type_id =
805 def_use_mgr_->GetDef(info.type_id)->GetSingleWordInOperand(0);
806
807 /// Sets the ID and value of the index'th member of the composite constant.
808 /// Returns false and emits a diagnostic on error.
809 auto set_param = [this, composite_def](uint32_t* id_ptr, uint32_t* value_ptr,
810 int index) -> bool {
811 const auto id = composite_def->GetSingleWordInOperand(index);
812 const auto* def = def_use_mgr_->GetDef(id);
813 if (!def ||
814 (def->opcode() != SpvOpSpecConstant &&
815 def->opcode() != SpvOpConstant) ||
816 (def->NumInOperands() != 1)) {
817 return Fail() << "invalid component " << index << " of workgroupsize "
818 << (def ? def->PrettyPrint()
819 : std::string("no definition"));
820 }
821 *id_ptr = id;
822 // Use the default value of a spec constant.
823 *value_ptr = def->GetSingleWordInOperand(0);
824 return true;
825 };
826
827 return set_param(&info.x_id, &info.x_value, 0) &&
828 set_param(&info.y_id, &info.y_value, 1) &&
829 set_param(&info.z_id, &info.z_value, 2);
830 }
831
RegisterEntryPoints()832 bool ParserImpl::RegisterEntryPoints() {
833 // Mapping from entry point ID to GridSize computed from LocalSize
834 // decorations.
835 std::unordered_map<uint32_t, GridSize> local_size;
836 for (const spvtools::opt::Instruction& inst : module_->execution_modes()) {
837 auto mode = static_cast<SpvExecutionMode>(inst.GetSingleWordInOperand(1));
838 if (mode == SpvExecutionModeLocalSize) {
839 if (inst.NumInOperands() != 5) {
840 // This won't even get past SPIR-V binary parsing.
841 return Fail() << "invalid LocalSize execution mode: "
842 << inst.PrettyPrint();
843 }
844 uint32_t function_id = inst.GetSingleWordInOperand(0);
845 local_size[function_id] = GridSize{inst.GetSingleWordInOperand(2),
846 inst.GetSingleWordInOperand(3),
847 inst.GetSingleWordInOperand(4)};
848 }
849 }
850
851 for (const spvtools::opt::Instruction& entry_point :
852 module_->entry_points()) {
853 const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0));
854 const uint32_t function_id = entry_point.GetSingleWordInOperand(1);
855
856 const std::string ep_name = entry_point.GetOperand(2).AsString();
857 if (!IsValidIdentifier(ep_name)) {
858 return Fail() << "entry point name is not a valid WGSL identifier: "
859 << ep_name;
860 }
861
862 bool owns_inner_implementation = false;
863 std::string inner_implementation_name;
864
865 auto where = function_to_ep_info_.find(function_id);
866 if (where == function_to_ep_info_.end()) {
867 // If this is the first entry point to have function_id as its
868 // implementation, then this entry point is responsible for generating
869 // the inner implementation.
870 owns_inner_implementation = true;
871 inner_implementation_name = namer_.MakeDerivedName(ep_name);
872 } else {
873 // Reuse the inner implementation owned by the first entry point.
874 inner_implementation_name = where->second[0].inner_name;
875 }
876 TINT_ASSERT(Reader, !inner_implementation_name.empty());
877 TINT_ASSERT(Reader, ep_name != inner_implementation_name);
878
879 utils::UniqueVector<uint32_t> inputs;
880 utils::UniqueVector<uint32_t> outputs;
881 for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) {
882 const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg);
883 if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) {
884 switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) {
885 case SpvStorageClassInput:
886 inputs.add(var_id);
887 break;
888 case SpvStorageClassOutput:
889 outputs.add(var_id);
890 break;
891 default:
892 break;
893 }
894 }
895 }
896 // Save the lists, in ID-sorted order.
897 std::vector<uint32_t> sorted_inputs(inputs);
898 std::sort(sorted_inputs.begin(), sorted_inputs.end());
899 std::vector<uint32_t> sorted_outputs(outputs);
900 std::sort(sorted_outputs.begin(), sorted_outputs.end());
901
902 const auto ast_stage = enum_converter_.ToPipelineStage(stage);
903 GridSize wgsize;
904 if (ast_stage == ast::PipelineStage::kCompute) {
905 if (workgroup_size_builtin_.id) {
906 // Store the default values.
907 // WGSL allows specializing these, but this code doesn't support that
908 // yet. https://github.com/gpuweb/gpuweb/issues/1442
909 wgsize = GridSize{workgroup_size_builtin_.x_value,
910 workgroup_size_builtin_.y_value,
911 workgroup_size_builtin_.z_value};
912 } else {
913 // Use the LocalSize execution mode. This is the second choice.
914 auto where_local_size = local_size.find(function_id);
915 if (where_local_size != local_size.end()) {
916 wgsize = where_local_size->second;
917 }
918 }
919 }
920 function_to_ep_info_[function_id].emplace_back(
921 ep_name, ast_stage, owns_inner_implementation,
922 inner_implementation_name, std::move(sorted_inputs),
923 std::move(sorted_outputs), wgsize);
924 }
925
926 // The enum conversion could have failed, so return the existing status value.
927 return success_;
928 }
929
ConvertType(const spvtools::opt::analysis::Integer * int_ty)930 const Type* ParserImpl::ConvertType(
931 const spvtools::opt::analysis::Integer* int_ty) {
932 if (int_ty->width() == 32) {
933 return int_ty->IsSigned() ? static_cast<const Type*>(ty_.I32())
934 : static_cast<const Type*>(ty_.U32());
935 }
936 Fail() << "unhandled integer width: " << int_ty->width();
937 return nullptr;
938 }
939
ConvertType(const spvtools::opt::analysis::Float * float_ty)940 const Type* ParserImpl::ConvertType(
941 const spvtools::opt::analysis::Float* float_ty) {
942 if (float_ty->width() == 32) {
943 return ty_.F32();
944 }
945 Fail() << "unhandled float width: " << float_ty->width();
946 return nullptr;
947 }
948
ConvertType(const spvtools::opt::analysis::Vector * vec_ty)949 const Type* ParserImpl::ConvertType(
950 const spvtools::opt::analysis::Vector* vec_ty) {
951 const auto num_elem = vec_ty->element_count();
952 auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
953 if (ast_elem_ty == nullptr) {
954 return ast_elem_ty;
955 }
956 return ty_.Vector(ast_elem_ty, num_elem);
957 }
958
ConvertType(const spvtools::opt::analysis::Matrix * mat_ty)959 const Type* ParserImpl::ConvertType(
960 const spvtools::opt::analysis::Matrix* mat_ty) {
961 const auto* vec_ty = mat_ty->element_type()->AsVector();
962 const auto* scalar_ty = vec_ty->element_type();
963 const auto num_rows = vec_ty->element_count();
964 const auto num_columns = mat_ty->element_count();
965 auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty));
966 if (ast_scalar_ty == nullptr) {
967 return nullptr;
968 }
969 return ty_.Matrix(ast_scalar_ty, num_columns, num_rows);
970 }
971
ConvertType(uint32_t type_id,const spvtools::opt::analysis::RuntimeArray * rtarr_ty)972 const Type* ParserImpl::ConvertType(
973 uint32_t type_id,
974 const spvtools::opt::analysis::RuntimeArray* rtarr_ty) {
975 auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
976 if (ast_elem_ty == nullptr) {
977 return nullptr;
978 }
979 uint32_t array_stride = 0;
980 if (!ParseArrayDecorations(rtarr_ty, &array_stride)) {
981 return nullptr;
982 }
983 const Type* result = ty_.Array(ast_elem_ty, 0, array_stride);
984 return MaybeGenerateAlias(type_id, rtarr_ty, result);
985 }
986
ConvertType(uint32_t type_id,const spvtools::opt::analysis::Array * arr_ty)987 const Type* ParserImpl::ConvertType(
988 uint32_t type_id,
989 const spvtools::opt::analysis::Array* arr_ty) {
990 // Get the element type. The SPIR-V optimizer's types representation
991 // deduplicates array types that have the same parameterization.
992 // We don't want that deduplication, so get the element type from
993 // the SPIR-V type directly.
994 const auto* inst = def_use_mgr_->GetDef(type_id);
995 const auto elem_type_id = inst->GetSingleWordInOperand(0);
996 auto* ast_elem_ty = ConvertType(elem_type_id);
997 if (ast_elem_ty == nullptr) {
998 return nullptr;
999 }
1000 // Get the length.
1001 const auto& length_info = arr_ty->length_info();
1002 if (length_info.words.empty()) {
1003 // The internal representation is invalid. The discriminant vector
1004 // is mal-formed.
1005 Fail() << "internal error: Array length info is invalid";
1006 return nullptr;
1007 }
1008 if (length_info.words[0] !=
1009 spvtools::opt::analysis::Array::LengthInfo::kConstant) {
1010 Fail() << "Array type " << type_mgr_->GetId(arr_ty)
1011 << " length is a specialization constant";
1012 return nullptr;
1013 }
1014 const auto* constant = constant_mgr_->FindDeclaredConstant(length_info.id);
1015 if (constant == nullptr) {
1016 Fail() << "Array type " << type_mgr_->GetId(arr_ty) << " length ID "
1017 << length_info.id << " does not name an OpConstant";
1018 return nullptr;
1019 }
1020 const uint64_t num_elem = constant->GetZeroExtendedValue();
1021 // For now, limit to only 32bits.
1022 if (num_elem > std::numeric_limits<uint32_t>::max()) {
1023 Fail() << "Array type " << type_mgr_->GetId(arr_ty)
1024 << " has too many elements (more than can fit in 32 bits): "
1025 << num_elem;
1026 return nullptr;
1027 }
1028 uint32_t array_stride = 0;
1029 if (!ParseArrayDecorations(arr_ty, &array_stride)) {
1030 return nullptr;
1031 }
1032 if (remap_buffer_block_type_.count(elem_type_id)) {
1033 remap_buffer_block_type_.insert(type_mgr_->GetId(arr_ty));
1034 }
1035 const Type* result =
1036 ty_.Array(ast_elem_ty, static_cast<uint32_t>(num_elem), array_stride);
1037 return MaybeGenerateAlias(type_id, arr_ty, result);
1038 }
1039
ParseArrayDecorations(const spvtools::opt::analysis::Type * spv_type,uint32_t * array_stride)1040 bool ParserImpl::ParseArrayDecorations(
1041 const spvtools::opt::analysis::Type* spv_type,
1042 uint32_t* array_stride) {
1043 bool has_array_stride = false;
1044 *array_stride = 0; // Implicit stride case.
1045 const auto type_id = type_mgr_->GetId(spv_type);
1046 for (auto& decoration : this->GetDecorationsFor(type_id)) {
1047 if (decoration.size() == 2 && decoration[0] == SpvDecorationArrayStride) {
1048 const auto stride = decoration[1];
1049 if (stride == 0) {
1050 return Fail() << "invalid array type ID " << type_id
1051 << ": ArrayStride can't be 0";
1052 }
1053 if (has_array_stride) {
1054 return Fail() << "invalid array type ID " << type_id
1055 << ": multiple ArrayStride decorations";
1056 }
1057 has_array_stride = true;
1058 *array_stride = stride;
1059 } else {
1060 return Fail() << "invalid array type ID " << type_id
1061 << ": unknown decoration "
1062 << (decoration.empty() ? "(empty)"
1063 : std::to_string(decoration[0]))
1064 << " with " << decoration.size() << " total words";
1065 }
1066 }
1067 return true;
1068 }
1069
ConvertType(uint32_t type_id,const spvtools::opt::analysis::Struct * struct_ty)1070 const Type* ParserImpl::ConvertType(
1071 uint32_t type_id,
1072 const spvtools::opt::analysis::Struct* struct_ty) {
1073 // Compute the struct decoration.
1074 auto struct_decorations = this->GetDecorationsFor(type_id);
1075 bool is_block_decorated = false;
1076 if (struct_decorations.size() == 1) {
1077 const auto decoration = struct_decorations[0][0];
1078 if (decoration == SpvDecorationBlock) {
1079 is_block_decorated = true;
1080 } else if (decoration == SpvDecorationBufferBlock) {
1081 is_block_decorated = true;
1082 remap_buffer_block_type_.insert(type_id);
1083 } else {
1084 Fail() << "struct with ID " << type_id
1085 << " has unrecognized decoration: " << int(decoration);
1086 }
1087 } else if (struct_decorations.size() > 1) {
1088 Fail() << "can't handle a struct with more than one decoration: struct "
1089 << type_id << " has " << struct_decorations.size();
1090 return nullptr;
1091 }
1092
1093 // Compute members
1094 ast::StructMemberList ast_members;
1095 const auto members = struct_ty->element_types();
1096 if (members.empty()) {
1097 Fail() << "WGSL does not support empty structures. can't convert type: "
1098 << def_use_mgr_->GetDef(type_id)->PrettyPrint();
1099 return nullptr;
1100 }
1101 TypeList ast_member_types;
1102 unsigned num_non_writable_members = 0;
1103 for (uint32_t member_index = 0; member_index < members.size();
1104 ++member_index) {
1105 const auto member_type_id = type_mgr_->GetId(members[member_index]);
1106 auto* ast_member_ty = ConvertType(member_type_id);
1107 if (ast_member_ty == nullptr) {
1108 // Already emitted diagnostics.
1109 return nullptr;
1110 }
1111
1112 ast_member_types.emplace_back(ast_member_ty);
1113
1114 // Scan member for built-in decorations. Some vertex built-ins are handled
1115 // specially, and should not generate a structure member.
1116 bool create_ast_member = true;
1117 for (auto& decoration : GetDecorationsForMember(type_id, member_index)) {
1118 if (decoration.empty()) {
1119 Fail() << "malformed SPIR-V decoration: it's empty";
1120 return nullptr;
1121 }
1122 if ((decoration[0] == SpvDecorationBuiltIn) && (decoration.size() > 1)) {
1123 switch (decoration[1]) {
1124 case SpvBuiltInPosition:
1125 // Record this built-in variable specially.
1126 builtin_position_.struct_type_id = type_id;
1127 builtin_position_.position_member_index = member_index;
1128 builtin_position_.position_member_type_id = member_type_id;
1129 create_ast_member = false; // Not part of the WGSL structure.
1130 break;
1131 case SpvBuiltInPointSize: // not supported in WGSL, but ignore
1132 builtin_position_.pointsize_member_index = member_index;
1133 create_ast_member = false; // Not part of the WGSL structure.
1134 break;
1135 case SpvBuiltInClipDistance: // not supported in WGSL
1136 case SpvBuiltInCullDistance: // not supported in WGSL
1137 create_ast_member = false; // Not part of the WGSL structure.
1138 break;
1139 default:
1140 Fail() << "unrecognized builtin " << decoration[1];
1141 return nullptr;
1142 }
1143 }
1144 }
1145 if (!create_ast_member) {
1146 // This member is decorated as a built-in, and is handled specially.
1147 continue;
1148 }
1149
1150 bool is_non_writable = false;
1151 ast::DecorationList ast_member_decorations;
1152 for (auto& decoration : GetDecorationsForMember(type_id, member_index)) {
1153 if (IsPipelineDecoration(decoration)) {
1154 // IO decorations are handled when emitting the entry point.
1155 continue;
1156 } else if (decoration[0] == SpvDecorationNonWritable) {
1157 // WGSL doesn't represent individual members as non-writable. Instead,
1158 // apply the ReadOnly access control to the containing struct if all
1159 // the members are non-writable.
1160 is_non_writable = true;
1161 } else {
1162 auto decos = ConvertMemberDecoration(type_id, member_index,
1163 ast_member_ty, decoration);
1164 for (auto* deco : decos) {
1165 ast_member_decorations.emplace_back(deco);
1166 }
1167 if (!success_) {
1168 return nullptr;
1169 }
1170 }
1171 }
1172
1173 if (is_non_writable) {
1174 // Count a member as non-writable only once, no matter how many
1175 // NonWritable decorations are applied to it.
1176 ++num_non_writable_members;
1177 }
1178 const auto member_name = namer_.GetMemberName(type_id, member_index);
1179 auto* ast_struct_member = create<ast::StructMember>(
1180 Source{}, builder_.Symbols().Register(member_name),
1181 ast_member_ty->Build(builder_), std::move(ast_member_decorations));
1182 ast_members.push_back(ast_struct_member);
1183 }
1184
1185 if (ast_members.empty()) {
1186 // All members were likely built-ins. Don't generate an empty AST structure.
1187 return nullptr;
1188 }
1189
1190 namer_.SuggestSanitizedName(type_id, "S");
1191
1192 auto name = namer_.GetName(type_id);
1193
1194 // Now make the struct.
1195 auto sym = builder_.Symbols().Register(name);
1196 ast::DecorationList ast_struct_decorations;
1197 if (is_block_decorated && struct_types_for_buffers_.count(type_id)) {
1198 ast_struct_decorations.emplace_back(
1199 create<ast::StructBlockDecoration>(Source{}));
1200 }
1201 auto* ast_struct = create<ast::Struct>(Source{}, sym, std::move(ast_members),
1202 std::move(ast_struct_decorations));
1203 if (num_non_writable_members == members.size()) {
1204 read_only_struct_types_.insert(ast_struct->name);
1205 }
1206 AddTypeDecl(sym, ast_struct);
1207 const auto* result = ty_.Struct(sym, std::move(ast_member_types));
1208 struct_id_for_symbol_[sym] = type_id;
1209 return result;
1210 }
1211
AddTypeDecl(Symbol name,const ast::TypeDecl * decl)1212 void ParserImpl::AddTypeDecl(Symbol name, const ast::TypeDecl* decl) {
1213 auto iter = declared_types_.insert(name);
1214 if (iter.second) {
1215 builder_.AST().AddTypeDecl(decl);
1216 }
1217 }
1218
ConvertType(uint32_t type_id,PtrAs ptr_as,const spvtools::opt::analysis::Pointer *)1219 const Type* ParserImpl::ConvertType(uint32_t type_id,
1220 PtrAs ptr_as,
1221 const spvtools::opt::analysis::Pointer*) {
1222 const auto* inst = def_use_mgr_->GetDef(type_id);
1223 const auto pointee_type_id = inst->GetSingleWordInOperand(1);
1224 const auto storage_class = SpvStorageClass(inst->GetSingleWordInOperand(0));
1225
1226 if (pointee_type_id == builtin_position_.struct_type_id) {
1227 builtin_position_.pointer_type_id = type_id;
1228 // Pipeline IO builtins map to private variables.
1229 builtin_position_.storage_class = SpvStorageClassPrivate;
1230 return nullptr;
1231 }
1232 auto* ast_elem_ty = ConvertType(pointee_type_id, PtrAs::Ptr);
1233 if (ast_elem_ty == nullptr) {
1234 Fail() << "SPIR-V pointer type with ID " << type_id
1235 << " has invalid pointee type " << pointee_type_id;
1236 return nullptr;
1237 }
1238
1239 auto ast_storage_class = enum_converter_.ToStorageClass(storage_class);
1240 if (ast_storage_class == ast::StorageClass::kInvalid) {
1241 Fail() << "SPIR-V pointer type with ID " << type_id
1242 << " has invalid storage class "
1243 << static_cast<uint32_t>(storage_class);
1244 return nullptr;
1245 }
1246 if (ast_storage_class == ast::StorageClass::kUniform &&
1247 remap_buffer_block_type_.count(pointee_type_id)) {
1248 ast_storage_class = ast::StorageClass::kStorage;
1249 remap_buffer_block_type_.insert(type_id);
1250 }
1251
1252 // Pipeline input and output variables map to private variables.
1253 if (ast_storage_class == ast::StorageClass::kInput ||
1254 ast_storage_class == ast::StorageClass::kOutput) {
1255 ast_storage_class = ast::StorageClass::kPrivate;
1256 }
1257 switch (ptr_as) {
1258 case PtrAs::Ref:
1259 return ty_.Reference(ast_elem_ty, ast_storage_class);
1260 case PtrAs::Ptr:
1261 return ty_.Pointer(ast_elem_ty, ast_storage_class);
1262 }
1263 Fail() << "invalid value for ptr_as: " << static_cast<int>(ptr_as);
1264 return nullptr;
1265 }
1266
RegisterTypes()1267 bool ParserImpl::RegisterTypes() {
1268 if (!success_) {
1269 return false;
1270 }
1271
1272 // First record the structure types that should have a `block` decoration
1273 // in WGSL. In particular, exclude user-defined pipeline IO in a
1274 // block-decorated struct.
1275 for (const auto& type_or_value : module_->types_values()) {
1276 if (type_or_value.opcode() != SpvOpVariable) {
1277 continue;
1278 }
1279 const auto& var = type_or_value;
1280 const auto spirv_storage_class =
1281 SpvStorageClass(var.GetSingleWordInOperand(0));
1282 if ((spirv_storage_class != SpvStorageClassStorageBuffer) &&
1283 (spirv_storage_class != SpvStorageClassUniform)) {
1284 continue;
1285 }
1286 const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
1287 if (ptr_type->opcode() != SpvOpTypePointer) {
1288 return Fail() << "OpVariable type expected to be a pointer: "
1289 << var.PrettyPrint();
1290 }
1291 const auto* store_type =
1292 def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
1293 if (store_type->opcode() == SpvOpTypeStruct) {
1294 struct_types_for_buffers_.insert(store_type->result_id());
1295 } else {
1296 Fail() << "WGSL does not support arrays of buffers: "
1297 << var.PrettyPrint();
1298 }
1299 }
1300
1301 // Now convert each type.
1302 for (auto& type_or_const : module_->types_values()) {
1303 const auto* type = type_mgr_->GetType(type_or_const.result_id());
1304 if (type == nullptr) {
1305 continue;
1306 }
1307 ConvertType(type_or_const.result_id());
1308 }
1309 // Manufacture a type for the gl_Position variable if we have to.
1310 if ((builtin_position_.struct_type_id != 0) &&
1311 (builtin_position_.position_member_pointer_type_id == 0)) {
1312 builtin_position_.position_member_pointer_type_id =
1313 type_mgr_->FindPointerToType(builtin_position_.position_member_type_id,
1314 builtin_position_.storage_class);
1315 ConvertType(builtin_position_.position_member_pointer_type_id);
1316 }
1317 return success_;
1318 }
1319
RejectInvalidPointerRoots()1320 bool ParserImpl::RejectInvalidPointerRoots() {
1321 if (!success_) {
1322 return false;
1323 }
1324 for (auto& inst : module_->types_values()) {
1325 if (const auto* result_type = type_mgr_->GetType(inst.type_id())) {
1326 if (result_type->AsPointer()) {
1327 switch (inst.opcode()) {
1328 case SpvOpVariable:
1329 // This is the only valid case.
1330 break;
1331 case SpvOpUndef:
1332 return Fail() << "undef pointer is not valid: "
1333 << inst.PrettyPrint();
1334 case SpvOpConstantNull:
1335 return Fail() << "null pointer is not valid: "
1336 << inst.PrettyPrint();
1337 default:
1338 return Fail() << "module-scope pointer is not valid: "
1339 << inst.PrettyPrint();
1340 }
1341 }
1342 }
1343 }
1344 return success();
1345 }
1346
EmitScalarSpecConstants()1347 bool ParserImpl::EmitScalarSpecConstants() {
1348 if (!success_) {
1349 return false;
1350 }
1351 // Generate a module-scope const declaration for each instruction
1352 // that is OpSpecConstantTrue, OpSpecConstantFalse, or OpSpecConstant.
1353 for (auto& inst : module_->types_values()) {
1354 // These will be populated for a valid scalar spec constant.
1355 const Type* ast_type = nullptr;
1356 ast::LiteralExpression* ast_expr = nullptr;
1357
1358 switch (inst.opcode()) {
1359 case SpvOpSpecConstantTrue:
1360 case SpvOpSpecConstantFalse: {
1361 ast_type = ConvertType(inst.type_id());
1362 ast_expr = create<ast::BoolLiteralExpression>(
1363 Source{}, inst.opcode() == SpvOpSpecConstantTrue);
1364 break;
1365 }
1366 case SpvOpSpecConstant: {
1367 ast_type = ConvertType(inst.type_id());
1368 const uint32_t literal_value = inst.GetSingleWordInOperand(0);
1369 if (ast_type->Is<I32>()) {
1370 ast_expr = create<ast::SintLiteralExpression>(
1371 Source{}, static_cast<int32_t>(literal_value));
1372 } else if (ast_type->Is<U32>()) {
1373 ast_expr = create<ast::UintLiteralExpression>(
1374 Source{}, static_cast<uint32_t>(literal_value));
1375 } else if (ast_type->Is<F32>()) {
1376 float float_value;
1377 // Copy the bits so we can read them as a float.
1378 std::memcpy(&float_value, &literal_value, sizeof(float_value));
1379 ast_expr = create<ast::FloatLiteralExpression>(Source{}, float_value);
1380 } else {
1381 return Fail() << " invalid result type for OpSpecConstant "
1382 << inst.PrettyPrint();
1383 }
1384 break;
1385 }
1386 default:
1387 break;
1388 }
1389 if (ast_type && ast_expr) {
1390 ast::DecorationList spec_id_decos;
1391 for (const auto& deco : GetDecorationsFor(inst.result_id())) {
1392 if ((deco.size() == 2) && (deco[0] == SpvDecorationSpecId)) {
1393 const uint32_t id = deco[1];
1394 if (id > 65535) {
1395 return Fail() << "SpecId too large. WGSL override IDs must be "
1396 "between 0 and 65535: ID %"
1397 << inst.result_id() << " has SpecId " << id;
1398 }
1399 auto* cid = create<ast::OverrideDecoration>(Source{}, id);
1400 spec_id_decos.push_back(cid);
1401 break;
1402 }
1403 }
1404 auto* ast_var =
1405 MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type,
1406 true, ast_expr, std::move(spec_id_decos));
1407 if (ast_var) {
1408 builder_.AST().AddGlobalVariable(ast_var);
1409 scalar_spec_constants_.insert(inst.result_id());
1410 }
1411 }
1412 }
1413 return success_;
1414 }
1415
MaybeGenerateAlias(uint32_t type_id,const spvtools::opt::analysis::Type * type,const Type * ast_type)1416 const Type* ParserImpl::MaybeGenerateAlias(
1417 uint32_t type_id,
1418 const spvtools::opt::analysis::Type* type,
1419 const Type* ast_type) {
1420 if (!success_) {
1421 return nullptr;
1422 }
1423
1424 // We only care about arrays, and runtime arrays.
1425 switch (type->kind()) {
1426 case spvtools::opt::analysis::Type::kRuntimeArray:
1427 // Runtime arrays are always decorated with ArrayStride so always get a
1428 // type alias.
1429 namer_.SuggestSanitizedName(type_id, "RTArr");
1430 break;
1431 case spvtools::opt::analysis::Type::kArray:
1432 // Only make a type aliase for arrays with decorations.
1433 if (GetDecorationsFor(type_id).empty()) {
1434 return ast_type;
1435 }
1436 namer_.SuggestSanitizedName(type_id, "Arr");
1437 break;
1438 default:
1439 // Ignore constants, and any other types.
1440 return ast_type;
1441 }
1442 auto* ast_underlying_type = ast_type;
1443 if (ast_underlying_type == nullptr) {
1444 Fail() << "internal error: no type registered for SPIR-V ID: " << type_id;
1445 return nullptr;
1446 }
1447 const auto name = namer_.GetName(type_id);
1448 const auto sym = builder_.Symbols().Register(name);
1449 auto* ast_alias_type =
1450 builder_.ty.alias(sym, ast_underlying_type->Build(builder_));
1451
1452 // Record this new alias as the AST type for this SPIR-V ID.
1453 AddTypeDecl(sym, ast_alias_type);
1454
1455 return ty_.Alias(sym, ast_underlying_type);
1456 }
1457
EmitModuleScopeVariables()1458 bool ParserImpl::EmitModuleScopeVariables() {
1459 if (!success_) {
1460 return false;
1461 }
1462 for (const auto& type_or_value : module_->types_values()) {
1463 if (type_or_value.opcode() != SpvOpVariable) {
1464 continue;
1465 }
1466 const auto& var = type_or_value;
1467 const auto spirv_storage_class =
1468 SpvStorageClass(var.GetSingleWordInOperand(0));
1469
1470 uint32_t type_id = var.type_id();
1471 if ((type_id == builtin_position_.pointer_type_id) &&
1472 ((spirv_storage_class == SpvStorageClassInput) ||
1473 (spirv_storage_class == SpvStorageClassOutput))) {
1474 // Skip emitting gl_PerVertex.
1475 builtin_position_.per_vertex_var_id = var.result_id();
1476 builtin_position_.per_vertex_var_init_id =
1477 var.NumInOperands() > 1 ? var.GetSingleWordInOperand(1) : 0u;
1478 continue;
1479 }
1480 switch (enum_converter_.ToStorageClass(spirv_storage_class)) {
1481 case ast::StorageClass::kNone:
1482 case ast::StorageClass::kInput:
1483 case ast::StorageClass::kOutput:
1484 case ast::StorageClass::kUniform:
1485 case ast::StorageClass::kUniformConstant:
1486 case ast::StorageClass::kStorage:
1487 case ast::StorageClass::kImage:
1488 case ast::StorageClass::kWorkgroup:
1489 case ast::StorageClass::kPrivate:
1490 break;
1491 default:
1492 return Fail() << "invalid SPIR-V storage class "
1493 << int(spirv_storage_class)
1494 << " for module scope variable: " << var.PrettyPrint();
1495 }
1496 if (!success_) {
1497 return false;
1498 }
1499 const Type* ast_type = nullptr;
1500 if (spirv_storage_class == SpvStorageClassUniformConstant) {
1501 // These are opaque handles: samplers or textures
1502 ast_type = GetTypeForHandleVar(var);
1503 if (!ast_type) {
1504 return false;
1505 }
1506 } else {
1507 ast_type = ConvertType(type_id);
1508 if (ast_type == nullptr) {
1509 return Fail() << "internal error: failed to register Tint AST type for "
1510 "SPIR-V type with ID: "
1511 << var.type_id();
1512 }
1513 if (!ast_type->Is<Pointer>()) {
1514 return Fail() << "variable with ID " << var.result_id()
1515 << " has non-pointer type " << var.type_id();
1516 }
1517 }
1518
1519 auto* ast_store_type = ast_type->As<Pointer>()->type;
1520 auto ast_storage_class = ast_type->As<Pointer>()->storage_class;
1521 const ast::Expression* ast_constructor = nullptr;
1522 if (var.NumInOperands() > 1) {
1523 // SPIR-V initializers are always constants.
1524 // (OpenCL also allows the ID of an OpVariable, but we don't handle that
1525 // here.)
1526 ast_constructor =
1527 MakeConstantExpression(var.GetSingleWordInOperand(1)).expr;
1528 }
1529 auto* ast_var =
1530 MakeVariable(var.result_id(), ast_storage_class, ast_store_type, false,
1531 ast_constructor, ast::DecorationList{});
1532 // TODO(dneto): initializers (a.k.a. constructor expression)
1533 if (ast_var) {
1534 builder_.AST().AddGlobalVariable(ast_var);
1535 }
1536 }
1537
1538 // Emit gl_Position instead of gl_PerVertex
1539 if (builtin_position_.per_vertex_var_id) {
1540 // Make sure the variable has a name.
1541 namer_.SuggestSanitizedName(builtin_position_.per_vertex_var_id,
1542 "gl_Position");
1543 const ast::Expression* ast_constructor = nullptr;
1544 if (builtin_position_.per_vertex_var_init_id) {
1545 // The initializer is complex.
1546 const auto* init =
1547 def_use_mgr_->GetDef(builtin_position_.per_vertex_var_init_id);
1548 switch (init->opcode()) {
1549 case SpvOpConstantComposite:
1550 case SpvOpSpecConstantComposite:
1551 ast_constructor = MakeConstantExpression(
1552 init->GetSingleWordInOperand(
1553 builtin_position_.position_member_index))
1554 .expr;
1555 break;
1556 default:
1557 return Fail() << "gl_PerVertex initializer too complex. only "
1558 "OpCompositeConstruct and OpSpecConstantComposite "
1559 "are supported: "
1560 << init->PrettyPrint();
1561 }
1562 }
1563 auto* ast_var = MakeVariable(
1564 builtin_position_.per_vertex_var_id,
1565 enum_converter_.ToStorageClass(builtin_position_.storage_class),
1566 ConvertType(builtin_position_.position_member_type_id), false,
1567 ast_constructor, {});
1568
1569 builder_.AST().AddGlobalVariable(ast_var);
1570 }
1571 return success_;
1572 }
1573
1574 // @param var_id SPIR-V id of an OpVariable, assumed to be pointer
1575 // to an array
1576 // @returns the IntConstant for the size of the array, or nullptr
GetArraySize(uint32_t var_id)1577 const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(
1578 uint32_t var_id) {
1579 auto* var = def_use_mgr_->GetDef(var_id);
1580 if (!var || var->opcode() != SpvOpVariable) {
1581 return nullptr;
1582 }
1583 auto* ptr_type = def_use_mgr_->GetDef(var->type_id());
1584 if (!ptr_type || ptr_type->opcode() != SpvOpTypePointer) {
1585 return nullptr;
1586 }
1587 auto* array_type = def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
1588 if (!array_type || array_type->opcode() != SpvOpTypeArray) {
1589 return nullptr;
1590 }
1591 auto* size = constant_mgr_->FindDeclaredConstant(
1592 array_type->GetSingleWordInOperand(1));
1593 if (!size) {
1594 return nullptr;
1595 }
1596 return size->AsIntConstant();
1597 }
1598
MakeVariable(uint32_t id,ast::StorageClass sc,const Type * storage_type,bool is_const,const ast::Expression * constructor,ast::DecorationList decorations)1599 ast::Variable* ParserImpl::MakeVariable(uint32_t id,
1600 ast::StorageClass sc,
1601 const Type* storage_type,
1602 bool is_const,
1603 const ast::Expression* constructor,
1604 ast::DecorationList decorations) {
1605 if (storage_type == nullptr) {
1606 Fail() << "internal error: can't make ast::Variable for null type";
1607 return nullptr;
1608 }
1609
1610 ast::Access access = ast::Access::kUndefined;
1611 if (sc == ast::StorageClass::kStorage) {
1612 bool read_only = false;
1613 if (auto* tn = storage_type->As<Named>()) {
1614 read_only = read_only_struct_types_.count(tn->name) > 0;
1615 }
1616
1617 // Apply the access(read) or access(read_write) modifier.
1618 access = read_only ? ast::Access::kRead : ast::Access::kReadWrite;
1619 }
1620
1621 // Handle variables (textures and samplers) are always in the handle
1622 // storage class, so we don't mention the storage class.
1623 if (sc == ast::StorageClass::kUniformConstant) {
1624 sc = ast::StorageClass::kNone;
1625 }
1626
1627 if (!ConvertDecorationsForVariable(id, &storage_type, &decorations,
1628 sc != ast::StorageClass::kPrivate)) {
1629 return nullptr;
1630 }
1631
1632 std::string name = namer_.Name(id);
1633
1634 // Note: we're constructing the variable here with the *storage* type,
1635 // regardless of whether this is a `let` or `var` declaration.
1636 // `var` declarations will have a resolved type of ref<storage>, but at the
1637 // AST level both `var` and `let` are declared with the same type.
1638 return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc,
1639 access, storage_type->Build(builder_), is_const,
1640 constructor, decorations);
1641 }
1642
ConvertDecorationsForVariable(uint32_t id,const Type ** store_type,ast::DecorationList * decorations,bool transfer_pipeline_io)1643 bool ParserImpl::ConvertDecorationsForVariable(uint32_t id,
1644 const Type** store_type,
1645 ast::DecorationList* decorations,
1646 bool transfer_pipeline_io) {
1647 DecorationList non_builtin_pipeline_decorations;
1648 for (auto& deco : GetDecorationsFor(id)) {
1649 if (deco.empty()) {
1650 return Fail() << "malformed decoration on ID " << id << ": it is empty";
1651 }
1652 if (deco[0] == SpvDecorationBuiltIn) {
1653 if (deco.size() == 1) {
1654 return Fail() << "malformed BuiltIn decoration on ID " << id
1655 << ": has no operand";
1656 }
1657 const auto spv_builtin = static_cast<SpvBuiltIn>(deco[1]);
1658 switch (spv_builtin) {
1659 case SpvBuiltInPointSize:
1660 special_builtins_[id] = spv_builtin;
1661 return false; // This is not an error
1662 case SpvBuiltInSampleId:
1663 case SpvBuiltInVertexIndex:
1664 case SpvBuiltInInstanceIndex:
1665 case SpvBuiltInLocalInvocationId:
1666 case SpvBuiltInLocalInvocationIndex:
1667 case SpvBuiltInGlobalInvocationId:
1668 case SpvBuiltInWorkgroupId:
1669 case SpvBuiltInNumWorkgroups:
1670 // The SPIR-V variable may signed (because GLSL requires signed for
1671 // some of these), but WGSL requires unsigned. Handle specially
1672 // so we always perform the conversion at load and store.
1673 special_builtins_[id] = spv_builtin;
1674 if (auto* forced_type = UnsignedTypeFor(*store_type)) {
1675 // Requires conversion and special handling in code generation.
1676 if (transfer_pipeline_io) {
1677 *store_type = forced_type;
1678 }
1679 }
1680 break;
1681 case SpvBuiltInSampleMask: {
1682 // In SPIR-V this is used for both input and output variable.
1683 // The SPIR-V variable has store type of array of integer scalar,
1684 // either signed or unsigned.
1685 // WGSL requires the store type to be u32.
1686 auto* size = GetArraySize(id);
1687 if (!size || size->GetZeroExtendedValue() != 1) {
1688 Fail() << "WGSL supports a sample mask of at most 32 bits. "
1689 "SampleMask must be an array of 1 element.";
1690 }
1691 special_builtins_[id] = spv_builtin;
1692 if (transfer_pipeline_io) {
1693 *store_type = ty_.U32();
1694 }
1695 break;
1696 }
1697 default:
1698 break;
1699 }
1700 auto ast_builtin = enum_converter_.ToBuiltin(spv_builtin);
1701 if (ast_builtin == ast::Builtin::kNone) {
1702 // A diagnostic has already been emitted.
1703 return false;
1704 }
1705 if (transfer_pipeline_io) {
1706 decorations->emplace_back(
1707 create<ast::BuiltinDecoration>(Source{}, ast_builtin));
1708 }
1709 }
1710 if (transfer_pipeline_io && IsPipelineDecoration(deco)) {
1711 non_builtin_pipeline_decorations.push_back(deco);
1712 }
1713 if (deco[0] == SpvDecorationDescriptorSet) {
1714 if (deco.size() == 1) {
1715 return Fail() << "malformed DescriptorSet decoration on ID " << id
1716 << ": has no operand";
1717 }
1718 decorations->emplace_back(
1719 create<ast::GroupDecoration>(Source{}, deco[1]));
1720 }
1721 if (deco[0] == SpvDecorationBinding) {
1722 if (deco.size() == 1) {
1723 return Fail() << "malformed Binding decoration on ID " << id
1724 << ": has no operand";
1725 }
1726 decorations->emplace_back(
1727 create<ast::BindingDecoration>(Source{}, deco[1]));
1728 }
1729 }
1730
1731 if (transfer_pipeline_io) {
1732 if (!ConvertPipelineDecorations(
1733 *store_type, non_builtin_pipeline_decorations, decorations)) {
1734 return false;
1735 }
1736 }
1737
1738 return success();
1739 }
1740
GetMemberPipelineDecorations(const Struct & struct_type,int member_index)1741 DecorationList ParserImpl::GetMemberPipelineDecorations(
1742 const Struct& struct_type,
1743 int member_index) {
1744 // Yes, I could have used std::copy_if or std::copy_if.
1745 DecorationList result;
1746 for (const auto& deco : GetDecorationsForMember(
1747 struct_id_for_symbol_[struct_type.name], member_index)) {
1748 if (IsPipelineDecoration(deco)) {
1749 result.emplace_back(deco);
1750 }
1751 }
1752 return result;
1753 }
1754
SetLocation(ast::DecorationList * decos,const ast::Decoration * replacement)1755 const ast::Decoration* ParserImpl::SetLocation(
1756 ast::DecorationList* decos,
1757 const ast::Decoration* replacement) {
1758 if (!replacement) {
1759 return nullptr;
1760 }
1761 for (auto*& deco : *decos) {
1762 if (deco->Is<ast::LocationDecoration>()) {
1763 // Replace this location decoration with the replacement.
1764 // The old one doesn't leak because it's kept in the builder's AST node
1765 // list.
1766 const ast::Decoration* result = nullptr;
1767 result = deco;
1768 deco = replacement;
1769 return result; // Assume there is only one such decoration.
1770 }
1771 }
1772 // The list didn't have a location. Add it.
1773 decos->push_back(replacement);
1774 return nullptr;
1775 }
1776
ConvertPipelineDecorations(const Type * store_type,const DecorationList & decorations,ast::DecorationList * ast_decos)1777 bool ParserImpl::ConvertPipelineDecorations(const Type* store_type,
1778 const DecorationList& decorations,
1779 ast::DecorationList* ast_decos) {
1780 // Vulkan defaults to perspective-correct interpolation.
1781 ast::InterpolationType type = ast::InterpolationType::kPerspective;
1782 ast::InterpolationSampling sampling = ast::InterpolationSampling::kNone;
1783
1784 for (const auto& deco : decorations) {
1785 TINT_ASSERT(Reader, deco.size() > 0);
1786 switch (deco[0]) {
1787 case SpvDecorationLocation:
1788 if (deco.size() != 2) {
1789 return Fail() << "malformed Location decoration on ID requires one "
1790 "literal operand";
1791 }
1792 SetLocation(ast_decos,
1793 create<ast::LocationDecoration>(Source{}, deco[1]));
1794 break;
1795 case SpvDecorationFlat:
1796 type = ast::InterpolationType::kFlat;
1797 break;
1798 case SpvDecorationNoPerspective:
1799 if (store_type->IsIntegerScalarOrVector()) {
1800 // This doesn't capture the array or struct case.
1801 return Fail() << "NoPerspective is invalid on integral IO";
1802 }
1803 type = ast::InterpolationType::kLinear;
1804 break;
1805 case SpvDecorationCentroid:
1806 if (store_type->IsIntegerScalarOrVector()) {
1807 // This doesn't capture the array or struct case.
1808 return Fail()
1809 << "Centroid interpolation sampling is invalid on integral IO";
1810 }
1811 sampling = ast::InterpolationSampling::kCentroid;
1812 break;
1813 case SpvDecorationSample:
1814 if (store_type->IsIntegerScalarOrVector()) {
1815 // This doesn't capture the array or struct case.
1816 return Fail()
1817 << "Sample interpolation sampling is invalid on integral IO";
1818 }
1819 sampling = ast::InterpolationSampling::kSample;
1820 break;
1821 default:
1822 break;
1823 }
1824 }
1825
1826 // Apply interpolation.
1827 if (type == ast::InterpolationType::kPerspective &&
1828 sampling == ast::InterpolationSampling::kNone) {
1829 // This is the default. Don't add a decoration.
1830 } else {
1831 ast_decos->emplace_back(create<ast::InterpolateDecoration>(type, sampling));
1832 }
1833
1834 return success();
1835 }
1836
CanMakeConstantExpression(uint32_t id)1837 bool ParserImpl::CanMakeConstantExpression(uint32_t id) {
1838 if ((id == workgroup_size_builtin_.id) ||
1839 (id == workgroup_size_builtin_.x_id) ||
1840 (id == workgroup_size_builtin_.y_id) ||
1841 (id == workgroup_size_builtin_.z_id)) {
1842 return true;
1843 }
1844 const auto* inst = def_use_mgr_->GetDef(id);
1845 if (!inst) {
1846 return false;
1847 }
1848 if (inst->opcode() == SpvOpUndef) {
1849 return true;
1850 }
1851 return nullptr != constant_mgr_->FindDeclaredConstant(id);
1852 }
1853
MakeConstantExpression(uint32_t id)1854 TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
1855 if (!success_) {
1856 return {};
1857 }
1858
1859 // Handle the special cases for workgroup sizing.
1860 if (id == workgroup_size_builtin_.id) {
1861 auto x = MakeConstantExpression(workgroup_size_builtin_.x_id);
1862 auto y = MakeConstantExpression(workgroup_size_builtin_.y_id);
1863 auto z = MakeConstantExpression(workgroup_size_builtin_.z_id);
1864 auto* ast_type = ty_.Vector(x.type, 3);
1865 return {ast_type,
1866 builder_.Construct(Source{}, ast_type->Build(builder_),
1867 ast::ExpressionList{x.expr, y.expr, z.expr})};
1868 } else if (id == workgroup_size_builtin_.x_id) {
1869 return MakeConstantExpressionForScalarSpirvConstant(
1870 Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1871 constant_mgr_->GetConstant(
1872 type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1873 {workgroup_size_builtin_.x_value}));
1874 } else if (id == workgroup_size_builtin_.y_id) {
1875 return MakeConstantExpressionForScalarSpirvConstant(
1876 Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1877 constant_mgr_->GetConstant(
1878 type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1879 {workgroup_size_builtin_.y_value}));
1880 } else if (id == workgroup_size_builtin_.z_id) {
1881 return MakeConstantExpressionForScalarSpirvConstant(
1882 Source{}, ConvertType(workgroup_size_builtin_.component_type_id),
1883 constant_mgr_->GetConstant(
1884 type_mgr_->GetType(workgroup_size_builtin_.component_type_id),
1885 {workgroup_size_builtin_.z_value}));
1886 }
1887
1888 // Handle the general case where a constant is already registered
1889 // with the SPIR-V optimizer's analysis framework.
1890 const auto* inst = def_use_mgr_->GetDef(id);
1891 if (inst == nullptr) {
1892 Fail() << "ID " << id << " is not a registered instruction";
1893 return {};
1894 }
1895 auto source = GetSourceForInst(inst);
1896
1897 // TODO(dneto): Handle spec constants too?
1898
1899 auto* original_ast_type = ConvertType(inst->type_id());
1900 if (original_ast_type == nullptr) {
1901 return {};
1902 }
1903
1904 switch (inst->opcode()) {
1905 case SpvOpUndef: // Remap undef to null.
1906 case SpvOpConstantNull:
1907 return {original_ast_type, MakeNullValue(original_ast_type)};
1908 case SpvOpConstantTrue:
1909 case SpvOpConstantFalse:
1910 case SpvOpConstant: {
1911 const auto* spirv_const = constant_mgr_->FindDeclaredConstant(id);
1912 if (spirv_const == nullptr) {
1913 Fail() << "ID " << id << " is not a constant";
1914 return {};
1915 }
1916 return MakeConstantExpressionForScalarSpirvConstant(
1917 source, original_ast_type, spirv_const);
1918 }
1919 case SpvOpConstantComposite: {
1920 // Handle vector, matrix, array, and struct
1921
1922 // Generate a composite from explicit components.
1923 ast::ExpressionList ast_components;
1924 if (!inst->WhileEachInId([&](const uint32_t* id_ref) -> bool {
1925 auto component = MakeConstantExpression(*id_ref);
1926 if (!component) {
1927 this->Fail() << "invalid constant with ID " << *id_ref;
1928 return false;
1929 }
1930 ast_components.emplace_back(component.expr);
1931 return true;
1932 })) {
1933 // We've already emitted a diagnostic.
1934 return {};
1935 }
1936 return {original_ast_type,
1937 builder_.Construct(source, original_ast_type->Build(builder_),
1938 std::move(ast_components))};
1939 }
1940 default:
1941 break;
1942 }
1943 Fail() << "unhandled constant instruction " << inst->PrettyPrint();
1944 return {};
1945 }
1946
MakeConstantExpressionForScalarSpirvConstant(Source source,const Type * original_ast_type,const spvtools::opt::analysis::Constant * spirv_const)1947 TypedExpression ParserImpl::MakeConstantExpressionForScalarSpirvConstant(
1948 Source source,
1949 const Type* original_ast_type,
1950 const spvtools::opt::analysis::Constant* spirv_const) {
1951 auto* ast_type = original_ast_type->UnwrapAlias();
1952
1953 // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
1954 // So canonicalization should map that way too.
1955 // Currently "null<type>" is missing from the WGSL parser.
1956 // See https://bugs.chromium.org/p/tint/issues/detail?id=34
1957 if (ast_type->Is<U32>()) {
1958 return {ty_.U32(),
1959 create<ast::UintLiteralExpression>(source, spirv_const->GetU32())};
1960 }
1961 if (ast_type->Is<I32>()) {
1962 return {ty_.I32(),
1963 create<ast::SintLiteralExpression>(source, spirv_const->GetS32())};
1964 }
1965 if (ast_type->Is<F32>()) {
1966 return {ty_.F32(), create<ast::FloatLiteralExpression>(
1967 source, spirv_const->GetFloat())};
1968 }
1969 if (ast_type->Is<Bool>()) {
1970 const bool value = spirv_const->AsNullConstant()
1971 ? false
1972 : spirv_const->AsBoolConstant()->value();
1973 return {ty_.Bool(), create<ast::BoolLiteralExpression>(source, value)};
1974 }
1975 Fail() << "expected scalar constant";
1976 return {};
1977 }
1978
MakeNullValue(const Type * type)1979 const ast::Expression* ParserImpl::MakeNullValue(const Type* type) {
1980 // TODO(dneto): Use the no-operands constructor syntax when it becomes
1981 // available in Tint.
1982 // https://github.com/gpuweb/gpuweb/issues/685
1983 // https://bugs.chromium.org/p/tint/issues/detail?id=34
1984
1985 if (!type) {
1986 Fail() << "trying to create null value for a null type";
1987 return nullptr;
1988 }
1989
1990 auto* original_type = type;
1991 type = type->UnwrapAlias();
1992
1993 if (type->Is<Bool>()) {
1994 return create<ast::BoolLiteralExpression>(Source{}, false);
1995 }
1996 if (type->Is<U32>()) {
1997 return create<ast::UintLiteralExpression>(Source{}, 0u);
1998 }
1999 if (type->Is<I32>()) {
2000 return create<ast::SintLiteralExpression>(Source{}, 0);
2001 }
2002 if (type->Is<F32>()) {
2003 return create<ast::FloatLiteralExpression>(Source{}, 0.0f);
2004 }
2005 if (type->Is<Alias>()) {
2006 // TODO(amaiorano): No type constructor for TypeName (yet?)
2007 ast::ExpressionList ast_components;
2008 return builder_.Construct(Source{}, original_type->Build(builder_),
2009 std::move(ast_components));
2010 }
2011 if (auto* vec_ty = type->As<Vector>()) {
2012 ast::ExpressionList ast_components;
2013 for (size_t i = 0; i < vec_ty->size; ++i) {
2014 ast_components.emplace_back(MakeNullValue(vec_ty->type));
2015 }
2016 return builder_.Construct(Source{}, type->Build(builder_),
2017 std::move(ast_components));
2018 }
2019 if (auto* mat_ty = type->As<Matrix>()) {
2020 // Matrix components are columns
2021 auto* column_ty = ty_.Vector(mat_ty->type, mat_ty->rows);
2022 ast::ExpressionList ast_components;
2023 for (size_t i = 0; i < mat_ty->columns; ++i) {
2024 ast_components.emplace_back(MakeNullValue(column_ty));
2025 }
2026 return builder_.Construct(Source{}, type->Build(builder_),
2027 std::move(ast_components));
2028 }
2029 if (auto* arr_ty = type->As<Array>()) {
2030 ast::ExpressionList ast_components;
2031 for (size_t i = 0; i < arr_ty->size; ++i) {
2032 ast_components.emplace_back(MakeNullValue(arr_ty->type));
2033 }
2034 return builder_.Construct(Source{}, original_type->Build(builder_),
2035 std::move(ast_components));
2036 }
2037 if (auto* struct_ty = type->As<Struct>()) {
2038 ast::ExpressionList ast_components;
2039 for (auto* member : struct_ty->members) {
2040 ast_components.emplace_back(MakeNullValue(member));
2041 }
2042 return builder_.Construct(Source{}, original_type->Build(builder_),
2043 std::move(ast_components));
2044 }
2045 Fail() << "can't make null value for type: " << type->TypeInfo().name;
2046 return nullptr;
2047 }
2048
MakeNullExpression(const Type * type)2049 TypedExpression ParserImpl::MakeNullExpression(const Type* type) {
2050 return {type, MakeNullValue(type)};
2051 }
2052
UnsignedTypeFor(const Type * type)2053 const Type* ParserImpl::UnsignedTypeFor(const Type* type) {
2054 if (type->Is<I32>()) {
2055 return ty_.U32();
2056 }
2057 if (auto* v = type->As<Vector>()) {
2058 if (v->type->Is<I32>()) {
2059 return ty_.Vector(ty_.U32(), v->size);
2060 }
2061 }
2062 return {};
2063 }
2064
SignedTypeFor(const Type * type)2065 const Type* ParserImpl::SignedTypeFor(const Type* type) {
2066 if (type->Is<U32>()) {
2067 return ty_.I32();
2068 }
2069 if (auto* v = type->As<Vector>()) {
2070 if (v->type->Is<U32>()) {
2071 return ty_.Vector(ty_.I32(), v->size);
2072 }
2073 }
2074 return {};
2075 }
2076
RectifyOperandSignedness(const spvtools::opt::Instruction & inst,TypedExpression && expr)2077 TypedExpression ParserImpl::RectifyOperandSignedness(
2078 const spvtools::opt::Instruction& inst,
2079 TypedExpression&& expr) {
2080 bool requires_signed = false;
2081 bool requires_unsigned = false;
2082 if (IsGlslExtendedInstruction(inst)) {
2083 const auto extended_opcode =
2084 static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
2085 requires_signed = AssumesSignedOperands(extended_opcode);
2086 requires_unsigned = AssumesUnsignedOperands(extended_opcode);
2087 } else {
2088 const auto opcode = inst.opcode();
2089 requires_signed = AssumesSignedOperands(opcode);
2090 requires_unsigned = AssumesUnsignedOperands(opcode);
2091 }
2092 if (!requires_signed && !requires_unsigned) {
2093 // No conversion is required, assuming our tables are complete.
2094 return std::move(expr);
2095 }
2096 if (!expr) {
2097 Fail() << "internal error: RectifyOperandSignedness given a null expr\n";
2098 return {};
2099 }
2100 auto* type = expr.type;
2101 if (!type) {
2102 Fail() << "internal error: unmapped type for: "
2103 << expr.expr->TypeInfo().name << "\n";
2104 return {};
2105 }
2106 if (requires_unsigned) {
2107 if (auto* unsigned_ty = UnsignedTypeFor(type)) {
2108 // Conversion is required.
2109 return {unsigned_ty,
2110 create<ast::BitcastExpression>(
2111 Source{}, unsigned_ty->Build(builder_), expr.expr)};
2112 }
2113 } else if (requires_signed) {
2114 if (auto* signed_ty = SignedTypeFor(type)) {
2115 // Conversion is required.
2116 return {signed_ty, create<ast::BitcastExpression>(
2117 Source{}, signed_ty->Build(builder_), expr.expr)};
2118 }
2119 }
2120 // We should not reach here.
2121 return std::move(expr);
2122 }
2123
RectifySecondOperandSignedness(const spvtools::opt::Instruction & inst,const Type * first_operand_type,TypedExpression && second_operand_expr)2124 TypedExpression ParserImpl::RectifySecondOperandSignedness(
2125 const spvtools::opt::Instruction& inst,
2126 const Type* first_operand_type,
2127 TypedExpression&& second_operand_expr) {
2128 if ((first_operand_type != second_operand_expr.type) &&
2129 AssumesSecondOperandSignednessMatchesFirstOperand(inst.opcode())) {
2130 // Conversion is required.
2131 return {first_operand_type,
2132 create<ast::BitcastExpression>(Source{},
2133 first_operand_type->Build(builder_),
2134 second_operand_expr.expr)};
2135 }
2136 // No conversion necessary.
2137 return std::move(second_operand_expr);
2138 }
2139
ForcedResultType(const spvtools::opt::Instruction & inst,const Type * first_operand_type)2140 const Type* ParserImpl::ForcedResultType(const spvtools::opt::Instruction& inst,
2141 const Type* first_operand_type) {
2142 const auto opcode = inst.opcode();
2143 if (AssumesResultSignednessMatchesFirstOperand(opcode)) {
2144 return first_operand_type;
2145 }
2146 if (IsGlslExtendedInstruction(inst)) {
2147 const auto extended_opcode =
2148 static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
2149 if (AssumesResultSignednessMatchesFirstOperand(extended_opcode)) {
2150 return first_operand_type;
2151 }
2152 }
2153 return nullptr;
2154 }
2155
GetSignedIntMatchingShape(const Type * other)2156 const Type* ParserImpl::GetSignedIntMatchingShape(const Type* other) {
2157 if (other == nullptr) {
2158 Fail() << "no type provided";
2159 }
2160 if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
2161 return ty_.I32();
2162 }
2163 if (auto* vec_ty = other->As<Vector>()) {
2164 return ty_.Vector(ty_.I32(), vec_ty->size);
2165 }
2166 Fail() << "required numeric scalar or vector, but got "
2167 << other->TypeInfo().name;
2168 return nullptr;
2169 }
2170
GetUnsignedIntMatchingShape(const Type * other)2171 const Type* ParserImpl::GetUnsignedIntMatchingShape(const Type* other) {
2172 if (other == nullptr) {
2173 Fail() << "no type provided";
2174 return nullptr;
2175 }
2176 if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
2177 return ty_.U32();
2178 }
2179 if (auto* vec_ty = other->As<Vector>()) {
2180 return ty_.Vector(ty_.U32(), vec_ty->size);
2181 }
2182 Fail() << "required numeric scalar or vector, but got "
2183 << other->TypeInfo().name;
2184 return nullptr;
2185 }
2186
RectifyForcedResultType(TypedExpression expr,const spvtools::opt::Instruction & inst,const Type * first_operand_type)2187 TypedExpression ParserImpl::RectifyForcedResultType(
2188 TypedExpression expr,
2189 const spvtools::opt::Instruction& inst,
2190 const Type* first_operand_type) {
2191 auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
2192 if ((!forced_result_ty) || (forced_result_ty == expr.type)) {
2193 return expr;
2194 }
2195 return {expr.type, create<ast::BitcastExpression>(
2196 Source{}, expr.type->Build(builder_), expr.expr)};
2197 }
2198
AsUnsigned(TypedExpression expr)2199 TypedExpression ParserImpl::AsUnsigned(TypedExpression expr) {
2200 if (expr.type && expr.type->IsSignedScalarOrVector()) {
2201 auto* new_type = GetUnsignedIntMatchingShape(expr.type);
2202 return {new_type, create<ast::BitcastExpression>(
2203 Source{}, new_type->Build(builder_), expr.expr)};
2204 }
2205 return expr;
2206 }
2207
AsSigned(TypedExpression expr)2208 TypedExpression ParserImpl::AsSigned(TypedExpression expr) {
2209 if (expr.type && expr.type->IsUnsignedScalarOrVector()) {
2210 auto* new_type = GetSignedIntMatchingShape(expr.type);
2211 return {new_type, create<ast::BitcastExpression>(
2212 Source{}, new_type->Build(builder_), expr.expr)};
2213 }
2214 return expr;
2215 }
2216
EmitFunctions()2217 bool ParserImpl::EmitFunctions() {
2218 if (!success_) {
2219 return false;
2220 }
2221 for (const auto* f : topologically_ordered_functions_) {
2222 if (!success_) {
2223 return false;
2224 }
2225
2226 auto id = f->result_id();
2227 auto it = function_to_ep_info_.find(id);
2228 if (it == function_to_ep_info_.end()) {
2229 FunctionEmitter emitter(this, *f, nullptr);
2230 success_ = emitter.Emit();
2231 } else {
2232 for (const auto& ep : it->second) {
2233 FunctionEmitter emitter(this, *f, &ep);
2234 success_ = emitter.Emit();
2235 if (!success_) {
2236 return false;
2237 }
2238 }
2239 }
2240 }
2241 return success_;
2242 }
2243
2244 const spvtools::opt::Instruction*
GetMemoryObjectDeclarationForHandle(uint32_t id,bool follow_image)2245 ParserImpl::GetMemoryObjectDeclarationForHandle(uint32_t id,
2246 bool follow_image) {
2247 auto saved_id = id;
2248 auto local_fail = [this, saved_id, id,
2249 follow_image]() -> const spvtools::opt::Instruction* {
2250 const auto* inst = def_use_mgr_->GetDef(id);
2251 Fail() << "Could not find memory object declaration for the "
2252 << (follow_image ? "image" : "sampler") << " underlying id " << id
2253 << " (from original id " << saved_id << ") "
2254 << (inst ? inst->PrettyPrint() : std::string());
2255 return nullptr;
2256 };
2257
2258 auto& memo_table =
2259 (follow_image ? mem_obj_decl_image_ : mem_obj_decl_sampler_);
2260
2261 // Use a visited set to defend against bad input which might have long
2262 // chains or even loops.
2263 std::unordered_set<uint32_t> visited;
2264
2265 // Trace backward in the SSA data flow until we hit a memory object
2266 // declaration.
2267 while (true) {
2268 auto where = memo_table.find(id);
2269 if (where != memo_table.end()) {
2270 return where->second;
2271 }
2272 // Protect against loops.
2273 auto visited_iter = visited.find(id);
2274 if (visited_iter != visited.end()) {
2275 // We've hit a loop. Mark all the visited nodes
2276 // as dead ends.
2277 for (auto iter : visited) {
2278 memo_table[iter] = nullptr;
2279 }
2280 return nullptr;
2281 }
2282 visited.insert(id);
2283
2284 const auto* inst = def_use_mgr_->GetDef(id);
2285 if (inst == nullptr) {
2286 return local_fail();
2287 }
2288 switch (inst->opcode()) {
2289 case SpvOpFunctionParameter:
2290 case SpvOpVariable:
2291 // We found the memory object declaration.
2292 // Remember it as the answer for the whole path.
2293 for (auto iter : visited) {
2294 memo_table[iter] = inst;
2295 }
2296 return inst;
2297 case SpvOpLoad:
2298 // Follow the pointer being loaded
2299 id = inst->GetSingleWordInOperand(0);
2300 break;
2301 case SpvOpCopyObject:
2302 // Follow the object being copied.
2303 id = inst->GetSingleWordInOperand(0);
2304 break;
2305 case SpvOpAccessChain:
2306 case SpvOpInBoundsAccessChain:
2307 case SpvOpPtrAccessChain:
2308 case SpvOpInBoundsPtrAccessChain:
2309 // Follow the base pointer.
2310 id = inst->GetSingleWordInOperand(0);
2311 break;
2312 case SpvOpSampledImage:
2313 // Follow the image or the sampler, depending on the follow_image
2314 // parameter.
2315 id = inst->GetSingleWordInOperand(follow_image ? 0 : 1);
2316 break;
2317 case SpvOpImage:
2318 // Follow the sampled image
2319 id = inst->GetSingleWordInOperand(0);
2320 break;
2321 default:
2322 // Can't trace further.
2323 // Remember it as the answer for the whole path.
2324 for (auto iter : visited) {
2325 memo_table[iter] = nullptr;
2326 }
2327 return nullptr;
2328 }
2329 }
2330 }
2331
2332 const spvtools::opt::Instruction*
GetSpirvTypeForHandleMemoryObjectDeclaration(const spvtools::opt::Instruction & var)2333 ParserImpl::GetSpirvTypeForHandleMemoryObjectDeclaration(
2334 const spvtools::opt::Instruction& var) {
2335 if (!success()) {
2336 return nullptr;
2337 }
2338 // The WGSL handle type is determined by looking at information from
2339 // several sources:
2340 // - the usage of the handle by image access instructions
2341 // - the SPIR-V type declaration
2342 // Each source does not have enough information to completely determine
2343 // the result.
2344
2345 // Messages are phrased in terms of images and samplers because those
2346 // are the only SPIR-V handles supported by WGSL.
2347
2348 // Get the SPIR-V handle type.
2349 const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
2350 if (!ptr_type || (ptr_type->opcode() != SpvOpTypePointer)) {
2351 Fail() << "Invalid type for variable or function parameter "
2352 << var.PrettyPrint();
2353 return nullptr;
2354 }
2355 const auto* raw_handle_type =
2356 def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
2357 if (!raw_handle_type) {
2358 Fail() << "Invalid pointer type for variable or function parameter "
2359 << var.PrettyPrint();
2360 return nullptr;
2361 }
2362 switch (raw_handle_type->opcode()) {
2363 case SpvOpTypeSampler:
2364 case SpvOpTypeImage:
2365 // The expected cases.
2366 break;
2367 case SpvOpTypeArray:
2368 case SpvOpTypeRuntimeArray:
2369 Fail()
2370 << "arrays of textures or samplers are not supported in WGSL; can't "
2371 "translate variable or function parameter: "
2372 << var.PrettyPrint();
2373 return nullptr;
2374 case SpvOpTypeSampledImage:
2375 Fail() << "WGSL does not support combined image-samplers: "
2376 << var.PrettyPrint();
2377 return nullptr;
2378 default:
2379 Fail() << "invalid type for image or sampler variable or function "
2380 "parameter: "
2381 << var.PrettyPrint();
2382 return nullptr;
2383 }
2384 return raw_handle_type;
2385 }
2386
GetTypeForHandleVar(const spvtools::opt::Instruction & var)2387 const Pointer* ParserImpl::GetTypeForHandleVar(
2388 const spvtools::opt::Instruction& var) {
2389 auto where = handle_type_.find(&var);
2390 if (where != handle_type_.end()) {
2391 return where->second;
2392 }
2393
2394 const spvtools::opt::Instruction* raw_handle_type =
2395 GetSpirvTypeForHandleMemoryObjectDeclaration(var);
2396 if (!raw_handle_type) {
2397 return nullptr;
2398 }
2399
2400 // The variable could be a sampler or image.
2401 // Where possible, determine which one it is from the usage inferred
2402 // for the variable.
2403 Usage usage = handle_usage_[&var];
2404 if (!usage.IsValid()) {
2405 Fail() << "Invalid sampler or texture usage for variable "
2406 << var.PrettyPrint() << "\n"
2407 << usage;
2408 return nullptr;
2409 }
2410 // Infer a handle type, if usage didn't already tell us.
2411 if (!usage.IsComplete()) {
2412 // In SPIR-V you could statically reference a texture or sampler without
2413 // using it in a way that gives us a clue on how to declare it. Look inside
2414 // the store type to infer a usage.
2415 if (raw_handle_type->opcode() == SpvOpTypeSampler) {
2416 usage.AddSampler();
2417 } else {
2418 // It's a texture.
2419 if (raw_handle_type->NumInOperands() != 7) {
2420 Fail() << "invalid SPIR-V image type: expected 7 operands: "
2421 << raw_handle_type->PrettyPrint();
2422 return nullptr;
2423 }
2424 const auto sampled_param = raw_handle_type->GetSingleWordInOperand(5);
2425 const auto format_param = raw_handle_type->GetSingleWordInOperand(6);
2426 // Only storage images have a format.
2427 if ((format_param != SpvImageFormatUnknown) ||
2428 sampled_param == 2 /* without sampler */) {
2429 // Get NonWritable and NonReadable attributes of the variable.
2430 bool is_nonwritable = false;
2431 bool is_nonreadable = false;
2432 for (const auto& deco : GetDecorationsFor(var.result_id())) {
2433 if (deco.size() != 1) {
2434 continue;
2435 }
2436 if (deco[0] == SpvDecorationNonWritable) {
2437 is_nonwritable = true;
2438 }
2439 if (deco[0] == SpvDecorationNonReadable) {
2440 is_nonreadable = true;
2441 }
2442 }
2443 if (is_nonwritable && is_nonreadable) {
2444 Fail() << "storage image variable is both NonWritable and NonReadable"
2445 << var.PrettyPrint();
2446 }
2447 if (!is_nonwritable && !is_nonreadable) {
2448 Fail()
2449 << "storage image variable is neither NonWritable nor NonReadable"
2450 << var.PrettyPrint();
2451 }
2452 // Let's make it one of the storage textures.
2453 if (is_nonwritable) {
2454 usage.AddStorageReadTexture();
2455 } else {
2456 usage.AddStorageWriteTexture();
2457 }
2458 } else {
2459 usage.AddSampledTexture();
2460 }
2461 }
2462 if (!usage.IsComplete()) {
2463 Fail()
2464 << "internal error: should have inferred a complete handle type. got "
2465 << usage.to_str();
2466 return nullptr;
2467 }
2468 }
2469
2470 // Construct the Tint handle type.
2471 const Type* ast_store_type = nullptr;
2472 if (usage.IsSampler()) {
2473 ast_store_type = ty_.Sampler(usage.IsComparisonSampler()
2474 ? ast::SamplerKind::kComparisonSampler
2475 : ast::SamplerKind::kSampler);
2476 } else if (usage.IsTexture()) {
2477 const spvtools::opt::analysis::Image* image_type =
2478 type_mgr_->GetType(raw_handle_type->result_id())->AsImage();
2479 if (!image_type) {
2480 Fail() << "internal error: Couldn't look up image type"
2481 << raw_handle_type->PrettyPrint();
2482 return nullptr;
2483 }
2484
2485 if (image_type->is_arrayed()) {
2486 // Give a nicer error message here, where we have the offending variable
2487 // in hand, rather than inside the enum converter.
2488 switch (image_type->dim()) {
2489 case SpvDim2D:
2490 case SpvDimCube:
2491 break;
2492 default:
2493 Fail() << "WGSL arrayed textures must be 2d_array or cube_array: "
2494 "invalid multisampled texture variable "
2495 << namer_.Name(var.result_id()) << ": " << var.PrettyPrint();
2496 return nullptr;
2497 }
2498 }
2499
2500 const ast::TextureDimension dim =
2501 enum_converter_.ToDim(image_type->dim(), image_type->is_arrayed());
2502 if (dim == ast::TextureDimension::kNone) {
2503 return nullptr;
2504 }
2505
2506 // WGSL textures are always formatted. Unformatted textures are always
2507 // sampled.
2508 if (usage.IsSampledTexture() || usage.IsStorageReadTexture() ||
2509 (image_type->format() == SpvImageFormatUnknown)) {
2510 // Make a sampled texture type.
2511 auto* ast_sampled_component_type =
2512 ConvertType(raw_handle_type->GetSingleWordInOperand(0));
2513
2514 // Vulkan ignores the depth parameter on OpImage, so pay attention to the
2515 // usage as well. That is, it's valid for a Vulkan shader to use an
2516 // OpImage variable with an OpImage*Dref* instruction. In WGSL we must
2517 // treat that as a depth texture.
2518 if (image_type->depth() || usage.IsDepthTexture()) {
2519 if (image_type->is_multisampled()) {
2520 ast_store_type = ty_.DepthMultisampledTexture(dim);
2521 } else {
2522 ast_store_type = ty_.DepthTexture(dim);
2523 }
2524 } else if (image_type->is_multisampled()) {
2525 if (dim != ast::TextureDimension::k2d) {
2526 Fail() << "WGSL multisampled textures must be 2d and non-arrayed: "
2527 "invalid multisampled texture variable "
2528 << namer_.Name(var.result_id()) << ": " << var.PrettyPrint();
2529 }
2530 // Multisampled textures are never depth textures.
2531 ast_store_type =
2532 ty_.MultisampledTexture(dim, ast_sampled_component_type);
2533 } else {
2534 ast_store_type = ty_.SampledTexture(dim, ast_sampled_component_type);
2535 }
2536 } else {
2537 const auto access = ast::Access::kWrite;
2538 const auto format = enum_converter_.ToImageFormat(image_type->format());
2539 if (format == ast::ImageFormat::kNone) {
2540 return nullptr;
2541 }
2542 ast_store_type = ty_.StorageTexture(dim, format, access);
2543 }
2544 } else {
2545 Fail() << "unsupported: UniformConstant variable is not a recognized "
2546 "sampler or texture"
2547 << var.PrettyPrint();
2548 return nullptr;
2549 }
2550
2551 // Form the pointer type.
2552 auto* result =
2553 ty_.Pointer(ast_store_type, ast::StorageClass::kUniformConstant);
2554 // Remember it for later.
2555 handle_type_[&var] = result;
2556 return result;
2557 }
2558
GetComponentTypeForFormat(ast::ImageFormat format)2559 const Type* ParserImpl::GetComponentTypeForFormat(ast::ImageFormat format) {
2560 switch (format) {
2561 case ast::ImageFormat::kR8Uint:
2562 case ast::ImageFormat::kR16Uint:
2563 case ast::ImageFormat::kRg8Uint:
2564 case ast::ImageFormat::kR32Uint:
2565 case ast::ImageFormat::kRg16Uint:
2566 case ast::ImageFormat::kRgba8Uint:
2567 case ast::ImageFormat::kRg32Uint:
2568 case ast::ImageFormat::kRgba16Uint:
2569 case ast::ImageFormat::kRgba32Uint:
2570 return ty_.U32();
2571
2572 case ast::ImageFormat::kR8Sint:
2573 case ast::ImageFormat::kR16Sint:
2574 case ast::ImageFormat::kRg8Sint:
2575 case ast::ImageFormat::kR32Sint:
2576 case ast::ImageFormat::kRg16Sint:
2577 case ast::ImageFormat::kRgba8Sint:
2578 case ast::ImageFormat::kRg32Sint:
2579 case ast::ImageFormat::kRgba16Sint:
2580 case ast::ImageFormat::kRgba32Sint:
2581 return ty_.I32();
2582
2583 case ast::ImageFormat::kR8Unorm:
2584 case ast::ImageFormat::kRg8Unorm:
2585 case ast::ImageFormat::kRgba8Unorm:
2586 case ast::ImageFormat::kRgba8UnormSrgb:
2587 case ast::ImageFormat::kBgra8Unorm:
2588 case ast::ImageFormat::kBgra8UnormSrgb:
2589 case ast::ImageFormat::kRgb10A2Unorm:
2590 case ast::ImageFormat::kR8Snorm:
2591 case ast::ImageFormat::kRg8Snorm:
2592 case ast::ImageFormat::kRgba8Snorm:
2593 case ast::ImageFormat::kR16Float:
2594 case ast::ImageFormat::kR32Float:
2595 case ast::ImageFormat::kRg16Float:
2596 case ast::ImageFormat::kRg11B10Float:
2597 case ast::ImageFormat::kRg32Float:
2598 case ast::ImageFormat::kRgba16Float:
2599 case ast::ImageFormat::kRgba32Float:
2600 return ty_.F32();
2601 default:
2602 break;
2603 }
2604 Fail() << "unknown format " << int(format);
2605 return nullptr;
2606 }
2607
GetChannelCountForFormat(ast::ImageFormat format)2608 unsigned ParserImpl::GetChannelCountForFormat(ast::ImageFormat format) {
2609 switch (format) {
2610 case ast::ImageFormat::kR16Float:
2611 case ast::ImageFormat::kR16Sint:
2612 case ast::ImageFormat::kR16Uint:
2613 case ast::ImageFormat::kR32Float:
2614 case ast::ImageFormat::kR32Sint:
2615 case ast::ImageFormat::kR32Uint:
2616 case ast::ImageFormat::kR8Sint:
2617 case ast::ImageFormat::kR8Snorm:
2618 case ast::ImageFormat::kR8Uint:
2619 case ast::ImageFormat::kR8Unorm:
2620 // One channel
2621 return 1;
2622
2623 case ast::ImageFormat::kRg11B10Float:
2624 case ast::ImageFormat::kRg16Float:
2625 case ast::ImageFormat::kRg16Sint:
2626 case ast::ImageFormat::kRg16Uint:
2627 case ast::ImageFormat::kRg32Float:
2628 case ast::ImageFormat::kRg32Sint:
2629 case ast::ImageFormat::kRg32Uint:
2630 case ast::ImageFormat::kRg8Sint:
2631 case ast::ImageFormat::kRg8Snorm:
2632 case ast::ImageFormat::kRg8Uint:
2633 case ast::ImageFormat::kRg8Unorm:
2634 // Two channels
2635 return 2;
2636
2637 case ast::ImageFormat::kBgra8Unorm:
2638 case ast::ImageFormat::kBgra8UnormSrgb:
2639 case ast::ImageFormat::kRgb10A2Unorm:
2640 case ast::ImageFormat::kRgba16Float:
2641 case ast::ImageFormat::kRgba16Sint:
2642 case ast::ImageFormat::kRgba16Uint:
2643 case ast::ImageFormat::kRgba32Float:
2644 case ast::ImageFormat::kRgba32Sint:
2645 case ast::ImageFormat::kRgba32Uint:
2646 case ast::ImageFormat::kRgba8Sint:
2647 case ast::ImageFormat::kRgba8Snorm:
2648 case ast::ImageFormat::kRgba8Uint:
2649 case ast::ImageFormat::kRgba8Unorm:
2650 case ast::ImageFormat::kRgba8UnormSrgb:
2651 // Four channels
2652 return 4;
2653
2654 default:
2655 break;
2656 }
2657 Fail() << "unknown format " << int(format);
2658 return 0;
2659 }
2660
GetTexelTypeForFormat(ast::ImageFormat format)2661 const Type* ParserImpl::GetTexelTypeForFormat(ast::ImageFormat format) {
2662 const auto* component_type = GetComponentTypeForFormat(format);
2663 if (!component_type) {
2664 return nullptr;
2665 }
2666 return ty_.Vector(component_type, 4);
2667 }
2668
RegisterHandleUsage()2669 bool ParserImpl::RegisterHandleUsage() {
2670 if (!success_) {
2671 return false;
2672 }
2673
2674 // Map a function ID to the list of its function parameter instructions, in
2675 // order.
2676 std::unordered_map<uint32_t, std::vector<const spvtools::opt::Instruction*>>
2677 function_params;
2678 for (const auto* f : topologically_ordered_functions_) {
2679 // Record the instructions defining this function's parameters.
2680 auto& params = function_params[f->result_id()];
2681 f->ForEachParam([¶ms](const spvtools::opt::Instruction* param) {
2682 params.push_back(param);
2683 });
2684 }
2685
2686 // Returns the memory object declaration for an image underlying the first
2687 // operand of the given image instruction.
2688 auto get_image = [this](const spvtools::opt::Instruction& image_inst) {
2689 return this->GetMemoryObjectDeclarationForHandle(
2690 image_inst.GetSingleWordInOperand(0), true);
2691 };
2692 // Returns the memory object declaration for a sampler underlying the first
2693 // operand of the given image instruction.
2694 auto get_sampler = [this](const spvtools::opt::Instruction& image_inst) {
2695 return this->GetMemoryObjectDeclarationForHandle(
2696 image_inst.GetSingleWordInOperand(0), false);
2697 };
2698
2699 // Scan the bodies of functions for image operations, recording their implied
2700 // usage properties on the memory object declarations (i.e. variables or
2701 // function parameters). We scan the functions in an order so that callees
2702 // precede callers. That way the usage on a function parameter is already
2703 // computed before we see the call to that function. So when we reach
2704 // a function call, we can add the usage from the callee formal parameters.
2705 for (const auto* f : topologically_ordered_functions_) {
2706 for (const auto& bb : *f) {
2707 for (const auto& inst : bb) {
2708 switch (inst.opcode()) {
2709 // Single texel reads and writes
2710
2711 case SpvOpImageRead:
2712 handle_usage_[get_image(inst)].AddStorageReadTexture();
2713 break;
2714 case SpvOpImageWrite:
2715 handle_usage_[get_image(inst)].AddStorageWriteTexture();
2716 break;
2717 case SpvOpImageFetch:
2718 handle_usage_[get_image(inst)].AddSampledTexture();
2719 break;
2720
2721 // Sampling and gathering from a sampled image.
2722
2723 case SpvOpImageSampleImplicitLod:
2724 case SpvOpImageSampleExplicitLod:
2725 case SpvOpImageSampleProjImplicitLod:
2726 case SpvOpImageSampleProjExplicitLod:
2727 case SpvOpImageGather:
2728 handle_usage_[get_image(inst)].AddSampledTexture();
2729 handle_usage_[get_sampler(inst)].AddSampler();
2730 break;
2731 case SpvOpImageSampleDrefImplicitLod:
2732 case SpvOpImageSampleDrefExplicitLod:
2733 case SpvOpImageSampleProjDrefImplicitLod:
2734 case SpvOpImageSampleProjDrefExplicitLod:
2735 case SpvOpImageDrefGather:
2736 // Depth reference access implies usage as a depth texture, which
2737 // in turn is a sampled texture.
2738 handle_usage_[get_image(inst)].AddDepthTexture();
2739 handle_usage_[get_sampler(inst)].AddComparisonSampler();
2740 break;
2741
2742 // Image queries
2743
2744 case SpvOpImageQuerySizeLod:
2745 // Vulkan requires Sampled=1 for this. SPIR-V already requires MS=0.
2746 handle_usage_[get_image(inst)].AddSampledTexture();
2747 break;
2748 case SpvOpImageQuerySize:
2749 // Applies to either MS=1 or Sampled=0 or 2.
2750 // So we can't force it to be multisampled, or storage image.
2751 break;
2752 case SpvOpImageQueryLod:
2753 handle_usage_[get_image(inst)].AddSampledTexture();
2754 handle_usage_[get_sampler(inst)].AddSampler();
2755 break;
2756 case SpvOpImageQueryLevels:
2757 // We can't tell anything more than that it's an image.
2758 handle_usage_[get_image(inst)].AddTexture();
2759 break;
2760 case SpvOpImageQuerySamples:
2761 handle_usage_[get_image(inst)].AddMultisampledTexture();
2762 break;
2763
2764 // Function calls
2765
2766 case SpvOpFunctionCall: {
2767 // Propagate handle usages from callee function formal parameters to
2768 // the matching caller parameters. This is where we rely on the
2769 // fact that callees have been processed earlier in the flow.
2770 const auto num_in_operands = inst.NumInOperands();
2771 // The first operand of the call is the function ID.
2772 // The remaining operands are the operands to the function.
2773 if (num_in_operands < 1) {
2774 return Fail() << "Call instruction must have at least one operand"
2775 << inst.PrettyPrint();
2776 }
2777 const auto function_id = inst.GetSingleWordInOperand(0);
2778 const auto& formal_params = function_params[function_id];
2779 if (formal_params.size() != (num_in_operands - 1)) {
2780 return Fail() << "Called function has " << formal_params.size()
2781 << " parameters, but function call has "
2782 << (num_in_operands - 1) << " parameters"
2783 << inst.PrettyPrint();
2784 }
2785 for (uint32_t i = 1; i < num_in_operands; ++i) {
2786 auto where = handle_usage_.find(formal_params[i - 1]);
2787 if (where == handle_usage_.end()) {
2788 // We haven't recorded any handle usage on the formal parameter.
2789 continue;
2790 }
2791 const Usage& formal_param_usage = where->second;
2792 const auto operand_id = inst.GetSingleWordInOperand(i);
2793 const auto* operand_as_sampler =
2794 GetMemoryObjectDeclarationForHandle(operand_id, false);
2795 const auto* operand_as_image =
2796 GetMemoryObjectDeclarationForHandle(operand_id, true);
2797 if (operand_as_sampler) {
2798 handle_usage_[operand_as_sampler].Add(formal_param_usage);
2799 }
2800 if (operand_as_image &&
2801 (operand_as_image != operand_as_sampler)) {
2802 handle_usage_[operand_as_image].Add(formal_param_usage);
2803 }
2804 }
2805 break;
2806 }
2807
2808 default:
2809 break;
2810 }
2811 }
2812 }
2813 }
2814 return success_;
2815 }
2816
GetHandleUsage(uint32_t id) const2817 Usage ParserImpl::GetHandleUsage(uint32_t id) const {
2818 const auto where = handle_usage_.find(def_use_mgr_->GetDef(id));
2819 if (where != handle_usage_.end()) {
2820 return where->second;
2821 }
2822 return Usage();
2823 }
2824
GetInstructionForTest(uint32_t id) const2825 const spvtools::opt::Instruction* ParserImpl::GetInstructionForTest(
2826 uint32_t id) const {
2827 return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr;
2828 }
2829
GetMemberName(const Struct & struct_type,int member_index)2830 std::string ParserImpl::GetMemberName(const Struct& struct_type,
2831 int member_index) {
2832 auto where = struct_id_for_symbol_.find(struct_type.name);
2833 if (where == struct_id_for_symbol_.end()) {
2834 Fail() << "no structure type registered for symbol";
2835 return "";
2836 }
2837 return namer_.GetMemberName(where->second, member_index);
2838 }
2839
2840 WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
2841
2842 WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
2843
2844 } // namespace spirv
2845 } // namespace reader
2846 } // namespace tint
2847