• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/reader/spirv/function.h"
16 
17 #include <algorithm>
18 #include <array>
19 
20 #include "src/ast/assignment_statement.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/break_statement.h"
23 #include "src/ast/builtin.h"
24 #include "src/ast/builtin_decoration.h"
25 #include "src/ast/call_statement.h"
26 #include "src/ast/continue_statement.h"
27 #include "src/ast/discard_statement.h"
28 #include "src/ast/fallthrough_statement.h"
29 #include "src/ast/if_statement.h"
30 #include "src/ast/loop_statement.h"
31 #include "src/ast/return_statement.h"
32 #include "src/ast/stage_decoration.h"
33 #include "src/ast/switch_statement.h"
34 #include "src/ast/unary_op_expression.h"
35 #include "src/ast/variable_decl_statement.h"
36 #include "src/sem/depth_texture_type.h"
37 #include "src/sem/intrinsic_type.h"
38 #include "src/sem/sampled_texture_type.h"
39 
40 // Terms:
41 //    CFG: the control flow graph of the function, where basic blocks are the
42 //    nodes, and branches form the directed arcs.  The function entry block is
43 //    the root of the CFG.
44 //
45 //    Suppose H is a header block (i.e. has an OpSelectionMerge or OpLoopMerge).
46 //    Then:
47 //    - Let M(H) be the merge block named by the merge instruction in H.
48 //    - If H is a loop header, i.e. has an OpLoopMerge instruction, then let
49 //      CT(H) be the continue target block named by the OpLoopMerge
50 //      instruction.
51 //    - If H is a selection construct whose header ends in
52 //      OpBranchConditional with true target %then and false target %else,
53 //      then  TT(H) = %then and FT(H) = %else
54 //
55 // Determining output block order:
56 //    The "structured post-order traversal" of the CFG is a post-order traversal
57 //    of the basic blocks in the CFG, where:
58 //      We visit the entry node of the function first.
59 //      When visiting a header block:
60 //        We next visit its merge block
61 //        Then if it's a loop header, we next visit the continue target,
62 //      Then we visit the block's successors (whether it's a header or not)
63 //        If the block ends in an OpBranchConditional, we visit the false target
64 //        before the true target.
65 //
66 //    The "reverse structured post-order traversal" of the CFG is the reverse
67 //    of the structured post-order traversal.
68 //    This is the order of basic blocks as they should be emitted to the WGSL
69 //    function. It is the order computed by ComputeBlockOrder, and stored in
70 //    the |FunctionEmiter::block_order_|.
71 //    Blocks not in this ordering are ignored by the rest of the algorithm.
72 //
73 //    Note:
74 //     - A block D in the function might not appear in this order because
75 //       no block in the order branches to D.
76 //     - An unreachable block D might still be in the order because some header
77 //       block in the order names D as its continue target, or merge block,
78 //       or D is reachable from one of those otherwise-unreachable continue
79 //       targets or merge blocks.
80 //
81 // Terms:
82 //    Let Pos(B) be the index position of a block B in the computed block order.
83 //
84 // CFG intervals and valid nesting:
85 //
86 //    A correctly structured CFG satisfies nesting rules that we can check by
87 //    comparing positions of related blocks.
88 //
89 //    If header block H is in the block order, then the following holds:
90 //
91 //      Pos(H) < Pos(M(H))
92 //
93 //      If CT(H) exists, then:
94 //
95 //         Pos(H) <= Pos(CT(H))
96 //         Pos(CT(H)) < Pos(M)
97 //
98 //    This gives us the fundamental ordering of blocks in relation to a
99 //    structured construct:
100 //      The blocks before H in the block order, are not in the construct
101 //      The blocks at M(H) or later in the block order, are not in the construct
102 //      The blocks in a selection headed at H are in positions [ Pos(H),
103 //      Pos(M(H)) ) The blocks in a loop construct headed at H are in positions
104 //      [ Pos(H), Pos(CT(H)) ) The blocks in the continue construct for loop
105 //      headed at H are in
106 //        positions [ Pos(CT(H)), Pos(M(H)) )
107 //
108 //      Schematically, for a selection construct headed by H, the blocks are in
109 //      order from left to right:
110 //
111 //                 ...a-b-c H d-e-f M(H) n-o-p...
112 //
113 //           where ...a-b-c: blocks before the selection construct
114 //           where H and d-e-f: blocks in the selection construct
115 //           where M(H) and n-o-p...: blocks after the selection construct
116 //
117 //      Schematically, for a loop construct headed by H that is its own
118 //      continue construct, the blocks in order from left to right:
119 //
120 //                 ...a-b-c H=CT(H) d-e-f M(H) n-o-p...
121 //
122 //           where ...a-b-c: blocks before the loop
123 //           where H is the continue construct; CT(H)=H, and the loop construct
124 //           is *empty*
125 //           where d-e-f... are other blocks in the continue construct
126 //           where M(H) and n-o-p...: blocks after the continue construct
127 //
128 //      Schematically, for a multi-block loop construct headed by H, there are
129 //      blocks in order from left to right:
130 //
131 //                 ...a-b-c H d-e-f CT(H) j-k-l M(H) n-o-p...
132 //
133 //           where ...a-b-c: blocks before the loop
134 //           where H and d-e-f: blocks in the loop construct
135 //           where CT(H) and j-k-l: blocks in the continue construct
136 //           where M(H) and n-o-p...: blocks after the loop and continue
137 //           constructs
138 //
139 
140 namespace tint {
141 namespace reader {
142 namespace spirv {
143 
144 namespace {
145 
146 constexpr uint32_t kMaxVectorLen = 4;
147 
148 // Gets the AST unary opcode for the given SPIR-V opcode, if any
149 // @param opcode SPIR-V opcode
150 // @param ast_unary_op return parameter
151 // @returns true if it was a unary operation
GetUnaryOp(SpvOp opcode,ast::UnaryOp * ast_unary_op)152 bool GetUnaryOp(SpvOp opcode, ast::UnaryOp* ast_unary_op) {
153   switch (opcode) {
154     case SpvOpSNegate:
155     case SpvOpFNegate:
156       *ast_unary_op = ast::UnaryOp::kNegation;
157       return true;
158     case SpvOpLogicalNot:
159       *ast_unary_op = ast::UnaryOp::kNot;
160       return true;
161     case SpvOpNot:
162       *ast_unary_op = ast::UnaryOp::kComplement;
163       return true;
164     default:
165       break;
166   }
167   return false;
168 }
169 
170 /// Converts a SPIR-V opcode for a WGSL builtin function, if there is a
171 /// direct translation. Returns nullptr otherwise.
172 /// @returns the WGSL builtin function name for the given opcode, or nullptr.
GetUnaryBuiltInFunctionName(SpvOp opcode)173 const char* GetUnaryBuiltInFunctionName(SpvOp opcode) {
174   switch (opcode) {
175     case SpvOpAny:
176       return "any";
177     case SpvOpAll:
178       return "all";
179     case SpvOpIsNan:
180       return "isNan";
181     case SpvOpIsInf:
182       return "isInf";
183     case SpvOpTranspose:
184       return "transpose";
185     default:
186       break;
187   }
188   return nullptr;
189 }
190 
191 // Converts a SPIR-V opcode to its corresponding AST binary opcode, if any
192 // @param opcode SPIR-V opcode
193 // @returns the AST binary op for the given opcode, or kNone
ConvertBinaryOp(SpvOp opcode)194 ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
195   switch (opcode) {
196     case SpvOpIAdd:
197     case SpvOpFAdd:
198       return ast::BinaryOp::kAdd;
199     case SpvOpISub:
200     case SpvOpFSub:
201       return ast::BinaryOp::kSubtract;
202     case SpvOpIMul:
203     case SpvOpFMul:
204     case SpvOpVectorTimesScalar:
205     case SpvOpMatrixTimesScalar:
206     case SpvOpVectorTimesMatrix:
207     case SpvOpMatrixTimesVector:
208     case SpvOpMatrixTimesMatrix:
209       return ast::BinaryOp::kMultiply;
210     case SpvOpUDiv:
211     case SpvOpSDiv:
212     case SpvOpFDiv:
213       return ast::BinaryOp::kDivide;
214     case SpvOpUMod:
215     case SpvOpSMod:
216     case SpvOpFRem:
217       return ast::BinaryOp::kModulo;
218     case SpvOpLogicalEqual:
219     case SpvOpIEqual:
220     case SpvOpFOrdEqual:
221       return ast::BinaryOp::kEqual;
222     case SpvOpLogicalNotEqual:
223     case SpvOpINotEqual:
224     case SpvOpFOrdNotEqual:
225       return ast::BinaryOp::kNotEqual;
226     case SpvOpBitwiseAnd:
227       return ast::BinaryOp::kAnd;
228     case SpvOpBitwiseOr:
229       return ast::BinaryOp::kOr;
230     case SpvOpBitwiseXor:
231       return ast::BinaryOp::kXor;
232     case SpvOpLogicalAnd:
233       return ast::BinaryOp::kAnd;
234     case SpvOpLogicalOr:
235       return ast::BinaryOp::kOr;
236     case SpvOpUGreaterThan:
237     case SpvOpSGreaterThan:
238     case SpvOpFOrdGreaterThan:
239       return ast::BinaryOp::kGreaterThan;
240     case SpvOpUGreaterThanEqual:
241     case SpvOpSGreaterThanEqual:
242     case SpvOpFOrdGreaterThanEqual:
243       return ast::BinaryOp::kGreaterThanEqual;
244     case SpvOpULessThan:
245     case SpvOpSLessThan:
246     case SpvOpFOrdLessThan:
247       return ast::BinaryOp::kLessThan;
248     case SpvOpULessThanEqual:
249     case SpvOpSLessThanEqual:
250     case SpvOpFOrdLessThanEqual:
251       return ast::BinaryOp::kLessThanEqual;
252     default:
253       break;
254   }
255   // It's not clear what OpSMod should map to.
256   // https://bugs.chromium.org/p/tint/issues/detail?id=52
257   return ast::BinaryOp::kNone;
258 }
259 
260 // If the given SPIR-V opcode is a floating point unordered comparison,
261 // then returns the binary float comparison for which it is the negation.
262 // Othewrise returns BinaryOp::kNone.
263 // @param opcode SPIR-V opcode
264 // @returns operation corresponding to negated version of the SPIR-V opcode
NegatedFloatCompare(SpvOp opcode)265 ast::BinaryOp NegatedFloatCompare(SpvOp opcode) {
266   switch (opcode) {
267     case SpvOpFUnordEqual:
268       return ast::BinaryOp::kNotEqual;
269     case SpvOpFUnordNotEqual:
270       return ast::BinaryOp::kEqual;
271     case SpvOpFUnordLessThan:
272       return ast::BinaryOp::kGreaterThanEqual;
273     case SpvOpFUnordLessThanEqual:
274       return ast::BinaryOp::kGreaterThan;
275     case SpvOpFUnordGreaterThan:
276       return ast::BinaryOp::kLessThanEqual;
277     case SpvOpFUnordGreaterThanEqual:
278       return ast::BinaryOp::kLessThan;
279     default:
280       break;
281   }
282   return ast::BinaryOp::kNone;
283 }
284 
285 // Returns the WGSL standard library function for the given
286 // GLSL.std.450 extended instruction operation code.  Unknown
287 // and invalid opcodes map to the empty string.
288 // @returns the WGSL standard function name, or an empty string.
GetGlslStd450FuncName(uint32_t ext_opcode)289 std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
290   switch (ext_opcode) {
291     case GLSLstd450FAbs:
292     case GLSLstd450SAbs:
293       return "abs";
294     case GLSLstd450Acos:
295       return "acos";
296     case GLSLstd450Asin:
297       return "asin";
298     case GLSLstd450Atan:
299       return "atan";
300     case GLSLstd450Atan2:
301       return "atan2";
302     case GLSLstd450Ceil:
303       return "ceil";
304     case GLSLstd450UClamp:
305     case GLSLstd450SClamp:
306     case GLSLstd450NClamp:
307     case GLSLstd450FClamp:  // FClamp is less prescriptive about NaN operands
308       return "clamp";
309     case GLSLstd450Cos:
310       return "cos";
311     case GLSLstd450Cosh:
312       return "cosh";
313     case GLSLstd450Cross:
314       return "cross";
315     case GLSLstd450Distance:
316       return "distance";
317     case GLSLstd450Exp:
318       return "exp";
319     case GLSLstd450Exp2:
320       return "exp2";
321     case GLSLstd450FaceForward:
322       return "faceForward";
323     case GLSLstd450Floor:
324       return "floor";
325     case GLSLstd450Fma:
326       return "fma";
327     case GLSLstd450Fract:
328       return "fract";
329     case GLSLstd450InverseSqrt:
330       return "inverseSqrt";
331     case GLSLstd450Ldexp:
332       return "ldexp";
333     case GLSLstd450Length:
334       return "length";
335     case GLSLstd450Log:
336       return "log";
337     case GLSLstd450Log2:
338       return "log2";
339     case GLSLstd450NMax:
340     case GLSLstd450FMax:  // FMax is less prescriptive about NaN operands
341     case GLSLstd450UMax:
342     case GLSLstd450SMax:
343       return "max";
344     case GLSLstd450NMin:
345     case GLSLstd450FMin:  // FMin is less prescriptive about NaN operands
346     case GLSLstd450UMin:
347     case GLSLstd450SMin:
348       return "min";
349     case GLSLstd450FMix:
350       return "mix";
351     case GLSLstd450Normalize:
352       return "normalize";
353     case GLSLstd450PackSnorm4x8:
354       return "pack4x8snorm";
355     case GLSLstd450PackUnorm4x8:
356       return "pack4x8unorm";
357     case GLSLstd450PackSnorm2x16:
358       return "pack2x16snorm";
359     case GLSLstd450PackUnorm2x16:
360       return "pack2x16unorm";
361     case GLSLstd450PackHalf2x16:
362       return "pack2x16float";
363     case GLSLstd450Pow:
364       return "pow";
365     case GLSLstd450FSign:
366       return "sign";
367     case GLSLstd450Reflect:
368       return "reflect";
369     case GLSLstd450Refract:
370       return "refract";
371     case GLSLstd450Round:
372     case GLSLstd450RoundEven:
373       return "round";
374     case GLSLstd450Sin:
375       return "sin";
376     case GLSLstd450Sinh:
377       return "sinh";
378     case GLSLstd450SmoothStep:
379       return "smoothStep";
380     case GLSLstd450Sqrt:
381       return "sqrt";
382     case GLSLstd450Step:
383       return "step";
384     case GLSLstd450Tan:
385       return "tan";
386     case GLSLstd450Tanh:
387       return "tanh";
388     case GLSLstd450Trunc:
389       return "trunc";
390     case GLSLstd450UnpackSnorm4x8:
391       return "unpack4x8snorm";
392     case GLSLstd450UnpackUnorm4x8:
393       return "unpack4x8unorm";
394     case GLSLstd450UnpackSnorm2x16:
395       return "unpack2x16snorm";
396     case GLSLstd450UnpackUnorm2x16:
397       return "unpack2x16unorm";
398     case GLSLstd450UnpackHalf2x16:
399       return "unpack2x16float";
400 
401     default:
402       // TODO(dneto) - The following are not implemented.
403       // They are grouped semantically, as in GLSL.std.450.h.
404 
405     case GLSLstd450SSign:
406 
407     case GLSLstd450Radians:
408     case GLSLstd450Degrees:
409     case GLSLstd450Asinh:
410     case GLSLstd450Acosh:
411     case GLSLstd450Atanh:
412 
413     case GLSLstd450Determinant:
414     case GLSLstd450MatrixInverse:
415 
416     case GLSLstd450Modf:
417     case GLSLstd450ModfStruct:
418     case GLSLstd450IMix:
419 
420     case GLSLstd450Frexp:
421     case GLSLstd450FrexpStruct:
422 
423     case GLSLstd450PackDouble2x32:
424     case GLSLstd450UnpackDouble2x32:
425 
426     case GLSLstd450FindILsb:
427     case GLSLstd450FindSMsb:
428     case GLSLstd450FindUMsb:
429 
430     case GLSLstd450InterpolateAtCentroid:
431     case GLSLstd450InterpolateAtSample:
432     case GLSLstd450InterpolateAtOffset:
433       break;
434   }
435   return "";
436 }
437 
438 // Returns the WGSL standard library function intrinsic for the
439 // given instruction, or sem::IntrinsicType::kNone
GetIntrinsic(SpvOp opcode)440 sem::IntrinsicType GetIntrinsic(SpvOp opcode) {
441   switch (opcode) {
442     case SpvOpBitCount:
443       return sem::IntrinsicType::kCountOneBits;
444     case SpvOpBitReverse:
445       return sem::IntrinsicType::kReverseBits;
446     case SpvOpDot:
447       return sem::IntrinsicType::kDot;
448     case SpvOpDPdx:
449       return sem::IntrinsicType::kDpdx;
450     case SpvOpDPdy:
451       return sem::IntrinsicType::kDpdy;
452     case SpvOpFwidth:
453       return sem::IntrinsicType::kFwidth;
454     case SpvOpDPdxFine:
455       return sem::IntrinsicType::kDpdxFine;
456     case SpvOpDPdyFine:
457       return sem::IntrinsicType::kDpdyFine;
458     case SpvOpFwidthFine:
459       return sem::IntrinsicType::kFwidthFine;
460     case SpvOpDPdxCoarse:
461       return sem::IntrinsicType::kDpdxCoarse;
462     case SpvOpDPdyCoarse:
463       return sem::IntrinsicType::kDpdyCoarse;
464     case SpvOpFwidthCoarse:
465       return sem::IntrinsicType::kFwidthCoarse;
466     default:
467       break;
468   }
469   return sem::IntrinsicType::kNone;
470 }
471 
472 // @param opcode a SPIR-V opcode
473 // @returns true if the given instruction is an image access instruction
474 // whose first input operand is an OpSampledImage value.
IsSampledImageAccess(SpvOp opcode)475 bool IsSampledImageAccess(SpvOp opcode) {
476   switch (opcode) {
477     case SpvOpImageSampleImplicitLod:
478     case SpvOpImageSampleExplicitLod:
479     case SpvOpImageSampleDrefImplicitLod:
480     case SpvOpImageSampleDrefExplicitLod:
481     // WGSL doesn't have *Proj* texturing; spirv reader emulates it.
482     case SpvOpImageSampleProjImplicitLod:
483     case SpvOpImageSampleProjExplicitLod:
484     case SpvOpImageSampleProjDrefImplicitLod:
485     case SpvOpImageSampleProjDrefExplicitLod:
486     case SpvOpImageGather:
487     case SpvOpImageDrefGather:
488     case SpvOpImageQueryLod:
489       return true;
490     default:
491       break;
492   }
493   return false;
494 }
495 
496 // @param opcode a SPIR-V opcode
497 // @returns true if the given instruction is an image sampling operation.
IsImageSampling(SpvOp opcode)498 bool IsImageSampling(SpvOp opcode) {
499   switch (opcode) {
500     case SpvOpImageSampleImplicitLod:
501     case SpvOpImageSampleExplicitLod:
502     case SpvOpImageSampleDrefImplicitLod:
503     case SpvOpImageSampleDrefExplicitLod:
504       // WGSL doesn't have *Proj* texturing; spirv reader emulates it.
505     case SpvOpImageSampleProjImplicitLod:
506     case SpvOpImageSampleProjExplicitLod:
507     case SpvOpImageSampleProjDrefImplicitLod:
508     case SpvOpImageSampleProjDrefExplicitLod:
509       return true;
510     default:
511       break;
512   }
513   return false;
514 }
515 
516 // @param opcode a SPIR-V opcode
517 // @returns true if the given instruction is an image access instruction
518 // whose first input operand is an OpImage value.
IsRawImageAccess(SpvOp opcode)519 bool IsRawImageAccess(SpvOp opcode) {
520   switch (opcode) {
521     case SpvOpImageRead:
522     case SpvOpImageWrite:
523     case SpvOpImageFetch:
524       return true;
525     default:
526       break;
527   }
528   return false;
529 }
530 
531 // @param opcode a SPIR-V opcode
532 // @returns true if the given instruction is an image query instruction
IsImageQuery(SpvOp opcode)533 bool IsImageQuery(SpvOp opcode) {
534   switch (opcode) {
535     case SpvOpImageQuerySize:
536     case SpvOpImageQuerySizeLod:
537     case SpvOpImageQueryLevels:
538     case SpvOpImageQuerySamples:
539     case SpvOpImageQueryLod:
540       return true;
541     default:
542       break;
543   }
544   return false;
545 }
546 
547 // @returns the merge block ID for the given basic block, or 0 if there is none.
MergeFor(const spvtools::opt::BasicBlock & bb)548 uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
549   // Get the OpSelectionMerge or OpLoopMerge instruction, if any.
550   auto* inst = bb.GetMergeInst();
551   return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0);
552 }
553 
554 // @returns the continue target ID for the given basic block, or 0 if there
555 // is none.
ContinueTargetFor(const spvtools::opt::BasicBlock & bb)556 uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) {
557   // Get the OpLoopMerge instruction, if any.
558   auto* inst = bb.GetLoopMergeInst();
559   return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1);
560 }
561 
562 // A structured traverser produces the reverse structured post-order of the
563 // CFG of a function.  The blocks traversed are the transitive closure (minimum
564 // fixed point) of:
565 //  - the entry block
566 //  - a block reached by a branch from another block in the set
567 //  - a block mentioned as a merge block or continue target for a block in the
568 //  set
569 class StructuredTraverser {
570  public:
StructuredTraverser(const spvtools::opt::Function & function)571   explicit StructuredTraverser(const spvtools::opt::Function& function)
572       : function_(function) {
573     for (auto& block : function_) {
574       id_to_block_[block.id()] = &block;
575     }
576   }
577 
578   // Returns the reverse postorder traversal of the CFG, where:
579   //  - a merge block always follows its associated constructs
580   //  - a continue target always follows the associated loop construct, if any
581   // @returns the IDs of blocks in reverse structured post order
ReverseStructuredPostOrder()582   std::vector<uint32_t> ReverseStructuredPostOrder() {
583     visit_order_.clear();
584     visited_.clear();
585     VisitBackward(function_.entry()->id());
586 
587     std::vector<uint32_t> order(visit_order_.rbegin(), visit_order_.rend());
588     return order;
589   }
590 
591  private:
592   // Executes a depth first search of the CFG, where right after we visit a
593   // header, we will visit its merge block, then its continue target (if any).
594   // Also records the post order ordering.
VisitBackward(uint32_t id)595   void VisitBackward(uint32_t id) {
596     if (id == 0)
597       return;
598     if (visited_.count(id))
599       return;
600     visited_.insert(id);
601 
602     const spvtools::opt::BasicBlock* bb =
603         id_to_block_[id];  // non-null for valid modules
604     VisitBackward(MergeFor(*bb));
605     VisitBackward(ContinueTargetFor(*bb));
606 
607     // Visit successors. We will naturally skip the continue target and merge
608     // blocks.
609     auto* terminator = bb->terminator();
610     auto opcode = terminator->opcode();
611     if (opcode == SpvOpBranchConditional) {
612       // Visit the false branch, then the true branch, to make them come
613       // out in the natural order for an "if".
614       VisitBackward(terminator->GetSingleWordInOperand(2));
615       VisitBackward(terminator->GetSingleWordInOperand(1));
616     } else if (opcode == SpvOpBranch) {
617       VisitBackward(terminator->GetSingleWordInOperand(0));
618     } else if (opcode == SpvOpSwitch) {
619       // TODO(dneto): Consider visiting the labels in literal-value order.
620       std::vector<uint32_t> successors;
621       bb->ForEachSuccessorLabel([&successors](const uint32_t succ_id) {
622         successors.push_back(succ_id);
623       });
624       for (auto succ_id : successors) {
625         VisitBackward(succ_id);
626       }
627     }
628 
629     visit_order_.push_back(id);
630   }
631 
632   const spvtools::opt::Function& function_;
633   std::unordered_map<uint32_t, const spvtools::opt::BasicBlock*> id_to_block_;
634   std::vector<uint32_t> visit_order_;
635   std::unordered_set<uint32_t> visited_;
636 };
637 
638 /// A StatementBuilder for ast::SwitchStatement
639 /// @see StatementBuilder
640 struct SwitchStatementBuilder
641     : public Castable<SwitchStatementBuilder, StatementBuilder> {
642   /// Constructor
643   /// @param cond the switch statement condition
SwitchStatementBuildertint::reader::spirv::__anone3e1bc1b0111::SwitchStatementBuilder644   explicit SwitchStatementBuilder(const ast::Expression* cond)
645       : condition(cond) {}
646 
647   /// @param builder the program builder
648   /// @returns the built ast::SwitchStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::SwitchStatementBuilder649   const ast::SwitchStatement* Build(ProgramBuilder* builder) const override {
650     // We've listed cases in reverse order in the switch statement.
651     // Reorder them to match the presentation order in WGSL.
652     auto reversed_cases = cases;
653     std::reverse(reversed_cases.begin(), reversed_cases.end());
654 
655     return builder->create<ast::SwitchStatement>(Source{}, condition,
656                                                  reversed_cases);
657   }
658 
659   /// Switch statement condition
660   const ast::Expression* const condition;
661   /// Switch statement cases
662   ast::CaseStatementList cases;
663 };
664 
665 /// A StatementBuilder for ast::IfStatement
666 /// @see StatementBuilder
667 struct IfStatementBuilder
668     : public Castable<IfStatementBuilder, StatementBuilder> {
669   /// Constructor
670   /// @param c the if-statement condition
IfStatementBuildertint::reader::spirv::__anone3e1bc1b0111::IfStatementBuilder671   explicit IfStatementBuilder(const ast::Expression* c) : cond(c) {}
672 
673   /// @param builder the program builder
674   /// @returns the built ast::IfStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::IfStatementBuilder675   const ast::IfStatement* Build(ProgramBuilder* builder) const override {
676     return builder->create<ast::IfStatement>(Source{}, cond, body, else_stmts);
677   }
678 
679   /// If-statement condition
680   const ast::Expression* const cond;
681   /// If-statement block body
682   const ast::BlockStatement* body = nullptr;
683   /// Optional if-statement else statements
684   ast::ElseStatementList else_stmts;
685 };
686 
687 /// A StatementBuilder for ast::LoopStatement
688 /// @see StatementBuilder
689 struct LoopStatementBuilder
690     : public Castable<LoopStatementBuilder, StatementBuilder> {
691   /// @param builder the program builder
692   /// @returns the built ast::LoopStatement
Buildtint::reader::spirv::__anone3e1bc1b0111::LoopStatementBuilder693   ast::LoopStatement* Build(ProgramBuilder* builder) const override {
694     return builder->create<ast::LoopStatement>(Source{}, body, continuing);
695   }
696 
697   /// Loop-statement block body
698   const ast::BlockStatement* body = nullptr;
699   /// Loop-statement continuing body
700   /// @note the mutable keyword here is required as all non-StatementBuilders
701   /// `ast::Node`s are immutable and are referenced with `const` pointers.
702   /// StatementBuilders however exist to provide mutable state while the
703   /// FunctionEmitter is building the function. All StatementBuilders are
704   /// replaced with immutable AST nodes when Finalize() is called.
705   mutable const ast::BlockStatement* continuing = nullptr;
706 };
707 
708 /// @param decos a list of parsed decorations
709 /// @returns true if the decorations include a SampleMask builtin
HasBuiltinSampleMask(const ast::DecorationList & decos)710 bool HasBuiltinSampleMask(const ast::DecorationList& decos) {
711   if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos)) {
712     return builtin->builtin == ast::Builtin::kSampleMask;
713   }
714   return false;
715 }
716 
717 }  // namespace
718 
BlockInfo(const spvtools::opt::BasicBlock & bb)719 BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
720     : basic_block(&bb), id(bb.id()) {}
721 
722 BlockInfo::~BlockInfo() = default;
723 
DefInfo(const spvtools::opt::Instruction & def_inst,uint32_t the_block_pos,size_t the_index)724 DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
725                  uint32_t the_block_pos,
726                  size_t the_index)
727     : inst(def_inst), block_pos(the_block_pos), index(the_index) {}
728 
729 DefInfo::~DefInfo() = default;
730 
Clone(CloneContext *) const731 ast::Node* StatementBuilder::Clone(CloneContext*) const {
732   return nullptr;
733 }
734 
FunctionEmitter(ParserImpl * pi,const spvtools::opt::Function & function,const EntryPointInfo * ep_info)735 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
736                                  const spvtools::opt::Function& function,
737                                  const EntryPointInfo* ep_info)
738     : parser_impl_(*pi),
739       ty_(pi->type_manager()),
740       builder_(pi->builder()),
741       ir_context_(*(pi->ir_context())),
742       def_use_mgr_(ir_context_.get_def_use_mgr()),
743       constant_mgr_(ir_context_.get_constant_mgr()),
744       type_mgr_(ir_context_.get_type_mgr()),
745       fail_stream_(pi->fail_stream()),
746       namer_(pi->namer()),
747       function_(function),
748       sample_mask_in_id(0u),
749       sample_mask_out_id(0u),
750       ep_info_(ep_info) {
751   PushNewStatementBlock(nullptr, 0, nullptr);
752 }
753 
FunctionEmitter(ParserImpl * pi,const spvtools::opt::Function & function)754 FunctionEmitter::FunctionEmitter(ParserImpl* pi,
755                                  const spvtools::opt::Function& function)
756     : FunctionEmitter(pi, function, nullptr) {}
757 
FunctionEmitter(FunctionEmitter && other)758 FunctionEmitter::FunctionEmitter(FunctionEmitter&& other)
759     : parser_impl_(other.parser_impl_),
760       ty_(other.ty_),
761       builder_(other.builder_),
762       ir_context_(other.ir_context_),
763       def_use_mgr_(ir_context_.get_def_use_mgr()),
764       constant_mgr_(ir_context_.get_constant_mgr()),
765       type_mgr_(ir_context_.get_type_mgr()),
766       fail_stream_(other.fail_stream_),
767       namer_(other.namer_),
768       function_(other.function_),
769       sample_mask_in_id(other.sample_mask_out_id),
770       sample_mask_out_id(other.sample_mask_in_id),
771       ep_info_(other.ep_info_) {
772   other.statements_stack_.clear();
773   PushNewStatementBlock(nullptr, 0, nullptr);
774 }
775 
776 FunctionEmitter::~FunctionEmitter() = default;
777 
StatementBlock(const Construct * construct,uint32_t end_id,FunctionEmitter::CompletionAction completion_action)778 FunctionEmitter::StatementBlock::StatementBlock(
779     const Construct* construct,
780     uint32_t end_id,
781     FunctionEmitter::CompletionAction completion_action)
782     : construct_(construct),
783       end_id_(end_id),
784       completion_action_(completion_action) {}
785 
786 FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other) =
787     default;
788 
789 FunctionEmitter::StatementBlock::~StatementBlock() = default;
790 
Finalize(ProgramBuilder * pb)791 void FunctionEmitter::StatementBlock::Finalize(ProgramBuilder* pb) {
792   TINT_ASSERT(Reader, !finalized_ /* Finalize() must only be called once */);
793 
794   for (size_t i = 0; i < statements_.size(); i++) {
795     if (auto* sb = statements_[i]->As<StatementBuilder>()) {
796       statements_[i] = sb->Build(pb);
797     }
798   }
799 
800   if (completion_action_ != nullptr) {
801     completion_action_(statements_);
802   }
803 
804   finalized_ = true;
805 }
806 
Add(const ast::Statement * statement)807 void FunctionEmitter::StatementBlock::Add(const ast::Statement* statement) {
808   TINT_ASSERT(Reader,
809               !finalized_ /* Add() must not be called after Finalize() */);
810   statements_.emplace_back(statement);
811 }
812 
PushNewStatementBlock(const Construct * construct,uint32_t end_id,CompletionAction action)813 void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
814                                             uint32_t end_id,
815                                             CompletionAction action) {
816   statements_stack_.emplace_back(StatementBlock{construct, end_id, action});
817 }
818 
PushGuard(const std::string & guard_name,uint32_t end_id)819 void FunctionEmitter::PushGuard(const std::string& guard_name,
820                                 uint32_t end_id) {
821   TINT_ASSERT(Reader, !statements_stack_.empty());
822   TINT_ASSERT(Reader, !guard_name.empty());
823   // Guard control flow by the guard variable.  Introduce a new
824   // if-selection with a then-clause ending at the same block
825   // as the statement block at the top of the stack.
826   const auto& top = statements_stack_.back();
827 
828   auto* cond = create<ast::IdentifierExpression>(
829       Source{}, builder_.Symbols().Register(guard_name));
830   auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
831 
832   PushNewStatementBlock(
833       top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
834         builder->body = create<ast::BlockStatement>(Source{}, stmts);
835       });
836 }
837 
PushTrueGuard(uint32_t end_id)838 void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
839   TINT_ASSERT(Reader, !statements_stack_.empty());
840   const auto& top = statements_stack_.back();
841 
842   auto* cond = MakeTrue(Source{});
843   auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
844 
845   PushNewStatementBlock(
846       top.GetConstruct(), end_id, [=](const ast::StatementList& stmts) {
847         builder->body = create<ast::BlockStatement>(Source{}, stmts);
848       });
849 }
850 
ast_body()851 const ast::StatementList FunctionEmitter::ast_body() {
852   TINT_ASSERT(Reader, !statements_stack_.empty());
853   auto& entry = statements_stack_[0];
854   entry.Finalize(&builder_);
855   return entry.GetStatements();
856 }
857 
AddStatement(const ast::Statement * statement)858 const ast::Statement* FunctionEmitter::AddStatement(
859     const ast::Statement* statement) {
860   TINT_ASSERT(Reader, !statements_stack_.empty());
861   if (statement != nullptr) {
862     statements_stack_.back().Add(statement);
863   }
864   return statement;
865 }
866 
LastStatement()867 const ast::Statement* FunctionEmitter::LastStatement() {
868   TINT_ASSERT(Reader, !statements_stack_.empty());
869   auto& statement_list = statements_stack_.back().GetStatements();
870   TINT_ASSERT(Reader, !statement_list.empty());
871   return statement_list.back();
872 }
873 
Emit()874 bool FunctionEmitter::Emit() {
875   if (failed()) {
876     return false;
877   }
878   // We only care about functions with bodies.
879   if (function_.cbegin() == function_.cend()) {
880     return true;
881   }
882 
883   // The function declaration, corresponding to how it's written in SPIR-V,
884   // and without regard to whether it's an entry point.
885   FunctionDeclaration decl;
886   if (!ParseFunctionDeclaration(&decl)) {
887     return false;
888   }
889 
890   bool make_body_function = true;
891   if (ep_info_) {
892     TINT_ASSERT(Reader, !ep_info_->inner_name.empty());
893     if (ep_info_->owns_inner_implementation) {
894       // This is an entry point, and we want to emit it as a wrapper around
895       // an implementation function.
896       decl.name = ep_info_->inner_name;
897     } else {
898       // This is a second entry point that shares an inner implementation
899       // function.
900       make_body_function = false;
901     }
902   }
903 
904   if (make_body_function) {
905     auto* body = MakeFunctionBody();
906     if (!body) {
907       return false;
908     }
909 
910     builder_.AST().AddFunction(create<ast::Function>(
911         decl.source, builder_.Symbols().Register(decl.name),
912         std::move(decl.params), decl.return_type->Build(builder_), body,
913         std::move(decl.decorations), ast::DecorationList{}));
914   }
915 
916   if (ep_info_ && !ep_info_->inner_name.empty()) {
917     return EmitEntryPointAsWrapper();
918   }
919 
920   return success();
921 }
922 
MakeFunctionBody()923 const ast::BlockStatement* FunctionEmitter::MakeFunctionBody() {
924   TINT_ASSERT(Reader, statements_stack_.size() == 1);
925 
926   if (!EmitBody()) {
927     return nullptr;
928   }
929 
930   // Set the body of the AST function node.
931   if (statements_stack_.size() != 1) {
932     Fail() << "internal error: statement-list stack should have 1 "
933               "element but has "
934            << statements_stack_.size();
935     return nullptr;
936   }
937 
938   statements_stack_[0].Finalize(&builder_);
939   auto& statements = statements_stack_[0].GetStatements();
940   auto* body = create<ast::BlockStatement>(Source{}, statements);
941 
942   // Maintain the invariant by repopulating the one and only element.
943   statements_stack_.clear();
944   PushNewStatementBlock(constructs_[0].get(), 0, nullptr);
945 
946   return body;
947 }
948 
EmitPipelineInput(std::string var_name,const Type * var_type,ast::DecorationList * decos,std::vector<int> index_prefix,const Type * tip_type,const Type * forced_param_type,ast::VariableList * params,ast::StatementList * statements)949 bool FunctionEmitter::EmitPipelineInput(std::string var_name,
950                                         const Type* var_type,
951                                         ast::DecorationList* decos,
952                                         std::vector<int> index_prefix,
953                                         const Type* tip_type,
954                                         const Type* forced_param_type,
955                                         ast::VariableList* params,
956                                         ast::StatementList* statements) {
957   // TODO(dneto): Handle structs where the locations are annotated on members.
958   tip_type = tip_type->UnwrapAlias();
959   if (auto* ref_type = tip_type->As<Reference>()) {
960     tip_type = ref_type->type;
961   }
962 
963   // Recursively flatten matrices, arrays, and structures.
964   if (auto* matrix_type = tip_type->As<Matrix>()) {
965     index_prefix.push_back(0);
966     const auto num_columns = static_cast<int>(matrix_type->columns);
967     const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
968     for (int col = 0; col < num_columns; col++) {
969       index_prefix.back() = col;
970       if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
971                              forced_param_type, params, statements)) {
972         return false;
973       }
974     }
975     return success();
976   } else if (auto* array_type = tip_type->As<Array>()) {
977     if (array_type->size == 0) {
978       return Fail() << "runtime-size array not allowed on pipeline IO";
979     }
980     index_prefix.push_back(0);
981     const Type* elem_ty = array_type->type;
982     for (int i = 0; i < static_cast<int>(array_type->size); i++) {
983       index_prefix.back() = i;
984       if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
985                              forced_param_type, params, statements)) {
986         return false;
987       }
988     }
989     return success();
990   } else if (auto* struct_type = tip_type->As<Struct>()) {
991     const auto& members = struct_type->members;
992     index_prefix.push_back(0);
993     for (int i = 0; i < static_cast<int>(members.size()); ++i) {
994       index_prefix.back() = i;
995       ast::DecorationList member_decos(*decos);
996       if (!parser_impl_.ConvertPipelineDecorations(
997               struct_type,
998               parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
999               &member_decos)) {
1000         return false;
1001       }
1002       if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
1003                              members[i], forced_param_type, params,
1004                              statements)) {
1005         return false;
1006       }
1007       // Copy the location as updated by nested expansion of the member.
1008       parser_impl_.SetLocation(decos, GetLocation(member_decos));
1009     }
1010     return success();
1011   }
1012 
1013   const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
1014 
1015   const Type* param_type = is_builtin ? forced_param_type : tip_type;
1016 
1017   const auto param_name = namer_.MakeDerivedName(var_name + "_param");
1018   // Create the parameter.
1019   // TODO(dneto): Note: If the parameter has non-location decorations,
1020   // then those decoration AST nodes will be reused between multiple elements
1021   // of a matrix, array, or structure.  Normally that's disallowed but currently
1022   // the SPIR-V reader will make duplicates when the entire AST is cloned
1023   // at the top level of the SPIR-V reader flow.  Consider rewriting this
1024   // to avoid this node-sharing.
1025   params->push_back(
1026       builder_.Param(param_name, param_type->Build(builder_), *decos));
1027 
1028   // Add a body statement to copy the parameter to the corresponding private
1029   // variable.
1030   const ast::Expression* param_value = builder_.Expr(param_name);
1031   const ast::Expression* store_dest = builder_.Expr(var_name);
1032 
1033   // Index into the LHS as needed.
1034   auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
1035   for (auto index : index_prefix) {
1036     if (auto* matrix_type = current_type->As<Matrix>()) {
1037       store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
1038       current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
1039     } else if (auto* array_type = current_type->As<Array>()) {
1040       store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
1041       current_type = array_type->type->UnwrapAlias();
1042     } else if (auto* struct_type = current_type->As<Struct>()) {
1043       store_dest = builder_.MemberAccessor(
1044           store_dest,
1045           builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
1046       current_type = struct_type->members[index];
1047     }
1048   }
1049 
1050   if (is_builtin && (tip_type != forced_param_type)) {
1051     // The parameter will have the WGSL type, but we need bitcast to
1052     // the variable store type.
1053     param_value =
1054         create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
1055   }
1056 
1057   statements->push_back(builder_.Assign(store_dest, param_value));
1058 
1059   // Increment the location attribute, in case more parameters will follow.
1060   IncrementLocation(decos);
1061 
1062   return success();
1063 }
1064 
IncrementLocation(ast::DecorationList * decos)1065 void FunctionEmitter::IncrementLocation(ast::DecorationList* decos) {
1066   for (auto*& deco : *decos) {
1067     if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
1068       // Replace this location decoration with a new one with one higher index.
1069       // The old one doesn't leak because it's kept in the builder's AST node
1070       // list.
1071       deco = builder_.Location(loc_deco->source, loc_deco->value + 1);
1072     }
1073   }
1074 }
1075 
GetLocation(const ast::DecorationList & decos)1076 const ast::Decoration* FunctionEmitter::GetLocation(
1077     const ast::DecorationList& decos) {
1078   for (auto* const& deco : decos) {
1079     if (deco->Is<ast::LocationDecoration>()) {
1080       return deco;
1081     }
1082   }
1083   return nullptr;
1084 }
1085 
EmitPipelineOutput(std::string var_name,const Type * var_type,ast::DecorationList * decos,std::vector<int> index_prefix,const Type * tip_type,const Type * forced_member_type,ast::StructMemberList * return_members,ast::ExpressionList * return_exprs)1086 bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
1087                                          const Type* var_type,
1088                                          ast::DecorationList* decos,
1089                                          std::vector<int> index_prefix,
1090                                          const Type* tip_type,
1091                                          const Type* forced_member_type,
1092                                          ast::StructMemberList* return_members,
1093                                          ast::ExpressionList* return_exprs) {
1094   tip_type = tip_type->UnwrapAlias();
1095   if (auto* ref_type = tip_type->As<Reference>()) {
1096     tip_type = ref_type->type;
1097   }
1098 
1099   // Recursively flatten matrices, arrays, and structures.
1100   if (auto* matrix_type = tip_type->As<Matrix>()) {
1101     index_prefix.push_back(0);
1102     const auto num_columns = static_cast<int>(matrix_type->columns);
1103     const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
1104     for (int col = 0; col < num_columns; col++) {
1105       index_prefix.back() = col;
1106       if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
1107                               forced_member_type, return_members,
1108                               return_exprs)) {
1109         return false;
1110       }
1111     }
1112     return success();
1113   } else if (auto* array_type = tip_type->As<Array>()) {
1114     if (array_type->size == 0) {
1115       return Fail() << "runtime-size array not allowed on pipeline IO";
1116     }
1117     index_prefix.push_back(0);
1118     const Type* elem_ty = array_type->type;
1119     for (int i = 0; i < static_cast<int>(array_type->size); i++) {
1120       index_prefix.back() = i;
1121       if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
1122                               forced_member_type, return_members,
1123                               return_exprs)) {
1124         return false;
1125       }
1126     }
1127     return success();
1128   } else if (auto* struct_type = tip_type->As<Struct>()) {
1129     const auto& members = struct_type->members;
1130     index_prefix.push_back(0);
1131     for (int i = 0; i < static_cast<int>(members.size()); ++i) {
1132       index_prefix.back() = i;
1133       ast::DecorationList member_decos(*decos);
1134       if (!parser_impl_.ConvertPipelineDecorations(
1135               struct_type,
1136               parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
1137               &member_decos)) {
1138         return false;
1139       }
1140       if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
1141                               members[i], forced_member_type, return_members,
1142                               return_exprs)) {
1143         return false;
1144       }
1145       // Copy the location as updated by nested expansion of the member.
1146       parser_impl_.SetLocation(decos, GetLocation(member_decos));
1147     }
1148     return success();
1149   }
1150 
1151   const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
1152 
1153   const Type* member_type = is_builtin ? forced_member_type : tip_type;
1154   // Derive the member name directly from the variable name.  They can't
1155   // collide.
1156   const auto member_name = namer_.MakeDerivedName(var_name);
1157   // Create the member.
1158   // TODO(dneto): Note: If the parameter has non-location decorations,
1159   // then those decoration AST nodes  will be reused between multiple elements
1160   // of a matrix, array, or structure.  Normally that's disallowed but currently
1161   // the SPIR-V reader will make duplicates when the entire AST is cloned
1162   // at the top level of the SPIR-V reader flow.  Consider rewriting this
1163   // to avoid this node-sharing.
1164   return_members->push_back(
1165       builder_.Member(member_name, member_type->Build(builder_), *decos));
1166 
1167   // Create an expression to evaluate the part of the variable indexed by
1168   // the index_prefix.
1169   const ast::Expression* load_source = builder_.Expr(var_name);
1170 
1171   // Index into the variable as needed to pick out the flattened member.
1172   auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
1173   for (auto index : index_prefix) {
1174     if (auto* matrix_type = current_type->As<Matrix>()) {
1175       load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
1176       current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
1177     } else if (auto* array_type = current_type->As<Array>()) {
1178       load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
1179       current_type = array_type->type->UnwrapAlias();
1180     } else if (auto* struct_type = current_type->As<Struct>()) {
1181       load_source = builder_.MemberAccessor(
1182           load_source,
1183           builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
1184       current_type = struct_type->members[index];
1185     }
1186   }
1187 
1188   if (is_builtin && (tip_type != forced_member_type)) {
1189     // The member will have the WGSL type, but we need bitcast to
1190     // the variable store type.
1191     load_source = create<ast::BitcastExpression>(
1192         forced_member_type->Build(builder_), load_source);
1193   }
1194   return_exprs->push_back(load_source);
1195 
1196   // Increment the location attribute, in case more parameters will follow.
1197   IncrementLocation(decos);
1198 
1199   return success();
1200 }
1201 
EmitEntryPointAsWrapper()1202 bool FunctionEmitter::EmitEntryPointAsWrapper() {
1203   Source source;
1204 
1205   // The statements in the body.
1206   ast::StatementList stmts;
1207 
1208   FunctionDeclaration decl;
1209   decl.source = source;
1210   decl.name = ep_info_->name;
1211   const ast::Type* return_type = nullptr;  // Populated below.
1212 
1213   // Pipeline inputs become parameters to the wrapper function, and
1214   // their values are saved into the corresponding private variables that
1215   // have already been created.
1216   for (uint32_t var_id : ep_info_->inputs) {
1217     const auto* var = def_use_mgr_->GetDef(var_id);
1218     TINT_ASSERT(Reader, var != nullptr);
1219     TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
1220     auto* store_type = GetVariableStoreType(*var);
1221     auto* forced_param_type = store_type;
1222     ast::DecorationList param_decos;
1223     if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_param_type,
1224                                                     &param_decos, true)) {
1225       // This occurs, and is not an error, for the PointSize builtin.
1226       if (!success()) {
1227         // But exit early if an error was logged.
1228         return false;
1229       }
1230       continue;
1231     }
1232 
1233     // We don't have to handle initializers because in Vulkan SPIR-V, Input
1234     // variables must not have them.
1235 
1236     const auto var_name = namer_.GetName(var_id);
1237 
1238     bool ok = true;
1239     if (HasBuiltinSampleMask(param_decos)) {
1240       // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar.
1241       // Use the first element only.
1242       auto* sample_mask_array_type =
1243           store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
1244       TINT_ASSERT(Reader, sample_mask_array_type);
1245       ok = EmitPipelineInput(var_name, store_type, &param_decos, {0},
1246                              sample_mask_array_type->type, forced_param_type,
1247                              &(decl.params), &stmts);
1248     } else {
1249       // The normal path.
1250       ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type,
1251                              forced_param_type, &(decl.params), &stmts);
1252     }
1253     if (!ok) {
1254       return false;
1255     }
1256   }
1257 
1258   // Call the inner function.  It has no parameters.
1259   stmts.push_back(create<ast::CallStatement>(
1260       source,
1261       create<ast::CallExpression>(
1262           source,
1263           create<ast::IdentifierExpression>(
1264               source, builder_.Symbols().Register(ep_info_->inner_name)),
1265           ast::ExpressionList{})));
1266 
1267   // Pipeline outputs are mapped to the return value.
1268   if (ep_info_->outputs.empty()) {
1269     // There is nothing to return.
1270     return_type = ty_.Void()->Build(builder_);
1271   } else {
1272     // Pipeline outputs are converted to a structure that is written
1273     // to just before returning.
1274 
1275     const auto return_struct_name =
1276         namer_.MakeDerivedName(ep_info_->name + "_out");
1277     const auto return_struct_sym =
1278         builder_.Symbols().Register(return_struct_name);
1279 
1280     // Define the structure.
1281     std::vector<const ast::StructMember*> return_members;
1282     ast::ExpressionList return_exprs;
1283 
1284     const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
1285 
1286     for (uint32_t var_id : ep_info_->outputs) {
1287       if (var_id == builtin_position_info.per_vertex_var_id) {
1288         // The SPIR-V gl_PerVertex variable has already been remapped to
1289         // a gl_Position variable.  Substitute the type.
1290         const Type* param_type = ty_.Vector(ty_.F32(), 4);
1291         ast::DecorationList out_decos{
1292             create<ast::BuiltinDecoration>(source, ast::Builtin::kPosition)};
1293 
1294         const auto var_name = namer_.GetName(var_id);
1295         return_members.push_back(
1296             builder_.Member(var_name, param_type->Build(builder_), out_decos));
1297         return_exprs.push_back(builder_.Expr(var_name));
1298 
1299       } else {
1300         const auto* var = def_use_mgr_->GetDef(var_id);
1301         TINT_ASSERT(Reader, var != nullptr);
1302         TINT_ASSERT(Reader, var->opcode() == SpvOpVariable);
1303         const Type* store_type = GetVariableStoreType(*var);
1304         const Type* forced_member_type = store_type;
1305         ast::DecorationList out_decos;
1306         if (!parser_impl_.ConvertDecorationsForVariable(
1307                 var_id, &forced_member_type, &out_decos, true)) {
1308           // This occurs, and is not an error, for the PointSize builtin.
1309           if (!success()) {
1310             // But exit early if an error was logged.
1311             return false;
1312           }
1313           continue;
1314         }
1315 
1316         const auto var_name = namer_.GetName(var_id);
1317         bool ok = true;
1318         if (HasBuiltinSampleMask(out_decos)) {
1319           // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a
1320           // scalar. Use the first element only.
1321           auto* sample_mask_array_type =
1322               store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
1323           TINT_ASSERT(Reader, sample_mask_array_type);
1324           ok = EmitPipelineOutput(var_name, store_type, &out_decos, {0},
1325                                   sample_mask_array_type->type,
1326                                   forced_member_type, &return_members,
1327                                   &return_exprs);
1328         } else {
1329           // The normal path.
1330           ok = EmitPipelineOutput(var_name, store_type, &out_decos, {},
1331                                   store_type, forced_member_type,
1332                                   &return_members, &return_exprs);
1333         }
1334         if (!ok) {
1335           return false;
1336         }
1337       }
1338     }
1339 
1340     if (return_members.empty()) {
1341       // This can occur if only the PointSize member is accessed, because we
1342       // never emit it.
1343       return_type = ty_.Void()->Build(builder_);
1344     } else {
1345       // Create and register the result type.
1346       auto* str = create<ast::Struct>(Source{}, return_struct_sym,
1347                                       return_members, ast::DecorationList{});
1348       parser_impl_.AddTypeDecl(return_struct_sym, str);
1349       return_type = builder_.ty.Of(str);
1350 
1351       // Add the return-value statement.
1352       stmts.push_back(create<ast::ReturnStatement>(
1353           source,
1354           builder_.Construct(source, return_type, std::move(return_exprs))));
1355     }
1356   }
1357 
1358   auto* body = create<ast::BlockStatement>(source, stmts);
1359   ast::DecorationList fn_decos;
1360   fn_decos.emplace_back(create<ast::StageDecoration>(source, ep_info_->stage));
1361 
1362   if (ep_info_->stage == ast::PipelineStage::kCompute) {
1363     auto& size = ep_info_->workgroup_size;
1364     if (size.x != 0 && size.y != 0 && size.z != 0) {
1365       const ast::Expression* x = builder_.Expr(static_cast<int>(size.x));
1366       const ast::Expression* y =
1367           size.y ? builder_.Expr(static_cast<int>(size.y)) : nullptr;
1368       const ast::Expression* z =
1369           size.z ? builder_.Expr(static_cast<int>(size.z)) : nullptr;
1370       fn_decos.emplace_back(
1371           create<ast::WorkgroupDecoration>(Source{}, x, y, z));
1372     }
1373   }
1374 
1375   builder_.AST().AddFunction(
1376       create<ast::Function>(source, builder_.Symbols().Register(ep_info_->name),
1377                             std::move(decl.params), return_type, body,
1378                             std::move(fn_decos), ast::DecorationList{}));
1379 
1380   return true;
1381 }
1382 
ParseFunctionDeclaration(FunctionDeclaration * decl)1383 bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
1384   if (failed()) {
1385     return false;
1386   }
1387 
1388   const std::string name = namer_.Name(function_.result_id());
1389 
1390   // Surprisingly, the "type id" on an OpFunction is the result type of the
1391   // function, not the type of the function.  This is the one exceptional case
1392   // in SPIR-V where the type ID is not the type of the result ID.
1393   auto* ret_ty = parser_impl_.ConvertType(function_.type_id());
1394   if (failed()) {
1395     return false;
1396   }
1397   if (ret_ty == nullptr) {
1398     return Fail()
1399            << "internal error: unregistered return type for function with ID "
1400            << function_.result_id();
1401   }
1402 
1403   ast::VariableList ast_params;
1404   function_.ForEachParam(
1405       [this, &ast_params](const spvtools::opt::Instruction* param) {
1406         auto* type = parser_impl_.ConvertType(param->type_id());
1407         if (type != nullptr) {
1408           auto* ast_param = parser_impl_.MakeVariable(
1409               param->result_id(), ast::StorageClass::kNone, type, true, nullptr,
1410               ast::DecorationList{});
1411           // Parameters are treated as const declarations.
1412           ast_params.emplace_back(ast_param);
1413           // The value is accessible by name.
1414           identifier_types_.emplace(param->result_id(), type);
1415         } else {
1416           // We've already logged an error and emitted a diagnostic. Do nothing
1417           // here.
1418         }
1419       });
1420   if (failed()) {
1421     return false;
1422   }
1423   decl->name = name;
1424   decl->params = std::move(ast_params);
1425   decl->return_type = ret_ty;
1426   decl->decorations.clear();
1427 
1428   return success();
1429 }
1430 
GetVariableStoreType(const spvtools::opt::Instruction & var_decl_inst)1431 const Type* FunctionEmitter::GetVariableStoreType(
1432     const spvtools::opt::Instruction& var_decl_inst) {
1433   const auto type_id = var_decl_inst.type_id();
1434   // Normally we use the SPIRV-Tools optimizer to manage types.
1435   // But when two struct types have the same member types and decorations,
1436   // but differ only in member names, the two struct types will be
1437   // represented by a single common internal struct type.
1438   // So avoid the optimizer's representation and instead follow the
1439   // SPIR-V instructions themselves.
1440   const auto* ptr_ty = def_use_mgr_->GetDef(type_id);
1441   const auto store_ty_id = ptr_ty->GetSingleWordInOperand(1);
1442   const auto* result = parser_impl_.ConvertType(store_ty_id);
1443   return result;
1444 }
1445 
EmitBody()1446 bool FunctionEmitter::EmitBody() {
1447   RegisterBasicBlocks();
1448 
1449   if (!TerminatorsAreValid()) {
1450     return false;
1451   }
1452   if (!RegisterMerges()) {
1453     return false;
1454   }
1455 
1456   ComputeBlockOrderAndPositions();
1457   if (!VerifyHeaderContinueMergeOrder()) {
1458     return false;
1459   }
1460   if (!LabelControlFlowConstructs()) {
1461     return false;
1462   }
1463   if (!FindSwitchCaseHeaders()) {
1464     return false;
1465   }
1466   if (!ClassifyCFGEdges()) {
1467     return false;
1468   }
1469   if (!FindIfSelectionInternalHeaders()) {
1470     return false;
1471   }
1472 
1473   if (!RegisterSpecialBuiltInVariables()) {
1474     return false;
1475   }
1476   if (!RegisterLocallyDefinedValues()) {
1477     return false;
1478   }
1479   FindValuesNeedingNamedOrHoistedDefinition();
1480 
1481   if (!EmitFunctionVariables()) {
1482     return false;
1483   }
1484   if (!EmitFunctionBodyStatements()) {
1485     return false;
1486   }
1487   return success();
1488 }
1489 
RegisterBasicBlocks()1490 void FunctionEmitter::RegisterBasicBlocks() {
1491   for (auto& block : function_) {
1492     block_info_[block.id()] = std::make_unique<BlockInfo>(block);
1493   }
1494 }
1495 
TerminatorsAreValid()1496 bool FunctionEmitter::TerminatorsAreValid() {
1497   if (failed()) {
1498     return false;
1499   }
1500 
1501   const auto entry_id = function_.begin()->id();
1502   for (const auto& block : function_) {
1503     if (!block.terminator()) {
1504       return Fail() << "Block " << block.id() << " has no terminator";
1505     }
1506   }
1507   for (const auto& block : function_) {
1508     block.WhileEachSuccessorLabel(
1509         [this, &block, entry_id](const uint32_t succ_id) -> bool {
1510           if (succ_id == entry_id) {
1511             return Fail() << "Block " << block.id()
1512                           << " branches to function entry block " << entry_id;
1513           }
1514           if (!GetBlockInfo(succ_id)) {
1515             return Fail() << "Block " << block.id() << " in function "
1516                           << function_.DefInst().result_id() << " branches to "
1517                           << succ_id << " which is not a block in the function";
1518           }
1519           return true;
1520         });
1521   }
1522   return success();
1523 }
1524 
RegisterMerges()1525 bool FunctionEmitter::RegisterMerges() {
1526   if (failed()) {
1527     return false;
1528   }
1529 
1530   const auto entry_id = function_.begin()->id();
1531   for (const auto& block : function_) {
1532     const auto block_id = block.id();
1533     auto* block_info = GetBlockInfo(block_id);
1534     if (!block_info) {
1535       return Fail() << "internal error: block " << block_id
1536                     << " missing; blocks should already "
1537                        "have been registered";
1538     }
1539 
1540     if (const auto* inst = block.GetMergeInst()) {
1541       auto terminator_opcode = block.terminator()->opcode();
1542       switch (inst->opcode()) {
1543         case SpvOpSelectionMerge:
1544           if ((terminator_opcode != SpvOpBranchConditional) &&
1545               (terminator_opcode != SpvOpSwitch)) {
1546             return Fail() << "Selection header " << block_id
1547                           << " does not end in an OpBranchConditional or "
1548                              "OpSwitch instruction";
1549           }
1550           break;
1551         case SpvOpLoopMerge:
1552           if ((terminator_opcode != SpvOpBranchConditional) &&
1553               (terminator_opcode != SpvOpBranch)) {
1554             return Fail() << "Loop header " << block_id
1555                           << " does not end in an OpBranch or "
1556                              "OpBranchConditional instruction";
1557           }
1558           break;
1559         default:
1560           break;
1561       }
1562 
1563       const uint32_t header = block.id();
1564       auto* header_info = block_info;
1565       const uint32_t merge = inst->GetSingleWordInOperand(0);
1566       auto* merge_info = GetBlockInfo(merge);
1567       if (!merge_info) {
1568         return Fail() << "Structured header block " << header
1569                       << " declares invalid merge block " << merge;
1570       }
1571       if (merge == header) {
1572         return Fail() << "Structured header block " << header
1573                       << " cannot be its own merge block";
1574       }
1575       if (merge_info->header_for_merge) {
1576         return Fail() << "Block " << merge
1577                       << " declared as merge block for more than one header: "
1578                       << merge_info->header_for_merge << ", " << header;
1579       }
1580       merge_info->header_for_merge = header;
1581       header_info->merge_for_header = merge;
1582 
1583       if (inst->opcode() == SpvOpLoopMerge) {
1584         if (header == entry_id) {
1585           return Fail() << "Function entry block " << entry_id
1586                         << " cannot be a loop header";
1587         }
1588         const uint32_t ct = inst->GetSingleWordInOperand(1);
1589         auto* ct_info = GetBlockInfo(ct);
1590         if (!ct_info) {
1591           return Fail() << "Structured header " << header
1592                         << " declares invalid continue target " << ct;
1593         }
1594         if (ct == merge) {
1595           return Fail() << "Invalid structured header block " << header
1596                         << ": declares block " << ct
1597                         << " as both its merge block and continue target";
1598         }
1599         if (ct_info->header_for_continue) {
1600           return Fail()
1601                  << "Block " << ct
1602                  << " declared as continue target for more than one header: "
1603                  << ct_info->header_for_continue << ", " << header;
1604         }
1605         ct_info->header_for_continue = header;
1606         header_info->continue_for_header = ct;
1607       }
1608     }
1609 
1610     // Check single-block loop cases.
1611     bool is_single_block_loop = false;
1612     block_info->basic_block->ForEachSuccessorLabel(
1613         [&is_single_block_loop, block_id](const uint32_t succ) {
1614           if (block_id == succ)
1615             is_single_block_loop = true;
1616         });
1617     const auto ct = block_info->continue_for_header;
1618     block_info->is_continue_entire_loop = ct == block_id;
1619     if (is_single_block_loop && !block_info->is_continue_entire_loop) {
1620       return Fail() << "Block " << block_id
1621                     << " branches to itself but is not its own continue target";
1622     }
1623     // It's valid for a the header of a multi-block loop header to declare
1624     // itself as its own continue target.
1625   }
1626   return success();
1627 }
1628 
ComputeBlockOrderAndPositions()1629 void FunctionEmitter::ComputeBlockOrderAndPositions() {
1630   block_order_ = StructuredTraverser(function_).ReverseStructuredPostOrder();
1631 
1632   for (uint32_t i = 0; i < block_order_.size(); ++i) {
1633     GetBlockInfo(block_order_[i])->pos = i;
1634   }
1635   // The invalid block position is not the position of any block that is in the
1636   // order.
1637   assert(block_order_.size() <= kInvalidBlockPos);
1638 }
1639 
VerifyHeaderContinueMergeOrder()1640 bool FunctionEmitter::VerifyHeaderContinueMergeOrder() {
1641   // Verify interval rules for a structured header block:
1642   //
1643   //    If the CFG satisfies structured control flow rules, then:
1644   //    If header H is reachable, then the following "interval rules" hold,
1645   //    where M(H) is H's merge block, and CT(H) is H's continue target:
1646   //
1647   //      Pos(H) < Pos(M(H))
1648   //
1649   //      If CT(H) exists, then:
1650   //         Pos(H) <= Pos(CT(H))
1651   //         Pos(CT(H)) < Pos(M)
1652   //
1653   for (auto block_id : block_order_) {
1654     const auto* block_info = GetBlockInfo(block_id);
1655     const auto merge = block_info->merge_for_header;
1656     if (merge == 0) {
1657       continue;
1658     }
1659     // This is a header.
1660     const auto header = block_id;
1661     const auto* header_info = block_info;
1662     const auto header_pos = header_info->pos;
1663     const auto merge_pos = GetBlockInfo(merge)->pos;
1664 
1665     // Pos(H) < Pos(M(H))
1666     // Note: When recording merges we made sure H != M(H)
1667     if (merge_pos <= header_pos) {
1668       return Fail() << "Header " << header
1669                     << " does not strictly dominate its merge block " << merge;
1670       // TODO(dneto): Report a path from the entry block to the merge block
1671       // without going through the header block.
1672     }
1673 
1674     const auto ct = block_info->continue_for_header;
1675     if (ct == 0) {
1676       continue;
1677     }
1678     // Furthermore, this is a loop header.
1679     const auto* ct_info = GetBlockInfo(ct);
1680     const auto ct_pos = ct_info->pos;
1681     // Pos(H) <= Pos(CT(H))
1682     if (ct_pos < header_pos) {
1683       Fail() << "Loop header " << header
1684              << " does not dominate its continue target " << ct;
1685     }
1686     // Pos(CT(H)) < Pos(M(H))
1687     // Note: When recording merges we made sure CT(H) != M(H)
1688     if (merge_pos <= ct_pos) {
1689       return Fail() << "Merge block " << merge << " for loop headed at block "
1690                     << header
1691                     << " appears at or before the loop's continue "
1692                        "construct headed by "
1693                        "block "
1694                     << ct;
1695     }
1696   }
1697   return success();
1698 }
1699 
LabelControlFlowConstructs()1700 bool FunctionEmitter::LabelControlFlowConstructs() {
1701   // Label each block in the block order with its nearest enclosing structured
1702   // control flow construct. Populates the |construct| member of BlockInfo.
1703 
1704   //  Keep a stack of enclosing structured control flow constructs.  Start
1705   //  with the synthetic construct representing the entire function.
1706   //
1707   //  Scan from left to right in the block order, and check conditions
1708   //  on each block in the following order:
1709   //
1710   //        a. When you reach a merge block, the top of the stack should
1711   //           be the associated header. Pop it off.
1712   //        b. When you reach a header, push it on the stack.
1713   //        c. When you reach a continue target, push it on the stack.
1714   //           (A block can be both a header and a continue target.)
1715   //        c. When you reach a block with an edge branching backward (in the
1716   //           structured order) to block T:
1717   //            T should be a loop header, and the top of the stack should be a
1718   //            continue target associated with T.
1719   //            This is the end of the continue construct. Pop the continue
1720   //            target off the stack.
1721   //
1722   //       Note: A loop header can declare itself as its own continue target.
1723   //
1724   //       Note: For a single-block loop, that block is a header, its own
1725   //       continue target, and its own backedge block.
1726   //
1727   //       Note: We pop the merge off first because a merge block that marks
1728   //       the end of one construct can be a single-block loop.  So that block
1729   //       is a merge, a header, a continue target, and a backedge block.
1730   //       But we want to finish processing of the merge before dealing with
1731   //       the loop.
1732   //
1733   //      In the same scan, mark each basic block with the nearest enclosing
1734   //      header: the most recent header for which we haven't reached its merge
1735   //      block. Also mark the the most recent continue target for which we
1736   //      haven't reached the backedge block.
1737 
1738   TINT_ASSERT(Reader, block_order_.size() > 0);
1739   constructs_.clear();
1740   const auto entry_id = block_order_[0];
1741 
1742   // The stack of enclosing constructs.
1743   std::vector<Construct*> enclosing;
1744 
1745   // Creates a control flow construct and pushes it onto the stack.
1746   // Its parent is the top of the stack, or nullptr if the stack is empty.
1747   // Returns the newly created construct.
1748   auto push_construct = [this, &enclosing](size_t depth, Construct::Kind k,
1749                                            uint32_t begin_id,
1750                                            uint32_t end_id) -> Construct* {
1751     const auto begin_pos = GetBlockInfo(begin_id)->pos;
1752     const auto end_pos =
1753         end_id == 0 ? uint32_t(block_order_.size()) : GetBlockInfo(end_id)->pos;
1754     const auto* parent = enclosing.empty() ? nullptr : enclosing.back();
1755     auto scope_end_pos = end_pos;
1756     // A loop construct is added right after its associated continue construct.
1757     // In that case, adjust the parent up.
1758     if (k == Construct::kLoop) {
1759       TINT_ASSERT(Reader, parent);
1760       TINT_ASSERT(Reader, parent->kind == Construct::kContinue);
1761       scope_end_pos = parent->end_pos;
1762       parent = parent->parent;
1763     }
1764     constructs_.push_back(std::make_unique<Construct>(
1765         parent, static_cast<int>(depth), k, begin_id, end_id, begin_pos,
1766         end_pos, scope_end_pos));
1767     Construct* result = constructs_.back().get();
1768     enclosing.push_back(result);
1769     return result;
1770   };
1771 
1772   // Make a synthetic kFunction construct to enclose all blocks in the function.
1773   push_construct(0, Construct::kFunction, entry_id, 0);
1774   // The entry block can be a selection construct, so be sure to process
1775   // it anyway.
1776 
1777   for (uint32_t i = 0; i < block_order_.size(); ++i) {
1778     const auto block_id = block_order_[i];
1779     TINT_ASSERT(Reader, block_id > 0);
1780     auto* block_info = GetBlockInfo(block_id);
1781     TINT_ASSERT(Reader, block_info);
1782 
1783     if (enclosing.empty()) {
1784       return Fail() << "internal error: too many merge blocks before block "
1785                     << block_id;
1786     }
1787     const Construct* top = enclosing.back();
1788 
1789     while (block_id == top->end_id) {
1790       // We've reached a predeclared end of the construct.  Pop it off the
1791       // stack.
1792       enclosing.pop_back();
1793       if (enclosing.empty()) {
1794         return Fail() << "internal error: too many merge blocks before block "
1795                       << block_id;
1796       }
1797       top = enclosing.back();
1798     }
1799 
1800     const auto merge = block_info->merge_for_header;
1801     if (merge != 0) {
1802       // The current block is a header.
1803       const auto header = block_id;
1804       const auto* header_info = block_info;
1805       const auto depth = 1 + top->depth;
1806       const auto ct = header_info->continue_for_header;
1807       if (ct != 0) {
1808         // The current block is a loop header.
1809         // We should see the continue construct after the loop construct, so
1810         // push the loop construct last.
1811 
1812         // From the interval rule, the continue construct consists of blocks
1813         // in the block order, starting at the continue target, until just
1814         // before the merge block.
1815         top = push_construct(depth, Construct::kContinue, ct, merge);
1816         // A loop header that is its own continue target will have an
1817         // empty loop construct. Only create a loop construct when
1818         // the continue target is *not* the same as the loop header.
1819         if (header != ct) {
1820           // From the interval rule, the loop construct consists of blocks
1821           // in the block order, starting at the header, until just
1822           // before the continue target.
1823           top = push_construct(depth, Construct::kLoop, header, ct);
1824 
1825           // If the loop header branches to two different blocks inside the loop
1826           // construct, then the loop body should be modeled as an if-selection
1827           // construct
1828           std::vector<uint32_t> targets;
1829           header_info->basic_block->ForEachSuccessorLabel(
1830               [&targets](const uint32_t target) { targets.push_back(target); });
1831           if ((targets.size() == 2u) && targets[0] != targets[1]) {
1832             const auto target0_pos = GetBlockInfo(targets[0])->pos;
1833             const auto target1_pos = GetBlockInfo(targets[1])->pos;
1834             if (top->ContainsPos(target0_pos) &&
1835                 top->ContainsPos(target1_pos)) {
1836               // Insert a synthetic if-selection
1837               top = push_construct(depth + 1, Construct::kIfSelection, header,
1838                                    ct);
1839             }
1840           }
1841         }
1842       } else {
1843         // From the interval rule, the selection construct consists of blocks
1844         // in the block order, starting at the header, until just before the
1845         // merge block.
1846         const auto branch_opcode =
1847             header_info->basic_block->terminator()->opcode();
1848         const auto kind = (branch_opcode == SpvOpBranchConditional)
1849                               ? Construct::kIfSelection
1850                               : Construct::kSwitchSelection;
1851         top = push_construct(depth, kind, header, merge);
1852       }
1853     }
1854 
1855     TINT_ASSERT(Reader, top);
1856     block_info->construct = top;
1857   }
1858 
1859   // At the end of the block list, we should only have the kFunction construct
1860   // left.
1861   if (enclosing.size() != 1) {
1862     return Fail() << "internal error: unbalanced structured constructs when "
1863                      "labeling structured constructs: ended with "
1864                   << enclosing.size() - 1 << " unterminated constructs";
1865   }
1866   const auto* top = enclosing[0];
1867   if (top->kind != Construct::kFunction || top->depth != 0) {
1868     return Fail() << "internal error: outermost construct is not a function?!";
1869   }
1870 
1871   return success();
1872 }
1873 
FindSwitchCaseHeaders()1874 bool FunctionEmitter::FindSwitchCaseHeaders() {
1875   if (failed()) {
1876     return false;
1877   }
1878   for (auto& construct : constructs_) {
1879     if (construct->kind != Construct::kSwitchSelection) {
1880       continue;
1881     }
1882     const auto* branch =
1883         GetBlockInfo(construct->begin_id)->basic_block->terminator();
1884 
1885     // Mark the default block
1886     const auto default_id = branch->GetSingleWordInOperand(1);
1887     auto* default_block = GetBlockInfo(default_id);
1888     // A default target can't be a backedge.
1889     if (construct->begin_pos >= default_block->pos) {
1890       // An OpSwitch must dominate its cases.  Also, it can't be a self-loop
1891       // as that would be a backedge, and backedges can only target a loop,
1892       // and loops use an OpLoopMerge instruction, which can't precede an
1893       // OpSwitch.
1894       return Fail() << "Switch branch from block " << construct->begin_id
1895                     << " to default target block " << default_id
1896                     << " can't be a back-edge";
1897     }
1898     // A default target can be the merge block, but can't go past it.
1899     if (construct->end_pos < default_block->pos) {
1900       return Fail() << "Switch branch from block " << construct->begin_id
1901                     << " to default block " << default_id
1902                     << " escapes the selection construct";
1903     }
1904     if (default_block->default_head_for) {
1905       // An OpSwitch must dominate its cases, including the default target.
1906       return Fail() << "Block " << default_id
1907                     << " is declared as the default target for two OpSwitch "
1908                        "instructions, at blocks "
1909                     << default_block->default_head_for->begin_id << " and "
1910                     << construct->begin_id;
1911     }
1912     if ((default_block->header_for_merge != 0) &&
1913         (default_block->header_for_merge != construct->begin_id)) {
1914       // The switch instruction for this default block is an alternate path to
1915       // the merge block, and hence the merge block is not dominated by its own
1916       // (different) header.
1917       return Fail() << "Block " << default_block->id
1918                     << " is the default block for switch-selection header "
1919                     << construct->begin_id << " and also the merge block for "
1920                     << default_block->header_for_merge
1921                     << " (violates dominance rule)";
1922     }
1923 
1924     default_block->default_head_for = construct.get();
1925     default_block->default_is_merge = default_block->pos == construct->end_pos;
1926 
1927     // Map a case target to the list of values selecting that case.
1928     std::unordered_map<uint32_t, std::vector<uint64_t>> block_to_values;
1929     std::vector<uint32_t> case_targets;
1930     std::unordered_set<uint64_t> case_values;
1931 
1932     // Process case targets.
1933     for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
1934       const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
1935       const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
1936 
1937       if (case_values.count(value)) {
1938         return Fail() << "Duplicate case value " << value
1939                       << " in OpSwitch in block " << construct->begin_id;
1940       }
1941       case_values.insert(value);
1942       if (block_to_values.count(case_target_id) == 0) {
1943         case_targets.push_back(case_target_id);
1944       }
1945       block_to_values[case_target_id].push_back(value);
1946     }
1947 
1948     for (uint32_t case_target_id : case_targets) {
1949       auto* case_block = GetBlockInfo(case_target_id);
1950 
1951       case_block->case_values = std::make_unique<std::vector<uint64_t>>(
1952           std::move(block_to_values[case_target_id]));
1953 
1954       // A case target can't be a back-edge.
1955       if (construct->begin_pos >= case_block->pos) {
1956         // An OpSwitch must dominate its cases.  Also, it can't be a self-loop
1957         // as that would be a backedge, and backedges can only target a loop,
1958         // and loops use an OpLoopMerge instruction, which can't preceded an
1959         // OpSwitch.
1960         return Fail() << "Switch branch from block " << construct->begin_id
1961                       << " to case target block " << case_target_id
1962                       << " can't be a back-edge";
1963       }
1964       // A case target can be the merge block, but can't go past it.
1965       if (construct->end_pos < case_block->pos) {
1966         return Fail() << "Switch branch from block " << construct->begin_id
1967                       << " to case target block " << case_target_id
1968                       << " escapes the selection construct";
1969       }
1970       if (case_block->header_for_merge != 0 &&
1971           case_block->header_for_merge != construct->begin_id) {
1972         // The switch instruction for this case block is an alternate path to
1973         // the merge block, and hence the merge block is not dominated by its
1974         // own (different) header.
1975         return Fail() << "Block " << case_block->id
1976                       << " is a case block for switch-selection header "
1977                       << construct->begin_id << " and also the merge block for "
1978                       << case_block->header_for_merge
1979                       << " (violates dominance rule)";
1980       }
1981 
1982       // Mark the target as a case target.
1983       if (case_block->case_head_for) {
1984         // An OpSwitch must dominate its cases.
1985         return Fail()
1986                << "Block " << case_target_id
1987                << " is declared as the switch case target for two OpSwitch "
1988                   "instructions, at blocks "
1989                << case_block->case_head_for->begin_id << " and "
1990                << construct->begin_id;
1991       }
1992       case_block->case_head_for = construct.get();
1993     }
1994   }
1995   return success();
1996 }
1997 
HeaderIfBreakable(const Construct * c)1998 BlockInfo* FunctionEmitter::HeaderIfBreakable(const Construct* c) {
1999   if (c == nullptr) {
2000     return nullptr;
2001   }
2002   switch (c->kind) {
2003     case Construct::kLoop:
2004     case Construct::kSwitchSelection:
2005       return GetBlockInfo(c->begin_id);
2006     case Construct::kContinue: {
2007       const auto* continue_target = GetBlockInfo(c->begin_id);
2008       return GetBlockInfo(continue_target->header_for_continue);
2009     }
2010     default:
2011       break;
2012   }
2013   return nullptr;
2014 }
2015 
SiblingLoopConstruct(const Construct * c) const2016 const Construct* FunctionEmitter::SiblingLoopConstruct(
2017     const Construct* c) const {
2018   if (c == nullptr || c->kind != Construct::kContinue) {
2019     return nullptr;
2020   }
2021   const uint32_t continue_target_id = c->begin_id;
2022   const auto* continue_target = GetBlockInfo(continue_target_id);
2023   const uint32_t header_id = continue_target->header_for_continue;
2024   if (continue_target_id == header_id) {
2025     // The continue target is the whole loop.
2026     return nullptr;
2027   }
2028   const auto* candidate = GetBlockInfo(header_id)->construct;
2029   // Walk up the construct tree until we hit the loop.  In future
2030   // we might handle the corner case where the same block is both a
2031   // loop header and a selection header. For example, where the
2032   // loop header block has a conditional branch going to distinct
2033   // targets inside the loop body.
2034   while (candidate && candidate->kind != Construct::kLoop) {
2035     candidate = candidate->parent;
2036   }
2037   return candidate;
2038 }
2039 
ClassifyCFGEdges()2040 bool FunctionEmitter::ClassifyCFGEdges() {
2041   if (failed()) {
2042     return false;
2043   }
2044 
2045   // Checks validity of CFG edges leaving each basic block.  This implicitly
2046   // checks dominance rules for headers and continue constructs.
2047   //
2048   // For each branch encountered, classify each edge (S,T) as:
2049   //    - a back-edge
2050   //    - a structured exit (specific ways of branching to enclosing construct)
2051   //    - a normal (forward) edge, either natural control flow or a case
2052   //    fallthrough
2053   //
2054   // If more than one block is targeted by a normal edge, then S must be a
2055   // structured header.
2056   //
2057   // Term: NEC(B) is the nearest enclosing construct for B.
2058   //
2059   // If edge (S,T) is a normal edge, and NEC(S) != NEC(T), then
2060   //    T is the header block of its NEC(T), and
2061   //    NEC(S) is the parent of NEC(T).
2062 
2063   for (const auto src : block_order_) {
2064     TINT_ASSERT(Reader, src > 0);
2065     auto* src_info = GetBlockInfo(src);
2066     TINT_ASSERT(Reader, src_info);
2067     const auto src_pos = src_info->pos;
2068     const auto& src_construct = *(src_info->construct);
2069 
2070     // Compute the ordered list of unique successors.
2071     std::vector<uint32_t> successors;
2072     {
2073       std::unordered_set<uint32_t> visited;
2074       src_info->basic_block->ForEachSuccessorLabel(
2075           [&successors, &visited](const uint32_t succ) {
2076             if (visited.count(succ) == 0) {
2077               successors.push_back(succ);
2078               visited.insert(succ);
2079             }
2080           });
2081     }
2082 
2083     // There should only be one backedge per backedge block.
2084     uint32_t num_backedges = 0;
2085 
2086     // Track destinations for normal forward edges, either kForward
2087     // or kCaseFallThrough. These count toward the need
2088     // to have a merge instruction.  We also track kIfBreak edges
2089     // because when used with normal forward edges, we'll need
2090     // to generate a flow guard variable.
2091     std::vector<uint32_t> normal_forward_edges;
2092     std::vector<uint32_t> if_break_edges;
2093 
2094     if (successors.empty() && src_construct.enclosing_continue) {
2095       // Kill and return are not allowed in a continue construct.
2096       return Fail() << "Invalid function exit at block " << src
2097                     << " from continue construct starting at "
2098                     << src_construct.enclosing_continue->begin_id;
2099     }
2100 
2101     for (const auto dest : successors) {
2102       const auto* dest_info = GetBlockInfo(dest);
2103       // We've already checked terminators are valid.
2104       TINT_ASSERT(Reader, dest_info);
2105       const auto dest_pos = dest_info->pos;
2106 
2107       // Insert the edge kind entry and keep a handle to update
2108       // its classification.
2109       EdgeKind& edge_kind = src_info->succ_edge[dest];
2110 
2111       if (src_pos >= dest_pos) {
2112         // This is a backedge.
2113         edge_kind = EdgeKind::kBack;
2114         num_backedges++;
2115         const auto* continue_construct = src_construct.enclosing_continue;
2116         if (!continue_construct) {
2117           return Fail() << "Invalid backedge (" << src << "->" << dest
2118                         << "): " << src << " is not in a continue construct";
2119         }
2120         if (src_pos != continue_construct->end_pos - 1) {
2121           return Fail() << "Invalid exit (" << src << "->" << dest
2122                         << ") from continue construct: " << src
2123                         << " is not the last block in the continue construct "
2124                            "starting at "
2125                         << src_construct.begin_id
2126                         << " (violates post-dominance rule)";
2127         }
2128         const auto* ct_info = GetBlockInfo(continue_construct->begin_id);
2129         TINT_ASSERT(Reader, ct_info);
2130         if (ct_info->header_for_continue != dest) {
2131           return Fail()
2132                  << "Invalid backedge (" << src << "->" << dest
2133                  << "): does not branch to the corresponding loop header, "
2134                     "expected "
2135                  << ct_info->header_for_continue;
2136         }
2137       } else {
2138         // This is a forward edge.
2139         // For now, classify it that way, but we might update it.
2140         edge_kind = EdgeKind::kForward;
2141 
2142         // Exit from a continue construct can only be from the last block.
2143         const auto* continue_construct = src_construct.enclosing_continue;
2144         if (continue_construct != nullptr) {
2145           if (continue_construct->ContainsPos(src_pos) &&
2146               !continue_construct->ContainsPos(dest_pos) &&
2147               (src_pos != continue_construct->end_pos - 1)) {
2148             return Fail() << "Invalid exit (" << src << "->" << dest
2149                           << ") from continue construct: " << src
2150                           << " is not the last block in the continue construct "
2151                              "starting at "
2152                           << continue_construct->begin_id
2153                           << " (violates post-dominance rule)";
2154           }
2155         }
2156 
2157         // Check valid structured exit cases.
2158 
2159         if (edge_kind == EdgeKind::kForward) {
2160           // Check for a 'break' from a loop or from a switch.
2161           const auto* breakable_header = HeaderIfBreakable(
2162               src_construct.enclosing_loop_or_continue_or_switch);
2163           if (breakable_header != nullptr) {
2164             if (dest == breakable_header->merge_for_header) {
2165               // It's a break.
2166               edge_kind = (breakable_header->construct->kind ==
2167                            Construct::kSwitchSelection)
2168                               ? EdgeKind::kSwitchBreak
2169                               : EdgeKind::kLoopBreak;
2170             }
2171           }
2172         }
2173 
2174         if (edge_kind == EdgeKind::kForward) {
2175           // Check for a 'continue' from within a loop.
2176           const auto* loop_header =
2177               HeaderIfBreakable(src_construct.enclosing_loop);
2178           if (loop_header != nullptr) {
2179             if (dest == loop_header->continue_for_header) {
2180               // It's a continue.
2181               edge_kind = EdgeKind::kLoopContinue;
2182             }
2183           }
2184         }
2185 
2186         if (edge_kind == EdgeKind::kForward) {
2187           const auto& header_info = *GetBlockInfo(src_construct.begin_id);
2188           if (dest == header_info.merge_for_header) {
2189             // Branch to construct's merge block.  The loop break and
2190             // switch break cases have already been covered.
2191             edge_kind = EdgeKind::kIfBreak;
2192           }
2193         }
2194 
2195         // A forward edge into a case construct that comes from something
2196         // other than the OpSwitch is actually a fallthrough.
2197         if (edge_kind == EdgeKind::kForward) {
2198           const auto* switch_construct =
2199               (dest_info->case_head_for ? dest_info->case_head_for
2200                                         : dest_info->default_head_for);
2201           if (switch_construct != nullptr) {
2202             if (src != switch_construct->begin_id) {
2203               edge_kind = EdgeKind::kCaseFallThrough;
2204             }
2205           }
2206         }
2207 
2208         // The edge-kind has been finalized.
2209 
2210         if ((edge_kind == EdgeKind::kForward) ||
2211             (edge_kind == EdgeKind::kCaseFallThrough)) {
2212           normal_forward_edges.push_back(dest);
2213         }
2214         if (edge_kind == EdgeKind::kIfBreak) {
2215           if_break_edges.push_back(dest);
2216         }
2217 
2218         if ((edge_kind == EdgeKind::kForward) ||
2219             (edge_kind == EdgeKind::kCaseFallThrough)) {
2220           // Check for an invalid forward exit out of this construct.
2221           if (dest_info->pos > src_construct.end_pos) {
2222             // In most cases we're bypassing the merge block for the source
2223             // construct.
2224             auto end_block = src_construct.end_id;
2225             const char* end_block_desc = "merge block";
2226             if (src_construct.kind == Construct::kLoop) {
2227               // For a loop construct, we have two valid places to go: the
2228               // continue target or the merge for the loop header, which is
2229               // further down.
2230               const auto loop_merge =
2231                   GetBlockInfo(src_construct.begin_id)->merge_for_header;
2232               if (dest_info->pos >= GetBlockInfo(loop_merge)->pos) {
2233                 // We're bypassing the loop's merge block.
2234                 end_block = loop_merge;
2235               } else {
2236                 // We're bypassing the loop's continue target, and going into
2237                 // the middle of the continue construct.
2238                 end_block_desc = "continue target";
2239               }
2240             }
2241             return Fail()
2242                    << "Branch from block " << src << " to block " << dest
2243                    << " is an invalid exit from construct starting at block "
2244                    << src_construct.begin_id << "; branch bypasses "
2245                    << end_block_desc << " " << end_block;
2246           }
2247 
2248           // Check dominance.
2249 
2250           //      Look for edges that violate the dominance condition: a branch
2251           //      from X to Y where:
2252           //        If Y is in a nearest enclosing continue construct headed by
2253           //        CT:
2254           //          Y is not CT, and
2255           //          In the structured order, X appears before CT order or
2256           //          after CT's backedge block.
2257           //        Otherwise, if Y is in a nearest enclosing construct
2258           //        headed by H:
2259           //          Y is not H, and
2260           //          In the structured order, X appears before H or after H's
2261           //          merge block.
2262 
2263           const auto& dest_construct = *(dest_info->construct);
2264           if (dest != dest_construct.begin_id &&
2265               !dest_construct.ContainsPos(src_pos)) {
2266             return Fail() << "Branch from " << src << " to " << dest
2267                           << " bypasses "
2268                           << (dest_construct.kind == Construct::kContinue
2269                                   ? "continue target "
2270                                   : "header ")
2271                           << dest_construct.begin_id
2272                           << " (dominance rule violated)";
2273           }
2274         }
2275       }  // end forward edge
2276     }    // end successor
2277 
2278     if (num_backedges > 1) {
2279       return Fail() << "Block " << src
2280                     << " has too many backedges: " << num_backedges;
2281     }
2282     if ((normal_forward_edges.size() > 1) &&
2283         (src_info->merge_for_header == 0)) {
2284       return Fail() << "Control flow diverges at block " << src << " (to "
2285                     << normal_forward_edges[0] << ", "
2286                     << normal_forward_edges[1]
2287                     << ") but it is not a structured header (it has no merge "
2288                        "instruction)";
2289     }
2290     if ((normal_forward_edges.size() + if_break_edges.size() > 1) &&
2291         (src_info->merge_for_header == 0)) {
2292       // There is a branch to the merge of an if-selection combined
2293       // with an other normal forward branch.  Control within the
2294       // if-selection needs to be gated by a flow predicate.
2295       for (auto if_break_dest : if_break_edges) {
2296         auto* head_info =
2297             GetBlockInfo(GetBlockInfo(if_break_dest)->header_for_merge);
2298         // Generate a guard name, but only once.
2299         if (head_info->flow_guard_name.empty()) {
2300           const std::string guard = "guard" + std::to_string(head_info->id);
2301           head_info->flow_guard_name = namer_.MakeDerivedName(guard);
2302         }
2303       }
2304     }
2305   }
2306 
2307   return success();
2308 }
2309 
FindIfSelectionInternalHeaders()2310 bool FunctionEmitter::FindIfSelectionInternalHeaders() {
2311   if (failed()) {
2312     return false;
2313   }
2314   for (auto& construct : constructs_) {
2315     if (construct->kind != Construct::kIfSelection) {
2316       continue;
2317     }
2318     auto* if_header_info = GetBlockInfo(construct->begin_id);
2319     const auto* branch = if_header_info->basic_block->terminator();
2320     const auto true_head = branch->GetSingleWordInOperand(1);
2321     const auto false_head = branch->GetSingleWordInOperand(2);
2322 
2323     auto* true_head_info = GetBlockInfo(true_head);
2324     auto* false_head_info = GetBlockInfo(false_head);
2325     const auto true_head_pos = true_head_info->pos;
2326     const auto false_head_pos = false_head_info->pos;
2327 
2328     const bool contains_true = construct->ContainsPos(true_head_pos);
2329     const bool contains_false = construct->ContainsPos(false_head_pos);
2330 
2331     // The cases for each edge are:
2332     //  - kBack: invalid because it's an invalid exit from the selection
2333     //  - kSwitchBreak ; record this for later special processing
2334     //  - kLoopBreak ; record this for later special processing
2335     //  - kLoopContinue ; record this for later special processing
2336     //  - kIfBreak; normal case, may require a guard variable.
2337     //  - kFallThrough; invalid exit from the selection
2338     //  - kForward; normal case
2339 
2340     if_header_info->true_kind = if_header_info->succ_edge[true_head];
2341     if_header_info->false_kind = if_header_info->succ_edge[false_head];
2342     if (contains_true) {
2343       if_header_info->true_head = true_head;
2344     }
2345     if (contains_false) {
2346       if_header_info->false_head = false_head;
2347     }
2348 
2349     if (contains_true && (true_head_info->header_for_merge != 0) &&
2350         (true_head_info->header_for_merge != construct->begin_id)) {
2351       // The OpBranchConditional instruction for the true head block is an
2352       // alternate path to the merge block of a construct nested inside the
2353       // selection, and hence the merge block is not dominated by its own
2354       // (different) header.
2355       return Fail() << "Block " << true_head
2356                     << " is the true branch for if-selection header "
2357                     << construct->begin_id
2358                     << " and also the merge block for header block "
2359                     << true_head_info->header_for_merge
2360                     << " (violates dominance rule)";
2361     }
2362     if (contains_false && (false_head_info->header_for_merge != 0) &&
2363         (false_head_info->header_for_merge != construct->begin_id)) {
2364       // The OpBranchConditional instruction for the false head block is an
2365       // alternate path to the merge block of a construct nested inside the
2366       // selection, and hence the merge block is not dominated by its own
2367       // (different) header.
2368       return Fail() << "Block " << false_head
2369                     << " is the false branch for if-selection header "
2370                     << construct->begin_id
2371                     << " and also the merge block for header block "
2372                     << false_head_info->header_for_merge
2373                     << " (violates dominance rule)";
2374     }
2375 
2376     if (contains_true && contains_false && (true_head_pos != false_head_pos)) {
2377       // This construct has both a "then" clause and an "else" clause.
2378       //
2379       // We have this structure:
2380       //
2381       //   Option 1:
2382       //
2383       //     * condbranch
2384       //        * true-head (start of then-clause)
2385       //        ...
2386       //        * end-then-clause
2387       //        * false-head (start of else-clause)
2388       //        ...
2389       //        * end-false-clause
2390       //        * premerge-head
2391       //        ...
2392       //     * selection merge
2393       //
2394       //   Option 2:
2395       //
2396       //     * condbranch
2397       //        * true-head (start of then-clause)
2398       //        ...
2399       //        * end-then-clause
2400       //        * false-head (start of else-clause) and also premerge-head
2401       //        ...
2402       //        * end-false-clause
2403       //     * selection merge
2404       //
2405       //   Option 3:
2406       //
2407       //     * condbranch
2408       //        * false-head (start of else-clause)
2409       //        ...
2410       //        * end-else-clause
2411       //        * true-head (start of then-clause) and also premerge-head
2412       //        ...
2413       //        * end-then-clause
2414       //     * selection merge
2415       //
2416       // The premerge-head exists if there is a kForward branch from the end
2417       // of the first clause to a block within the surrounding selection.
2418       // The first clause might be a then-clause or an else-clause.
2419       const auto second_head = std::max(true_head_pos, false_head_pos);
2420       const auto end_first_clause_pos = second_head - 1;
2421       TINT_ASSERT(Reader, end_first_clause_pos < block_order_.size());
2422       const auto end_first_clause = block_order_[end_first_clause_pos];
2423       uint32_t premerge_id = 0;
2424       uint32_t if_break_id = 0;
2425       for (auto& then_succ_iter : GetBlockInfo(end_first_clause)->succ_edge) {
2426         const uint32_t dest_id = then_succ_iter.first;
2427         const auto edge_kind = then_succ_iter.second;
2428         switch (edge_kind) {
2429           case EdgeKind::kIfBreak:
2430             if_break_id = dest_id;
2431             break;
2432           case EdgeKind::kForward: {
2433             if (construct->ContainsPos(GetBlockInfo(dest_id)->pos)) {
2434               // It's a premerge.
2435               if (premerge_id != 0) {
2436                 // TODO(dneto): I think this is impossible to trigger at this
2437                 // point in the flow. It would require a merge instruction to
2438                 // get past the check of "at-most-one-forward-edge".
2439                 return Fail()
2440                        << "invalid structure: then-clause headed by block "
2441                        << true_head << " ending at block " << end_first_clause
2442                        << " has two forward edges to within selection"
2443                        << " going to " << premerge_id << " and " << dest_id;
2444               }
2445               premerge_id = dest_id;
2446               auto* dest_block_info = GetBlockInfo(dest_id);
2447               if_header_info->premerge_head = dest_id;
2448               if (dest_block_info->header_for_merge != 0) {
2449                 // Premerge has two edges coming into it, from the then-clause
2450                 // and the else-clause. It's also, by construction, not the
2451                 // merge block of the if-selection.  So it must not be a merge
2452                 // block itself. The OpBranchConditional instruction for the
2453                 // false head block is an alternate path to the merge block, and
2454                 // hence the merge block is not dominated by its own (different)
2455                 // header.
2456                 return Fail()
2457                        << "Block " << premerge_id << " is the merge block for "
2458                        << dest_block_info->header_for_merge
2459                        << " but has alternate paths reaching it, starting from"
2460                        << " blocks " << true_head << " and " << false_head
2461                        << " which are the true and false branches for the"
2462                        << " if-selection header block " << construct->begin_id
2463                        << " (violates dominance rule)";
2464               }
2465             }
2466             break;
2467           }
2468           default:
2469             break;
2470         }
2471       }
2472       if (if_break_id != 0 && premerge_id != 0) {
2473         return Fail() << "Block " << end_first_clause
2474                       << " in if-selection headed at block "
2475                       << construct->begin_id
2476                       << " branches to both the merge block " << if_break_id
2477                       << " and also to block " << premerge_id
2478                       << " later in the selection";
2479       }
2480     }
2481   }
2482   return success();
2483 }
2484 
EmitFunctionVariables()2485 bool FunctionEmitter::EmitFunctionVariables() {
2486   if (failed()) {
2487     return false;
2488   }
2489   for (auto& inst : *function_.entry()) {
2490     if (inst.opcode() != SpvOpVariable) {
2491       continue;
2492     }
2493     auto* var_store_type = GetVariableStoreType(inst);
2494     if (failed()) {
2495       return false;
2496     }
2497     const ast::Expression* constructor = nullptr;
2498     if (inst.NumInOperands() > 1) {
2499       // SPIR-V initializers are always constants.
2500       // (OpenCL also allows the ID of an OpVariable, but we don't handle that
2501       // here.)
2502       constructor =
2503           parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1))
2504               .expr;
2505       if (!constructor) {
2506         return false;
2507       }
2508     }
2509     auto* var = parser_impl_.MakeVariable(
2510         inst.result_id(), ast::StorageClass::kNone, var_store_type, false,
2511         constructor, ast::DecorationList{});
2512     auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
2513     AddStatement(var_decl_stmt);
2514     auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
2515     identifier_types_.emplace(inst.result_id(), var_type);
2516   }
2517   return success();
2518 }
2519 
AddressOfIfNeeded(TypedExpression expr,const spvtools::opt::Instruction * inst)2520 TypedExpression FunctionEmitter::AddressOfIfNeeded(
2521     TypedExpression expr,
2522     const spvtools::opt::Instruction* inst) {
2523   if (inst && expr) {
2524     if (auto* spirv_type = type_mgr_->GetType(inst->type_id())) {
2525       if (expr.type->Is<Reference>() && spirv_type->AsPointer()) {
2526         return AddressOf(expr);
2527       }
2528     }
2529   }
2530   return expr;
2531 }
2532 
MakeExpression(uint32_t id)2533 TypedExpression FunctionEmitter::MakeExpression(uint32_t id) {
2534   if (failed()) {
2535     return {};
2536   }
2537   switch (GetSkipReason(id)) {
2538     case SkipReason::kDontSkip:
2539       break;
2540     case SkipReason::kOpaqueObject:
2541       Fail() << "internal error: unhandled use of opaque object with ID: "
2542              << id;
2543       return {};
2544     case SkipReason::kSinkPointerIntoUse: {
2545       // Replace the pointer with its source reference expression.
2546       auto source_expr = GetDefInfo(id)->sink_pointer_source_expr;
2547       TINT_ASSERT(Reader, source_expr.type->Is<Reference>());
2548       return source_expr;
2549     }
2550     case SkipReason::kPointSizeBuiltinValue: {
2551       return {ty_.F32(), create<ast::FloatLiteralExpression>(Source{}, 1.0f)};
2552     }
2553     case SkipReason::kPointSizeBuiltinPointer:
2554       Fail() << "unhandled use of a pointer to the PointSize builtin, with ID: "
2555              << id;
2556       return {};
2557     case SkipReason::kSampleMaskInBuiltinPointer:
2558       Fail()
2559           << "unhandled use of a pointer to the SampleMask builtin, with ID: "
2560           << id;
2561       return {};
2562     case SkipReason::kSampleMaskOutBuiltinPointer: {
2563       // The result type is always u32.
2564       auto name = namer_.Name(sample_mask_out_id);
2565       return TypedExpression{ty_.U32(),
2566                              create<ast::IdentifierExpression>(
2567                                  Source{}, builder_.Symbols().Register(name))};
2568     }
2569   }
2570   auto type_it = identifier_types_.find(id);
2571   if (type_it != identifier_types_.end()) {
2572     auto name = namer_.Name(id);
2573     auto* type = type_it->second;
2574     return TypedExpression{type,
2575                            create<ast::IdentifierExpression>(
2576                                Source{}, builder_.Symbols().Register(name))};
2577   }
2578   if (parser_impl_.IsScalarSpecConstant(id)) {
2579     auto name = namer_.Name(id);
2580     return TypedExpression{
2581         parser_impl_.ConvertType(def_use_mgr_->GetDef(id)->type_id()),
2582         create<ast::IdentifierExpression>(Source{},
2583                                           builder_.Symbols().Register(name))};
2584   }
2585   if (singly_used_values_.count(id)) {
2586     auto expr = std::move(singly_used_values_[id]);
2587     singly_used_values_.erase(id);
2588     return expr;
2589   }
2590   const auto* spirv_constant = constant_mgr_->FindDeclaredConstant(id);
2591   if (spirv_constant) {
2592     return parser_impl_.MakeConstantExpression(id);
2593   }
2594   const auto* inst = def_use_mgr_->GetDef(id);
2595   if (inst == nullptr) {
2596     Fail() << "ID " << id << " does not have a defining SPIR-V instruction";
2597     return {};
2598   }
2599   switch (inst->opcode()) {
2600     case SpvOpVariable: {
2601       // This occurs for module-scope variables.
2602       auto name = namer_.Name(inst->result_id());
2603       return TypedExpression{
2604           parser_impl_.ConvertType(inst->type_id(), PtrAs::Ref),
2605           create<ast::IdentifierExpression>(Source{},
2606                                             builder_.Symbols().Register(name))};
2607     }
2608     case SpvOpUndef:
2609       // Substitute a null value for undef.
2610       // This case occurs when OpUndef appears at module scope, as if it were
2611       // a constant.
2612       return parser_impl_.MakeNullExpression(
2613           parser_impl_.ConvertType(inst->type_id()));
2614 
2615     default:
2616       break;
2617   }
2618   if (const spvtools::opt::BasicBlock* const bb =
2619           ir_context_.get_instr_block(id)) {
2620     if (auto* block = GetBlockInfo(bb->id())) {
2621       if (block->pos == kInvalidBlockPos) {
2622         // The value came from a block not in the block order.
2623         // Substitute a null value.
2624         return parser_impl_.MakeNullExpression(
2625             parser_impl_.ConvertType(inst->type_id()));
2626       }
2627     }
2628   }
2629   Fail() << "unhandled expression for ID " << id << "\n" << inst->PrettyPrint();
2630   return {};
2631 }
2632 
EmitFunctionBodyStatements()2633 bool FunctionEmitter::EmitFunctionBodyStatements() {
2634   // Dump the basic blocks in order, grouped by construct.
2635 
2636   // We maintain a stack of StatementBlock objects, where new statements
2637   // are always written to the topmost entry of the stack. By this point in
2638   // processing, we have already recorded the interesting control flow
2639   // boundaries in the BlockInfo and associated Construct objects. As we
2640   // enter a new statement grouping, we push onto the stack, and also schedule
2641   // the statement block's completion and removal at a future block's ID.
2642 
2643   // Upon entry, the statement stack has one entry representing the whole
2644   // function.
2645   TINT_ASSERT(Reader, !constructs_.empty());
2646   Construct* function_construct = constructs_[0].get();
2647   TINT_ASSERT(Reader, function_construct != nullptr);
2648   TINT_ASSERT(Reader, function_construct->kind == Construct::kFunction);
2649   // Make the first entry valid by filling in the construct field, which
2650   // had not been computed at the time the entry was first created.
2651   // TODO(dneto): refactor how the first construct is created vs.
2652   // this statements stack entry is populated.
2653   TINT_ASSERT(Reader, statements_stack_.size() == 1);
2654   statements_stack_[0].SetConstruct(function_construct);
2655 
2656   for (auto block_id : block_order()) {
2657     if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
2658       return false;
2659     }
2660   }
2661   return success();
2662 }
2663 
EmitBasicBlock(const BlockInfo & block_info)2664 bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
2665   // Close off previous constructs.
2666   while (!statements_stack_.empty() &&
2667          (statements_stack_.back().GetEndId() == block_info.id)) {
2668     statements_stack_.back().Finalize(&builder_);
2669     statements_stack_.pop_back();
2670   }
2671   if (statements_stack_.empty()) {
2672     return Fail() << "internal error: statements stack empty at block "
2673                   << block_info.id;
2674   }
2675 
2676   // Enter new constructs.
2677 
2678   std::vector<const Construct*> entering_constructs;  // inner most comes first
2679   {
2680     auto* here = block_info.construct;
2681     auto* const top_construct = statements_stack_.back().GetConstruct();
2682     while (here != top_construct) {
2683       // Only enter a construct at its header block.
2684       if (here->begin_id == block_info.id) {
2685         entering_constructs.push_back(here);
2686       }
2687       here = here->parent;
2688     }
2689   }
2690   // What constructs can we have entered?
2691   // - It can't be kFunction, because there is only one of those, and it was
2692   //   already on the stack at the outermost level.
2693   // - We have at most one of kSwitchSelection, or kLoop because each of those
2694   //   is headed by a block with a merge instruction (OpLoopMerge for kLoop,
2695   //   and OpSelectionMerge for kSwitchSelection).
2696   // - When there is a kIfSelection, it can't contain another construct,
2697   //   because both would have to have their own distinct merge instructions
2698   //   and distinct terminators.
2699   // - A kContinue can contain a kContinue
2700   //   This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
2701   //   that a block can be continue target for at most one header block. See
2702   //   test DISABLED_BlockIsContinueForMoreThanOneHeader. If we generalize this,
2703   //   then by a dominance argument, the inner loop continue target can only be
2704   //   a single-block loop.
2705   // TODO(dneto): Handle this case.
2706   // - If a kLoop is on the outside, its terminator is either:
2707   //   - an OpBranch, in which case there is no other construct.
2708   //   - an OpBranchConditional, in which case there is either an kIfSelection
2709   //     (when both branch targets are different and are inside the loop),
2710   //     or no other construct (because the branch targets are the same,
2711   //     or one of them is a break or continue).
2712   // - All that's left is a kContinue on the outside, and one of
2713   //   kIfSelection, kSwitchSelection, kLoop on the inside.
2714   //
2715   //   The kContinue can be the parent of the other.  For example, a selection
2716   //   starting at the first block of a continue construct.
2717   //
2718   //   The kContinue can't be the child of the other because either:
2719   //     - The other can't be kLoop because:
2720   //        - If the kLoop is for a different loop then the kContinue, then
2721   //          the kContinue must be its own loop header, and so the same
2722   //          block is two different loops. That's a contradiction.
2723   //        - If the kLoop is for a the same loop, then this is a contradiction
2724   //          because a kContinue and its kLoop have disjoint block sets.
2725   //     - The other construct can't be a selection because:
2726   //       - The kContinue construct is the entire loop, i.e. the continue
2727   //         target is its own loop header block.  But then the continue target
2728   //         has an OpLoopMerge instruction, which contradicts this block being
2729   //         a selection header.
2730   //       - The kContinue is in a multi-block loop that is has a non-empty
2731   //         kLoop; and the selection contains the kContinue block but not the
2732   //         loop block. That breaks dominance rules. That is, the continue
2733   //         target is dominated by that loop header, and so gets found by the
2734   //         block traversal on the outside before the selection is found. The
2735   //         selection is inside the outer loop.
2736   //
2737   // So we fall into one of the following cases:
2738   //  - We are entering 0 or 1 constructs, or
2739   //  - We are entering 2 constructs, with the outer one being a kContinue or
2740   //    kLoop, the inner one is not a continue.
2741   if (entering_constructs.size() > 2) {
2742     return Fail() << "internal error: bad construct nesting found";
2743   }
2744   if (entering_constructs.size() == 2) {
2745     auto inner_kind = entering_constructs[0]->kind;
2746     auto outer_kind = entering_constructs[1]->kind;
2747     if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
2748       return Fail()
2749              << "internal error: bad construct nesting. Only a Continue "
2750                 "or a Loop construct can be outer construct on same block.  "
2751                 "Got outer kind "
2752              << int(outer_kind) << " inner kind " << int(inner_kind);
2753     }
2754     if (inner_kind == Construct::kContinue) {
2755       return Fail() << "internal error: unsupported construct nesting: "
2756                        "Continue around Continue";
2757     }
2758     if (inner_kind != Construct::kIfSelection &&
2759         inner_kind != Construct::kSwitchSelection &&
2760         inner_kind != Construct::kLoop) {
2761       return Fail() << "internal error: bad construct nesting. Continue around "
2762                        "something other than if, switch, or loop";
2763     }
2764   }
2765 
2766   // Enter constructs from outermost to innermost.
2767   // kLoop and kContinue push a new statement-block onto the stack before
2768   // emitting statements in the block.
2769   // kIfSelection and kSwitchSelection emit statements in the block and then
2770   // emit push a new statement-block. Only emit the statements in the block
2771   // once.
2772 
2773   // Have we emitted the statements for this block?
2774   bool emitted = false;
2775 
2776   // When entering an if-selection or switch-selection, we will emit the WGSL
2777   // construct to cause the divergent branching.  But otherwise, we will
2778   // emit a "normal" block terminator, which occurs at the end of this method.
2779   bool has_normal_terminator = true;
2780 
2781   for (auto iter = entering_constructs.rbegin();
2782        iter != entering_constructs.rend(); ++iter) {
2783     const Construct* construct = *iter;
2784 
2785     switch (construct->kind) {
2786       case Construct::kFunction:
2787         return Fail() << "internal error: nested function construct";
2788 
2789       case Construct::kLoop:
2790         if (!EmitLoopStart(construct)) {
2791           return false;
2792         }
2793         if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2794           return false;
2795         }
2796         break;
2797 
2798       case Construct::kContinue:
2799         if (block_info.is_continue_entire_loop) {
2800           if (!EmitLoopStart(construct)) {
2801             return false;
2802           }
2803           if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2804             return false;
2805           }
2806         } else {
2807           if (!EmitContinuingStart(construct)) {
2808             return false;
2809           }
2810         }
2811         break;
2812 
2813       case Construct::kIfSelection:
2814         if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2815           return false;
2816         }
2817         if (!EmitIfStart(block_info)) {
2818           return false;
2819         }
2820         has_normal_terminator = false;
2821         break;
2822 
2823       case Construct::kSwitchSelection:
2824         if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2825           return false;
2826         }
2827         if (!EmitSwitchStart(block_info)) {
2828           return false;
2829         }
2830         has_normal_terminator = false;
2831         break;
2832     }
2833   }
2834 
2835   // If we aren't starting or transitioning, then emit the normal
2836   // statements now.
2837   if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
2838     return false;
2839   }
2840 
2841   if (has_normal_terminator) {
2842     if (!EmitNormalTerminator(block_info)) {
2843       return false;
2844     }
2845   }
2846   return success();
2847 }
2848 
EmitIfStart(const BlockInfo & block_info)2849 bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
2850   // The block is the if-header block.  So its construct is the if construct.
2851   auto* construct = block_info.construct;
2852   TINT_ASSERT(Reader, construct->kind == Construct::kIfSelection);
2853   TINT_ASSERT(Reader, construct->begin_id == block_info.id);
2854 
2855   const uint32_t true_head = block_info.true_head;
2856   const uint32_t false_head = block_info.false_head;
2857   const uint32_t premerge_head = block_info.premerge_head;
2858 
2859   const std::string guard_name = block_info.flow_guard_name;
2860   if (!guard_name.empty()) {
2861     // Declare the guard variable just before the "if", initialized to true.
2862     auto* guard_var =
2863         builder_.Var(guard_name, builder_.ty.bool_(), MakeTrue(Source{}));
2864     auto* guard_decl = create<ast::VariableDeclStatement>(Source{}, guard_var);
2865     AddStatement(guard_decl);
2866   }
2867 
2868   const auto condition_id =
2869       block_info.basic_block->terminator()->GetSingleWordInOperand(0);
2870   auto* cond = MakeExpression(condition_id).expr;
2871   if (!cond) {
2872     return false;
2873   }
2874   // Generate the code for the condition.
2875   auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
2876 
2877   // Compute the block IDs that should end the then-clause and the else-clause.
2878 
2879   // We need to know where the *emitted* selection should end, i.e. the intended
2880   // merge block id.  That should be the current premerge block, if it exists,
2881   // or otherwise the declared merge block.
2882   //
2883   // This is another way to think about it:
2884   //   If there is a premerge, then there are three cases:
2885   //    - premerge_head is different from the true_head and false_head:
2886   //      - Premerge comes last. In effect, move the selection merge up
2887   //        to where the premerge begins.
2888   //    - premerge_head is the same as the false_head
2889   //      - This is really an if-then without an else clause.
2890   //        Move the merge up to where the premerge is.
2891   //    - premerge_head is the same as the true_head
2892   //      - This is really an if-else without an then clause.
2893   //        Emit it as:   if (cond) {} else {....}
2894   //        Move the merge up to where the premerge is.
2895   const uint32_t intended_merge =
2896       premerge_head ? premerge_head : construct->end_id;
2897 
2898   // then-clause:
2899   //   If true_head exists:
2900   //     spans from true head to the earlier of the false head (if it exists)
2901   //     or the selection merge.
2902   //   Otherwise:
2903   //     ends at from the false head (if it exists), otherwise the selection
2904   //     end.
2905   const uint32_t then_end = false_head ? false_head : intended_merge;
2906 
2907   // else-clause:
2908   //   ends at the premerge head (if it exists) or at the selection end.
2909   const uint32_t else_end = premerge_head ? premerge_head : intended_merge;
2910 
2911   const bool true_is_break = (block_info.true_kind == EdgeKind::kSwitchBreak) ||
2912                              (block_info.true_kind == EdgeKind::kLoopBreak);
2913   const bool false_is_break =
2914       (block_info.false_kind == EdgeKind::kSwitchBreak) ||
2915       (block_info.false_kind == EdgeKind::kLoopBreak);
2916   const bool true_is_continue = block_info.true_kind == EdgeKind::kLoopContinue;
2917   const bool false_is_continue =
2918       block_info.false_kind == EdgeKind::kLoopContinue;
2919 
2920   // Push statement blocks for the then-clause and the else-clause.
2921   // But make sure we do it in the right order.
2922   auto push_else = [this, builder, else_end, construct, false_is_break,
2923                     false_is_continue]() {
2924     // Push the else clause onto the stack first.
2925     PushNewStatementBlock(
2926         construct, else_end, [=](const ast::StatementList& stmts) {
2927           // Only set the else-clause if there are statements to fill it.
2928           if (!stmts.empty()) {
2929             // The "else" consists of the statement list from the top of
2930             // statements stack, without an elseif condition.
2931             auto* else_body = create<ast::BlockStatement>(Source{}, stmts);
2932             builder->else_stmts.emplace_back(
2933                 create<ast::ElseStatement>(Source{}, nullptr, else_body));
2934           }
2935         });
2936     if (false_is_break) {
2937       AddStatement(create<ast::BreakStatement>(Source{}));
2938     }
2939     if (false_is_continue) {
2940       AddStatement(create<ast::ContinueStatement>(Source{}));
2941     }
2942   };
2943 
2944   if (!true_is_break && !true_is_continue &&
2945       (GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos)) {
2946     // Process the else-clause first.  The then-clause will be empty so avoid
2947     // pushing onto the stack at all.
2948     push_else();
2949   } else {
2950     // Blocks for the then-clause appear before blocks for the else-clause.
2951     // So push the else-clause handling onto the stack first. The else-clause
2952     // might be empty, but this works anyway.
2953 
2954     // Handle the premerge, if it exists.
2955     if (premerge_head) {
2956       // The top of the stack is the statement block that is the parent of the
2957       // if-statement. Adding statements now will place them after that 'if'.
2958       if (guard_name.empty()) {
2959         // We won't have a flow guard for the premerge.
2960         // Insert a trivial if(true) { ... } around the blocks from the
2961         // premerge head until the end of the if-selection.  This is needed
2962         // to ensure uniform reconvergence occurs at the end of the if-selection
2963         // just like in the original SPIR-V.
2964         PushTrueGuard(construct->end_id);
2965       } else {
2966         // Add a flow guard around the blocks in the premerge area.
2967         PushGuard(guard_name, construct->end_id);
2968       }
2969     }
2970 
2971     push_else();
2972     if (true_head && false_head && !guard_name.empty()) {
2973       // There are non-trivial then and else clauses.
2974       // We have to guard the start of the else.
2975       PushGuard(guard_name, else_end);
2976     }
2977 
2978     // Push the then clause onto the stack.
2979     PushNewStatementBlock(
2980         construct, then_end, [=](const ast::StatementList& stmts) {
2981           builder->body = create<ast::BlockStatement>(Source{}, stmts);
2982         });
2983     if (true_is_break) {
2984       AddStatement(create<ast::BreakStatement>(Source{}));
2985     }
2986     if (true_is_continue) {
2987       AddStatement(create<ast::ContinueStatement>(Source{}));
2988     }
2989   }
2990 
2991   return success();
2992 }
2993 
EmitSwitchStart(const BlockInfo & block_info)2994 bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
2995   // The block is the if-header block.  So its construct is the if construct.
2996   auto* construct = block_info.construct;
2997   TINT_ASSERT(Reader, construct->kind == Construct::kSwitchSelection);
2998   TINT_ASSERT(Reader, construct->begin_id == block_info.id);
2999   const auto* branch = block_info.basic_block->terminator();
3000 
3001   const auto selector_id = branch->GetSingleWordInOperand(0);
3002   // Generate the code for the selector.
3003   auto selector = MakeExpression(selector_id);
3004   if (!selector) {
3005     return false;
3006   }
3007   // First, push the statement block for the entire switch.
3008   auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
3009 
3010   // Grab a pointer to the case list.  It will get buried in the statement block
3011   // stack.
3012   PushNewStatementBlock(construct, construct->end_id, nullptr);
3013 
3014   // We will push statement-blocks onto the stack to gather the statements in
3015   // the default clause and cases clauses. Determine the list of blocks
3016   // that start each clause.
3017   std::vector<const BlockInfo*> clause_heads;
3018 
3019   // Collect the case clauses, even if they are just the merge block.
3020   // First the default clause.
3021   const auto default_id = branch->GetSingleWordInOperand(1);
3022   const auto* default_info = GetBlockInfo(default_id);
3023   clause_heads.push_back(default_info);
3024   // Now the case clauses.
3025   for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
3026     const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
3027     clause_heads.push_back(GetBlockInfo(case_target_id));
3028   }
3029 
3030   std::stable_sort(clause_heads.begin(), clause_heads.end(),
3031                    [](const BlockInfo* lhs, const BlockInfo* rhs) {
3032                      return lhs->pos < rhs->pos;
3033                    });
3034   // Remove duplicates
3035   {
3036     // Use read index r, and write index w.
3037     // Invariant: w <= r;
3038     size_t w = 0;
3039     for (size_t r = 0; r < clause_heads.size(); ++r) {
3040       if (clause_heads[r] != clause_heads[w]) {
3041         ++w;  // Advance the write cursor.
3042       }
3043       clause_heads[w] = clause_heads[r];
3044     }
3045     // We know it's not empty because it always has at least a default clause.
3046     TINT_ASSERT(Reader, !clause_heads.empty());
3047     clause_heads.resize(w + 1);
3048   }
3049 
3050   // Push them on in reverse order.
3051   const auto last_clause_index = clause_heads.size() - 1;
3052   for (size_t i = last_clause_index;; --i) {
3053     // Create a list of integer literals for the selector values leading to
3054     // this case clause.
3055     ast::CaseSelectorList selectors;
3056     const auto* values_ptr = clause_heads[i]->case_values.get();
3057     const bool has_selectors = (values_ptr && !values_ptr->empty());
3058     if (has_selectors) {
3059       std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
3060       std::stable_sort(values.begin(), values.end());
3061       for (auto value : values) {
3062         // The rest of this module can handle up to 64 bit switch values.
3063         // The Tint AST handles 32-bit values.
3064         const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
3065         if (selector.type->IsUnsignedScalarOrVector()) {
3066           selectors.emplace_back(
3067               create<ast::UintLiteralExpression>(Source{}, value32));
3068         } else {
3069           selectors.emplace_back(
3070               create<ast::SintLiteralExpression>(Source{}, value32));
3071         }
3072       }
3073     }
3074 
3075     // Where does this clause end?
3076     const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
3077                                                       : construct->end_id;
3078 
3079     // Reserve the case clause slot in swch->cases, push the new statement block
3080     // for the case, and fill the case clause once the block is generated.
3081     auto case_idx = swch->cases.size();
3082     swch->cases.emplace_back(nullptr);
3083     PushNewStatementBlock(
3084         construct, end_id, [=](const ast::StatementList& stmts) {
3085           auto* body = create<ast::BlockStatement>(Source{}, stmts);
3086           swch->cases[case_idx] =
3087               create<ast::CaseStatement>(Source{}, selectors, body);
3088         });
3089 
3090     if ((default_info == clause_heads[i]) && has_selectors &&
3091         construct->ContainsPos(default_info->pos)) {
3092       // Generate a default clause with a just fallthrough.
3093       auto* stmts = create<ast::BlockStatement>(
3094           Source{}, ast::StatementList{
3095                         create<ast::FallthroughStatement>(Source{}),
3096                     });
3097       auto* case_stmt =
3098           create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
3099       swch->cases.emplace_back(case_stmt);
3100     }
3101 
3102     if (i == 0) {
3103       break;
3104     }
3105   }
3106 
3107   return success();
3108 }
3109 
EmitLoopStart(const Construct * construct)3110 bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
3111   auto* builder = AddStatementBuilder<LoopStatementBuilder>();
3112   PushNewStatementBlock(
3113       construct, construct->end_id, [=](const ast::StatementList& stmts) {
3114         builder->body = create<ast::BlockStatement>(Source{}, stmts);
3115       });
3116   return success();
3117 }
3118 
EmitContinuingStart(const Construct * construct)3119 bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
3120   // A continue construct has the same depth as its associated loop
3121   // construct. Start a continue construct.
3122   auto* loop_candidate = LastStatement();
3123   auto* loop = loop_candidate->As<LoopStatementBuilder>();
3124   if (loop == nullptr) {
3125     return Fail() << "internal error: starting continue construct, "
3126                      "expected loop on top of stack";
3127   }
3128   PushNewStatementBlock(
3129       construct, construct->end_id, [=](const ast::StatementList& stmts) {
3130         loop->continuing = create<ast::BlockStatement>(Source{}, stmts);
3131       });
3132 
3133   return success();
3134 }
3135 
EmitNormalTerminator(const BlockInfo & block_info)3136 bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
3137   const auto& terminator = *(block_info.basic_block->terminator());
3138   switch (terminator.opcode()) {
3139     case SpvOpReturn:
3140       AddStatement(create<ast::ReturnStatement>(Source{}));
3141       return true;
3142     case SpvOpReturnValue: {
3143       auto value = MakeExpression(terminator.GetSingleWordInOperand(0));
3144       if (!value) {
3145         return false;
3146       }
3147       AddStatement(create<ast::ReturnStatement>(Source{}, value.expr));
3148     }
3149       return true;
3150     case SpvOpKill:
3151       // For now, assume SPIR-V OpKill has same semantics as WGSL discard.
3152       // TODO(dneto): https://github.com/gpuweb/gpuweb/issues/676
3153       AddStatement(create<ast::DiscardStatement>(Source{}));
3154       return true;
3155     case SpvOpUnreachable:
3156       // Translate as if it's a return. This avoids the problem where WGSL
3157       // requires a return statement at the end of the function body.
3158       {
3159         const auto* result_type = type_mgr_->GetType(function_.type_id());
3160         if (result_type->AsVoid() != nullptr) {
3161           AddStatement(create<ast::ReturnStatement>(Source{}));
3162         } else {
3163           auto* ast_type = parser_impl_.ConvertType(function_.type_id());
3164           AddStatement(create<ast::ReturnStatement>(
3165               Source{}, parser_impl_.MakeNullValue(ast_type)));
3166         }
3167       }
3168       return true;
3169     case SpvOpBranch: {
3170       const auto dest_id = terminator.GetSingleWordInOperand(0);
3171       AddStatement(MakeBranch(block_info, *GetBlockInfo(dest_id)));
3172       return true;
3173     }
3174     case SpvOpBranchConditional: {
3175       // If both destinations are the same, then do the same as we would
3176       // for an unconditional branch (OpBranch).
3177       const auto true_dest = terminator.GetSingleWordInOperand(1);
3178       const auto false_dest = terminator.GetSingleWordInOperand(2);
3179       if (true_dest == false_dest) {
3180         // This is like an unconditional branch.
3181         AddStatement(MakeBranch(block_info, *GetBlockInfo(true_dest)));
3182         return true;
3183       }
3184 
3185       const EdgeKind true_kind = block_info.succ_edge.find(true_dest)->second;
3186       const EdgeKind false_kind = block_info.succ_edge.find(false_dest)->second;
3187       auto* const true_info = GetBlockInfo(true_dest);
3188       auto* const false_info = GetBlockInfo(false_dest);
3189       auto* cond = MakeExpression(terminator.GetSingleWordInOperand(0)).expr;
3190       if (!cond) {
3191         return false;
3192       }
3193 
3194       // We have two distinct destinations. But we only get here if this
3195       // is a normal terminator; in particular the source block is *not* the
3196       // start of an if-selection or a switch-selection.  So at most one branch
3197       // is a kForward, kCaseFallThrough, or kIfBreak.
3198 
3199       // The fallthrough case is special because WGSL requires the fallthrough
3200       // statement to be last in the case clause.
3201       if (true_kind == EdgeKind::kCaseFallThrough) {
3202         return EmitConditionalCaseFallThrough(block_info, cond, false_kind,
3203                                               *false_info, true);
3204       } else if (false_kind == EdgeKind::kCaseFallThrough) {
3205         return EmitConditionalCaseFallThrough(block_info, cond, true_kind,
3206                                               *true_info, false);
3207       }
3208 
3209       // At this point, at most one edge is kForward or kIfBreak.
3210 
3211       // Emit an 'if' statement to express the *other* branch as a conditional
3212       // break or continue.  Either or both of these could be nullptr.
3213       // (A nullptr is generated for kIfBreak, kForward, or kBack.)
3214       // Also if one of the branches is an if-break out of an if-selection
3215       // requiring a flow guard, then get that flow guard name too.  It will
3216       // come from at most one of these two branches.
3217       std::string flow_guard;
3218       auto* true_branch =
3219           MakeBranchDetailed(block_info, *true_info, false, &flow_guard);
3220       auto* false_branch =
3221           MakeBranchDetailed(block_info, *false_info, false, &flow_guard);
3222 
3223       AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
3224       if (!flow_guard.empty()) {
3225         PushGuard(flow_guard, statements_stack_.back().GetEndId());
3226       }
3227       return true;
3228     }
3229     case SpvOpSwitch:
3230       // An OpSelectionMerge must precede an OpSwitch.  That is clarified
3231       // in the resolution to Khronos-internal SPIR-V issue 115.
3232       // A new enough version of the SPIR-V validator checks this case.
3233       // But issue an error in this case, as a defensive measure.
3234       return Fail() << "invalid structured control flow: found an OpSwitch "
3235                        "that is not preceded by an "
3236                        "OpSelectionMerge: "
3237                     << terminator.PrettyPrint();
3238     default:
3239       break;
3240   }
3241   return success();
3242 }
3243 
MakeBranchDetailed(const BlockInfo & src_info,const BlockInfo & dest_info,bool forced,std::string * flow_guard_name_ptr) const3244 const ast::Statement* FunctionEmitter::MakeBranchDetailed(
3245     const BlockInfo& src_info,
3246     const BlockInfo& dest_info,
3247     bool forced,
3248     std::string* flow_guard_name_ptr) const {
3249   auto kind = src_info.succ_edge.find(dest_info.id)->second;
3250   switch (kind) {
3251     case EdgeKind::kBack:
3252       // Nothing to do. The loop backedge is implicit.
3253       break;
3254     case EdgeKind::kSwitchBreak: {
3255       if (forced) {
3256         return create<ast::BreakStatement>(Source{});
3257       }
3258       // Unless forced, don't bother with a break at the end of a case/default
3259       // clause.
3260       const auto header = dest_info.header_for_merge;
3261       TINT_ASSERT(Reader, header != 0);
3262       const auto* exiting_construct = GetBlockInfo(header)->construct;
3263       TINT_ASSERT(Reader,
3264                   exiting_construct->kind == Construct::kSwitchSelection);
3265       const auto candidate_next_case_pos = src_info.pos + 1;
3266       // Leaving the last block from the last case?
3267       if (candidate_next_case_pos == dest_info.pos) {
3268         // No break needed.
3269         return nullptr;
3270       }
3271       // Leaving the last block from not-the-last-case?
3272       if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
3273         const auto* candidate_next_case =
3274             GetBlockInfo(block_order_[candidate_next_case_pos]);
3275         if (candidate_next_case->case_head_for == exiting_construct ||
3276             candidate_next_case->default_head_for == exiting_construct) {
3277           // No break needed.
3278           return nullptr;
3279         }
3280       }
3281       // We need a break.
3282       return create<ast::BreakStatement>(Source{});
3283     }
3284     case EdgeKind::kLoopBreak:
3285       return create<ast::BreakStatement>(Source{});
3286     case EdgeKind::kLoopContinue:
3287       // An unconditional continue to the next block is redundant and ugly.
3288       // Skip it in that case.
3289       if (dest_info.pos == 1 + src_info.pos) {
3290         break;
3291       }
3292       // Otherwise, emit a regular continue statement.
3293       return create<ast::ContinueStatement>(Source{});
3294     case EdgeKind::kIfBreak: {
3295       const auto& flow_guard =
3296           GetBlockInfo(dest_info.header_for_merge)->flow_guard_name;
3297       if (!flow_guard.empty()) {
3298         if (flow_guard_name_ptr != nullptr) {
3299           *flow_guard_name_ptr = flow_guard;
3300         }
3301         // Signal an exit from the branch.
3302         return create<ast::AssignmentStatement>(
3303             Source{},
3304             create<ast::IdentifierExpression>(
3305                 Source{}, builder_.Symbols().Register(flow_guard)),
3306             MakeFalse(Source{}));
3307       }
3308 
3309       // For an unconditional branch, the break out to an if-selection
3310       // merge block is implicit.
3311       break;
3312     }
3313     case EdgeKind::kCaseFallThrough:
3314       return create<ast::FallthroughStatement>(Source{});
3315     case EdgeKind::kForward:
3316       // Unconditional forward branch is implicit.
3317       break;
3318   }
3319   return nullptr;
3320 }
3321 
MakeSimpleIf(const ast::Expression * condition,const ast::Statement * then_stmt,const ast::Statement * else_stmt) const3322 const ast::Statement* FunctionEmitter::MakeSimpleIf(
3323     const ast::Expression* condition,
3324     const ast::Statement* then_stmt,
3325     const ast::Statement* else_stmt) const {
3326   if ((then_stmt == nullptr) && (else_stmt == nullptr)) {
3327     return nullptr;
3328   }
3329   ast::ElseStatementList else_stmts;
3330   if (else_stmt != nullptr) {
3331     ast::StatementList stmts{else_stmt};
3332     else_stmts.emplace_back(create<ast::ElseStatement>(
3333         Source{}, nullptr, create<ast::BlockStatement>(Source{}, stmts)));
3334   }
3335   ast::StatementList if_stmts;
3336   if (then_stmt != nullptr) {
3337     if_stmts.emplace_back(then_stmt);
3338   }
3339   auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
3340   auto* if_stmt =
3341       create<ast::IfStatement>(Source{}, condition, if_block, else_stmts);
3342 
3343   return if_stmt;
3344 }
3345 
EmitConditionalCaseFallThrough(const BlockInfo & src_info,const ast::Expression * cond,EdgeKind other_edge_kind,const BlockInfo & other_dest,bool fall_through_is_true_branch)3346 bool FunctionEmitter::EmitConditionalCaseFallThrough(
3347     const BlockInfo& src_info,
3348     const ast::Expression* cond,
3349     EdgeKind other_edge_kind,
3350     const BlockInfo& other_dest,
3351     bool fall_through_is_true_branch) {
3352   // In WGSL, the fallthrough statement must come last in the case clause.
3353   // So we'll emit an if statement for the other branch, and then emit
3354   // the fallthrough.
3355 
3356   // We have two distinct destinations. But we only get here if this
3357   // is a normal terminator; in particular the source block is *not* the
3358   // start of an if-selection.  So at most one branch is a kForward or
3359   // kCaseFallThrough.
3360   if (other_edge_kind == EdgeKind::kForward) {
3361     return Fail()
3362            << "internal error: normal terminator OpBranchConditional has "
3363               "both forward and fallthrough edges";
3364   }
3365   if (other_edge_kind == EdgeKind::kIfBreak) {
3366     return Fail()
3367            << "internal error: normal terminator OpBranchConditional has "
3368               "both IfBreak and fallthrough edges.  Violates nesting rule";
3369   }
3370   if (other_edge_kind == EdgeKind::kBack) {
3371     return Fail()
3372            << "internal error: normal terminator OpBranchConditional has "
3373               "both backedge and fallthrough edges.  Violates nesting rule";
3374   }
3375   auto* other_branch = MakeForcedBranch(src_info, other_dest);
3376   if (other_branch == nullptr) {
3377     return Fail() << "internal error: expected a branch for edge-kind "
3378                   << int(other_edge_kind);
3379   }
3380   if (fall_through_is_true_branch) {
3381     AddStatement(MakeSimpleIf(cond, nullptr, other_branch));
3382   } else {
3383     AddStatement(MakeSimpleIf(cond, other_branch, nullptr));
3384   }
3385   AddStatement(create<ast::FallthroughStatement>(Source{}));
3386 
3387   return success();
3388 }
3389 
EmitStatementsInBasicBlock(const BlockInfo & block_info,bool * already_emitted)3390 bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
3391                                                  bool* already_emitted) {
3392   if (*already_emitted) {
3393     // Only emit this part of the basic block once.
3394     return true;
3395   }
3396   // Returns the given list of local definition IDs, sorted by their index.
3397   auto sorted_by_index = [this](const std::vector<uint32_t>& ids) {
3398     auto sorted = ids;
3399     std::stable_sort(sorted.begin(), sorted.end(),
3400                      [this](const uint32_t lhs, const uint32_t rhs) {
3401                        return GetDefInfo(lhs)->index < GetDefInfo(rhs)->index;
3402                      });
3403     return sorted;
3404   };
3405 
3406   // Emit declarations of hoisted variables, in index order.
3407   for (auto id : sorted_by_index(block_info.hoisted_ids)) {
3408     const auto* def_inst = def_use_mgr_->GetDef(id);
3409     TINT_ASSERT(Reader, def_inst);
3410     auto* storage_type =
3411         RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
3412     AddStatement(create<ast::VariableDeclStatement>(
3413         Source{},
3414         parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type,
3415                                   false, nullptr, ast::DecorationList{})));
3416     auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
3417     identifier_types_.emplace(id, type);
3418   }
3419   // Emit declarations of phi state variables, in index order.
3420   for (auto id : sorted_by_index(block_info.phis_needing_state_vars)) {
3421     const auto* def_inst = def_use_mgr_->GetDef(id);
3422     TINT_ASSERT(Reader, def_inst);
3423     const auto phi_var_name = GetDefInfo(id)->phi_var;
3424     TINT_ASSERT(Reader, !phi_var_name.empty());
3425     auto* var = builder_.Var(
3426         phi_var_name,
3427         parser_impl_.ConvertType(def_inst->type_id())->Build(builder_));
3428     AddStatement(create<ast::VariableDeclStatement>(Source{}, var));
3429   }
3430 
3431   // Emit regular statements.
3432   const spvtools::opt::BasicBlock& bb = *(block_info.basic_block);
3433   const auto* terminator = bb.terminator();
3434   const auto* merge = bb.GetMergeInst();  // Might be nullptr
3435   for (auto& inst : bb) {
3436     if (&inst == terminator || &inst == merge || inst.opcode() == SpvOpLabel ||
3437         inst.opcode() == SpvOpVariable) {
3438       continue;
3439     }
3440     if (!EmitStatement(inst)) {
3441       return false;
3442     }
3443   }
3444 
3445   // Emit assignments to carry values to phi nodes in potential destinations.
3446   // Do it in index order.
3447   if (!block_info.phi_assignments.empty()) {
3448     auto sorted = block_info.phi_assignments;
3449     std::stable_sort(sorted.begin(), sorted.end(),
3450                      [this](const BlockInfo::PhiAssignment& lhs,
3451                             const BlockInfo::PhiAssignment& rhs) {
3452                        return GetDefInfo(lhs.phi_id)->index <
3453                               GetDefInfo(rhs.phi_id)->index;
3454                      });
3455     for (auto assignment : block_info.phi_assignments) {
3456       const auto var_name = GetDefInfo(assignment.phi_id)->phi_var;
3457       auto expr = MakeExpression(assignment.value);
3458       if (!expr) {
3459         return false;
3460       }
3461       AddStatement(create<ast::AssignmentStatement>(
3462           Source{},
3463           create<ast::IdentifierExpression>(
3464               Source{}, builder_.Symbols().Register(var_name)),
3465           expr.expr));
3466     }
3467   }
3468 
3469   *already_emitted = true;
3470   return true;
3471 }
3472 
EmitConstDefinition(const spvtools::opt::Instruction & inst,TypedExpression expr)3473 bool FunctionEmitter::EmitConstDefinition(
3474     const spvtools::opt::Instruction& inst,
3475     TypedExpression expr) {
3476   if (!expr) {
3477     return false;
3478   }
3479 
3480   // Do not generate pointers that we want to sink.
3481   if (GetDefInfo(inst.result_id())->skip == SkipReason::kSinkPointerIntoUse) {
3482     return true;
3483   }
3484 
3485   expr = AddressOfIfNeeded(expr, &inst);
3486   auto* ast_const = parser_impl_.MakeVariable(
3487       inst.result_id(), ast::StorageClass::kNone, expr.type, true, expr.expr,
3488       ast::DecorationList{});
3489   if (!ast_const) {
3490     return false;
3491   }
3492   AddStatement(create<ast::VariableDeclStatement>(Source{}, ast_const));
3493   identifier_types_.emplace(inst.result_id(), expr.type);
3494   return success();
3495 }
3496 
EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction & inst,TypedExpression expr)3497 bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar(
3498     const spvtools::opt::Instruction& inst,
3499     TypedExpression expr) {
3500   return WriteIfHoistedVar(inst, expr) || EmitConstDefinition(inst, expr);
3501 }
3502 
WriteIfHoistedVar(const spvtools::opt::Instruction & inst,TypedExpression expr)3503 bool FunctionEmitter::WriteIfHoistedVar(const spvtools::opt::Instruction& inst,
3504                                         TypedExpression expr) {
3505   const auto result_id = inst.result_id();
3506   const auto* def_info = GetDefInfo(result_id);
3507   if (def_info && def_info->requires_hoisted_def) {
3508     auto name = namer_.Name(result_id);
3509     // Emit an assignment of the expression to the hoisted variable.
3510     AddStatement(create<ast::AssignmentStatement>(
3511         Source{},
3512         create<ast::IdentifierExpression>(Source{},
3513                                           builder_.Symbols().Register(name)),
3514         expr.expr));
3515     return true;
3516   }
3517   return false;
3518 }
3519 
EmitStatement(const spvtools::opt::Instruction & inst)3520 bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
3521   if (failed()) {
3522     return false;
3523   }
3524   const auto result_id = inst.result_id();
3525   const auto type_id = inst.type_id();
3526 
3527   if (type_id != 0) {
3528     const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
3529     if (type_id == builtin_position_info.struct_type_id) {
3530       return Fail() << "operations producing a per-vertex structure are not "
3531                        "supported: "
3532                     << inst.PrettyPrint();
3533     }
3534     if (type_id == builtin_position_info.pointer_type_id) {
3535       return Fail() << "operations producing a pointer to a per-vertex "
3536                        "structure are not "
3537                        "supported: "
3538                     << inst.PrettyPrint();
3539     }
3540   }
3541 
3542   // Handle combinatorial instructions.
3543   const auto* def_info = GetDefInfo(result_id);
3544   if (def_info) {
3545     TypedExpression combinatorial_expr;
3546     if (def_info->skip == SkipReason::kDontSkip) {
3547       combinatorial_expr = MaybeEmitCombinatorialValue(inst);
3548       if (!success()) {
3549         return false;
3550       }
3551     }
3552     // An access chain or OpCopyObject can generate a skip.
3553     if (def_info->skip != SkipReason::kDontSkip) {
3554       return true;
3555     }
3556 
3557     if (combinatorial_expr.expr != nullptr) {
3558       if (def_info->requires_hoisted_def ||
3559           def_info->requires_named_const_def || def_info->num_uses != 1) {
3560         // Generate a const definition or an assignment to a hoisted definition
3561         // now and later use the const or variable name at the uses of this
3562         // value.
3563         return EmitConstDefOrWriteToHoistedVar(inst, combinatorial_expr);
3564       }
3565       // It is harmless to defer emitting the expression until it's used.
3566       // Any supporting statements have already been emitted.
3567       singly_used_values_.insert(std::make_pair(result_id, combinatorial_expr));
3568       return success();
3569     }
3570   }
3571   if (failed()) {
3572     return false;
3573   }
3574 
3575   if (IsImageQuery(inst.opcode())) {
3576     return EmitImageQuery(inst);
3577   }
3578 
3579   if (IsSampledImageAccess(inst.opcode()) || IsRawImageAccess(inst.opcode())) {
3580     return EmitImageAccess(inst);
3581   }
3582 
3583   switch (inst.opcode()) {
3584     case SpvOpNop:
3585       return true;
3586 
3587     case SpvOpStore: {
3588       auto ptr_id = inst.GetSingleWordInOperand(0);
3589       const auto value_id = inst.GetSingleWordInOperand(1);
3590 
3591       const auto ptr_type_id = def_use_mgr_->GetDef(ptr_id)->type_id();
3592       const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
3593       if (ptr_type_id == builtin_position_info.pointer_type_id) {
3594         return Fail()
3595                << "storing to the whole per-vertex structure is not supported: "
3596                << inst.PrettyPrint();
3597       }
3598 
3599       TypedExpression rhs = MakeExpression(value_id);
3600       if (!rhs) {
3601         return false;
3602       }
3603 
3604       TypedExpression lhs;
3605 
3606       // Handle exceptional cases
3607       switch (GetSkipReason(ptr_id)) {
3608         case SkipReason::kPointSizeBuiltinPointer:
3609           if (IsFloatOne(value_id)) {
3610             // Don't store to PointSize
3611             return true;
3612           }
3613           return Fail() << "cannot store a value other than constant 1.0 to "
3614                            "PointSize builtin: "
3615                         << inst.PrettyPrint();
3616 
3617         case SkipReason::kSampleMaskOutBuiltinPointer:
3618           lhs = MakeExpression(sample_mask_out_id);
3619           if (lhs.type->Is<Pointer>()) {
3620             // LHS of an assignment must be a reference type.
3621             // Convert the LHS to a reference by dereferencing it.
3622             lhs = Dereference(lhs);
3623           }
3624           // The private variable is an array whose element type is already of
3625           // the same type as the value being stored into it.  Form the
3626           // reference into the first element.
3627           lhs.expr = create<ast::IndexAccessorExpression>(
3628               Source{}, lhs.expr, parser_impl_.MakeNullValue(ty_.I32()));
3629           if (auto* ref = lhs.type->As<Reference>()) {
3630             lhs.type = ref->type;
3631           }
3632           if (auto* arr = lhs.type->As<Array>()) {
3633             lhs.type = arr->type;
3634           }
3635           TINT_ASSERT(Reader, lhs.type);
3636           break;
3637         default:
3638           break;
3639       }
3640 
3641       // Handle an ordinary store as an assignment.
3642       if (!lhs) {
3643         lhs = MakeExpression(ptr_id);
3644       }
3645       if (!lhs) {
3646         return false;
3647       }
3648 
3649       if (lhs.type->Is<Pointer>()) {
3650         // LHS of an assignment must be a reference type.
3651         // Convert the LHS to a reference by dereferencing it.
3652         lhs = Dereference(lhs);
3653       }
3654 
3655       AddStatement(
3656           create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
3657       return success();
3658     }
3659 
3660     case SpvOpLoad: {
3661       // Memory accesses must be issued in SPIR-V program order.
3662       // So represent a load by a new const definition.
3663       const auto ptr_id = inst.GetSingleWordInOperand(0);
3664       const auto skip_reason = GetSkipReason(ptr_id);
3665 
3666       switch (skip_reason) {
3667         case SkipReason::kPointSizeBuiltinPointer:
3668           GetDefInfo(inst.result_id())->skip =
3669               SkipReason::kPointSizeBuiltinValue;
3670           return true;
3671         case SkipReason::kSampleMaskInBuiltinPointer: {
3672           auto name = namer_.Name(sample_mask_in_id);
3673           const ast::Expression* id_expr = create<ast::IdentifierExpression>(
3674               Source{}, builder_.Symbols().Register(name));
3675           // SampleMask is an array in Vulkan SPIR-V. Always access the first
3676           // element.
3677           id_expr = create<ast::IndexAccessorExpression>(
3678               Source{}, id_expr, parser_impl_.MakeNullValue(ty_.I32()));
3679 
3680           auto* loaded_type = parser_impl_.ConvertType(inst.type_id());
3681 
3682           if (!loaded_type->IsIntegerScalar()) {
3683             return Fail() << "loading the whole SampleMask input array is not "
3684                              "supported: "
3685                           << inst.PrettyPrint();
3686           }
3687 
3688           auto expr = TypedExpression{loaded_type, id_expr};
3689           return EmitConstDefinition(inst, expr);
3690         }
3691         default:
3692           break;
3693       }
3694       auto expr = MakeExpression(ptr_id);
3695       if (!expr) {
3696         return false;
3697       }
3698 
3699       // The load result type is the storage type of its operand.
3700       if (expr.type->Is<Pointer>()) {
3701         expr = Dereference(expr);
3702       } else if (auto* ref = expr.type->As<Reference>()) {
3703         expr.type = ref->type;
3704       } else {
3705         Fail() << "OpLoad expression is not a pointer or reference";
3706         return false;
3707       }
3708 
3709       return EmitConstDefOrWriteToHoistedVar(inst, expr);
3710     }
3711 
3712     case SpvOpCopyMemory: {
3713       // Generate an assignment.
3714       auto lhs = MakeOperand(inst, 0);
3715       auto rhs = MakeOperand(inst, 1);
3716       // Ignore any potential memory operands. Currently they are all for
3717       // concepts not in WGSL:
3718       //   Volatile
3719       //   Aligned
3720       //   Nontemporal
3721       //   MakePointerAvailable ; Vulkan memory model
3722       //   MakePointerVisible   ; Vulkan memory model
3723       //   NonPrivatePointer    ; Vulkan memory model
3724 
3725       if (!success()) {
3726         return false;
3727       }
3728 
3729       // LHS and RHS pointers must be reference types in WGSL.
3730       if (lhs.type->Is<Pointer>()) {
3731         lhs = Dereference(lhs);
3732       }
3733       if (rhs.type->Is<Pointer>()) {
3734         rhs = Dereference(rhs);
3735       }
3736 
3737       AddStatement(
3738           create<ast::AssignmentStatement>(Source{}, lhs.expr, rhs.expr));
3739       return success();
3740     }
3741 
3742     case SpvOpCopyObject: {
3743       // Arguably, OpCopyObject is purely combinatorial. On the other hand,
3744       // it exists to make a new name for something. So we choose to make
3745       // a new named constant definition.
3746       auto value_id = inst.GetSingleWordInOperand(0);
3747       const auto skip = GetSkipReason(value_id);
3748       if (skip != SkipReason::kDontSkip) {
3749         GetDefInfo(inst.result_id())->skip = skip;
3750         GetDefInfo(inst.result_id())->sink_pointer_source_expr =
3751             GetDefInfo(value_id)->sink_pointer_source_expr;
3752         return true;
3753       }
3754       auto expr = AddressOfIfNeeded(MakeExpression(value_id), &inst);
3755       if (!expr) {
3756         return false;
3757       }
3758       expr.type = RemapStorageClass(expr.type, result_id);
3759       return EmitConstDefOrWriteToHoistedVar(inst, expr);
3760     }
3761 
3762     case SpvOpPhi: {
3763       // Emit a read from the associated state variable.
3764       TypedExpression expr{
3765           parser_impl_.ConvertType(inst.type_id()),
3766           create<ast::IdentifierExpression>(
3767               Source{}, builder_.Symbols().Register(def_info->phi_var))};
3768       return EmitConstDefOrWriteToHoistedVar(inst, expr);
3769     }
3770 
3771     case SpvOpOuterProduct:
3772       // Synthesize an outer product expression in its own statement.
3773       return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
3774 
3775     case SpvOpVectorInsertDynamic:
3776       // Synthesize a vector insertion in its own statements.
3777       return MakeVectorInsertDynamic(inst);
3778 
3779     case SpvOpCompositeInsert:
3780       // Synthesize a composite insertion in its own statements.
3781       return MakeCompositeInsert(inst);
3782 
3783     case SpvOpFunctionCall:
3784       return EmitFunctionCall(inst);
3785 
3786     case SpvOpControlBarrier:
3787       return EmitControlBarrier(inst);
3788 
3789     case SpvOpExtInst:
3790       if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
3791         return true;
3792       }
3793       break;
3794 
3795     case SpvOpIAddCarry:
3796     case SpvOpISubBorrow:
3797     case SpvOpUMulExtended:
3798     case SpvOpSMulExtended:
3799       return Fail() << "extended arithmetic is not finalized for WGSL: "
3800                        "https://github.com/gpuweb/gpuweb/issues/1565: "
3801                     << inst.PrettyPrint();
3802 
3803     default:
3804       break;
3805   }
3806   return Fail() << "unhandled instruction with opcode " << inst.opcode() << ": "
3807                 << inst.PrettyPrint();
3808 }
3809 
MakeOperand(const spvtools::opt::Instruction & inst,uint32_t operand_index)3810 TypedExpression FunctionEmitter::MakeOperand(
3811     const spvtools::opt::Instruction& inst,
3812     uint32_t operand_index) {
3813   auto expr = MakeExpression(inst.GetSingleWordInOperand(operand_index));
3814   if (!expr) {
3815     return {};
3816   }
3817   return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
3818 }
3819 
InferFunctionStorageClass(TypedExpression expr)3820 TypedExpression FunctionEmitter::InferFunctionStorageClass(
3821     TypedExpression expr) {
3822   TypedExpression result(expr);
3823   if (const auto* ref = expr.type->UnwrapAlias()->As<Reference>()) {
3824     if (ref->storage_class == ast::StorageClass::kNone) {
3825       expr.type = ty_.Reference(ref->type, ast::StorageClass::kFunction);
3826     }
3827   } else if (const auto* ptr = expr.type->UnwrapAlias()->As<Pointer>()) {
3828     if (ptr->storage_class == ast::StorageClass::kNone) {
3829       expr.type = ty_.Pointer(ptr->type, ast::StorageClass::kFunction);
3830     }
3831   }
3832   return expr;
3833 }
3834 
MaybeEmitCombinatorialValue(const spvtools::opt::Instruction & inst)3835 TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
3836     const spvtools::opt::Instruction& inst) {
3837   if (inst.result_id() == 0) {
3838     return {};
3839   }
3840 
3841   const auto opcode = inst.opcode();
3842 
3843   const Type* ast_type = nullptr;
3844   if (inst.type_id()) {
3845     ast_type = parser_impl_.ConvertType(inst.type_id());
3846     if (!ast_type) {
3847       Fail() << "couldn't convert result type for: " << inst.PrettyPrint();
3848       return {};
3849     }
3850   }
3851 
3852   auto binary_op = ConvertBinaryOp(opcode);
3853   if (binary_op != ast::BinaryOp::kNone) {
3854     auto arg0 = MakeOperand(inst, 0);
3855     auto arg1 = parser_impl_.RectifySecondOperandSignedness(
3856         inst, arg0.type, MakeOperand(inst, 1));
3857     if (!arg0 || !arg1) {
3858       return {};
3859     }
3860     auto* binary_expr = create<ast::BinaryExpression>(Source{}, binary_op,
3861                                                       arg0.expr, arg1.expr);
3862     TypedExpression result{ast_type, binary_expr};
3863     return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3864   }
3865 
3866   auto unary_op = ast::UnaryOp::kNegation;
3867   if (GetUnaryOp(opcode, &unary_op)) {
3868     auto arg0 = MakeOperand(inst, 0);
3869     auto* unary_expr =
3870         create<ast::UnaryOpExpression>(Source{}, unary_op, arg0.expr);
3871     TypedExpression result{ast_type, unary_expr};
3872     return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3873   }
3874 
3875   const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
3876   if (unary_builtin_name != nullptr) {
3877     ast::ExpressionList params;
3878     params.emplace_back(MakeOperand(inst, 0).expr);
3879     return {ast_type,
3880             create<ast::CallExpression>(
3881                 Source{},
3882                 create<ast::IdentifierExpression>(
3883                     Source{}, builder_.Symbols().Register(unary_builtin_name)),
3884                 std::move(params))};
3885   }
3886 
3887   const auto intrinsic = GetIntrinsic(opcode);
3888   if (intrinsic != sem::IntrinsicType::kNone) {
3889     return MakeIntrinsicCall(inst);
3890   }
3891 
3892   if (opcode == SpvOpFMod) {
3893     return MakeFMod(inst);
3894   }
3895 
3896   if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
3897     return MakeAccessChain(inst);
3898   }
3899 
3900   if (opcode == SpvOpBitcast) {
3901     return {ast_type,
3902             create<ast::BitcastExpression>(Source{}, ast_type->Build(builder_),
3903                                            MakeOperand(inst, 0).expr)};
3904   }
3905 
3906   if (opcode == SpvOpShiftLeftLogical || opcode == SpvOpShiftRightLogical ||
3907       opcode == SpvOpShiftRightArithmetic) {
3908     auto arg0 = MakeOperand(inst, 0);
3909     // The second operand must be unsigned. It's ok to wrap the shift amount
3910     // since the shift is modulo the bit width of the first operand.
3911     auto arg1 = parser_impl_.AsUnsigned(MakeOperand(inst, 1));
3912 
3913     switch (opcode) {
3914       case SpvOpShiftLeftLogical:
3915         binary_op = ast::BinaryOp::kShiftLeft;
3916         break;
3917       case SpvOpShiftRightLogical:
3918         arg0 = parser_impl_.AsUnsigned(arg0);
3919         binary_op = ast::BinaryOp::kShiftRight;
3920         break;
3921       case SpvOpShiftRightArithmetic:
3922         arg0 = parser_impl_.AsSigned(arg0);
3923         binary_op = ast::BinaryOp::kShiftRight;
3924         break;
3925       default:
3926         break;
3927     }
3928     TypedExpression result{
3929         ast_type, create<ast::BinaryExpression>(Source{}, binary_op, arg0.expr,
3930                                                 arg1.expr)};
3931     return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
3932   }
3933 
3934   auto negated_op = NegatedFloatCompare(opcode);
3935   if (negated_op != ast::BinaryOp::kNone) {
3936     auto arg0 = MakeOperand(inst, 0);
3937     auto arg1 = MakeOperand(inst, 1);
3938     auto* binary_expr = create<ast::BinaryExpression>(Source{}, negated_op,
3939                                                       arg0.expr, arg1.expr);
3940     auto* negated_expr = create<ast::UnaryOpExpression>(
3941         Source{}, ast::UnaryOp::kNot, binary_expr);
3942     return {ast_type, negated_expr};
3943   }
3944 
3945   if (opcode == SpvOpExtInst) {
3946     if (parser_impl_.IsIgnoredExtendedInstruction(inst)) {
3947       // Ignore it but don't error out.
3948       return {};
3949     }
3950     if (!parser_impl_.IsGlslExtendedInstruction(inst)) {
3951       Fail() << "unhandled extended instruction import with ID "
3952              << inst.GetSingleWordInOperand(0);
3953       return {};
3954     }
3955     return EmitGlslStd450ExtInst(inst);
3956   }
3957 
3958   if (opcode == SpvOpCompositeConstruct) {
3959     ast::ExpressionList operands;
3960     for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
3961       operands.emplace_back(MakeOperand(inst, iarg).expr);
3962     }
3963     return {ast_type, builder_.Construct(Source{}, ast_type->Build(builder_),
3964                                          std::move(operands))};
3965   }
3966 
3967   if (opcode == SpvOpCompositeExtract) {
3968     return MakeCompositeExtract(inst);
3969   }
3970 
3971   if (opcode == SpvOpVectorShuffle) {
3972     return MakeVectorShuffle(inst);
3973   }
3974 
3975   if (opcode == SpvOpVectorExtractDynamic) {
3976     return {ast_type, create<ast::IndexAccessorExpression>(
3977                           Source{}, MakeOperand(inst, 0).expr,
3978                           MakeOperand(inst, 1).expr)};
3979   }
3980 
3981   if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF ||
3982       opcode == SpvOpConvertFToS || opcode == SpvOpConvertFToU) {
3983     return MakeNumericConversion(inst);
3984   }
3985 
3986   if (opcode == SpvOpUndef) {
3987     // Replace undef with the null value.
3988     return parser_impl_.MakeNullExpression(ast_type);
3989   }
3990 
3991   if (opcode == SpvOpSelect) {
3992     return MakeSimpleSelect(inst);
3993   }
3994 
3995   if (opcode == SpvOpArrayLength) {
3996     return MakeArrayLength(inst);
3997   }
3998 
3999   // builtin readonly function
4000   // glsl.std.450 readonly function
4001 
4002   // Instructions:
4003   //    OpSatConvertSToU // Only in Kernel (OpenCL), not in WebGPU
4004   //    OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU
4005   //    OpUConvert // Only needed when multiple widths supported
4006   //    OpSConvert // Only needed when multiple widths supported
4007   //    OpFConvert // Only needed when multiple widths supported
4008   //    OpConvertPtrToU // Not in WebGPU
4009   //    OpConvertUToPtr // Not in WebGPU
4010   //    OpPtrCastToGeneric // Not in Vulkan
4011   //    OpGenericCastToPtr // Not in Vulkan
4012   //    OpGenericCastToPtrExplicit // Not in Vulkan
4013 
4014   return {};
4015 }
4016 
EmitGlslStd450ExtInst(const spvtools::opt::Instruction & inst)4017 TypedExpression FunctionEmitter::EmitGlslStd450ExtInst(
4018     const spvtools::opt::Instruction& inst) {
4019   const auto ext_opcode = inst.GetSingleWordInOperand(1);
4020 
4021   if (ext_opcode == GLSLstd450Ldexp) {
4022     // WGSL requires the second argument to be signed.
4023     // Use a type constructor to convert it, which is the same as a bitcast.
4024     // If the value would go from very large positive to negative, then the
4025     // original result would have been infinity.  And since WGSL
4026     // implementations may assume that infinities are not present, then we
4027     // don't have to worry about that case.
4028     auto e1 = MakeOperand(inst, 2);
4029     auto e2 = ToSignedIfUnsigned(MakeOperand(inst, 3));
4030 
4031     return {e1.type, builder_.Call(Source{}, "ldexp",
4032                                    ast::ExpressionList{e1.expr, e2.expr})};
4033   }
4034 
4035   auto* result_type = parser_impl_.ConvertType(inst.type_id());
4036 
4037   if (result_type->IsScalar()) {
4038     // Some GLSLstd450 builtins have scalar forms not supported by WGSL.
4039     // Emulate them.
4040     switch (ext_opcode) {
4041       case GLSLstd450Normalize:
4042         // WGSL does not have scalar form of the normalize builtin.
4043         // The answer would be 1 anyway, so return that directly.
4044         return {ty_.F32(), builder_.Expr(1.0f)};
4045 
4046       case GLSLstd450FaceForward: {
4047         // If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
4048         // Also: select(-normal,normal, Incident*Nref < 0)
4049         // (The dot product of scalars is their product.)
4050         // Use a multiply instead of comparing floating point signs. It should
4051         // be among the fastest operations on a GPU.
4052         auto normal = MakeOperand(inst, 2);
4053         auto incident = MakeOperand(inst, 3);
4054         auto nref = MakeOperand(inst, 4);
4055         TINT_ASSERT(Reader, normal.type->Is<F32>());
4056         TINT_ASSERT(Reader, incident.type->Is<F32>());
4057         TINT_ASSERT(Reader, nref.type->Is<F32>());
4058         return {ty_.F32(),
4059                 builder_.Call(
4060                     Source{}, "select",
4061                     ast::ExpressionList{
4062                         create<ast::UnaryOpExpression>(
4063                             Source{}, ast::UnaryOp::kNegation, normal.expr),
4064                         normal.expr,
4065                         create<ast::BinaryExpression>(
4066                             Source{}, ast::BinaryOp::kLessThan,
4067                             builder_.Mul({}, incident.expr, nref.expr),
4068                             builder_.Expr(0.0f))})};
4069       }
4070 
4071       case GLSLstd450Reflect: {
4072         // Compute  Incident - 2 * Normal * Normal * Incident
4073         auto incident = MakeOperand(inst, 2);
4074         auto normal = MakeOperand(inst, 3);
4075         TINT_ASSERT(Reader, incident.type->Is<F32>());
4076         TINT_ASSERT(Reader, normal.type->Is<F32>());
4077         return {
4078             ty_.F32(),
4079             builder_.Sub(
4080                 incident.expr,
4081                 builder_.Mul(2.0f, builder_.Mul(normal.expr,
4082                                                 builder_.Mul(normal.expr,
4083                                                              incident.expr))))};
4084       }
4085 
4086       case GLSLstd450Refract: {
4087         // It's a complicated expression. Compute it in two dimensions, but
4088         // with a 0-valued y component in both the incident and normal vectors,
4089         // then take the x component of that result.
4090         auto incident = MakeOperand(inst, 2);
4091         auto normal = MakeOperand(inst, 3);
4092         auto eta = MakeOperand(inst, 4);
4093         TINT_ASSERT(Reader, incident.type->Is<F32>());
4094         TINT_ASSERT(Reader, normal.type->Is<F32>());
4095         TINT_ASSERT(Reader, eta.type->Is<F32>());
4096         if (!success()) {
4097           return {};
4098         }
4099         const Type* f32 = eta.type;
4100         return {f32,
4101                 builder_.MemberAccessor(
4102                     builder_.Call(
4103                         Source{}, "refract",
4104                         ast::ExpressionList{
4105                             builder_.vec2<float>(incident.expr, 0.0f),
4106                             builder_.vec2<float>(normal.expr, 0.0f), eta.expr}),
4107                     "x")};
4108       }
4109       default:
4110         break;
4111     }
4112   }
4113 
4114   // Some GLSLStd450 builtins don't have a WGSL equivalent. Polyfill them.
4115   switch (ext_opcode) {
4116     case GLSLstd450Radians: {
4117       auto degrees = MakeOperand(inst, 2);
4118       TINT_ASSERT(Reader, degrees.type->IsFloatScalarOrVector());
4119 
4120       constexpr auto kPiOver180 = static_cast<float>(3.141592653589793 / 180.0);
4121       auto* factor = builder_.Expr(kPiOver180);
4122       if (degrees.type->Is<F32>()) {
4123         return {degrees.type, builder_.Mul(degrees.expr, factor)};
4124       } else {
4125         uint32_t size = degrees.type->As<Vector>()->size;
4126         return {degrees.type,
4127                 builder_.Mul(degrees.expr,
4128                              builder_.vec(builder_.ty.f32(), size, factor))};
4129       }
4130     }
4131 
4132     case GLSLstd450Degrees: {
4133       auto radians = MakeOperand(inst, 2);
4134       TINT_ASSERT(Reader, radians.type->IsFloatScalarOrVector());
4135 
4136       constexpr auto k180OverPi = static_cast<float>(180.0 / 3.141592653589793);
4137       auto* factor = builder_.Expr(k180OverPi);
4138       if (radians.type->Is<F32>()) {
4139         return {radians.type, builder_.Mul(radians.expr, factor)};
4140       } else {
4141         uint32_t size = radians.type->As<Vector>()->size;
4142         return {radians.type,
4143                 builder_.Mul(radians.expr,
4144                              builder_.vec(builder_.ty.f32(), size, factor))};
4145       }
4146     }
4147   }
4148 
4149   const auto name = GetGlslStd450FuncName(ext_opcode);
4150   if (name.empty()) {
4151     Fail() << "unhandled GLSL.std.450 instruction " << ext_opcode;
4152     return {};
4153   }
4154 
4155   auto* func = create<ast::IdentifierExpression>(
4156       Source{}, builder_.Symbols().Register(name));
4157   ast::ExpressionList operands;
4158   const Type* first_operand_type = nullptr;
4159   // All parameters to GLSL.std.450 extended instructions are IDs.
4160   for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
4161     TypedExpression operand = MakeOperand(inst, iarg);
4162     if (first_operand_type == nullptr) {
4163       first_operand_type = operand.type;
4164     }
4165     operands.emplace_back(operand.expr);
4166   }
4167   auto* call = create<ast::CallExpression>(Source{}, func, std::move(operands));
4168   TypedExpression call_expr{result_type, call};
4169   return parser_impl_.RectifyForcedResultType(call_expr, inst,
4170                                               first_operand_type);
4171 }
4172 
Swizzle(uint32_t i)4173 ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
4174   if (i >= kMaxVectorLen) {
4175     Fail() << "vector component index is larger than " << kMaxVectorLen - 1
4176            << ": " << i;
4177     return nullptr;
4178   }
4179   const char* names[] = {"x", "y", "z", "w"};
4180   return create<ast::IdentifierExpression>(
4181       Source{}, builder_.Symbols().Register(names[i & 3]));
4182 }
4183 
PrefixSwizzle(uint32_t n)4184 ast::IdentifierExpression* FunctionEmitter::PrefixSwizzle(uint32_t n) {
4185   switch (n) {
4186     case 1:
4187       return create<ast::IdentifierExpression>(
4188           Source{}, builder_.Symbols().Register("x"));
4189     case 2:
4190       return create<ast::IdentifierExpression>(
4191           Source{}, builder_.Symbols().Register("xy"));
4192     case 3:
4193       return create<ast::IdentifierExpression>(
4194           Source{}, builder_.Symbols().Register("xyz"));
4195     default:
4196       break;
4197   }
4198   Fail() << "invalid swizzle prefix count: " << n;
4199   return nullptr;
4200 }
4201 
MakeFMod(const spvtools::opt::Instruction & inst)4202 TypedExpression FunctionEmitter::MakeFMod(
4203     const spvtools::opt::Instruction& inst) {
4204   auto x = MakeOperand(inst, 0);
4205   auto y = MakeOperand(inst, 1);
4206   if (!x || !y) {
4207     return {};
4208   }
4209   // Emulated with: x - y * floor(x / y)
4210   auto* div = builder_.Div(x.expr, y.expr);
4211   auto* floor = builder_.Call("floor", div);
4212   auto* y_floor = builder_.Mul(y.expr, floor);
4213   auto* res = builder_.Sub(x.expr, y_floor);
4214   return {x.type, res};
4215 }
4216 
MakeAccessChain(const spvtools::opt::Instruction & inst)4217 TypedExpression FunctionEmitter::MakeAccessChain(
4218     const spvtools::opt::Instruction& inst) {
4219   if (inst.NumInOperands() < 1) {
4220     // Binary parsing will fail on this anyway.
4221     Fail() << "invalid access chain: has no input operands";
4222     return {};
4223   }
4224 
4225   const auto base_id = inst.GetSingleWordInOperand(0);
4226   const auto base_skip = GetSkipReason(base_id);
4227   if (base_skip != SkipReason::kDontSkip) {
4228     // This can occur for AccessChain with no indices.
4229     GetDefInfo(inst.result_id())->skip = base_skip;
4230     GetDefInfo(inst.result_id())->sink_pointer_source_expr =
4231         GetDefInfo(base_id)->sink_pointer_source_expr;
4232     return {};
4233   }
4234 
4235   auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
4236   uint32_t first_index = 1;
4237   const auto num_in_operands = inst.NumInOperands();
4238 
4239   bool sink_pointer = false;
4240   TypedExpression current_expr;
4241 
4242   // If the variable was originally gl_PerVertex, then in the AST we
4243   // have instead emitted a gl_Position variable.
4244   // If computing the pointer to the Position builtin, then emit the
4245   // pointer to the generated gl_Position variable.
4246   // If computing the pointer to the PointSize builtin, then mark the
4247   // result as skippable due to being the point-size pointer.
4248   // If computing the pointer to the ClipDistance or CullDistance builtins,
4249   // then error out.
4250   {
4251     const auto& builtin_position_info = parser_impl_.GetBuiltInPositionInfo();
4252     if (base_id == builtin_position_info.per_vertex_var_id) {
4253       // We only support the Position member.
4254       const auto* member_index_inst =
4255           def_use_mgr_->GetDef(inst.GetSingleWordInOperand(first_index));
4256       if (member_index_inst == nullptr) {
4257         Fail()
4258             << "first index of access chain does not reference an instruction: "
4259             << inst.PrettyPrint();
4260         return {};
4261       }
4262       const auto* member_index_const =
4263           constant_mgr_->GetConstantFromInst(member_index_inst);
4264       if (member_index_const == nullptr) {
4265         Fail() << "first index of access chain into per-vertex structure is "
4266                   "not a constant: "
4267                << inst.PrettyPrint();
4268         return {};
4269       }
4270       const auto* member_index_const_int = member_index_const->AsIntConstant();
4271       if (member_index_const_int == nullptr) {
4272         Fail() << "first index of access chain into per-vertex structure is "
4273                   "not a constant integer: "
4274                << inst.PrettyPrint();
4275         return {};
4276       }
4277       const auto member_index_value =
4278           member_index_const_int->GetZeroExtendedValue();
4279       if (member_index_value != builtin_position_info.position_member_index) {
4280         if (member_index_value ==
4281             builtin_position_info.pointsize_member_index) {
4282           if (auto* def_info = GetDefInfo(inst.result_id())) {
4283             def_info->skip = SkipReason::kPointSizeBuiltinPointer;
4284             return {};
4285           }
4286         } else {
4287           // TODO(dneto): Handle ClipDistance and CullDistance
4288           Fail() << "accessing per-vertex member " << member_index_value
4289                  << " is not supported. Only Position is supported, and "
4290                     "PointSize is ignored";
4291           return {};
4292         }
4293       }
4294 
4295       // Skip past the member index that gets us to Position.
4296       first_index = first_index + 1;
4297       // Replace the gl_PerVertex reference with the gl_Position reference
4298       ptr_ty_id = builtin_position_info.position_member_pointer_type_id;
4299 
4300       auto name = namer_.Name(base_id);
4301       current_expr.expr = create<ast::IdentifierExpression>(
4302           Source{}, builder_.Symbols().Register(name));
4303       current_expr.type = parser_impl_.ConvertType(ptr_ty_id, PtrAs::Ref);
4304     }
4305   }
4306 
4307   // A SPIR-V access chain is a single instruction with multiple indices
4308   // walking down into composites.  The Tint AST represents this as
4309   // ever-deeper nested indexing expressions. Start off with an expression
4310   // for the base, and then bury that inside nested indexing expressions.
4311   if (!current_expr) {
4312     current_expr = InferFunctionStorageClass(MakeOperand(inst, 0));
4313     if (current_expr.type->Is<Pointer>()) {
4314       current_expr = Dereference(current_expr);
4315     }
4316   }
4317   const auto constants = constant_mgr_->GetOperandConstants(&inst);
4318 
4319   const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id);
4320   if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) {
4321     Fail() << "Access chain %" << inst.result_id()
4322            << " base pointer is not of pointer type";
4323     return {};
4324   }
4325   SpvStorageClass storage_class =
4326       static_cast<SpvStorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
4327   uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
4328 
4329   // Build up a nested expression for the access chain by walking down the type
4330   // hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of
4331   // the object pointed to after processing the previous indices.
4332   for (uint32_t index = first_index; index < num_in_operands; ++index) {
4333     const auto* index_const =
4334         constants[index] ? constants[index]->AsIntConstant() : nullptr;
4335     const int64_t index_const_val =
4336         index_const ? index_const->GetSignExtendedValue() : 0;
4337     const ast::Expression* next_expr = nullptr;
4338 
4339     const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id);
4340     if (!pointee_type_inst) {
4341       Fail() << "pointee type %" << pointee_type_id
4342              << " is invalid after following " << (index - first_index)
4343              << " indices: " << inst.PrettyPrint();
4344       return {};
4345     }
4346     switch (pointee_type_inst->opcode()) {
4347       case SpvOpTypeVector:
4348         if (index_const) {
4349           // Try generating a MemberAccessor expression
4350           const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1);
4351           if (index_const_val < 0 || num_elems <= index_const_val) {
4352             Fail() << "Access chain %" << inst.result_id() << " index %"
4353                    << inst.GetSingleWordInOperand(index) << " value "
4354                    << index_const_val << " is out of bounds for vector of "
4355                    << num_elems << " elements";
4356             return {};
4357           }
4358           if (uint64_t(index_const_val) >= kMaxVectorLen) {
4359             Fail() << "internal error: swizzle index " << index_const_val
4360                    << " is too big. Max handled index is " << kMaxVectorLen - 1;
4361           }
4362           next_expr = create<ast::MemberAccessorExpression>(
4363               Source{}, current_expr.expr, Swizzle(uint32_t(index_const_val)));
4364         } else {
4365           // Non-constant index. Use array syntax
4366           next_expr = create<ast::IndexAccessorExpression>(
4367               Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4368         }
4369         // All vector components are the same type.
4370         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4371         // Sink pointers to vector components.
4372         sink_pointer = true;
4373         break;
4374       case SpvOpTypeMatrix:
4375         // Use array syntax.
4376         next_expr = create<ast::IndexAccessorExpression>(
4377             Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4378         // All matrix components are the same type.
4379         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4380         break;
4381       case SpvOpTypeArray:
4382         next_expr = create<ast::IndexAccessorExpression>(
4383             Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4384         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4385         break;
4386       case SpvOpTypeRuntimeArray:
4387         next_expr = create<ast::IndexAccessorExpression>(
4388             Source{}, current_expr.expr, MakeOperand(inst, index).expr);
4389         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
4390         break;
4391       case SpvOpTypeStruct: {
4392         if (!index_const) {
4393           Fail() << "Access chain %" << inst.result_id() << " index %"
4394                  << inst.GetSingleWordInOperand(index)
4395                  << " is a non-constant index into a structure %"
4396                  << pointee_type_id;
4397           return {};
4398         }
4399         const auto num_members = pointee_type_inst->NumInOperands();
4400         if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) {
4401           Fail() << "Access chain %" << inst.result_id() << " index value "
4402                  << index_const_val << " is out of bounds for structure %"
4403                  << pointee_type_id << " having " << num_members << " members";
4404           return {};
4405         }
4406         auto name =
4407             namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val));
4408         auto* member_access = create<ast::IdentifierExpression>(
4409             Source{}, builder_.Symbols().Register(name));
4410 
4411         next_expr = create<ast::MemberAccessorExpression>(
4412             Source{}, current_expr.expr, member_access);
4413         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(
4414             static_cast<uint32_t>(index_const_val));
4415         break;
4416       }
4417       default:
4418         Fail() << "Access chain with unknown or invalid pointee type %"
4419                << pointee_type_id << ": " << pointee_type_inst->PrettyPrint();
4420         return {};
4421     }
4422     const auto pointer_type_id =
4423         type_mgr_->FindPointerToType(pointee_type_id, storage_class);
4424     auto* type = parser_impl_.ConvertType(pointer_type_id, PtrAs::Ref);
4425     TINT_ASSERT(Reader, type && type->Is<Reference>());
4426     current_expr = TypedExpression{type, next_expr};
4427   }
4428 
4429   if (sink_pointer) {
4430     // Capture the reference so that we can sink it into the point of use.
4431     GetDefInfo(inst.result_id())->skip = SkipReason::kSinkPointerIntoUse;
4432     GetDefInfo(inst.result_id())->sink_pointer_source_expr = current_expr;
4433   }
4434 
4435   return current_expr;
4436 }
4437 
MakeCompositeExtract(const spvtools::opt::Instruction & inst)4438 TypedExpression FunctionEmitter::MakeCompositeExtract(
4439     const spvtools::opt::Instruction& inst) {
4440   // This is structurally similar to creating an access chain, but
4441   // the SPIR-V instruction has literal indices instead of IDs for indices.
4442 
4443   auto composite_index = 0;
4444   auto first_index_position = 1;
4445   TypedExpression current_expr(MakeOperand(inst, composite_index));
4446   if (!current_expr) {
4447     return {};
4448   }
4449 
4450   const auto composite_id = inst.GetSingleWordInOperand(composite_index);
4451   auto current_type_id = def_use_mgr_->GetDef(composite_id)->type_id();
4452 
4453   return MakeCompositeValueDecomposition(inst, current_expr, current_type_id,
4454                                          first_index_position);
4455 }
4456 
MakeCompositeValueDecomposition(const spvtools::opt::Instruction & inst,TypedExpression composite,uint32_t composite_type_id,int index_start)4457 TypedExpression FunctionEmitter::MakeCompositeValueDecomposition(
4458     const spvtools::opt::Instruction& inst,
4459     TypedExpression composite,
4460     uint32_t composite_type_id,
4461     int index_start) {
4462   // This is structurally similar to creating an access chain, but
4463   // the SPIR-V instruction has literal indices instead of IDs for indices.
4464 
4465   // A SPIR-V composite extract is a single instruction with multiple
4466   // literal indices walking down into composites.
4467   // A SPIR-V composite insert is similar but also tells you what component
4468   // to inject. This function is responsible for the the walking-into part
4469   // of composite-insert.
4470   //
4471   // The Tint AST represents this as ever-deeper nested indexing expressions.
4472   // Start off with an expression for the composite, and then bury that inside
4473   // nested indexing expressions.
4474 
4475   auto current_expr = composite;
4476   auto current_type_id = composite_type_id;
4477 
4478   auto make_index = [this](uint32_t literal) {
4479     return create<ast::UintLiteralExpression>(Source{}, literal);
4480   };
4481 
4482   // Build up a nested expression for the decomposition by walking down the type
4483   // hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of
4484   // the object pointed to after processing the previous indices.
4485   const auto num_in_operands = inst.NumInOperands();
4486   for (uint32_t index = index_start; index < num_in_operands; ++index) {
4487     const uint32_t index_val = inst.GetSingleWordInOperand(index);
4488 
4489     const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id);
4490     if (!current_type_inst) {
4491       Fail() << "composite type %" << current_type_id
4492              << " is invalid after following " << (index - index_start)
4493              << " indices: " << inst.PrettyPrint();
4494       return {};
4495     }
4496     const char* operation_name = nullptr;
4497     switch (inst.opcode()) {
4498       case SpvOpCompositeExtract:
4499         operation_name = "OpCompositeExtract";
4500         break;
4501       case SpvOpCompositeInsert:
4502         operation_name = "OpCompositeInsert";
4503         break;
4504       default:
4505         Fail() << "internal error: unhandled " << inst.PrettyPrint();
4506         return {};
4507     }
4508     const ast::Expression* next_expr = nullptr;
4509     switch (current_type_inst->opcode()) {
4510       case SpvOpTypeVector: {
4511         // Try generating a MemberAccessor expression. That result in something
4512         // like  "foo.z", which is more idiomatic than "foo[2]".
4513         const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
4514         if (num_elems <= index_val) {
4515           Fail() << operation_name << " %" << inst.result_id()
4516                  << " index value " << index_val
4517                  << " is out of bounds for vector of " << num_elems
4518                  << " elements";
4519           return {};
4520         }
4521         if (index_val >= kMaxVectorLen) {
4522           Fail() << "internal error: swizzle index " << index_val
4523                  << " is too big. Max handled index is " << kMaxVectorLen - 1;
4524           return {};
4525         }
4526         next_expr = create<ast::MemberAccessorExpression>(
4527             Source{}, current_expr.expr, Swizzle(index_val));
4528         // All vector components are the same type.
4529         current_type_id = current_type_inst->GetSingleWordInOperand(0);
4530         break;
4531       }
4532       case SpvOpTypeMatrix: {
4533         // Check bounds
4534         const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
4535         if (num_elems <= index_val) {
4536           Fail() << operation_name << " %" << inst.result_id()
4537                  << " index value " << index_val
4538                  << " is out of bounds for matrix of " << num_elems
4539                  << " elements";
4540           return {};
4541         }
4542         if (index_val >= kMaxVectorLen) {
4543           Fail() << "internal error: swizzle index " << index_val
4544                  << " is too big. Max handled index is " << kMaxVectorLen - 1;
4545         }
4546         // Use array syntax.
4547         next_expr = create<ast::IndexAccessorExpression>(
4548             Source{}, current_expr.expr, make_index(index_val));
4549         // All matrix components are the same type.
4550         current_type_id = current_type_inst->GetSingleWordInOperand(0);
4551         break;
4552       }
4553       case SpvOpTypeArray:
4554         // The array size could be a spec constant, and so it's not always
4555         // statically checkable.  Instead, rely on a runtime index clamp
4556         // or runtime check to keep this safe.
4557         next_expr = create<ast::IndexAccessorExpression>(
4558             Source{}, current_expr.expr, make_index(index_val));
4559         current_type_id = current_type_inst->GetSingleWordInOperand(0);
4560         break;
4561       case SpvOpTypeRuntimeArray:
4562         Fail() << "can't do " << operation_name
4563                << " on a runtime array: " << inst.PrettyPrint();
4564         return {};
4565       case SpvOpTypeStruct: {
4566         const auto num_members = current_type_inst->NumInOperands();
4567         if (num_members <= index_val) {
4568           Fail() << operation_name << " %" << inst.result_id()
4569                  << " index value " << index_val
4570                  << " is out of bounds for structure %" << current_type_id
4571                  << " having " << num_members << " members";
4572           return {};
4573         }
4574         auto name = namer_.GetMemberName(current_type_id, uint32_t(index_val));
4575         auto* member_access = create<ast::IdentifierExpression>(
4576             Source{}, builder_.Symbols().Register(name));
4577 
4578         next_expr = create<ast::MemberAccessorExpression>(
4579             Source{}, current_expr.expr, member_access);
4580         current_type_id = current_type_inst->GetSingleWordInOperand(index_val);
4581         break;
4582       }
4583       default:
4584         Fail() << operation_name << " with bad type %" << current_type_id
4585                << ": " << current_type_inst->PrettyPrint();
4586         return {};
4587     }
4588     current_expr =
4589         TypedExpression{parser_impl_.ConvertType(current_type_id), next_expr};
4590   }
4591   return current_expr;
4592 }
4593 
MakeTrue(const Source & source) const4594 const ast::Expression* FunctionEmitter::MakeTrue(const Source& source) const {
4595   return create<ast::BoolLiteralExpression>(source, true);
4596 }
4597 
MakeFalse(const Source & source) const4598 const ast::Expression* FunctionEmitter::MakeFalse(const Source& source) const {
4599   return create<ast::BoolLiteralExpression>(source, false);
4600 }
4601 
MakeVectorShuffle(const spvtools::opt::Instruction & inst)4602 TypedExpression FunctionEmitter::MakeVectorShuffle(
4603     const spvtools::opt::Instruction& inst) {
4604   const auto vec0_id = inst.GetSingleWordInOperand(0);
4605   const auto vec1_id = inst.GetSingleWordInOperand(1);
4606   const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id));
4607   const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id));
4608   const auto vec0_len =
4609       type_mgr_->GetType(vec0.type_id())->AsVector()->element_count();
4610   const auto vec1_len =
4611       type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
4612 
4613   // Idiomatic vector accessors.
4614 
4615   // Generate an ast::TypeConstructor expression.
4616   // Assume the literal indices are valid, and there is a valid number of them.
4617   auto source = GetSourceForInst(inst);
4618   const Vector* result_type =
4619       As<Vector>(parser_impl_.ConvertType(inst.type_id()));
4620   ast::ExpressionList values;
4621   for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
4622     const auto index = inst.GetSingleWordInOperand(i);
4623     if (index < vec0_len) {
4624       auto expr = MakeExpression(vec0_id);
4625       if (!expr) {
4626         return {};
4627       }
4628       values.emplace_back(create<ast::MemberAccessorExpression>(
4629           source, expr.expr, Swizzle(index)));
4630     } else if (index < vec0_len + vec1_len) {
4631       const auto sub_index = index - vec0_len;
4632       TINT_ASSERT(Reader, sub_index < kMaxVectorLen);
4633       auto expr = MakeExpression(vec1_id);
4634       if (!expr) {
4635         return {};
4636       }
4637       values.emplace_back(create<ast::MemberAccessorExpression>(
4638           source, expr.expr, Swizzle(sub_index)));
4639     } else if (index == 0xFFFFFFFF) {
4640       // By rule, this maps to OpUndef.  Instead, make it zero.
4641       values.emplace_back(parser_impl_.MakeNullValue(result_type->type));
4642     } else {
4643       Fail() << "invalid vectorshuffle ID %" << inst.result_id()
4644              << ": index too large: " << index;
4645       return {};
4646     }
4647   }
4648   return {result_type,
4649           builder_.Construct(source, result_type->Build(builder_), values)};
4650 }
4651 
RegisterSpecialBuiltInVariables()4652 bool FunctionEmitter::RegisterSpecialBuiltInVariables() {
4653   size_t index = def_info_.size();
4654   for (auto& special_var : parser_impl_.special_builtins()) {
4655     const auto id = special_var.first;
4656     const auto builtin = special_var.second;
4657     const auto* var = def_use_mgr_->GetDef(id);
4658     def_info_[id] = std::make_unique<DefInfo>(*var, 0, index);
4659     ++index;
4660     auto& def = def_info_[id];
4661     switch (builtin) {
4662       case SpvBuiltInPointSize:
4663         def->skip = SkipReason::kPointSizeBuiltinPointer;
4664         break;
4665       case SpvBuiltInSampleMask: {
4666         // Distinguish between input and output variable.
4667         const auto storage_class =
4668             static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
4669         if (storage_class == SpvStorageClassInput) {
4670           sample_mask_in_id = id;
4671           def->skip = SkipReason::kSampleMaskInBuiltinPointer;
4672         } else {
4673           sample_mask_out_id = id;
4674           def->skip = SkipReason::kSampleMaskOutBuiltinPointer;
4675         }
4676         break;
4677       }
4678       case SpvBuiltInSampleId:
4679       case SpvBuiltInInstanceIndex:
4680       case SpvBuiltInVertexIndex:
4681       case SpvBuiltInLocalInvocationIndex:
4682       case SpvBuiltInLocalInvocationId:
4683       case SpvBuiltInGlobalInvocationId:
4684       case SpvBuiltInWorkgroupId:
4685       case SpvBuiltInNumWorkgroups:
4686         break;
4687       default:
4688         return Fail() << "unrecognized special builtin: " << int(builtin);
4689     }
4690   }
4691   return true;
4692 }
4693 
RegisterLocallyDefinedValues()4694 bool FunctionEmitter::RegisterLocallyDefinedValues() {
4695   // Create a DefInfo for each value definition in this function.
4696   size_t index = def_info_.size();
4697   for (auto block_id : block_order_) {
4698     const auto* block_info = GetBlockInfo(block_id);
4699     const auto block_pos = block_info->pos;
4700     for (const auto& inst : *(block_info->basic_block)) {
4701       const auto result_id = inst.result_id();
4702       if ((result_id == 0) || inst.opcode() == SpvOpLabel) {
4703         continue;
4704       }
4705       def_info_[result_id] = std::make_unique<DefInfo>(inst, block_pos, index);
4706       ++index;
4707       auto& info = def_info_[result_id];
4708 
4709       // Determine storage class for pointer values. Do this in order because
4710       // we might rely on the storage class for a previously-visited definition.
4711       // Logical pointers can't be transmitted through OpPhi, so remaining
4712       // pointer definitions are SSA values, and their definitions must be
4713       // visited before their uses.
4714       const auto* type = type_mgr_->GetType(inst.type_id());
4715       if (type) {
4716         if (type->AsPointer()) {
4717           if (auto* ast_type = parser_impl_.ConvertType(inst.type_id())) {
4718             if (auto* ptr = ast_type->As<Pointer>()) {
4719               info->storage_class = ptr->storage_class;
4720             }
4721           }
4722           switch (inst.opcode()) {
4723             case SpvOpUndef:
4724               return Fail()
4725                      << "undef pointer is not valid: " << inst.PrettyPrint();
4726             case SpvOpVariable:
4727               // Keep the default decision based on the result type.
4728               break;
4729             case SpvOpAccessChain:
4730             case SpvOpInBoundsAccessChain:
4731             case SpvOpCopyObject:
4732               // Inherit from the first operand. We need this so we can pick up
4733               // a remapped storage buffer.
4734               info->storage_class = GetStorageClassForPointerValue(
4735                   inst.GetSingleWordInOperand(0));
4736               break;
4737             default:
4738               return Fail()
4739                      << "pointer defined in function from unknown opcode: "
4740                      << inst.PrettyPrint();
4741           }
4742         }
4743         auto* unwrapped = type;
4744         while (auto* ptr = unwrapped->AsPointer()) {
4745           unwrapped = ptr->pointee_type();
4746         }
4747         if (unwrapped->AsSampler() || unwrapped->AsImage() ||
4748             unwrapped->AsSampledImage()) {
4749           // Defer code generation until the instruction that actually acts on
4750           // the image.
4751           info->skip = SkipReason::kOpaqueObject;
4752         }
4753       }
4754     }
4755   }
4756   return true;
4757 }
4758 
GetStorageClassForPointerValue(uint32_t id)4759 ast::StorageClass FunctionEmitter::GetStorageClassForPointerValue(uint32_t id) {
4760   auto where = def_info_.find(id);
4761   if (where != def_info_.end()) {
4762     auto candidate = where->second.get()->storage_class;
4763     if (candidate != ast::StorageClass::kInvalid) {
4764       return candidate;
4765     }
4766   }
4767   const auto type_id = def_use_mgr_->GetDef(id)->type_id();
4768   if (type_id) {
4769     auto* ast_type = parser_impl_.ConvertType(type_id);
4770     if (auto* ptr = As<Pointer>(ast_type)) {
4771       return ptr->storage_class;
4772     }
4773   }
4774   return ast::StorageClass::kInvalid;
4775 }
4776 
RemapStorageClass(const Type * type,uint32_t result_id)4777 const Type* FunctionEmitter::RemapStorageClass(const Type* type,
4778                                                uint32_t result_id) {
4779   if (auto* ast_ptr_type = As<Pointer>(type)) {
4780     // Remap an old-style storage buffer pointer to a new-style storage
4781     // buffer pointer.
4782     const auto sc = GetStorageClassForPointerValue(result_id);
4783     if (ast_ptr_type->storage_class != sc) {
4784       return ty_.Pointer(ast_ptr_type->type, sc);
4785     }
4786   }
4787   return type;
4788 }
4789 
FindValuesNeedingNamedOrHoistedDefinition()4790 void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
4791   // Mark vector operands of OpVectorShuffle as needing a named definition,
4792   // but only if they are defined in this function as well.
4793   auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
4794                                      int in_operand_index) {
4795     const auto id = inst.GetSingleWordInOperand(in_operand_index);
4796     auto* const operand_def = GetDefInfo(id);
4797     if (operand_def) {
4798       operand_def->requires_named_const_def = true;
4799     }
4800   };
4801   for (auto& id_def_info_pair : def_info_) {
4802     const auto& inst = id_def_info_pair.second->inst;
4803     const auto opcode = inst.opcode();
4804     if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
4805       // We might access the vector operands multiple times. Make sure they
4806       // are evaluated only once.
4807       require_named_const_def(inst, 0);
4808       require_named_const_def(inst, 1);
4809     }
4810     if (parser_impl_.IsGlslExtendedInstruction(inst)) {
4811       // Some emulations of GLSLstd450 instructions evaluate certain operands
4812       // multiple times. Ensure their expressions are evaluated only once.
4813       switch (inst.GetSingleWordInOperand(1)) {
4814         case GLSLstd450FaceForward:
4815           // The "normal" operand expression is used twice in code generation.
4816           require_named_const_def(inst, 2);
4817           break;
4818         case GLSLstd450Reflect:
4819           require_named_const_def(inst, 2);  // Incident
4820           require_named_const_def(inst, 3);  // Normal
4821           break;
4822         default:
4823           break;
4824       }
4825     }
4826   }
4827 
4828   // Scan uses of locally defined IDs, in function block order.
4829   for (auto block_id : block_order_) {
4830     const auto* block_info = GetBlockInfo(block_id);
4831     const auto block_pos = block_info->pos;
4832     for (const auto& inst : *(block_info->basic_block)) {
4833       // Update bookkeeping for locally-defined IDs used by this instruction.
4834       inst.ForEachInId([this, block_pos, block_info](const uint32_t* id_ptr) {
4835         auto* def_info = GetDefInfo(*id_ptr);
4836         if (def_info) {
4837           // Update usage count.
4838           def_info->num_uses++;
4839           // Update usage span.
4840           def_info->last_use_pos = std::max(def_info->last_use_pos, block_pos);
4841 
4842           // Determine whether this ID is defined in a different construct
4843           // from this use.
4844           const auto defining_block = block_order_[def_info->block_pos];
4845           const auto* def_in_construct =
4846               GetBlockInfo(defining_block)->construct;
4847           if (def_in_construct != block_info->construct) {
4848             def_info->used_in_another_construct = true;
4849           }
4850         }
4851       });
4852 
4853       if (inst.opcode() == SpvOpPhi) {
4854         // Declare a name for the variable used to carry values to a phi.
4855         const auto phi_id = inst.result_id();
4856         auto* phi_def_info = GetDefInfo(phi_id);
4857         phi_def_info->phi_var =
4858             namer_.MakeDerivedName(namer_.Name(phi_id) + "_phi");
4859         // Track all the places where we need to mention the variable,
4860         // so we can place its declaration.  First, record the location of
4861         // the read from the variable.
4862         uint32_t first_pos = block_pos;
4863         uint32_t last_pos = block_pos;
4864         // Record the assignments that will propagate values from predecessor
4865         // blocks.
4866         for (uint32_t i = 0; i + 1 < inst.NumInOperands(); i += 2) {
4867           const uint32_t value_id = inst.GetSingleWordInOperand(i);
4868           const uint32_t pred_block_id = inst.GetSingleWordInOperand(i + 1);
4869           auto* pred_block_info = GetBlockInfo(pred_block_id);
4870           // The predecessor might not be in the block order at all, so we
4871           // need this guard.
4872           if (IsInBlockOrder(pred_block_info)) {
4873             // Record the assignment that needs to occur at the end
4874             // of the predecessor block.
4875             pred_block_info->phi_assignments.push_back({phi_id, value_id});
4876             first_pos = std::min(first_pos, pred_block_info->pos);
4877             last_pos = std::max(last_pos, pred_block_info->pos);
4878           }
4879         }
4880 
4881         // Schedule the declaration of the state variable.
4882         const auto* enclosing_construct =
4883             GetEnclosingScope(first_pos, last_pos);
4884         GetBlockInfo(enclosing_construct->begin_id)
4885             ->phis_needing_state_vars.push_back(phi_id);
4886       }
4887     }
4888   }
4889 
4890   // For an ID defined in this function, determine if its evaluation and
4891   // potential declaration needs special handling:
4892   // - Compensate for the fact that dominance does not map directly to scope.
4893   //   A definition could dominate its use, but a named definition in WGSL
4894   //   at the location of the definition could go out of scope by the time
4895   //   you reach the use.  In that case, we hoist the definition to a basic
4896   //   block at the smallest scope enclosing both the definition and all
4897   //   its uses.
4898   // - If value is used in a different construct than its definition, then it
4899   //   needs a named constant definition.  Otherwise we might sink an
4900   //   expensive computation into control flow, and hence change performance.
4901   for (auto& id_def_info_pair : def_info_) {
4902     const auto def_id = id_def_info_pair.first;
4903     auto* def_info = id_def_info_pair.second.get();
4904     if (def_info->num_uses == 0) {
4905       // There is no need to adjust the location of the declaration.
4906       continue;
4907     }
4908     // The first use must be the at the SSA definition, because block order
4909     // respects dominance.
4910     const auto first_pos = def_info->block_pos;
4911     const auto last_use_pos = def_info->last_use_pos;
4912 
4913     const auto* def_in_construct =
4914         GetBlockInfo(block_order_[first_pos])->construct;
4915     // A definition in the first block of an kIfSelection or kSwitchSelection
4916     // occurs before the branch, and so that definition should count as
4917     // having been defined at the scope of the parent construct.
4918     if (first_pos == def_in_construct->begin_pos) {
4919       if ((def_in_construct->kind == Construct::kIfSelection) ||
4920           (def_in_construct->kind == Construct::kSwitchSelection)) {
4921         def_in_construct = def_in_construct->parent;
4922       }
4923     }
4924 
4925     bool should_hoist = false;
4926     if (!def_in_construct->ContainsPos(last_use_pos)) {
4927       // To satisfy scoping, we have to hoist the definition out to an enclosing
4928       // construct.
4929       should_hoist = true;
4930     } else {
4931       // Avoid moving combinatorial values across constructs.  This is a
4932       // simple heuristic to avoid changing the cost of an operation
4933       // by moving it into or out of a loop, for example.
4934       if ((def_info->storage_class == ast::StorageClass::kInvalid) &&
4935           def_info->used_in_another_construct) {
4936         should_hoist = true;
4937       }
4938     }
4939 
4940     if (should_hoist) {
4941       const auto* enclosing_construct =
4942           GetEnclosingScope(first_pos, last_use_pos);
4943       if (enclosing_construct == def_in_construct) {
4944         // We can use a plain 'const' definition.
4945         def_info->requires_named_const_def = true;
4946       } else {
4947         // We need to make a hoisted variable definition.
4948         // TODO(dneto): Handle non-storable types, particularly pointers.
4949         def_info->requires_hoisted_def = true;
4950         auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id);
4951         hoist_to_block->hoisted_ids.push_back(def_id);
4952       }
4953     }
4954   }
4955 }
4956 
GetEnclosingScope(uint32_t first_pos,uint32_t last_pos) const4957 const Construct* FunctionEmitter::GetEnclosingScope(uint32_t first_pos,
4958                                                     uint32_t last_pos) const {
4959   const auto* enclosing_construct =
4960       GetBlockInfo(block_order_[first_pos])->construct;
4961   TINT_ASSERT(Reader, enclosing_construct != nullptr);
4962   // Constructs are strictly nesting, so follow parent pointers
4963   while (enclosing_construct &&
4964          !enclosing_construct->ScopeContainsPos(last_pos)) {
4965     // The scope of a continue construct is enclosed in its associated loop
4966     // construct, but they are siblings in our construct tree.
4967     const auto* sibling_loop = SiblingLoopConstruct(enclosing_construct);
4968     // Go to the sibling loop if it exists, otherwise walk up to the parent.
4969     enclosing_construct =
4970         sibling_loop ? sibling_loop : enclosing_construct->parent;
4971   }
4972   // At worst, we go all the way out to the function construct.
4973   TINT_ASSERT(Reader, enclosing_construct != nullptr);
4974   return enclosing_construct;
4975 }
4976 
MakeNumericConversion(const spvtools::opt::Instruction & inst)4977 TypedExpression FunctionEmitter::MakeNumericConversion(
4978     const spvtools::opt::Instruction& inst) {
4979   const auto opcode = inst.opcode();
4980   auto* requested_type = parser_impl_.ConvertType(inst.type_id());
4981   auto arg_expr = MakeOperand(inst, 0);
4982   if (!arg_expr) {
4983     return {};
4984   }
4985   arg_expr.type = arg_expr.type->UnwrapRef();
4986 
4987   const Type* expr_type = nullptr;
4988   if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) {
4989     if (arg_expr.type->IsIntegerScalarOrVector()) {
4990       expr_type = requested_type;
4991     } else {
4992       Fail() << "operand for conversion to floating point must be integral "
4993                 "scalar or vector: "
4994              << inst.PrettyPrint();
4995     }
4996   } else if (inst.opcode() == SpvOpConvertFToU) {
4997     if (arg_expr.type->IsFloatScalarOrVector()) {
4998       expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type);
4999     } else {
5000       Fail() << "operand for conversion to unsigned integer must be floating "
5001                 "point scalar or vector: "
5002              << inst.PrettyPrint();
5003     }
5004   } else if (inst.opcode() == SpvOpConvertFToS) {
5005     if (arg_expr.type->IsFloatScalarOrVector()) {
5006       expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type);
5007     } else {
5008       Fail() << "operand for conversion to signed integer must be floating "
5009                 "point scalar or vector: "
5010              << inst.PrettyPrint();
5011     }
5012   }
5013   if (expr_type == nullptr) {
5014     // The diagnostic has already been emitted.
5015     return {};
5016   }
5017 
5018   ast::ExpressionList params;
5019   params.push_back(arg_expr.expr);
5020   TypedExpression result{
5021       expr_type,
5022       builder_.Construct(GetSourceForInst(inst), expr_type->Build(builder_),
5023                          std::move(params))};
5024 
5025   if (requested_type == expr_type) {
5026     return result;
5027   }
5028   return {requested_type, create<ast::BitcastExpression>(
5029                               GetSourceForInst(inst),
5030                               requested_type->Build(builder_), result.expr)};
5031 }
5032 
EmitFunctionCall(const spvtools::opt::Instruction & inst)5033 bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
5034   // We ignore function attributes such as Inline, DontInline, Pure, Const.
5035   auto name = namer_.Name(inst.GetSingleWordInOperand(0));
5036   auto* function = create<ast::IdentifierExpression>(
5037       Source{}, builder_.Symbols().Register(name));
5038 
5039   ast::ExpressionList args;
5040   for (uint32_t iarg = 1; iarg < inst.NumInOperands(); ++iarg) {
5041     auto expr = MakeOperand(inst, iarg);
5042     if (!expr) {
5043       return false;
5044     }
5045     // Functions cannot use references as parameters, so we need to pass by
5046     // pointer if the operand is of pointer type.
5047     expr = AddressOfIfNeeded(
5048         expr, def_use_mgr_->GetDef(inst.GetSingleWordInOperand(iarg)));
5049     args.emplace_back(expr.expr);
5050   }
5051   if (failed()) {
5052     return false;
5053   }
5054   auto* call_expr =
5055       create<ast::CallExpression>(Source{}, function, std::move(args));
5056   auto* result_type = parser_impl_.ConvertType(inst.type_id());
5057   if (!result_type) {
5058     return Fail() << "internal error: no mapped type result of call: "
5059                   << inst.PrettyPrint();
5060   }
5061 
5062   if (result_type->Is<Void>()) {
5063     return nullptr !=
5064            AddStatement(create<ast::CallStatement>(Source{}, call_expr));
5065   }
5066 
5067   return EmitConstDefOrWriteToHoistedVar(inst, {result_type, call_expr});
5068 }
5069 
EmitControlBarrier(const spvtools::opt::Instruction & inst)5070 bool FunctionEmitter::EmitControlBarrier(
5071     const spvtools::opt::Instruction& inst) {
5072   uint32_t operands[3];
5073   for (int i = 0; i < 3; i++) {
5074     auto id = inst.GetSingleWordInOperand(i);
5075     if (auto* constant = constant_mgr_->FindDeclaredConstant(id)) {
5076       operands[i] = constant->GetU32();
5077     } else {
5078       return Fail() << "invalid or missing operands for control barrier";
5079     }
5080   }
5081 
5082   uint32_t execution = operands[0];
5083   uint32_t memory = operands[1];
5084   uint32_t semantics = operands[2];
5085 
5086   if (execution != SpvScopeWorkgroup) {
5087     return Fail() << "unsupported control barrier execution scope: "
5088                   << "expected Workgroup (2), got: " << execution;
5089   }
5090   if (semantics & SpvMemorySemanticsAcquireReleaseMask) {
5091     semantics &= ~SpvMemorySemanticsAcquireReleaseMask;
5092   } else {
5093     return Fail() << "control barrier semantics requires acquire and release";
5094   }
5095   if (semantics & SpvMemorySemanticsWorkgroupMemoryMask) {
5096     if (memory != SpvScopeWorkgroup) {
5097       return Fail() << "workgroupBarrier requires workgroup memory scope";
5098     }
5099     AddStatement(create<ast::CallStatement>(builder_.Call("workgroupBarrier")));
5100     semantics &= ~SpvMemorySemanticsWorkgroupMemoryMask;
5101   }
5102   if (semantics & SpvMemorySemanticsUniformMemoryMask) {
5103     if (memory != SpvScopeDevice) {
5104       return Fail() << "storageBarrier requires device memory scope";
5105     }
5106     AddStatement(create<ast::CallStatement>(builder_.Call("storageBarrier")));
5107     semantics &= ~SpvMemorySemanticsUniformMemoryMask;
5108   }
5109   if (semantics) {
5110     return Fail() << "unsupported control barrier semantics: " << semantics;
5111   }
5112   return true;
5113 }
5114 
MakeIntrinsicCall(const spvtools::opt::Instruction & inst)5115 TypedExpression FunctionEmitter::MakeIntrinsicCall(
5116     const spvtools::opt::Instruction& inst) {
5117   const auto intrinsic = GetIntrinsic(inst.opcode());
5118   auto* name = sem::str(intrinsic);
5119   auto* ident = create<ast::IdentifierExpression>(
5120       Source{}, builder_.Symbols().Register(name));
5121 
5122   ast::ExpressionList params;
5123   const Type* first_operand_type = nullptr;
5124   for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
5125     TypedExpression operand = MakeOperand(inst, iarg);
5126     if (first_operand_type == nullptr) {
5127       first_operand_type = operand.type;
5128     }
5129     params.emplace_back(operand.expr);
5130   }
5131   auto* call_expr =
5132       create<ast::CallExpression>(Source{}, ident, std::move(params));
5133   auto* result_type = parser_impl_.ConvertType(inst.type_id());
5134   if (!result_type) {
5135     Fail() << "internal error: no mapped type result of call: "
5136            << inst.PrettyPrint();
5137     return {};
5138   }
5139   TypedExpression call{result_type, call_expr};
5140   return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
5141 }
5142 
MakeSimpleSelect(const spvtools::opt::Instruction & inst)5143 TypedExpression FunctionEmitter::MakeSimpleSelect(
5144     const spvtools::opt::Instruction& inst) {
5145   auto condition = MakeOperand(inst, 0);
5146   auto true_value = MakeOperand(inst, 1);
5147   auto false_value = MakeOperand(inst, 2);
5148 
5149   // SPIR-V validation requires:
5150   // - the condition to be bool or bool vector, so we don't check it here.
5151   // - true_value false_value, and result type to match.
5152   // - you can't select over pointers or pointer vectors, unless you also have
5153   //   a VariablePointers* capability, which is not allowed in by WebGPU.
5154   auto* op_ty = true_value.type;
5155   if (op_ty->Is<Vector>() || op_ty->IsFloatScalar() ||
5156       op_ty->IsIntegerScalar() || op_ty->Is<Bool>()) {
5157     ast::ExpressionList params;
5158     params.push_back(false_value.expr);
5159     params.push_back(true_value.expr);
5160     // The condition goes last.
5161     params.push_back(condition.expr);
5162     return {op_ty, create<ast::CallExpression>(
5163                        Source{},
5164                        create<ast::IdentifierExpression>(
5165                            Source{}, builder_.Symbols().Register("select")),
5166                        std::move(params))};
5167   }
5168   return {};
5169 }
5170 
GetSourceForInst(const spvtools::opt::Instruction & inst) const5171 Source FunctionEmitter::GetSourceForInst(
5172     const spvtools::opt::Instruction& inst) const {
5173   return parser_impl_.GetSourceForInst(&inst);
5174 }
5175 
GetImage(const spvtools::opt::Instruction & inst)5176 const spvtools::opt::Instruction* FunctionEmitter::GetImage(
5177     const spvtools::opt::Instruction& inst) {
5178   if (inst.NumInOperands() == 0) {
5179     Fail() << "not an image access instruction: " << inst.PrettyPrint();
5180     return nullptr;
5181   }
5182   // The image or sampled image operand is always the first operand.
5183   const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
5184   const auto* image = parser_impl_.GetMemoryObjectDeclarationForHandle(
5185       image_or_sampled_image_operand_id, true);
5186   if (!image) {
5187     Fail() << "internal error: couldn't find image for " << inst.PrettyPrint();
5188     return nullptr;
5189   }
5190   return image;
5191 }
5192 
GetImageType(const spvtools::opt::Instruction & image)5193 const Texture* FunctionEmitter::GetImageType(
5194     const spvtools::opt::Instruction& image) {
5195   const Pointer* ptr_type = parser_impl_.GetTypeForHandleVar(image);
5196   if (!parser_impl_.success()) {
5197     Fail();
5198     return {};
5199   }
5200   if (!ptr_type) {
5201     Fail() << "invalid texture type for " << image.PrettyPrint();
5202     return {};
5203   }
5204   auto* result = ptr_type->type->UnwrapAll()->As<Texture>();
5205   if (!result) {
5206     Fail() << "invalid texture type for " << image.PrettyPrint();
5207     return {};
5208   }
5209   return result;
5210 }
5211 
GetImageExpression(const spvtools::opt::Instruction & inst)5212 const ast::Expression* FunctionEmitter::GetImageExpression(
5213     const spvtools::opt::Instruction& inst) {
5214   auto* image = GetImage(inst);
5215   if (!image) {
5216     return nullptr;
5217   }
5218   auto name = namer_.Name(image->result_id());
5219   return create<ast::IdentifierExpression>(GetSourceForInst(inst),
5220                                            builder_.Symbols().Register(name));
5221 }
5222 
GetSamplerExpression(const spvtools::opt::Instruction & inst)5223 const ast::Expression* FunctionEmitter::GetSamplerExpression(
5224     const spvtools::opt::Instruction& inst) {
5225   // The sampled image operand is always the first operand.
5226   const auto image_or_sampled_image_operand_id = inst.GetSingleWordInOperand(0);
5227   const auto* image = parser_impl_.GetMemoryObjectDeclarationForHandle(
5228       image_or_sampled_image_operand_id, false);
5229   if (!image) {
5230     Fail() << "internal error: couldn't find sampler for "
5231            << inst.PrettyPrint();
5232     return nullptr;
5233   }
5234   auto name = namer_.Name(image->result_id());
5235   return create<ast::IdentifierExpression>(GetSourceForInst(inst),
5236                                            builder_.Symbols().Register(name));
5237 }
5238 
EmitImageAccess(const spvtools::opt::Instruction & inst)5239 bool FunctionEmitter::EmitImageAccess(const spvtools::opt::Instruction& inst) {
5240   ast::ExpressionList params;
5241   const auto opcode = inst.opcode();
5242 
5243   // Form the texture operand.
5244   const spvtools::opt::Instruction* image = GetImage(inst);
5245   if (!image) {
5246     return false;
5247   }
5248   params.push_back(GetImageExpression(inst));
5249 
5250   if (IsSampledImageAccess(opcode)) {
5251     // Form the sampler operand.
5252     if (auto* sampler = GetSamplerExpression(inst)) {
5253       params.push_back(sampler);
5254     } else {
5255       return false;
5256     }
5257   }
5258 
5259   const Pointer* texture_ptr_type = parser_impl_.GetTypeForHandleVar(*image);
5260   if (!texture_ptr_type) {
5261     return Fail();
5262   }
5263   const Texture* texture_type =
5264       texture_ptr_type->type->UnwrapAll()->As<Texture>();
5265 
5266   if (!texture_type) {
5267     return Fail();
5268   }
5269 
5270   // This is the SPIR-V operand index.  We're done with the first operand.
5271   uint32_t arg_index = 1;
5272 
5273   // Push the coordinates operands.
5274   auto coords = MakeCoordinateOperandsForImageAccess(inst);
5275   if (coords.empty()) {
5276     return false;
5277   }
5278   params.insert(params.end(), coords.begin(), coords.end());
5279   // Skip the coordinates operand.
5280   arg_index++;
5281 
5282   const auto num_args = inst.NumInOperands();
5283 
5284   std::string builtin_name;
5285   bool use_level_of_detail_suffix = true;
5286   bool is_dref_sample = false;
5287   bool is_non_dref_sample = false;
5288   switch (opcode) {
5289     case SpvOpImageSampleImplicitLod:
5290     case SpvOpImageSampleExplicitLod:
5291     case SpvOpImageSampleProjImplicitLod:
5292     case SpvOpImageSampleProjExplicitLod:
5293       is_non_dref_sample = true;
5294       builtin_name = "textureSample";
5295       break;
5296     case SpvOpImageSampleDrefImplicitLod:
5297     case SpvOpImageSampleDrefExplicitLod:
5298     case SpvOpImageSampleProjDrefImplicitLod:
5299     case SpvOpImageSampleProjDrefExplicitLod:
5300       is_dref_sample = true;
5301       builtin_name = "textureSampleCompare";
5302       if (arg_index < num_args) {
5303         params.push_back(MakeOperand(inst, arg_index).expr);
5304         arg_index++;
5305       } else {
5306         return Fail()
5307                << "image depth-compare instruction is missing a Dref operand: "
5308                << inst.PrettyPrint();
5309       }
5310       break;
5311     case SpvOpImageGather:
5312     case SpvOpImageDrefGather:
5313       return Fail() << " image gather is not yet supported";
5314     case SpvOpImageFetch:
5315     case SpvOpImageRead:
5316       // Read a single texel from a sampled or storage image.
5317       builtin_name = "textureLoad";
5318       use_level_of_detail_suffix = false;
5319       break;
5320     case SpvOpImageWrite:
5321       builtin_name = "textureStore";
5322       use_level_of_detail_suffix = false;
5323       if (arg_index < num_args) {
5324         auto texel = MakeOperand(inst, arg_index);
5325         auto* converted_texel =
5326             ConvertTexelForStorage(inst, texel, texture_type);
5327         if (!converted_texel) {
5328           return false;
5329         }
5330 
5331         params.push_back(converted_texel);
5332         arg_index++;
5333       } else {
5334         return Fail() << "image write is missing a Texel operand: "
5335                       << inst.PrettyPrint();
5336       }
5337       break;
5338     default:
5339       return Fail() << "internal error: unrecognized image access: "
5340                     << inst.PrettyPrint();
5341   }
5342 
5343   // Loop over the image operands, looking for extra operands to the builtin.
5344   // Except we uroll the loop.
5345   uint32_t image_operands_mask = 0;
5346   if (arg_index < num_args) {
5347     image_operands_mask = inst.GetSingleWordInOperand(arg_index);
5348     arg_index++;
5349   }
5350   if (arg_index < num_args &&
5351       (image_operands_mask & SpvImageOperandsBiasMask)) {
5352     if (is_dref_sample) {
5353       return Fail() << "WGSL does not support depth-reference sampling with "
5354                        "level-of-detail bias: "
5355                     << inst.PrettyPrint();
5356     }
5357     builtin_name += "Bias";
5358     params.push_back(MakeOperand(inst, arg_index).expr);
5359     image_operands_mask ^= SpvImageOperandsBiasMask;
5360     arg_index++;
5361   }
5362   if (arg_index < num_args && (image_operands_mask & SpvImageOperandsLodMask)) {
5363     if (use_level_of_detail_suffix) {
5364       builtin_name += "Level";
5365     }
5366     if (is_dref_sample) {
5367       // Metal only supports Lod = 0 for comparison sampling without
5368       // derivatives.
5369       if (!IsFloatZero(inst.GetSingleWordInOperand(arg_index))) {
5370         return Fail() << "WGSL comparison sampling without derivatives "
5371                          "requires level-of-detail 0.0"
5372                       << inst.PrettyPrint();
5373       }
5374       // Don't generate the Lod argument.
5375     } else {
5376       // Generate the Lod argument.
5377       TypedExpression lod = MakeOperand(inst, arg_index);
5378       // When sampling from a depth texture, the Lod operand must be an I32.
5379       if (texture_type->Is<DepthTexture>()) {
5380         // Convert it to a signed integer type.
5381         lod = ToI32(lod);
5382       }
5383       params.push_back(lod.expr);
5384     }
5385 
5386     image_operands_mask ^= SpvImageOperandsLodMask;
5387     arg_index++;
5388   } else if ((opcode == SpvOpImageFetch || opcode == SpvOpImageRead) &&
5389              !texture_type
5390                   ->IsAnyOf<DepthMultisampledTexture, MultisampledTexture>()) {
5391     // textureLoad requires an explicit level-of-detail parameter for
5392     // non-multisampled texture types.
5393     params.push_back(parser_impl_.MakeNullValue(ty_.I32()));
5394   }
5395   if (arg_index + 1 < num_args &&
5396       (image_operands_mask & SpvImageOperandsGradMask)) {
5397     if (is_dref_sample) {
5398       return Fail() << "WGSL does not support depth-reference sampling with "
5399                        "explicit gradient: "
5400                     << inst.PrettyPrint();
5401     }
5402     builtin_name += "Grad";
5403     params.push_back(MakeOperand(inst, arg_index).expr);
5404     params.push_back(MakeOperand(inst, arg_index + 1).expr);
5405     image_operands_mask ^= SpvImageOperandsGradMask;
5406     arg_index += 2;
5407   }
5408   if (arg_index < num_args &&
5409       (image_operands_mask & SpvImageOperandsConstOffsetMask)) {
5410     if (!IsImageSampling(opcode)) {
5411       return Fail() << "ConstOffset is only permitted for sampling operations: "
5412                     << inst.PrettyPrint();
5413     }
5414     switch (texture_type->dims) {
5415       case ast::TextureDimension::k2d:
5416       case ast::TextureDimension::k2dArray:
5417       case ast::TextureDimension::k3d:
5418         break;
5419       default:
5420         return Fail() << "ConstOffset is only permitted for 2D, 2D Arrayed, "
5421                          "and 3D textures: "
5422                       << inst.PrettyPrint();
5423     }
5424 
5425     params.push_back(ToSignedIfUnsigned(MakeOperand(inst, arg_index)).expr);
5426     image_operands_mask ^= SpvImageOperandsConstOffsetMask;
5427     arg_index++;
5428   }
5429   if (arg_index < num_args &&
5430       (image_operands_mask & SpvImageOperandsSampleMask)) {
5431     // TODO(dneto): only permitted with ImageFetch
5432     params.push_back(ToI32(MakeOperand(inst, arg_index)).expr);
5433     image_operands_mask ^= SpvImageOperandsSampleMask;
5434     arg_index++;
5435   }
5436   if (image_operands_mask) {
5437     return Fail() << "unsupported image operands (" << image_operands_mask
5438                   << "): " << inst.PrettyPrint();
5439   }
5440 
5441   auto* ident = create<ast::IdentifierExpression>(
5442       Source{}, builder_.Symbols().Register(builtin_name));
5443   auto* call_expr =
5444       create<ast::CallExpression>(Source{}, ident, std::move(params));
5445 
5446   if (inst.type_id() != 0) {
5447     // It returns a value.
5448     const ast::Expression* value = call_expr;
5449 
5450     // The result type, derived from the SPIR-V instruction.
5451     auto* result_type = parser_impl_.ConvertType(inst.type_id());
5452     auto* result_component_type = result_type;
5453     if (auto* result_vector_type = As<Vector>(result_type)) {
5454       result_component_type = result_vector_type->type;
5455     }
5456 
5457     // For depth textures, the arity might mot match WGSL:
5458     //  Operation           SPIR-V                     WGSL
5459     //   normal sampling     vec4  ImplicitLod          f32
5460     //   normal sampling     vec4  ExplicitLod          f32
5461     //   compare sample      f32   DrefImplicitLod      f32
5462     //   compare sample      f32   DrefExplicitLod      f32
5463     //   texel load          vec4  ImageFetch           f32
5464     //   normal gather       vec4  ImageGather          vec4 TODO(dneto)
5465     //   dref gather         vec4  ImageFetch           vec4 TODO(dneto)
5466     // Construct a 4-element vector with the result from the builtin in the
5467     // first component.
5468     if (texture_type->IsAnyOf<DepthTexture, DepthMultisampledTexture>()) {
5469       if (is_non_dref_sample || (opcode == SpvOpImageFetch)) {
5470         value = builder_.Construct(
5471             Source{},
5472             result_type->Build(builder_),  // a vec4
5473             ast::ExpressionList{
5474                 value, parser_impl_.MakeNullValue(result_component_type),
5475                 parser_impl_.MakeNullValue(result_component_type),
5476                 parser_impl_.MakeNullValue(result_component_type)});
5477       }
5478     }
5479 
5480     // If necessary, convert the result to the signedness of the instruction
5481     // result type. Compare the SPIR-V image's sampled component type with the
5482     // component of the result type of the SPIR-V instruction.
5483     auto* spirv_image_type =
5484         parser_impl_.GetSpirvTypeForHandleMemoryObjectDeclaration(*image);
5485     if (!spirv_image_type || (spirv_image_type->opcode() != SpvOpTypeImage)) {
5486       return Fail() << "invalid image type for image memory object declaration "
5487                     << image->PrettyPrint();
5488     }
5489     auto* expected_component_type =
5490         parser_impl_.ConvertType(spirv_image_type->GetSingleWordInOperand(0));
5491     if (expected_component_type != result_component_type) {
5492       // This occurs if one is signed integer and the other is unsigned integer,
5493       // or vice versa. Perform a bitcast.
5494       value = create<ast::BitcastExpression>(
5495           Source{}, result_type->Build(builder_), call_expr);
5496     }
5497     if (!expected_component_type->Is<F32>() && IsSampledImageAccess(opcode)) {
5498       // WGSL permits sampled image access only on float textures.
5499       // Reject this case in the SPIR-V reader, at least until SPIR-V validation
5500       // catches up with this rule and can reject it earlier in the workflow.
5501       return Fail() << "sampled image must have float component type";
5502     }
5503 
5504     EmitConstDefOrWriteToHoistedVar(inst, {result_type, value});
5505   } else {
5506     // It's an image write. No value is returned, so make a statement out
5507     // of the call.
5508     AddStatement(create<ast::CallStatement>(Source{}, call_expr));
5509   }
5510   return success();
5511 }
5512 
EmitImageQuery(const spvtools::opt::Instruction & inst)5513 bool FunctionEmitter::EmitImageQuery(const spvtools::opt::Instruction& inst) {
5514   // TODO(dneto): Reject cases that are valid in Vulkan but invalid in WGSL.
5515   const spvtools::opt::Instruction* image = GetImage(inst);
5516   if (!image) {
5517     return false;
5518   }
5519   auto* texture_type = GetImageType(*image);
5520   if (!texture_type) {
5521     return false;
5522   }
5523 
5524   const auto opcode = inst.opcode();
5525   switch (opcode) {
5526     case SpvOpImageQuerySize:
5527     case SpvOpImageQuerySizeLod: {
5528       ast::ExpressionList exprs;
5529       // Invoke textureDimensions.
5530       // If the texture is arrayed, combine with the result from
5531       // textureNumLayers.
5532       auto* dims_ident = create<ast::IdentifierExpression>(
5533           Source{}, builder_.Symbols().Register("textureDimensions"));
5534       ast::ExpressionList dims_args{GetImageExpression(inst)};
5535       if (opcode == SpvOpImageQuerySizeLod) {
5536         dims_args.push_back(ToI32(MakeOperand(inst, 1)).expr);
5537       }
5538       const ast::Expression* dims_call =
5539           create<ast::CallExpression>(Source{}, dims_ident, dims_args);
5540       auto dims = texture_type->dims;
5541       if ((dims == ast::TextureDimension::kCube) ||
5542           (dims == ast::TextureDimension::kCubeArray)) {
5543         // textureDimension returns a 3-element vector but SPIR-V expects 2.
5544         dims_call = create<ast::MemberAccessorExpression>(Source{}, dims_call,
5545                                                           PrefixSwizzle(2));
5546       }
5547       exprs.push_back(dims_call);
5548       if (ast::IsTextureArray(dims)) {
5549         auto* layers_ident = create<ast::IdentifierExpression>(
5550             Source{}, builder_.Symbols().Register("textureNumLayers"));
5551         exprs.push_back(create<ast::CallExpression>(
5552             Source{}, layers_ident,
5553             ast::ExpressionList{GetImageExpression(inst)}));
5554       }
5555       auto* result_type = parser_impl_.ConvertType(inst.type_id());
5556       TypedExpression expr = {
5557           result_type,
5558           builder_.Construct(Source{}, result_type->Build(builder_), exprs)};
5559       return EmitConstDefOrWriteToHoistedVar(inst, expr);
5560     }
5561     case SpvOpImageQueryLod:
5562       return Fail() << "WGSL does not support querying the level of detail of "
5563                        "an image: "
5564                     << inst.PrettyPrint();
5565     case SpvOpImageQueryLevels:
5566     case SpvOpImageQuerySamples: {
5567       const auto* name = (opcode == SpvOpImageQueryLevels)
5568                              ? "textureNumLevels"
5569                              : "textureNumSamples";
5570       auto* levels_ident = create<ast::IdentifierExpression>(
5571           Source{}, builder_.Symbols().Register(name));
5572       const ast::Expression* ast_expr = create<ast::CallExpression>(
5573           Source{}, levels_ident,
5574           ast::ExpressionList{GetImageExpression(inst)});
5575       auto* result_type = parser_impl_.ConvertType(inst.type_id());
5576       // The SPIR-V result type must be integer scalar. The WGSL bulitin
5577       // returns i32. If they aren't the same then convert the result.
5578       if (!result_type->Is<I32>()) {
5579         ast_expr = builder_.Construct(Source{}, result_type->Build(builder_),
5580                                       ast::ExpressionList{ast_expr});
5581       }
5582       TypedExpression expr{result_type, ast_expr};
5583       return EmitConstDefOrWriteToHoistedVar(inst, expr);
5584     }
5585     default:
5586       break;
5587   }
5588   return Fail() << "unhandled image query: " << inst.PrettyPrint();
5589 }
5590 
MakeCoordinateOperandsForImageAccess(const spvtools::opt::Instruction & inst)5591 ast::ExpressionList FunctionEmitter::MakeCoordinateOperandsForImageAccess(
5592     const spvtools::opt::Instruction& inst) {
5593   if (!parser_impl_.success()) {
5594     Fail();
5595     return {};
5596   }
5597   const spvtools::opt::Instruction* image = GetImage(inst);
5598   if (!image) {
5599     return {};
5600   }
5601   if (inst.NumInOperands() < 1) {
5602     Fail() << "image access is missing a coordinate parameter: "
5603            << inst.PrettyPrint();
5604     return {};
5605   }
5606 
5607   // In SPIR-V for Shader, coordinates are:
5608   //  - floating point for sampling, dref sampling, gather, dref gather
5609   //  - integral for fetch, read, write
5610   // In WGSL:
5611   //  - floating point for sampling, dref sampling, gather, dref gather
5612   //  - signed integral for textureLoad, textureStore
5613   //
5614   // The only conversions we have to do for WGSL are:
5615   //  - When the coordinates are unsigned integral, convert them to signed.
5616   //  - Array index is always i32
5617 
5618   // The coordinates parameter is always in position 1.
5619   TypedExpression raw_coords(MakeOperand(inst, 1));
5620   if (!raw_coords) {
5621     return {};
5622   }
5623   const Texture* texture_type = GetImageType(*image);
5624   if (!texture_type) {
5625     return {};
5626   }
5627   ast::TextureDimension dim = texture_type->dims;
5628   // Number of regular coordinates.
5629   uint32_t num_axes = ast::NumCoordinateAxes(dim);
5630   bool is_arrayed = ast::IsTextureArray(dim);
5631   if ((num_axes == 0) || (num_axes > 3)) {
5632     Fail() << "unsupported image dimensionality for "
5633            << texture_type->TypeInfo().name << " prompted by "
5634            << inst.PrettyPrint();
5635   }
5636   bool is_proj = false;
5637   switch (inst.opcode()) {
5638     case SpvOpImageSampleProjImplicitLod:
5639     case SpvOpImageSampleProjExplicitLod:
5640     case SpvOpImageSampleProjDrefImplicitLod:
5641     case SpvOpImageSampleProjDrefExplicitLod:
5642       is_proj = true;
5643       break;
5644     default:
5645       break;
5646   }
5647 
5648   const auto num_coords_required =
5649       num_axes + (is_arrayed ? 1 : 0) + (is_proj ? 1 : 0);
5650   uint32_t num_coords_supplied = 0;
5651   auto* component_type = raw_coords.type;
5652   if (component_type->IsFloatScalar() || component_type->IsIntegerScalar()) {
5653     num_coords_supplied = 1;
5654   } else if (auto* vec_type = As<Vector>(raw_coords.type)) {
5655     component_type = vec_type->type;
5656     num_coords_supplied = vec_type->size;
5657   }
5658   if (num_coords_supplied == 0) {
5659     Fail() << "bad or unsupported coordinate type for image access: "
5660            << inst.PrettyPrint();
5661     return {};
5662   }
5663   if (num_coords_required > num_coords_supplied) {
5664     Fail() << "image access required " << num_coords_required
5665            << " coordinate components, but only " << num_coords_supplied
5666            << " provided, in: " << inst.PrettyPrint();
5667     return {};
5668   }
5669 
5670   ast::ExpressionList result;
5671 
5672   // Generates the expression for the WGSL coordinates, when it is a prefix
5673   // swizzle with num_axes.  If the result would be unsigned, also converts
5674   // it to a signed value of the same shape (scalar or vector).
5675   // Use a lambda to make it easy to only generate the expressions when we
5676   // will actually use them.
5677   auto prefix_swizzle_expr = [this, num_axes, component_type, is_proj,
5678                               raw_coords]() -> const ast::Expression* {
5679     auto* swizzle_type =
5680         (num_axes == 1) ? component_type : ty_.Vector(component_type, num_axes);
5681     auto* swizzle = create<ast::MemberAccessorExpression>(
5682         Source{}, raw_coords.expr, PrefixSwizzle(num_axes));
5683     if (is_proj) {
5684       auto* q = create<ast::MemberAccessorExpression>(Source{}, raw_coords.expr,
5685                                                       Swizzle(num_axes));
5686       auto* proj_div = builder_.Div(swizzle, q);
5687       return ToSignedIfUnsigned({swizzle_type, proj_div}).expr;
5688     } else {
5689       return ToSignedIfUnsigned({swizzle_type, swizzle}).expr;
5690     }
5691   };
5692 
5693   if (is_arrayed) {
5694     // The source must be a vector. It has at least one coordinate component
5695     // and it must have an array component.  Use a vector swizzle to get the
5696     // first `num_axes` components.
5697     result.push_back(prefix_swizzle_expr());
5698 
5699     // Now get the array index.
5700     const ast::Expression* array_index =
5701         builder_.MemberAccessor(raw_coords.expr, Swizzle(num_axes));
5702     if (component_type->IsFloatScalar()) {
5703       // When converting from a float array layer to integer, Vulkan requires
5704       // round-to-nearest, with preference for round-to-nearest-even.
5705       // But i32(f32) in WGSL has unspecified rounding mode, so we have to
5706       // explicitly specify the rounding.
5707       array_index = builder_.Call("round", array_index);
5708     }
5709     // Convert it to a signed integer type, if needed.
5710     result.push_back(ToI32({component_type, array_index}).expr);
5711   } else {
5712     if (num_coords_supplied == num_coords_required && !is_proj) {
5713       // Pass the value through, with possible unsigned->signed conversion.
5714       result.push_back(ToSignedIfUnsigned(raw_coords).expr);
5715     } else {
5716       // There are more coordinates supplied than needed. So the source type
5717       // is a vector. Use a vector swizzle to get the first `num_axes`
5718       // components.
5719       result.push_back(prefix_swizzle_expr());
5720     }
5721   }
5722   return result;
5723 }
5724 
ConvertTexelForStorage(const spvtools::opt::Instruction & inst,TypedExpression texel,const Texture * texture_type)5725 const ast::Expression* FunctionEmitter::ConvertTexelForStorage(
5726     const spvtools::opt::Instruction& inst,
5727     TypedExpression texel,
5728     const Texture* texture_type) {
5729   auto* storage_texture_type = As<StorageTexture>(texture_type);
5730   auto* src_type = texel.type;
5731   if (!storage_texture_type) {
5732     Fail() << "writing to other than storage texture: " << inst.PrettyPrint();
5733     return nullptr;
5734   }
5735   const auto format = storage_texture_type->format;
5736   auto* dest_type = parser_impl_.GetTexelTypeForFormat(format);
5737   if (!dest_type) {
5738     Fail();
5739     return nullptr;
5740   }
5741 
5742   // The texel type is always a 4-element vector.
5743   const uint32_t dest_count = 4u;
5744   TINT_ASSERT(Reader, dest_type->Is<Vector>() &&
5745                           dest_type->As<Vector>()->size == dest_count);
5746   TINT_ASSERT(Reader, dest_type->IsFloatVector() ||
5747                           dest_type->IsUnsignedIntegerVector() ||
5748                           dest_type->IsSignedIntegerVector());
5749 
5750   if (src_type == dest_type) {
5751     return texel.expr;
5752   }
5753 
5754   // Component type must match floatness, or integral signedness.
5755   if ((src_type->IsFloatScalarOrVector() != dest_type->IsFloatVector()) ||
5756       (src_type->IsUnsignedIntegerVector() !=
5757        dest_type->IsUnsignedIntegerVector()) ||
5758       (src_type->IsSignedIntegerVector() !=
5759        dest_type->IsSignedIntegerVector())) {
5760     Fail() << "invalid texel type for storage texture write: component must be "
5761               "float, signed integer, or unsigned integer "
5762               "to match the texture channel type: "
5763            << inst.PrettyPrint();
5764     return nullptr;
5765   }
5766 
5767   const auto required_count = parser_impl_.GetChannelCountForFormat(format);
5768   TINT_ASSERT(Reader, 0 < required_count && required_count <= 4);
5769 
5770   const uint32_t src_count =
5771       src_type->IsScalar() ? 1 : src_type->As<Vector>()->size;
5772   if (src_count < required_count) {
5773     Fail() << "texel has too few components for storage texture: " << src_count
5774            << " provided but " << required_count
5775            << " required, in: " << inst.PrettyPrint();
5776     return nullptr;
5777   }
5778 
5779   // It's valid for required_count < src_count. The extra components will
5780   // be written out but the textureStore will ignore them.
5781 
5782   if (src_count < dest_count) {
5783     // Expand the texel to a 4 element vector.
5784     auto* component_type =
5785         texel.type->IsScalar() ? texel.type : texel.type->As<Vector>()->type;
5786     texel.type = ty_.Vector(component_type, dest_count);
5787     ast::ExpressionList exprs;
5788     exprs.push_back(texel.expr);
5789     for (auto i = src_count; i < dest_count; i++) {
5790       exprs.push_back(parser_impl_.MakeNullExpression(component_type).expr);
5791     }
5792     texel.expr = builder_.Construct(Source{}, texel.type->Build(builder_),
5793                                     std::move(exprs));
5794   }
5795 
5796   return texel.expr;
5797 }
5798 
ToI32(TypedExpression value)5799 TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
5800   if (!value || value.type->Is<I32>()) {
5801     return value;
5802   }
5803   return {ty_.I32(), builder_.Construct(Source{}, builder_.ty.i32(),
5804                                         ast::ExpressionList{value.expr})};
5805 }
5806 
ToSignedIfUnsigned(TypedExpression value)5807 TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
5808   if (!value || !value.type->IsUnsignedScalarOrVector()) {
5809     return value;
5810   }
5811   if (auto* vec_type = value.type->As<Vector>()) {
5812     auto* new_type = ty_.Vector(ty_.I32(), vec_type->size);
5813     return {new_type, builder_.Construct(new_type->Build(builder_),
5814                                          ast::ExpressionList{value.expr})};
5815   }
5816   return ToI32(value);
5817 }
5818 
MakeArrayLength(const spvtools::opt::Instruction & inst)5819 TypedExpression FunctionEmitter::MakeArrayLength(
5820     const spvtools::opt::Instruction& inst) {
5821   if (inst.NumInOperands() != 2) {
5822     // Binary parsing will fail on this anyway.
5823     Fail() << "invalid array length: requires 2 operands: "
5824            << inst.PrettyPrint();
5825     return {};
5826   }
5827   const auto struct_ptr_id = inst.GetSingleWordInOperand(0);
5828   const auto field_index = inst.GetSingleWordInOperand(1);
5829   const auto struct_ptr_type_id =
5830       def_use_mgr_->GetDef(struct_ptr_id)->type_id();
5831   // Trace through the pointer type to get to the struct type.
5832   const auto struct_type_id =
5833       def_use_mgr_->GetDef(struct_ptr_type_id)->GetSingleWordInOperand(1);
5834   const auto field_name = namer_.GetMemberName(struct_type_id, field_index);
5835   if (field_name.empty()) {
5836     Fail() << "struct index out of bounds for array length: "
5837            << inst.PrettyPrint();
5838     return {};
5839   }
5840 
5841   auto member_expr = MakeExpression(struct_ptr_id);
5842   if (!member_expr) {
5843     return {};
5844   }
5845   if (member_expr.type->Is<Pointer>()) {
5846     member_expr = Dereference(member_expr);
5847   }
5848   auto* member_ident = create<ast::IdentifierExpression>(
5849       Source{}, builder_.Symbols().Register(field_name));
5850   auto* member_access = create<ast::MemberAccessorExpression>(
5851       Source{}, member_expr.expr, member_ident);
5852 
5853   // Generate the intrinsic function call.
5854   auto* call_expr =
5855       builder_.Call(Source{}, "arrayLength", builder_.AddressOf(member_access));
5856 
5857   return {parser_impl_.ConvertType(inst.type_id()), call_expr};
5858 }
5859 
MakeOuterProduct(const spvtools::opt::Instruction & inst)5860 TypedExpression FunctionEmitter::MakeOuterProduct(
5861     const spvtools::opt::Instruction& inst) {
5862   // Synthesize the result.
5863   auto col = MakeOperand(inst, 0);
5864   auto row = MakeOperand(inst, 1);
5865   auto* col_ty = As<Vector>(col.type);
5866   auto* row_ty = As<Vector>(row.type);
5867   auto* result_ty = As<Matrix>(parser_impl_.ConvertType(inst.type_id()));
5868   if (!col_ty || !col_ty || !result_ty || result_ty->type != col_ty->type ||
5869       result_ty->type != row_ty->type || result_ty->columns != row_ty->size ||
5870       result_ty->rows != col_ty->size) {
5871     Fail() << "invalid outer product instruction: bad types "
5872            << inst.PrettyPrint();
5873     return {};
5874   }
5875 
5876   // Example:
5877   //    c : vec3 column vector
5878   //    r : vec2 row vector
5879   //    OuterProduct c r : mat2x3 (2 columns, 3 rows)
5880   //    Result:
5881   //      | c.x * r.x   c.x * r.y |
5882   //      | c.y * r.x   c.y * r.y |
5883   //      | c.z * r.x   c.z * r.y |
5884 
5885   ast::ExpressionList result_columns;
5886   for (uint32_t icol = 0; icol < result_ty->columns; icol++) {
5887     ast::ExpressionList result_row;
5888     auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
5889                                                              Swizzle(icol));
5890     for (uint32_t irow = 0; irow < result_ty->rows; irow++) {
5891       auto* column_factor = create<ast::MemberAccessorExpression>(
5892           Source{}, col.expr, Swizzle(irow));
5893       auto* elem = create<ast::BinaryExpression>(
5894           Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
5895       result_row.push_back(elem);
5896     }
5897     result_columns.push_back(
5898         builder_.Construct(Source{}, col_ty->Build(builder_), result_row));
5899   }
5900   return {result_ty, builder_.Construct(Source{}, result_ty->Build(builder_),
5901                                         result_columns)};
5902 }
5903 
MakeVectorInsertDynamic(const spvtools::opt::Instruction & inst)5904 bool FunctionEmitter::MakeVectorInsertDynamic(
5905     const spvtools::opt::Instruction& inst) {
5906   // For
5907   //    %result = OpVectorInsertDynamic %type %src_vector %component %index
5908   // there are two cases.
5909   //
5910   // Case 1:
5911   //   The %src_vector value has already been hoisted into a variable.
5912   //   In this case, assign %src_vector to that variable, then write the
5913   //   component into the right spot:
5914   //
5915   //    hoisted = src_vector;
5916   //    hoisted[index] = component;
5917   //
5918   // Case 2:
5919   //   The %src_vector value is not hoisted. In this case, make a temporary
5920   //   variable with the %src_vector contents, then write the component,
5921   //   and then make a let-declaration that reads the value out:
5922   //
5923   //    var temp : type = src_vector;
5924   //    temp[index] = component;
5925   //    let result : type = temp;
5926   //
5927   //   Then use result everywhere the original SPIR-V id is used.  Using a const
5928   //   like this avoids constantly reloading the value many times.
5929 
5930   auto* type = parser_impl_.ConvertType(inst.type_id());
5931   auto src_vector = MakeOperand(inst, 0);
5932   auto component = MakeOperand(inst, 1);
5933   auto index = MakeOperand(inst, 2);
5934 
5935   std::string var_name;
5936   auto original_value_name = namer_.Name(inst.result_id());
5937   const bool hoisted = WriteIfHoistedVar(inst, src_vector);
5938   if (hoisted) {
5939     // The variable was already declared in an earlier block.
5940     var_name = original_value_name;
5941     // Assign the source vector value to it.
5942     builder_.Assign({}, builder_.Expr(var_name), src_vector.expr);
5943   } else {
5944     // Synthesize the temporary variable.
5945     // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
5946     // API in parser_impl_.
5947     var_name = namer_.MakeDerivedName(original_value_name);
5948 
5949     auto* temp_var = builder_.Var(var_name, type->Build(builder_),
5950                                   ast::StorageClass::kNone, src_vector.expr);
5951 
5952     AddStatement(builder_.Decl({}, temp_var));
5953   }
5954 
5955   auto* lhs = create<ast::IndexAccessorExpression>(
5956       Source{}, builder_.Expr(var_name), index.expr);
5957   if (!lhs) {
5958     return false;
5959   }
5960 
5961   AddStatement(builder_.Assign(lhs, component.expr));
5962 
5963   if (hoisted) {
5964     // The hoisted variable itself stands for this result ID.
5965     return success();
5966   }
5967   // Create a new let-declaration that is initialized by the contents
5968   // of the temporary variable.
5969   return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
5970 }
5971 
MakeCompositeInsert(const spvtools::opt::Instruction & inst)5972 bool FunctionEmitter::MakeCompositeInsert(
5973     const spvtools::opt::Instruction& inst) {
5974   // For
5975   //    %result = OpCompositeInsert %type %object %composite 1 2 3 ...
5976   // there are two cases.
5977   //
5978   // Case 1:
5979   //   The %composite value has already been hoisted into a variable.
5980   //   In this case, assign %composite to that variable, then write the
5981   //   component into the right spot:
5982   //
5983   //    hoisted = composite;
5984   //    hoisted[index].x = object;
5985   //
5986   // Case 2:
5987   //   The %composite value is not hoisted. In this case, make a temporary
5988   //   variable with the %composite contents, then write the component,
5989   //   and then make a let-declaration that reads the value out:
5990   //
5991   //    var temp : type = composite;
5992   //    temp[index].x = object;
5993   //    let result : type = temp;
5994   //
5995   //   Then use result everywhere the original SPIR-V id is used.  Using a const
5996   //   like this avoids constantly reloading the value many times.
5997   //
5998   //   This technique is a combination of:
5999   //   - making a temporary variable and constant declaration, like what we do
6000   //     for VectorInsertDynamic, and
6001   //   - building up an access-chain like access like for CompositeExtract, but
6002   //     on the left-hand side of the assignment.
6003 
6004   auto* type = parser_impl_.ConvertType(inst.type_id());
6005   auto component = MakeOperand(inst, 0);
6006   auto src_composite = MakeOperand(inst, 1);
6007 
6008   std::string var_name;
6009   auto original_value_name = namer_.Name(inst.result_id());
6010   const bool hoisted = WriteIfHoistedVar(inst, src_composite);
6011   if (hoisted) {
6012     // The variable was already declared in an earlier block.
6013     var_name = original_value_name;
6014     // Assign the source composite value to it.
6015     builder_.Assign({}, builder_.Expr(var_name), src_composite.expr);
6016   } else {
6017     // Synthesize a temporary variable.
6018     // It doesn't correspond to a SPIR-V ID, so we don't use the ordinary
6019     // API in parser_impl_.
6020     var_name = namer_.MakeDerivedName(original_value_name);
6021     auto* temp_var = builder_.Var(var_name, type->Build(builder_),
6022                                   ast::StorageClass::kNone, src_composite.expr);
6023     AddStatement(builder_.Decl({}, temp_var));
6024   }
6025 
6026   TypedExpression seed_expr{type, builder_.Expr(var_name)};
6027 
6028   // The left-hand side of the assignment *looks* like a decomposition.
6029   TypedExpression lhs =
6030       MakeCompositeValueDecomposition(inst, seed_expr, inst.type_id(), 2);
6031   if (!lhs) {
6032     return false;
6033   }
6034 
6035   AddStatement(builder_.Assign(lhs.expr, component.expr));
6036 
6037   if (hoisted) {
6038     // The hoisted variable itself stands for this result ID.
6039     return success();
6040   }
6041   // Create a new let-declaration that is initialized by the contents
6042   // of the temporary variable.
6043   return EmitConstDefinition(inst, {type, builder_.Expr(var_name)});
6044 }
6045 
AddressOf(TypedExpression expr)6046 TypedExpression FunctionEmitter::AddressOf(TypedExpression expr) {
6047   auto* ref = expr.type->As<Reference>();
6048   if (!ref) {
6049     Fail() << "AddressOf() called on non-reference type";
6050     return {};
6051   }
6052   return {
6053       ty_.Pointer(ref->type, ref->storage_class),
6054       create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kAddressOf,
6055                                      expr.expr),
6056   };
6057 }
6058 
Dereference(TypedExpression expr)6059 TypedExpression FunctionEmitter::Dereference(TypedExpression expr) {
6060   auto* ptr = expr.type->As<Pointer>();
6061   if (!ptr) {
6062     Fail() << "Dereference() called on non-pointer type";
6063     return {};
6064   }
6065   return {
6066       ptr->type,
6067       create<ast::UnaryOpExpression>(Source{}, ast::UnaryOp::kIndirection,
6068                                      expr.expr),
6069   };
6070 }
6071 
IsFloatZero(uint32_t value_id)6072 bool FunctionEmitter::IsFloatZero(uint32_t value_id) {
6073   if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
6074     if (const auto* float_const = c->AsFloatConstant()) {
6075       return 0.0f == float_const->GetFloatValue();
6076     }
6077     if (c->AsNullConstant()) {
6078       // Valid SPIR-V requires it to be a float value anyway.
6079       return true;
6080     }
6081   }
6082   return false;
6083 }
6084 
IsFloatOne(uint32_t value_id)6085 bool FunctionEmitter::IsFloatOne(uint32_t value_id) {
6086   if (const auto* c = constant_mgr_->FindDeclaredConstant(value_id)) {
6087     if (const auto* float_const = c->AsFloatConstant()) {
6088       return 1.0f == float_const->GetFloatValue();
6089     }
6090   }
6091   return false;
6092 }
6093 
6094 FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
6095 FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
6096 
6097 }  // namespace spirv
6098 }  // namespace reader
6099 }  // namespace tint
6100 
6101 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::StatementBuilder);
6102 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::SwitchStatementBuilder);
6103 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::IfStatementBuilder);
6104 TINT_INSTANTIATE_TYPEINFO(tint::reader::spirv::LoopStatementBuilder);
6105