1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/reader/spirv/function.h"
16
17 #include <algorithm>
18 #include <array>
19
20 #include "src/ast/assignment_statement.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/break_statement.h"
23 #include "src/ast/builtin.h"
24 #include "src/ast/builtin_decoration.h"
25 #include "src/ast/call_statement.h"
26 #include "src/ast/continue_statement.h"
27 #include "src/ast/discard_statement.h"
28 #include "src/ast/fallthrough_statement.h"
29 #include "src/ast/if_statement.h"
30 #include "src/ast/loop_statement.h"
31 #include "src/ast/return_statement.h"
32 #include "src/ast/stage_decoration.h"
33 #include "src/ast/switch_statement.h"
34 #include "src/ast/unary_op_expression.h"
35 #include "src/ast/variable_decl_statement.h"
36 #include "src/sem/depth_texture_type.h"
37 #include "src/sem/intrinsic_type.h"
38 #include "src/sem/sampled_texture_type.h"
39
40 // Terms:
41 // CFG: the control flow graph of the function, where basic blocks are the
42 // nodes, and branches form the directed arcs. The function entry block is
43 // the root of the CFG.
44 //
45 // Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
46 // Then:
47 // - Let M(H) be the merge block named by the merge instruction in H.
48 // - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
49 // CT(H) be the continue target block named by the OpLoopMerge
50 // instruction.
51 // - If H is a selection construct whose header ends in
52 // OpBranchConditional with true target %then and false target %else,
53 // then TT(H) = %then and FT(H) = %else
54 //
55 // Determining output block order:
56 // The "structured post-order traversal" of the CFG is a post-order traversal
57 // of the basic blocks in the CFG, where:
58 // We visit the entry node of the function first.
59 // When visiting a header block:
60 // We next visit its merge block
61 // Then if it's a loop header, we next visit the continue target,
62 // Then we visit the block's successors (whether it's a header or not)
63 // If the block ends in an OpBranchConditional, we visit the false target
64 // before the true target.
65 //
66 // The "reverse structured post-order traversal" of the CFG is the reverse
67 // of the structured post-order traversal.
68 // This is the order of basic blocks as they should be emitted to the WGSL
69 // function. It is the order computed by ComputeBlockOrder, and stored in
70 // the |FunctionEmiter::block_order_|.
71 // Blocks not in this ordering are ignored by the rest of the algorithm.
72 //
73 // Note:
74 // - A block D in the function might not appear in this order because
75 // no block in the order branches to D.
76 // - An unreachable block D might still be in the order because some header
77 // block in the order names D as its continue target, or merge block,
78 // or D is reachable from one of those otherwise-unreachable continue
79 // targets or merge blocks.
80 //
81 // Terms:
82 // Let Pos(B) be the index position of a block B in the computed block order.
83 //
84 // CFG intervals and valid nesting:
85 //
86 // A correctly structured CFG satisfies nesting rules that we can check by
87 // comparing positions of related blocks.
88 //
89 // If header block H is in the block order, then the following holds:
90 //
91 // Pos(H) < Pos(M(H))
92 //
93 // If CT(H) exists, then:
94 //
95 // Pos(H) <= Pos(CT(H))
96 // Pos(CT(H)) < Pos(M)
97 //
98 // This gives us the fundamental ordering of blocks in relation to a
99 // structured construct:
100 // The blocks before H in the block order, are not in the construct
101 // The blocks at M(H) or later in the block order, are not in the construct
102 // The blocks in a selection headed at H are in positions [ Pos(H),
103 // Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
104 // [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
105 // headed at H are in
106 // positions [ Pos(CT(H)), Pos(M(H)) )
107 //
108 // Schematically, for a selection construct headed by H, the blocks are in
109 // order from left to right:
110 //
111 // ...a-b-c H d-e-f M(H) n-o-p...
112 //
113 // where ...a-b-c: blocks before the selection construct
114 // where H and d-e-f: blocks in the selection construct
115 // where M(H) and n-o-p...: blocks after the selection construct
116 //
117 // Schematically, for a loop construct headed by H that is its own
118 // continue construct, the blocks in order from left to right:
119 //
120 // ...a-b-c H=CT(H) d-e-f M(H) n-o-p...
121 //
122 // where ...a-b-c: blocks before the loop
123 // where H is the continue construct; CT(H)=H, and the loop construct
124 // is *empty*
125 // where d-e-f... are other blocks in the continue construct
126 // where M(H) and n-o-p...: blocks after the continue construct
127 //
128 // Schematically, for a multi-block loop construct headed by H, there are
129 // blocks in order from left to right:
130 //
131 // ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
132 //
133 // where ...a-b-c: blocks before the loop
134 // where H and d-e-f: blocks in the loop construct
135 // where CT(H) and j-k-l: blocks in the continue construct
136 // where M(H) and n-o-p...: blocks after the loop and continue
137 // constructs
138 //
139
140 namespace tint {
141 namespace reader {
142 namespace spirv {
143
144 namespace {
145
146 constexpr uint32_t kMaxVectorLen = 4;
147
148 // Gets the AST unary opcode for the given SPIR-V opcode, if any
149 // @param opcode SPIR-V opcode
150 // @param ast_unary_op return parameter
151 // @returns true if it was a unary operation
GetUnaryOp(SpvOp opcode,ast::UnaryOp * ast_unary_op)152 bool GetUnaryOp(SpvOp opcode, ast::UnaryOp* ast_unary_op) {
153 switch (opcode) {
154 case SpvOpSNegate:
155 case SpvOpFNegate:
156 *ast_unary_op = ast::UnaryOp::kNegation;
157 return true;
158 case SpvOpLogicalNot:
159 *ast_unary_op = ast::UnaryOp::kNot;
160 return true;
161 case SpvOpNot:
162 *ast_unary_op = ast::UnaryOp::kComplement;
163 return true;
164 default:
165 break;
166 }
167 return false;
168 }
169
170 /// Converts a SPIR-V opcode for a WGSL builtin function, if there is a
171 /// direct translation. Returns nullptr otherwise.
172 /// @returns the WGSL builtin function name for the given opcode, or nullptr.
GetUnaryBuiltInFunctionName(SpvOp opcode)173 const char* GetUnaryBuiltInFunctionName(SpvOp opcode) {
174 switch (opcode) {
175 case SpvOpAny:
176 return "any";
177 case SpvOpAll:
178 return "all";
179 case SpvOpIsNan:
180 return "isNan";
181 case SpvOpIsInf:
182 return "isInf";
183 case SpvOpTranspose:
184 return "transpose";
185 default:
186 break;
187 }
188 return nullptr;
189 }
190
191 // Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
192 // @param opcode SPIR-V opcode
193 // @returns the AST binary op for the given opcode, or kNone
ConvertBinaryOp(SpvOp opcode)194 ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
195 switch (opcode) {
196 case SpvOpIAdd:
197 case SpvOpFAdd:
198 return ast::BinaryOp::kAdd;
199 case SpvOpISub:
200 case SpvOpFSub:
201 return ast::BinaryOp::kSubtract;
202 case SpvOpIMul:
203 case SpvOpFMul:
204 case SpvOpVectorTimesScalar:
205 case SpvOpMatrixTimesScalar:
206 case SpvOpVectorTimesMatrix:
207 case SpvOpMatrixTimesVector:
208 case SpvOpMatrixTimesMatrix:
209 return ast::BinaryOp::kMultiply;
210 case SpvOpUDiv:
211 case SpvOpSDiv:
212 case SpvOpFDiv:
213 return ast::BinaryOp::kDivide;
214 case SpvOpUMod:
215 case SpvOpSMod:
216 case SpvOpFRem:
217 return ast::BinaryOp::kModulo;
218 case SpvOpLogicalEqual:
219 case SpvOpIEqual:
220 case SpvOpFOrdEqual:
221 return ast::BinaryOp::kEqual;
222 case SpvOpLogicalNotEqual:
223 case SpvOpINotEqual:
224 case SpvOpFOrdNotEqual:
225 return ast::BinaryOp::kNotEqual;
226 case SpvOpBitwiseAnd:
227 return ast::BinaryOp::kAnd;
228 case SpvOpBitwiseOr:
229 return ast::BinaryOp::kOr;
230 case SpvOpBitwiseXor:
231 return ast::BinaryOp::kXor;
232 case SpvOpLogicalAnd:
233 return ast::BinaryOp::kAnd;
234 case SpvOpLogicalOr:
235 return ast::BinaryOp::kOr;
236 case SpvOpUGreaterThan:
237 case SpvOpSGreaterThan:
238 case SpvOpFOrdGreaterThan:
239 return ast::BinaryOp::kGreaterThan;
240 case SpvOpUGreaterThanEqual:
241 case SpvOpSGreaterThanEqual:
242 case SpvOpFOrdGreaterThanEqual:
243 return ast::BinaryOp::kGreaterThanEqual;
244 case SpvOpULessThan:
245 case SpvOpSLessThan:
246 case SpvOpFOrdLessThan:
247 return ast::BinaryOp::kLessThan;
248 case SpvOpULessThanEqual:
249 case SpvOpSLessThanEqual:
250 case SpvOpFOrdLessThanEqual:
251 return ast::BinaryOp::kLessThanEqual;
252 default:
253 break;
254 }
255 // It's not clear what OpSMod should map to.
256 // https://bugs.chromium.org/p/tint/issues/detail?id=52
257 return ast::BinaryOp::kNone;
258 }
259
260 // If the given SPIR-V opcode is a floating point unordered comparison,
261 // then returns the binary float comparison for which it is the negation.
262 // Othewrise returns BinaryOp::kNone.
263 // @param opcode SPIR-V opcode
264 // @returns operation corresponding to negated version of the SPIR-V opcode
NegatedFloatCompare(SpvOp opcode)265 ast::BinaryOp NegatedFloatCompare(SpvOp opcode) {
266 switch (opcode) {
267 case SpvOpFUnordEqual:
268 return ast::BinaryOp::kNotEqual;
269 case SpvOpFUnordNotEqual:
270 return ast::BinaryOp::kEqual;
271 case SpvOpFUnordLessThan:
272 return ast::BinaryOp::kGreaterThanEqual;
273 case SpvOpFUnordLessThanEqual:
274 return ast::BinaryOp::kGreaterThan;
275 case SpvOpFUnordGreaterThan:
276 return ast::BinaryOp::kLessThanEqual;
277 case SpvOpFUnordGreaterThanEqual:
278 return ast::BinaryOp::kLessThan;
279 default:
280 break;
281 }
282 return ast::BinaryOp::kNone;
283 }
284
285 // Returns the WGSL standard library function for the given
286 // GLSL.std.450 extended instruction operation code. Unknown
287 // and invalid opcodes map to the empty string.
288 // @returns the WGSL standard function name, or an empty string.
GetGlslStd450FuncName(uint32_t ext_opcode)289 std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
290 switch (ext_opcode) {
291 case GLSLstd450FAbs:
292 case GLSLstd450SAbs:
293 return "abs";
294 case GLSLstd450Acos:
295 return "acos";
296 case GLSLstd450Asin:
297 return "asin";
298 case GLSLstd450Atan:
299 return "atan";
300 case GLSLstd450Atan2:
301 return "atan2";
302 case GLSLstd450Ceil:
303 return "ceil";
304 case GLSLstd450UClamp:
305 case GLSLstd450SClamp:
306 case GLSLstd450NClamp:
307 case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands
308 return "clamp";
309 case GLSLstd450Cos:
310 return "cos";
311 case GLSLstd450Cosh:
312 return "cosh";
313 case GLSLstd450Cross:
314 return "cross";
315 case GLSLstd450Distance:
316 return "distance";
317 case GLSLstd450Exp:
318 return "exp";
319 case GLSLstd450Exp2:
320 return "exp2";
321 case GLSLstd450FaceForward:
322 return "faceForward";
323 case GLSLstd450Floor:
324 return "floor";
325 case GLSLstd450Fma:
326 return "fma";
327 case GLSLstd450Fract:
328 return "fract";
329 case GLSLstd450InverseSqrt:
330 return "inverseSqrt";
331 case GLSLstd450Ldexp:
332 return "ldexp";
333 case GLSLstd450Length:
334 return "length";
335 case GLSLstd450Log:
336 return "log";
337 case GLSLstd450Log2:
338 return "log2";
339 case GLSLstd450NMax:
340 case GLSLstd450FMax: // FMax is less prescriptive about NaN operands
341 case GLSLstd450UMax:
342 case GLSLstd450SMax:
343 return "max";
344 case GLSLstd450NMin:
345 case GLSLstd450FMin: // FMin is less prescriptive about NaN operands
346 case GLSLstd450UMin:
347 case GLSLstd450SMin:
348 return "min";
349 case GLSLstd450FMix:
350 return "mix";
351 case GLSLstd450Normalize:
352 return "normalize";
353 case GLSLstd450PackSnorm4x8:
354 return "pack4x8snorm";
355 case GLSLstd450PackUnorm4x8:
356 return "pack4x8unorm";
357 case GLSLstd450PackSnorm2x16:
358 return "pack2x16snorm";
359 case GLSLstd450PackUnorm2x16:
360 return "pack2x16unorm";
361 case GLSLstd450PackHalf2x16:
362 return "pack2x16float";
363 case GLSLstd450Pow:
364 return "pow";
365 case GLSLstd450FSign:
366 return "sign";
367 case GLSLstd450Reflect:
368 return "reflect";
369 case GLSLstd450Refract:
370 return "refract";
371 case GLSLstd450Round:
372 case GLSLstd450RoundEven:
373 return "round";
374 case GLSLstd450Sin:
375 return "sin";
376 case GLSLstd450Sinh:
377 return "sinh";
378 case GLSLstd450SmoothStep:
379 return "smoothStep";
380 case GLSLstd450Sqrt:
381 return "sqrt";
382 case GLSLstd450Step:
383 return "step";
384 case GLSLstd450Tan:
385 return "tan";
386 case GLSLstd450Tanh:
387 return "tanh";
388 case GLSLstd450Trunc:
389 return "trunc";
390 case GLSLstd450UnpackSnorm4x8:
391 return "unpack4x8snorm";
392 case GLSLstd450UnpackUnorm4x8:
393 return "unpack4x8unorm";
394 case GLSLstd450UnpackSnorm2x16:
395 return "unpack2x16snorm";
396 case GLSLstd450UnpackUnorm2x16:
397 return "unpack2x16unorm";
398 case GLSLstd450UnpackHalf2x16:
399 return "unpack2x16float";
400
401 default:
402 // TODO(dneto) - The following are not implemented.
403 // They are grouped semantically, as in GLSL.std.450.h.
404
405 case GLSLstd450SSign:
406
407 case GLSLstd450Radians:
408 case GLSLstd450Degrees:
409 case GLSLstd450Asinh:
410 case GLSLstd450Acosh:
411 case GLSLstd450Atanh:
412
413 case GLSLstd450Determinant:
414 case GLSLstd450MatrixInverse:
415
416 case GLSLstd450Modf:
417 case GLSLstd450ModfStruct:
418 case GLSLstd450IMix:
419
420 case GLSLstd450Frexp:
421 case GLSLstd450FrexpStruct:
422
423 case GLSLstd450PackDouble2x32:
424 case GLSLstd450UnpackDouble2x32:
425
426 case GLSLstd450FindILsb:
427 case GLSLstd450FindSMsb:
428 case GLSLstd450FindUMsb:
429
430 case GLSLstd450InterpolateAtCentroid:
431 case GLSLstd450InterpolateAtSample:
432 case GLSLstd450InterpolateAtOffset:
433 break;
434 }
435 return "";
436 }
437
438 // Returns the WGSL standard library function intrinsic for the
439 // given instruction, or sem::IntrinsicType::kNone
GetIntrinsic(SpvOp opcode)440 sem::IntrinsicType GetIntrinsic(SpvOp opcode) {
441 switch (opcode) {
442 case SpvOpBitCount:
443 return sem::IntrinsicType::kCountOneBits;
444 case SpvOpBitReverse:
445 return sem::IntrinsicType::kReverseBits;
446 case SpvOpDot:
447 return sem::IntrinsicType::kDot;
448 case SpvOpDPdx:
449 return sem::IntrinsicType::kDpdx;
450 case SpvOpDPdy:
451 return sem::IntrinsicType::kDpdy;
452 case SpvOpFwidth:
453 return sem::IntrinsicType::kFwidth;
454 case SpvOpDPdxFine:
455 return sem::IntrinsicType::kDpdxFine;
456 case SpvOpDPdyFine:
457 return sem::IntrinsicType::kDpdyFine;
458 case SpvOpFwidthFine:
459 return sem::IntrinsicType::kFwidthFine;
460 case SpvOpDPdxCoarse:
461 return sem::IntrinsicType::kDpdxCoarse;
462 case SpvOpDPdyCoarse:
463 return sem::IntrinsicType::kDpdyCoarse;
464 case SpvOpFwidthCoarse:
465 return sem::IntrinsicType::kFwidthCoarse;
466 default:
467 break;
468 }
469 return sem::IntrinsicType::kNone;
470 }
471
472 // @param opcode a SPIR-V opcode
473 // @returns true if the given instruction is an image access instruction
474 // whose first input operand is an OpSampledImage value.
IsSampledImageAccess(SpvOp opcode)475 bool IsSampledImageAccess(SpvOp opcode) {
476 switch (opcode) {
477 case SpvOpImageSampleImplicitLod:
478 case SpvOpImageSampleExplicitLod:
479 case SpvOpImageSampleDrefImplicitLod:
480 case SpvOpImageSampleDrefExplicitLod:
481 // WGSL doesn't have *Proj* texturing; spirv reader emulates it.
482 case SpvOpImageSampleProjImplicitLod:
483 case SpvOpImageSampleProjExplicitLod:
484 case SpvOpImageSampleProjDrefImplicitLod:
485 case SpvOpImageSampleProjDrefExplicitLod:
486 case SpvOpImageGather:
487 case SpvOpImageDrefGather:
488 case SpvOpImageQueryLod:
489 return true;
490 default:
491 break;
492 }
493 return false;
494 }
495
496 // @param opcode a SPIR-V opcode
497 // @returns true if the given instruction is an image sampling operation.
IsImageSampling(SpvOp opcode)498 bool IsImageSampling(SpvOp opcode) {
499 switch (opcode) {
500 case SpvOpImageSampleImplicitLod:
501 case SpvOpImageSampleExplicitLod:
502 case SpvOpImageSampleDrefImplicitLod:
503 case SpvOpImageSampleDrefExplicitLod:
504 // WGSL doesn't have *Proj* texturing; spirv reader emulates it.
505 case SpvOpImageSampleProjImplicitLod:
506 case SpvOpImageSampleProjExplicitLod:
507 case SpvOpImageSampleProjDrefImplicitLod:
508 case SpvOpImageSampleProjDrefExplicitLod:
509 return true;
510 default:
511 break;
512 }
513 return false;
514 }
515
516 // @param opcode a SPIR-V opcode
517 // @returns true if the given instruction is an image access instruction
518 // whose first input operand is an OpImage value.
IsRawImageAccess(SpvOp opcode)519 bool IsRawImageAccess(SpvOp opcode) {
520 switch (opcode) {
521 case SpvOpImageRead:
522 case SpvOpImageWrite:
523 case SpvOpImageFetch:
524 return true;
525 default:
526 break;
527 }
528 return false;
529 }
530
531 // @param opcode a SPIR-V opcode
532 // @returns true if the given instruction is an image query instruction
IsImageQuery(SpvOp opcode)533 bool IsImageQuery(SpvOp opcode) {
534 switch (opcode) {
535 case SpvOpImageQuerySize:
536 case SpvOpImageQuerySizeLod:
537 case SpvOpImageQueryLevels:
538 case SpvOpImageQuerySamples:
539 case SpvOpImageQueryLod:
540 return true;
541 default:
542 break;
543 }
544 return false;
545 }
546
547 // @returns the merge block ID for the given basic block, or 0 if there is none.
MergeFor(const spvtools::opt::BasicBlock & bb)548 uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
549 // Get the OpSelectionMerge or OpLoopMerge instruction, if any.
550 auto* inst = bb.GetMergeInst();
551 return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
552 }
553
554 // @returns the continue target ID for the given basic block, or 0 if there
555 // is none.
ContinueTargetFor(const spvtools::opt::BasicBlock & bb)556 uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
557 // Get the OpLoopMerge instruction, if any.
558 auto* inst = bb.GetLoopMergeInst();
559 return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
560 }
561
562 // A structured traverser produces the reverse structured post-order of the
563 // CFG of a function. The blocks traversed are the transitive closure (minimum
564 // fixed point) of:
565 // - the entry block
566 // - a block reached by a branch from another block in the set
567 // - a block mentioned as a merge block or continue target for a block in the
568 // set
569 class StructuredTraverser {
570 public:
StructuredTraverser(const spvtools::opt::Function & function)571 explicit StructuredTraverser(const spvtools::opt::Function& function)
572 : function_(function) {
573 for (auto& block : function_) {
574 id_to_block_[block.id()] = █
575 }
576 }
577
578 // Returns the reverse postorder traversal of the CFG, where:
579 // - a merge block always follows its associated constructs
580 // - a continue target always follows the associated loop construct, if any
581 // @returns the IDs of blocks in reverse structured post order
ReverseStructuredPostOrder()582 std::vector<uint32_t> ReverseStructuredPostOrder() {
583 visit_order_.clear();
584 visited_.clear();
585 VisitBackward(function_.entry()->id());
586
587 std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
588 return order;
589 }
590
591 private:
592 // Executes a depth first search of the CFG, where right after we visit a
593 // header, we will visit its merge block, then its continue target (if any).
594 // Also records the post order ordering.
VisitBackward(uint32_t id)595 void VisitBackward(uint32_t id) {
596 if (id == 0)
597 return;
598 if (visited_.count(id))
599 return;
600 visited_.insert(id);
601
602 const spvtools::opt::BasicBlock* bb =
603 id_to_block_[id]; // non-null for valid modules
604 VisitBackward(MergeFor(*bb));
605 VisitBackward(ContinueTargetFor(*bb));
606
607 // Visit successors. We will naturally skip the continue target and merge
608 // blocks.
609 auto* terminator = bb->terminator();
610 auto opcode = terminator->opcode();
611 if (opcode == SpvOpBranchConditional) {
612 // Visit the false branch, then the true branch, to make them come
613 // out in the natural order for an "if".
614 VisitBackward(terminator->GetSingleWordInOperand(2));
615 VisitBackward(terminator->GetSingleWordInOperand(1));
616 } else if (opcode == SpvOpBranch) {
617 VisitBackward(terminator->GetSingleWordInOperand(0));
618 } else if (opcode == SpvOpSwitch) {
619 // TODO(dneto): Consider visiting the labels in literal-value order.
620 std::vector<uint32_t> successors;
621 bb->ForEachSuccessorLabel([&successors](const uint32_t succ_id) {
622 successors.push_back(succ_id);
623 });
624 for (auto succ_id : successors) {
625 VisitBackward(succ_id);
626 }
627 }
628
629 visit_order_.push_back(id);
630 }
631
632 const spvtools::opt::Function& function_;
633 std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
634 std::vector<uint32_t> visit_order_;
635 std::unordered_set<uint32_t> visited_;
636 };
637
638 /// A StatementBuilder for ast::SwitchStatement
639 /// @see StatementBuilder
640 struct SwitchStatementBuilder
641 : public Castable<SwitchStatementBuilder, StatementBuilder> {
642 /// Constructor
643 /// @param cond the switch statement condition
SwitchStatementBuildertint::reader::spirv::__anone3e1bc1b0111::SwitchStatementBuilder644 explicit SwitchStatementBuilder(const ast::Expression* cond)
645 : condition(cond) {}
646
647 /// @param builder the program builder
648 /// @returns the built ast::SwitchStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::SwitchStatementBuilder649 const ast::SwitchStatement* Build(ProgramBuilder* builder) const override {
650 // We've listed cases in reverse order in the switch statement.
651 // Reorder them to match the presentation order in WGSL.
652 auto reversed_cases = cases;
653 std::reverse(reversed_cases.begin(), reversed_cases.end());
654
655 return builder->create<ast::SwitchStatement>(Source{}, condition,
656 reversed_cases);
657 }
658
659 /// Switch statement condition
660 const ast::Expression* const condition;
661 /// Switch statement cases
662 ast::CaseStatementList cases;
663 };
664
665 /// A StatementBuilder for ast::IfStatement
666 /// @see StatementBuilder
667 struct IfStatementBuilder
668 : public Castable<IfStatementBuilder, StatementBuilder> {
669 /// Constructor
670 /// @param c the if-statement condition
IfStatementBuildertint::reader::spirv::__anone3e1bc1b0111::IfStatementBuilder671 explicit IfStatementBuilder(const ast::Expression* c) : cond(c) {}
672
673 /// @param builder the program builder
674 /// @returns the built ast::IfStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::IfStatementBuilder675 const ast::IfStatement* Build(ProgramBuilder* builder) const override {
676 return builder->create<ast::IfStatement>(Source{}, cond, body, else_stmts);
677 }
678
679 /// If-statement condition
680 const ast::Expression* const cond;
681 /// If-statement block body
682 const ast::BlockStatement* body = nullptr;
683 /// Optional if-statement else statements
684 ast::ElseStatementList else_stmts;
685 };
686
687 /// A StatementBuilder for ast::LoopStatement
688 /// @see StatementBuilder
689 struct LoopStatementBuilder
690 : public Castable<LoopStatementBuilder, StatementBuilder> {
691 /// @param builder the program builder
692 /// @returns the built ast::LoopStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::LoopStatementBuilder693 ast::LoopStatement* Build(ProgramBuilder* builder) const override {
694 return builder->create<ast::LoopStatement>(Source{}, body, continuing);
695 }
696
697 /// Loop-statement block body
698 const ast::BlockStatement* body = nullptr;
699 /// Loop-statement continuing body
700 /// @note the mutable keyword here is required as all non-StatementBuilders
701 /// `ast::Node`s are immutable and are referenced with `const` pointers.
702 /// StatementBuilders however exist to provide mutable state while the
703 /// FunctionEmitter is building the function. All StatementBuilders are
704 /// replaced with immutable AST nodes when Finalize() is called.
705 mutable const ast::BlockStatement* continuing = nullptr;
706 };
707
708 /// @param decos a list of parsed decorations
709 /// @returns true if the decorations include a SampleMask builtin
HasBuiltinSampleMask(const ast::DecorationList & decos)710 bool HasBuiltinSampleMask(const ast::DecorationList& decos) {
711 if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos)) {
712 return builtin->builtin == ast::Builtin::kSampleMask;
713 }
714 return false;
715 }
716
717 } // namespace
718
BlockInfo(const spvtools::opt::BasicBlock & bb)719 BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
720 : basic_block(&bb), id(bb.id()) {}
721
722 BlockInfo::~BlockInfo() = default;
723
DefInfo(const spvtools::opt::Instruction & def_inst,uint32_t the_block_pos,size_t the_index)724 DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
725 uint32_t the_block_pos,
726 size_t the_index)
727 : inst(def_inst), block_pos(the_block_pos), index(the_index) {}
728
729 DefInfo::~DefInfo() = default;
730
Clone(CloneContext *) const731 ast::Node* StatementBuilder::Clone(CloneContext*) const {
732 return nullptr;
733 }
734
FunctionEmitter(ParserImpl * pi,const spvtools::opt::Function & function,const EntryPointInfo * ep_info)735 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
736 const spvtools::opt::Function& function,
737 const EntryPointInfo* ep_info)
738 : parser_impl_(*pi),
739 ty_(pi->type_manager()),
740 builder_(pi->builder()),
741 ir_context_(*(pi->ir_context())),
742 def_use_mgr_(ir_context_.get_def_use_mgr()),
743 constant_mgr_(ir_context_.get_constant_mgr()),
744 type_mgr_(ir_context_.get_type_mgr()),
745 fail_stream_(pi->fail_stream()),
746 namer_(pi->namer()),
747 function_(function),
748 sample_mask_in_id(0u),
749 sample_mask_out_id(0u),
750 ep_info_(ep_info) {
751 PushNewStatementBlock(nullptr, 0, nullptr);
752 }
753
FunctionEmitter(ParserImpl * pi,const spvtools::opt::Function & function)754 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
755 const spvtools::opt::Function& function)
756 : FunctionEmitter(pi, function, nullptr) {}
757
FunctionEmitter(FunctionEmitter && other)758 FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
759 : parser_impl_(other.parser_impl_),
760 ty_(other.ty_),
761 builder_(other.builder_),
762 ir_context_(other.ir_context_),
763 def_use_mgr_(ir_context_.get_def_use_mgr()),
764 constant_mgr_(ir_context_.get_constant_mgr()),
765 type_mgr_(ir_context_.get_type_mgr()),
766 fail_stream_(other.fail_stream_),
767 namer_(other.namer_),
768 function_(other.function_),
769 sample_mask_in_id(other.sample_mask_out_id),
770 sample_mask_out_id(other.sample_mask_in_id),
771 ep_info_(other.ep_info_) {
772 other.statements_stack_.clear();
773 PushNewStatementBlock(nullptr, 0, nullptr);
774 }
775
776 FunctionEmitter::~FunctionEmitter() = default;
777
StatementBlock(const Construct * construct,uint32_t end_id,FunctionEmitter::CompletionAction completion_action)778 FunctionEmitter::StatementBlock::StatementBlock(
779 const Construct* construct,
780 uint32_t end_id,
781 FunctionEmitter::CompletionAction completion_action)
782 : construct_(construct),
783 end_id_(end_id),
784 completion_action_(completion_action) {}
785
786 FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other) =
787 default;
788
789 FunctionEmitter::StatementBlock::~StatementBlock() = default;
790
Finalize(ProgramBuilder * pb)791 void FunctionEmitter::StatementBlock::Finalize(ProgramBuilder* pb) {
792 TINT_ASSERT(Reader, !finalized_ /* Finalize() must only be called once */);
793
794 for (size_t i = 0; i < statements_.size(); i++) {
795 if (auto* sb = statements_[i]->As<StatementBuilder>()) {
796 statements_[i] = sb->Build(pb);
797 }
798 }
799
800 if (completion_action_ != nullptr) {
801 completion_action_(statements_);
802 }
803
804 finalized_ = true;
805 }
806
Add(const ast::Statement * statement)807 void FunctionEmitter::StatementBlock::Add(const ast::Statement* statement) {
808 TINT_ASSERT(Reader,
809 !finalized_ /* Add() must not be called after Finalize() */);
810 statements_.emplace_back(statement);
811 }
812
PushNewStatementBlock(const Construct * construct,uint32_t end_id,CompletionAction action)813 void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
814 uint32_t end_id,
815 CompletionAction action) {
816 statements_stack_.emplace_back(StatementBlock{construct, end_id, action});
817 }
818
PushGuard(const std::string & guard_name,uint32_t end_id)819 void FunctionEmitter::PushGuard(const std::string& guard_name,
820 uint32_t end_id) {
821 TINT_ASSERT(Reader, !statements_stack_.empty());
822 TINT_ASSERT(Reader, !guard_name.empty());
823 // Guard control flow by the guard variable. Introduce a new
824 // if-selection with a then-clause ending at the same block
825 // as the statement block at the top of the stack.
826 const auto& top = statements_stack_.back();
827
828 auto* cond = create<ast::IdentifierExpression>(
829 Source{}, builder_.Symbols().Register(guard_name));
830 auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
831
832 PushNewStatementBlock(
833 top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
834 builder->body = create<ast::BlockStatement>(Source{}, stmts);
835 });
836 }
837
PushTrueGuard(uint32_t end_id)838 void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
839 TINT_ASSERT(Reader, !statements_stack_.empty());
840 const auto& top = statements_stack_.back();
841
842 auto* cond = MakeTrue(Source{});
843 auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
844
845 PushNewStatementBlock(
846 top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
847 builder->body = create<ast::BlockStatement>(Source{}, stmts);
848 });
849 }
850
ast_body()851 const ast::StatementList FunctionEmitter::ast_body() {
852 TINT_ASSERT(Reader, !statements_stack_.empty());
853 auto& entry = statements_stack_[0];
854 entry.Finalize(&builder_);
855 return entry.GetStatements();
856 }
857
AddStatement(const ast::Statement * statement)858 const ast::Statement* FunctionEmitter::AddStatement(
859 const ast::Statement* statement) {
860 TINT_ASSERT(Reader, !statements_stack_.empty());
861 if (statement != nullptr) {
862 statements_stack_.back().Add(statement);
863 }
864 return statement;
865 }
866
LastStatement()867 const ast::Statement* FunctionEmitter::LastStatement() {
868 TINT_ASSERT(Reader, !statements_stack_.empty());
869 auto& statement_list = statements_stack_.back().GetStatements();
870 TINT_ASSERT(Reader, !statement_list.empty());
871 return statement_list.back();
872 }
873
Emit()874 bool FunctionEmitter::Emit() {
875 if (failed()) {
876 return false;
877 }
878 // We only care about functions with bodies.
879 if (function_.cbegin() == function_.cend()) {
880 return true;
881 }
882
883 // The function declaration, corresponding to how it's written in SPIR-V,
884 // and without regard to whether it's an entry point.
885 FunctionDeclaration decl;
886 if (!ParseFunctionDeclaration(&decl)) {
887 return false;
888 }
889
890 bool make_body_function = true;
891 if (ep_info_) {
892 TINT_ASSERT(Reader, !ep_info_->inner_name.empty());
893 if (ep_info_->owns_inner_implementation) {
894 // This is an entry point, and we want to emit it as a wrapper around
895 // an implementation function.
896 decl.name = ep_info_->inner_name;
897 } else {
898 // This is a second entry point that shares an inner implementation
899 // function.
900 make_body_function = false;
901 }
902 }
903
904 if (make_body_function) {
905 auto* body = MakeFunctionBody();
906 if (!body) {
907 return false;
908 }
909
910 builder_.AST().AddFunction(create<ast::Function>(
911 decl.source, builder_.Symbols().Register(decl.name),
912 std::move(decl.params), decl.return_type->Build(builder_), body,
913 std::move(decl.decorations), ast::DecorationList{}));
914 }
915
916 if (ep_info_ && !ep_info_->inner_name.empty()) {
917 return EmitEntryPointAsWrapper();
918 }
919
920 return success();
921 }
922
MakeFunctionBody()923 const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
924 TINT_ASSERT(Reader, statements_stack_.size() == 1);
925
926 if (!EmitBody()) {
927 return nullptr;
928 }
929
930 // Set the body of the AST function node.
931 if (statements_stack_.size() != 1) {
932 Fail() << "internal error: statement-list stack should have 1 "
933 "element but has "
934 << statements_stack_.size();
935 return nullptr;
936 }
937
938 statements_stack_[0].Finalize(&builder_);
939 auto& statements = statements_stack_[0].GetStatements();
940 auto* body = create<ast::BlockStatement>(Source{}, statements);
941
942 // Maintain the invariant by repopulating the one and only element.
943 statements_stack_.clear();
944 PushNewStatementBlock(constructs_[0].get(), 0, nullptr);
945
946 return body;
947 }
948
EmitPipelineInput(std::string var_name,const Type * var_type,ast::DecorationList * decos,std::vector<int> index_prefix,const Type * tip_type,const Type * forced_param_type,ast::VariableList * params,ast::StatementList * statements)949 bool FunctionEmitter::EmitPipelineInput(std::string var_name,
950 const Type* var_type,
951 ast::DecorationList* decos,
952 std::vector<int> index_prefix,
953 const Type* tip_type,
954 const Type* forced_param_type,
955 ast::VariableList* params,
956 ast::StatementList* statements) {
957 // TODO(dneto): Handle structs where the locations are annotated on members.
958 tip_type = tip_type->UnwrapAlias();
959 if (auto* ref_type = tip_type->As<Reference>()) {
960 tip_type = ref_type->type;
961 }
962
963 // Recursively flatten matrices, arrays, and structures.
964 if (auto* matrix_type = tip_type->As<Matrix>()) {
965 index_prefix.push_back(0);
966 const auto num_columns = static_cast<int>(matrix_type->columns);
967 const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
968 for (int col = 0; col < num_columns; col++) {
969 index_prefix.back() = col;
970 if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
971 forced_param_type, params, statements)) {
972 return false;
973 }
974 }
975 return success();
976 } else if (auto* array_type = tip_type->As<Array>()) {
977 if (array_type->size == 0) {
978 return Fail() << "runtime-size array not allowed on pipeline IO";
979 }
980 index_prefix.push_back(0);
981 const Type* elem_ty = array_type->type;
982 for (int i = 0; i < static_cast<int>(array_type->size); i++) {
983 index_prefix.back() = i;
984 if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
985 forced_param_type, params, statements)) {
986 return false;
987 }
988 }
989 return success();
990 } else if (auto* struct_type = tip_type->As<Struct>()) {
991 const auto& members = struct_type->members;
992 index_prefix.push_back(0);
993 for (int i = 0; i < static_cast<int>(members.size()); ++i) {
994 index_prefix.back() = i;
995 ast::DecorationList member_decos(*decos);
996 if (!parser_impl_.ConvertPipelineDecorations(
997 struct_type,
998 parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
999 &member_decos)) {
1000 return false;
1001 }
1002 if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
1003 members[i], forced_param_type, params,
1004 statements)) {
1005 return false;
1006 }
1007 // Copy the location as updated by nested expansion of the member.
1008 parser_impl_.SetLocation(decos, GetLocation(member_decos));
1009 }
1010 return success();
1011 }
1012
1013 const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
1014
1015 const Type* param_type = is_builtin ? forced_param_type : tip_type;
1016
1017 const auto param_name = namer_.MakeDerivedName(var_name + "_param");
1018 // Create the parameter.
1019 // TODO(dneto): Note: If the parameter has non-location decorations,
1020 // then those decoration AST nodes will be reused between multiple elements
1021 // of a matrix, array, or structure. Normally that's disallowed but currently
1022 // the SPIR-V reader will make duplicates when the entire AST is cloned
1023 // at the top level of the SPIR-V reader flow. Consider rewriting this
1024 // to avoid this node-sharing.
1025 params->push_back(
1026 builder_.Param(param_name, param_type->Build(builder_), *decos));
1027
1028 // Add a body statement to copy the parameter to the corresponding private
1029 // variable.
1030 const ast::Expression* param_value = builder_.Expr(param_name);
1031 const ast::Expression* store_dest = builder_.Expr(var_name);
1032
1033 // Index into the LHS as needed.
1034 auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
1035 for (auto index : index_prefix) {
1036 if (auto* matrix_type = current_type->As<Matrix>()) {
1037 store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
1038 current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
1039 } else if (auto* array_type = current_type->As<Array>()) {
1040 store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
1041 current_type = array_type->type->UnwrapAlias();
1042 } else if (auto* struct_type = current_type->As<Struct>()) {
1043 store_dest = builder_.MemberAccessor(
1044 store_dest,
1045 builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
1046 current_type = struct_type->members[index];
1047 }
1048 }
1049
1050 if (is_builtin && (tip_type != forced_param_type)) {
1051 // The parameter will have the WGSL type, but we need bitcast to
1052 // the variable store type.
1053 param_value =
1054 create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
1055 }
1056
1057 statements->push_back(builder_.Assign(store_dest, param_value));
1058
1059 // Increment the location attribute, in case more parameters will follow.
1060 IncrementLocation(decos);
1061
1062 return success();
1063 }
1064
IncrementLocation(ast::DecorationList * decos)1065 void FunctionEmitter::IncrementLocation(ast::DecorationList* decos) {
1066 for (auto*& deco : *decos) {
1067 if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
1068 // Replace this location decoration with a new one with one higher index.
1069 // The old one doesn't leak because it's kept in the builder's AST node
1070 // list.
1071 deco = builder_.Location(loc_deco->source, loc_deco->value + 1);
1072 }
1073 }
1074 }
1075
GetLocation(const ast::DecorationList & decos)1076 const ast::Decoration* FunctionEmitter::GetLocation(
1077 const ast::DecorationList& decos) {
1078 for (auto* const& deco : decos) {
1079 if (deco->Is<ast::LocationDecoration>()) {
1080 return deco;
1081 }
1082 }
1083 return nullptr;
1084 }
1085
EmitPipelineOutput(std::string var_name,const Type * var_type,ast::DecorationList * decos,std::vector<int> index_prefix,const Type * tip_type,const Type * forced_member_type,ast::StructMemberList * return_members,ast::ExpressionList * return_exprs)1086 bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
1087 const Type* var_type,
1088 ast::DecorationList* decos,
1089 std::vector<int> index_prefix,
1090 const Type* tip_type,
1091 const Type* forced_member_type,
1092 ast::StructMemberList* return_members,
1093 ast::ExpressionList* return_exprs) {
1094 tip_type = tip_type->UnwrapAlias();
1095 if (auto* ref_type = tip_type->As<Reference>()) {
1096 tip_type = ref_type->type;
1097 }
1098
1099 // Recursively flatten matrices, arrays, and structures.
1100 if (auto* matrix_type = tip_type->As<Matrix>()) {
1101 index_prefix.push_back(0);
1102 const auto num_columns = static_cast<int>(matrix_type->columns);
1103 const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
1104 for (int col = 0; col < num_columns; col++) {
1105 index_prefix.back() = col;
1106 if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
1107 forced_member_type, return_members,
1108 return_exprs)) {
1109 return false;
1110 }
1111 }
1112 return success();
1113 } else if (auto* array_type = tip_type->As<Array>()) {
1114 if (array_type->size == 0) {
1115 return Fail() << "runtime-size array not allowed on pipeline IO";
1116 }
1117 index_prefix.push_back(0);
1118 const Type* elem_ty = array_type->type;
1119 for (int i = 0; i < static_cast<int>(array_type->size); i++) {
1120 index_prefix.back() = i;
1121 if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
1122 forced_member_type, return_members,
1123 return_exprs)) {
1124 return false;
1125 }
1126 }
1127 return success();
1128 } else if (auto* struct_type = tip_type->As<Struct>()) {
1129 const auto& members = struct_type->members;
1130 index_prefix.push_back(0);
1131 for (int i = 0; i < static_cast<int>(members.size()); ++i) {
1132 index_prefix.back() = i;
1133 ast::DecorationList member_decos(*decos);
1134 if (!parser_impl_.ConvertPipelineDecorations(
1135 struct_type,
1136 parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
1137 &member_decos)) {
1138 return false;
1139 }
1140 if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
1141 members[i], forced_member_type, return_members,
1142 return_exprs)) {
1143 return false;
1144 }
1145 // Copy the location as updated by nested expansion of the member.
1146 parser_impl_.SetLocation(decos, GetLocation(member_decos));
1147 }
1148 return success();
1149 }
1150
1151 const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
1152
1153 const Type* member_type = is_builtin ? forced_member_type : tip_type;
1154 // Derive the member name directly from the variable name. They can't
1155 // collide.
1156 const auto member_name = namer_.MakeDerivedName(var_name);
1157 // Create the member.
1158 // TODO(dneto): Note: If the parameter has non-location decorations,
1159 // then those decoration AST nodes will be reused between multiple elements
1160 // of a matrix, array, or structure. Normally that's disallowed but currently
1161 // the SPIR-V reader will make duplicates when the entire AST is cloned
1162 // at the top level of the SPIR-V reader flow. Consider rewriting this
1163 // to avoid this node-sharing.
1164 return_members->push_back(
1165 builder_.Member(member_name, member_type->Build(builder_), *decos));
1166
1167 // Create an expression to evaluate the part of the variable indexed by
1168 // the index_prefix.
1169 const ast::Expression* load_source = builder_.Expr(var_name);
1170
1171 // Index into the variable as needed to pick out the flattened member.
1172 auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
1173 for (auto index : index_prefix) {
1174 if (auto* matrix_type = current_type->As<Matrix>()) {
1175 load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
1176 current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
1177 } else if (auto* array_type = current_type->As<Array>()) {
1178 load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
1179 current_type = array_type->type->UnwrapAlias();
1180 } else if (auto* struct_type = current_type->As<Struct>()) {
1181 load_source = builder_.MemberAccessor(
1182 load_source,
1183 builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
1184 current_type = struct_type->members[index];
1185 }
1186 }
1187
1188 if (is_builtin && (tip_type != forced_member_type)) {
1189 // The member will have the WGSL type, but we need bitcast to
1190 // the variable store type.
1191 load_source = create<ast::BitcastExpression>(
1192 forced_member_type->Build(builder_), load_source);
1193 }
1194 return_exprs->push_back(load_source);
1195
1196 // Increment the location attribute, in case more parameters will follow.
1197 IncrementLocation(decos);
1198
1199 return success();
1200 }
1201
EmitEntryPointAsWrapper()1202 bool FunctionEmitter::EmitEntryPointAsWrapper() {
1203 Source source;
1204
1205 // The statements in the body.
1206 ast::StatementList stmts;
1207
1208 FunctionDeclaration decl;
1209 decl.source = source;
1210 decl.name = ep_info_->name;
1211 const ast::Type* return_type = nullptr; // Populated below.
1212
1213 // Pipeline inputs become parameters to the wrapper function, and
1214 // their values are saved into the corresponding private variables that
1215 // have already been created.
1216 for (uint32_t var_id : ep_info_->inputs) {
1217 const auto* var = def_use_mgr_->GetDef(var_id);
1218 TINT_ASSERT(Reader, var != nullptr);
1219 TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
1220 auto* store_type = GetVariableStoreType(*var);
1221 auto* forced_param_type = store_type;
1222 ast::DecorationList param_decos;
1223 if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_param_type,
1224 ¶m_decos, true)) {
1225 // This occurs, and is not an error, for the PointSize builtin.
1226 if (!success()) {
1227 // But exit early if an error was logged.
1228 return false;
1229 }
1230 continue;
1231 }
1232
1233 // We don't have to handle initializers because in Vulkan SPIR-V, Input
1234 // variables must not have them.
1235
1236 const auto var_name = namer_.GetName(var_id);
1237
1238 bool ok = true;
1239 if (HasBuiltinSampleMask(param_decos)) {
1240 // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
1241 // Use the first element only.
1242 auto* sample_mask_array_type =
1243 store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
1244 TINT_ASSERT(Reader, sample_mask_array_type);
1245 ok = EmitPipelineInput(var_name, store_type, ¶m_decos, {0},
1246 sample_mask_array_type->type, forced_param_type,
1247 &(decl.params), &stmts);
1248 } else {
1249 // The normal path.
1250 ok = EmitPipelineInput(var_name, store_type, ¶m_decos, {}, store_type,
1251 forced_param_type, &(decl.params), &stmts);
1252 }
1253 if (!ok) {
1254 return false;
1255 }
1256 }
1257
1258 // Call the inner function. It has no parameters.
1259 stmts.push_back(create<ast::CallStatement>(
1260 source,
1261 create<ast::CallExpression>(
1262 source,
1263 create<ast::IdentifierExpression>(
1264 source, builder_.Symbols().Register(ep_info_->inner_name)),
1265 ast::ExpressionList{})));
1266
1267 // Pipeline outputs are mapped to the return value.
1268 if (ep_info_->outputs.empty()) {
1269 // There is nothing to return.
1270 return_type = ty_.Void()->Build(builder_);
1271 } else {
1272 // Pipeline outputs are converted to a structure that is written
1273 // to just before returning.
1274
1275 const auto return_struct_name =
1276 namer_.MakeDerivedName(ep_info_->name + "_out");
1277 const auto return_struct_sym =
1278 builder_.Symbols().Register(return_struct_name);
1279
1280 // Define the structure.
1281 std::vector<const ast::StructMember*> return_members;
1282 ast::ExpressionList return_exprs;
1283
1284 const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
1285
1286 for (uint32_t var_id : ep_info_->outputs) {
1287 if (var_id == builtin_position_info.per_vertex_var_id) {
1288 // The SPIR-V gl_PerVertex variable has already been remapped to
1289 // a gl_Position variable. Substitute the type.
1290 const Type* param_type = ty_.Vector(ty_.F32(), 4);
1291 ast::DecorationList out_decos{
1292 create<ast::BuiltinDecoration>(source, ast::Builtin::kPosition)};
1293
1294 const auto var_name = namer_.GetName(var_id);
1295 return_members.push_back(
1296 builder_.Member(var_name, param_type->Build(builder_), out_decos));
1297 return_exprs.push_back(builder_.Expr(var_name));
1298
1299 } else {
1300 const auto* var = def_use_mgr_->GetDef(var_id);
1301 TINT_ASSERT(Reader, var != nullptr);
1302 TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
1303 const Type* store_type = GetVariableStoreType(*var);
1304 const Type* forced_member_type = store_type;
1305 ast::DecorationList out_decos;
1306 if (!parser_impl_.ConvertDecorationsForVariable(
1307 var_id, &forced_member_type, &out_decos, true)) {
1308 // This occurs, and is not an error, for the PointSize builtin.
1309 if (!success()) {
1310 // But exit early if an error was logged.
1311 return false;
1312 }
1313 continue;
1314 }
1315
1316 const auto var_name = namer_.GetName(var_id);
1317 bool ok = true;
1318 if (HasBuiltinSampleMask(out_decos)) {
1319 // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a
1320 // scalar. Use the first element only.
1321 auto* sample_mask_array_type =
1322 store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
1323 TINT_ASSERT(Reader, sample_mask_array_type);
1324 ok = EmitPipelineOutput(var_name, store_type, &out_decos, {0},
1325 sample_mask_array_type->type,
1326 forced_member_type, &return_members,
1327 &return_exprs);
1328 } else {
1329 // The normal path.
1330 ok = EmitPipelineOutput(var_name, store_type, &out_decos, {},
1331 store_type, forced_member_type,
1332 &return_members, &return_exprs);
1333 }
1334 if (!ok) {
1335 return false;
1336 }
1337 }
1338 }
1339
1340 if (return_members.empty()) {
1341 // This can occur if only the PointSize member is accessed, because we
1342 // never emit it.
1343 return_type = ty_.Void()->Build(builder_);
1344 } else {
1345 // Create and register the result type.
1346 auto* str = create<ast::Struct>(Source{}, return_struct_sym,
1347 return_members, ast::DecorationList{});
1348 parser_impl_.AddTypeDecl(return_struct_sym, str);
1349 return_type = builder_.ty.Of(str);
1350
1351 // Add the return-value statement.
1352 stmts.push_back(create<ast::ReturnStatement>(
1353 source,
1354 builder_.Construct(source, return_type, std::move(return_exprs))));
1355 }
1356 }
1357
1358 auto* body = create<ast::BlockStatement>(source, stmts);
1359 ast::DecorationList fn_decos;
1360 fn_decos.emplace_back(create<ast::StageDecoration>(source, ep_info_->stage));
1361
1362 if (ep_info_->stage == ast::PipelineStage::kCompute) {
1363 auto& size = ep_info_->workgroup_size;
1364 if (size.x != 0 && size.y != 0 && size.z != 0) {
1365 const ast::Expression* x = builder_.Expr(static_cast<int>(size.x));
1366 const ast::Expression* y =
1367 size.y ? builder_.Expr(static_cast<int>(size.y)) : nullptr;
1368 const ast::Expression* z =
1369 size.z ? builder_.Expr(static_cast<int>(size.z)) : nullptr;
1370 fn_decos.emplace_back(
1371 create<ast::WorkgroupDecoration>(Source{}, x, y, z));
1372 }
1373 }
1374
1375 builder_.AST().AddFunction(
1376 create<ast::Function>(source, builder_.Symbols().Register(ep_info_->name),
1377 std::move(decl.params), return_type, body,
1378 std::move(fn_decos), ast::DecorationList{}));
1379
1380 return true;
1381 }
1382
ParseFunctionDeclaration(FunctionDeclaration * decl)1383 bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
1384 if (failed()) {
1385 return false;
1386 }
1387
1388 const std::string name = namer_.Name(function_.result_id());
1389
1390 // Surprisingly, the "type id" on an OpFunction is the result type of the
1391 // function, not the type of the function. This is the one exceptional case
1392 // in SPIR-V where the type ID is not the type of the result ID.
1393 auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
1394 if (failed()) {
1395 return false;
1396 }
1397 if (ret_ty == nullptr) {
1398 return Fail()
1399 << "internal error: unregistered return type for function with ID "
1400 << function_.result_id();
1401 }
1402
1403 ast::VariableList ast_params;
1404 function_.ForEachParam(
1405 [this, &ast_params](const spvtools::opt::Instruction* param) {
1406 auto* type = parser_impl_.ConvertType(param->type_id());
1407 if (type != nullptr) {
1408 auto* ast_param = parser_impl_.MakeVariable(
1409 param->result_id(), ast::StorageClass::kNone, type, true, nullptr,
1410 ast::DecorationList{});
1411 // Parameters are treated as const declarations.
1412 ast_params.emplace_back(ast_param);
1413 // The value is accessible by name.
1414 identifier_types_.emplace(param->result_id(), type);
1415 } else {
1416 // We've already logged an error and emitted a diagnostic. Do nothing
1417 // here.
1418 }
1419 });
1420 if (failed()) {
1421 return false;
1422 }
1423 decl->name = name;
1424 decl->params = std::move(ast_params);
1425 decl->return_type = ret_ty;
1426 decl->decorations.clear();
1427
1428 return success();
1429 }
1430
GetVariableStoreType(const spvtools::opt::Instruction & var_decl_inst)1431 const Type* FunctionEmitter::GetVariableStoreType(
1432 const spvtools::opt::Instruction& var_decl_inst) {
1433 const auto type_id = var_decl_inst.type_id();
1434 // Normally we use the SPIRV-Tools optimizer to manage types.
1435 // But when two struct types have the same member types and decorations,
1436 // but differ only in member names, the two struct types will be
1437 // represented by a single common internal struct type.
1438 // So avoid the optimizer's representation and instead follow the
1439 // SPIR-V instructions themselves.
1440 const auto* ptr_ty = def_use_mgr_->GetDef(type_id);
1441 const auto store_ty_id = ptr_ty->GetSingleWordInOperand(1);
1442 const auto* result = parser_impl_.ConvertType(store_ty_id);
1443 return result;
1444 }
1445
EmitBody()1446 bool FunctionEmitter::EmitBody() {
1447 RegisterBasicBlocks();
1448
1449 if (!TerminatorsAreValid()) {
1450 return false;
1451 }
1452 if (!RegisterMerges()) {
1453 return false;
1454 }
1455
1456 ComputeBlockOrderAndPositions();
1457 if (!VerifyHeaderContinueMergeOrder()) {
1458 return false;
1459 }
1460 if (!LabelControlFlowConstructs()) {
1461 return false;
1462 }
1463 if (!FindSwitchCaseHeaders()) {
1464 return false;
1465 }
1466 if (!ClassifyCFGEdges()) {
1467 return false;
1468 }
1469 if (!FindIfSelectionInternalHeaders()) {
1470 return false;
1471 }
1472
1473 if (!RegisterSpecialBuiltInVariables()) {
1474 return false;
1475 }
1476 if (!RegisterLocallyDefinedValues()) {
1477 return false;
1478 }
1479 FindValuesNeedingNamedOrHoistedDefinition();
1480
1481 if (!EmitFunctionVariables()) {
1482 return false;
1483 }
1484 if (!EmitFunctionBodyStatements()) {
1485 return false;
1486 }
1487 return success();
1488 }
1489
RegisterBasicBlocks()1490 void FunctionEmitter::RegisterBasicBlocks() {
1491 for (auto& block : function_) {
1492 block_info_[block.id()] = std::make_unique<BlockInfo>(block);
1493 }
1494 }
1495
TerminatorsAreValid()1496 bool FunctionEmitter::TerminatorsAreValid() {
1497 if (failed()) {
1498 return false;
1499 }
1500
1501 const auto entry_id = function_.begin()->id();
1502 for (const auto& block : function_) {
1503 if (!block.terminator()) {
1504 return Fail() << "Block " << block.id() << " has no terminator";
1505 }
1506 }
1507 for (const auto& block : function_) {
1508 block.WhileEachSuccessorLabel(
1509 [this, &block, entry_id](const uint32_t succ_id) -> bool {
1510 if (succ_id == entry_id) {
1511 return Fail() << "Block " << block.id()
1512 << " branches to function entry block " << entry_id;
1513 }
1514 if (!GetBlockInfo(succ_id)) {
1515 return Fail() << "Block " << block.id() << " in function "
1516 << function_.DefInst().result_id() << " branches to "
1517 << succ_id << " which is not a block in the function";
1518 }
1519 return true;
1520 });
1521 }
1522 return success();
1523 }
1524
RegisterMerges()1525 bool FunctionEmitter::RegisterMerges() {
1526 if (failed()) {
1527 return false;
1528 }
1529
1530 const auto entry_id = function_.begin()->id();
1531 for (const auto& block : function_) {
1532 const auto block_id = block.id();
1533 auto* block_info = GetBlockInfo(block_id);
1534 if (!block_info) {
1535 return Fail() << "internal error: block " << block_id
1536 << " missing; blocks should already "
1537 "have been registered";
1538 }
1539
1540 if (const auto* inst = block.GetMergeInst()) {
1541 auto terminator_opcode = block.terminator()->opcode();
1542 switch (inst->opcode()) {
1543 case SpvOpSelectionMerge:
1544 if ((terminator_opcode != SpvOpBranchConditional) &&
1545 (terminator_opcode != SpvOpSwitch)) {
1546 return Fail() << "Selection header " << block_id
1547 << " does not end in an OpBranchConditional or "
1548 "OpSwitch instruction";
1549 }
1550 break;
1551 case SpvOpLoopMerge:
1552 if ((terminator_opcode != SpvOpBranchConditional) &&
1553 (terminator_opcode != SpvOpBranch)) {
1554 return Fail() << "Loop header " << block_id
1555 << " does not end in an OpBranch or "
1556 "OpBranchConditional instruction";
1557 }
1558 break;
1559 default:
1560 break;
1561 }
1562
1563 const uint32_t header = block.id();
1564 auto* header_info = block_info;
1565 const uint32_t merge = inst->GetSingleWordInOperand(0);
1566 auto* merge_info = GetBlockInfo(merge);
1567 if (!merge_info) {
1568 return Fail() << "Structured header block " << header
1569 << " declares invalid merge block " << merge;
1570 }
1571 if (merge == header) {
1572 return Fail() << "Structured header block " << header
1573 << " cannot be its own merge block";
1574 }
1575 if (merge_info->header_for_merge) {
1576 return Fail() << "Block " << merge
1577 << " declared as merge block for more than one header: "
1578 << merge_info->header_for_merge << ", " << header;
1579 }
1580 merge_info->header_for_merge = header;
1581 header_info->merge_for_header = merge;
1582
1583 if (inst->opcode() == SpvOpLoopMerge) {
1584 if (header == entry_id) {
1585 return Fail() << "Function entry block " << entry_id
1586 << " cannot be a loop header";
1587 }
1588 const uint32_t ct = inst->GetSingleWordInOperand(1);
1589 auto* ct_info = GetBlockInfo(ct);
1590 if (!ct_info) {
1591 return Fail() << "Structured header " << header
1592 << " declares invalid continue target " << ct;
1593 }
1594 if (ct == merge) {
1595 return Fail() << "Invalid structured header block " << header
1596 << ": declares block " << ct
1597 << " as both its merge block and continue target";
1598 }
1599 if (ct_info->header_for_continue) {
1600 return Fail()
1601 << "Block " << ct
1602 << " declared as continue target for more than one header: "
1603 << ct_info->header_for_continue << ", " << header;
1604 }
1605 ct_info->header_for_continue = header;
1606 header_info->continue_for_header = ct;
1607 }
1608 }
1609
1610 // Check single-block loop cases.
1611 bool is_single_block_loop = false;
1612 block_info->basic_block->ForEachSuccessorLabel(
1613 [&is_single_block_loop, block_id](const uint32_t succ) {
1614 if (block_id == succ)
1615 is_single_block_loop = true;
1616 });
1617 const auto ct = block_info->continue_for_header;
1618 block_info->is_continue_entire_loop = ct == block_id;
1619 if (is_single_block_loop && !block_info->is_continue_entire_loop) {
1620 return Fail() << "Block " << block_id
1621 << " branches to itself but is not its own continue target";
1622 }
1623 // It's valid for a the header of a multi-block loop header to declare
1624 // itself as its own continue target.
1625 }
1626 return success();
1627 }
1628
ComputeBlockOrderAndPositions()1629 void FunctionEmitter::ComputeBlockOrderAndPositions() {
1630 block_order_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
1631
1632 for (uint32_t i = 0; i < block_order_.size(); ++i) {
1633 GetBlockInfo(block_order_[i])->pos = i;
1634 }
1635 // The invalid block position is not the position of any block that is in the
1636 // order.
1637 assert(block_order_.size() <= kInvalidBlockPos);
1638 }
1639
VerifyHeaderContinueMergeOrder()1640 bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
1641 // Verify interval rules for a structured header block:
1642 //
1643 // If the CFG satisfies structured control flow rules, then:
1644 // If header H is reachable, then the following "interval rules" hold,
1645 // where M(H) is H's merge block, and CT(H) is H's continue target:
1646 //
1647 // Pos(H) < Pos(M(H))
1648 //
1649 // If CT(H) exists, then:
1650 // Pos(H) <= Pos(CT(H))
1651 // Pos(CT(H)) < Pos(M)
1652 //
1653 for (auto block_id : block_order_) {
1654 const auto* block_info = GetBlockInfo(block_id);
1655 const auto merge = block_info->merge_for_header;
1656 if (merge == 0) {
1657 continue;
1658 }
1659 // This is a header.
1660 const auto header = block_id;
1661 const auto* header_info = block_info;
1662 const auto header_pos = header_info->pos;
1663 const auto merge_pos = GetBlockInfo(merge)->pos;
1664
1665 // Pos(H) < Pos(M(H))
1666 // Note: When recording merges we made sure H != M(H)
1667 if (merge_pos <= header_pos) {
1668 return Fail() << "Header " << header
1669 << " does not strictly dominate its merge block " << merge;
1670 // TODO(dneto): Report a path from the entry block to the merge block
1671 // without going through the header block.
1672 }
1673
1674 const auto ct = block_info->continue_for_header;
1675 if (ct == 0) {
1676 continue;
1677 }
1678 // Furthermore, this is a loop header.
1679 const auto* ct_info = GetBlockInfo(ct);
1680 const auto ct_pos = ct_info->pos;
1681 // Pos(H) <= Pos(CT(H))
1682 if (ct_pos < header_pos) {
1683 Fail() << "Loop header " << header
1684 << " does not dominate its continue target " << ct;
1685 }
1686 // Pos(CT(H)) < Pos(M(H))
1687 // Note: When recording merges we made sure CT(H) != M(H)
1688 if (merge_pos <= ct_pos) {
1689 return Fail() << "Merge block " << merge << " for loop headed at block "
1690 << header
1691 << " appears at or before the loop's continue "
1692 "construct headed by "
1693 "block "
1694 << ct;
1695 }
1696 }
1697 return success();
1698 }
1699
LabelControlFlowConstructs()1700 bool FunctionEmitter::LabelControlFlowConstructs() {
1701 // Label each block in the block order with its nearest enclosing structured
1702 // control flow construct. Populates the |construct| member of BlockInfo.
1703
1704 // Keep a stack of enclosing structured control flow constructs. Start
1705 // with the synthetic construct representing the entire function.
1706 //
1707 // Scan from left to right in the block order, and check conditions
1708 // on each block in the following order:
1709 //
1710 // a. When you reach a merge block, the top of the stack should
1711 // be the associated header. Pop it off.
1712 // b. When you reach a header, push it on the stack.
1713 // c. When you reach a continue target, push it on the stack.
1714 // (A block can be both a header and a continue target.)
1715 // c. When you reach a block with an edge branching backward (in the
1716 // structured order) to block T:
1717 // T should be a loop header, and the top of the stack should be a
1718 // continue target associated with T.
1719 // This is the end of the continue construct. Pop the continue
1720 // target off the stack.
1721 //
1722 // Note: A loop header can declare itself as its own continue target.
1723 //
1724 // Note: For a single-block loop, that block is a header, its own
1725 // continue target, and its own backedge block.
1726 //
1727 // Note: We pop the merge off first because a merge block that marks
1728 // the end of one construct can be a single-block loop. So that block
1729 // is a merge, a header, a continue target, and a backedge block.
1730 // But we want to finish processing of the merge before dealing with
1731 // the loop.
1732 //
1733 // In the same scan, mark each basic block with the nearest enclosing
1734 // header: the most recent header for which we haven't reached its merge
1735 // block. Also mark the the most recent continue target for which we
1736 // haven't reached the backedge block.
1737
1738 TINT_ASSERT(Reader, block_order_.size() > 0);
1739 constructs_.clear();
1740 const auto entry_id = block_order_[0];
1741
1742 // The stack of enclosing constructs.
1743 std::vector<Construct*> enclosing;
1744
1745 // Creates a control flow construct and pushes it onto the stack.
1746 // Its parent is the top of the stack, or nullptr if the stack is empty.
1747 // Returns the newly created construct.
1748 auto push_construct = [this, &enclosing](size_t depth, Construct::Kind k,
1749 uint32_t begin_id,
1750 uint32_t end_id) -> Construct* {
1751 const auto begin_pos = GetBlockInfo(begin_id)->pos;
1752 const auto end_pos =
1753 end_id == 0 ? uint32_t(block_order_.size()) : GetBlockInfo(end_id)->pos;
1754 const auto* parent = enclosing.empty() ? nullptr : enclosing.back();
1755 auto scope_end_pos = end_pos;
1756 // A loop construct is added right after its associated continue construct.
1757 // In that case, adjust the parent up.
1758 if (k == Construct::kLoop) {
1759 TINT_ASSERT(Reader, parent);
1760 TINT_ASSERT(Reader, parent->kind == Construct::kContinue);
1761 scope_end_pos = parent->end_pos;
1762 parent = parent->parent;
1763 }
1764 constructs_.push_back(std::make_unique<Construct>(
1765 parent, static_cast<int>(depth), k, begin_id, end_id, begin_pos,
1766 end_pos, scope_end_pos));
1767 Construct* result = constructs_.back().get();
1768 enclosing.push_back(result);
1769 return result;
1770 };
1771
1772 // Make a synthetic kFunction construct to enclose all blocks in the function.
1773 push_construct(0, Construct::kFunction, entry_id, 0);
1774 // The entry block can be a selection construct, so be sure to process
1775 // it anyway.
1776
1777 for (uint32_t i = 0; i < block_order_.size(); ++i) {
1778 const auto block_id = block_order_[i];
1779 TINT_ASSERT(Reader, block_id > 0);
1780 auto* block_info = GetBlockInfo(block_id);
1781 TINT_ASSERT(Reader, block_info);
1782
1783 if (enclosing.empty()) {
1784 return Fail() << "internal error: too many merge blocks before block "
1785 << block_id;
1786 }
1787 const Construct* top = enclosing.back();
1788
1789 while (block_id == top->end_id) {
1790 // We've reached a predeclared end of the construct. Pop it off the
1791 // stack.
1792 enclosing.pop_back();
1793 if (enclosing.empty()) {
1794 return Fail() << "internal error: too many merge blocks before block "
1795 << block_id;
1796 }
1797 top = enclosing.back();
1798 }
1799
1800 const auto merge = block_info->merge_for_header;
1801 if (merge != 0) {
1802 // The current block is a header.
1803 const auto header = block_id;
1804 const auto* header_info = block_info;
1805 const auto depth = 1 + top->depth;
1806 const auto ct = header_info->continue_for_header;
1807 if (ct != 0) {
1808 // The current block is a loop header.
1809 // We should see the continue construct after the loop construct, so
1810 // push the loop construct last.
1811
1812 // From the interval rule, the continue construct consists of blocks
1813 // in the block order, starting at the continue target, until just
1814 // before the merge block.
1815 top = push_construct(depth, Construct::kContinue, ct, merge);
1816 // A loop header that is its own continue target will have an
1817 // empty loop construct. Only create a loop construct when
1818 // the continue target is *not* the same as the loop header.
1819 if (header != ct) {
1820 // From the interval rule, the loop construct consists of blocks
1821 // in the block order, starting at the header, until just
1822 // before the continue target.
1823 top = push_construct(depth, Construct::kLoop, header, ct);
1824
1825 // If the loop header branches to two different blocks inside the loop
1826 // construct, then the loop body should be modeled as an if-selection
1827 // construct
1828 std::vector<uint32_t> targets;
1829 header_info->basic_block->ForEachSuccessorLabel(
1830 [&targets](const uint32_t target) { targets.push_back(target); });
1831 if ((targets.size() == 2u) && targets[0] != targets[1]) {
1832 const auto target0_pos = GetBlockInfo(targets[0])->pos;
1833 const auto target1_pos = GetBlockInfo(targets[1])->pos;
1834 if (top->ContainsPos(target0_pos) &&
1835 top->ContainsPos(target1_pos)) {
1836 // Insert a synthetic if-selection
1837 top = push_construct(depth + 1, Construct::kIfSelection, header,
1838 ct);
1839 }
1840 }
1841 }
1842 } else {
1843 // From the interval rule, the selection construct consists of blocks
1844 // in the block order, starting at the header, until just before the
1845 // merge block.
1846 const auto branch_opcode =
1847 header_info->basic_block->terminator()->opcode();
1848 const auto kind = (branch_opcode == SpvOpBranchConditional)
1849 ? Construct::kIfSelection
1850 : Construct::kSwitchSelection;
1851 top = push_construct(depth, kind, header, merge);
1852 }
1853 }
1854
1855 TINT_ASSERT(Reader, top);
1856 block_info->construct = top;
1857 }
1858
1859 // At the end of the block list, we should only have the kFunction construct
1860 // left.
1861 if (enclosing.size() != 1) {
1862 return Fail() << "internal error: unbalanced structured constructs when "
1863 "labeling structured constructs: ended with "
1864 << enclosing.size() - 1 << " unterminated constructs";
1865 }
1866 const auto* top = enclosing[0];
1867 if (top->kind != Construct::kFunction || top->depth != 0) {
1868 return Fail() << "internal error: outermost construct is not a function?!";
1869 }
1870
1871 return success();
1872 }
1873
FindSwitchCaseHeaders()1874 bool FunctionEmitter::FindSwitchCaseHeaders() {
1875 if (failed()) {
1876 return false;
1877 }
1878 for (auto& construct : constructs_) {
1879 if (construct->kind != Construct::kSwitchSelection) {
1880 continue;
1881 }
1882 const auto* branch =
1883 GetBlockInfo(construct->begin_id)->basic_block->terminator();
1884
1885 // Mark the default block
1886 const auto default_id = branch->GetSingleWordInOperand(1);
1887 auto* default_block = GetBlockInfo(default_id);
1888 // A default target can't be a backedge.
1889 if (construct->begin_pos >= default_block->pos) {
1890 // An OpSwitch must dominate its cases. Also, it can't be a self-loop
1891 // as that would be a backedge, and backedges can only target a loop,
1892 // and loops use an OpLoopMerge instruction, which can't precede an
1893 // OpSwitch.
1894 return Fail() << "Switch branch from block " << construct->begin_id
1895 << " to default target block " << default_id
1896 << " can't be a back-edge";
1897 }
1898 // A default target can be the merge block, but can't go past it.
1899 if (construct->end_pos < default_block->pos) {
1900 return Fail() << "Switch branch from block " << construct->begin_id
1901 << " to default block " << default_id
1902 << " escapes the selection construct";
1903 }
1904 if (default_block->default_head_for) {
1905 // An OpSwitch must dominate its cases, including the default target.
1906 return Fail() << "Block " << default_id
1907 << " is declared as the default target for two OpSwitch "
1908 "instructions, at blocks "
1909 << default_block->default_head_for->begin_id << " and "
1910 << construct->begin_id;
1911 }
1912 if ((default_block->header_for_merge != 0) &&
1913 (default_block->header_for_merge != construct->begin_id)) {
1914 // The switch instruction for this default block is an alternate path to
1915 // the merge block, and hence the merge block is not dominated by its own
1916 // (different) header.
1917 return Fail() << "Block " << default_block->id
1918 << " is the default block for switch-selection header "
1919 << construct->begin_id << " and also the merge block for "
1920 << default_block->header_for_merge
1921 << " (violates dominance rule)";
1922 }
1923
1924 default_block->default_head_for = construct.get();
1925 default_block->default_is_merge = default_block->pos == construct->end_pos;
1926
1927 // Map a case target to the list of values selecting that case.
1928 std::unordered_map<uint32_t, std::vector<uint64_t>> block_to_values;
1929 std::vector<uint32_t> case_targets;
1930 std::unordered_set<uint64_t> case_values;
1931
1932 // Process case targets.
1933 for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
1934 const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
1935 const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
1936
1937 if (case_values.count(value)) {
1938 return Fail() << "Duplicate case value " << value
1939 << " in OpSwitch in block " << construct->begin_id;
1940 }
1941 case_values.insert(value);
1942 if (block_to_values.count(case_target_id) == 0) {
1943 case_targets.push_back(case_target_id);
1944 }
1945 block_to_values[case_target_id].push_back(value);
1946 }
1947
1948 for (uint32_t case_target_id : case_targets) {
1949 auto* case_block = GetBlockInfo(case_target_id);
1950
1951 case_block->case_values = std::make_unique<std::vector<uint64_t>>(
1952 std::move(block_to_values[case_target_id]));
1953
1954 // A case target can't be a back-edge.
1955 if (construct->begin_pos >= case_block->pos) {
1956 // An OpSwitch must dominate its cases. Also, it can't be a self-loop
1957 // as that would be a backedge, and backedges can only target a loop,
1958 // and loops use an OpLoopMerge instruction, which can't preceded an
1959 // OpSwitch.
1960 return Fail() << "Switch branch from block " << construct->begin_id
1961 << " to case target block " << case_target_id
1962 << " can't be a back-edge";
1963 }
1964 // A case target can be the merge block, but can't go past it.
1965 if (construct->end_pos < case_block->pos) {
1966 return Fail() << "Switch branch from block " << construct->begin_id
1967 << " to case target block " << case_target_id
1968 << " escapes the selection construct";
1969 }
1970 if (case_block->header_for_merge != 0 &&
1971 case_block->header_for_merge != construct->begin_id) {
1972 // The switch instruction for this case block is an alternate path to
1973 // the merge block, and hence the merge block is not dominated by its
1974 // own (different) header.
1975 return Fail() << "Block " << case_block->id
1976 << " is a case block for switch-selection header "
1977 << construct->begin_id << " and also the merge block for "
1978 << case_block->header_for_merge
1979 << " (violates dominance rule)";
1980 }
1981
1982 // Mark the target as a case target.
1983 if (case_block->case_head_for) {
1984 // An OpSwitch must dominate its cases.
1985 return Fail()
1986 << "Block " << case_target_id
1987 << " is declared as the switch case target for two OpSwitch "
1988 "instructions, at blocks "
1989 << case_block->case_head_for->begin_id << " and "
1990 << construct->begin_id;
1991 }
1992 case_block->case_head_for = construct.get();
1993 }
1994 }
1995 return success();
1996 }
1997
HeaderIfBreakable(const Construct * c)1998 BlockInfo* FunctionEmitter::HeaderIfBreakable(const Construct* c) {
1999 if (c == nullptr) {
2000 return nullptr;
2001 }
2002 switch (c->kind) {
2003 case Construct::kLoop:
2004 case Construct::kSwitchSelection:
2005 return GetBlockInfo(c->begin_id);
2006 case Construct::kContinue: {
2007 const auto* continue_target = GetBlockInfo(c->begin_id);
2008 return GetBlockInfo(continue_target->header_for_continue);
2009 }
2010 default:
2011 break;
2012 }
2013 return nullptr;
2014 }
2015
SiblingLoopConstruct(const Construct * c) const2016 const Construct* FunctionEmitter::SiblingLoopConstruct(
2017 const Construct* c) const {
2018 if (c == nullptr || c->kind != Construct::kContinue) {
2019 return nullptr;
2020 }
2021 const uint32_t continue_target_id = c->begin_id;
2022 const auto* continue_target = GetBlockInfo(continue_target_id);
2023 const uint32_t header_id = continue_target->header_for_continue;
2024 if (continue_target_id == header_id) {
2025 // The continue target is the whole loop.
2026 return nullptr;
2027 }
2028 const auto* candidate = GetBlockInfo(header_id)->construct;
2029 // Walk up the construct tree until we hit the loop. In future
2030 // we might handle the corner case where the same block is both a
2031 // loop header and a selection header. For example, where the
2032 // loop header block has a conditional branch going to distinct
2033 // targets inside the loop body.
2034 while (candidate && candidate->kind != Construct::kLoop) {
2035 candidate = candidate->parent;
2036 }
2037 return candidate;
2038 }
2039
ClassifyCFGEdges()2040 bool FunctionEmitter::ClassifyCFGEdges() {
2041 if (failed()) {
2042 return false;
2043 }
2044
2045 // Checks validity of CFG edges leaving each basic block. This implicitly
2046 // checks dominance rules for headers and continue constructs.
2047 //
2048 // For each branch encountered, classify each edge (S,T) as:
2049 // - a back-edge
2050 // - a structured exit (specific ways of branching to enclosing construct)
2051 // - a normal (forward) edge, either natural control flow or a case
2052 // fallthrough
2053 //
2054 // If more than one block is targeted by a normal edge, then S must be a
2055 // structured header.
2056 //
2057 // Term: NEC(B) is the nearest enclosing construct for B.
2058 //
2059 // If edge (S,T) is a normal edge, and NEC(S) != NEC(T), then
2060 // T is the header block of its NEC(T), and
2061 // NEC(S) is the parent of NEC(T).
2062
2063 for (const auto src : block_order_) {
2064 TINT_ASSERT(Reader, src > 0);
2065 auto* src_info = GetBlockInfo(src);
2066 TINT_ASSERT(Reader, src_info);
2067 const auto src_pos = src_info->pos;
2068 const auto& src_construct = *(src_info->construct);
2069
2070 // Compute the ordered list of unique successors.
2071 std::vector<uint32_t> successors;
2072 {
2073 std::unordered_set<uint32_t> visited;
2074 src_info->basic_block->ForEachSuccessorLabel(
2075 [&successors, &visited](const uint32_t succ) {
2076 if (visited.count(succ) == 0) {
2077 successors.push_back(succ);
2078 visited.insert(succ);
2079 }
2080 });
2081 }
2082
2083 // There should only be one backedge per backedge block.
2084 uint32_t num_backedges = 0;
2085
2086 // Track destinations for normal forward edges, either kForward
2087 // or kCaseFallThrough. These count toward the need
2088 // to have a merge instruction. We also track kIfBreak edges
2089 // because when used with normal forward edges, we'll need
2090 // to generate a flow guard variable.
2091 std::vector<uint32_t> normal_forward_edges;
2092 std::vector<uint32_t> if_break_edges;
2093
2094 if (successors.empty() && src_construct.enclosing_continue) {
2095 // Kill and return are not allowed in a continue construct.
2096 return Fail() << "Invalid function exit at block " << src
2097 << " from continue construct starting at "
2098 << src_construct.enclosing_continue->begin_id;
2099 }
2100
2101 for (const auto dest : successors) {
2102 const auto* dest_info = GetBlockInfo(dest);
2103 // We've already checked terminators are valid.
2104 TINT_ASSERT(Reader, dest_info);
2105 const auto dest_pos = dest_info->pos;
2106
2107 // Insert the edge kind entry and keep a handle to update
2108 // its classification.
2109 EdgeKind& edge_kind = src_info->succ_edge[dest];
2110
2111 if (src_pos >= dest_pos) {
2112 // This is a backedge.
2113 edge_kind = EdgeKind::kBack;
2114 num_backedges++;
2115 const auto* continue_construct = src_construct.enclosing_continue;
2116 if (!continue_construct) {
2117 return Fail() << "Invalid backedge (" << src << "->" << dest
2118 << "): " << src << " is not in a continue construct";
2119 }
2120 if (src_pos != continue_construct->end_pos - 1) {
2121 return Fail() << "Invalid exit (" << src << "->" << dest
2122 << ") from continue construct: " << src
2123 << " is not the last block in the continue construct "
2124 "starting at "
2125 << src_construct.begin_id
2126 << " (violates post-dominance rule)";
2127 }
2128 const auto* ct_info = GetBlockInfo(continue_construct->begin_id);
2129 TINT_ASSERT(Reader, ct_info);
2130 if (ct_info->header_for_continue != dest) {
2131 return Fail()
2132 << "Invalid backedge (" << src << "->" << dest
2133 << "): does not branch to the corresponding loop header, "
2134 "expected "
2135 << ct_info->header_for_continue;
2136 }
2137 } else {
2138 // This is a forward edge.
2139 // For now, classify it that way, but we might update it.
2140 edge_kind = EdgeKind::kForward;
2141
2142 // Exit from a continue construct can only be from the last block.
2143 const auto* continue_construct = src_construct.enclosing_continue;
2144 if (continue_construct != nullptr) {
2145 if (continue_construct->ContainsPos(src_pos) &&
2146 !continue_construct->ContainsPos(dest_pos) &&
2147 (src_pos != continue_construct->end_pos - 1)) {
2148 return Fail() << "Invalid exit (" << src << "->" << dest
2149 << ") from continue construct: " << src
2150 << " is not the last block in the continue construct "
2151 "starting at "
2152 << continue_construct->begin_id
2153 << " (violates post-dominance rule)";
2154 }
2155 }
2156
2157 // Check valid structured exit cases.
2158
2159 if (edge_kind == EdgeKind::kForward) {
2160 // Check for a 'break' from a loop or from a switch.
2161 const auto* breakable_header = HeaderIfBreakable(
2162 src_construct.enclosing_loop_or_continue_or_switch);
2163 if (breakable_header != nullptr) {
2164 if (dest == breakable_header->merge_for_header) {
2165 // It's a break.
2166 edge_kind = (breakable_header->construct->kind ==
2167 Construct::kSwitchSelection)
2168 ? EdgeKind::kSwitchBreak
2169 : EdgeKind::kLoopBreak;
2170 }
2171 }
2172 }
2173
2174 if (edge_kind == EdgeKind::kForward) {
2175 // Check for a 'continue' from within a loop.
2176 const auto* loop_header =
2177 HeaderIfBreakable(src_construct.enclosing_loop);
2178 if (loop_header != nullptr) {
2179 if (dest == loop_header->continue_for_header) {
2180 // It's a continue.
2181 edge_kind = EdgeKind::kLoopContinue;
2182 }
2183 }
2184 }
2185
2186 if (edge_kind == EdgeKind::kForward) {
2187 const auto& header_info = *GetBlockInfo(src_construct.begin_id);
2188 if (dest == header_info.merge_for_header) {
2189 // Branch to construct's merge block. The loop break and
2190 // switch break cases have already been covered.
2191 edge_kind = EdgeKind::kIfBreak;
2192 }
2193 }
2194
2195 // A forward edge into a case construct that comes from something
2196 // other than the OpSwitch is actually a fallthrough.
2197 if (edge_kind == EdgeKind::kForward) {
2198 const auto* switch_construct =
2199 (dest_info->case_head_for ? dest_info->case_head_for
2200 : dest_info->default_head_for);
2201 if (switch_construct != nullptr) {
2202 if (src != switch_construct->begin_id) {
2203 edge_kind = EdgeKind::kCaseFallThrough;
2204 }
2205 }
2206 }
2207
2208 // The edge-kind has been finalized.
2209
2210 if ((edge_kind == EdgeKind::kForward) ||
2211 (edge_kind == EdgeKind::kCaseFallThrough)) {
2212 normal_forward_edges.push_back(dest);
2213 }
2214 if (edge_kind == EdgeKind::kIfBreak) {
2215 if_break_edges.push_back(dest);
2216 }
2217
2218 if ((edge_kind == EdgeKind::kForward) ||
2219 (edge_kind == EdgeKind::kCaseFallThrough)) {
2220 // Check for an invalid forward exit out of this construct.
2221 if (dest_info->pos > src_construct.end_pos) {
2222 // In most cases we're bypassing the merge block for the source
2223 // construct.
2224 auto end_block = src_construct.end_id;
2225 const char* end_block_desc = "merge block";
2226 if (src_construct.kind == Construct::kLoop) {
2227 // For a loop construct, we have two valid places to go: the
2228 // continue target or the merge for the loop header, which is
2229 // further down.
2230 const auto loop_merge =
2231 GetBlockInfo(src_construct.begin_id)->merge_for_header;
2232 if (dest_info->pos >= GetBlockInfo(loop_merge)->pos) {
2233 // We're bypassing the loop's merge block.
2234 end_block = loop_merge;
2235 } else {
2236 // We're bypassing the loop's continue target, and going into
2237 // the middle of the continue construct.
2238 end_block_desc = "continue target";
2239 }
2240 }
2241 return Fail()
2242 << "Branch from block " << src << " to block " << dest
2243 << " is an invalid exit from construct starting at block "
2244 << src_construct.begin_id << "; branch bypasses "
2245 << end_block_desc << " " << end_block;
2246 }
2247
2248 // Check dominance.
2249
2250 // Look for edges that violate the dominance condition: a branch
2251 // from X to Y where:
2252 // If Y is in a nearest enclosing continue construct headed by
2253 // CT:
2254 // Y is not CT, and
2255 // In the structured order, X appears before CT order or
2256 // after CT's backedge block.
2257 // Otherwise, if Y is in a nearest enclosing construct
2258 // headed by H:
2259 // Y is not H, and
2260 // In the structured order, X appears before H or after H's
2261 // merge block.
2262
2263 const auto& dest_construct = *(dest_info->construct);
2264 if (dest != dest_construct.begin_id &&
2265 !dest_construct.ContainsPos(src_pos)) {
2266 return Fail() << "Branch from " << src << " to " << dest
2267 << " bypasses "
2268 << (dest_construct.kind == Construct::kContinue
2269 ? "continue target "
2270 : "header ")
2271 << dest_construct.begin_id
2272 << " (dominance rule violated)";
2273 }
2274 }
2275 } // end forward edge
2276 } // end successor
2277
2278 if (num_backedges > 1) {
2279 return Fail() << "Block " << src
2280 << " has too many backedges: " << num_backedges;
2281 }
2282 if ((normal_forward_edges.size() > 1) &&
2283 (src_info->merge_for_header == 0)) {
2284 return Fail() << "Control flow diverges at block " << src << " (to "
2285 << normal_forward_edges[0] << ", "
2286 << normal_forward_edges[1]
2287 << ") but it is not a structured header (it has no merge "
2288 "instruction)";
2289 }
2290 if ((normal_forward_edges.size() + if_break_edges.size() > 1) &&
2291 (src_info->merge_for_header == 0)) {
2292 // There is a branch to the merge of an if-selection combined
2293 // with an other normal forward branch. Control within the
2294 // if-selection needs to be gated by a flow predicate.
2295 for (auto if_break_dest : if_break_edges) {
2296 auto* head_info =
2297 GetBlockInfo(GetBlockInfo(if_break_dest)->header_for_merge);
2298 // Generate a guard name, but only once.
2299 if (head_info->flow_guard_name.empty()) {
2300 const std::string guard = "guard" + std::to_string(head_info->id);
2301 head_info->flow_guard_name = namer_.MakeDerivedName(guard);
2302 }
2303 }
2304 }
2305 }
2306
2307 return success();
2308 }
2309
FindIfSelectionInternalHeaders()2310 bool FunctionEmitter::FindIfSelectionInternalHeaders() {
2311 if (failed()) {
2312 return false;
2313 }
2314 for (auto& construct : constructs_) {
2315 if (construct->kind != Construct::kIfSelection) {
2316 continue;
2317 }
2318 auto* if_header_info = GetBlockInfo(construct->begin_id);
2319 const auto* branch = if_header_info->basic_block->terminator();
2320 const auto true_head = branch->GetSingleWordInOperand(1);
2321 const auto false_head = branch->GetSingleWordInOperand(2);
2322
2323 auto* true_head_info = GetBlockInfo(true_head);
2324 auto* false_head_info = GetBlockInfo(false_head);
2325 const auto true_head_pos = true_head_info->pos;
2326 const auto false_head_pos = false_head_info->pos;
2327
2328 const bool contains_true = construct->ContainsPos(true_head_pos);
2329 const bool contains_false = construct->ContainsPos(false_head_pos);
2330
2331 // The cases for each edge are:
2332 // - kBack: invalid because it's an invalid exit from the selection
2333 // - kSwitchBreak ; record this for later special processing
2334 // - kLoopBreak ; record this for later special processing
2335 // - kLoopContinue ; record this for later special processing
2336 // - kIfBreak; normal case, may require a guard variable.
2337 // - kFallThrough; invalid exit from the selection
2338 // - kForward; normal case
2339
2340 if_header_info->true_kind = if_header_info->succ_edge[true_head];
2341 if_header_info->false_kind = if_header_info->succ_edge[false_head];
2342 if (contains_true) {
2343 if_header_info->true_head = true_head;
2344 }
2345 if (contains_false) {
2346 if_header_info->false_head = false_head;
2347 }
2348
2349 if (contains_true && (true_head_info->header_for_merge != 0) &&
2350 (true_head_info->header_for_merge != construct->begin_id)) {
2351 // The OpBranchConditional instruction for the true head block is an
2352 // alternate path to the merge block of a construct nested inside the
2353 // selection, and hence the merge block is not dominated by its own
2354 // (different) header.
2355 return Fail() << "Block " << true_head
2356 << " is the true branch for if-selection header "
2357 << construct->begin_id
2358 << " and also the merge block for header block "
2359 << true_head_info->header_for_merge
2360 << " (violates dominance rule)";
2361 }
2362 if (contains_false && (false_head_info->header_for_merge != 0) &&
2363 (false_head_info->header_for_merge != construct->begin_id)) {
2364 // The OpBranchConditional instruction for the false head block is an
2365 // alternate path to the merge block of a construct nested inside the
2366 // selection, and hence the merge block is not dominated by its own
2367 // (different) header.
2368 return Fail() << "Block " << false_head
2369 << " is the false branch for if-selection header "
2370 << construct->begin_id
2371 << " and also the merge block for header block "
2372 << false_head_info->header_for_merge
2373 << " (violates dominance rule)";
2374 }
2375
2376 if (contains_true && contains_false && (true_head_pos != false_head_pos)) {
2377 // This construct has both a "then" clause and an "else" clause.
2378 //
2379 // We have this structure:
2380 //
2381 // Option 1:
2382 //
2383 // * condbranch
2384 // * true-head (start of then-clause)
2385 // ...
2386 // * end-then-clause
2387 // * false-head (start of else-clause)
2388 // ...
2389 // * end-false-clause
2390 // * premerge-head
2391 // ...
2392 // * selection merge
2393 //
2394 // Option 2:
2395 //
2396 // * condbranch
2397 // * true-head (start of then-clause)
2398 // ...
2399 // * end-then-clause
2400 // * false-head (start of else-clause) and also premerge-head
2401 // ...
2402 // * end-false-clause
2403 // * selection merge
2404 //
2405 // Option 3:
2406 //
2407 // * condbranch
2408 // * false-head (start of else-clause)
2409 // ...
2410 // * end-else-clause
2411 // * true-head (start of then-clause) and also premerge-head
2412 // ...
2413 // * end-then-clause
2414 // * selection merge
2415 //
2416 // The premerge-head exists if there is a kForward branch from the end
2417 // of the first clause to a block within the surrounding selection.
2418 // The first clause might be a then-clause or an else-clause.
2419 const auto second_head = std::max(true_head_pos, false_head_pos);
2420 const auto end_first_clause_pos = second_head - 1;
2421 TINT_ASSERT(Reader, end_first_clause_pos < block_order_.size());
2422 const auto end_first_clause = block_order_[end_first_clause_pos];
2423 uint32_t premerge_id = 0;
2424 uint32_t if_break_id = 0;
2425 for (auto& then_succ_iter : GetBlockInfo(end_first_clause)->succ_edge) {
2426 const uint32_t dest_id = then_succ_iter.first;
2427 const auto edge_kind = then_succ_iter.second;
2428 switch (edge_kind) {
2429 case EdgeKind::kIfBreak:
2430 if_break_id = dest_id;
2431 break;
2432 case EdgeKind::kForward: {
2433 if (construct->ContainsPos(GetBlockInfo(dest_id)->pos)) {
2434 // It's a premerge.
2435 if (premerge_id != 0) {
2436 // TODO(dneto): I think this is impossible to trigger at this
2437 // point in the flow. It would require a merge instruction to
2438 // get past the check of "at-most-one-forward-edge".
2439 return Fail()
2440 << "invalid structure: then-clause headed by block "
2441 << true_head << " ending at block " << end_first_clause
2442 << " has two forward edges to within selection"
2443 << " going to " << premerge_id << " and " << dest_id;
2444 }
2445 premerge_id = dest_id;
2446 auto* dest_block_info = GetBlockInfo(dest_id);
2447 if_header_info->premerge_head = dest_id;
2448 if (dest_block_info->header_for_merge != 0) {
2449 // Premerge has two edges coming into it, from the then-clause
2450 // and the else-clause. It's also, by construction, not the
2451 // merge block of the if-selection. So it must not be a merge
2452 // block itself. The OpBranchConditional instruction for the
2453 // false head block is an alternate path to the merge block, and
2454 // hence the merge block is not dominated by its own (different)
2455 // header.
2456 return Fail()
2457 << "Block " << premerge_id << " is the merge block for "
2458 << dest_block_info->header_for_merge
2459 << " but has alternate paths reaching it, starting from"
2460 << " blocks " << true_head << " and " << false_head
2461 << " which are the true and false branches for the"
2462 << " if-selection header block " << construct->begin_id
2463 << " (violates dominance rule)";
2464 }
2465 }
2466 break;
2467 }
2468 default:
2469 break;
2470 }
2471 }
2472 if (if_break_id != 0 && premerge_id != 0) {
2473 return Fail() << "Block " << end_first_clause
2474 << " in if-selection headed at block "
2475 << construct->begin_id
2476 << " branches to both the merge block " << if_break_id
2477 << " and also to block " << premerge_id
2478 << " later in the selection";
2479 }
2480 }
2481 }
2482 return success();
2483 }
2484
EmitFunctionVariables()2485 bool FunctionEmitter::EmitFunctionVariables() {
2486 if (failed()) {
2487 return false;
2488 }
2489 for (auto& inst : *function_.entry()) {
2490 if (inst.opcode() != SpvOpVariable) {
2491 continue;
2492 }
2493 auto* var_store_type = GetVariableStoreType(inst);
2494 if (failed()) {
2495 return false;
2496 }
2497 const ast::Expression* constructor = nullptr;
2498 if (inst.NumInOperands() > 1) {
2499 // SPIR-V initializers are always constants.
2500 // (OpenCL also allows the ID of an OpVariable, but we don't handle that
2501 // here.)
2502 constructor =
2503 parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))
2504 .expr;
2505 if (!constructor) {
2506 return false;
2507 }
2508 }
2509 auto* var = parser_impl_.MakeVariable(
2510 inst.result_id(), ast::StorageClass::kNone, var_store_type, false,
2511 constructor, ast::DecorationList{});
2512 auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
2513 AddStatement(var_decl_stmt);
2514 auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
2515 identifier_types_.emplace(inst.result_id(), var_type);
2516 }
2517 return success();
2518 }
2519
AddressOfIfNeeded(TypedExpression expr,const spvtools::opt::Instruction * inst)2520 TypedExpression FunctionEmitter::AddressOfIfNeeded(
2521 TypedExpression expr,
2522 const spvtools::opt::Instruction* inst) {
2523 if (inst && expr) {
2524 if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) {
2525 if (expr.type->Is<Reference>() && spirv_type->AsPointer()) {
2526 return AddressOf(expr);
2527 }
2528 }
2529 }
2530 return expr;
2531 }
2532
MakeExpression(uint32_t id)2533 TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
2534 if (failed()) {
2535 return {};
2536 }
2537 switch (GetSkipReason(id)) {
2538 case SkipReason::kDontSkip:
2539 break;
2540 case SkipReason::kOpaqueObject:
2541 Fail() << "internal error: unhandled use of opaque object with ID: "
2542 << id;
2543 return {};
2544 case SkipReason::kSinkPointerIntoUse: {
2545 // Replace the pointer with its source reference expression.
2546 auto source_expr = GetDefInfo(id)->sink_pointer_source_expr;
2547 TINT_ASSERT(Reader, source_expr.type->Is<Reference>());
2548 return source_expr;
2549 }
2550 case SkipReason::kPointSizeBuiltinValue: {
2551 return {ty_.F32(), create<ast::FloatLiteralExpression>(Source{}, 1.0f)};
2552 }
2553 case SkipReason::kPointSizeBuiltinPointer:
2554 Fail() << "unhandled use of a pointer to the PointSize builtin, with ID: "
2555 << id;
2556 return {};
2557 case SkipReason::kSampleMaskInBuiltinPointer:
2558 Fail()
2559 << "unhandled use of a pointer to the SampleMask builtin, with ID: "
2560 << id;
2561 return {};
2562 case SkipReason::kSampleMaskOutBuiltinPointer: {
2563 // The result type is always u32.
2564 auto name = namer_.Name(sample_mask_out_id);
2565 return TypedExpression{ty_.U32(),
2566 create<ast::IdentifierExpression>(
2567 Source{}, builder_.Symbols().Register(name))};
2568 }
2569 }
2570 auto type_it = identifier_types_.find(id);
2571 if (type_it != identifier_types_.end()) {
2572 auto name = namer_.Name(id);
2573 auto* type = type_it->second;
2574 return TypedExpression{type,
2575 create<ast::IdentifierExpression>(
2576 Source{}, builder_.Symbols().Register(name))};
2577 }
2578 if (parser_impl_.IsScalarSpecConstant(id)) {
2579 auto name = namer_.Name(id);
2580 return TypedExpression{
2581 parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
2582 create<ast::IdentifierExpression>(Source{},
2583 builder_.Symbols().Register(name))};
2584 }
2585 if (singly_used_values_.count(id)) {
2586 auto expr = std::move(singly_used_values_[id]);
2587 singly_used_values_.erase(id);
2588 return expr;
2589 }
2590 const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id);
2591 if (spirv_constant) {
2592 return parser_impl_.MakeConstantExpression(id);
2593 }
2594 const auto* inst = def_use_mgr_->GetDef(id);
2595 if (inst == nullptr) {
2596 Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
2597 return {};
2598 }
2599 switch (inst->opcode()) {
2600 case SpvOpVariable: {
2601 // This occurs for module-scope variables.
2602 auto name = namer_.Name(inst->result_id());
2603 return TypedExpression{
2604 parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref),
2605 create<ast::IdentifierExpression>(Source{},
2606 builder_.Symbols().Register(name))};
2607 }
2608 case SpvOpUndef:
2609 // Substitute a null value for undef.
2610 // This case occurs when OpUndef appears at module scope, as if it were
2611 // a constant.
2612 return parser_impl_.MakeNullExpression(
2613 parser_impl_.ConvertType(inst->type_id()));
2614
2615 default:
2616 break;
2617 }
2618 if (const spvtools::opt::BasicBlock* const bb =
2619 ir_context_.get_instr_block(id)) {
2620 if (auto* block = GetBlockInfo(bb->id())) {
2621 if (block->pos == kInvalidBlockPos) {
2622 // The value came from a block not in the block order.
2623 // Substitute a null value.
2624 return parser_impl_.MakeNullExpression(
2625 parser_impl_.ConvertType(inst->type_id()));
2626 }
2627 }
2628 }
2629 Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
2630 return {};
2631 }
2632
EmitFunctionBodyStatements()2633 bool FunctionEmitter::EmitFunctionBodyStatements() {
2634 // Dump the basic blocks in order, grouped by construct.
2635
2636 // We maintain a stack of StatementBlock objects, where new statements
2637 // are always written to the topmost entry of the stack. By this point in
2638 // processing, we have already recorded the interesting control flow
2639 // boundaries in the BlockInfo and associated Construct objects. As we
2640 // enter a new statement grouping, we push onto the stack, and also schedule
2641 // the statement block's completion and removal at a future block's ID.
2642
2643 // Upon entry, the statement stack has one entry representing the whole
2644 // function.
2645 TINT_ASSERT(Reader, !constructs_.empty());
2646 Construct* function_construct = constructs_[0].get();
2647 TINT_ASSERT(Reader, function_construct != nullptr);
2648 TINT_ASSERT(Reader, function_construct->kind == Construct::kFunction);
2649 // Make the first entry valid by filling in the construct field, which
2650 // had not been computed at the time the entry was first created.
2651 // TODO(dneto): refactor how the first construct is created vs.
2652 // this statements stack entry is populated.
2653 TINT_ASSERT(Reader, statements_stack_.size() == 1);
2654 statements_stack_[0].SetConstruct(function_construct);
2655
2656 for (auto block_id : block_order()) {
2657 if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
2658 return false;
2659 }
2660 }
2661 return success();
2662 }
2663
EmitBasicBlock(const BlockInfo & block_info)2664 bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
2665 // Close off previous constructs.
2666 while (!statements_stack_.empty() &&
2667 (statements_stack_.back().GetEndId() == block_info.id)) {
2668 statements_stack_.back().Finalize(&builder_);
2669 statements_stack_.pop_back();
2670 }
2671 if (statements_stack_.empty()) {
2672 return Fail() << "internal error: statements stack empty at block "
2673 << block_info.id;
2674 }
2675
2676 // Enter new constructs.
2677
2678 std::vector<const Construct*> entering_constructs; // inner most comes first
2679 {
2680 auto* here = block_info.construct;
2681 auto* const top_construct = statements_stack_.back().GetConstruct();
2682 while (here != top_construct) {
2683 // Only enter a construct at its header block.
2684 if (here->begin_id == block_info.id) {
2685 entering_constructs.push_back(here);
2686 }
2687 here = here->parent;
2688 }
2689 }
2690 // What constructs can we have entered?
2691 // - It can't be kFunction, because there is only one of those, and it was
2692 // already on the stack at the outermost level.
2693 // - We have at most one of kSwitchSelection, or kLoop because each of those
2694 // is headed by a block with a merge instruction (OpLoopMerge for kLoop,
2695 // and OpSelectionMerge for kSwitchSelection).
2696 // - When there is a kIfSelection, it can't contain another construct,
2697 // because both would have to have their own distinct merge instructions
2698 // and distinct terminators.
2699 // - A kContinue can contain a kContinue
2700 // This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
2701 // that a block can be continue target for at most one header block. See
2702 // test DISABLED_BlockIsContinueForMoreThanOneHeader. If we generalize this,
2703 // then by a dominance argument, the inner loop continue target can only be
2704 // a single-block loop.
2705 // TODO(dneto): Handle this case.
2706 // - If a kLoop is on the outside, its terminator is either:
2707 // - an OpBranch, in which case there is no other construct.
2708 // - an OpBranchConditional, in which case there is either an kIfSelection
2709 // (when both branch targets are different and are inside the loop),
2710 // or no other construct (because the branch targets are the same,
2711 // or one of them is a break or continue).
2712 // - All that's left is a kContinue on the outside, and one of
2713 // kIfSelection, kSwitchSelection, kLoop on the inside.
2714 //
2715 // The kContinue can be the parent of the other. For example, a selection
2716 // starting at the first block of a continue construct.
2717 //
2718 // The kContinue can't be the child of the other because either:
2719 // - The other can't be kLoop because:
2720 // - If the kLoop is for a different loop then the kContinue, then
2721 // the kContinue must be its own loop header, and so the same
2722 // block is two different loops. That's a contradiction.
2723 // - If the kLoop is for a the same loop, then this is a contradiction
2724 // because a kContinue and its kLoop have disjoint block sets.
2725 // - The other construct can't be a selection because:
2726 // - The kContinue construct is the entire loop, i.e. the continue
2727 // target is its own loop header block. But then the continue target
2728 // has an OpLoopMerge instruction, which contradicts this block being
2729 // a selection header.
2730 // - The kContinue is in a multi-block loop that is has a non-empty
2731 // kLoop; and the selection contains the kContinue block but not the
2732 // loop block. That breaks dominance rules. That is, the continue
2733 // target is dominated by that loop header, and so gets found by the
2734 // block traversal on the outside before the selection is found. The
2735 // selection is inside the outer loop.
2736 //
2737 // So we fall into one of the following cases:
2738 // - We are entering 0 or 1 constructs, or
2739 // - We are entering 2 constructs, with the outer one being a kContinue or
2740 // kLoop, the inner one is not a continue.
2741 if (entering_constructs.size() > 2) {
2742 return Fail() << "internal error: bad construct nesting found";
2743 }
2744 if (entering_constructs.size() == 2) {
2745 auto inner_kind = entering_constructs[0]->kind;
2746 auto outer_kind = entering_constructs[1]->kind;
2747 if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
2748 return Fail()
2749 << "internal error: bad construct nesting. Only a Continue "
2750 "or a Loop construct can be outer construct on same block. "
2751 "Got outer kind "
2752 << int(outer_kind) << " inner kind " << int(inner_kind);
2753 }
2754 if (inner_kind == Construct::kContinue) {
2755 return Fail() << "internal error: unsupported construct nesting: "
2756 "Continue around Continue";
2757 }
2758 if (inner_kind != Construct::kIfSelection &&
2759 inner_kind != Construct::kSwitchSelection &&
2760 inner_kind != Construct::kLoop) {
2761 return Fail() << "internal error: bad construct nesting. Continue around "
2762 "something other than if, switch, or loop";
2763 }
2764 }
2765
2766 // Enter constructs from outermost to innermost.
2767 // kLoop and kContinue push a new statement-block onto the stack before
2768 // emitting statements in the block.
2769 // kIfSelection and kSwitchSelection emit statements in the block and then
2770 // emit push a new statement-block. Only emit the statements in the block
2771 // once.
2772
2773 // Have we emitted the statements for this block?
2774 bool emitted = false;
2775
2776 // When entering an if-selection or switch-selection, we will emit the WGSL
2777 // construct to cause the divergent branching. But otherwise, we will
2778 // emit a "normal" block terminator, which occurs at the end of this method.
2779 bool has_normal_terminator = true;
2780
2781 for (auto iter = entering_constructs.rbegin();
2782 iter != entering_constructs.rend(); ++iter) {
2783 const Construct* construct = *iter;
2784
2785 switch (construct->kind) {
2786 case Construct::kFunction:
2787 return Fail() << "internal error: nested function construct";
2788
2789 case Construct::kLoop:
2790 if (!EmitLoopStart(construct)) {
2791 return false;
2792 }
2793 if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2794 return false;
2795 }
2796 break;
2797
2798 case Construct::kContinue:
2799 if (block_info.is_continue_entire_loop) {
2800 if (!EmitLoopStart(construct)) {
2801 return false;
2802 }
2803 if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2804 return false;
2805 }
2806 } else {
2807 if (!EmitContinuingStart(construct)) {
2808 return false;
2809 }
2810 }
2811 break;
2812
2813 case Construct::kIfSelection:
2814 if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2815 return false;
2816 }
2817 if (!EmitIfStart(block_info)) {
2818 return false;
2819 }
2820 has_normal_terminator = false;
2821 break;
2822
2823 case Construct::kSwitchSelection:
2824 if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2825 return false;
2826 }
2827 if (!EmitSwitchStart(block_info)) {
2828 return false;
2829 }
2830 has_normal_terminator = false;
2831 break;
2832 }
2833 }
2834
2835 // If we aren't starting or transitioning, then emit the normal
2836 // statements now.
2837 if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2838 return false;
2839 }
2840
2841 if (has_normal_terminator) {
2842 if (!EmitNormalTerminator(block_info)) {
2843 return false;
2844 }
2845 }
2846 return success();
2847 }
2848
EmitIfStart(const BlockInfo & block_info)2849 bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
2850 // The block is the if-header block. So its construct is the if construct.
2851 auto* construct = block_info.construct;
2852 TINT_ASSERT(Reader, construct->kind == Construct::kIfSelection);
2853 TINT_ASSERT(Reader, construct->begin_id == block_info.id);
2854
2855 const uint32_t true_head = block_info.true_head;
2856 const uint32_t false_head = block_info.false_head;
2857 const uint32_t premerge_head = block_info.premerge_head;
2858
2859 const std::string guard_name = block_info.flow_guard_name;
2860 if (!guard_name.empty()) {
2861 // Declare the guard variable just before the "if", initialized to true.
2862 auto* guard_var =
2863 builder_.Var(guard_name, builder_.ty.bool_(), MakeTrue(Source{}));
2864 auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var);
2865 AddStatement(guard_decl);
2866 }
2867
2868 const auto condition_id =
2869 block_info.basic_block->terminator()->GetSingleWordInOperand(0);
2870 auto* cond = MakeExpression(condition_id).expr;
2871 if (!cond) {
2872 return false;
2873 }
2874 // Generate the code for the condition.
2875 auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
2876
2877 // Compute the block IDs that should end the then-clause and the else-clause.
2878
2879 // We need to know where the *emitted* selection should end, i.e. the intended
2880 // merge block id. That should be the current premerge block, if it exists,
2881 // or otherwise the declared merge block.
2882 //
2883 // This is another way to think about it:
2884 // If there is a premerge, then there are three cases:
2885 // - premerge_head is different from the true_head and false_head:
2886 // - Premerge comes last. In effect, move the selection merge up
2887 // to where the premerge begins.
2888 // - premerge_head is the same as the false_head
2889 // - This is really an if-then without an else clause.
2890 // Move the merge up to where the premerge is.
2891 // - premerge_head is the same as the true_head
2892 // - This is really an if-else without an then clause.
2893 // Emit it as: if (cond) {} else {....}
2894 // Move the merge up to where the premerge is.
2895 const uint32_t intended_merge =
2896 premerge_head ? premerge_head : construct->end_id;
2897
2898 // then-clause:
2899 // If true_head exists:
2900 // spans from true head to the earlier of the false head (if it exists)
2901 // or the selection merge.
2902 // Otherwise:
2903 // ends at from the false head (if it exists), otherwise the selection
2904 // end.
2905 const uint32_t then_end = false_head ? false_head : intended_merge;
2906
2907 // else-clause:
2908 // ends at the premerge head (if it exists) or at the selection end.
2909 const uint32_t else_end = premerge_head ? premerge_head : intended_merge;
2910
2911 const bool true_is_break = (block_info.true_kind == EdgeKind::kSwitchBreak) ||
2912 (block_info.true_kind == EdgeKind::kLoopBreak);
2913 const bool false_is_break =
2914 (block_info.false_kind == EdgeKind::kSwitchBreak) ||
2915 (block_info.false_kind == EdgeKind::kLoopBreak);
2916 const bool true_is_continue = block_info.true_kind == EdgeKind::kLoopContinue;
2917 const bool false_is_continue =
2918 block_info.false_kind == EdgeKind::kLoopContinue;
2919
2920 // Push statement blocks for the then-clause and the else-clause.
2921 // But make sure we do it in the right order.
2922 auto push_else = [this, builder, else_end, construct, false_is_break,
2923 false_is_continue]() {
2924 // Push the else clause onto the stack first.
2925 PushNewStatementBlock(
2926 construct, else_end, [=](const ast::StatementList& stmts) {
2927 // Only set the else-clause if there are statements to fill it.
2928 if (!stmts.empty()) {
2929 // The "else" consists of the statement list from the top of
2930 // statements stack, without an elseif condition.
2931 auto* else_body = create<ast::BlockStatement>(Source{}, stmts);
2932 builder->else_stmts.emplace_back(
2933 create<ast::ElseStatement>(Source{}, nullptr, else_body));
2934 }
2935 });
2936 if (false_is_break) {
2937 AddStatement(create<ast::BreakStatement>(Source{}));
2938 }
2939 if (false_is_continue) {
2940 AddStatement(create<ast::ContinueStatement>(Source{}));
2941 }
2942 };
2943
2944 if (!true_is_break && !true_is_continue &&
2945 (GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos)) {
2946 // Process the else-clause first. The then-clause will be empty so avoid
2947 // pushing onto the stack at all.
2948 push_else();
2949 } else {
2950 // Blocks for the then-clause appear before blocks for the else-clause.
2951 // So push the else-clause handling onto the stack first. The else-clause
2952 // might be empty, but this works anyway.
2953
2954 // Handle the premerge, if it exists.
2955 if (premerge_head) {
2956 // The top of the stack is the statement block that is the parent of the
2957 // if-statement. Adding statements now will place them after that 'if'.
2958 if (guard_name.empty()) {
2959 // We won't have a flow guard for the premerge.
2960 // Insert a trivial if(true) { ... } around the blocks from the
2961 // premerge head until the end of the if-selection. This is needed
2962 // to ensure uniform reconvergence occurs at the end of the if-selection
2963 // just like in the original SPIR-V.
2964 PushTrueGuard(construct->end_id);
2965 } else {
2966 // Add a flow guard around the blocks in the premerge area.
2967 PushGuard(guard_name, construct->end_id);
2968 }
2969 }
2970
2971 push_else();
2972 if (true_head && false_head && !guard_name.empty()) {
2973 // There are non-trivial then and else clauses.
2974 // We have to guard the start of the else.
2975 PushGuard(guard_name, else_end);
2976 }
2977
2978 // Push the then clause onto the stack.
2979 PushNewStatementBlock(
2980 construct, then_end, [=](const ast::StatementList& stmts) {
2981 builder->body = create<ast::BlockStatement>(Source{}, stmts);
2982 });
2983 if (true_is_break) {
2984 AddStatement(create<ast::BreakStatement>(Source{}));
2985 }
2986 if (true_is_continue) {
2987 AddStatement(create<ast::ContinueStatement>(Source{}));
2988 }
2989 }
2990
2991 return success();
2992 }
2993
EmitSwitchStart(const BlockInfo & block_info)2994 bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
2995 // The block is the if-header block. So its construct is the if construct.
2996 auto* construct = block_info.construct;
2997 TINT_ASSERT(Reader, construct->kind == Construct::kSwitchSelection);
2998 TINT_ASSERT(Reader, construct->begin_id == block_info.id);
2999 const auto* branch = block_info.basic_block->terminator();
3000
3001 const auto selector_id = branch->GetSingleWordInOperand(0);
3002 // Generate the code for the selector.
3003 auto selector = MakeExpression(selector_id);
3004 if (!selector) {
3005 return false;
3006 }
3007 // First, push the statement block for the entire switch.
3008 auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
3009
3010 // Grab a pointer to the case list. It will get buried in the statement block
3011 // stack.
3012 PushNewStatementBlock(construct, construct->end_id, nullptr);
3013
3014 // We will push statement-blocks onto the stack to gather the statements in
3015 // the default clause and cases clauses. Determine the list of blocks
3016 // that start each clause.
3017 std::vector<const BlockInfo*> clause_heads;
3018
3019 // Collect the case clauses, even if they are just the merge block.
3020 // First the default clause.
3021 const auto default_id = branch->GetSingleWordInOperand(1);
3022 const auto* default_info = GetBlockInfo(default_id);
3023 clause_heads.push_back(default_info);
3024 // Now the case clauses.
3025 for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
3026 const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
3027 clause_heads.push_back(GetBlockInfo(case_target_id));
3028 }
3029
3030 std::stable_sort(clause_heads.begin(), clause_heads.end(),
3031 [](const BlockInfo* lhs, const BlockInfo* rhs) {
3032 return lhs->pos < rhs->pos;
3033 });
3034 // Remove duplicates
3035 {
3036 // Use read index r, and write index w.
3037 // Invariant: w <= r;
3038 size_t w = 0;
3039 for (size_t r = 0; r < clause_heads.size(); ++r) {
3040 if (clause_heads[r] != clause_heads[w]) {
3041 ++w; // Advance the write cursor.
3042 }
3043 clause_heads[w] = clause_heads[r];
3044 }
3045 // We know it's not empty because it always has at least a default clause.
3046 TINT_ASSERT(Reader, !clause_heads.empty());
3047 clause_heads.resize(w + 1);
3048 }
3049
3050 // Push them on in reverse order.
3051 const auto last_clause_index = clause_heads.size() - 1;
3052 for (size_t i = last_clause_index;; --i) {
3053 // Create a list of integer literals for the selector values leading to
3054 // this case clause.
3055 ast::CaseSelectorList selectors;
3056 const auto* values_ptr = clause_heads[i]->case_values.get();
3057 const bool has_selectors = (values_ptr && !values_ptr->empty());
3058 if (has_selectors) {
3059 std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
3060 std::stable_sort(values.begin(), values.end());
3061 for (auto value : values) {
3062 // The rest of this module can handle up to 64 bit switch values.
3063 // The Tint AST handles 32-bit values.
3064 const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
3065 if (selector.type->IsUnsignedScalarOrVector()) {
3066 selectors.emplace_back(
3067 create<ast::UintLiteralExpression>(Source{}, value32));
3068 } else {
3069 selectors.emplace_back(
3070 create<ast::SintLiteralExpression>(Source{}, value32));
3071 }
3072 }
3073 }
3074
3075 // Where does this clause end?
3076 const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
3077 : construct->end_id;
3078
3079 // Reserve the case clause slot in swch->cases, push the new statement block
3080 // for the case, and fill the case clause once the block is generated.
3081 auto case_idx = swch->cases.size();
3082 swch->cases.emplace_back(nullptr);
3083 PushNewStatementBlock(
3084 construct, end_id, [=](const ast::StatementList& stmts) {
3085 auto* body = create<ast::BlockStatement>(Source{}, stmts);
3086 swch->cases[case_idx] =
3087 create<ast::CaseStatement>(Source{}, selectors, body);
3088 });
3089
3090 if ((default_info == clause_heads[i]) && has_selectors &&
3091 construct->ContainsPos(default_info->pos)) {
3092 // Generate a default clause with a just fallthrough.
3093 auto* stmts = create<ast::BlockStatement>(
3094 Source{}, ast::StatementList{
3095 create<ast::FallthroughStatement>(Source{}),
3096 });
3097 auto* case_stmt =
3098 create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
3099 swch->cases.emplace_back(case_stmt);
3100 }
3101
3102 if (i == 0) {
3103 break;
3104 }
3105 }
3106
3107 return success();
3108 }
3109
EmitLoopStart(const Construct * construct)3110 bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
3111 auto* builder = AddStatementBuilder<LoopStatementBuilder>();
3112 PushNewStatementBlock(
3113 construct, construct->end_id, [=](const ast::StatementList& stmts) {
3114 builder->body = create<ast::BlockStatement>(Source{}, stmts);
3115 });
3116 return success();
3117 }
3118
EmitContinuingStart(const Construct * construct)3119 bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
3120 // A continue construct has the same depth as its associated loop
3121 // construct. Start a continue construct.
3122 auto* loop_candidate = LastStatement();
3123 auto* loop = loop_candidate->As<LoopStatementBuilder>();
3124 if (loop == nullptr) {
3125 return Fail() << "internal error: starting continue construct, "
3126 "expected loop on top of stack";
3127 }
3128 PushNewStatementBlock(
3129 construct, construct->end_id, [=](const ast::StatementList& stmts) {
3130 loop->continuing = create<ast::BlockStatement>(Source{}, stmts);
3131 });
3132
3133 return success();
3134 }
3135
EmitNormalTerminator(const BlockInfo & block_info)3136 bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
3137 const auto& terminator = *(block_info.basic_block->terminator());
3138 switch (terminator.opcode()) {
3139 case SpvOpReturn:
3140 AddStatement(create<ast::ReturnStatement>(Source{}));
3141 return true;
3142 case SpvOpReturnValue: {
3143 auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
3144 if (!value) {
3145 return false;
3146 }
3147 AddStatement(create<ast::ReturnStatement>(Source{}, value.expr));
3148 }
3149 return true;
3150 case SpvOpKill:
3151 // For now, assume SPIR-V OpKill has same semantics as WGSL discard.
3152 // TODO(dneto): https://github.com/gpuweb/gpuweb/issues/676
3153 AddStatement(create<ast::DiscardStatement>(Source{}));
3154 return true;
3155 case SpvOpUnreachable:
3156 // Translate as if it's a return. This avoids the problem where WGSL
3157 // requires a return statement at the end of the function body.
3158 {
3159 const auto* result_type = type_mgr_->GetType(function_.type_id());
3160 if (result_type->AsVoid() != nullptr) {
3161 AddStatement(create<ast::ReturnStatement>(Source{}));
3162 } else {
3163 auto* ast_type = parser_impl_.ConvertType(function_.type_id());
3164 AddStatement(create<ast::ReturnStatement>(
3165 Source{}, parser_impl_.MakeNullValue(ast_type)));
3166 }
3167 }
3168 return true;
3169 case SpvOpBranch: {
3170 const auto dest_id = terminator.GetSingleWordInOperand(0);
3171 AddStatement(MakeBranch(block_info, *GetBlockInfo(dest_id)));
3172 return true;
3173 }
3174 case SpvOpBranchConditional: {
3175 // If both destinations are the same, then do the same as we would
3176 // for an unconditional branch (OpBranch).
3177 const auto true_dest = terminator.GetSingleWordInOperand(1);
3178 const auto false_dest = terminator.GetSingleWordInOperand(2);
3179 if (true_dest == false_dest) {
3180 // This is like an unconditional branch.
3181 AddStatement(MakeBranch(block_info, *GetBlockInfo(true_dest)));
3182 return true;
3183 }
3184
3185 const EdgeKind true_kind = block_info.succ_edge.find(true_dest)->second;
3186 const EdgeKind false_kind = block_info.succ_edge.find(false_dest)->second;
3187 auto* const true_info = GetBlockInfo(true_dest);
3188 auto* const false_info = GetBlockInfo(false_dest);
3189 auto* cond = MakeExpression(terminator.GetSingleWordInOperand(0)).expr;
3190 if (!cond) {
3191 return false;
3192 }
3193
3194 // We have two distinct destinations. But we only get here if this
3195 // is a normal terminator; in particular the source block is *not* the
3196 // start of an if-selection or a switch-selection. So at most one branch
3197 // is a kForward, kCaseFallThrough, or kIfBreak.
3198
3199 // The fallthrough case is special because WGSL requires the fallthrough
3200 // statement to be last in the case clause.
3201 if (true_kind == EdgeKind::kCaseFallThrough) {
3202 return EmitConditionalCaseFallThrough(block_info, cond, false_kind,
3203 *false_info, true);
3204 } else if (false_kind == EdgeKind::kCaseFallThrough) {
3205 return EmitConditionalCaseFallThrough(block_info, cond, true_kind,
3206 *true_info, false);
3207 }
3208
3209 // At this point, at most one edge is kForward or kIfBreak.
3210
3211 // Emit an 'if' statement to express the *other* branch as a conditional
3212 // break or continue. Either or both of these could be nullptr.
3213 // (A nullptr is generated for kIfBreak, kForward, or kBack.)
3214 // Also if one of the branches is an if-break out of an if-selection
3215 // requiring a flow guard, then get that flow guard name too. It will
3216 // come from at most one of these two branches.
3217 std::string flow_guard;
3218 auto* true_branch =
3219 MakeBranchDetailed(block_info, *true_info, false, &flow_guard);
3220 auto* false_branch =
3221 MakeBranchDetailed(block_info, *false_info, false, &flow_guard);
3222
3223 AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
3224 if (!flow_guard.empty()) {
3225 PushGuard(flow_guard, statements_stack_.back().GetEndId());
3226 }
3227 return true;
3228 }
3229 case SpvOpSwitch:
3230 // An OpSelectionMerge must precede an OpSwitch. That is clarified
3231 // in the resolution to Khronos-internal SPIR-V issue 115.
3232 // A new enough version of the SPIR-V validator checks this case.
3233 // But issue an error in this case, as a defensive measure.
3234 return Fail() << "invalid structured control flow: found an OpSwitch "
3235 "that is not preceded by an "
3236 "OpSelectionMerge: "
3237 << terminator.PrettyPrint();
3238 default:
3239 break;
3240 }
3241 return success();
3242 }
3243
MakeBranchDetailed(const BlockInfo & src_info,const BlockInfo & dest_info,bool forced,std::string * flow_guard_name_ptr) const3244 const ast::Statement* FunctionEmitter::MakeBranchDetailed(
3245 const BlockInfo& src_info,
3246 const BlockInfo& dest_info,
3247 bool forced,
3248 std::string* flow_guard_name_ptr) const {
3249 auto kind = src_info.succ_edge.find(dest_info.id)->second;
3250 switch (kind) {
3251 case EdgeKind::kBack:
3252 // Nothing to do. The loop backedge is implicit.
3253 break;
3254 case EdgeKind::kSwitchBreak: {
3255 if (forced) {
3256 return create<ast::BreakStatement>(Source{});
3257 }
3258 // Unless forced, don't bother with a break at the end of a case/default
3259 // clause.
3260 const auto header = dest_info.header_for_merge;
3261 TINT_ASSERT(Reader, header != 0);
3262 const auto* exiting_construct = GetBlockInfo(header)->construct;
3263 TINT_ASSERT(Reader,
3264 exiting_construct->kind == Construct::kSwitchSelection);
3265 const auto candidate_next_case_pos = src_info.pos + 1;
3266 // Leaving the last block from the last case?
3267 if (candidate_next_case_pos == dest_info.pos) {
3268 // No break needed.
3269 return nullptr;
3270 }
3271 // Leaving the last block from not-the-last-case?
3272 if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
3273 const auto* candidate_next_case =
3274 GetBlockInfo(block_order_[candidate_next_case_pos]);
3275 if (candidate_next_case->case_head_for == exiting_construct ||
3276 candidate_next_case->default_head_for == exiting_construct) {
3277 // No break needed.
3278 return nullptr;
3279 }
3280 }
3281 // We need a break.
3282 return create<ast::BreakStatement>(Source{});
3283 }
3284 case EdgeKind::kLoopBreak:
3285 return create<ast::BreakStatement>(Source{});
3286 case EdgeKind::kLoopContinue:
3287 // An unconditional continue to the next block is redundant and ugly.
3288 // Skip it in that case.
3289 if (dest_info.pos == 1 + src_info.pos) {
3290 break;
3291 }
3292 // Otherwise, emit a regular continue statement.
3293 return create<ast::ContinueStatement>(Source{});
3294 case EdgeKind::kIfBreak: {
3295 const auto& flow_guard =
3296 GetBlockInfo(dest_info.header_for_merge)->flow_guard_name;
3297 if (!flow_guard.empty()) {
3298 if (flow_guard_name_ptr != nullptr) {
3299 *flow_guard_name_ptr = flow_guard;
3300 }
3301 // Signal an exit from the branch.
3302 return create<ast::AssignmentStatement>(
3303 Source{},
3304 create<ast::IdentifierExpression>(
3305 Source{}, builder_.Symbols().Register(flow_guard)),
3306 MakeFalse(Source{}));
3307 }
3308
3309 // For an unconditional branch, the break out to an if-selection
3310 // merge block is implicit.
3311 break;
3312 }
3313 case EdgeKind::kCaseFallThrough:
3314 return create<ast::FallthroughStatement>(Source{});
3315 case EdgeKind::kForward:
3316 // Unconditional forward branch is implicit.
3317 break;
3318 }
3319 return nullptr;
3320 }
3321
MakeSimpleIf(const ast::Expression * condition,const ast::Statement * then_stmt,const ast::Statement * else_stmt) const3322 const ast::Statement* FunctionEmitter::MakeSimpleIf(
3323 const ast::Expression* condition,
3324 const ast::Statement* then_stmt,
3325 const ast::Statement* else_stmt) const {
3326 if ((then_stmt == nullptr) && (else_stmt == nullptr)) {
3327 return nullptr;
3328 }
3329 ast::ElseStatementList else_stmts;
3330 if (else_stmt != nullptr) {
3331 ast::StatementList stmts{else_stmt};
3332 else_stmts.emplace_back(create<ast::ElseStatement>(
3333 Source{}, nullptr, create<ast::BlockStatement>(Source{}, stmts)));
3334 }
3335 ast::StatementList if_stmts;
3336 if (then_stmt != nullptr) {
3337 if_stmts.emplace_back(then_stmt);
3338 }
3339 auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
3340 auto* if_stmt =
3341 create<ast::IfStatement>(Source{}, condition, if_block, else_stmts);
3342
3343 return if_stmt;
3344 }
3345
EmitConditionalCaseFallThrough(const BlockInfo & src_info,const ast::Expression * cond,EdgeKind other_edge_kind,const BlockInfo & other_dest,bool fall_through_is_true_branch)3346 bool FunctionEmitter::EmitConditionalCaseFallThrough(
3347 const BlockInfo& src_info,
3348 const ast::Expression* cond,
3349 EdgeKind other_edge_kind,
3350 const BlockInfo& other_dest,
3351 bool fall_through_is_true_branch) {
3352 // In WGSL, the fallthrough statement must come last in the case clause.
3353 // So we'll emit an if statement for the other branch, and then emit
3354 // the fallthrough.
3355
3356 // We have two distinct destinations. But we only get here if this
3357 // is a normal terminator; in particular the source block is *not* the
3358 // start of an if-selection. So at most one branch is a kForward or
3359 // kCaseFallThrough.
3360 if (other_edge_kind == EdgeKind::kForward) {
3361 return Fail()
3362 << "internal error: normal terminator OpBranchConditional has "
3363 "both forward and fallthrough edges";
3364 }
3365 if (other_edge_kind == EdgeKind::kIfBreak) {
3366 return Fail()
3367 << "internal error: normal terminator OpBranchConditional has "
3368 "both IfBreak and fallthrough edges. Violates nesting rule";
3369 }
3370 if (other_edge_kind == EdgeKind::kBack) {
3371 return Fail()
3372 << "internal error: normal terminator OpBranchConditional has "
3373 "both backedge and fallthrough edges. Violates nesting rule";
3374 }
3375 auto* other_branch = MakeForcedBranch(src_info, other_dest);
3376 if (other_branch == nullptr) {
3377 return Fail() << "internal error: expected a branch for edge-kind "
3378 << int(other_edge_kind);
3379 }
3380 if (fall_through_is_true_branch) {
3381 AddStatement(MakeSimpleIf(cond, nullptr, other_branch));
3382 } else {
3383 AddStatement(MakeSimpleIf(cond, other_branch, nullptr));
3384 }
3385 AddStatement(create<ast::FallthroughStatement>(Source{}));
3386
3387 return success();
3388 }
3389
EmitStatementsInBasicBlock(const BlockInfo & block_info,bool * already_emitted)3390 bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
3391 bool* already_emitted) {
3392 if (*already_emitted) {
3393 // Only emit this part of the basic block once.
3394 return true;
3395 }
3396 // Returns the given list of local definition IDs, sorted by their index.
3397 auto sorted_by_index = [this](const std::vector<uint32_t>& ids) {
3398 auto sorted = ids;
3399 std::stable_sort(sorted.begin(), sorted.end(),
3400 [this](const uint32_t lhs, const uint32_t rhs) {
3401 return GetDefInfo(lhs)->index < GetDefInfo(rhs)->index;
3402 });
3403 return sorted;
3404 };
3405
3406 // Emit declarations of hoisted variables, in index order.
3407 for (auto id : sorted_by_index(block_info.hoisted_ids)) {
3408 const auto* def_inst = def_use_mgr_->GetDef(id);
3409 TINT_ASSERT(Reader, def_inst);
3410 auto* storage_type =
3411 RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
3412 AddStatement(create<ast::VariableDeclStatement>(
3413 Source{},
3414 parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type,
3415 false, nullptr, ast::DecorationList{})));
3416 auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
3417 identifier_types_.emplace(id, type);
3418 }
3419 // Emit declarations of phi state variables, in index order.
3420 for (auto id : sorted_by_index(block_info.phis_needing_state_vars)) {
3421 const auto* def_inst = def_use_mgr_->GetDef(id);
3422 TINT_ASSERT(Reader, def_inst);
3423 const auto phi_var_name = GetDefInfo(id)->phi_var;
3424 TINT_ASSERT(Reader, !phi_var_name.empty());
3425 auto* var = builder_.Var(
3426 phi_var_name,
3427 parser_impl_.ConvertType(def_inst->type_id())->Build(builder_));
3428 AddStatement(create<ast::VariableDeclStatement>(Source{}, var));
3429 }
3430
3431 // Emit regular statements.
3432 const spvtools::opt::BasicBlock& bb = *(block_info.basic_block);
3433 const auto* terminator = bb.terminator();
3434 const auto* merge = bb.GetMergeInst(); // Might be nullptr
3435 for (auto& inst : bb) {
3436 if (&inst == terminator || &inst == merge || inst.opcode() == SpvOpLabel ||
3437 inst.opcode() == SpvOpVariable) {
3438 continue;
3439 }
3440 if (!EmitStatement(inst)) {
3441 return false;
3442 }
3443 }
3444
3445 // Emit assignments to carry values to phi nodes in potential destinations.
3446 // Do it in index order.
3447 if (!block_info.phi_assignments.empty()) {
3448 auto sorted = block_info.phi_assignments;
3449 std::stable_sort(sorted.begin(), sorted.end(),
3450 [this](const BlockInfo::PhiAssignment& lhs,
3451 const BlockInfo::PhiAssignment& rhs) {
3452 return GetDefInfo(lhs.phi_id)->index <
3453 GetDefInfo(rhs.phi_id)->index;
3454 });
3455 for (auto assignment : block_info.phi_assignments) {
3456 const auto var_name = GetDefInfo(assignment.phi_id)->phi_var;
3457 auto expr = MakeExpression(assignment.value);
3458 if (!expr) {
3459 return false;
3460 }
3461 AddStatement(create<ast::AssignmentStatement>(
3462 Source{},
3463 create<ast::IdentifierExpression>(
3464 Source{}, builder_.Symbols().Register(var_name)),
3465 expr.expr));
3466 }
3467 }
3468
3469 *already_emitted = true;
3470 return true;
3471 }
3472
EmitConstDefinition(const spvtools::opt::Instruction & inst,TypedExpression expr)3473 bool FunctionEmitter::EmitConstDefinition(
3474 const spvtools::opt::Instruction& inst,
3475 TypedExpression expr) {
3476 if (!expr) {
3477 return false;
3478 }
3479
3480 // Do not generate pointers that we want to sink.
3481 if (GetDefInfo(inst.result_id())->skip == SkipReason::kSinkPointerIntoUse) {
3482 return true;
3483 }
3484
3485 expr = AddressOfIfNeeded(expr, &inst);
3486 auto* ast_const = parser_impl_.MakeVariable(
3487 inst.result_id(), ast::StorageClass::kNone, expr.type, true, expr.expr,
3488 ast::DecorationList{});
3489 if (!ast_const) {
3490 return false;
3491 }
3492 AddStatement(create<ast::VariableDeclStatement>(Source{}, ast_const));
3493 identifier_types_.emplace(inst.result_id(), expr.type);
3494 return success();
3495 }
3496
EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction & inst,TypedExpression expr)3497 bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar(
3498 const spvtools::opt::Instruction& inst,
3499 TypedExpression expr) {
3500 return WriteIfHoistedVar(inst, expr) || EmitConstDefinition(inst, expr);
3501 }
3502
WriteIfHoistedVar(const spvtools::opt::Instruction & inst,TypedExpression expr)3503 bool FunctionEmitter::WriteIfHoistedVar(const spvtools::opt::Instruction& inst,
3504 TypedExpression expr) {
3505 const auto result_id = inst.result_id();
3506 const auto* def_info = GetDefInfo(result_id);
3507 if (def_info && def_info->requires_hoisted_def) {
3508 auto name = namer_.Name(result_id);
3509 // Emit an assignment of the expression to the hoisted variable.
3510 AddStatement(create<ast::AssignmentStatement>(
3511 Source{},
3512 create<ast::IdentifierExpression>(Source{},
3513 builder_.Symbols().Register(name)),
3514 expr.expr));
3515 return true;
3516 }
3517 return false;
3518 }
3519
EmitStatement(const spvtools::opt::Instruction & inst)3520 bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
3521 if (failed()) {
3522 return false;
3523 }
3524 const auto result_id = inst.result_id();
3525 const auto type_id = inst.type_id();
3526
3527 if (type_id != 0) {
3528 const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
3529 if (type_id == builtin_position_info.struct_type_id) {
3530 return Fail() << "operations producing a per-vertex structure are not "
3531 "supported: "
3532 << inst.PrettyPrint();
3533 }
3534 if (type_id == builtin_position_info.pointer_type_id) {
3535 return Fail() << "operations producing a pointer to a per-vertex "
3536 "structure are not "
3537 "supported: "
3538 << inst.PrettyPrint();
3539 }
3540 }
3541
3542 // Handle combinatorial instructions.
3543 const auto* def_info = GetDefInfo(result_id);
3544 if (def_info) {
3545 TypedExpression combinatorial_expr;
3546 if (def_info->skip == SkipReason::kDontSkip) {
3547 combinatorial_expr = MaybeEmitCombinatorialValue(inst);
3548 if (!success()) {
3549 return false;
3550 }
3551 }
3552 // An access chain or OpCopyObject can generate a skip.
3553 if (def_info->skip != SkipReason::kDontSkip) {
3554 return true;
3555 }
3556
3557 if (combinatorial_expr.expr != nullptr) {
3558 if (def_info->requires_hoisted_def ||
3559 def_info->requires_named_const_def || def_info->num_uses != 1) {
3560 // Generate a const definition or an assignment to a hoisted definition
3561 // now and later use the const or variable name at the uses of this
3562 // value.
3563 return EmitConstDefOrWriteToHoistedVar(inst, combinatorial_expr);
3564 }
3565 // It is harmless to defer emitting the expression until it's used.
3566 // Any supporting statements have already been emitted.
3567 singly_used_values_.insert(std::make_pair(result_id, combinatorial_expr));
3568 return success();
3569 }
3570 }
3571 if (failed()) {
3572 return false;
3573 }
3574
3575 if (IsImageQuery(inst.opcode())) {
3576 return EmitImageQuery(inst);
3577 }
3578
3579 if (IsSampledImageAccess(inst.opcode()) || IsRawImageAccess(inst.opcode())) {
3580 return EmitImageAccess(inst);
3581 }
3582
3583 switch (inst.opcode()) {
3584 case SpvOpNop:
3585 return true;
3586
3587 case SpvOpStore: {
3588 auto ptr_id = inst.GetSingleWordInOperand(0);
3589 const auto value_id = inst.GetSingleWordInOperand(1);
3590
3591 const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
3592 const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
3593 if (ptr_type_id == builtin_position_info.pointer_type_id) {
3594 return Fail()
3595 << "storing to the whole per-vertex structure is not supported: "
3596 << inst.PrettyPrint();
3597 }
3598
3599 TypedExpression rhs = MakeExpression(value_id);
3600 if (!rhs) {
3601 return false;
3602 }
3603
3604 TypedExpression lhs;
3605
3606 // Handle exceptional cases
3607 switch (GetSkipReason(ptr_id)) {
3608 case SkipReason::kPointSizeBuiltinPointer:
3609 if (IsFloatOne(value_id)) {
3610 // Don't store to PointSize
3611 return true;
3612 }
3613 return Fail() << "cannot store a value other than constant 1.0 to "
3614 "PointSize builtin: "
3615 << inst.PrettyPrint();
3616
3617 case SkipReason::kSampleMaskOutBuiltinPointer:
3618 lhs = MakeExpression(sample_mask_out_id);
3619 if (lhs.type->Is<Pointer>()) {
3620 // LHS of an assignment must be a reference type.
3621 // Convert the LHS to a reference by dereferencing it.
3622 lhs = Dereference(lhs);
3623 }
3624 // The private variable is an array whose element type is already of
3625 // the same type as the value being stored into it. Form the
3626 // reference into the first element.
3627 lhs.expr = create<ast::IndexAccessorExpression>(
3628 Source{}, lhs.expr, parser_impl_.MakeNullValue(ty_.I32()));
3629 if (auto* ref = lhs.type->As<Reference>()) {
3630 lhs.type = ref->type;
3631 }
3632 if (auto* arr = lhs.type->As<Array>()) {
3633 lhs.type = arr->type;
3634 }
3635 TINT_ASSERT(Reader, lhs.type);
3636 break;
3637 default:
3638 break;
3639 }
3640
3641 // Handle an ordinary store as an assignment.
3642 if (!lhs) {
3643 lhs = MakeExpression(ptr_id);
3644 }
3645 if (!lhs) {
3646 return false;
3647 }
3648
3649 if (lhs.type->Is<Pointer>()) {
3650 // LHS of an assignment must be a reference type.
3651 // Convert the LHS to a reference by dereferencing it.
3652 lhs = Dereference(lhs);
3653 }
3654
3655 AddStatement(
3656 create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
3657 return success();
3658 }
3659
3660 case SpvOpLoad: {
3661 // Memory accesses must be issued in SPIR-V program order.
3662 // So represent a load by a new const definition.
3663 const auto ptr_id = inst.GetSingleWordInOperand(0);
3664 const auto skip_reason = GetSkipReason(ptr_id);
3665
3666 switch (skip_reason) {
3667 case SkipReason::kPointSizeBuiltinPointer:
3668 GetDefInfo(inst.result_id())->skip =
3669 SkipReason::kPointSizeBuiltinValue;
3670 return true;
3671 case SkipReason::kSampleMaskInBuiltinPointer: {
3672 auto name = namer_.Name(sample_mask_in_id);
3673 const ast::Expression* id_expr = create<ast::IdentifierExpression>(
3674 Source{}, builder_.Symbols().Register(name));
3675 // SampleMask is an array in Vulkan SPIR-V. Always access the first
3676 // element.
3677 id_expr = create<ast::IndexAccessorExpression>(
3678 Source{}, id_expr, parser_impl_.MakeNullValue(ty_.I32()));
3679
3680 auto* loaded_type = parser_impl_.ConvertType(inst.type_id());
3681
3682 if (!loaded_type->IsIntegerScalar()) {
3683 return Fail() << "loading the whole SampleMask input array is not "
3684 "supported: "
3685 << inst.PrettyPrint();
3686 }
3687
3688 auto expr = TypedExpression{loaded_type, id_expr};
3689 return EmitConstDefinition(inst, expr);
3690 }
3691 default:
3692 break;
3693 }
3694 auto expr = MakeExpression(ptr_id);
3695 if (!expr) {
3696 return false;
3697 }
3698
3699 // The load result type is the storage type of its operand.
3700 if (expr.type->Is<Pointer>()) {
3701 expr = Dereference(expr);
3702 } else if (auto* ref = expr.type->As<Reference>()) {
3703 expr.type = ref->type;
3704 } else {
3705 Fail() << "OpLoad expression is not a pointer or reference";
3706 return false;
3707 }
3708
3709 return EmitConstDefOrWriteToHoistedVar(inst, expr);
3710 }
3711
3712 case SpvOpCopyMemory: {
3713 // Generate an assignment.
3714 auto lhs = MakeOperand(inst, 0);
3715 auto rhs = MakeOperand(inst, 1);
3716 // Ignore any potential memory operands. Currently they are all for
3717 // concepts not in WGSL:
3718 // Volatile
3719 // Aligned
3720 // Nontemporal
3721 // MakePointerAvailable ; Vulkan memory model
3722 // MakePointerVisible ; Vulkan memory model
3723 // NonPrivatePointer ; Vulkan memory model
3724
3725 if (!success()) {
3726 return false;
3727 }
3728
3729 // LHS and RHS pointers must be reference types in WGSL.
3730 if (lhs.type->Is<Pointer>()) {
3731 lhs = Dereference(lhs);
3732 }
3733 if (rhs.type->Is<Pointer>()) {
3734 rhs = Dereference(rhs);
3735 }
3736
3737 AddStatement(
3738 create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
3739 return success();
3740 }
3741
3742 case SpvOpCopyObject: {
3743 // Arguably, OpCopyObject is purely combinatorial. On the other hand,
3744 // it exists to make a new name for something. So we choose to make
3745 // a new named constant definition.
3746 auto value_id = inst.GetSingleWordInOperand(0);
3747 const auto skip = GetSkipReason(value_id);
3748 if (skip != SkipReason::kDontSkip) {
3749 GetDefInfo(inst.result_id())->skip = skip;
3750 GetDefInfo(inst.result_id())->sink_pointer_source_expr =
3751 GetDefInfo(value_id)->sink_pointer_source_expr;
3752 return true;
3753 }
3754 auto expr = AddressOfIfNeeded(MakeExpression(value_id), &inst);
3755 if (!expr) {
3756 return false;
3757 }
3758 expr.type = RemapStorageClass(expr.type, result_id);
3759 return EmitConstDefOrWriteToHoistedVar(inst, expr);
3760 }
3761
3762 case SpvOpPhi: {
3763 // Emit a read from the associated state variable.
3764 TypedExpression expr{
3765 parser_impl_.ConvertType(inst.type_id()),
3766 create<ast::IdentifierExpression>(
3767 Source{}, builder_.Symbols().Register(def_info->phi_var))};
3768 return EmitConstDefOrWriteToHoistedVar(inst, expr);
3769 }
3770
3771 case SpvOpOuterProduct:
3772 // Synthesize an outer product expression in its own statement.
3773 return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
3774
3775 case SpvOpVectorInsertDynamic:
3776 // Synthesize a vector insertion in its own statements.
3777 return MakeVectorInsertDynamic(inst);
3778
3779 case SpvOpCompositeInsert:
3780 // Synthesize a composite insertion in its own statements.
3781 return MakeCompositeInsert(inst);
3782
3783 case SpvOpFunctionCall:
3784 return EmitFunctionCall(inst);
3785
3786 case SpvOpControlBarrier:
3787 return EmitControlBarrier(inst);
3788
3789 case SpvOpExtInst:
3790 if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
3791 return true;
3792 }
3793 break;
3794
3795 case SpvOpIAddCarry:
3796 case SpvOpISubBorrow:
3797 case SpvOpUMulExtended:
3798 case SpvOpSMulExtended:
3799 return Fail() << "extended arithmetic is not finalized for WGSL: "
3800 "https://github.com/gpuweb/gpuweb/issues/1565: "
3801 << inst.PrettyPrint();
3802
3803 default:
3804 break;
3805 }
3806 return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
3807 << inst.PrettyPrint();
3808 }
3809
MakeOperand(const spvtools::opt::Instruction & inst,uint32_t operand_index)3810 TypedExpression FunctionEmitter::MakeOperand(
3811 const spvtools::opt::Instruction& inst,
3812 uint32_t operand_index) {
3813 auto expr = MakeExpression(inst.GetSingleWordInOperand(operand_index));
3814 if (!expr) {
3815 return {};
3816 }
3817 return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
3818 }
3819
InferFunctionStorageClass(TypedExpression expr)3820 TypedExpression FunctionEmitter::InferFunctionStorageClass(
3821 TypedExpression expr) {
3822 TypedExpression result(expr);
3823 if (const auto* ref = expr.type->UnwrapAlias()->As<Reference>()) {
3824 if (ref->storage_class == ast::StorageClass::kNone) {
3825 expr.type = ty_.Reference(ref->type, ast::StorageClass::kFunction);
3826 }
3827 } else if (const auto* ptr = expr.type->UnwrapAlias()->As<Pointer>()) {
3828 if (ptr->storage_class == ast::StorageClass::kNone) {
3829 expr.type = ty_.Pointer(ptr->type, ast::StorageClass::kFunction);
3830 }
3831 }
3832 return expr;
3833 }
3834
MaybeEmitCombinatorialValue(const spvtools::opt::Instruction & inst)3835 TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
3836 const spvtools::opt::Instruction& inst) {
3837 if (inst.result_id() == 0) {
3838 return {};
3839 }
3840
3841 const auto opcode = inst.opcode();
3842
3843 const Type* ast_type = nullptr;
3844 if (inst.type_id()) {
3845 ast_type = parser_impl_.ConvertType(inst.type_id());
3846 if (!ast_type) {
3847 Fail() << "couldn't convert result type for: " << inst.PrettyPrint();
3848 return {};
3849 }
3850 }
3851
3852 auto binary_op = ConvertBinaryOp(opcode);
3853 if (binary_op != ast::BinaryOp::kNone) {
3854 auto arg0 = MakeOperand(inst, 0);
3855 auto arg1 = parser_impl_.RectifySecondOperandSignedness(
3856 inst, arg0.type, MakeOperand(inst, 1));
3857 if (!arg0 || !arg1) {
3858 return {};
3859 }
3860 auto* binary_expr = create<ast::BinaryExpression>(Source{}, binary_op,
3861 arg0.expr, arg1.expr);
3862 TypedExpression result{ast_type, binary_expr};
3863 return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3864 }
3865
3866 auto unary_op = ast::UnaryOp::kNegation;
3867 if (GetUnaryOp(opcode, &unary_op)) {
3868 auto arg0 = MakeOperand(inst, 0);
3869 auto* unary_expr =
3870 create<ast::UnaryOpExpression>(Source{}, unary_op, arg0.expr);
3871 TypedExpression result{ast_type, unary_expr};
3872 return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3873 }
3874
3875 const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
3876 if (unary_builtin_name != nullptr) {
3877 ast::ExpressionList params;
3878 params.emplace_back(MakeOperand(inst, 0).expr);
3879 return {ast_type,
3880 create<ast::CallExpression>(
3881 Source{},
3882 create<ast::IdentifierExpression>(
3883 Source{}, builder_.Symbols().Register(unary_builtin_name)),
3884 std::move(params))};
3885 }
3886
3887 const auto intrinsic = GetIntrinsic(opcode);
3888 if (intrinsic != sem::IntrinsicType::kNone) {
3889 return MakeIntrinsicCall(inst);
3890 }
3891
3892 if (opcode == SpvOpFMod) {
3893 return MakeFMod(inst);
3894 }
3895
3896 if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
3897 return MakeAccessChain(inst);
3898 }
3899
3900 if (opcode == SpvOpBitcast) {
3901 return {ast_type,
3902 create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_),
3903 MakeOperand(inst, 0).expr)};
3904 }
3905
3906 if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
3907 opcode == SpvOpShiftRightArithmetic) {
3908 auto arg0 = MakeOperand(inst, 0);
3909 // The second operand must be unsigned. It's ok to wrap the shift amount
3910 // since the shift is modulo the bit width of the first operand.
3911 auto arg1 = parser_impl_.AsUnsigned(MakeOperand(inst, 1));
3912
3913 switch (opcode) {
3914 case SpvOpShiftLeftLogical:
3915 binary_op = ast::BinaryOp::kShiftLeft;
3916 break;
3917 case SpvOpShiftRightLogical:
3918 arg0 = parser_impl_.AsUnsigned(arg0);
3919 binary_op = ast::BinaryOp::kShiftRight;
3920 break;
3921 case SpvOpShiftRightArithmetic:
3922 arg0 = parser_impl_.AsSigned(arg0);
3923 binary_op = ast::BinaryOp::kShiftRight;
3924 break;
3925 default:
3926 break;
3927 }
3928 TypedExpression result{
3929 ast_type, create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr,
3930 arg1.expr)};
3931 return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3932 }
3933
3934 auto negated_op = NegatedFloatCompare(opcode);
3935 if (negated_op != ast::BinaryOp::kNone) {
3936 auto arg0 = MakeOperand(inst, 0);
3937 auto arg1 = MakeOperand(inst, 1);
3938 auto* binary_expr = create<ast::BinaryExpression>(Source{}, negated_op,
3939 arg0.expr, arg1.expr);
3940 auto* negated_expr = create<ast::UnaryOpExpression>(
3941 Source{}, ast::UnaryOp::kNot, binary_expr);
3942 return {ast_type, negated_expr};
3943 }
3944
3945 if (opcode == SpvOpExtInst) {
3946 if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
3947 // Ignore it but don't error out.
3948 return {};
3949 }
3950 if (!parser_impl_.IsGlslExtendedInstruction(inst)) {
3951 Fail() << "unhandled extended instruction import with ID "
3952 << inst.GetSingleWordInOperand(0);
3953 return {};
3954 }
3955 return EmitGlslStd450ExtInst(inst);
3956 }
3957
3958 if (opcode == SpvOpCompositeConstruct) {
3959 ast::ExpressionList operands;
3960 for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
3961 operands.emplace_back(MakeOperand(inst, iarg).expr);
3962 }
3963 return {ast_type, builder_.Construct(Source{}, ast_type->Build(builder_),
3964 std::move(operands))};
3965 }
3966
3967 if (opcode == SpvOpCompositeExtract) {
3968 return MakeCompositeExtract(inst);
3969 }
3970
3971 if (opcode == SpvOpVectorShuffle) {
3972 return MakeVectorShuffle(inst);
3973 }
3974
3975 if (opcode == SpvOpVectorExtractDynamic) {
3976 return {ast_type, create<ast::IndexAccessorExpression>(
3977 Source{}, MakeOperand(inst, 0).expr,
3978 MakeOperand(inst, 1).expr)};
3979 }
3980
3981 if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF ||
3982 opcode == SpvOpConvertFToS || opcode == SpvOpConvertFToU) {
3983 return MakeNumericConversion(inst);
3984 }
3985
3986 if (opcode == SpvOpUndef) {
3987 // Replace undef with the null value.
3988 return parser_impl_.MakeNullExpression(ast_type);
3989 }
3990
3991 if (opcode == SpvOpSelect) {
3992 return MakeSimpleSelect(inst);
3993 }
3994
3995 if (opcode == SpvOpArrayLength) {
3996 return MakeArrayLength(inst);
3997 }
3998
3999 // builtin readonly function
4000 // glsl.std.450 readonly function
4001
4002 // Instructions:
4003 // OpSatConvertSToU // Only in Kernel (OpenCL), not in WebGPU
4004 // OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU
4005 // OpUConvert // Only needed when multiple widths supported
4006 // OpSConvert // Only needed when multiple widths supported
4007 // OpFConvert // Only needed when multiple widths supported
4008 // OpConvertPtrToU // Not in WebGPU
4009 // OpConvertUToPtr // Not in WebGPU
4010 // OpPtrCastToGeneric // Not in Vulkan
4011 // OpGenericCastToPtr // Not in Vulkan
4012 // OpGenericCastToPtrExplicit // Not in Vulkan
4013
4014 return {};
4015 }
4016
EmitGlslStd450ExtInst(const spvtools::opt::Instruction & inst)4017 TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
4018 const spvtools::opt::Instruction& inst) {
4019 const auto ext_opcode = inst.GetSingleWordInOperand(1);
4020
4021 if (ext_opcode == GLSLstd450Ldexp) {
4022 // WGSL requires the second argument to be signed.
4023 // Use a type constructor to convert it, which is the same as a bitcast.
4024 // If the value would go from very large positive to negative, then the
4025 // original result would have been infinity. And since WGSL
4026 // implementations may assume that infinities are not present, then we
4027 // don't have to worry about that case.
4028 auto e1 = MakeOperand(inst, 2);
4029 auto e2 = ToSignedIfUnsigned(MakeOperand(inst, 3));
4030
4031 return {e1.type, builder_.Call(Source{}, "ldexp",
4032 ast::ExpressionList{e1.expr, e2.expr})};
4033 }
4034
4035 auto* result_type = parser_impl_.ConvertType(inst.type_id());
4036
4037 if (result_type->IsScalar()) {
4038 // Some GLSLstd450 builtins have scalar forms not supported by WGSL.
4039 // Emulate them.
4040 switch (ext_opcode) {
4041 case GLSLstd450Normalize:
4042 // WGSL does not have scalar form of the normalize builtin.
4043 // The answer would be 1 anyway, so return that directly.
4044 return {ty_.F32(), builder_.Expr(1.0f)};
4045
4046 case GLSLstd450FaceForward: {
4047 // If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
4048 // Also: select(-normal,normal, Incident*Nref < 0)
4049 // (The dot product of scalars is their product.)
4050 // Use a multiply instead of comparing floating point signs. It should
4051 // be among the fastest operations on a GPU.
4052 auto normal = MakeOperand(inst, 2);
4053 auto incident = MakeOperand(inst, 3);
4054 auto nref = MakeOperand(inst, 4);
4055 TINT_ASSERT(Reader, normal.type->Is<F32>());
4056 TINT_ASSERT(Reader, incident.type->Is<F32>());
4057 TINT_ASSERT(Reader, nref.type->Is<F32>());
4058 return {ty_.F32(),
4059 builder_.Call(
4060 Source{}, "select",
4061 ast::ExpressionList{
4062 create<ast::UnaryOpExpression>(
4063 Source{}, ast::UnaryOp::kNegation, normal.expr),
4064 normal.expr,
4065 create<ast::BinaryExpression>(
4066 Source{}, ast::BinaryOp::kLessThan,
4067 builder_.Mul({}, incident.expr, nref.expr),
4068 builder_.Expr(0.0f))})};
4069 }
4070
4071 case GLSLstd450Reflect: {
4072 // Compute Incident - 2 * Normal * Normal * Incident
4073 auto incident = MakeOperand(inst, 2);
4074 auto normal = MakeOperand(inst, 3);
4075 TINT_ASSERT(Reader, incident.type->Is<F32>());
4076 TINT_ASSERT(Reader, normal.type->Is<F32>());
4077 return {
4078 ty_.F32(),
4079 builder_.Sub(
4080 incident.expr,
4081 builder_.Mul(2.0f, builder_.Mul(normal.expr,
4082 builder_.Mul(normal.expr,
4083 incident.expr))))};
4084 }
4085
4086 case GLSLstd450Refract: {
4087 // It's a complicated expression. Compute it in two dimensions, but
4088 // with a 0-valued y component in both the incident and normal vectors,
4089 // then take the x component of that result.
4090 auto incident = MakeOperand(inst, 2);
4091 auto normal = MakeOperand(inst, 3);
4092 auto eta = MakeOperand(inst, 4);
4093 TINT_ASSERT(Reader, incident.type->Is<F32>());
4094 TINT_ASSERT(Reader, normal.type->Is<F32>());
4095 TINT_ASSERT(Reader, eta.type->Is<F32>());
4096 if (!success()) {
4097 return {};
4098 }
4099 const Type* f32 = eta.type;
4100 return {f32,
4101 builder_.MemberAccessor(
4102 builder_.Call(
4103 Source{}, "refract",
4104 ast::ExpressionList{
4105 builder_.vec2<float>(incident.expr, 0.0f),
4106 builder_.vec2<float>(normal.expr, 0.0f), eta.expr}),
4107 "x")};
4108 }
4109 default:
4110 break;
4111 }
4112 }
4113
4114 // Some GLSLStd450 builtins don't have a WGSL equivalent. Polyfill them.
4115 switch (ext_opcode) {
4116 case GLSLstd450Radians: {
4117 auto degrees = MakeOperand(inst, 2);
4118 TINT_ASSERT(Reader, degrees.type->IsFloatScalarOrVector());
4119
4120 constexpr auto kPiOver180 = static_cast<float>(3.141592653589793 / 180.0);
4121 auto* factor = builder_.Expr(kPiOver180);
4122 if (degrees.type->Is<F32>()) {
4123 return {degrees.type, builder_.Mul(degrees.expr, factor)};
4124 } else {
4125 uint32_t size = degrees.type->As<Vector>()->size;
4126 return {degrees.type,
4127 builder_.Mul(degrees.expr,
4128 builder_.vec(builder_.ty.f32(), size, factor))};
4129 }
4130 }
4131
4132 case GLSLstd450Degrees: {
4133 auto radians = MakeOperand(inst, 2);
4134 TINT_ASSERT(Reader, radians.type->IsFloatScalarOrVector());
4135
4136 constexpr auto k180OverPi = static_cast<float>(180.0 / 3.141592653589793);
4137 auto* factor = builder_.Expr(k180OverPi);
4138 if (radians.type->Is<F32>()) {
4139 return {radians.type, builder_.Mul(radians.expr, factor)};
4140 } else {
4141 uint32_t size = radians.type->As<Vector>()->size;
4142 return {radians.type,
4143 builder_.Mul(radians.expr,
4144 builder_.vec(builder_.ty.f32(), size, factor))};
4145 }
4146 }
4147 }
4148
4149 const auto name = GetGlslStd450FuncName(ext_opcode);
4150 if (name.empty()) {
4151 Fail() << "unhandled GLSL.std.450 instruction " << ext_opcode;
4152 return {};
4153 }
4154
4155 auto* func = create<ast::IdentifierExpression>(
4156 Source{}, builder_.Symbols().Register(name));
4157 ast::ExpressionList operands;
4158 const Type* first_operand_type = nullptr;
4159 // All parameters to GLSL.std.450 extended instructions are IDs.
4160 for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
4161 TypedExpression operand = MakeOperand(inst, iarg);
4162 if (first_operand_type == nullptr) {
4163 first_operand_type = operand.type;
4164 }
4165 operands.emplace_back(operand.expr);
4166 }
4167 auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands));
4168 TypedExpression call_expr{result_type, call};
4169 return parser_impl_.RectifyForcedResultType(call_expr, inst,
4170 first_operand_type);
4171 }
4172
Swizzle(uint32_t i)4173 ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
4174 if (i >= kMaxVectorLen) {
4175 Fail() << "vector component index is larger than " << kMaxVectorLen - 1
4176 << ": " << i;
4177 return nullptr;
4178 }
4179 const char* names[] = {"x", "y", "z", "w"};
4180 return create<ast::IdentifierExpression>(
4181 Source{}, builder_.Symbols().Register(names[i & 3]));
4182 }
4183
PrefixSwizzle(uint32_t n)4184 ast::IdentifierExpression* FunctionEmitter::PrefixSwizzle(uint32_t n) {
4185 switch (n) {
4186 case 1:
4187 return create<ast::IdentifierExpression>(
4188 Source{}, builder_.Symbols().Register("x"));
4189 case 2:
4190 return create<ast::IdentifierExpression>(
4191 Source{}, builder_.Symbols().Register("xy"));
4192 case 3:
4193 return create<ast::IdentifierExpression>(
4194 Source{}, builder_.Symbols().Register("xyz"));
4195 default:
4196 break;
4197 }
4198 Fail() << "invalid swizzle prefix count: " << n;
4199 return nullptr;
4200 }
4201
MakeFMod(const spvtools::opt::Instruction & inst)4202 TypedExpression FunctionEmitter::MakeFMod(
4203 const spvtools::opt::Instruction& inst) {
4204 auto x = MakeOperand(inst, 0);
4205 auto y = MakeOperand(inst, 1);
4206 if (!x || !y) {
4207 return {};
4208 }
4209 // Emulated with: x - y * floor(x / y)
4210 auto* div = builder_.Div(x.expr, y.expr);
4211 auto* floor = builder_.Call("floor", div);
4212 auto* y_floor = builder_.Mul(y.expr, floor);
4213 auto* res = builder_.Sub(x.expr, y_floor);
4214 return {x.type, res};
4215 }
4216
MakeAccessChain(const spvtools::opt::Instruction & inst)4217 TypedExpression FunctionEmitter::MakeAccessChain(
4218 const spvtools::opt::Instruction& inst) {
4219 if (inst.NumInOperands() < 1) {
4220 // Binary parsing will fail on this anyway.
4221 Fail() << "invalid access chain: has no input operands";
4222 return {};
4223 }
4224
4225 const auto base_id = inst.GetSingleWordInOperand(0);
4226 const auto base_skip = GetSkipReason(base_id);
4227 if (base_skip != SkipReason::kDontSkip) {
4228 // This can occur for AccessChain with no indices.
4229 GetDefInfo(inst.result_id())->skip = base_skip;
4230 GetDefInfo(inst.result_id())->sink_pointer_source_expr =
4231 GetDefInfo(base_id)->sink_pointer_source_expr;
4232 return {};
4233 }
4234
4235 auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
4236 uint32_t first_index = 1;
4237 const auto num_in_operands = inst.NumInOperands();
4238
4239 bool sink_pointer = false;
4240 TypedExpression current_expr;
4241
4242 // If the variable was originally gl_PerVertex, then in the AST we
4243 // have instead emitted a gl_Position variable.
4244 // If computing the pointer to the Position builtin, then emit the
4245 // pointer to the generated gl_Position variable.
4246 // If computing the pointer to the PointSize builtin, then mark the
4247 // result as skippable due to being the point-size pointer.
4248 // If computing the pointer to the ClipDistance or CullDistance builtins,
4249 // then error out.
4250 {
4251 const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
4252 if (base_id == builtin_position_info.per_vertex_var_id) {
4253 // We only support the Position member.
4254 const auto* member_index_inst =
4255 def_use_mgr_->GetDef(inst.GetSingleWordInOperand(first_index));
4256 if (member_index_inst == nullptr) {
4257 Fail()
4258 << "first index of access chain does not reference an instruction: "
4259 << inst.PrettyPrint();
4260 return {};
4261 }
4262 const auto* member_index_const =
4263 constant_mgr_->GetConstantFromInst(member_index_inst);
4264 if (member_index_const == nullptr) {
4265 Fail() << "first index of access chain into per-vertex structure is "
4266 "not a constant: "
4267 << inst.PrettyPrint();
4268 return {};
4269 }
4270 const auto* member_index_const_int = member_index_const->AsIntConstant();
4271 if (member_index_const_int == nullptr) {
4272 Fail() << "first index of access chain into per-vertex structure is "
4273 "not a constant integer: "
4274 << inst.PrettyPrint();
4275 return {};
4276 }
4277 const auto member_index_value =
4278 member_index_const_int->GetZeroExtendedValue();
4279 if (member_index_value != builtin_position_info.position_member_index) {
4280 if (member_index_value ==
4281 builtin_position_info.pointsize_member_index) {
4282 if (auto* def_info = GetDefInfo(inst.result_id())) {
4283 def_info->skip = SkipReason::kPointSizeBuiltinPointer;
4284 return {};
4285 }
4286 } else {
4287 // TODO(dneto): Handle ClipDistance and CullDistance
4288 Fail() << "accessing per-vertex member " << member_index_value
4289 << " is not supported. Only Position is supported, and "
4290 "PointSize is ignored";
4291 return {};
4292 }
4293 }
4294
4295 // Skip past the member index that gets us to Position.
4296 first_index = first_index + 1;
4297 // Replace the gl_PerVertex reference with the gl_Position reference
4298 ptr_ty_id = builtin_position_info.position_member_pointer_type_id;
4299
4300 auto name = namer_.Name(base_id);
4301 current_expr.expr = create<ast::IdentifierExpression>(
4302 Source{}, builder_.Symbols().Register(name));
4303 current_expr.type = parser_impl_.ConvertType(ptr_ty_id, PtrAs::Ref);
4304 }
4305 }
4306
4307 // A SPIR-V access chain is a single instruction with multiple indices
4308 // walking down into composites. The Tint AST represents this as
4309 // ever-deeper nested indexing expressions. Start off with an expression
4310 // for the base, and then bury that inside nested indexing expressions.
4311 if (!current_expr) {
4312 current_expr = InferFunctionStorageClass(MakeOperand(inst, 0));
4313 if (current_expr.type->Is<Pointer>()) {
4314 current_expr = Dereference(current_expr);
4315 }
4316 }
4317 const auto constants = constant_mgr_->GetOperandConstants(&inst);
4318
4319 const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id);
4320 if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) {
4321 Fail() << "Access chain %" << inst.result_id()
4322 << " base pointer is not of pointer type";
4323 return {};
4324 }
4325 SpvStorageClass storage_class =
4326 static_cast<SpvStorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
4327 uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
4328
4329 // Build up a nested expression for the access chain by walking down the type
4330 // hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of
4331 // the object pointed to after processing the previous indices.
4332 for (uint32_t index = first_index; index < num_in_operands; ++index) {
4333 const auto* index_const =
4334 constants[index] ? constants[index]->AsIntConstant() : nullptr;
4335 const int64_t index_const_val =
4336 index_const ? index_const->GetSignExtendedValue() : 0;
4337 const ast::Expression* next_expr = nullptr;
4338
4339 const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id);
4340 if (!pointee_type_inst) {
4341 Fail() << "pointee type %" << pointee_type_id
4342 << " is invalid after following " << (index - first_index)
4343 << " indices: " << inst.PrettyPrint();
4344 return {};
4345 }
4346 switch (pointee_type_inst->opcode()) {
4347 case SpvOpTypeVector:
4348 if (index_const) {
4349 // Try generating a MemberAccessor expression
4350 const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1);
4351 if (index_const_val < 0 || num_elems <= index_const_val) {
4352 Fail() << "Access chain %" << inst.result_id() << " index %"
4353 << inst.GetSingleWordInOperand(index) << " value "
4354 << index_const_val << " is out of bounds for vector of "
4355 << num_elems << " elements";
4356 return {};
4357 }
4358 if (uint64_t(index_const_val) >= kMaxVectorLen) {
4359 Fail() << "internal error: swizzle index " << index_const_val
4360 << " is too big. Max handled index is " << kMaxVectorLen - 1;
4361 }
4362 next_expr = create<ast::MemberAccessorExpression>(
4363 Source{}, current_expr.expr, Swizzle(uint32_t(index_const_val)));
4364 } else {
4365 // Non-constant index. Use array syntax
4366 next_expr = create<ast::IndexAccessorExpression>(
4367 Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4368 }
4369 // All vector components are the same type.
4370 pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4371 // Sink pointers to vector components.
4372 sink_pointer = true;
4373 break;
4374 case SpvOpTypeMatrix:
4375 // Use array syntax.
4376 next_expr = create<ast::IndexAccessorExpression>(
4377 Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4378 // All matrix components are the same type.
4379 pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4380 break;
4381 case SpvOpTypeArray:
4382 next_expr = create<ast::IndexAccessorExpression>(
4383 Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4384 pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4385 break;
4386 case SpvOpTypeRuntimeArray:
4387 next_expr = create<ast::IndexAccessorExpression>(
4388 Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4389 pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4390 break;
4391 case SpvOpTypeStruct: {
4392 if (!index_const) {
4393 Fail() << "Access chain %" << inst.result_id() << " index %"
4394 << inst.GetSingleWordInOperand(index)
4395 << " is a non-constant index into a structure %"
4396 << pointee_type_id;
4397 return {};
4398 }
4399 const auto num_members = pointee_type_inst->NumInOperands();
4400 if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) {
4401 Fail() << "Access chain %" << inst.result_id() << " index value "
4402 << index_const_val << " is out of bounds for structure %"
4403 << pointee_type_id << " having " << num_members << " members";
4404 return {};
4405 }
4406 auto name =
4407 namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val));
4408 auto* member_access = create<ast::IdentifierExpression>(
4409 Source{}, builder_.Symbols().Register(name));
4410
4411 next_expr = create<ast::MemberAccessorExpression>(
4412 Source{}, current_expr.expr, member_access);
4413 pointee_type_id = pointee_type_inst->GetSingleWordInOperand(
4414 static_cast<uint32_t>(index_const_val));
4415 break;
4416 }
4417 default:
4418 Fail() << "Access chain with unknown or invalid pointee type %"
4419 << pointee_type_id << ": " << pointee_type_inst->PrettyPrint();
4420 return {};
4421 }
4422 const auto pointer_type_id =
4423 type_mgr_->FindPointerToType(pointee_type_id, storage_class);
4424 auto* type = parser_impl_.ConvertType(pointer_type_id, PtrAs::Ref);
4425 TINT_ASSERT(Reader, type && type->Is<Reference>());
4426 current_expr = TypedExpression{type, next_expr};
4427 }
4428
4429 if (sink_pointer) {
4430 // Capture the reference so that we can sink it into the point of use.
4431 GetDefInfo(inst.result_id())->skip = SkipReason::kSinkPointerIntoUse;
4432 GetDefInfo(inst.result_id())->sink_pointer_source_expr = current_expr;
4433 }
4434
4435 return current_expr;
4436 }
4437
MakeCompositeExtract(const spvtools::opt::Instruction & inst)4438 TypedExpression FunctionEmitter::MakeCompositeExtract(
4439 const spvtools::opt::Instruction& inst) {
4440 // This is structurally similar to creating an access chain, but
4441 // the SPIR-V instruction has literal indices instead of IDs for indices.
4442
4443 auto composite_index = 0;
4444 auto first_index_position = 1;
4445 TypedExpression current_expr(MakeOperand(inst, composite_index));
4446 if (!current_expr) {
4447 return {};
4448 }
4449
4450 const auto composite_id = inst.GetSingleWordInOperand(composite_index);
4451 auto current_type_id = def_use_mgr_->GetDef(composite_id)->type_id();
4452
4453 return MakeCompositeValueDecomposition(inst, current_expr, current_type_id,
4454 first_index_position);
4455 }
4456
MakeCompositeValueDecomposition(const spvtools::opt::Instruction & inst,TypedExpression composite,uint32_t composite_type_id,int index_start)4457 TypedExpression FunctionEmitter::MakeCompositeValueDecomposition(
4458 const spvtools::opt::Instruction& inst,
4459 TypedExpression composite,
4460 uint32_t composite_type_id,
4461 int index_start) {
4462 // This is structurally similar to creating an access chain, but
4463 // the SPIR-V instruction has literal indices instead of IDs for indices.
4464
4465 // A SPIR-V composite extract is a single instruction with multiple
4466 // literal indices walking down into composites.
4467 // A SPIR-V composite insert is similar but also tells you what component
4468 // to inject. This function is responsible for the the walking-into part
4469 // of composite-insert.
4470 //
4471 // The Tint AST represents this as ever-deeper nested indexing expressions.
4472 // Start off with an expression for the composite, and then bury that inside
4473 // nested indexing expressions.
4474
4475 auto current_expr = composite;
4476 auto current_type_id = composite_type_id;
4477
4478 auto make_index = [this](uint32_t literal) {
4479 return create<ast::UintLiteralExpression>(Source{}, literal);
4480 };
4481
4482 // Build up a nested expression for the decomposition by walking down the type
4483 // hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of
4484 // the object pointed to after processing the previous indices.
4485 const auto num_in_operands = inst.NumInOperands();
4486 for (uint32_t index = index_start; index < num_in_operands; ++index) {
4487 const uint32_t index_val = inst.GetSingleWordInOperand(index);
4488
4489 const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id);
4490 if (!current_type_inst) {
4491 Fail() << "composite type %" << current_type_id
4492 << " is invalid after following " << (index - index_start)
4493 << " indices: " << inst.PrettyPrint();
4494 return {};
4495 }
4496 const char* operation_name = nullptr;
4497 switch (inst.opcode()) {
4498 case SpvOpCompositeExtract:
4499 operation_name = "OpCompositeExtract";
4500 break;
4501 case SpvOpCompositeInsert:
4502 operation_name = "OpCompositeInsert";
4503 break;
4504 default:
4505 Fail() << "internal error: unhandled " << inst.PrettyPrint();
4506 return {};
4507 }
4508 const ast::Expression* next_expr = nullptr;
4509 switch (current_type_inst->opcode()) {
4510 case SpvOpTypeVector: {
4511 // Try generating a MemberAccessor expression. That result in something
4512 // like "foo.z", which is more idiomatic than "foo[2]".
4513 const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
4514 if (num_elems <= index_val) {
4515 Fail() << operation_name << " %" << inst.result_id()
4516 << " index value " << index_val
4517 << " is out of bounds for vector of " << num_elems
4518 << " elements";
4519 return {};
4520 }
4521 if (index_val >= kMaxVectorLen) {
4522 Fail() << "internal error: swizzle index " << index_val
4523 << " is too big. Max handled index is " << kMaxVectorLen - 1;
4524 return {};
4525 }
4526 next_expr = create<ast::MemberAccessorExpression>(
4527 Source{}, current_expr.expr, Swizzle(index_val));
4528 // All vector components are the same type.
4529 current_type_id = current_type_inst->GetSingleWordInOperand(0);
4530 break;
4531 }
4532 case SpvOpTypeMatrix: {
4533 // Check bounds
4534 const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
4535 if (num_elems <= index_val) {
4536 Fail() << operation_name << " %" << inst.result_id()
4537 << " index value " << index_val
4538 << " is out of bounds for matrix of " << num_elems
4539 << " elements";
4540 return {};
4541 }
4542 if (index_val >= kMaxVectorLen) {
4543 Fail() << "internal error: swizzle index " << index_val
4544 << " is too big. Max handled index is " << kMaxVectorLen - 1;
4545 }
4546 // Use array syntax.
4547 next_expr = create<ast::IndexAccessorExpression>(
4548 Source{}, current_expr.expr, make_index(index_val));
4549 // All matrix components are the same type.
4550 current_type_id = current_type_inst->GetSingleWordInOperand(0);
4551 break;
4552 }
4553 case SpvOpTypeArray:
4554 // The array size could be a spec constant, and so it's not always
4555 // statically checkable. Instead, rely on a runtime index clamp
4556 // or runtime check to keep this safe.
4557 next_expr = create<ast::IndexAccessorExpression>(
4558 Source{}, current_expr.expr, make_index(index_val));
4559 current_type_id = current_type_inst->GetSingleWordInOperand(0);
4560 break;
4561 case SpvOpTypeRuntimeArray:
4562 Fail() << "can't do " << operation_name
4563 << " on a runtime array: " << inst.PrettyPrint();
4564 return {};
4565 case SpvOpTypeStruct: {
4566 const auto num_members = current_type_inst->NumInOperands();
4567 if (num_members <= index_val) {
4568 Fail() << operation_name << " %" << inst.result_id()
4569 << " index value " << index_val
4570 << " is out of bounds for structure %" << current_type_id
4571 << " having " << num_members << " members";
4572 return {};
4573 }
4574 auto name = namer_.GetMemberName(current_type_id, uint32_t(index_val));
4575 auto* member_access = create<ast::IdentifierExpression>(
4576 Source{}, builder_.Symbols().Register(name));
4577
4578 next_expr = create<ast::MemberAccessorExpression>(
4579 Source{}, current_expr.expr, member_access);
4580 current_type_id = current_type_inst->GetSingleWordInOperand(index_val);
4581 break;
4582 }
4583 default:
4584 Fail() << operation_name << " with bad type %" << current_type_id
4585 << ": " << current_type_inst->PrettyPrint();
4586 return {};
4587 }
4588 current_expr =
4589 TypedExpression{parser_impl_.ConvertType(current_type_id), next_expr};
4590 }
4591 return current_expr;
4592 }
4593
MakeTrue(const Source & source) const4594 const ast::Expression* FunctionEmitter::MakeTrue(const Source& source) const {
4595 return create<ast::BoolLiteralExpression>(source, true);
4596 }
4597
MakeFalse(const Source & source) const4598 const ast::Expression* FunctionEmitter::MakeFalse(const Source& source) const {
4599 return create<ast::BoolLiteralExpression>(source, false);
4600 }
4601
MakeVectorShuffle(const spvtools::opt::Instruction & inst)4602 TypedExpression FunctionEmitter::MakeVectorShuffle(
4603 const spvtools::opt::Instruction& inst) {
4604 const auto vec0_id = inst.GetSingleWordInOperand(0);
4605 const auto vec1_id = inst.GetSingleWordInOperand(1);
4606 const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id));
4607 const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id));
4608 const auto vec0_len =
4609 type_mgr_->GetType(vec0.type_id())->AsVector()->element_count();
4610 const auto vec1_len =
4611 type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
4612
4613 // Idiomatic vector accessors.
4614
4615 // Generate an ast::TypeConstructor expression.
4616 // Assume the literal indices are valid, and there is a valid number of them.
4617 auto source = GetSourceForInst(inst);
4618 const Vector* result_type =
4619 As<Vector>(parser_impl_.ConvertType(inst.type_id()));
4620 ast::ExpressionList values;
4621 for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
4622 const auto index = inst.GetSingleWordInOperand(i);
4623 if (index < vec0_len) {
4624 auto expr = MakeExpression(vec0_id);
4625 if (!expr) {
4626 return {};
4627 }
4628 values.emplace_back(create<ast::MemberAccessorExpression>(
4629 source, expr.expr, Swizzle(index)));
4630 } else if (index < vec0_len + vec1_len) {
4631 const auto sub_index = index - vec0_len;
4632 TINT_ASSERT(Reader, sub_index < kMaxVectorLen);
4633 auto expr = MakeExpression(vec1_id);
4634 if (!expr) {
4635 return {};
4636 }
4637 values.emplace_back(create<ast::MemberAccessorExpression>(
4638 source, expr.expr, Swizzle(sub_index)));
4639 } else if (index == 0xFFFFFFFF) {
4640 // By rule, this maps to OpUndef. Instead, make it zero.
4641 values.emplace_back(parser_impl_.MakeNullValue(result_type->type));
4642 } else {
4643 Fail() << "invalid vectorshuffle ID %" << inst.result_id()
4644 << ": index too large: " << index;
4645 return {};
4646 }
4647 }
4648 return {result_type,
4649 builder_.Construct(source, result_type->Build(builder_), values)};
4650 }
4651
RegisterSpecialBuiltInVariables()4652 bool FunctionEmitter::RegisterSpecialBuiltInVariables() {
4653 size_t index = def_info_.size();
4654 for (auto& special_var : parser_impl_.special_builtins()) {
4655 const auto id = special_var.first;
4656 const auto builtin = special_var.second;
4657 const auto* var = def_use_mgr_->GetDef(id);
4658 def_info_[id] = std::make_unique<DefInfo>(*var, 0, index);
4659 ++index;
4660 auto& def = def_info_[id];
4661 switch (builtin) {
4662 case SpvBuiltInPointSize:
4663 def->skip = SkipReason::kPointSizeBuiltinPointer;
4664 break;
4665 case SpvBuiltInSampleMask: {
4666 // Distinguish between input and output variable.
4667 const auto storage_class =
4668 static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
4669 if (storage_class == SpvStorageClassInput) {
4670 sample_mask_in_id = id;
4671 def->skip = SkipReason::kSampleMaskInBuiltinPointer;
4672 } else {
4673 sample_mask_out_id = id;
4674 def->skip = SkipReason::kSampleMaskOutBuiltinPointer;
4675 }
4676 break;
4677 }
4678 case SpvBuiltInSampleId:
4679 case SpvBuiltInInstanceIndex:
4680 case SpvBuiltInVertexIndex:
4681 case SpvBuiltInLocalInvocationIndex:
4682 case SpvBuiltInLocalInvocationId:
4683 case SpvBuiltInGlobalInvocationId:
4684 case SpvBuiltInWorkgroupId:
4685 case SpvBuiltInNumWorkgroups:
4686 break;
4687 default:
4688 return Fail() << "unrecognized special builtin: " << int(builtin);
4689 }
4690 }
4691 return true;
4692 }
4693
RegisterLocallyDefinedValues()4694 bool FunctionEmitter::RegisterLocallyDefinedValues() {
4695 // Create a DefInfo for each value definition in this function.
4696 size_t index = def_info_.size();
4697 for (auto block_id : block_order_) {
4698 const auto* block_info = GetBlockInfo(block_id);
4699 const auto block_pos = block_info->pos;
4700 for (const auto& inst : *(block_info->basic_block)) {
4701 const auto result_id = inst.result_id();
4702 if ((result_id == 0) || inst.opcode() == SpvOpLabel) {
4703 continue;
4704 }
4705 def_info_[result_id] = std::make_unique<DefInfo>(inst, block_pos, index);
4706 ++index;
4707 auto& info = def_info_[result_id];
4708
4709 // Determine storage class for pointer values. Do this in order because
4710 // we might rely on the storage class for a previously-visited definition.
4711 // Logical pointers can't be transmitted through OpPhi, so remaining
4712 // pointer definitions are SSA values, and their definitions must be
4713 // visited before their uses.
4714 const auto* type = type_mgr_->GetType(inst.type_id());
4715 if (type) {
4716 if (type->AsPointer()) {
4717 if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
4718 if (auto* ptr = ast_type->As<Pointer>()) {
4719 info->storage_class = ptr->storage_class;
4720 }
4721 }
4722 switch (inst.opcode()) {
4723 case SpvOpUndef:
4724 return Fail()
4725 << "undef pointer is not valid: " << inst.PrettyPrint();
4726 case SpvOpVariable:
4727 // Keep the default decision based on the result type.
4728 break;
4729 case SpvOpAccessChain:
4730 case SpvOpInBoundsAccessChain:
4731 case SpvOpCopyObject:
4732 // Inherit from the first operand. We need this so we can pick up
4733 // a remapped storage buffer.
4734 info->storage_class = GetStorageClassForPointerValue(
4735 inst.GetSingleWordInOperand(0));
4736 break;
4737 default:
4738 return Fail()
4739 << "pointer defined in function from unknown opcode: "
4740 << inst.PrettyPrint();
4741 }
4742 }
4743 auto* unwrapped = type;
4744 while (auto* ptr = unwrapped->AsPointer()) {
4745 unwrapped = ptr->pointee_type();
4746 }
4747 if (unwrapped->AsSampler() || unwrapped->AsImage() ||
4748 unwrapped->AsSampledImage()) {
4749 // Defer code generation until the instruction that actually acts on
4750 // the image.
4751 info->skip = SkipReason::kOpaqueObject;
4752 }
4753 }
4754 }
4755 }
4756 return true;
4757 }
4758
GetStorageClassForPointerValue(uint32_t id)4759 ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
4760 auto where = def_info_.find(id);
4761 if (where != def_info_.end()) {
4762 auto candidate = where->second.get()->storage_class;
4763 if (candidate != ast::StorageClass::kInvalid) {
4764 return candidate;
4765 }
4766 }
4767 const auto type_id = def_use_mgr_->GetDef(id)->type_id();
4768 if (type_id) {
4769 auto* ast_type = parser_impl_.ConvertType(type_id);
4770 if (auto* ptr = As<Pointer>(ast_type)) {
4771 return ptr->storage_class;
4772 }
4773 }
4774 return ast::StorageClass::kInvalid;
4775 }
4776
RemapStorageClass(const Type * type,uint32_t result_id)4777 const Type* FunctionEmitter::RemapStorageClass(const Type* type,
4778 uint32_t result_id) {
4779 if (auto* ast_ptr_type = As<Pointer>(type)) {
4780 // Remap an old-style storage buffer pointer to a new-style storage
4781 // buffer pointer.
4782 const auto sc = GetStorageClassForPointerValue(result_id);
4783 if (ast_ptr_type->storage_class != sc) {
4784 return ty_.Pointer(ast_ptr_type->type, sc);
4785 }
4786 }
4787 return type;
4788 }
4789
FindValuesNeedingNamedOrHoistedDefinition()4790 void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
4791 // Mark vector operands of OpVectorShuffle as needing a named definition,
4792 // but only if they are defined in this function as well.
4793 auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
4794 int in_operand_index) {
4795 const auto id = inst.GetSingleWordInOperand(in_operand_index);
4796 auto* const operand_def = GetDefInfo(id);
4797 if (operand_def) {
4798 operand_def->requires_named_const_def = true;
4799 }
4800 };
4801 for (auto& id_def_info_pair : def_info_) {
4802 const auto& inst = id_def_info_pair.second->inst;
4803 const auto opcode = inst.opcode();
4804 if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
4805 // We might access the vector operands multiple times. Make sure they
4806 // are evaluated only once.
4807 require_named_const_def(inst, 0);
4808 require_named_const_def(inst, 1);
4809 }
4810 if (parser_impl_.IsGlslExtendedInstruction(inst)) {
4811 // Some emulations of GLSLstd450 instructions evaluate certain operands
4812 // multiple times. Ensure their expressions are evaluated only once.
4813 switch (inst.GetSingleWordInOperand(1)) {
4814 case GLSLstd450FaceForward:
4815 // The "normal" operand expression is used twice in code generation.
4816 require_named_const_def(inst, 2);
4817 break;
4818 case GLSLstd450Reflect:
4819 require_named_const_def(inst, 2); // Incident
4820 require_named_const_def(inst, 3); // Normal
4821 break;
4822 default:
4823 break;
4824 }
4825 }
4826 }
4827
4828 // Scan uses of locally defined IDs, in function block order.
4829 for (auto block_id : block_order_) {
4830 const auto* block_info = GetBlockInfo(block_id);
4831 const auto block_pos = block_info->pos;
4832 for (const auto& inst : *(block_info->basic_block)) {
4833 // Update bookkeeping for locally-defined IDs used by this instruction.
4834 inst.ForEachInId([this, block_pos, block_info](const uint32_t* id_ptr) {
4835 auto* def_info = GetDefInfo(*id_ptr);
4836 if (def_info) {
4837 // Update usage count.
4838 def_info->num_uses++;
4839 // Update usage span.
4840 def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
4841
4842 // Determine whether this ID is defined in a different construct
4843 // from this use.
4844 const auto defining_block = block_order_[def_info->block_pos];
4845 const auto* def_in_construct =
4846 GetBlockInfo(defining_block)->construct;
4847 if (def_in_construct != block_info->construct) {
4848 def_info->used_in_another_construct = true;
4849 }
4850 }
4851 });
4852
4853 if (inst.opcode() == SpvOpPhi) {
4854 // Declare a name for the variable used to carry values to a phi.
4855 const auto phi_id = inst.result_id();
4856 auto* phi_def_info = GetDefInfo(phi_id);
4857 phi_def_info->phi_var =
4858 namer_.MakeDerivedName(namer_.Name(phi_id) + "_phi");
4859 // Track all the places where we need to mention the variable,
4860 // so we can place its declaration. First, record the location of
4861 // the read from the variable.
4862 uint32_t first_pos = block_pos;
4863 uint32_t last_pos = block_pos;
4864 // Record the assignments that will propagate values from predecessor
4865 // blocks.
4866 for (uint32_t i = 0; i + 1 < inst.NumInOperands(); i += 2) {
4867 const uint32_t value_id = inst.GetSingleWordInOperand(i);
4868 const uint32_t pred_block_id = inst.GetSingleWordInOperand(i + 1);
4869 auto* pred_block_info = GetBlockInfo(pred_block_id);
4870 // The predecessor might not be in the block order at all, so we
4871 // need this guard.
4872 if (IsInBlockOrder(pred_block_info)) {
4873 // Record the assignment that needs to occur at the end
4874 // of the predecessor block.
4875 pred_block_info->phi_assignments.push_back({phi_id, value_id});
4876 first_pos = std::min(first_pos, pred_block_info->pos);
4877 last_pos = std::max(last_pos, pred_block_info->pos);
4878 }
4879 }
4880
4881 // Schedule the declaration of the state variable.
4882 const auto* enclosing_construct =
4883 GetEnclosingScope(first_pos, last_pos);
4884 GetBlockInfo(enclosing_construct->begin_id)
4885 ->phis_needing_state_vars.push_back(phi_id);
4886 }
4887 }
4888 }
4889
4890 // For an ID defined in this function, determine if its evaluation and
4891 // potential declaration needs special handling:
4892 // - Compensate for the fact that dominance does not map directly to scope.
4893 // A definition could dominate its use, but a named definition in WGSL
4894 // at the location of the definition could go out of scope by the time
4895 // you reach the use. In that case, we hoist the definition to a basic
4896 // block at the smallest scope enclosing both the definition and all
4897 // its uses.
4898 // - If value is used in a different construct than its definition, then it
4899 // needs a named constant definition. Otherwise we might sink an
4900 // expensive computation into control flow, and hence change performance.
4901 for (auto& id_def_info_pair : def_info_) {
4902 const auto def_id = id_def_info_pair.first;
4903 auto* def_info = id_def_info_pair.second.get();
4904 if (def_info->num_uses == 0) {
4905 // There is no need to adjust the location of the declaration.
4906 continue;
4907 }
4908 // The first use must be the at the SSA definition, because block order
4909 // respects dominance.
4910 const auto first_pos = def_info->block_pos;
4911 const auto last_use_pos = def_info->last_use_pos;
4912
4913 const auto* def_in_construct =
4914 GetBlockInfo(block_order_[first_pos])->construct;
4915 // A definition in the first block of an kIfSelection or kSwitchSelection
4916 // occurs before the branch, and so that definition should count as
4917 // having been defined at the scope of the parent construct.
4918 if (first_pos == def_in_construct->begin_pos) {
4919 if ((def_in_construct->kind == Construct::kIfSelection) ||
4920 (def_in_construct->kind == Construct::kSwitchSelection)) {
4921 def_in_construct = def_in_construct->parent;
4922 }
4923 }
4924
4925 bool should_hoist = false;
4926 if (!def_in_construct->ContainsPos(last_use_pos)) {
4927 // To satisfy scoping, we have to hoist the definition out to an enclosing
4928 // construct.
4929 should_hoist = true;
4930 } else {
4931 // Avoid moving combinatorial values across constructs. This is a
4932 // simple heuristic to avoid changing the cost of an operation
4933 // by moving it into or out of a loop, for example.
4934 if ((def_info->storage_class == ast::StorageClass::kInvalid) &&
4935 def_info->used_in_another_construct) {
4936 should_hoist = true;
4937 }
4938 }
4939
4940 if (should_hoist) {
4941 const auto* enclosing_construct =
4942 GetEnclosingScope(first_pos, last_use_pos);
4943 if (enclosing_construct == def_in_construct) {
4944 // We can use a plain 'const' definition.
4945 def_info->requires_named_const_def = true;
4946 } else {
4947 // We need to make a hoisted variable definition.
4948 // TODO(dneto): Handle non-storable types, particularly pointers.
4949 def_info->requires_hoisted_def = true;
4950 auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
4951 hoist_to_block->hoisted_ids.push_back(def_id);
4952 }
4953 }
4954 }
4955 }
4956
GetEnclosingScope(uint32_t first_pos,uint32_t last_pos) const4957 const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos,
4958 uint32_t last_pos) const {
4959 const auto* enclosing_construct =
4960 GetBlockInfo(block_order_[first_pos])->construct;
4961 TINT_ASSERT(Reader, enclosing_construct != nullptr);
4962 // Constructs are strictly nesting, so follow parent pointers
4963 while (enclosing_construct &&
4964 !enclosing_construct->ScopeContainsPos(last_pos)) {
4965 // The scope of a continue construct is enclosed in its associated loop
4966 // construct, but they are siblings in our construct tree.
4967 const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
4968 // Go to the sibling loop if it exists, otherwise walk up to the parent.
4969 enclosing_construct =
4970 sibling_loop ? sibling_loop : enclosing_construct->parent;
4971 }
4972 // At worst, we go all the way out to the function construct.
4973 TINT_ASSERT(Reader, enclosing_construct != nullptr);
4974 return enclosing_construct;
4975 }
4976
MakeNumericConversion(const spvtools::opt::Instruction & inst)4977 TypedExpression FunctionEmitter::MakeNumericConversion(
4978 const spvtools::opt::Instruction& inst) {
4979 const auto opcode = inst.opcode();
4980 auto* requested_type = parser_impl_.ConvertType(inst.type_id());
4981 auto arg_expr = MakeOperand(inst, 0);
4982 if (!arg_expr) {
4983 return {};
4984 }
4985 arg_expr.type = arg_expr.type->UnwrapRef();
4986
4987 const Type* expr_type = nullptr;
4988 if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
4989 if (arg_expr.type->IsIntegerScalarOrVector()) {
4990 expr_type = requested_type;
4991 } else {
4992 Fail() << "operand for conversion to floating point must be integral "
4993 "scalar or vector: "
4994 << inst.PrettyPrint();
4995 }
4996 } else if (inst.opcode() == SpvOpConvertFToU) {
4997 if (arg_expr.type->IsFloatScalarOrVector()) {
4998 expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type);
4999 } else {
5000 Fail() << "operand for conversion to unsigned integer must be floating "
5001 "point scalar or vector: "
5002 << inst.PrettyPrint();
5003 }
5004 } else if (inst.opcode() == SpvOpConvertFToS) {
5005 if (arg_expr.type->IsFloatScalarOrVector()) {
5006 expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type);
5007 } else {
5008 Fail() << "operand for conversion to signed integer must be floating "
5009 "point scalar or vector: "
5010 << inst.PrettyPrint();
5011 }
5012 }
5013 if (expr_type == nullptr) {
5014 // The diagnostic has already been emitted.
5015 return {};
5016 }
5017
5018 ast::ExpressionList params;
5019 params.push_back(arg_expr.expr);
5020 TypedExpression result{
5021 expr_type,
5022 builder_.Construct(GetSourceForInst(inst), expr_type->Build(builder_),
5023 std::move(params))};
5024
5025 if (requested_type == expr_type) {
5026 return result;
5027 }
5028 return {requested_type, create<ast::BitcastExpression>(
5029 GetSourceForInst(inst),
5030 requested_type->Build(builder_), result.expr)};
5031 }
5032
EmitFunctionCall(const spvtools::opt::Instruction & inst)5033 bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
5034 // We ignore function attributes such as Inline, DontInline, Pure, Const.
5035 auto name = namer_.Name(inst.GetSingleWordInOperand(0));
5036 auto* function = create<ast::IdentifierExpression>(
5037 Source{}, builder_.Symbols().Register(name));
5038
5039 ast::ExpressionList args;
5040 for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
5041 auto expr = MakeOperand(inst, iarg);
5042 if (!expr) {
5043 return false;
5044 }
5045 // Functions cannot use references as parameters, so we need to pass by
5046 // pointer if the operand is of pointer type.
5047 expr = AddressOfIfNeeded(
5048 expr, def_use_mgr_->GetDef(inst.GetSingleWordInOperand(iarg)));
5049 args.emplace_back(expr.expr);
5050 }
5051 if (failed()) {
5052 return false;
5053 }
5054 auto* call_expr =
5055 create<ast::CallExpression>(Source{}, function, std::move(args));
5056 auto* result_type = parser_impl_.ConvertType(inst.type_id());
5057 if (!result_type) {
5058 return Fail() << "internal error: no mapped type result of call: "
5059 << inst.PrettyPrint();
5060 }
5061
5062 if (result_type->Is<Void>()) {
5063 return nullptr !=
5064 AddStatement(create<ast::CallStatement>(Source{}, call_expr));
5065 }
5066
5067 return EmitConstDefOrWriteToHoistedVar(inst, {result_type, call_expr});
5068 }
5069
EmitControlBarrier(const spvtools::opt::Instruction & inst)5070 bool FunctionEmitter::EmitControlBarrier(
5071 const spvtools::opt::Instruction& inst) {
5072 uint32_t operands[3];
5073 for (int i = 0; i < 3; i++) {
5074 auto id = inst.GetSingleWordInOperand(i);
5075 if (auto* constant = constant_mgr_->FindDeclaredConstant(id)) {
5076 operands[i] = constant->GetU32();
5077 } else {
5078 return Fail() << "invalid or missing operands for control barrier";
5079 }
5080 }
5081
5082 uint32_t execution = operands[0];
5083 uint32_t memory = operands[1];
5084 uint32_t semantics = operands[2];
5085
5086 if (execution != SpvScopeWorkgroup) {
5087 return Fail() << "unsupported control barrier execution scope: "
5088 << "expected Workgroup (2), got: " << execution;
5089 }
5090 if (semantics & SpvMemorySemanticsAcquireReleaseMask) {
5091 semantics &= ~SpvMemorySemanticsAcquireReleaseMask;
5092 } else {
5093 return Fail() << "control barrier semantics requires acquire and release";
5094 }
5095 if (semantics & SpvMemorySemanticsWorkgroupMemoryMask) {
5096 if (memory != SpvScopeWorkgroup) {
5097 return Fail() << "workgroupBarrier requires workgroup memory scope";
5098 }
5099 AddStatement(create<ast::CallStatement>(builder_.Call("workgroupBarrier")));
5100 semantics &= ~SpvMemorySemanticsWorkgroupMemoryMask;
5101 }
5102 if (semantics & SpvMemorySemanticsUniformMemoryMask) {
5103 if (memory != SpvScopeDevice) {
5104 return Fail() << "storageBarrier requires device memory scope";
5105 }
5106 AddStatement(create<ast::CallStatement>(builder_.Call("storageBarrier")));
5107 semantics &= ~SpvMemorySemanticsUniformMemoryMask;
5108 }
5109 if (semantics) {
5110 return Fail() << "unsupported control barrier semantics: " << semantics;
5111 }
5112 return true;
5113 }
5114
MakeIntrinsicCall(const spvtools::opt::Instruction & inst)5115 TypedExpression FunctionEmitter::MakeIntrinsicCall(
5116 const spvtools::opt::Instruction& inst) {
5117 const auto intrinsic = GetIntrinsic(inst.opcode());
5118 auto* name = sem::str(intrinsic);
5119 auto* ident = create<ast::IdentifierExpression>(
5120 Source{}, builder_.Symbols().Register(name));
5121
5122 ast::ExpressionList params;
5123 const Type* first_operand_type = nullptr;
5124 for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
5125 TypedExpression operand = MakeOperand(inst, iarg);
5126 if (first_operand_type == nullptr) {
5127 first_operand_type = operand.type;
5128 }
5129 params.emplace_back(operand.expr);
5130 }
5131 auto* call_expr =
5132 create<ast::CallExpression>(Source{}, ident, std::move(params));
5133 auto* result_type = parser_impl_.ConvertType(inst.type_id());
5134 if (!result_type) {
5135 Fail() << "internal error: no mapped type result of call: "
5136 << inst.PrettyPrint();
5137 return {};
5138 }
5139 TypedExpression call{result_type, call_expr};
5140 return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
5141 }
5142
MakeSimpleSelect(const spvtools::opt::Instruction & inst)5143 TypedExpression FunctionEmitter::MakeSimpleSelect(
5144 const spvtools::opt::Instruction& inst) {
5145 auto condition = MakeOperand(inst, 0);
5146 auto true_value = MakeOperand(inst, 1);
5147 auto false_value = MakeOperand(inst, 2);
5148
5149 // SPIR-V validation requires:
5150 // - the condition to be bool or bool vector, so we don't check it here.
5151 // - true_value false_value, and result type to match.
5152 // - you can't select over pointers or pointer vectors, unless you also have
5153 // a VariablePointers* capability, which is not allowed in by WebGPU.
5154 auto* op_ty = true_value.type;
5155 if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() ||
5156 op_ty->IsIntegerScalar() || op_ty->Is<Bool>()) {
5157 ast::ExpressionList params;
5158 params.push_back(false_value.expr);
5159 params.push_back(true_value.expr);
5160 // The condition goes last.
5161 params.push_back(condition.expr);
5162 return {op_ty, create<ast::CallExpression>(
5163 Source{},
5164 create<ast::IdentifierExpression>(
5165 Source{}, builder_.Symbols().Register("select")),
5166 std::move(params))};
5167 }
5168 return {};
5169 }
5170
GetSourceForInst(const spvtools::opt::Instruction & inst) const5171 Source FunctionEmitter::GetSourceForInst(
5172 const spvtools::opt::Instruction& inst) const {
5173 return parser_impl_.GetSourceForInst(&inst);
5174 }
5175
GetImage(const spvtools::opt::Instruction & inst)5176 const spvtools::opt::Instruction* FunctionEmitter::GetImage(
5177 const spvtools::opt::Instruction& inst) {
5178 if (inst.NumInOperands() == 0) {
5179 Fail() << "not an image access instruction: " << inst.PrettyPrint();
5180 return nullptr;
5181 }
5182 // The image or sampled image operand is always the first operand.
5183 const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
5184 const auto* image = parser_impl_.GetMemoryObjectDeclarationForHandle(
5185 image_or_sampled_image_operand_id, true);
5186 if (!image) {
5187 Fail() << "internal error: couldn't find image for " << inst.PrettyPrint();
5188 return nullptr;
5189 }
5190 return image;
5191 }
5192
GetImageType(const spvtools::opt::Instruction & image)5193 const Texture* FunctionEmitter::GetImageType(
5194 const spvtools::opt::Instruction& image) {
5195 const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
5196 if (!parser_impl_.success()) {
5197 Fail();
5198 return {};
5199 }
5200 if (!ptr_type) {
5201 Fail() << "invalid texture type for " << image.PrettyPrint();
5202 return {};
5203 }
5204 auto* result = ptr_type->type->UnwrapAll()->As<Texture>();
5205 if (!result) {
5206 Fail() << "invalid texture type for " << image.PrettyPrint();
5207 return {};
5208 }
5209 return result;
5210 }
5211
GetImageExpression(const spvtools::opt::Instruction & inst)5212 const ast::Expression* FunctionEmitter::GetImageExpression(
5213 const spvtools::opt::Instruction& inst) {
5214 auto* image = GetImage(inst);
5215 if (!image) {
5216 return nullptr;
5217 }
5218 auto name = namer_.Name(image->result_id());
5219 return create<ast::IdentifierExpression>(GetSourceForInst(inst),
5220 builder_.Symbols().Register(name));
5221 }
5222
GetSamplerExpression(const spvtools::opt::Instruction & inst)5223 const ast::Expression* FunctionEmitter::GetSamplerExpression(
5224 const spvtools::opt::Instruction& inst) {
5225 // The sampled image operand is always the first operand.
5226 const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
5227 const auto* image = parser_impl_.GetMemoryObjectDeclarationForHandle(
5228 image_or_sampled_image_operand_id, false);
5229 if (!image) {
5230 Fail() << "internal error: couldn't find sampler for "
5231 << inst.PrettyPrint();
5232 return nullptr;
5233 }
5234 auto name = namer_.Name(image->result_id());
5235 return create<ast::IdentifierExpression>(GetSourceForInst(inst),
5236 builder_.Symbols().Register(name));
5237 }
5238
EmitImageAccess(const spvtools::opt::Instruction & inst)5239 bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
5240 ast::ExpressionList params;
5241 const auto opcode = inst.opcode();
5242
5243 // Form the texture operand.
5244 const spvtools::opt::Instruction* image = GetImage(inst);
5245 if (!image) {
5246 return false;
5247 }
5248 params.push_back(GetImageExpression(inst));
5249
5250 if (IsSampledImageAccess(opcode)) {
5251 // Form the sampler operand.
5252 if (auto* sampler = GetSamplerExpression(inst)) {
5253 params.push_back(sampler);
5254 } else {
5255 return false;
5256 }
5257 }
5258
5259 const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
5260 if (!texture_ptr_type) {
5261 return Fail();
5262 }
5263 const Texture* texture_type =
5264 texture_ptr_type->type->UnwrapAll()->As<Texture>();
5265
5266 if (!texture_type) {
5267 return Fail();
5268 }
5269
5270 // This is the SPIR-V operand index. We're done with the first operand.
5271 uint32_t arg_index = 1;
5272
5273 // Push the coordinates operands.
5274 auto coords = MakeCoordinateOperandsForImageAccess(inst);
5275 if (coords.empty()) {
5276 return false;
5277 }
5278 params.insert(params.end(), coords.begin(), coords.end());
5279 // Skip the coordinates operand.
5280 arg_index++;
5281
5282 const auto num_args = inst.NumInOperands();
5283
5284 std::string builtin_name;
5285 bool use_level_of_detail_suffix = true;
5286 bool is_dref_sample = false;
5287 bool is_non_dref_sample = false;
5288 switch (opcode) {
5289 case SpvOpImageSampleImplicitLod:
5290 case SpvOpImageSampleExplicitLod:
5291 case SpvOpImageSampleProjImplicitLod:
5292 case SpvOpImageSampleProjExplicitLod:
5293 is_non_dref_sample = true;
5294 builtin_name = "textureSample";
5295 break;
5296 case SpvOpImageSampleDrefImplicitLod:
5297 case SpvOpImageSampleDrefExplicitLod:
5298 case SpvOpImageSampleProjDrefImplicitLod:
5299 case SpvOpImageSampleProjDrefExplicitLod:
5300 is_dref_sample = true;
5301 builtin_name = "textureSampleCompare";
5302 if (arg_index < num_args) {
5303 params.push_back(MakeOperand(inst, arg_index).expr);
5304 arg_index++;
5305 } else {
5306 return Fail()
5307 << "image depth-compare instruction is missing a Dref operand: "
5308 << inst.PrettyPrint();
5309 }
5310 break;
5311 case SpvOpImageGather:
5312 case SpvOpImageDrefGather:
5313 return Fail() << " image gather is not yet supported";
5314 case SpvOpImageFetch:
5315 case SpvOpImageRead:
5316 // Read a single texel from a sampled or storage image.
5317 builtin_name = "textureLoad";
5318 use_level_of_detail_suffix = false;
5319 break;
5320 case SpvOpImageWrite:
5321 builtin_name = "textureStore";
5322 use_level_of_detail_suffix = false;
5323 if (arg_index < num_args) {
5324 auto texel = MakeOperand(inst, arg_index);
5325 auto* converted_texel =
5326 ConvertTexelForStorage(inst, texel, texture_type);
5327 if (!converted_texel) {
5328 return false;
5329 }
5330
5331 params.push_back(converted_texel);
5332 arg_index++;
5333 } else {
5334 return Fail() << "image write is missing a Texel operand: "
5335 << inst.PrettyPrint();
5336 }
5337 break;
5338 default:
5339 return Fail() << "internal error: unrecognized image access: "
5340 << inst.PrettyPrint();
5341 }
5342
5343 // Loop over the image operands, looking for extra operands to the builtin.
5344 // Except we uroll the loop.
5345 uint32_t image_operands_mask = 0;
5346 if (arg_index < num_args) {
5347 image_operands_mask = inst.GetSingleWordInOperand(arg_index);
5348 arg_index++;
5349 }
5350 if (arg_index < num_args &&
5351 (image_operands_mask & SpvImageOperandsBiasMask)) {
5352 if (is_dref_sample) {
5353 return Fail() << "WGSL does not support depth-reference sampling with "
5354 "level-of-detail bias: "
5355 << inst.PrettyPrint();
5356 }
5357 builtin_name += "Bias";
5358 params.push_back(MakeOperand(inst, arg_index).expr);
5359 image_operands_mask ^= SpvImageOperandsBiasMask;
5360 arg_index++;
5361 }
5362 if (arg_index < num_args && (image_operands_mask & SpvImageOperandsLodMask)) {
5363 if (use_level_of_detail_suffix) {
5364 builtin_name += "Level";
5365 }
5366 if (is_dref_sample) {
5367 // Metal only supports Lod = 0 for comparison sampling without
5368 // derivatives.
5369 if (!IsFloatZero(inst.GetSingleWordInOperand(arg_index))) {
5370 return Fail() << "WGSL comparison sampling without derivatives "
5371 "requires level-of-detail 0.0"
5372 << inst.PrettyPrint();
5373 }
5374 // Don't generate the Lod argument.
5375 } else {
5376 // Generate the Lod argument.
5377 TypedExpression lod = MakeOperand(inst, arg_index);
5378 // When sampling from a depth texture, the Lod operand must be an I32.
5379 if (texture_type->Is<DepthTexture>()) {
5380 // Convert it to a signed integer type.
5381 lod = ToI32(lod);
5382 }
5383 params.push_back(lod.expr);
5384 }
5385
5386 image_operands_mask ^= SpvImageOperandsLodMask;
5387 arg_index++;
5388 } else if ((opcode == SpvOpImageFetch || opcode == SpvOpImageRead) &&
5389 !texture_type
5390 ->IsAnyOf<DepthMultisampledTexture, MultisampledTexture>()) {
5391 // textureLoad requires an explicit level-of-detail parameter for
5392 // non-multisampled texture types.
5393 params.push_back(parser_impl_.MakeNullValue(ty_.I32()));
5394 }
5395 if (arg_index + 1 < num_args &&
5396 (image_operands_mask & SpvImageOperandsGradMask)) {
5397 if (is_dref_sample) {
5398 return Fail() << "WGSL does not support depth-reference sampling with "
5399 "explicit gradient: "
5400 << inst.PrettyPrint();
5401 }
5402 builtin_name += "Grad";
5403 params.push_back(MakeOperand(inst, arg_index).expr);
5404 params.push_back(MakeOperand(inst, arg_index + 1).expr);
5405 image_operands_mask ^= SpvImageOperandsGradMask;
5406 arg_index += 2;
5407 }
5408 if (arg_index < num_args &&
5409 (image_operands_mask & SpvImageOperandsConstOffsetMask)) {
5410 if (!IsImageSampling(opcode)) {
5411 return Fail() << "ConstOffset is only permitted for sampling operations: "
5412 << inst.PrettyPrint();
5413 }
5414 switch (texture_type->dims) {
5415 case ast::TextureDimension::k2d:
5416 case ast::TextureDimension::k2dArray:
5417 case ast::TextureDimension::k3d:
5418 break;
5419 default:
5420 return Fail() << "ConstOffset is only permitted for 2D, 2D Arrayed, "
5421 "and 3D textures: "
5422 << inst.PrettyPrint();
5423 }
5424
5425 params.push_back(ToSignedIfUnsigned(MakeOperand(inst, arg_index)).expr);
5426 image_operands_mask ^= SpvImageOperandsConstOffsetMask;
5427 arg_index++;
5428 }
5429 if (arg_index < num_args &&
5430 (image_operands_mask & SpvImageOperandsSampleMask)) {
5431 // TODO(dneto): only permitted with ImageFetch
5432 params.push_back(ToI32(MakeOperand(inst, arg_index)).expr);
5433 image_operands_mask ^= SpvImageOperandsSampleMask;
5434 arg_index++;
5435 }
5436 if (image_operands_mask) {
5437 return Fail() << "unsupported image operands (" << image_operands_mask
5438 << "): " << inst.PrettyPrint();
5439 }
5440
5441 auto* ident = create<ast::IdentifierExpression>(
5442 Source{}, builder_.Symbols().Register(builtin_name));
5443 auto* call_expr =
5444 create<ast::CallExpression>(Source{}, ident, std::move(params));
5445
5446 if (inst.type_id() != 0) {
5447 // It returns a value.
5448 const ast::Expression* value = call_expr;
5449
5450 // The result type, derived from the SPIR-V instruction.
5451 auto* result_type = parser_impl_.ConvertType(inst.type_id());
5452 auto* result_component_type = result_type;
5453 if (auto* result_vector_type = As<Vector>(result_type)) {
5454 result_component_type = result_vector_type->type;
5455 }
5456
5457 // For depth textures, the arity might mot match WGSL:
5458 // Operation SPIR-V WGSL
5459 // normal sampling vec4 ImplicitLod f32
5460 // normal sampling vec4 ExplicitLod f32
5461 // compare sample f32 DrefImplicitLod f32
5462 // compare sample f32 DrefExplicitLod f32
5463 // texel load vec4 ImageFetch f32
5464 // normal gather vec4 ImageGather vec4 TODO(dneto)
5465 // dref gather vec4 ImageFetch vec4 TODO(dneto)
5466 // Construct a 4-element vector with the result from the builtin in the
5467 // first component.
5468 if (texture_type->IsAnyOf<DepthTexture, DepthMultisampledTexture>()) {
5469 if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
5470 value = builder_.Construct(
5471 Source{},
5472 result_type->Build(builder_), // a vec4
5473 ast::ExpressionList{
5474 value, parser_impl_.MakeNullValue(result_component_type),
5475 parser_impl_.MakeNullValue(result_component_type),
5476 parser_impl_.MakeNullValue(result_component_type)});
5477 }
5478 }
5479
5480 // If necessary, convert the result to the signedness of the instruction
5481 // result type. Compare the SPIR-V image's sampled component type with the
5482 // component of the result type of the SPIR-V instruction.
5483 auto* spirv_image_type =
5484 parser_impl_.GetSpirvTypeForHandleMemoryObjectDeclaration(*image);
5485 if (!spirv_image_type || (spirv_image_type->opcode() != SpvOpTypeImage)) {
5486 return Fail() << "invalid image type for image memory object declaration "
5487 << image->PrettyPrint();
5488 }
5489 auto* expected_component_type =
5490 parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
5491 if (expected_component_type != result_component_type) {
5492 // This occurs if one is signed integer and the other is unsigned integer,
5493 // or vice versa. Perform a bitcast.
5494 value = create<ast::BitcastExpression>(
5495 Source{}, result_type->Build(builder_), call_expr);
5496 }
5497 if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) {
5498 // WGSL permits sampled image access only on float textures.
5499 // Reject this case in the SPIR-V reader, at least until SPIR-V validation
5500 // catches up with this rule and can reject it earlier in the workflow.
5501 return Fail() << "sampled image must have float component type";
5502 }
5503
5504 EmitConstDefOrWriteToHoistedVar(inst, {result_type, value});
5505 } else {
5506 // It's an image write. No value is returned, so make a statement out
5507 // of the call.
5508 AddStatement(create<ast::CallStatement>(Source{}, call_expr));
5509 }
5510 return success();
5511 }
5512
EmitImageQuery(const spvtools::opt::Instruction & inst)5513 bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
5514 // TODO(dneto): Reject cases that are valid in Vulkan but invalid in WGSL.
5515 const spvtools::opt::Instruction* image = GetImage(inst);
5516 if (!image) {
5517 return false;
5518 }
5519 auto* texture_type = GetImageType(*image);
5520 if (!texture_type) {
5521 return false;
5522 }
5523
5524 const auto opcode = inst.opcode();
5525 switch (opcode) {
5526 case SpvOpImageQuerySize:
5527 case SpvOpImageQuerySizeLod: {
5528 ast::ExpressionList exprs;
5529 // Invoke textureDimensions.
5530 // If the texture is arrayed, combine with the result from
5531 // textureNumLayers.
5532 auto* dims_ident = create<ast::IdentifierExpression>(
5533 Source{}, builder_.Symbols().Register("textureDimensions"));
5534 ast::ExpressionList dims_args{GetImageExpression(inst)};
5535 if (opcode == SpvOpImageQuerySizeLod) {
5536 dims_args.push_back(ToI32(MakeOperand(inst, 1)).expr);
5537 }
5538 const ast::Expression* dims_call =
5539 create<ast::CallExpression>(Source{}, dims_ident, dims_args);
5540 auto dims = texture_type->dims;
5541 if ((dims == ast::TextureDimension::kCube) ||
5542 (dims == ast::TextureDimension::kCubeArray)) {
5543 // textureDimension returns a 3-element vector but SPIR-V expects 2.
5544 dims_call = create<ast::MemberAccessorExpression>(Source{}, dims_call,
5545 PrefixSwizzle(2));
5546 }
5547 exprs.push_back(dims_call);
5548 if (ast::IsTextureArray(dims)) {
5549 auto* layers_ident = create<ast::IdentifierExpression>(
5550 Source{}, builder_.Symbols().Register("textureNumLayers"));
5551 exprs.push_back(create<ast::CallExpression>(
5552 Source{}, layers_ident,
5553 ast::ExpressionList{GetImageExpression(inst)}));
5554 }
5555 auto* result_type = parser_impl_.ConvertType(inst.type_id());
5556 TypedExpression expr = {
5557 result_type,
5558 builder_.Construct(Source{}, result_type->Build(builder_), exprs)};
5559 return EmitConstDefOrWriteToHoistedVar(inst, expr);
5560 }
5561 case SpvOpImageQueryLod:
5562 return Fail() << "WGSL does not support querying the level of detail of "
5563 "an image: "
5564 << inst.PrettyPrint();
5565 case SpvOpImageQueryLevels:
5566 case SpvOpImageQuerySamples: {
5567 const auto* name = (opcode == SpvOpImageQueryLevels)
5568 ? "textureNumLevels"
5569 : "textureNumSamples";
5570 auto* levels_ident = create<ast::IdentifierExpression>(
5571 Source{}, builder_.Symbols().Register(name));
5572 const ast::Expression* ast_expr = create<ast::CallExpression>(
5573 Source{}, levels_ident,
5574 ast::ExpressionList{GetImageExpression(inst)});
5575 auto* result_type = parser_impl_.ConvertType(inst.type_id());
5576 // The SPIR-V result type must be integer scalar. The WGSL bulitin
5577 // returns i32. If they aren't the same then convert the result.
5578 if (!result_type->Is<I32>()) {
5579 ast_expr = builder_.Construct(Source{}, result_type->Build(builder_),
5580 ast::ExpressionList{ast_expr});
5581 }
5582 TypedExpression expr{result_type, ast_expr};
5583 return EmitConstDefOrWriteToHoistedVar(inst, expr);
5584 }
5585 default:
5586 break;
5587 }
5588 return Fail() << "unhandled image query: " << inst.PrettyPrint();
5589 }
5590
MakeCoordinateOperandsForImageAccess(const spvtools::opt::Instruction & inst)5591 ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
5592 const spvtools::opt::Instruction& inst) {
5593 if (!parser_impl_.success()) {
5594 Fail();
5595 return {};
5596 }
5597 const spvtools::opt::Instruction* image = GetImage(inst);
5598 if (!image) {
5599 return {};
5600 }
5601 if (inst.NumInOperands() < 1) {
5602 Fail() << "image access is missing a coordinate parameter: "
5603 << inst.PrettyPrint();
5604 return {};
5605 }
5606
5607 // In SPIR-V for Shader, coordinates are:
5608 // - floating point for sampling, dref sampling, gather, dref gather
5609 // - integral for fetch, read, write
5610 // In WGSL:
5611 // - floating point for sampling, dref sampling, gather, dref gather
5612 // - signed integral for textureLoad, textureStore
5613 //
5614 // The only conversions we have to do for WGSL are:
5615 // - When the coordinates are unsigned integral, convert them to signed.
5616 // - Array index is always i32
5617
5618 // The coordinates parameter is always in position 1.
5619 TypedExpression raw_coords(MakeOperand(inst, 1));
5620 if (!raw_coords) {
5621 return {};
5622 }
5623 const Texture* texture_type = GetImageType(*image);
5624 if (!texture_type) {
5625 return {};
5626 }
5627 ast::TextureDimension dim = texture_type->dims;
5628 // Number of regular coordinates.
5629 uint32_t num_axes = ast::NumCoordinateAxes(dim);
5630 bool is_arrayed = ast::IsTextureArray(dim);
5631 if ((num_axes == 0) || (num_axes > 3)) {
5632 Fail() << "unsupported image dimensionality for "
5633 << texture_type->TypeInfo().name << " prompted by "
5634 << inst.PrettyPrint();
5635 }
5636 bool is_proj = false;
5637 switch (inst.opcode()) {
5638 case SpvOpImageSampleProjImplicitLod:
5639 case SpvOpImageSampleProjExplicitLod:
5640 case SpvOpImageSampleProjDrefImplicitLod:
5641 case SpvOpImageSampleProjDrefExplicitLod:
5642 is_proj = true;
5643 break;
5644 default:
5645 break;
5646 }
5647
5648 const auto num_coords_required =
5649 num_axes + (is_arrayed ? 1 : 0) + (is_proj ? 1 : 0);
5650 uint32_t num_coords_supplied = 0;
5651 auto* component_type = raw_coords.type;
5652 if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) {
5653 num_coords_supplied = 1;
5654 } else if (auto* vec_type = As<Vector>(raw_coords.type)) {
5655 component_type = vec_type->type;
5656 num_coords_supplied = vec_type->size;
5657 }
5658 if (num_coords_supplied == 0) {
5659 Fail() << "bad or unsupported coordinate type for image access: "
5660 << inst.PrettyPrint();
5661 return {};
5662 }
5663 if (num_coords_required > num_coords_supplied) {
5664 Fail() << "image access required " << num_coords_required
5665 << " coordinate components, but only " << num_coords_supplied
5666 << " provided, in: " << inst.PrettyPrint();
5667 return {};
5668 }
5669
5670 ast::ExpressionList result;
5671
5672 // Generates the expression for the WGSL coordinates, when it is a prefix
5673 // swizzle with num_axes. If the result would be unsigned, also converts
5674 // it to a signed value of the same shape (scalar or vector).
5675 // Use a lambda to make it easy to only generate the expressions when we
5676 // will actually use them.
5677 auto prefix_swizzle_expr = [this, num_axes, component_type, is_proj,
5678 raw_coords]() -> const ast::Expression* {
5679 auto* swizzle_type =
5680 (num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes);
5681 auto* swizzle = create<ast::MemberAccessorExpression>(
5682 Source{}, raw_coords.expr, PrefixSwizzle(num_axes));
5683 if (is_proj) {
5684 auto* q = create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr,
5685 Swizzle(num_axes));
5686 auto* proj_div = builder_.Div(swizzle, q);
5687 return ToSignedIfUnsigned({swizzle_type, proj_div}).expr;
5688 } else {
5689 return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
5690 }
5691 };
5692
5693 if (is_arrayed) {
5694 // The source must be a vector. It has at least one coordinate component
5695 // and it must have an array component. Use a vector swizzle to get the
5696 // first `num_axes` components.
5697 result.push_back(prefix_swizzle_expr());
5698
5699 // Now get the array index.
5700 const ast::Expression* array_index =
5701 builder_.MemberAccessor(raw_coords.expr, Swizzle(num_axes));
5702 if (component_type->IsFloatScalar()) {
5703 // When converting from a float array layer to integer, Vulkan requires
5704 // round-to-nearest, with preference for round-to-nearest-even.
5705 // But i32(f32) in WGSL has unspecified rounding mode, so we have to
5706 // explicitly specify the rounding.
5707 array_index = builder_.Call("round", array_index);
5708 }
5709 // Convert it to a signed integer type, if needed.
5710 result.push_back(ToI32({component_type, array_index}).expr);
5711 } else {
5712 if (num_coords_supplied == num_coords_required && !is_proj) {
5713 // Pass the value through, with possible unsigned->signed conversion.
5714 result.push_back(ToSignedIfUnsigned(raw_coords).expr);
5715 } else {
5716 // There are more coordinates supplied than needed. So the source type
5717 // is a vector. Use a vector swizzle to get the first `num_axes`
5718 // components.
5719 result.push_back(prefix_swizzle_expr());
5720 }
5721 }
5722 return result;
5723 }
5724
ConvertTexelForStorage(const spvtools::opt::Instruction & inst,TypedExpression texel,const Texture * texture_type)5725 const ast::Expression* FunctionEmitter::ConvertTexelForStorage(
5726 const spvtools::opt::Instruction& inst,
5727 TypedExpression texel,
5728 const Texture* texture_type) {
5729 auto* storage_texture_type = As<StorageTexture>(texture_type);
5730 auto* src_type = texel.type;
5731 if (!storage_texture_type) {
5732 Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
5733 return nullptr;
5734 }
5735 const auto format = storage_texture_type->format;
5736 auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
5737 if (!dest_type) {
5738 Fail();
5739 return nullptr;
5740 }
5741
5742 // The texel type is always a 4-element vector.
5743 const uint32_t dest_count = 4u;
5744 TINT_ASSERT(Reader, dest_type->Is<Vector>() &&
5745 dest_type->As<Vector>()->size == dest_count);
5746 TINT_ASSERT(Reader, dest_type->IsFloatVector() ||
5747 dest_type->IsUnsignedIntegerVector() ||
5748 dest_type->IsSignedIntegerVector());
5749
5750 if (src_type == dest_type) {
5751 return texel.expr;
5752 }
5753
5754 // Component type must match floatness, or integral signedness.
5755 if ((src_type->IsFloatScalarOrVector() != dest_type->IsFloatVector()) ||
5756 (src_type->IsUnsignedIntegerVector() !=
5757 dest_type->IsUnsignedIntegerVector()) ||
5758 (src_type->IsSignedIntegerVector() !=
5759 dest_type->IsSignedIntegerVector())) {
5760 Fail() << "invalid texel type for storage texture write: component must be "
5761 "float, signed integer, or unsigned integer "
5762 "to match the texture channel type: "
5763 << inst.PrettyPrint();
5764 return nullptr;
5765 }
5766
5767 const auto required_count = parser_impl_.GetChannelCountForFormat(format);
5768 TINT_ASSERT(Reader, 0 < required_count && required_count <= 4);
5769
5770 const uint32_t src_count =
5771 src_type->IsScalar() ? 1 : src_type->As<Vector>()->size;
5772 if (src_count < required_count) {
5773 Fail() << "texel has too few components for storage texture: " << src_count
5774 << " provided but " << required_count
5775 << " required, in: " << inst.PrettyPrint();
5776 return nullptr;
5777 }
5778
5779 // It's valid for required_count < src_count. The extra components will
5780 // be written out but the textureStore will ignore them.
5781
5782 if (src_count < dest_count) {
5783 // Expand the texel to a 4 element vector.
5784 auto* component_type =
5785 texel.type->IsScalar() ? texel.type : texel.type->As<Vector>()->type;
5786 texel.type = ty_.Vector(component_type, dest_count);
5787 ast::ExpressionList exprs;
5788 exprs.push_back(texel.expr);
5789 for (auto i = src_count; i < dest_count; i++) {
5790 exprs.push_back(parser_impl_.MakeNullExpression(component_type).expr);
5791 }
5792 texel.expr = builder_.Construct(Source{}, texel.type->Build(builder_),
5793 std::move(exprs));
5794 }
5795
5796 return texel.expr;
5797 }
5798
ToI32(TypedExpression value)5799 TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
5800 if (!value || value.type->Is<I32>()) {
5801 return value;
5802 }
5803 return {ty_.I32(), builder_.Construct(Source{}, builder_.ty.i32(),
5804 ast::ExpressionList{value.expr})};
5805 }
5806
ToSignedIfUnsigned(TypedExpression value)5807 TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
5808 if (!value || !value.type->IsUnsignedScalarOrVector()) {
5809 return value;
5810 }
5811 if (auto* vec_type = value.type->As<Vector>()) {
5812 auto* new_type = ty_.Vector(ty_.I32(), vec_type->size);
5813 return {new_type, builder_.Construct(new_type->Build(builder_),
5814 ast::ExpressionList{value.expr})};
5815 }
5816 return ToI32(value);
5817 }
5818
MakeArrayLength(const spvtools::opt::Instruction & inst)5819 TypedExpression FunctionEmitter::MakeArrayLength(
5820 const spvtools::opt::Instruction& inst) {
5821 if (inst.NumInOperands() != 2) {
5822 // Binary parsing will fail on this anyway.
5823 Fail() << "invalid array length: requires 2 operands: "
5824 << inst.PrettyPrint();
5825 return {};
5826 }
5827 const auto struct_ptr_id = inst.GetSingleWordInOperand(0);
5828 const auto field_index = inst.GetSingleWordInOperand(1);
5829 const auto struct_ptr_type_id =
5830 def_use_mgr_->GetDef(struct_ptr_id)->type_id();
5831 // Trace through the pointer type to get to the struct type.
5832 const auto struct_type_id =
5833 def_use_mgr_->GetDef(struct_ptr_type_id)->GetSingleWordInOperand(1);
5834 const auto field_name = namer_.GetMemberName(struct_type_id, field_index);
5835 if (field_name.empty()) {
5836 Fail() << "struct index out of bounds for array length: "
5837 << inst.PrettyPrint();
5838 return {};
5839 }
5840
5841 auto member_expr = MakeExpression(struct_ptr_id);
5842 if (!member_expr) {
5843 return {};
5844 }
5845 if (member_expr.type->Is<Pointer>()) {
5846 member_expr = Dereference(member_expr);
5847 }
5848 auto* member_ident = create<ast::IdentifierExpression>(
5849 Source{}, builder_.Symbols().Register(field_name));
5850 auto* member_access = create<ast::MemberAccessorExpression>(
5851 Source{}, member_expr.expr, member_ident);
5852
5853 // Generate the intrinsic function call.
5854 auto* call_expr =
5855 builder_.Call(Source{}, "arrayLength", builder_.AddressOf(member_access));
5856
5857 return {parser_impl_.ConvertType(inst.type_id()), call_expr};
5858 }
5859
MakeOuterProduct(const spvtools::opt::Instruction & inst)5860 TypedExpression FunctionEmitter::MakeOuterProduct(
5861 const spvtools::opt::Instruction& inst) {
5862 // Synthesize the result.
5863 auto col = MakeOperand(inst, 0);
5864 auto row = MakeOperand(inst, 1);
5865 auto* col_ty = As<Vector>(col.type);
5866 auto* row_ty = As<Vector>(row.type);
5867 auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id()));
5868 if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type ||
5869 result_ty->type != row_ty->type || result_ty->columns != row_ty->size ||
5870 result_ty->rows != col_ty->size) {
5871 Fail() << "invalid outer product instruction: bad types "
5872 << inst.PrettyPrint();
5873 return {};
5874 }
5875
5876 // Example:
5877 // c : vec3 column vector
5878 // r : vec2 row vector
5879 // OuterProduct c r : mat2x3 (2 columns, 3 rows)
5880 // Result:
5881 // | c.x * r.x c.x * r.y |
5882 // | c.y * r.x c.y * r.y |
5883 // | c.z * r.x c.z * r.y |
5884
5885 ast::ExpressionList result_columns;
5886 for (uint32_t icol = 0; icol < result_ty->columns; icol++) {
5887 ast::ExpressionList result_row;
5888 auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
5889 Swizzle(icol));
5890 for (uint32_t irow = 0; irow < result_ty->rows; irow++) {
5891 auto* column_factor = create<ast::MemberAccessorExpression>(
5892 Source{}, col.expr, Swizzle(irow));
5893 auto* elem = create<ast::BinaryExpression>(
5894 Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
5895 result_row.push_back(elem);
5896 }
5897 result_columns.push_back(
5898 builder_.Construct(Source{}, col_ty->Build(builder_), result_row));
5899 }
5900 return {result_ty, builder_.Construct(Source{}, result_ty->Build(builder_),
5901 result_columns)};
5902 }
5903
MakeVectorInsertDynamic(const spvtools::opt::Instruction & inst)5904 bool FunctionEmitter::MakeVectorInsertDynamic(
5905 const spvtools::opt::Instruction& inst) {
5906 // For
5907 // %result = OpVectorInsertDynamic %type %src_vector %component %index
5908 // there are two cases.
5909 //
5910 // Case 1:
5911 // The %src_vector value has already been hoisted into a variable.
5912 // In this case, assign %src_vector to that variable, then write the
5913 // component into the right spot:
5914 //
5915 // hoisted = src_vector;
5916 // hoisted[index] = component;
5917 //
5918 // Case 2:
5919 // The %src_vector value is not hoisted. In this case, make a temporary
5920 // variable with the %src_vector contents, then write the component,
5921 // and then make a let-declaration that reads the value out:
5922 //
5923 // var temp : type = src_vector;
5924 // temp[index] = component;
5925 // let result : type = temp;
5926 //
5927 // Then use result everywhere the original SPIR-V id is used. Using a const
5928 // like this avoids constantly reloading the value many times.
5929
5930 auto* type = parser_impl_.ConvertType(inst.type_id());
5931 auto src_vector = MakeOperand(inst, 0);
5932 auto component = MakeOperand(inst, 1);
5933 auto index = MakeOperand(inst, 2);
5934
5935 std::string var_name;
5936 auto original_value_name = namer_.Name(inst.result_id());
5937 const bool hoisted = WriteIfHoistedVar(inst, src_vector);
5938 if (hoisted) {
5939 // The variable was already declared in an earlier block.
5940 var_name = original_value_name;
5941 // Assign the source vector value to it.
5942 builder_.Assign({}, builder_.Expr(var_name), src_vector.expr);
5943 } else {
5944 // Synthesize the temporary variable.
5945 // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
5946 // API in parser_impl_.
5947 var_name = namer_.MakeDerivedName(original_value_name);
5948
5949 auto* temp_var = builder_.Var(var_name, type->Build(builder_),
5950 ast::StorageClass::kNone, src_vector.expr);
5951
5952 AddStatement(builder_.Decl({}, temp_var));
5953 }
5954
5955 auto* lhs = create<ast::IndexAccessorExpression>(
5956 Source{}, builder_.Expr(var_name), index.expr);
5957 if (!lhs) {
5958 return false;
5959 }
5960
5961 AddStatement(builder_.Assign(lhs, component.expr));
5962
5963 if (hoisted) {
5964 // The hoisted variable itself stands for this result ID.
5965 return success();
5966 }
5967 // Create a new let-declaration that is initialized by the contents
5968 // of the temporary variable.
5969 return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
5970 }
5971
MakeCompositeInsert(const spvtools::opt::Instruction & inst)5972 bool FunctionEmitter::MakeCompositeInsert(
5973 const spvtools::opt::Instruction& inst) {
5974 // For
5975 // %result = OpCompositeInsert %type %object %composite 1 2 3 ...
5976 // there are two cases.
5977 //
5978 // Case 1:
5979 // The %composite value has already been hoisted into a variable.
5980 // In this case, assign %composite to that variable, then write the
5981 // component into the right spot:
5982 //
5983 // hoisted = composite;
5984 // hoisted[index].x = object;
5985 //
5986 // Case 2:
5987 // The %composite value is not hoisted. In this case, make a temporary
5988 // variable with the %composite contents, then write the component,
5989 // and then make a let-declaration that reads the value out:
5990 //
5991 // var temp : type = composite;
5992 // temp[index].x = object;
5993 // let result : type = temp;
5994 //
5995 // Then use result everywhere the original SPIR-V id is used. Using a const
5996 // like this avoids constantly reloading the value many times.
5997 //
5998 // This technique is a combination of:
5999 // - making a temporary variable and constant declaration, like what we do
6000 // for VectorInsertDynamic, and
6001 // - building up an access-chain like access like for CompositeExtract, but
6002 // on the left-hand side of the assignment.
6003
6004 auto* type = parser_impl_.ConvertType(inst.type_id());
6005 auto component = MakeOperand(inst, 0);
6006 auto src_composite = MakeOperand(inst, 1);
6007
6008 std::string var_name;
6009 auto original_value_name = namer_.Name(inst.result_id());
6010 const bool hoisted = WriteIfHoistedVar(inst, src_composite);
6011 if (hoisted) {
6012 // The variable was already declared in an earlier block.
6013 var_name = original_value_name;
6014 // Assign the source composite value to it.
6015 builder_.Assign({}, builder_.Expr(var_name), src_composite.expr);
6016 } else {
6017 // Synthesize a temporary variable.
6018 // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
6019 // API in parser_impl_.
6020 var_name = namer_.MakeDerivedName(original_value_name);
6021 auto* temp_var = builder_.Var(var_name, type->Build(builder_),
6022 ast::StorageClass::kNone, src_composite.expr);
6023 AddStatement(builder_.Decl({}, temp_var));
6024 }
6025
6026 TypedExpression seed_expr{type, builder_.Expr(var_name)};
6027
6028 // The left-hand side of the assignment *looks* like a decomposition.
6029 TypedExpression lhs =
6030 MakeCompositeValueDecomposition(inst, seed_expr, inst.type_id(), 2);
6031 if (!lhs) {
6032 return false;
6033 }
6034
6035 AddStatement(builder_.Assign(lhs.expr, component.expr));
6036
6037 if (hoisted) {
6038 // The hoisted variable itself stands for this result ID.
6039 return success();
6040 }
6041 // Create a new let-declaration that is initialized by the contents
6042 // of the temporary variable.
6043 return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
6044 }
6045
AddressOf(TypedExpression expr)6046 TypedExpression FunctionEmitter::AddressOf(TypedExpression expr) {
6047 auto* ref = expr.type->As<Reference>();
6048 if (!ref) {
6049 Fail() << "AddressOf() called on non-reference type";
6050 return {};
6051 }
6052 return {
6053 ty_.Pointer(ref->type, ref->storage_class),
6054 create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kAddressOf,
6055 expr.expr),
6056 };
6057 }
6058
Dereference(TypedExpression expr)6059 TypedExpression FunctionEmitter::Dereference(TypedExpression expr) {
6060 auto* ptr = expr.type->As<Pointer>();
6061 if (!ptr) {
6062 Fail() << "Dereference() called on non-pointer type";
6063 return {};
6064 }
6065 return {
6066 ptr->type,
6067 create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kIndirection,
6068 expr.expr),
6069 };
6070 }
6071
IsFloatZero(uint32_t value_id)6072 bool FunctionEmitter::IsFloatZero(uint32_t value_id) {
6073 if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
6074 if (const auto* float_const = c->AsFloatConstant()) {
6075 return 0.0f == float_const->GetFloatValue();
6076 }
6077 if (c->AsNullConstant()) {
6078 // Valid SPIR-V requires it to be a float value anyway.
6079 return true;
6080 }
6081 }
6082 return false;
6083 }
6084
IsFloatOne(uint32_t value_id)6085 bool FunctionEmitter::IsFloatOne(uint32_t value_id) {
6086 if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
6087 if (const auto* float_const = c->AsFloatConstant()) {
6088 return 1.0f == float_const->GetFloatValue();
6089 }
6090 }
6091 return false;
6092 }
6093
6094 FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
6095 FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
6096
6097 } // namespace spirv
6098 } // namespace reader
6099 } // namespace tint
6100
6101 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StatementBuilder);
6102 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SwitchStatementBuilder);
6103 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::IfStatementBuilder);
6104 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::LoopStatementBuilder);
6105