• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkOpts_spi.h"
13 #include "include/private/SkSLIRNode.h"
14 #include "include/private/SkSLProgramElement.h"
15 #include "include/private/SkSLStatement.h"
16 #include "include/private/SkSLSymbol.h"
17 #include "include/private/base/SkTArray.h"
18 #include "include/sksl/DSLCore.h"
19 #include "include/sksl/DSLExpression.h"
20 #include "include/sksl/DSLType.h"
21 #include "include/sksl/DSLVar.h"
22 #include "include/sksl/SkSLErrorReporter.h"
23 #include "include/sksl/SkSLOperator.h"
24 #include "include/sksl/SkSLPosition.h"
25 #include "src/sksl/GLSL.std.450.h"
26 #include "src/sksl/SkSLAnalysis.h"
27 #include "src/sksl/SkSLBuiltinTypes.h"
28 #include "src/sksl/SkSLCompiler.h"
29 #include "src/sksl/SkSLConstantFolder.h"
30 #include "src/sksl/SkSLContext.h"
31 #include "src/sksl/SkSLIntrinsicList.h"
32 #include "src/sksl/SkSLModifiersPool.h"
33 #include "src/sksl/SkSLOutputStream.h"
34 #include "src/sksl/SkSLPool.h"
35 #include "src/sksl/SkSLProgramSettings.h"
36 #include "src/sksl/SkSLThreadContext.h"
37 #include "src/sksl/SkSLUtil.h"
38 #include "src/sksl/analysis/SkSLProgramUsage.h"
39 #include "src/sksl/ir/SkSLBinaryExpression.h"
40 #include "src/sksl/ir/SkSLBlock.h"
41 #include "src/sksl/ir/SkSLConstructor.h"
42 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
43 #include "src/sksl/ir/SkSLConstructorCompound.h"
44 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
45 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
46 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
47 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
48 #include "src/sksl/ir/SkSLConstructorSplat.h"
49 #include "src/sksl/ir/SkSLDoStatement.h"
50 #include "src/sksl/ir/SkSLExpression.h"
51 #include "src/sksl/ir/SkSLExpressionStatement.h"
52 #include "src/sksl/ir/SkSLExtension.h"
53 #include "src/sksl/ir/SkSLField.h"
54 #include "src/sksl/ir/SkSLFieldAccess.h"
55 #include "src/sksl/ir/SkSLForStatement.h"
56 #include "src/sksl/ir/SkSLFunctionCall.h"
57 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
58 #include "src/sksl/ir/SkSLFunctionDefinition.h"
59 #include "src/sksl/ir/SkSLIfStatement.h"
60 #include "src/sksl/ir/SkSLIndexExpression.h"
61 #include "src/sksl/ir/SkSLInterfaceBlock.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLPostfixExpression.h"
64 #include "src/sksl/ir/SkSLPrefixExpression.h"
65 #include "src/sksl/ir/SkSLProgram.h"
66 #include "src/sksl/ir/SkSLReturnStatement.h"
67 #include "src/sksl/ir/SkSLSetting.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLVarDeclarations.h"
73 #include "src/sksl/ir/SkSLVariableReference.h"
74 #include "src/utils/SkBitSet.h"
75 
76 #include <cstring>
77 #include <set>
78 #include <string>
79 #include <utility>
80 
81 #define kLast_Capability SpvCapabilityMultiViewport
82 
83 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
84 constexpr int DEVICE_CLOCKWISE_BUILTIN  = -1001;
85 
86 namespace SkSL {
87 
88 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const89 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
90     return fOp         == that.fOp &&
91            fResultKind == that.fResultKind &&
92            fWords      == that.fWords;
93 }
94 
95 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash96     uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
97         uint32_t hash = key.fResultKind;
98         hash = SkOpts::hash_fn(&key.fOp, sizeof(key.fOp), hash);
99         hash = SkOpts::hash_fn(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
100         return hash;
101     }
102 };
103 
104 // This class is used to pass values and result placeholder slots to writeInstruction.
105 struct SPIRVCodeGenerator::Word {
106     enum Kind {
107         kNone,  // intended for use as a sentinel, not part of any Instruction
108         kSpvId,
109         kNumber,
110         kDefaultPrecisionResult,
111         kRelaxedPrecisionResult,
112         kUniqueResult,
113         kKeyedResult,
114     };
115 
WordSkSL::SPIRVCodeGenerator::Word116     Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word117     Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
118 
NumberSkSL::SPIRVCodeGenerator::Word119     static Word Number(int32_t val) {
120         return Word{val, Kind::kNumber};
121     }
122 
ResultSkSL::SPIRVCodeGenerator::Word123     static Word Result(const Type& type) {
124         return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
125     }
126 
RelaxedResultSkSL::SPIRVCodeGenerator::Word127     static Word RelaxedResult() {
128         return Word{(int32_t)NA, kRelaxedPrecisionResult};
129     }
130 
UniqueResultSkSL::SPIRVCodeGenerator::Word131     static Word UniqueResult() {
132         return Word{(int32_t)NA, kUniqueResult};
133     }
134 
ResultSkSL::SPIRVCodeGenerator::Word135     static Word Result() {
136         return Word{(int32_t)NA, kDefaultPrecisionResult};
137     }
138 
139     // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
140     // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
141     // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word142     static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
143 
isResultSkSL::SPIRVCodeGenerator::Word144     bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
145 
146     int32_t fValue;
147     Kind fKind;
148 };
149 
150 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
151 // sksl generator if we want.
152 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
153 static const int32_t SKSL_MAGIC  = 0x001F0000;
154 
getIntrinsic(IntrinsicKind ik) const155 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
156 
157 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
158                               GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
159 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
160                                                        GLSLstd450 ## ifFloat,             \
161                                                        GLSLstd450 ## ifInt,               \
162                                                        GLSLstd450 ## ifUInt,              \
163                                                        SpvOpUndef}
164 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
165                                SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
166 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
167                                 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
168 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
169                                  SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
170 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
171                              k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
172                              k ## x ## _SpecialIntrinsic}
173 
174     switch (ik) {
175         case k_round_IntrinsicKind:          return ALL_GLSL(Round);
176         case k_roundEven_IntrinsicKind:      return ALL_GLSL(RoundEven);
177         case k_trunc_IntrinsicKind:          return ALL_GLSL(Trunc);
178         case k_abs_IntrinsicKind:            return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
179         case k_sign_IntrinsicKind:           return BY_TYPE_GLSL(FSign, SSign, SSign);
180         case k_floor_IntrinsicKind:          return ALL_GLSL(Floor);
181         case k_ceil_IntrinsicKind:           return ALL_GLSL(Ceil);
182         case k_fract_IntrinsicKind:          return ALL_GLSL(Fract);
183         case k_radians_IntrinsicKind:        return ALL_GLSL(Radians);
184         case k_degrees_IntrinsicKind:        return ALL_GLSL(Degrees);
185         case k_sin_IntrinsicKind:            return ALL_GLSL(Sin);
186         case k_cos_IntrinsicKind:            return ALL_GLSL(Cos);
187         case k_tan_IntrinsicKind:            return ALL_GLSL(Tan);
188         case k_asin_IntrinsicKind:           return ALL_GLSL(Asin);
189         case k_acos_IntrinsicKind:           return ALL_GLSL(Acos);
190         case k_atan_IntrinsicKind:           return SPECIAL(Atan);
191         case k_sinh_IntrinsicKind:           return ALL_GLSL(Sinh);
192         case k_cosh_IntrinsicKind:           return ALL_GLSL(Cosh);
193         case k_tanh_IntrinsicKind:           return ALL_GLSL(Tanh);
194         case k_asinh_IntrinsicKind:          return ALL_GLSL(Asinh);
195         case k_acosh_IntrinsicKind:          return ALL_GLSL(Acosh);
196         case k_atanh_IntrinsicKind:          return ALL_GLSL(Atanh);
197         case k_pow_IntrinsicKind:            return ALL_GLSL(Pow);
198         case k_exp_IntrinsicKind:            return ALL_GLSL(Exp);
199         case k_log_IntrinsicKind:            return ALL_GLSL(Log);
200         case k_exp2_IntrinsicKind:           return ALL_GLSL(Exp2);
201         case k_log2_IntrinsicKind:           return ALL_GLSL(Log2);
202         case k_sqrt_IntrinsicKind:           return ALL_GLSL(Sqrt);
203         case k_inverse_IntrinsicKind:        return ALL_GLSL(MatrixInverse);
204         case k_outerProduct_IntrinsicKind:   return ALL_SPIRV(OuterProduct);
205         case k_transpose_IntrinsicKind:      return ALL_SPIRV(Transpose);
206         case k_isinf_IntrinsicKind:          return ALL_SPIRV(IsInf);
207         case k_isnan_IntrinsicKind:          return ALL_SPIRV(IsNan);
208         case k_inversesqrt_IntrinsicKind:    return ALL_GLSL(InverseSqrt);
209         case k_determinant_IntrinsicKind:    return ALL_GLSL(Determinant);
210         case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
211         case k_matrixInverse_IntrinsicKind:  return ALL_GLSL(MatrixInverse);
212         case k_mod_IntrinsicKind:            return SPECIAL(Mod);
213         case k_modf_IntrinsicKind:           return ALL_GLSL(Modf);
214         case k_min_IntrinsicKind:            return SPECIAL(Min);
215         case k_max_IntrinsicKind:            return SPECIAL(Max);
216         case k_clamp_IntrinsicKind:          return SPECIAL(Clamp);
217         case k_saturate_IntrinsicKind:       return SPECIAL(Saturate);
218         case k_dot_IntrinsicKind:            return FLOAT_SPIRV(Dot);
219         case k_mix_IntrinsicKind:            return SPECIAL(Mix);
220         case k_step_IntrinsicKind:           return SPECIAL(Step);
221         case k_smoothstep_IntrinsicKind:     return SPECIAL(SmoothStep);
222         case k_fma_IntrinsicKind:            return ALL_GLSL(Fma);
223         case k_frexp_IntrinsicKind:          return ALL_GLSL(Frexp);
224         case k_ldexp_IntrinsicKind:          return ALL_GLSL(Ldexp);
225 
226 #define PACK(type) case k_pack##type##_IntrinsicKind:   return ALL_GLSL(Pack##type); \
227                    case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
228         PACK(Snorm4x8);
229         PACK(Unorm4x8);
230         PACK(Snorm2x16);
231         PACK(Unorm2x16);
232         PACK(Half2x16);
233         PACK(Double2x32);
234 #undef PACK
235 
236         case k_length_IntrinsicKind:        return ALL_GLSL(Length);
237         case k_distance_IntrinsicKind:      return ALL_GLSL(Distance);
238         case k_cross_IntrinsicKind:         return ALL_GLSL(Cross);
239         case k_normalize_IntrinsicKind:     return ALL_GLSL(Normalize);
240         case k_faceforward_IntrinsicKind:   return ALL_GLSL(FaceForward);
241         case k_reflect_IntrinsicKind:       return ALL_GLSL(Reflect);
242         case k_refract_IntrinsicKind:       return ALL_GLSL(Refract);
243         case k_bitCount_IntrinsicKind:      return ALL_SPIRV(BitCount);
244         case k_findLSB_IntrinsicKind:       return ALL_GLSL(FindILsb);
245         case k_findMSB_IntrinsicKind:       return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
246         case k_dFdx_IntrinsicKind:          return FLOAT_SPIRV(DPdx);
247         case k_dFdy_IntrinsicKind:          return SPECIAL(DFdy);
248         case k_fwidth_IntrinsicKind:        return FLOAT_SPIRV(Fwidth);
249         case k_makeSampler2D_IntrinsicKind: return SPECIAL(SampledImage);
250 
251         case k_sample_IntrinsicKind:      return SPECIAL(Texture);
252         case k_sampleGrad_IntrinsicKind:  return SPECIAL(TextureGrad);
253         case k_sampleLod_IntrinsicKind:   return SPECIAL(TextureLod);
254         case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
255 
256         case k_floatBitsToInt_IntrinsicKind:  return ALL_SPIRV(Bitcast);
257         case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
258         case k_intBitsToFloat_IntrinsicKind:  return ALL_SPIRV(Bitcast);
259         case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
260 
261         case k_any_IntrinsicKind:   return BOOL_SPIRV(Any);
262         case k_all_IntrinsicKind:   return BOOL_SPIRV(All);
263         case k_not_IntrinsicKind:   return BOOL_SPIRV(LogicalNot);
264 
265         case k_equal_IntrinsicKind:
266             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
267                              SpvOpFOrdEqual,
268                              SpvOpIEqual,
269                              SpvOpIEqual,
270                              SpvOpLogicalEqual};
271         case k_notEqual_IntrinsicKind:
272             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
273                              SpvOpFUnordNotEqual,
274                              SpvOpINotEqual,
275                              SpvOpINotEqual,
276                              SpvOpLogicalNotEqual};
277         case k_lessThan_IntrinsicKind:
278             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
279                              SpvOpFOrdLessThan,
280                              SpvOpSLessThan,
281                              SpvOpULessThan,
282                              SpvOpUndef};
283         case k_lessThanEqual_IntrinsicKind:
284             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
285                              SpvOpFOrdLessThanEqual,
286                              SpvOpSLessThanEqual,
287                              SpvOpULessThanEqual,
288                              SpvOpUndef};
289         case k_greaterThan_IntrinsicKind:
290             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
291                              SpvOpFOrdGreaterThan,
292                              SpvOpSGreaterThan,
293                              SpvOpUGreaterThan,
294                              SpvOpUndef};
295         case k_greaterThanEqual_IntrinsicKind:
296             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
297                              SpvOpFOrdGreaterThanEqual,
298                              SpvOpSGreaterThanEqual,
299                              SpvOpUGreaterThanEqual,
300                              SpvOpUndef};
301         default:
302             return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
303     }
304 }
305 
writeWord(int32_t word,OutputStream & out)306 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
307     out.write((const char*) &word, sizeof(word));
308 }
309 
is_float(const Type & type)310 static bool is_float(const Type& type) {
311     return (type.isScalar() || type.isVector() || type.isMatrix()) &&
312            type.componentType().isFloat();
313 }
314 
is_signed(const Type & type)315 static bool is_signed(const Type& type) {
316     return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
317 }
318 
is_unsigned(const Type & type)319 static bool is_unsigned(const Type& type) {
320     return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
321 }
322 
is_bool(const Type & type)323 static bool is_bool(const Type& type) {
324     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
325 }
326 
327 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)328 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
329     if (is_float(type)) {
330         return ifFloat;
331     }
332     if (is_signed(type)) {
333         return ifInt;
334     }
335     if (is_unsigned(type)) {
336         return ifUInt;
337     }
338     if (is_bool(type)) {
339         return ifBool;
340     }
341     SkDEBUGFAIL("unrecognized type");
342     return ifFloat;
343 }
344 
is_out(const Modifiers & m)345 static bool is_out(const Modifiers& m) {
346     return (m.fFlags & Modifiers::kOut_Flag) != 0;
347 }
348 
is_in(const Modifiers & m)349 static bool is_in(const Modifiers& m) {
350     switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
351         case Modifiers::kOut_Flag:                       // out
352             return false;
353 
354         case 0:                                          // implicit in
355         case Modifiers::kIn_Flag:                        // explicit in
356         case Modifiers::kOut_Flag | Modifiers::kIn_Flag: // inout
357             return true;
358 
359         default: SkUNREACHABLE;
360     }
361 }
362 
is_control_flow_op(SpvOp_ op)363 static bool is_control_flow_op(SpvOp_ op) {
364     switch (op) {
365         case SpvOpReturn:
366         case SpvOpReturnValue:
367         case SpvOpKill:
368         case SpvOpSwitch:
369         case SpvOpBranch:
370         case SpvOpBranchConditional:
371             return true;
372         default:
373             return false;
374     }
375 }
376 
is_globally_reachable_op(SpvOp_ op)377 static bool is_globally_reachable_op(SpvOp_ op) {
378     switch (op) {
379         case SpvOpConstant:
380         case SpvOpConstantTrue:
381         case SpvOpConstantFalse:
382         case SpvOpConstantComposite:
383         case SpvOpTypeVoid:
384         case SpvOpTypeInt:
385         case SpvOpTypeFloat:
386         case SpvOpTypeBool:
387         case SpvOpTypeVector:
388         case SpvOpTypeMatrix:
389         case SpvOpTypeArray:
390         case SpvOpTypePointer:
391         case SpvOpTypeFunction:
392         case SpvOpTypeRuntimeArray:
393         case SpvOpTypeStruct:
394         case SpvOpTypeImage:
395         case SpvOpTypeSampledImage:
396         case SpvOpTypeSampler:
397         case SpvOpVariable:
398         case SpvOpFunction:
399         case SpvOpFunctionParameter:
400         case SpvOpFunctionEnd:
401         case SpvOpExecutionMode:
402         case SpvOpMemoryModel:
403         case SpvOpCapability:
404         case SpvOpExtInstImport:
405         case SpvOpEntryPoint:
406         case SpvOpSource:
407         case SpvOpSourceExtension:
408         case SpvOpName:
409         case SpvOpMemberName:
410         case SpvOpDecorate:
411         case SpvOpMemberDecorate:
412             return true;
413         default:
414             return false;
415     }
416 }
417 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)418 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
419     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
420     SkASSERT(opCode != SpvOpUndef);
421     bool foundDeadCode = false;
422     if (is_control_flow_op(opCode)) {
423         // This instruction causes us to leave the current block.
424         foundDeadCode = (fCurrentBlock == 0);
425         fCurrentBlock = 0;
426     } else if (!is_globally_reachable_op(opCode)) {
427         foundDeadCode = (fCurrentBlock == 0);
428     }
429 
430     if (foundDeadCode) {
431         // We just encountered dead code--an instruction that don't have an associated block.
432         // Synthesize a label if this happens; this is necessary to satisfy the validator.
433         this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
434     }
435 
436     this->writeWord((length << 16) | opCode, out);
437 }
438 
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)439 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
440     // The straight-line label type is not important; in any case, no caches are invalidated.
441     SkASSERT(!fCurrentBlock);
442     fCurrentBlock = label;
443     this->writeInstruction(SpvOpLabel, label, out);
444 }
445 
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)446 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
447                                     ConditionalOpCounts ops, OutputStream& out) {
448     switch (type) {
449         case kBranchIsBelow:
450         case kBranchesOnBothSides:
451             // With a backward or bidirectional branch, we haven't seen the code between the label
452             // and the branch yet, so any stored value is potentially suspect. Without scanning
453             // ahead to check, the only safe option is to ditch the store cache entirely.
454             fStoreCache.reset();
455             [[fallthrough]];
456 
457         case kBranchIsAbove:
458             // With a forward branch, we can rely on stores that we had cached at the start of the
459             // statement/expression, if they haven't been touched yet. Anything newer than that is
460             // pruned.
461             this->pruneConditionalOps(ops);
462             break;
463     }
464 
465     // Emit the label.
466     this->writeLabel(label, kBranchlessBlock, out);
467 }
468 
writeInstruction(SpvOp_ opCode,OutputStream & out)469 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
470     this->writeOpCode(opCode, 1, out);
471 }
472 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)473 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
474     this->writeOpCode(opCode, 2, out);
475     this->writeWord(word1, out);
476 }
477 
writeString(std::string_view s,OutputStream & out)478 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
479     out.write(s.data(), s.length());
480     switch (s.length() % 4) {
481         case 1:
482             out.write8(0);
483             [[fallthrough]];
484         case 2:
485             out.write8(0);
486             [[fallthrough]];
487         case 3:
488             out.write8(0);
489             break;
490         default:
491             this->writeWord(0, out);
492             break;
493     }
494 }
495 
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)496 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
497                                           OutputStream& out) {
498     this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
499     this->writeString(string, out);
500 }
501 
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)502 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
503                                           OutputStream& out) {
504     this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
505     this->writeWord(word1, out);
506     this->writeString(string, out);
507 }
508 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)509 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
510                                           std::string_view string, OutputStream& out) {
511     this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
512     this->writeWord(word1, out);
513     this->writeWord(word2, out);
514     this->writeString(string, out);
515 }
516 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)517 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
518                                           OutputStream& out) {
519     this->writeOpCode(opCode, 3, out);
520     this->writeWord(word1, out);
521     this->writeWord(word2, out);
522 }
523 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)524 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
525                                           int32_t word3, OutputStream& out) {
526     this->writeOpCode(opCode, 4, out);
527     this->writeWord(word1, out);
528     this->writeWord(word2, out);
529     this->writeWord(word3, out);
530 }
531 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)532 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
533                                           int32_t word3, int32_t word4, OutputStream& out) {
534     this->writeOpCode(opCode, 5, out);
535     this->writeWord(word1, out);
536     this->writeWord(word2, out);
537     this->writeWord(word3, out);
538     this->writeWord(word4, out);
539 }
540 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)541 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
542                                           int32_t word3, int32_t word4, int32_t word5,
543                                           OutputStream& out) {
544     this->writeOpCode(opCode, 6, out);
545     this->writeWord(word1, out);
546     this->writeWord(word2, out);
547     this->writeWord(word3, out);
548     this->writeWord(word4, out);
549     this->writeWord(word5, out);
550 }
551 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)552 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
553                                           int32_t word3, int32_t word4, int32_t word5,
554                                           int32_t word6, OutputStream& out) {
555     this->writeOpCode(opCode, 7, out);
556     this->writeWord(word1, out);
557     this->writeWord(word2, out);
558     this->writeWord(word3, out);
559     this->writeWord(word4, out);
560     this->writeWord(word5, out);
561     this->writeWord(word6, out);
562 }
563 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)564 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
565                                           int32_t word3, int32_t word4, int32_t word5,
566                                           int32_t word6, int32_t word7, OutputStream& out) {
567     this->writeOpCode(opCode, 8, out);
568     this->writeWord(word1, out);
569     this->writeWord(word2, out);
570     this->writeWord(word3, out);
571     this->writeWord(word4, out);
572     this->writeWord(word5, out);
573     this->writeWord(word6, out);
574     this->writeWord(word7, out);
575 }
576 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)577 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
578                                           int32_t word3, int32_t word4, int32_t word5,
579                                           int32_t word6, int32_t word7, int32_t word8,
580                                           OutputStream& out) {
581     this->writeOpCode(opCode, 9, out);
582     this->writeWord(word1, out);
583     this->writeWord(word2, out);
584     this->writeWord(word3, out);
585     this->writeWord(word4, out);
586     this->writeWord(word5, out);
587     this->writeWord(word6, out);
588     this->writeWord(word7, out);
589     this->writeWord(word8, out);
590 }
591 
BuildInstructionKey(SpvOp_ opCode,const SkTArray<Word> & words)592 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(
593         SpvOp_ opCode, const SkTArray<Word>& words) {
594     // Assemble a cache key for this instruction.
595     Instruction key;
596     key.fOp = opCode;
597     key.fWords.resize(words.size());
598     key.fResultKind = Word::Kind::kNone;
599 
600     for (int index = 0; index < words.size(); ++index) {
601         const Word& word = words[index];
602         key.fWords[index] = word.fValue;
603         if (word.isResult()) {
604             SkASSERT(key.fResultKind == Word::Kind::kNone);
605             key.fResultKind = word.fKind;
606         }
607     }
608 
609     return key;
610 }
611 
writeInstruction(SpvOp_ opCode,const SkTArray<Word> & words,OutputStream & out)612 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
613                                            const SkTArray<Word>& words,
614                                            OutputStream& out) {
615     // writeOpLoad and writeOpStore have dedicated code.
616     SkASSERT(opCode != SpvOpLoad);
617     SkASSERT(opCode != SpvOpStore);
618 
619     // If this instruction exists in our op cache, return the cached SpvId.
620     Instruction key = BuildInstructionKey(opCode, words);
621     if (SpvId* cachedOp = fOpCache.find(key)) {
622         return *cachedOp;
623     }
624 
625     SpvId result = NA;
626     Precision precision = Precision::kDefault;
627 
628     switch (key.fResultKind) {
629         case Word::Kind::kUniqueResult:
630             // The instruction returns a SpvId, but we do not want deduplication.
631             result = this->nextId(Precision::kDefault);
632             fSpvIdCache.set(result, key);
633             break;
634 
635         case Word::Kind::kNone:
636             // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
637             fOpCache.set(key, result);
638             break;
639 
640         case Word::Kind::kRelaxedPrecisionResult:
641             precision = Precision::kRelaxed;
642             [[fallthrough]];
643 
644         case Word::Kind::kKeyedResult:
645             [[fallthrough]];
646 
647         case Word::Kind::kDefaultPrecisionResult:
648             // Consume a new SpvId.
649             result = this->nextId(precision);
650             fOpCache.set(key, result);
651             fSpvIdCache.set(result, key);
652 
653             // Globally-reachable ops are not subject to the whims of flow control.
654             if (!is_globally_reachable_op(opCode)) {
655                 fReachableOps.push_back(result);
656             }
657             break;
658 
659         default:
660             SkDEBUGFAIL("unexpected result kind");
661             break;
662     }
663 
664     // Write the requested instruction.
665     this->writeOpCode(opCode, words.size() + 1, out);
666     for (const Word& word : words) {
667         if (word.isResult()) {
668             SkASSERT(result != NA);
669             this->writeWord(result, out);
670         } else {
671             this->writeWord(word.fValue, out);
672         }
673     }
674 
675     // Return the result.
676     return result;
677 }
678 
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)679 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
680                                       Precision precision,
681                                       SpvId pointer,
682                                       OutputStream& out) {
683     // Look for this pointer in our load-cache.
684     if (SpvId* cachedOp = fStoreCache.find(pointer)) {
685         return *cachedOp;
686     }
687 
688     // Write the requested OpLoad instruction.
689     SpvId result = this->nextId(precision);
690     this->writeInstruction(SpvOpLoad, type, result, pointer, out);
691     return result;
692 }
693 
writeOpStore(SpvStorageClass_ storageClass,SpvId pointer,SpvId value,OutputStream & out)694 void SPIRVCodeGenerator::writeOpStore(SpvStorageClass_ storageClass,
695                                       SpvId pointer,
696                                       SpvId value,
697                                       OutputStream& out) {
698     // Write the uncached SpvOpStore directly.
699     this->writeInstruction(SpvOpStore, pointer, value, out);
700 
701     if (storageClass == SpvStorageClassFunction) {
702         // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
703         // return the cached value as-is.
704         fStoreCache.set(pointer, value);
705         fStoreOps.push_back(pointer);
706     }
707 }
708 
writeOpConstantTrue(const Type & type)709 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
710     return this->writeInstruction(SpvOpConstantTrue,
711                                   Words{this->getType(type), Word::Result()},
712                                   fConstantBuffer);
713 }
714 
writeOpConstantFalse(const Type & type)715 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
716     return this->writeInstruction(SpvOpConstantFalse,
717                                   Words{this->getType(type), Word::Result()},
718                                   fConstantBuffer);
719 }
720 
writeOpConstant(const Type & type,int32_t valueBits)721 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
722     return this->writeInstruction(
723             SpvOpConstant,
724             Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
725             fConstantBuffer);
726 }
727 
writeOpConstantComposite(const Type & type,const SkTArray<SpvId> & values)728 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
729                                                    const SkTArray<SpvId>& values) {
730     SkASSERT(values.size() == (type.isStruct() ? (int)type.fields().size() : type.columns()));
731 
732     Words words;
733     words.push_back(this->getType(type));
734     words.push_back(Word::Result());
735     for (SpvId value : values) {
736         words.push_back(value);
737     }
738     return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
739 }
740 
toConstants(SpvId value,SkTArray<SpvId> * constants)741 bool SPIRVCodeGenerator::toConstants(SpvId value, SkTArray<SpvId>* constants) {
742     Instruction* instr = fSpvIdCache.find(value);
743     if (!instr) {
744         return false;
745     }
746     switch (instr->fOp) {
747         case SpvOpConstant:
748         case SpvOpConstantTrue:
749         case SpvOpConstantFalse:
750             constants->push_back(value);
751             return true;
752 
753         case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
754             // Start at word 2 to skip past ResultType and ResultID.
755             for (int i = 2; i < instr->fWords.size(); ++i) {
756                 if (!this->toConstants(instr->fWords[i], constants)) {
757                     return false;
758                 }
759             }
760             return true;
761 
762         default:
763             return false;
764     }
765 }
766 
toConstants(SkSpan<const SpvId> values,SkTArray<SpvId> * constants)767 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, SkTArray<SpvId>* constants) {
768     for (SpvId value : values) {
769         if (!this->toConstants(value, constants)) {
770             return false;
771         }
772     }
773     return true;
774 }
775 
writeOpCompositeConstruct(const Type & type,const SkTArray<SpvId> & values,OutputStream & out)776 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
777                                                     const SkTArray<SpvId>& values,
778                                                     OutputStream& out) {
779     // If this is a vector composed entirely of literals, write a constant-composite instead.
780     if (type.isVector()) {
781         SkSTArray<4, SpvId> constants;
782         if (this->toConstants(SkSpan(values), &constants)) {
783             // Create a vector from literals.
784             return this->writeOpConstantComposite(type, constants);
785         }
786     }
787 
788     // If this is a matrix composed entirely of literals, constant-composite them instead.
789     if (type.isMatrix()) {
790         SkSTArray<16, SpvId> constants;
791         if (this->toConstants(SkSpan(values), &constants)) {
792             // Create each matrix column.
793             SkASSERT(type.isMatrix());
794             const Type& vecType = type.componentType().toCompound(fContext,
795                                                                   /*columns=*/type.rows(),
796                                                                   /*rows=*/1);
797             SkSTArray<4, SpvId> columnIDs;
798             for (int index=0; index < type.columns(); ++index) {
799                 SkSTArray<4, SpvId> columnConstants(&constants[index * type.rows()],
800                                                     type.rows());
801                 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
802             }
803             // Compose the matrix from its columns.
804             return this->writeOpConstantComposite(type, columnIDs);
805         }
806     }
807 
808     Words words;
809     words.push_back(this->getType(type));
810     words.push_back(Word::Result(type));
811     for (SpvId value : values) {
812         words.push_back(value);
813     }
814 
815     return this->writeInstruction(SpvOpCompositeConstruct, words, out);
816 }
817 
resultTypeForInstruction(const Instruction & instr)818 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
819         const Instruction& instr) {
820     // This list should contain every op that we cache that has a result and result-type.
821     // (If one is missing, we will not find some optimization opportunities.)
822     // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
823     // universally true, so it's configurable on a per-op basis.
824     int resultTypeWord;
825     switch (instr.fOp) {
826         case SpvOpConstant:
827         case SpvOpConstantTrue:
828         case SpvOpConstantFalse:
829         case SpvOpConstantComposite:
830         case SpvOpCompositeConstruct:
831         case SpvOpCompositeExtract:
832         case SpvOpLoad:
833             resultTypeWord = 0;
834             break;
835 
836         default:
837             return nullptr;
838     }
839 
840     Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
841     SkASSERT(typeInstr);
842     return typeInstr;
843 }
844 
numComponentsForVecInstruction(const Instruction & instr)845 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
846     // If an instruction is in the op cache, its type should be as well.
847     Instruction* typeInstr = this->resultTypeForInstruction(instr);
848     SkASSERT(typeInstr);
849     SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
850              typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
851 
852     // For vectors, extract their column count. Scalars have one component by definition.
853     //   SpvOpTypeVector ResultID ComponentType NumComponents
854     return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
855                                                : 1;
856 }
857 
toComponent(SpvId id,int component)858 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
859     Instruction* instr = fSpvIdCache.find(id);
860     if (!instr) {
861         return NA;
862     }
863     if (instr->fOp == SpvOpConstantComposite) {
864         // SpvOpConstantComposite ResultType ResultID [components...]
865         // Add 2 to the component index to skip past ResultType and ResultID.
866         return instr->fWords[2 + component];
867     }
868     if (instr->fOp == SpvOpCompositeConstruct) {
869         // SpvOpCompositeConstruct ResultType ResultID [components...]
870         // Vectors have special rules; check to see if we are composing a vector.
871         Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
872         SkASSERT(composedType);
873 
874         // When composing a non-vector, each instruction word maps 1:1 to the component index.
875         // We can just extract out the associated component directly.
876         if (composedType->fOp != SpvOpTypeVector) {
877             return instr->fWords[2 + component];
878         }
879 
880         // When composing a vector, components can be either scalars or vectors.
881         // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
882         for (int index = 2; index < instr->fWords.size(); ++index) {
883             int32_t currentWord = instr->fWords[index];
884 
885             // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
886             Instruction* subinstr = fSpvIdCache.find(currentWord);
887             if (!subinstr) {
888                 return NA;
889             }
890             // If this subinstruction contains the component we're looking for...
891             int numComponents = this->numComponentsForVecInstruction(*subinstr);
892             if (component < numComponents) {
893                 if (numComponents == 1) {
894                     // ... it's a scalar. Return it.
895                     SkASSERT(component == 0);
896                     return currentWord;
897                 } else {
898                     // ... it's a vector. Recurse into it.
899                     return this->toComponent(currentWord, component);
900                 }
901             }
902             // This sub-instruction doesn't contain our component. Keep walking forward.
903             component -= numComponents;
904         }
905         SkDEBUGFAIL("component index goes past the end of this composite value");
906         return NA;
907     }
908     return NA;
909 }
910 
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)911 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
912                                                   SpvId base,
913                                                   int component,
914                                                   OutputStream& out) {
915     // If the base op is a composite, we can extract from it directly.
916     SpvId result = this->toComponent(base, component);
917     if (result != NA) {
918         return result;
919     }
920     return this->writeInstruction(
921             SpvOpCompositeExtract,
922             {this->getType(type), Word::Result(type), base, Word::Number(component)},
923             out);
924 }
925 
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)926 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
927                                                   SpvId base,
928                                                   int componentA,
929                                                   int componentB,
930                                                   OutputStream& out) {
931     // If the base op is a composite, we can extract from it directly.
932     SpvId result = this->toComponent(base, componentA);
933     if (result != NA) {
934         return this->writeOpCompositeExtract(type, result, componentB, out);
935     }
936     return this->writeInstruction(SpvOpCompositeExtract,
937                                   {this->getType(type),
938                                    Word::Result(type),
939                                    base,
940                                    Word::Number(componentA),
941                                    Word::Number(componentB)},
942                                   out);
943 }
944 
writeCapabilities(OutputStream & out)945 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
946     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
947         if (fCapabilities & bit) {
948             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
949         }
950     }
951     this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
952 }
953 
nextId(const Type * type)954 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
955     return this->nextId(type && type->hasPrecision() && !type->highPrecision()
956                 ? Precision::kRelaxed
957                 : Precision::kDefault);
958 }
959 
nextId(Precision precision)960 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
961     if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
962         this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
963                                fDecorationBuffer);
964     }
965     return fIdCount++;
966 }
967 
writeStruct(const Type & type,const MemoryLayout & memoryLayout)968 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
969     // If we've already written out this struct, return its existing SpvId.
970     if (SpvId* cachedStructId = fStructMap.find(&type)) {
971         return *cachedStructId;
972     }
973 
974     // Write all of the field types first, so we don't inadvertently write them while we're in the
975     // middle of writing the struct instruction.
976     Words words;
977     words.push_back(Word::UniqueResult());
978     for (const auto& f : type.fields()) {
979         words.push_back(this->getType(*f.fType, memoryLayout));
980     }
981     SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
982     this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
983     fStructMap.set(&type, resultId);
984 
985     size_t offset = 0;
986     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
987         const Type::Field& field = type.fields()[i];
988         if (!memoryLayout.isSupported(*field.fType)) {
989             fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
990                                                     "' is not permitted here");
991             return resultId;
992         }
993         size_t size = memoryLayout.size(*field.fType);
994         size_t alignment = memoryLayout.alignment(*field.fType);
995         const Layout& fieldLayout = field.fModifiers.fLayout;
996         if (fieldLayout.fOffset >= 0) {
997             if (fieldLayout.fOffset < (int) offset) {
998                 fContext.fErrors->error(field.fPosition, "offset of field '" +
999                         std::string(field.fName) + "' must be at least " + std::to_string(offset));
1000             }
1001             if (fieldLayout.fOffset % alignment) {
1002                 fContext.fErrors->error(field.fPosition,
1003                                         "offset of field '" + std::string(field.fName) +
1004                                         "' must be a multiple of " + std::to_string(alignment));
1005             }
1006             offset = fieldLayout.fOffset;
1007         } else {
1008             size_t mod = offset % alignment;
1009             if (mod) {
1010                 offset += alignment - mod;
1011             }
1012         }
1013         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1014         this->writeFieldLayout(fieldLayout, resultId, i);
1015         if (field.fModifiers.fLayout.fBuiltin < 0) {
1016             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1017                                    (SpvId) offset, fDecorationBuffer);
1018         }
1019         if (field.fType->isMatrix()) {
1020             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1021                                    fDecorationBuffer);
1022             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1023                                    (SpvId) memoryLayout.stride(*field.fType),
1024                                    fDecorationBuffer);
1025         }
1026         if (!field.fType->highPrecision()) {
1027             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1028                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
1029         }
1030         offset += size;
1031         if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1032             offset += alignment - offset % alignment;
1033         }
1034     }
1035 
1036     return resultId;
1037 }
1038 
getType(const Type & type)1039 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1040     return this->getType(type, fDefaultLayout);
1041 }
1042 
getType(const Type & rawType,const MemoryLayout & layout)1043 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
1044     const Type* type = &rawType;
1045 
1046     switch (type->typeKind()) {
1047         case Type::TypeKind::kVoid: {
1048             return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1049         }
1050         case Type::TypeKind::kScalar:
1051         case Type::TypeKind::kLiteral: {
1052             if (type->isBoolean()) {
1053                 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1054             }
1055             if (type->isSigned()) {
1056                 return this->writeInstruction(
1057                         SpvOpTypeInt,
1058                         Words{Word::Result(), Word::Number(32), Word::Number(1)},
1059                         fConstantBuffer);
1060             }
1061             if (type->isUnsigned()) {
1062                 return this->writeInstruction(
1063                         SpvOpTypeInt,
1064                         Words{Word::Result(), Word::Number(32), Word::Number(0)},
1065                         fConstantBuffer);
1066             }
1067             if (type->isFloat()) {
1068                 return this->writeInstruction(
1069                         SpvOpTypeFloat,
1070                         Words{Word::Result(), Word::Number(32)},
1071                         fConstantBuffer);
1072             }
1073             SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1074             return (SpvId)-1;
1075         }
1076         case Type::TypeKind::kVector: {
1077             SpvId scalarTypeId = this->getType(type->componentType(), layout);
1078             return this->writeInstruction(
1079                     SpvOpTypeVector,
1080                     Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1081                     fConstantBuffer);
1082         }
1083         case Type::TypeKind::kMatrix: {
1084             SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type), layout);
1085             return this->writeInstruction(
1086                     SpvOpTypeMatrix,
1087                     Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1088                     fConstantBuffer);
1089         }
1090         case Type::TypeKind::kArray: {
1091             if (!layout.isSupported(*type)) {
1092                 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1093                                                          "' is not permitted here");
1094                 return NA;
1095             }
1096             if (type->columns() == 0) {
1097                 // We do not support runtime-sized arrays.
1098                 fContext.fErrors->error(type->fPosition, "runtime-sized arrays are not supported");
1099                 return NA;
1100             }
1101             size_t stride = layout.stride(*type);
1102             SpvId typeId = this->getType(type->componentType(), layout);
1103             SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1104             SpvId result = this->writeInstruction(SpvOpTypeArray,
1105                                                   Words{Word::KeyedResult(stride), typeId, countId},
1106                                                   fConstantBuffer);
1107             this->writeInstruction(SpvOpDecorate,
1108                                    {result, SpvDecorationArrayStride, Word::Number(stride)},
1109                                    fDecorationBuffer);
1110             return result;
1111         }
1112         case Type::TypeKind::kStruct: {
1113             return this->writeStruct(*type, layout);
1114         }
1115         case Type::TypeKind::kSeparateSampler: {
1116             return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1117         }
1118         case Type::TypeKind::kSampler: {
1119             // Subpass inputs should use the Texture type, not a Sampler.
1120             SkASSERT(type->dimensions() != SpvDimSubpassData);
1121             if (SpvDimBuffer == type->dimensions()) {
1122                 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1123             }
1124             SpvId imageTypeId = this->getType(type->textureType(), layout);
1125             return this->writeInstruction(SpvOpTypeSampledImage,
1126                                           Words{Word::Result(), imageTypeId},
1127                                           fConstantBuffer);
1128         }
1129         case Type::TypeKind::kTexture: {
1130             SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat, layout);
1131             int sampled = (type->textureAccess() == Type::TextureAccess::kSample) ? 1 : 2;
1132             return this->writeInstruction(SpvOpTypeImage,
1133                                           Words{Word::Result(),
1134                                                 floatTypeId,
1135                                                 Word::Number(type->dimensions()),
1136                                                 Word::Number(type->isDepth()),
1137                                                 Word::Number(type->isArrayedTexture()),
1138                                                 Word::Number(type->isMultisampled()),
1139                                                 Word::Number(sampled),
1140                                                 SpvImageFormatUnknown},
1141                                           fConstantBuffer);
1142         }
1143         default: {
1144             SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1145             return NA;
1146         }
1147     }
1148 }
1149 
getFunctionType(const FunctionDeclaration & function)1150 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1151     Words words;
1152     words.push_back(Word::Result());
1153     words.push_back(this->getType(function.returnType()));
1154     for (const Variable* parameter : function.parameters()) {
1155         if (parameter->type().typeKind() == Type::TypeKind::kSampler &&
1156             fProgram.fConfig->fSettings.fSPIRVDawnCompatMode) {
1157             words.push_back(this->getFunctionParameterType(parameter->type().textureType()));
1158             words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler));
1159         } else {
1160             words.push_back(this->getFunctionParameterType(parameter->type()));
1161         }
1162     }
1163     return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1164 }
1165 
getFunctionParameterType(const Type & parameterType)1166 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType) {
1167     // glslang treats all function arguments as pointers whether they need to be or
1168     // not. I was initially puzzled by this until I ran bizarre failures with certain
1169     // patterns of function calls and control constructs, as exemplified by this minimal
1170     // failure case:
1171     //
1172     // void sphere(float x) {
1173     // }
1174     //
1175     // void map() {
1176     //     sphere(1.0);
1177     // }
1178     //
1179     // void main() {
1180     //     for (int i = 0; i < 1; i++) {
1181     //         map();
1182     //     }
1183     // }
1184     //
1185     // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1186     // crashes. Making it take a float* and storing the argument in a temporary variable,
1187     // as glslang does, fixes it.
1188     //
1189     // The consensus among shader compiler authors seems to be that GPU driver generally don't
1190     // handle value-based parameters consistently. It is highly likely that they fit their
1191     // implementations to conform to glslang. We take care to do so ourselves.
1192     //
1193     // Our implementation first stores every parameter value into a function storage-class pointer
1194     // before calling a function. The exception is for opaque handle types (samplers and textures)
1195     // which must be stored in a pointer with UniformConstant storage-class. This prevents
1196     // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1197     // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1198     SpvStorageClass_ storageClass;
1199     if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1200         parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1201         parameterType.typeKind() == Type::TypeKind::kTexture) {
1202         storageClass = SpvStorageClassUniformConstant;
1203     } else {
1204         storageClass = SpvStorageClassFunction;
1205     }
1206     return this->getPointerType(parameterType, storageClass);
1207 }
1208 
getPointerType(const Type & type,SpvStorageClass_ storageClass)1209 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1210     return this->getPointerType(
1211             type, this->memoryLayoutForStorageClass(storageClass), storageClass);
1212 }
1213 
getPointerType(const Type & type,const MemoryLayout & layout,SpvStorageClass_ storageClass)1214 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1215                                          SpvStorageClass_ storageClass) {
1216     return this->writeInstruction(
1217             SpvOpTypePointer,
1218             Words{Word::Result(), Word::Number(storageClass), this->getType(type, layout)},
1219             fConstantBuffer);
1220 }
1221 
writeExpression(const Expression & expr,OutputStream & out)1222 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1223     switch (expr.kind()) {
1224         case Expression::Kind::kBinary:
1225             return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1226         case Expression::Kind::kConstructorArrayCast:
1227             return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1228         case Expression::Kind::kConstructorArray:
1229         case Expression::Kind::kConstructorStruct:
1230             return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1231         case Expression::Kind::kConstructorDiagonalMatrix:
1232             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1233         case Expression::Kind::kConstructorMatrixResize:
1234             return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1235         case Expression::Kind::kConstructorScalarCast:
1236             return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1237         case Expression::Kind::kConstructorSplat:
1238             return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1239         case Expression::Kind::kConstructorCompound:
1240             return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1241         case Expression::Kind::kConstructorCompoundCast:
1242             return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1243         case Expression::Kind::kFieldAccess:
1244             return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1245         case Expression::Kind::kFunctionCall:
1246             return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1247         case Expression::Kind::kLiteral:
1248             return this->writeLiteral(expr.as<Literal>());
1249         case Expression::Kind::kPrefix:
1250             return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1251         case Expression::Kind::kPostfix:
1252             return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1253         case Expression::Kind::kSwizzle:
1254             return this->writeSwizzle(expr.as<Swizzle>(), out);
1255         case Expression::Kind::kVariableReference:
1256             return this->writeVariableReference(expr.as<VariableReference>(), out);
1257         case Expression::Kind::kTernary:
1258             return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1259         case Expression::Kind::kIndex:
1260             return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1261         case Expression::Kind::kSetting:
1262             return this->writeExpression(*expr.as<Setting>().toLiteral(fContext), out);
1263         default:
1264             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1265             break;
1266     }
1267     return NA;
1268 }
1269 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1270 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1271     const FunctionDeclaration& function = c.function();
1272     Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1273     if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1274         fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1275                 "'");
1276         return NA;
1277     }
1278     const ExpressionArray& arguments = c.arguments();
1279     int32_t intrinsicId = intrinsic.floatOp;
1280     if (arguments.size() > 0) {
1281         const Type& type = arguments[0]->type();
1282         if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1283             // Keep the default float op.
1284         } else {
1285             intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1286                                        intrinsic.unsignedOp, intrinsic.boolOp);
1287         }
1288     }
1289     switch (intrinsic.opKind) {
1290         case kGLSL_STD_450_IntrinsicOpcodeKind: {
1291             SpvId result = this->nextId(&c.type());
1292             SkTArray<SpvId> argumentIds;
1293             std::vector<TempVar> tempVars;
1294             argumentIds.reserve_back(arguments.size());
1295             for (int i = 0; i < arguments.size(); i++) {
1296                 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1297             }
1298             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1299             this->writeWord(this->getType(c.type()), out);
1300             this->writeWord(result, out);
1301             this->writeWord(fGLSLExtendedInstructions, out);
1302             this->writeWord(intrinsicId, out);
1303             for (SpvId id : argumentIds) {
1304                 this->writeWord(id, out);
1305             }
1306             this->copyBackTempVars(tempVars, out);
1307             return result;
1308         }
1309         case kSPIRV_IntrinsicOpcodeKind: {
1310             // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
1311             if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
1312                 intrinsicId = SpvOpFMul;
1313             }
1314             SpvId result = this->nextId(&c.type());
1315             SkTArray<SpvId> argumentIds;
1316             std::vector<TempVar> tempVars;
1317             argumentIds.reserve_back(arguments.size());
1318             for (int i = 0; i < arguments.size(); i++) {
1319                 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1320             }
1321             if (!c.type().isVoid()) {
1322                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1323                 this->writeWord(this->getType(c.type()), out);
1324                 this->writeWord(result, out);
1325             } else {
1326                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
1327             }
1328             for (SpvId id : argumentIds) {
1329                 this->writeWord(id, out);
1330             }
1331             this->copyBackTempVars(tempVars, out);
1332             return result;
1333         }
1334         case kSpecial_IntrinsicOpcodeKind:
1335             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1336         default:
1337             fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
1338                     function.description() + "'");
1339             return NA;
1340     }
1341 }
1342 
vectorize(const Expression & arg,int vectorSize,OutputStream & out)1343 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
1344     SkASSERT(vectorSize >= 1 && vectorSize <= 4);
1345     const Type& argType = arg.type();
1346     if (argType.isScalar() && vectorSize > 1) {
1347         ConstructorSplat splat{arg.fPosition,
1348                                argType.toCompound(fContext, vectorSize, /*rows=*/1),
1349                                arg.clone()};
1350         return this->writeConstructorSplat(splat, out);
1351     }
1352 
1353     SkASSERT(vectorSize == argType.columns());
1354     return this->writeExpression(arg, out);
1355 }
1356 
vectorize(const ExpressionArray & args,OutputStream & out)1357 SkTArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
1358     int vectorSize = 1;
1359     for (const auto& a : args) {
1360         if (a->type().isVector()) {
1361             if (vectorSize > 1) {
1362                 SkASSERT(a->type().columns() == vectorSize);
1363             } else {
1364                 vectorSize = a->type().columns();
1365             }
1366         }
1367     }
1368     SkTArray<SpvId> result;
1369     result.reserve_back(args.size());
1370     for (const auto& arg : args) {
1371         result.push_back(this->vectorize(*arg, vectorSize, out));
1372     }
1373     return result;
1374 }
1375 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const SkTArray<SpvId> & args,OutputStream & out)1376 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
1377                                                       SpvId signedInst, SpvId unsignedInst,
1378                                                       const SkTArray<SpvId>& args,
1379                                                       OutputStream& out) {
1380     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
1381     this->writeWord(this->getType(type), out);
1382     this->writeWord(id, out);
1383     this->writeWord(fGLSLExtendedInstructions, out);
1384     this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
1385     for (SpvId a : args) {
1386         this->writeWord(a, out);
1387     }
1388 }
1389 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)1390 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1391                                                 OutputStream& out) {
1392     const ExpressionArray& arguments = c.arguments();
1393     const Type& callType = c.type();
1394     SpvId result = this->nextId(nullptr);
1395     switch (kind) {
1396         case kAtan_SpecialIntrinsic: {
1397             SkSTArray<2, SpvId> argumentIds;
1398             for (const std::unique_ptr<Expression>& arg : arguments) {
1399                 argumentIds.push_back(this->writeExpression(*arg, out));
1400             }
1401             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1402             this->writeWord(this->getType(callType), out);
1403             this->writeWord(result, out);
1404             this->writeWord(fGLSLExtendedInstructions, out);
1405             this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1406             for (SpvId id : argumentIds) {
1407                 this->writeWord(id, out);
1408             }
1409             break;
1410         }
1411         case kSampledImage_SpecialIntrinsic: {
1412             SkASSERT(arguments.size() == 2);
1413             SpvId img = this->writeExpression(*arguments[0], out);
1414             SpvId sampler = this->writeExpression(*arguments[1], out);
1415             this->writeInstruction(SpvOpSampledImage,
1416                                    this->getType(callType),
1417                                    result,
1418                                    img,
1419                                    sampler,
1420                                    out);
1421             break;
1422         }
1423         case kSubpassLoad_SpecialIntrinsic: {
1424             SpvId img = this->writeExpression(*arguments[0], out);
1425             ExpressionArray args;
1426             args.reserve_back(2);
1427             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
1428             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
1429             ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
1430             SpvId coords = this->writeExpression(ctor, out);
1431             if (arguments.size() == 1) {
1432                 this->writeInstruction(SpvOpImageRead,
1433                                        this->getType(callType),
1434                                        result,
1435                                        img,
1436                                        coords,
1437                                        out);
1438             } else {
1439                 SkASSERT(arguments.size() == 2);
1440                 SpvId sample = this->writeExpression(*arguments[1], out);
1441                 this->writeInstruction(SpvOpImageRead,
1442                                        this->getType(callType),
1443                                        result,
1444                                        img,
1445                                        coords,
1446                                        SpvImageOperandsSampleMask,
1447                                        sample,
1448                                        out);
1449             }
1450             break;
1451         }
1452         case kTexture_SpecialIntrinsic: {
1453             SpvOp_ op = SpvOpImageSampleImplicitLod;
1454             const Type& arg1Type = arguments[1]->type();
1455             switch (arguments[0]->type().dimensions()) {
1456                 case SpvDim1D:
1457                     if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
1458                         op = SpvOpImageSampleProjImplicitLod;
1459                     } else {
1460                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
1461                     }
1462                     break;
1463                 case SpvDim2D:
1464                     if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
1465                         op = SpvOpImageSampleProjImplicitLod;
1466                     } else {
1467                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
1468                     }
1469                     break;
1470                 case SpvDim3D:
1471                     if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
1472                         op = SpvOpImageSampleProjImplicitLod;
1473                     } else {
1474                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
1475                     }
1476                     break;
1477                 case SpvDimCube:   // fall through
1478                 case SpvDimRect:   // fall through
1479                 case SpvDimBuffer: // fall through
1480                 case SpvDimSubpassData:
1481                     break;
1482             }
1483             SpvId type = this->getType(callType);
1484             SpvId sampler = this->writeExpression(*arguments[0], out);
1485             SpvId uv = this->writeExpression(*arguments[1], out);
1486             if (arguments.size() == 3) {
1487                 this->writeInstruction(op, type, result, sampler, uv,
1488                                        SpvImageOperandsBiasMask,
1489                                        this->writeExpression(*arguments[2], out),
1490                                        out);
1491             } else {
1492                 SkASSERT(arguments.size() == 2);
1493                 if (fProgram.fConfig->fSettings.fSharpenTextures) {
1494                     SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
1495                                                        *fContext.fTypes.fFloat);
1496                     this->writeInstruction(op, type, result, sampler, uv,
1497                                            SpvImageOperandsBiasMask, lodBias, out);
1498                 } else {
1499                     this->writeInstruction(op, type, result, sampler, uv,
1500                                            out);
1501                 }
1502             }
1503             break;
1504         }
1505         case kTextureGrad_SpecialIntrinsic: {
1506             SpvOp_ op = SpvOpImageSampleExplicitLod;
1507             SkASSERT(arguments.size() == 4);
1508             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
1509             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
1510             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
1511             SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
1512             SpvId type = this->getType(callType);
1513             SpvId sampler = this->writeExpression(*arguments[0], out);
1514             SpvId uv = this->writeExpression(*arguments[1], out);
1515             SpvId dPdx = this->writeExpression(*arguments[2], out);
1516             SpvId dPdy = this->writeExpression(*arguments[3], out);
1517             this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
1518                                    dPdx, dPdy, out);
1519             break;
1520         }
1521         case kTextureLod_SpecialIntrinsic: {
1522             SpvOp_ op = SpvOpImageSampleExplicitLod;
1523             SkASSERT(arguments.size() == 3);
1524             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
1525             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
1526             const Type& arg1Type = arguments[1]->type();
1527             if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
1528                 op = SpvOpImageSampleProjExplicitLod;
1529             } else {
1530                 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
1531             }
1532             SpvId type = this->getType(callType);
1533             SpvId sampler = this->writeExpression(*arguments[0], out);
1534             SpvId uv = this->writeExpression(*arguments[1], out);
1535             this->writeInstruction(op, type, result, sampler, uv,
1536                                    SpvImageOperandsLodMask,
1537                                    this->writeExpression(*arguments[2], out),
1538                                    out);
1539             break;
1540         }
1541         case kMod_SpecialIntrinsic: {
1542             SkTArray<SpvId> args = this->vectorize(arguments, out);
1543             SkASSERT(args.size() == 2);
1544             const Type& operandType = arguments[0]->type();
1545             SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
1546             SkASSERT(op != SpvOpUndef);
1547             this->writeOpCode(op, 5, out);
1548             this->writeWord(this->getType(operandType), out);
1549             this->writeWord(result, out);
1550             this->writeWord(args[0], out);
1551             this->writeWord(args[1], out);
1552             break;
1553         }
1554         case kDFdy_SpecialIntrinsic: {
1555             SpvId fn = this->writeExpression(*arguments[0], out);
1556             this->writeOpCode(SpvOpDPdy, 4, out);
1557             this->writeWord(this->getType(callType), out);
1558             this->writeWord(result, out);
1559             this->writeWord(fn, out);
1560             if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
1561                 this->addRTFlipUniform(c.fPosition);
1562                 using namespace dsl;
1563                 DSLExpression rtFlip(
1564                         ThreadContext::Compiler().convertIdentifier(Position(), SKSL_RTFLIP_NAME));
1565                 SpvId rtFlipY = this->vectorize(*rtFlip.y().release(), callType.columns(), out);
1566                 SpvId flipped = this->nextId(&callType);
1567                 this->writeInstruction(
1568                         SpvOpFMul, this->getType(callType), flipped, result, rtFlipY, out);
1569                 result = flipped;
1570             }
1571             break;
1572         }
1573         case kClamp_SpecialIntrinsic: {
1574             SkTArray<SpvId> args = this->vectorize(arguments, out);
1575             SkASSERT(args.size() == 3);
1576             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1577                                                GLSLstd450UClamp, args, out);
1578             break;
1579         }
1580         case kMax_SpecialIntrinsic: {
1581             SkTArray<SpvId> args = this->vectorize(arguments, out);
1582             SkASSERT(args.size() == 2);
1583             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
1584                                                GLSLstd450UMax, args, out);
1585             break;
1586         }
1587         case kMin_SpecialIntrinsic: {
1588             SkTArray<SpvId> args = this->vectorize(arguments, out);
1589             SkASSERT(args.size() == 2);
1590             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
1591                                                GLSLstd450UMin, args, out);
1592             break;
1593         }
1594         case kMix_SpecialIntrinsic: {
1595             SkTArray<SpvId> args = this->vectorize(arguments, out);
1596             SkASSERT(args.size() == 3);
1597             if (arguments[2]->type().componentType().isBoolean()) {
1598                 // Use OpSelect to implement Boolean mix().
1599                 SpvId falseId     = this->writeExpression(*arguments[0], out);
1600                 SpvId trueId      = this->writeExpression(*arguments[1], out);
1601                 SpvId conditionId = this->writeExpression(*arguments[2], out);
1602                 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
1603                                        conditionId, trueId, falseId, out);
1604             } else {
1605                 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
1606                                                    SpvOpUndef, args, out);
1607             }
1608             break;
1609         }
1610         case kSaturate_SpecialIntrinsic: {
1611             SkASSERT(arguments.size() == 1);
1612             ExpressionArray finalArgs;
1613             finalArgs.reserve_back(3);
1614             finalArgs.push_back(arguments[0]->clone());
1615             finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/0));
1616             finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/1));
1617             SkTArray<SpvId> spvArgs = this->vectorize(finalArgs, out);
1618             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1619                                                GLSLstd450UClamp, spvArgs, out);
1620             break;
1621         }
1622         case kSmoothStep_SpecialIntrinsic: {
1623             SkTArray<SpvId> args = this->vectorize(arguments, out);
1624             SkASSERT(args.size() == 3);
1625             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
1626                                                SpvOpUndef, args, out);
1627             break;
1628         }
1629         case kStep_SpecialIntrinsic: {
1630             SkTArray<SpvId> args = this->vectorize(arguments, out);
1631             SkASSERT(args.size() == 2);
1632             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
1633                                                SpvOpUndef, args, out);
1634             break;
1635         }
1636         case kMatrixCompMult_SpecialIntrinsic: {
1637             SkASSERT(arguments.size() == 2);
1638             SpvId lhs = this->writeExpression(*arguments[0], out);
1639             SpvId rhs = this->writeExpression(*arguments[1], out);
1640             result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
1641             break;
1642         }
1643     }
1644     return result;
1645 }
1646 
writeFunctionCallArgument(const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,OutputStream & out,SpvId * outSynthesizedSamplerId)1647 SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const FunctionCall& call,
1648                                                     int argIndex,
1649                                                     std::vector<TempVar>* tempVars,
1650                                                     OutputStream& out,
1651                                                     SpvId* outSynthesizedSamplerId) {
1652     const FunctionDeclaration& funcDecl = call.function();
1653     const Expression& arg = *call.arguments()[argIndex];
1654     const Modifiers& paramModifiers = funcDecl.parameters()[argIndex]->modifiers();
1655 
1656     // ID of temporary variable that we will use to hold this argument, or 0 if it is being
1657     // passed directly
1658     SpvId tmpVar;
1659     // if we need a temporary var to store this argument, this is the value to store in the var
1660     SpvId tmpValueId = NA;
1661 
1662     if (is_out(paramModifiers)) {
1663         std::unique_ptr<LValue> lv = this->getLValue(arg, out);
1664         // We handle out params with a temp var that we copy back to the original variable at the
1665         // end of the call. GLSL guarantees that the original variable will be unchanged until the
1666         // end of the call, and also that out params are written back to their original variables in
1667         // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
1668         if (is_in(paramModifiers)) {
1669             tmpValueId = lv->load(out);
1670         }
1671         tmpVar = this->nextId(&arg.type());
1672         tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
1673     } else if (funcDecl.isIntrinsic()) {
1674         // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
1675         return this->writeExpression(arg, out);
1676     } else if (arg.is<VariableReference>() &&
1677                (arg.type().typeKind() == Type::TypeKind::kSampler ||
1678                 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
1679                 arg.type().typeKind() == Type::TypeKind::kTexture)) {
1680         // Opaque handle (sampler/texture) arguments are always declared as pointers but never
1681         // stored in intermediates when calling user-defined functions.
1682         //
1683         // The case for intrinsics (which take opaque arguments by value) is handled above just like
1684         // regular pointers.
1685         //
1686         // See getFunctionParameterType for further explanation.
1687         const Variable* var = arg.as<VariableReference>().variable();
1688 
1689         // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
1690         if (const auto* p = fSynthesizedSamplerMap.find(var)) {
1691             SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
1692             SkASSERT(arg.type().typeKind() == Type::TypeKind::kSampler);
1693             SkASSERT(outSynthesizedSamplerId);
1694 
1695             SpvId* img = fVariableMap.find((*p)->fTexture.get());
1696             SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
1697             SkASSERT(img);
1698             SkASSERT(sampler);
1699 
1700             *outSynthesizedSamplerId = *sampler;
1701             return *img;
1702         }
1703 
1704         SpvId* entry = fVariableMap.find(var);
1705         SkASSERTF(entry, "%s", arg.description().c_str());
1706         return *entry;
1707     } else {
1708         // We always use pointer parameters when calling user functions.
1709         // See getFunctionParameterType for further explanation.
1710         tmpValueId = this->writeExpression(arg, out);
1711         tmpVar = this->nextId(nullptr);
1712     }
1713     this->writeInstruction(SpvOpVariable,
1714                            this->getPointerType(arg.type(), SpvStorageClassFunction),
1715                            tmpVar,
1716                            SpvStorageClassFunction,
1717                            fVariableBuffer);
1718     if (tmpValueId != NA) {
1719         this->writeOpStore(SpvStorageClassFunction, tmpVar, tmpValueId, out);
1720     }
1721     return tmpVar;
1722 }
1723 
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)1724 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
1725     for (const TempVar& tempVar : tempVars) {
1726         SpvId load = this->nextId(tempVar.type);
1727         this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
1728         tempVar.lvalue->store(load, out);
1729     }
1730 }
1731 
writeFunctionCall(const FunctionCall & c,OutputStream & out)1732 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1733     const FunctionDeclaration& function = c.function();
1734     if (function.isIntrinsic() && !function.definition()) {
1735         return this->writeIntrinsicCall(c, out);
1736     }
1737     const ExpressionArray& arguments = c.arguments();
1738     SpvId* entry = fFunctionMap.find(&function);
1739     if (!entry) {
1740         fContext.fErrors->error(c.fPosition, "function '" + function.description() +
1741                 "' is not defined");
1742         return NA;
1743     }
1744     // Temp variables are used to write back out-parameters after the function call is complete.
1745     std::vector<TempVar> tempVars;
1746     SkTArray<SpvId> argumentIds;
1747     argumentIds.reserve_back(arguments.size());
1748     for (int i = 0; i < arguments.size(); i++) {
1749         SpvId samplerId = NA;
1750         argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out, &samplerId));
1751         if (samplerId != NA) {
1752             argumentIds.push_back(samplerId);
1753         }
1754     }
1755     SpvId result = this->nextId(nullptr);
1756     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
1757     this->writeWord(this->getType(c.type()), out);
1758     this->writeWord(result, out);
1759     this->writeWord(*entry, out);
1760     for (SpvId id : argumentIds) {
1761         this->writeWord(id, out);
1762     }
1763     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
1764     this->copyBackTempVars(tempVars, out);
1765     return result;
1766 }
1767 
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)1768 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
1769                                            const Type& inputType,
1770                                            const Type& outputType,
1771                                            OutputStream& out) {
1772     if (outputType.isFloat()) {
1773         return this->castScalarToFloat(inputExprId, inputType, outputType, out);
1774     }
1775     if (outputType.isSigned()) {
1776         return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
1777     }
1778     if (outputType.isUnsigned()) {
1779         return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
1780     }
1781     if (outputType.isBoolean()) {
1782         return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
1783     }
1784 
1785     fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
1786             outputType.description());
1787     return inputExprId;
1788 }
1789 
writeFloatConstructor(const AnyConstructor & c,OutputStream & out)1790 SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) {
1791     SkASSERT(c.argumentSpan().size() == 1);
1792     SkASSERT(c.type().isFloat());
1793     const Expression& ctorExpr = *c.argumentSpan().front();
1794     SpvId expressionId = this->writeExpression(ctorExpr, out);
1795     return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out);
1796 }
1797 
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1798 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
1799                                             const Type& outputType, OutputStream& out) {
1800     // Casting a float to float is a no-op.
1801     if (inputType.isFloat()) {
1802         return inputId;
1803     }
1804 
1805     // Given the input type, generate the appropriate instruction to cast to float.
1806     SpvId result = this->nextId(&outputType);
1807     if (inputType.isBoolean()) {
1808         // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
1809         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
1810         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1811         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1812                                inputId, oneID, zeroID, out);
1813     } else if (inputType.isSigned()) {
1814         this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
1815     } else if (inputType.isUnsigned()) {
1816         this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
1817     } else {
1818         SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
1819         return NA;
1820     }
1821     return result;
1822 }
1823 
writeIntConstructor(const AnyConstructor & c,OutputStream & out)1824 SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) {
1825     SkASSERT(c.argumentSpan().size() == 1);
1826     SkASSERT(c.type().isSigned());
1827     const Expression& ctorExpr = *c.argumentSpan().front();
1828     SpvId expressionId = this->writeExpression(ctorExpr, out);
1829     return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out);
1830 }
1831 
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1832 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
1833                                                 const Type& outputType, OutputStream& out) {
1834     // Casting a signed int to signed int is a no-op.
1835     if (inputType.isSigned()) {
1836         return inputId;
1837     }
1838 
1839     // Given the input type, generate the appropriate instruction to cast to signed int.
1840     SpvId result = this->nextId(&outputType);
1841     if (inputType.isBoolean()) {
1842         // Use OpSelect to convert the boolean argument to a literal 1 or 0.
1843         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
1844         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1845         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1846                                inputId, oneID, zeroID, out);
1847     } else if (inputType.isFloat()) {
1848         this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
1849     } else if (inputType.isUnsigned()) {
1850         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1851     } else {
1852         SkDEBUGFAILF("unsupported type for signed int typecast: %s",
1853                      inputType.description().c_str());
1854         return NA;
1855     }
1856     return result;
1857 }
1858 
writeUIntConstructor(const AnyConstructor & c,OutputStream & out)1859 SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) {
1860     SkASSERT(c.argumentSpan().size() == 1);
1861     SkASSERT(c.type().isUnsigned());
1862     const Expression& ctorExpr = *c.argumentSpan().front();
1863     SpvId expressionId = this->writeExpression(ctorExpr, out);
1864     return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out);
1865 }
1866 
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1867 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
1868                                                   const Type& outputType, OutputStream& out) {
1869     // Casting an unsigned int to unsigned int is a no-op.
1870     if (inputType.isUnsigned()) {
1871         return inputId;
1872     }
1873 
1874     // Given the input type, generate the appropriate instruction to cast to unsigned int.
1875     SpvId result = this->nextId(&outputType);
1876     if (inputType.isBoolean()) {
1877         // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
1878         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
1879         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1880         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1881                                inputId, oneID, zeroID, out);
1882     } else if (inputType.isFloat()) {
1883         this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
1884     } else if (inputType.isSigned()) {
1885         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1886     } else {
1887         SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
1888                      inputType.description().c_str());
1889         return NA;
1890     }
1891     return result;
1892 }
1893 
writeBooleanConstructor(const AnyConstructor & c,OutputStream & out)1894 SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) {
1895     SkASSERT(c.argumentSpan().size() == 1);
1896     SkASSERT(c.type().isBoolean());
1897     const Expression& ctorExpr = *c.argumentSpan().front();
1898     SpvId expressionId = this->writeExpression(ctorExpr, out);
1899     return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out);
1900 }
1901 
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1902 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
1903                                               const Type& outputType, OutputStream& out) {
1904     // Casting a bool to bool is a no-op.
1905     if (inputType.isBoolean()) {
1906         return inputId;
1907     }
1908 
1909     // Given the input type, generate the appropriate instruction to cast to bool.
1910     SpvId result = this->nextId(nullptr);
1911     if (inputType.isSigned()) {
1912         // Synthesize a boolean result by comparing the input against a signed zero literal.
1913         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1914         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1915                                inputId, zeroID, out);
1916     } else if (inputType.isUnsigned()) {
1917         // Synthesize a boolean result by comparing the input against an unsigned zero literal.
1918         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1919         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1920                                inputId, zeroID, out);
1921     } else if (inputType.isFloat()) {
1922         // Synthesize a boolean result by comparing the input against a floating-point zero literal.
1923         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1924         this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
1925                                inputId, zeroID, out);
1926     } else {
1927         SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
1928         return NA;
1929     }
1930     return result;
1931 }
1932 
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1933 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
1934                                           OutputStream& out) {
1935     SkASSERT(srcType.isMatrix());
1936     SkASSERT(dstType.isMatrix());
1937     SkASSERT(srcType.componentType().matches(dstType.componentType()));
1938     const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
1939     const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
1940     SkASSERT(dstType.componentType().isFloat());
1941     SpvId dstColumnTypeId = this->getType(dstColumnType);
1942     const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
1943     const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
1944 
1945     SkSTArray<4, SpvId> columns;
1946     for (int i = 0; i < dstType.columns(); i++) {
1947         if (i < srcType.columns()) {
1948             // we're still inside the src matrix, copy the column
1949             SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
1950             SpvId dstColumn;
1951             if (srcType.rows() == dstType.rows()) {
1952                 // columns are equal size, don't need to do anything
1953                 dstColumn = srcColumn;
1954             }
1955             else if (dstType.rows() > srcType.rows()) {
1956                 // dst column is bigger, need to zero-pad it
1957                 SkSTArray<4, SpvId> values;
1958                 values.push_back(srcColumn);
1959                 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
1960                     values.push_back((i == j) ? oneId : zeroId);
1961                 }
1962                 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
1963             }
1964             else {
1965                 // dst column is smaller, need to swizzle the src column
1966                 dstColumn = this->nextId(&dstType);
1967                 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
1968                 this->writeWord(dstColumnTypeId, out);
1969                 this->writeWord(dstColumn, out);
1970                 this->writeWord(srcColumn, out);
1971                 this->writeWord(srcColumn, out);
1972                 for (int j = 0; j < dstType.rows(); j++) {
1973                     this->writeWord(j, out);
1974                 }
1975             }
1976             columns.push_back(dstColumn);
1977         } else {
1978             // we're past the end of the src matrix, need to synthesize an identity-matrix column
1979             SkSTArray<4, SpvId> values;
1980             for (int j = 0; j < dstType.rows(); ++j) {
1981                 values.push_back((i == j) ? oneId : zeroId);
1982             }
1983             columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
1984         }
1985     }
1986 
1987     return this->writeOpCompositeConstruct(dstType, columns, out);
1988 }
1989 
addColumnEntry(const Type & columnType,SkTArray<SpvId> * currentColumn,SkTArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)1990 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
1991                                         SkTArray<SpvId>* currentColumn,
1992                                         SkTArray<SpvId>* columnIds,
1993                                         int rows,
1994                                         SpvId entry,
1995                                         OutputStream& out) {
1996     SkASSERT(currentColumn->size() < rows);
1997     currentColumn->push_back(entry);
1998     if (currentColumn->size() == rows) {
1999         // Synthesize this column into a vector.
2000         SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2001         columnIds->push_back(columnId);
2002         currentColumn->clear();
2003     }
2004 }
2005 
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2006 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2007     const Type& type = c.type();
2008     SkASSERT(type.isMatrix());
2009     SkASSERT(!c.arguments().empty());
2010     const Type& arg0Type = c.arguments()[0]->type();
2011     // go ahead and write the arguments so we don't try to write new instructions in the middle of
2012     // an instruction
2013     SkSTArray<16, SpvId> arguments;
2014     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2015         arguments.push_back(this->writeExpression(*arg, out));
2016     }
2017 
2018     if (arguments.size() == 1 && arg0Type.isVector()) {
2019         // Special-case handling of float4 -> mat2x2.
2020         SkASSERT(type.rows() == 2 && type.columns() == 2);
2021         SkASSERT(arg0Type.columns() == 4);
2022         SpvId v[4];
2023         for (int i = 0; i < 4; ++i) {
2024             v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2025         }
2026         const Type& vecType = type.componentType().toCompound(fContext, /*columns=*/2, /*rows=*/1);
2027         SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2028         SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2029         return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2030     }
2031 
2032     int rows = type.rows();
2033     const Type& columnType = type.componentType().toCompound(fContext,
2034                                                              /*columns=*/rows, /*rows=*/1);
2035     // SpvIds of completed columns of the matrix.
2036     SkSTArray<4, SpvId> columnIds;
2037     // SpvIds of scalars we have written to the current column so far.
2038     SkSTArray<4, SpvId> currentColumn;
2039     for (int i = 0; i < arguments.size(); i++) {
2040         const Type& argType = c.arguments()[i]->type();
2041         if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2042             // This vector is a complete matrix column by itself and can be used as-is.
2043             columnIds.push_back(arguments[i]);
2044         } else if (argType.columns() == 1) {
2045             // This argument is a lone scalar and can be added to the current column as-is.
2046             this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, arguments[i], out);
2047         } else {
2048             // This argument needs to be decomposed into its constituent scalars.
2049             for (int j = 0; j < argType.columns(); ++j) {
2050                 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2051                                                               arguments[i], j, out);
2052                 this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, swizzle, out);
2053             }
2054         }
2055     }
2056     SkASSERT(columnIds.size() == type.columns());
2057     return this->writeOpCompositeConstruct(type, columnIds, out);
2058 }
2059 
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2060 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2061                                                    OutputStream& out) {
2062     return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2063                                : this->writeVectorConstructor(c, out);
2064 }
2065 
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2066 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2067     const Type& type = c.type();
2068     const Type& componentType = type.componentType();
2069     SkASSERT(type.isVector());
2070 
2071     SkSTArray<4, SpvId> arguments;
2072     for (int i = 0; i < c.arguments().size(); i++) {
2073         const Type& argType = c.arguments()[i]->type();
2074         SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2075 
2076         SpvId arg = this->writeExpression(*c.arguments()[i], out);
2077         if (argType.isMatrix()) {
2078             // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2079             // each scalar separately.
2080             SkASSERT(argType.rows() == 2);
2081             SkASSERT(argType.columns() == 2);
2082             for (int j = 0; j < 4; ++j) {
2083                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2084                                                                   j / 2, j % 2, out));
2085             }
2086         } else if (argType.isVector()) {
2087             // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2088             // vector arguments at all, so we always extract each vector component and pass them
2089             // into OpCompositeConstruct individually.
2090             for (int j = 0; j < argType.columns(); j++) {
2091                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2092             }
2093         } else {
2094             arguments.push_back(arg);
2095         }
2096     }
2097 
2098     return this->writeOpCompositeConstruct(type, arguments, out);
2099 }
2100 
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2101 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2102     // Write the splat argument.
2103     SpvId argument = this->writeExpression(*c.argument(), out);
2104 
2105     // Generate a OpCompositeConstruct which repeats the argument N times.
2106     SkSTArray<4, SpvId> values;
2107     values.push_back_n(/*n=*/c.type().columns(), /*t=*/argument);
2108     return this->writeOpCompositeConstruct(c.type(), values, out);
2109 }
2110 
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2111 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2112     SkASSERT(c.type().isArray() || c.type().isStruct());
2113     auto ctorArgs = c.argumentSpan();
2114 
2115     SkSTArray<4, SpvId> arguments;
2116     for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2117         arguments.push_back(this->writeExpression(*arg, out));
2118     }
2119 
2120     return this->writeOpCompositeConstruct(c.type(), arguments, out);
2121 }
2122 
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2123 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2124                                                      OutputStream& out) {
2125     const Type& type = c.type();
2126     if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2127         return this->writeExpression(*c.argument(), out);
2128     }
2129 
2130     const Expression& ctorExpr = *c.argument();
2131     SpvId expressionId = this->writeExpression(ctorExpr, out);
2132     return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2133 }
2134 
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2135 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2136                                                        OutputStream& out) {
2137     const Type& ctorType = c.type();
2138     const Type& argType = c.argument()->type();
2139     SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2140 
2141     // Write the composite that we are casting. If the actual type matches, we are done.
2142     SpvId compositeId = this->writeExpression(*c.argument(), out);
2143     if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
2144         return compositeId;
2145     }
2146 
2147     // writeMatrixCopy can cast matrices to a different type.
2148     if (ctorType.isMatrix()) {
2149         return this->writeMatrixCopy(compositeId, argType, ctorType, out);
2150     }
2151 
2152     // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
2153     // components and convert each one manually.
2154     const Type& srcType = argType.componentType();
2155     const Type& dstType = ctorType.componentType();
2156 
2157     SkSTArray<4, SpvId> arguments;
2158     for (int index = 0; index < argType.columns(); ++index) {
2159         SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
2160         arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
2161     }
2162 
2163     return this->writeOpCompositeConstruct(ctorType, arguments, out);
2164 }
2165 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)2166 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
2167                                                          OutputStream& out) {
2168     const Type& type = c.type();
2169     SkASSERT(type.isMatrix());
2170     SkASSERT(c.argument()->type().isScalar());
2171 
2172     // Write out the scalar argument.
2173     SpvId diagonal = this->writeExpression(*c.argument(), out);
2174 
2175     // Build the diagonal matrix.
2176     SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2177 
2178     const Type& vecType = type.componentType().toCompound(fContext,
2179                                                           /*columns=*/type.rows(),
2180                                                           /*rows=*/1);
2181     SkSTArray<4, SpvId> columnIds;
2182     SkSTArray<4, SpvId> arguments;
2183     arguments.resize(type.rows());
2184     for (int column = 0; column < type.columns(); column++) {
2185         for (int row = 0; row < type.rows(); row++) {
2186             arguments[row] = (row == column) ? diagonal : zeroId;
2187         }
2188         columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
2189     }
2190     return this->writeOpCompositeConstruct(type, columnIds, out);
2191 }
2192 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)2193 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
2194                                                        OutputStream& out) {
2195     // Write the input matrix.
2196     SpvId argument = this->writeExpression(*c.argument(), out);
2197 
2198     // Use matrix-copy to resize the input matrix to its new size.
2199     return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
2200 }
2201 
get_storage_class_for_global_variable(const Variable & var,SpvStorageClass_ fallbackStorageClass)2202 static SpvStorageClass_ get_storage_class_for_global_variable(
2203         const Variable& var, SpvStorageClass_ fallbackStorageClass) {
2204     SkASSERT(var.storage() == Variable::Storage::kGlobal);
2205 
2206     const Modifiers& modifiers = var.modifiers();
2207     if (modifiers.fFlags & Modifiers::kIn_Flag) {
2208         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
2209         return SpvStorageClassInput;
2210     }
2211     if (modifiers.fFlags & Modifiers::kOut_Flag) {
2212         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
2213         return SpvStorageClassOutput;
2214     }
2215     if (modifiers.fFlags & Modifiers::kUniform_Flag) {
2216         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
2217             return SpvStorageClassPushConstant;
2218         }
2219         if (var.type().typeKind() == Type::TypeKind::kSampler ||
2220             var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2221             var.type().typeKind() == Type::TypeKind::kTexture) {
2222             return SpvStorageClassUniformConstant;
2223         }
2224         return SpvStorageClassUniform;
2225     }
2226     return fallbackStorageClass;
2227 }
2228 
get_storage_class(const Expression & expr)2229 static SpvStorageClass_ get_storage_class(const Expression& expr) {
2230     switch (expr.kind()) {
2231         case Expression::Kind::kVariableReference: {
2232             const Variable& var = *expr.as<VariableReference>().variable();
2233             if (var.storage() != Variable::Storage::kGlobal) {
2234                 return SpvStorageClassFunction;
2235             }
2236             return get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
2237         }
2238         case Expression::Kind::kFieldAccess:
2239             return get_storage_class(*expr.as<FieldAccess>().base());
2240         case Expression::Kind::kIndex:
2241             return get_storage_class(*expr.as<IndexExpression>().base());
2242         default:
2243             return SpvStorageClassFunction;
2244     }
2245 }
2246 
getAccessChain(const Expression & expr,OutputStream & out)2247 SkTArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
2248     switch (expr.kind()) {
2249         case Expression::Kind::kIndex: {
2250             const IndexExpression& indexExpr = expr.as<IndexExpression>();
2251             SkTArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
2252             chain.push_back(this->writeExpression(*indexExpr.index(), out));
2253             return chain;
2254         }
2255         case Expression::Kind::kFieldAccess: {
2256             const FieldAccess& fieldExpr = expr.as<FieldAccess>();
2257             SkTArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
2258             chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
2259             return chain;
2260         }
2261         default: {
2262             SpvId id = this->getLValue(expr, out)->getPointer();
2263             SkASSERT(id != NA);
2264             return SkTArray<SpvId>{id};
2265         }
2266     }
2267     SkUNREACHABLE;
2268 }
2269 
2270 class PointerLValue : public SPIRVCodeGenerator::LValue {
2271 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,SpvStorageClass_ storageClass)2272     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
2273                   SPIRVCodeGenerator::Precision precision, SpvStorageClass_ storageClass)
2274     : fGen(gen)
2275     , fPointer(pointer)
2276     , fIsMemoryObject(isMemoryObject)
2277     , fType(type)
2278     , fPrecision(precision)
2279     , fStorageClass(storageClass) {}
2280 
getPointer()2281     SpvId getPointer() override {
2282         return fPointer;
2283     }
2284 
isMemoryObjectPointer() const2285     bool isMemoryObjectPointer() const override {
2286         return fIsMemoryObject;
2287     }
2288 
load(OutputStream & out)2289     SpvId load(OutputStream& out) override {
2290         return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
2291     }
2292 
store(SpvId value,OutputStream & out)2293     void store(SpvId value, OutputStream& out) override {
2294         if (!fIsMemoryObject) {
2295             // We are going to write into an access chain; this could represent one component of a
2296             // vector, or one element of an array. This has the potential to invalidate other,
2297             // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
2298             // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
2299             // relying on stale data, reset the store cache entirely when this happens.
2300             fGen.fStoreCache.reset();
2301         }
2302 
2303         fGen.writeOpStore(fStorageClass, fPointer, value, out);
2304     }
2305 
2306 private:
2307     SPIRVCodeGenerator& fGen;
2308     const SpvId fPointer;
2309     const bool fIsMemoryObject;
2310     const SpvId fType;
2311     const SPIRVCodeGenerator::Precision fPrecision;
2312     const SpvStorageClass_ fStorageClass;
2313 };
2314 
2315 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
2316 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,SpvStorageClass_ storageClass)2317     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
2318                   const Type& baseType, const Type& swizzleType, SpvStorageClass_ storageClass)
2319     : fGen(gen)
2320     , fVecPointer(vecPointer)
2321     , fComponents(components)
2322     , fBaseType(&baseType)
2323     , fSwizzleType(&swizzleType)
2324     , fStorageClass(storageClass) {}
2325 
applySwizzle(const ComponentArray & components,const Type & newType)2326     bool applySwizzle(const ComponentArray& components, const Type& newType) override {
2327         ComponentArray updatedSwizzle;
2328         for (int8_t component : components) {
2329             if (component < 0 || component >= fComponents.size()) {
2330                 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
2331                 return false;
2332             }
2333             updatedSwizzle.push_back(fComponents[component]);
2334         }
2335         fComponents = updatedSwizzle;
2336         fSwizzleType = &newType;
2337         return true;
2338     }
2339 
load(OutputStream & out)2340     SpvId load(OutputStream& out) override {
2341         SpvId base = fGen.nextId(fBaseType);
2342         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
2343         SpvId result = fGen.nextId(fBaseType);
2344         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
2345         fGen.writeWord(fGen.getType(*fSwizzleType), out);
2346         fGen.writeWord(result, out);
2347         fGen.writeWord(base, out);
2348         fGen.writeWord(base, out);
2349         for (int component : fComponents) {
2350             fGen.writeWord(component, out);
2351         }
2352         return result;
2353     }
2354 
store(SpvId value,OutputStream & out)2355     void store(SpvId value, OutputStream& out) override {
2356         // use OpVectorShuffle to mix and match the vector components. We effectively create
2357         // a virtual vector out of the concatenation of the left and right vectors, and then
2358         // select components from this virtual vector to make the result vector. For
2359         // instance, given:
2360         // float3L = ...;
2361         // float3R = ...;
2362         // L.xz = R.xy;
2363         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
2364         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
2365         // (3, 1, 4).
2366         SpvId base = fGen.nextId(fBaseType);
2367         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
2368         SpvId shuffle = fGen.nextId(fBaseType);
2369         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
2370         fGen.writeWord(fGen.getType(*fBaseType), out);
2371         fGen.writeWord(shuffle, out);
2372         fGen.writeWord(base, out);
2373         fGen.writeWord(value, out);
2374         for (int i = 0; i < fBaseType->columns(); i++) {
2375             // current offset into the virtual vector, defaults to pulling the unmodified
2376             // value from the left side
2377             int offset = i;
2378             // check to see if we are writing this component
2379             for (int j = 0; j < fComponents.size(); j++) {
2380                 if (fComponents[j] == i) {
2381                     // we're writing to this component, so adjust the offset to pull from
2382                     // the correct component of the right side instead of preserving the
2383                     // value from the left
2384                     offset = (int) (j + fBaseType->columns());
2385                     break;
2386                 }
2387             }
2388             fGen.writeWord(offset, out);
2389         }
2390         fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
2391     }
2392 
2393 private:
2394     SPIRVCodeGenerator& fGen;
2395     const SpvId fVecPointer;
2396     ComponentArray fComponents;
2397     const Type* fBaseType;
2398     const Type* fSwizzleType;
2399     const SpvStorageClass_ fStorageClass;
2400 };
2401 
findUniformFieldIndex(const Variable & var) const2402 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
2403     int* fieldIndex = fTopLevelUniformMap.find(&var);
2404     return fieldIndex ? *fieldIndex : -1;
2405 }
2406 
getLValue(const Expression & expr,OutputStream & out)2407 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
2408                                                                           OutputStream& out) {
2409     const Type& type = expr.type();
2410     Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
2411     switch (expr.kind()) {
2412         case Expression::Kind::kVariableReference: {
2413             const Variable& var = *expr.as<VariableReference>().variable();
2414             int uniformIdx = this->findUniformFieldIndex(var);
2415             if (uniformIdx >= 0) {
2416                 SpvId memberId = this->nextId(nullptr);
2417                 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
2418                 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
2419                 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
2420                                        uniformIdxId, out);
2421                 return std::make_unique<PointerLValue>(
2422                         *this,
2423                         memberId,
2424                         /*isMemoryObjectPointer=*/true,
2425                         this->getType(type, this->memoryLayoutForVariable(var)),
2426                         precision,
2427                         SpvStorageClassUniform);
2428             }
2429             SpvId typeId = this->getType(type, this->memoryLayoutForVariable(var));
2430             SpvId* entry = fVariableMap.find(&var);
2431             SkASSERTF(entry, "%s", expr.description().c_str());
2432             return std::make_unique<PointerLValue>(*this, *entry,
2433                                                    /*isMemoryObjectPointer=*/true,
2434                                                    typeId, precision, get_storage_class(expr));
2435         }
2436         case Expression::Kind::kIndex: // fall through
2437         case Expression::Kind::kFieldAccess: {
2438             SkTArray<SpvId> chain = this->getAccessChain(expr, out);
2439             SpvId member = this->nextId(nullptr);
2440             SpvStorageClass_ storageClass = get_storage_class(expr);
2441             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
2442             this->writeWord(this->getPointerType(type, storageClass), out);
2443             this->writeWord(member, out);
2444             for (SpvId idx : chain) {
2445                 this->writeWord(idx, out);
2446             }
2447             return std::make_unique<PointerLValue>(
2448                     *this,
2449                     member,
2450                     /*isMemoryObjectPointer=*/false,
2451                     this->getType(type, this->memoryLayoutForStorageClass(storageClass)),
2452                     precision,
2453                     storageClass);
2454         }
2455         case Expression::Kind::kSwizzle: {
2456             const Swizzle& swizzle = expr.as<Swizzle>();
2457             std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
2458             if (lvalue->applySwizzle(swizzle.components(), type)) {
2459                 return lvalue;
2460             }
2461             SpvId base = lvalue->getPointer();
2462             if (base == NA) {
2463                 fContext.fErrors->error(swizzle.fPosition,
2464                         "unable to retrieve lvalue from swizzle");
2465             }
2466             SpvStorageClass_ storageClass = get_storage_class(*swizzle.base());
2467             if (swizzle.components().size() == 1) {
2468                 SpvId member = this->nextId(nullptr);
2469                 SpvId typeId = this->getPointerType(type, storageClass);
2470                 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
2471                 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
2472                 return std::make_unique<PointerLValue>(*this, member,
2473                                                        /*isMemoryObjectPointer=*/false,
2474                                                        this->getType(type),
2475                                                        precision, storageClass);
2476             } else {
2477                 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
2478                                                        swizzle.base()->type(), type, storageClass);
2479             }
2480         }
2481         default: {
2482             // expr isn't actually an lvalue, create a placeholder variable for it. This case
2483             // happens due to the need to store values in temporary variables during function
2484             // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
2485             // lvalues should have been caught before code generation.
2486             //
2487             // This is with the exception of opaque handle types (textures/samplers) which are
2488             // always defined as UniformConstant pointers and don't need to be explicitly stored
2489             // into a temporary (which is handled explicitly in writeFunctionCallArgument).
2490             SpvId result = this->nextId(nullptr);
2491             SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
2492             this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
2493                                    fVariableBuffer);
2494             this->writeOpStore(SpvStorageClassFunction, result, this->writeExpression(expr, out),
2495                                out);
2496             return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
2497                                                    this->getType(type), precision,
2498                                                    SpvStorageClassFunction);
2499         }
2500     }
2501 }
2502 
writeVariableReference(const VariableReference & ref,OutputStream & out)2503 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
2504     const Variable* variable = ref.variable();
2505     switch (variable->modifiers().fLayout.fBuiltin) {
2506         case DEVICE_FRAGCOORDS_BUILTIN: {
2507             // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
2508             // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
2509             // access the fragcoord; do so now.
2510             dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2511             return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2512         }
2513         case DEVICE_CLOCKWISE_BUILTIN: {
2514             // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
2515             // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
2516             // access front facing; do so now.
2517             dsl::DSLGlobalVar clockwise("sk_Clockwise");
2518             return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2519         }
2520         case SK_SECONDARYFRAGCOLOR_BUILTIN: {
2521             // sk_SecondaryFragColor corresponds to gl_SecondaryFragColorEXT, which isn't supposed
2522             // to appear in a SPIR-V program (it's only valid in ES2). Report an error.
2523             fContext.fErrors->error(ref.fPosition,
2524                     "sk_SecondaryFragColor is not allowed in SPIR-V");
2525             return NA;
2526         }
2527         case SK_FRAGCOORD_BUILTIN: {
2528             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
2529                 dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2530                 return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2531             }
2532 
2533             // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
2534             this->addRTFlipUniform(ref.fPosition);
2535             // Use sk_RTAdjust to compute the flipped coordinate
2536             using namespace dsl;
2537             const char* DEVICE_COORDS_NAME = "$device_FragCoords";
2538             SymbolTable& symbols = *ThreadContext::SymbolTable();
2539             // Use a uniform to flip the Y coordinate. The new expression will be written in
2540             // terms of $device_FragCoords, which is a fake variable that means "access the
2541             // underlying fragcoords directly without flipping it".
2542             DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(Position(),
2543                     SKSL_RTFLIP_NAME));
2544             if (!symbols.find(DEVICE_COORDS_NAME)) {
2545                 AutoAttachPoolToThread attach(fProgram.fPool.get());
2546                 Modifiers modifiers;
2547                 modifiers.fLayout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
2548                 auto coordsVar = std::make_unique<Variable>(/*pos=*/Position(),
2549                                                             /*modifiersPosition=*/Position(),
2550                                                             fContext.fModifiersPool->add(modifiers),
2551                                                             DEVICE_COORDS_NAME,
2552                                                             fContext.fTypes.fFloat4.get(),
2553                                                             /*builtin=*/true,
2554                                                             Variable::Storage::kGlobal);
2555                 fSPIRVBonusVariables.add(coordsVar.get());
2556                 symbols.add(std::move(coordsVar));
2557             }
2558             DSLGlobalVar deviceCoord(DEVICE_COORDS_NAME);
2559             std::unique_ptr<Expression> rtFlipSkSLExpr = rtFlip.release();
2560             DSLExpression x = DSLExpression(rtFlipSkSLExpr->clone()).x();
2561             DSLExpression y = DSLExpression(std::move(rtFlipSkSLExpr)).y();
2562             return this->writeExpression(*dsl::Float4(deviceCoord.x(),
2563                                                       std::move(x) + std::move(y) * deviceCoord.y(),
2564                                                       deviceCoord.z(),
2565                                                       deviceCoord.w()).release(),
2566                                          out);
2567         }
2568         case SK_CLOCKWISE_BUILTIN: {
2569             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
2570                 dsl::DSLGlobalVar clockwise("sk_Clockwise");
2571                 return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2572             }
2573 
2574             // Handle flipping sk_Clockwise.
2575             this->addRTFlipUniform(ref.fPosition);
2576             using namespace dsl;
2577             const char* DEVICE_CLOCKWISE_NAME = "$device_Clockwise";
2578             SymbolTable& symbols = *ThreadContext::SymbolTable();
2579             // Use a uniform to flip the Y coordinate. The new expression will be written in
2580             // terms of $device_Clockwise, which is a fake variable that means "access the
2581             // underlying FrontFacing directly".
2582             DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(Position(),
2583                     SKSL_RTFLIP_NAME));
2584             if (!symbols.find(DEVICE_CLOCKWISE_NAME)) {
2585                 AutoAttachPoolToThread attach(fProgram.fPool.get());
2586                 Modifiers modifiers;
2587                 modifiers.fLayout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
2588                 auto clockwiseVar = std::make_unique<Variable>(/*pos=*/Position(),
2589                         /*modifiersPosition=*/Position(),
2590                         fContext.fModifiersPool->add(modifiers),
2591                         DEVICE_CLOCKWISE_NAME,
2592                         fContext.fTypes.fBool.get(),
2593                         /*builtin=*/true,
2594                         Variable::Storage::kGlobal);
2595                 fSPIRVBonusVariables.add(clockwiseVar.get());
2596                 symbols.add(std::move(clockwiseVar));
2597             }
2598             DSLGlobalVar deviceClockwise(DEVICE_CLOCKWISE_NAME);
2599             // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia,
2600             // we use the default convention of "counter-clockwise face is front".
2601             return this->writeExpression(*dsl::Bool(Select(rtFlip.y() > 0,
2602                                                            !deviceClockwise,
2603                                                            deviceClockwise)).release(),
2604                                          out);
2605         }
2606         default: {
2607             // Constant-propagate variables that have a known compile-time value.
2608             if (const Expression* expr = ConstantFolder::GetConstantValueOrNullForVariable(ref)) {
2609                 return this->writeExpression(*expr, out);
2610             }
2611 
2612             // A reference to a sampler variable at global scope with synthesized texture/sampler
2613             // backing should construct a function-scope combined image-sampler from the synthesized
2614             // constituents. This is the case in which a sample intrinsic was invoked.
2615             //
2616             // Variable references to opaque handles (texture/sampler) that appear as the argument
2617             // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
2618             if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
2619                 SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
2620 
2621                 SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
2622                 SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
2623                 SkASSERT(imgPtr);
2624                 SkASSERT(samplerPtr);
2625 
2626                 SpvId img = this->writeOpLoad(
2627                         this->getType((*p)->fTexture->type()), Precision::kDefault, *imgPtr, out);
2628                 SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
2629                                                   Precision::kDefault,
2630                                                   *samplerPtr,
2631                                                   out);
2632 
2633                 SpvId result = this->nextId(nullptr);
2634                 this->writeInstruction(SpvOpSampledImage,
2635                                        this->getType(variable->type()),
2636                                        result,
2637                                        img,
2638                                        sampler,
2639                                        out);
2640 
2641                 return result;
2642             }
2643 
2644             return this->getLValue(ref, out)->load(out);
2645         }
2646     }
2647 }
2648 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)2649 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
2650     if (expr.base()->type().isVector()) {
2651         SpvId base = this->writeExpression(*expr.base(), out);
2652         SpvId index = this->writeExpression(*expr.index(), out);
2653         SpvId result = this->nextId(nullptr);
2654         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
2655                                index, out);
2656         return result;
2657     }
2658     return getLValue(expr, out)->load(out);
2659 }
2660 
writeFieldAccess(const FieldAccess & f,OutputStream & out)2661 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
2662     return getLValue(f, out)->load(out);
2663 }
2664 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)2665 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
2666     SpvId base = this->writeExpression(*swizzle.base(), out);
2667     size_t count = swizzle.components().size();
2668     if (count == 1) {
2669         return this->writeOpCompositeExtract(swizzle.type(), base, swizzle.components()[0], out);
2670     }
2671 
2672     SpvId result = this->nextId(&swizzle.type());
2673     this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
2674     this->writeWord(this->getType(swizzle.type()), out);
2675     this->writeWord(result, out);
2676     this->writeWord(base, out);
2677     this->writeWord(base, out);
2678     for (int component : swizzle.components()) {
2679         this->writeWord(component, out);
2680     }
2681     return result;
2682 }
2683 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2684 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2685                                                const Type& operandType, SpvId lhs,
2686                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2687                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2688     SpvId result = this->nextId(&resultType);
2689     SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
2690     if (op == SpvOpUndef) {
2691         fContext.fErrors->error(operandType.fPosition,
2692                 "unsupported operand for binary expression: " + operandType.description());
2693         return NA;
2694     }
2695     this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
2696     return result;
2697 }
2698 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)2699 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
2700                                      OutputStream& out) {
2701     if (operandType.isVector()) {
2702         SpvId result = this->nextId(nullptr);
2703         this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
2704         return result;
2705     }
2706     return id;
2707 }
2708 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)2709 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2710                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
2711                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
2712                                                 OutputStream& out) {
2713     SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
2714     SkASSERT(operandType.isMatrix());
2715     const Type& columnType = operandType.componentType().toCompound(fContext,
2716                                                                     operandType.rows(),
2717                                                                     1);
2718     SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
2719                                                                      operandType.rows(),
2720                                                                      1));
2721     SpvId boolType = this->getType(*fContext.fTypes.fBool);
2722     SpvId result = 0;
2723     for (int i = 0; i < operandType.columns(); i++) {
2724         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
2725         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
2726         SpvId compare = this->nextId(&operandType);
2727         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2728         SpvId merge = this->nextId(nullptr);
2729         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2730         if (result != 0) {
2731             SpvId next = this->nextId(nullptr);
2732             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2733             result = next;
2734         } else {
2735             result = merge;
2736         }
2737     }
2738     return result;
2739 }
2740 
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)2741 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
2742                                                         SpvId operand,
2743                                                         SpvOp_ op,
2744                                                         OutputStream& out) {
2745     SkASSERT(operandType.isMatrix());
2746     const Type& columnType = operandType.componentType().toCompound(fContext,
2747                                                                     /*columns=*/operandType.rows(),
2748                                                                     /*rows=*/1);
2749     SpvId columnTypeId = this->getType(columnType);
2750 
2751     SkSTArray<4, SpvId> columns;
2752     for (int i = 0; i < operandType.columns(); i++) {
2753         SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
2754         SpvId dstColumn = this->nextId(&operandType);
2755         this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
2756         columns.push_back(dstColumn);
2757     }
2758 
2759     return this->writeOpCompositeConstruct(operandType, columns, out);
2760 }
2761 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)2762 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2763                                                          SpvId rhs, SpvOp_ op, OutputStream& out) {
2764     SkASSERT(operandType.isMatrix());
2765     const Type& columnType = operandType.componentType().toCompound(fContext,
2766                                                                     /*columns=*/operandType.rows(),
2767                                                                     /*rows=*/1);
2768     SpvId columnTypeId = this->getType(columnType);
2769 
2770     SkSTArray<4, SpvId> columns;
2771     for (int i = 0; i < operandType.columns(); i++) {
2772         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
2773         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
2774         columns.push_back(this->nextId(&operandType));
2775         this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
2776     }
2777     return this->writeOpCompositeConstruct(operandType, columns, out);
2778 }
2779 
writeReciprocal(const Type & type,SpvId value,OutputStream & out)2780 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
2781     SkASSERT(type.isFloat());
2782     SpvId one = this->writeLiteral(1.0, type);
2783     SpvId reciprocal = this->nextId(&type);
2784     this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
2785     return reciprocal;
2786 }
2787 
writeScalarToMatrixSplat(const Type & matrixType,SpvId scalarId,OutputStream & out)2788 SpvId SPIRVCodeGenerator::writeScalarToMatrixSplat(const Type& matrixType,
2789                                                    SpvId scalarId,
2790                                                    OutputStream& out) {
2791     // Splat the scalar into a vector.
2792     const Type& vectorType = matrixType.componentType().toCompound(fContext,
2793                                                                    /*columns=*/matrixType.rows(),
2794                                                                    /*rows=*/1);
2795     SkSTArray<4, SpvId> vecArguments;
2796     vecArguments.push_back_n(/*n=*/matrixType.rows(), /*t=*/scalarId);
2797     SpvId vectorId = this->writeOpCompositeConstruct(vectorType, vecArguments, out);
2798 
2799     // Splat the vector into a matrix.
2800     SkSTArray<4, SpvId> matArguments;
2801     matArguments.push_back_n(/*n=*/matrixType.columns(), /*t=*/vectorId);
2802     return this->writeOpCompositeConstruct(matrixType, matArguments, out);
2803 }
2804 
types_match(const Type & a,const Type & b)2805 static bool types_match(const Type& a, const Type& b) {
2806     if (a.matches(b)) {
2807         return true;
2808     }
2809     return (a.typeKind() == b.typeKind()) &&
2810            (a.isScalar() || a.isVector() || a.isMatrix()) &&
2811            (a.columns() == b.columns() && a.rows() == b.rows()) &&
2812            a.componentType().numberKind() == b.componentType().numberKind();
2813 }
2814 
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2815 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
2816                                                 const Type& rightType, SpvId rhs,
2817                                                 const Type& resultType, OutputStream& out) {
2818     // The comma operator ignores the type of the left-hand side entirely.
2819     if (op.kind() == Operator::Kind::COMMA) {
2820         return rhs;
2821     }
2822     // overall type we are operating on: float2, int, uint4...
2823     const Type* operandType;
2824     if (types_match(leftType, rightType)) {
2825         operandType = &leftType;
2826     } else {
2827         // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2828         // handling in SPIR-V
2829         if (leftType.isVector() && rightType.isNumber()) {
2830             if (resultType.componentType().isFloat()) {
2831                 switch (op.kind()) {
2832                     case Operator::Kind::SLASH: {
2833                         rhs = this->writeReciprocal(rightType, rhs, out);
2834                         [[fallthrough]];
2835                     }
2836                     case Operator::Kind::STAR: {
2837                         SpvId result = this->nextId(&resultType);
2838                         this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2839                                                result, lhs, rhs, out);
2840                         return result;
2841                     }
2842                     default:
2843                         break;
2844                 }
2845             }
2846             // Vectorize the right-hand side.
2847             SkSTArray<4, SpvId> arguments;
2848             arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
2849             rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
2850             operandType = &leftType;
2851         } else if (rightType.isVector() && leftType.isNumber()) {
2852             if (resultType.componentType().isFloat()) {
2853                 if (op.kind() == Operator::Kind::STAR) {
2854                     SpvId result = this->nextId(&resultType);
2855                     this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2856                                            result, rhs, lhs, out);
2857                     return result;
2858                 }
2859             }
2860             // Vectorize the left-hand side.
2861             SkSTArray<4, SpvId> arguments;
2862             arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
2863             lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
2864             operandType = &rightType;
2865         } else if (leftType.isMatrix()) {
2866             if (op.kind() == Operator::Kind::STAR) {
2867                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2868                 SpvOp_ spvop;
2869                 if (rightType.isMatrix()) {
2870                     spvop = SpvOpMatrixTimesMatrix;
2871                 } else if (rightType.isVector()) {
2872                     spvop = SpvOpMatrixTimesVector;
2873                 } else {
2874                     SkASSERT(rightType.isScalar());
2875                     spvop = SpvOpMatrixTimesScalar;
2876                 }
2877                 SpvId result = this->nextId(&resultType);
2878                 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2879                 return result;
2880             } else {
2881                 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
2882                 // expect to have a scalar here.
2883                 SkASSERT(rightType.isScalar());
2884 
2885                 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
2886                 SpvId rhsMatrix = this->writeScalarToMatrixSplat(leftType, rhs, out);
2887 
2888                 // Perform this operation as matrix-op-matrix.
2889                 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
2890                                                    resultType, out);
2891             }
2892         } else if (rightType.isMatrix()) {
2893             if (op.kind() == Operator::Kind::STAR) {
2894                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2895                 SpvId result = this->nextId(&resultType);
2896                 if (leftType.isVector()) {
2897                     this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
2898                                            result, lhs, rhs, out);
2899                 } else {
2900                     SkASSERT(leftType.isScalar());
2901                     this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
2902                                            result, rhs, lhs, out);
2903                 }
2904                 return result;
2905             } else {
2906                 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
2907                 // expect to have a scalar here.
2908                 SkASSERT(leftType.isScalar());
2909 
2910                 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
2911                 SpvId lhsMatrix = this->writeScalarToMatrixSplat(rightType, lhs, out);
2912 
2913                 // Perform this operation as matrix-op-matrix.
2914                 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
2915                                                    resultType, out);
2916             }
2917         } else {
2918             fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
2919             return NA;
2920         }
2921     }
2922 
2923     switch (op.kind()) {
2924         case Operator::Kind::EQEQ: {
2925             if (operandType->isMatrix()) {
2926                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2927                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2928             }
2929             if (operandType->isStruct()) {
2930                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2931             }
2932             if (operandType->isArray()) {
2933                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2934             }
2935             SkASSERT(resultType.isBoolean());
2936             const Type* tmpType;
2937             if (operandType->isVector()) {
2938                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2939                                                              operandType->columns(),
2940                                                              operandType->rows());
2941             } else {
2942                 tmpType = &resultType;
2943             }
2944             if (lhs == rhs) {
2945                 // This ignores the effects of NaN.
2946                 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
2947             }
2948             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2949                                                                SpvOpFOrdEqual, SpvOpIEqual,
2950                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2951                                     *operandType, SpvOpAll, out);
2952         }
2953         case Operator::Kind::NEQ:
2954             if (operandType->isMatrix()) {
2955                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
2956                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2957             }
2958             if (operandType->isStruct()) {
2959                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2960             }
2961             if (operandType->isArray()) {
2962                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2963             }
2964             [[fallthrough]];
2965         case Operator::Kind::LOGICALXOR:
2966             SkASSERT(resultType.isBoolean());
2967             const Type* tmpType;
2968             if (operandType->isVector()) {
2969                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2970                                                              operandType->columns(),
2971                                                              operandType->rows());
2972             } else {
2973                 tmpType = &resultType;
2974             }
2975             if (lhs == rhs) {
2976                 // This ignores the effects of NaN.
2977                 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
2978             }
2979             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2980                                                                SpvOpFUnordNotEqual, SpvOpINotEqual,
2981                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2982                                                                out),
2983                                     *operandType, SpvOpAny, out);
2984         case Operator::Kind::GT:
2985             SkASSERT(resultType.isBoolean());
2986             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2987                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2988                                               SpvOpUGreaterThan, SpvOpUndef, out);
2989         case Operator::Kind::LT:
2990             SkASSERT(resultType.isBoolean());
2991             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2992                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2993         case Operator::Kind::GTEQ:
2994             SkASSERT(resultType.isBoolean());
2995             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2996                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2997                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2998         case Operator::Kind::LTEQ:
2999             SkASSERT(resultType.isBoolean());
3000             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3001                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
3002                                               SpvOpULessThanEqual, SpvOpUndef, out);
3003         case Operator::Kind::PLUS:
3004             if (leftType.isMatrix() && rightType.isMatrix()) {
3005                 SkASSERT(leftType.matches(rightType));
3006                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFAdd, out);
3007             }
3008             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
3009                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
3010         case Operator::Kind::MINUS:
3011             if (leftType.isMatrix() && rightType.isMatrix()) {
3012                 SkASSERT(leftType.matches(rightType));
3013                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFSub, out);
3014             }
3015             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
3016                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
3017         case Operator::Kind::STAR:
3018             if (leftType.isMatrix() && rightType.isMatrix()) {
3019                 // matrix multiply
3020                 SpvId result = this->nextId(&resultType);
3021                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
3022                                        lhs, rhs, out);
3023                 return result;
3024             }
3025             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
3026                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
3027         case Operator::Kind::SLASH:
3028             if (leftType.isMatrix() && rightType.isMatrix()) {
3029                 SkASSERT(leftType.matches(rightType));
3030                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFDiv, out);
3031             }
3032             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
3033                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
3034         case Operator::Kind::PERCENT:
3035             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
3036                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
3037         case Operator::Kind::SHL:
3038             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3039                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
3040                                               SpvOpUndef, out);
3041         case Operator::Kind::SHR:
3042             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3043                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
3044                                               SpvOpUndef, out);
3045         case Operator::Kind::BITWISEAND:
3046             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3047                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
3048         case Operator::Kind::BITWISEOR:
3049             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3050                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
3051         case Operator::Kind::BITWISEXOR:
3052             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3053                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
3054         default:
3055             fContext.fErrors->error(Position(), "unsupported token");
3056             return NA;
3057     }
3058 }
3059 
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3060 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
3061                                                SpvId rhs, OutputStream& out) {
3062     // The inputs must be arrays, and the op must be == or !=.
3063     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3064     SkASSERT(arrayType.isArray());
3065     const Type& componentType = arrayType.componentType();
3066     const int arraySize = arrayType.columns();
3067     SkASSERT(arraySize > 0);
3068 
3069     // Synthesize equality checks for each item in the array.
3070     const Type& boolType = *fContext.fTypes.fBool;
3071     SpvId allComparisons = NA;
3072     for (int index = 0; index < arraySize; ++index) {
3073         // Get the left and right item in the array.
3074         SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
3075         SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
3076         // Use `writeBinaryExpression` with the requested == or != operator on these items.
3077         SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
3078                                                        componentType, itemR, boolType, out);
3079         // Merge this comparison result with all the other comparisons we've done.
3080         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3081     }
3082     return allComparisons;
3083 }
3084 
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3085 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
3086                                                 SpvId rhs, OutputStream& out) {
3087     // The inputs must be structs containing fields, and the op must be == or !=.
3088     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3089     SkASSERT(structType.isStruct());
3090     const std::vector<Type::Field>& fields = structType.fields();
3091     SkASSERT(!fields.empty());
3092 
3093     // Synthesize equality checks for each field in the struct.
3094     const Type& boolType = *fContext.fTypes.fBool;
3095     SpvId allComparisons = NA;
3096     for (int index = 0; index < (int)fields.size(); ++index) {
3097         // Get the left and right versions of this field.
3098         const Type& fieldType = *fields[index].fType;
3099 
3100         SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
3101         SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
3102         // Use `writeBinaryExpression` with the requested == or != operator on these fields.
3103         SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
3104                                                        boolType, out);
3105         // Merge this comparison result with all the other comparisons we've done.
3106         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3107     }
3108     return allComparisons;
3109 }
3110 
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)3111 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
3112                                            OutputStream& out) {
3113     // If this is the first entry, we don't need to merge comparison results with anything.
3114     if (allComparisons == NA) {
3115         return comparison;
3116     }
3117     // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
3118     const Type& boolType = *fContext.fTypes.fBool;
3119     SpvId boolTypeId = this->getType(boolType);
3120     SpvId logicalOp = this->nextId(&boolType);
3121     switch (op.kind()) {
3122         case Operator::Kind::EQEQ:
3123             this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
3124                                    comparison, allComparisons, out);
3125             break;
3126         case Operator::Kind::NEQ:
3127             this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
3128                                    comparison, allComparisons, out);
3129             break;
3130         default:
3131             SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
3132             return NA;
3133     }
3134     return logicalOp;
3135 }
3136 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)3137 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
3138     const Expression* left = b.left().get();
3139     const Expression* right = b.right().get();
3140     Operator op = b.getOperator();
3141 
3142     switch (op.kind()) {
3143         case Operator::Kind::EQ: {
3144             // Handles assignment.
3145             SpvId rhs = this->writeExpression(*right, out);
3146             this->getLValue(*left, out)->store(rhs, out);
3147             return rhs;
3148         }
3149         case Operator::Kind::LOGICALAND:
3150             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
3151             return this->writeLogicalAnd(*b.left(), *b.right(), out);
3152 
3153         case Operator::Kind::LOGICALOR:
3154             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
3155             return this->writeLogicalOr(*b.left(), *b.right(), out);
3156 
3157         default:
3158             break;
3159     }
3160 
3161     std::unique_ptr<LValue> lvalue;
3162     SpvId lhs;
3163     if (op.isAssignment()) {
3164         lvalue = this->getLValue(*left, out);
3165         lhs = lvalue->load(out);
3166     } else {
3167         lvalue = nullptr;
3168         lhs = this->writeExpression(*left, out);
3169     }
3170 
3171     SpvId rhs = this->writeExpression(*right, out);
3172     SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
3173                                                right->type(), rhs, b.type(), out);
3174     if (lvalue) {
3175         lvalue->store(result, out);
3176     }
3177     return result;
3178 }
3179 
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)3180 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
3181                                           OutputStream& out) {
3182     SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
3183     SpvId lhs = this->writeExpression(left, out);
3184 
3185     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3186 
3187     SpvId rhsLabel = this->nextId(nullptr);
3188     SpvId end = this->nextId(nullptr);
3189     SpvId lhsBlock = fCurrentBlock;
3190     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3191     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
3192     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
3193     SpvId rhs = this->writeExpression(right, out);
3194     SpvId rhsBlock = fCurrentBlock;
3195     this->writeInstruction(SpvOpBranch, end, out);
3196     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3197     SpvId result = this->nextId(nullptr);
3198     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
3199                            lhsBlock, rhs, rhsBlock, out);
3200 
3201     return result;
3202 }
3203 
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)3204 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
3205                                          OutputStream& out) {
3206     SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
3207     SpvId lhs = this->writeExpression(left, out);
3208 
3209     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3210 
3211     SpvId rhsLabel = this->nextId(nullptr);
3212     SpvId end = this->nextId(nullptr);
3213     SpvId lhsBlock = fCurrentBlock;
3214     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3215     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
3216     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
3217     SpvId rhs = this->writeExpression(right, out);
3218     SpvId rhsBlock = fCurrentBlock;
3219     this->writeInstruction(SpvOpBranch, end, out);
3220     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3221     SpvId result = this->nextId(nullptr);
3222     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
3223                            lhsBlock, rhs, rhsBlock, out);
3224 
3225     return result;
3226 }
3227 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)3228 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
3229     const Type& type = t.type();
3230     SpvId test = this->writeExpression(*t.test(), out);
3231     if (t.ifTrue()->type().columns() == 1 &&
3232         Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
3233         Analysis::IsCompileTimeConstant(*t.ifFalse())) {
3234         // both true and false are constants, can just use OpSelect
3235         SpvId result = this->nextId(nullptr);
3236         SpvId trueId = this->writeExpression(*t.ifTrue(), out);
3237         SpvId falseId = this->writeExpression(*t.ifFalse(), out);
3238         this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
3239                                out);
3240         return result;
3241     }
3242 
3243     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3244 
3245     // was originally using OpPhi to choose the result, but for some reason that is crashing on
3246     // Adreno. Switched to storing the result in a temp variable as glslang does.
3247     SpvId var = this->nextId(nullptr);
3248     this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
3249                            var, SpvStorageClassFunction, fVariableBuffer);
3250     SpvId trueLabel = this->nextId(nullptr);
3251     SpvId falseLabel = this->nextId(nullptr);
3252     SpvId end = this->nextId(nullptr);
3253     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3254     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
3255     this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
3256     this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifTrue(), out), out);
3257     this->writeInstruction(SpvOpBranch, end, out);
3258     this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
3259     this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifFalse(), out), out);
3260     this->writeInstruction(SpvOpBranch, end, out);
3261     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3262     SpvId result = this->nextId(&type);
3263     this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
3264 
3265     return result;
3266 }
3267 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)3268 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
3269     const Type& type = p.type();
3270     if (p.getOperator().kind() == Operator::Kind::MINUS) {
3271         SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
3272         SkASSERT(negateOp != SpvOpUndef);
3273         SpvId expr = this->writeExpression(*p.operand(), out);
3274         if (type.isMatrix()) {
3275             return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
3276         }
3277         SpvId result = this->nextId(&type);
3278         SpvId typeId = this->getType(type);
3279         this->writeInstruction(negateOp, typeId, result, expr, out);
3280         return result;
3281     }
3282     switch (p.getOperator().kind()) {
3283         case Operator::Kind::PLUS:
3284             return this->writeExpression(*p.operand(), out);
3285         case Operator::Kind::PLUSPLUS: {
3286             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3287             SpvId one = this->writeLiteral(1.0, type);
3288             SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
3289                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
3290                                                       out);
3291             lv->store(result, out);
3292             return result;
3293         }
3294         case Operator::Kind::MINUSMINUS: {
3295             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3296             SpvId one = this->writeLiteral(1.0, type);
3297             SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
3298                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
3299             lv->store(result, out);
3300             return result;
3301         }
3302         case Operator::Kind::LOGICALNOT: {
3303             SkASSERT(p.operand()->type().isBoolean());
3304             SpvId result = this->nextId(nullptr);
3305             this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
3306                                    this->writeExpression(*p.operand(), out), out);
3307             return result;
3308         }
3309         case Operator::Kind::BITWISENOT: {
3310             SpvId result = this->nextId(nullptr);
3311             this->writeInstruction(SpvOpNot, this->getType(type), result,
3312                                    this->writeExpression(*p.operand(), out), out);
3313             return result;
3314         }
3315         default:
3316             SkDEBUGFAILF("unsupported prefix expression: %s",
3317                          p.description(OperatorPrecedence::kTopLevel).c_str());
3318             return NA;
3319     }
3320 }
3321 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)3322 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
3323     const Type& type = p.type();
3324     std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3325     SpvId result = lv->load(out);
3326     SpvId one = this->writeLiteral(1.0, type);
3327     switch (p.getOperator().kind()) {
3328         case Operator::Kind::PLUSPLUS: {
3329             SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,
3330                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
3331             lv->store(temp, out);
3332             return result;
3333         }
3334         case Operator::Kind::MINUSMINUS: {
3335             SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub,
3336                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
3337             lv->store(temp, out);
3338             return result;
3339         }
3340         default:
3341             SkDEBUGFAILF("unsupported postfix expression %s",
3342                          p.description(OperatorPrecedence::kTopLevel).c_str());
3343             return NA;
3344     }
3345 }
3346 
writeLiteral(const Literal & l)3347 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
3348     return this->writeLiteral(l.value(), l.type());
3349 }
3350 
writeLiteral(double value,const Type & type)3351 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
3352     switch (type.numberKind()) {
3353         case Type::NumberKind::kFloat: {
3354             float floatVal = value;
3355             int32_t valueBits;
3356             memcpy(&valueBits, &floatVal, sizeof(valueBits));
3357             return this->writeOpConstant(type, valueBits);
3358         }
3359         case Type::NumberKind::kBoolean: {
3360             return value ? this->writeOpConstantTrue(type)
3361                          : this->writeOpConstantFalse(type);
3362         }
3363         default: {
3364             return this->writeOpConstant(type, (SKSL_INT)value);
3365         }
3366     }
3367 }
3368 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)3369 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
3370     SpvId result = fFunctionMap[&f];
3371     SpvId returnTypeId = this->getType(f.returnType());
3372     SpvId functionTypeId = this->getFunctionType(f);
3373     this->writeInstruction(SpvOpFunction, returnTypeId, result,
3374                            SpvFunctionControlMaskNone, functionTypeId, out);
3375     std::string mangledName = f.mangledName();
3376     this->writeInstruction(SpvOpName,
3377                            result,
3378                            std::string_view(mangledName.c_str(), mangledName.size()),
3379                            fNameBuffer);
3380     for (const Variable* parameter : f.parameters()) {
3381         if (parameter->type().typeKind() == Type::TypeKind::kSampler &&
3382             fProgram.fConfig->fSettings.fSPIRVDawnCompatMode) {
3383             auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
3384 
3385             SpvId textureId = this->nextId(nullptr);
3386             SpvId samplerId = this->nextId(nullptr);
3387             fVariableMap.set(texture, textureId);
3388             fVariableMap.set(sampler, samplerId);
3389 
3390             SpvId textureType = this->getFunctionParameterType(texture->type());
3391             SpvId samplerType = this->getFunctionParameterType(sampler->type());
3392 
3393             this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
3394             this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
3395         } else {
3396             SpvId id = this->nextId(nullptr);
3397             fVariableMap.set(parameter, id);
3398 
3399             SpvId type = this->getFunctionParameterType(parameter->type());
3400             this->writeInstruction(SpvOpFunctionParameter, type, id, out);
3401         }
3402     }
3403     return result;
3404 }
3405 
writeFunction(const FunctionDefinition & f,OutputStream & out)3406 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
3407     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3408 
3409     fVariableBuffer.reset();
3410     SpvId result = this->writeFunctionStart(f.declaration(), out);
3411     fCurrentBlock = 0;
3412     this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
3413     StringStream bodyBuffer;
3414     this->writeBlock(f.body()->as<Block>(), bodyBuffer);
3415     write_stringstream(fVariableBuffer, out);
3416     if (f.declaration().isMain()) {
3417         write_stringstream(fGlobalInitializersBuffer, out);
3418     }
3419     write_stringstream(bodyBuffer, out);
3420     if (fCurrentBlock) {
3421         if (f.declaration().returnType().isVoid()) {
3422             this->writeInstruction(SpvOpReturn, out);
3423         } else {
3424             this->writeInstruction(SpvOpUnreachable, out);
3425         }
3426     }
3427     this->writeInstruction(SpvOpFunctionEnd, out);
3428     this->pruneConditionalOps(conditionalOps);
3429     return result;
3430 }
3431 
writeLayout(const Layout & layout,SpvId target,Position pos)3432 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
3433     bool isPushConstant = (layout.fFlags & Layout::kPushConstant_Flag);
3434     if (layout.fLocation >= 0) {
3435         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
3436                                fDecorationBuffer);
3437     }
3438     if (layout.fBinding >= 0) {
3439         if (isPushConstant) {
3440             fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
3441         } else {
3442             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
3443                                    fDecorationBuffer);
3444         }
3445     }
3446     if (layout.fIndex >= 0) {
3447         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
3448                                fDecorationBuffer);
3449     }
3450     if (layout.fSet >= 0) {
3451         if (isPushConstant) {
3452             fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
3453         } else {
3454             this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
3455                                    fDecorationBuffer);
3456         }
3457     }
3458     if (layout.fInputAttachmentIndex >= 0) {
3459         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
3460                                layout.fInputAttachmentIndex, fDecorationBuffer);
3461         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
3462     }
3463     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
3464         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
3465                                fDecorationBuffer);
3466     }
3467 }
3468 
writeFieldLayout(const Layout & layout,SpvId target,int member)3469 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
3470     // 'binding' and 'set' can not be applied to struct members
3471     SkASSERT(layout.fBinding == -1);
3472     SkASSERT(layout.fSet == -1);
3473     if (layout.fLocation >= 0) {
3474         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
3475                                layout.fLocation, fDecorationBuffer);
3476     }
3477     if (layout.fIndex >= 0) {
3478         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
3479                                layout.fIndex, fDecorationBuffer);
3480     }
3481     if (layout.fInputAttachmentIndex >= 0) {
3482         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
3483                                layout.fInputAttachmentIndex, fDecorationBuffer);
3484     }
3485     if (layout.fBuiltin >= 0) {
3486         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
3487                                layout.fBuiltin, fDecorationBuffer);
3488     }
3489 }
3490 
memoryLayoutForStorageClass(SpvStorageClass_ storageClass)3491 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(SpvStorageClass_ storageClass) {
3492     return storageClass == SpvStorageClassPushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
3493                                                        : fDefaultLayout;
3494 }
3495 
memoryLayoutForVariable(const Variable & v) const3496 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
3497     bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
3498     return pushConstant ? MemoryLayout(MemoryLayout::Standard::k430) : fDefaultLayout;
3499 }
3500 
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)3501 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
3502     MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
3503     SpvId result = this->nextId(nullptr);
3504     const Variable& intfVar = *intf.var();
3505     const Type& type = intfVar.type();
3506     if (!memoryLayout.isSupported(type)) {
3507         fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
3508                                                 "' is not permitted here");
3509         return this->nextId(nullptr);
3510     }
3511     SpvStorageClass_ storageClass =
3512             get_storage_class_for_global_variable(intfVar, SpvStorageClassFunction);
3513     if (fProgram.fInputs.fUseFlipRTUniform && appendRTFlip && type.isStruct()) {
3514         // We can only have one interface block (because we use push_constant and that is limited
3515         // to one per program), so we need to append rtflip to this one rather than synthesize an
3516         // entirely new block when the variable is referenced. And we can't modify the existing
3517         // block, so we instead create a modified copy of it and write that.
3518         std::vector<Type::Field> fields = type.fields();
3519         fields.emplace_back(Position(),
3520                             Modifiers(Layout(/*flags=*/0,
3521                                              /*location=*/-1,
3522                                              fProgram.fConfig->fSettings.fRTFlipOffset,
3523                                              /*binding=*/-1,
3524                                              /*index=*/-1,
3525                                              /*set=*/-1,
3526                                              /*builtin=*/-1,
3527                                              /*inputAttachmentIndex=*/-1),
3528                                       /*flags=*/0),
3529                             SKSL_RTFLIP_NAME,
3530                             fContext.fTypes.fFloat2.get());
3531         {
3532             AutoAttachPoolToThread attach(fProgram.fPool.get());
3533             const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
3534                     Type::MakeStructType(fContext,
3535                                          type.fPosition,
3536                                          type.name(),
3537                                          std::move(fields),
3538                                          /*interfaceBlock=*/true));
3539             InterfaceBlockVariable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
3540                     std::make_unique<InterfaceBlockVariable>(intfVar.fPosition,
3541                                                              intfVar.modifiersPosition(),
3542                                                              &intfVar.modifiers(),
3543                                                              intfVar.name(),
3544                                                              rtFlipStructType,
3545                                                              intfVar.isBuiltin(),
3546                                                              intfVar.storage()));
3547             fSPIRVBonusVariables.add(modifiedVar);
3548             InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar, intf.typeOwner());
3549             result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
3550             fProgram.fSymbols->add(std::make_unique<Field>(
3551                     Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
3552         }
3553         fVariableMap.set(&intfVar, result);
3554         fWroteRTFlip = true;
3555         return result;
3556     }
3557     const Modifiers& intfModifiers = intfVar.modifiers();
3558     SpvId typeId = this->getType(type, memoryLayout);
3559     if (intfModifiers.fLayout.fBuiltin == -1) {
3560         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
3561     }
3562     SpvId ptrType = this->nextId(nullptr);
3563     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
3564     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
3565     Layout layout = intfModifiers.fLayout;
3566     if (storageClass == SpvStorageClassUniform && layout.fSet < 0) {
3567         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3568     }
3569     this->writeLayout(layout, result, intfVar.fPosition);
3570     fVariableMap.set(&intfVar, result);
3571     return result;
3572 }
3573 
isDead(const Variable & var) const3574 bool SPIRVCodeGenerator::isDead(const Variable& var) const {
3575     // During SPIR-V code generation, we synthesize some extra bonus variables that don't actually
3576     // exist in the Program at all and aren't tracked by the ProgramUsage. They aren't dead, though.
3577     if (fSPIRVBonusVariables.contains(&var)) {
3578         return false;
3579     }
3580     ProgramUsage::VariableCounts counts = fProgram.usage()->get(var);
3581     if (counts.fRead || counts.fWrite) {
3582         return false;
3583     }
3584     // It's not entirely clear what the rules are for eliding interface variables. Generally, it
3585     // causes problems to elide them, even when they're dead.
3586     return !(var.modifiers().fFlags &
3587              (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag));
3588 }
3589 
3590 // This function determines whether to skip an OpVariable (of pointer type) declaration for
3591 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
3592 // always reference by value.
3593 //
3594 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
3595 // requires a base operand of pointer type. However, a vector can always be accessed by value using
3596 // OpVectorExtractDynamic (see writeIndexExpression).
3597 //
3598 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
3599 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)3600 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
3601     return varDecl.var()->modifiers().fFlags & Modifiers::kConst_Flag &&
3602            (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
3603            (ConstantFolder::GetConstantValueOrNullForVariable(*varDecl.value()) ||
3604             Analysis::IsCompileTimeConstant(*varDecl.value()));
3605 }
3606 
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)3607 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
3608                                                    const VarDeclaration& varDecl) {
3609     const Variable* var = varDecl.var();
3610     const bool inDawnMode = fProgram.fConfig->fSettings.fSPIRVDawnCompatMode;
3611     const int backendFlags = var->modifiers().fLayout.fFlags & Layout::kAllBackendFlagsMask;
3612     const int permittedBackendFlags = Layout::kSPIRV_Flag | (inDawnMode ? Layout::kWGSL_Flag : 0);
3613     if (backendFlags & ~permittedBackendFlags) {
3614         fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
3615         return false;
3616     }
3617 
3618     // If this global variable is a compile-time constant then we'll emit OpConstant or
3619     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
3620     if (is_vardecl_compile_time_constant(varDecl)) {
3621         return true;
3622     }
3623 
3624     SpvStorageClass_ storageClass =
3625             get_storage_class_for_global_variable(*var, SpvStorageClassPrivate);
3626     if (storageClass == SpvStorageClassUniform) {
3627         // Top-level uniforms are emitted in writeUniformBuffer.
3628         fTopLevelUniforms.push_back(&varDecl);
3629         return true;
3630     }
3631 
3632     if (this->isDead(*var)) {
3633         return true;
3634     }
3635 
3636     if (var->type().typeKind() == Type::TypeKind::kSampler && inDawnMode) {
3637         if (var->modifiers().fLayout.fTexture == -1 || var->modifiers().fLayout.fSampler == -1 ||
3638             !(var->modifiers().fLayout.fFlags & Layout::kWGSL_Flag)) {
3639             fContext.fErrors->error(var->fPosition,
3640                                     "SPIR-V dawn compatibility mode requires an explicit texture "
3641                                     "and sampler index");
3642             return false;
3643         }
3644         SkASSERT(storageClass == SpvStorageClassUniformConstant);
3645 
3646         auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
3647         this->writeGlobalVar(kind, storageClass, *texture);
3648         this->writeGlobalVar(kind, storageClass, *sampler);
3649 
3650         return true;
3651     }
3652 
3653     SpvId id = this->writeGlobalVar(kind, storageClass, *var);
3654     if (id != NA && varDecl.value()) {
3655         SkASSERT(!fCurrentBlock);
3656         fCurrentBlock = NA;
3657         SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
3658         this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
3659         fCurrentBlock = 0;
3660     }
3661     return true;
3662 }
3663 
writeGlobalVar(ProgramKind kind,SpvStorageClass_ storageClass,const Variable & var)3664 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
3665                                          SpvStorageClass_ storageClass,
3666                                          const Variable& var) {
3667     if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
3668         !ProgramConfig::IsFragment(kind)) {
3669         SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
3670         return NA;
3671     }
3672 
3673     // Add this global to the variable map.
3674     const Type& type = var.type();
3675     SpvId id = this->nextId(&type);
3676     fVariableMap.set(&var, id);
3677 
3678     Layout layout = var.modifiers().fLayout;
3679     if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
3680         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3681     }
3682 
3683     SpvId typeId = this->getPointerType(type, storageClass);
3684     this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
3685     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3686     this->writeLayout(layout, id, var.fPosition);
3687     if (var.modifiers().fFlags & Modifiers::kFlat_Flag) {
3688         this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
3689     }
3690     if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) {
3691         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
3692                                fDecorationBuffer);
3693     }
3694 
3695     return id;
3696 }
3697 
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)3698 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
3699     // If this variable is a compile-time constant then we'll emit OpConstant or
3700     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
3701     if (is_vardecl_compile_time_constant(varDecl)) {
3702         return;
3703     }
3704 
3705     const Variable* var = varDecl.var();
3706     SpvId id = this->nextId(&var->type());
3707     fVariableMap.set(var, id);
3708     SpvId type = this->getPointerType(var->type(), SpvStorageClassFunction);
3709     this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
3710     this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
3711     if (varDecl.value()) {
3712         SpvId value = this->writeExpression(*varDecl.value(), out);
3713         this->writeOpStore(SpvStorageClassFunction, id, value, out);
3714     }
3715 }
3716 
writeStatement(const Statement & s,OutputStream & out)3717 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
3718     switch (s.kind()) {
3719         case Statement::Kind::kNop:
3720             break;
3721         case Statement::Kind::kBlock:
3722             this->writeBlock(s.as<Block>(), out);
3723             break;
3724         case Statement::Kind::kExpression:
3725             this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
3726             break;
3727         case Statement::Kind::kReturn:
3728             this->writeReturnStatement(s.as<ReturnStatement>(), out);
3729             break;
3730         case Statement::Kind::kVarDeclaration:
3731             this->writeVarDeclaration(s.as<VarDeclaration>(), out);
3732             break;
3733         case Statement::Kind::kIf:
3734             this->writeIfStatement(s.as<IfStatement>(), out);
3735             break;
3736         case Statement::Kind::kFor:
3737             this->writeForStatement(s.as<ForStatement>(), out);
3738             break;
3739         case Statement::Kind::kDo:
3740             this->writeDoStatement(s.as<DoStatement>(), out);
3741             break;
3742         case Statement::Kind::kSwitch:
3743             this->writeSwitchStatement(s.as<SwitchStatement>(), out);
3744             break;
3745         case Statement::Kind::kBreak:
3746             this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
3747             break;
3748         case Statement::Kind::kContinue:
3749             this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
3750             break;
3751         case Statement::Kind::kDiscard:
3752             this->writeInstruction(SpvOpKill, out);
3753             break;
3754         default:
3755             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
3756             break;
3757     }
3758 }
3759 
writeBlock(const Block & b,OutputStream & out)3760 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
3761     for (const std::unique_ptr<Statement>& stmt : b.children()) {
3762         this->writeStatement(*stmt, out);
3763     }
3764 }
3765 
getConditionalOpCounts()3766 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
3767     return {fReachableOps.size(), fStoreOps.size()};
3768 }
3769 
pruneConditionalOps(ConditionalOpCounts ops)3770 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
3771     // Remove ops which are no longer reachable.
3772     while (fReachableOps.size() > ops.numReachableOps) {
3773         SpvId prunableSpvId = fReachableOps.back();
3774         const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
3775 
3776         if (prunableOp) {
3777             fOpCache.remove(*prunableOp);
3778             fSpvIdCache.remove(prunableSpvId);
3779         } else {
3780             SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
3781         }
3782 
3783         fReachableOps.pop_back();
3784     }
3785 
3786     // Remove any cached stores that occurred during the conditional block.
3787     while (fStoreOps.size() > ops.numStoreOps) {
3788         if (fStoreCache.find(fStoreOps.back())) {
3789             fStoreCache.remove(fStoreOps.back());
3790         }
3791         fStoreOps.pop_back();
3792     }
3793 }
3794 
writeIfStatement(const IfStatement & stmt,OutputStream & out)3795 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
3796     SpvId test = this->writeExpression(*stmt.test(), out);
3797     SpvId ifTrue = this->nextId(nullptr);
3798     SpvId ifFalse = this->nextId(nullptr);
3799 
3800     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3801 
3802     if (stmt.ifFalse()) {
3803         SpvId end = this->nextId(nullptr);
3804         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3805         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3806         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
3807         this->writeStatement(*stmt.ifTrue(), out);
3808         if (fCurrentBlock) {
3809             this->writeInstruction(SpvOpBranch, end, out);
3810         }
3811         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
3812         this->writeStatement(*stmt.ifFalse(), out);
3813         if (fCurrentBlock) {
3814             this->writeInstruction(SpvOpBranch, end, out);
3815         }
3816         this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3817     } else {
3818         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
3819         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3820         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
3821         this->writeStatement(*stmt.ifTrue(), out);
3822         if (fCurrentBlock) {
3823             this->writeInstruction(SpvOpBranch, ifFalse, out);
3824         }
3825         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
3826     }
3827 }
3828 
writeForStatement(const ForStatement & f,OutputStream & out)3829 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
3830     if (f.initializer()) {
3831         this->writeStatement(*f.initializer(), out);
3832     }
3833 
3834     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3835 
3836     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3837     // in the context of linear straight-line execution. If we wanted to be more clever, we could
3838     // only invalidate store cache entries for variables affected by the loop body, but for now we
3839     // simply clear the entire cache whenever branching occurs.
3840     SpvId header = this->nextId(nullptr);
3841     SpvId start = this->nextId(nullptr);
3842     SpvId body = this->nextId(nullptr);
3843     SpvId next = this->nextId(nullptr);
3844     fContinueTarget.push_back(next);
3845     SpvId end = this->nextId(nullptr);
3846     fBreakTarget.push_back(end);
3847     this->writeInstruction(SpvOpBranch, header, out);
3848     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
3849     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
3850     this->writeInstruction(SpvOpBranch, start, out);
3851     this->writeLabel(start, kBranchIsOnPreviousLine, out);
3852     if (f.test()) {
3853         SpvId test = this->writeExpression(*f.test(), out);
3854         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
3855     } else {
3856         this->writeInstruction(SpvOpBranch, body, out);
3857     }
3858     this->writeLabel(body, kBranchIsOnPreviousLine, out);
3859     this->writeStatement(*f.statement(), out);
3860     if (fCurrentBlock) {
3861         this->writeInstruction(SpvOpBranch, next, out);
3862     }
3863     this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
3864     if (f.next()) {
3865         this->writeExpression(*f.next(), out);
3866     }
3867     this->writeInstruction(SpvOpBranch, header, out);
3868     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3869     fBreakTarget.pop_back();
3870     fContinueTarget.pop_back();
3871 }
3872 
writeDoStatement(const DoStatement & d,OutputStream & out)3873 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
3874     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3875 
3876     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3877     // in the context of linear straight-line execution. If we wanted to be more clever, we could
3878     // only invalidate store cache entries for variables affected by the loop body, but for now we
3879     // simply clear the entire cache whenever branching occurs.
3880     SpvId header = this->nextId(nullptr);
3881     SpvId start = this->nextId(nullptr);
3882     SpvId next = this->nextId(nullptr);
3883     SpvId continueTarget = this->nextId(nullptr);
3884     fContinueTarget.push_back(continueTarget);
3885     SpvId end = this->nextId(nullptr);
3886     fBreakTarget.push_back(end);
3887     this->writeInstruction(SpvOpBranch, header, out);
3888     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
3889     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
3890     this->writeInstruction(SpvOpBranch, start, out);
3891     this->writeLabel(start, kBranchIsOnPreviousLine, out);
3892     this->writeStatement(*d.statement(), out);
3893     if (fCurrentBlock) {
3894         this->writeInstruction(SpvOpBranch, next, out);
3895         this->writeLabel(next, kBranchIsOnPreviousLine, out);
3896         this->writeInstruction(SpvOpBranch, continueTarget, out);
3897     }
3898     this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
3899     SpvId test = this->writeExpression(*d.test(), out);
3900     this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
3901     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3902     fBreakTarget.pop_back();
3903     fContinueTarget.pop_back();
3904 }
3905 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)3906 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3907     SpvId value = this->writeExpression(*s.value(), out);
3908 
3909     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3910 
3911     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3912     // in the context of linear straight-line execution. If we wanted to be more clever, we could
3913     // only invalidate store cache entries for variables affected by the switch body, but for now we
3914     // simply clear the entire cache whenever branching occurs.
3915     SkTArray<SpvId> labels;
3916     SpvId end = this->nextId(nullptr);
3917     SpvId defaultLabel = end;
3918     fBreakTarget.push_back(end);
3919     int size = 3;
3920     const StatementArray& cases = s.cases();
3921     for (const std::unique_ptr<Statement>& stmt : cases) {
3922         const SwitchCase& c = stmt->as<SwitchCase>();
3923         SpvId label = this->nextId(nullptr);
3924         labels.push_back(label);
3925         if (!c.isDefault()) {
3926             size += 2;
3927         } else {
3928             defaultLabel = label;
3929         }
3930     }
3931 
3932     // We should have exactly one label for each case.
3933     SkASSERT(labels.size() == cases.size());
3934 
3935     // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
3936     // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
3937     // does support multiple switch-cases branching to the same label.
3938     SkBitSet caseIsCollapsed(cases.size());
3939     for (int index = cases.size() - 2; index >= 0; index--) {
3940         if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
3941             caseIsCollapsed.set(index);
3942             labels[index] = labels[index + 1];
3943         }
3944     }
3945 
3946     labels.push_back(end);
3947 
3948     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3949     this->writeOpCode(SpvOpSwitch, size, out);
3950     this->writeWord(value, out);
3951     this->writeWord(defaultLabel, out);
3952     for (int i = 0; i < cases.size(); ++i) {
3953         const SwitchCase& c = cases[i]->as<SwitchCase>();
3954         if (c.isDefault()) {
3955             continue;
3956         }
3957         this->writeWord(c.value(), out);
3958         this->writeWord(labels[i], out);
3959     }
3960     for (int i = 0; i < cases.size(); ++i) {
3961         if (caseIsCollapsed.test(i)) {
3962             continue;
3963         }
3964         const SwitchCase& c = cases[i]->as<SwitchCase>();
3965         if (i == 0) {
3966             this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
3967         } else {
3968             this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
3969         }
3970         this->writeStatement(*c.statement(), out);
3971         if (fCurrentBlock) {
3972             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3973         }
3974     }
3975     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3976     fBreakTarget.pop_back();
3977 }
3978 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3979 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3980     if (r.expression()) {
3981         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
3982                                out);
3983     } else {
3984         this->writeInstruction(SpvOpReturn, out);
3985     }
3986 }
3987 
3988 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)3989 static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
3990     return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
3991 }
3992 
writeEntrypointAdapter(const FunctionDeclaration & main)3993 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
3994         const FunctionDeclaration& main) {
3995     // Our goal is to synthesize a tiny helper function which looks like this:
3996     //     void _entrypoint() { sk_FragColor = main(); }
3997 
3998     // Fish a symbol table out of main().
3999     std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main);
4000 
4001     // Get `sk_FragColor` as a writable reference.
4002     const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
4003     SkASSERT(skFragColorSymbol);
4004     const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
4005     auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
4006                                                               VariableReference::RefKind::kWrite);
4007     // Synthesize a call to the `main()` function.
4008     if (!main.returnType().matches(skFragColorRef->type())) {
4009         fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
4010                 main.returnType().description() + "' from main()");
4011         return {};
4012     }
4013     ExpressionArray args;
4014     if (main.parameters().size() == 1) {
4015         if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
4016             fContext.fErrors->error(main.fPosition,
4017                     "SPIR-V does not support parameter of type '" +
4018                     main.parameters()[0]->type().description() + "' to main()");
4019             return {};
4020         }
4021         args.push_back(dsl::Float2(0).release());
4022     }
4023     auto callMainFn = std::make_unique<FunctionCall>(Position(), &main.returnType(), &main,
4024                                                      std::move(args));
4025 
4026     // Synthesize `skFragColor = main()` as a BinaryExpression.
4027     auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
4028             Position(),
4029             std::move(skFragColorRef),
4030             Operator::Kind::EQ,
4031             std::move(callMainFn),
4032             &main.returnType()));
4033 
4034     // Function bodies are always wrapped in a Block.
4035     StatementArray entrypointStmts;
4036     entrypointStmts.push_back(std::move(assignmentStmt));
4037     auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
4038                                        Block::Kind::kBracedScope, symbolTable);
4039     // Declare an entrypoint function.
4040     EntrypointAdapter adapter;
4041     adapter.fLayout = {};
4042     adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kNo_Flag};
4043     adapter.entrypointDecl =
4044             std::make_unique<FunctionDeclaration>(Position(),
4045                                                   &adapter.fModifiers,
4046                                                   "_entrypoint",
4047                                                   /*parameters=*/std::vector<Variable*>{},
4048                                                   /*returnType=*/fContext.fTypes.fVoid.get(),
4049                                                   /*builtin=*/false);
4050     // Define it.
4051     adapter.entrypointDef = FunctionDefinition::Convert(fContext,
4052                                                         Position(),
4053                                                         *adapter.entrypointDecl,
4054                                                         std::move(entrypointBlock),
4055                                                         /*builtin=*/false);
4056 
4057     adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
4058     return adapter;
4059 }
4060 
writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable)4061 void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) {
4062     SkASSERT(!fTopLevelUniforms.empty());
4063     static constexpr char kUniformBufferName[] = "_UniformBuffer";
4064 
4065     // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
4066     // a lookup table of variables to UniformBuffer field indices.
4067     std::vector<Type::Field> fields;
4068     fields.reserve(fTopLevelUniforms.size());
4069     for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
4070         const Variable* var = topLevelUniform->var();
4071         fTopLevelUniformMap.set(var, (int)fields.size());
4072         Modifiers modifiers = var->modifiers();
4073         modifiers.fFlags &= ~Modifiers::kUniform_Flag;
4074         fields.emplace_back(var->fPosition, modifiers, var->name(), &var->type());
4075     }
4076     fUniformBuffer.fStruct = Type::MakeStructType(fContext,
4077                                                   Position(),
4078                                                   kUniformBufferName,
4079                                                   std::move(fields),
4080                                                   /*interfaceBlock=*/true);
4081 
4082     // Create a global variable to contain this struct.
4083     Layout layout;
4084     layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4085     layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
4086     Modifiers modifiers{layout, Modifiers::kUniform_Flag};
4087 
4088     fUniformBuffer.fInnerVariable = std::make_unique<InterfaceBlockVariable>(
4089             /*pos=*/Position(), /*modifiersPosition=*/Position(),
4090             fContext.fModifiersPool->add(modifiers), kUniformBufferName,
4091             fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal);
4092 
4093     // Create an interface block object for this global variable.
4094     fUniformBuffer.fInterfaceBlock =
4095             std::make_unique<InterfaceBlock>(Position(),
4096                                              fUniformBuffer.fInnerVariable.get(),
4097                                              topLevelSymbolTable);
4098 
4099     // Generate an interface block and hold onto its ID.
4100     fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
4101 }
4102 
addRTFlipUniform(Position pos)4103 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
4104     SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
4105 
4106     if (fWroteRTFlip) {
4107         return;
4108     }
4109     // Flip variable hasn't been written yet. This means we don't have an existing
4110     // interface block, so we're free to just synthesize one.
4111     fWroteRTFlip = true;
4112     std::vector<Type::Field> fields;
4113     if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
4114         fContext.fErrors->error(pos, "RTFlipOffset is negative");
4115     }
4116     fields.emplace_back(pos,
4117                         Modifiers(Layout(/*flags=*/0,
4118                                          /*location=*/-1,
4119                                          fProgram.fConfig->fSettings.fRTFlipOffset,
4120                                          /*binding=*/-1,
4121                                          /*index=*/-1,
4122                                          /*set=*/-1,
4123                                          /*builtin=*/-1,
4124                                          /*inputAttachmentIndex=*/-1),
4125                                   /*flags=*/0),
4126                         SKSL_RTFLIP_NAME,
4127                         fContext.fTypes.fFloat2.get());
4128     std::string_view name = "sksl_synthetic_uniforms";
4129     const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(
4130             Type::MakeStructType(fContext, Position(), name, fields, /*interfaceBlock=*/true));
4131     bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
4132     int binding = -1, set = -1;
4133     if (!usePushConstants) {
4134         binding = fProgram.fConfig->fSettings.fRTFlipBinding;
4135         if (binding == -1) {
4136             fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
4137         }
4138         set = fProgram.fConfig->fSettings.fRTFlipSet;
4139         if (set == -1) {
4140             fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
4141         }
4142     }
4143     int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0;
4144     const Modifiers* modsPtr;
4145     {
4146         AutoAttachPoolToThread attach(fProgram.fPool.get());
4147         Modifiers modifiers(Layout(flags,
4148                                    /*location=*/-1,
4149                                    /*offset=*/-1,
4150                                    binding,
4151                                    /*index=*/-1,
4152                                    set,
4153                                    /*builtin=*/-1,
4154                                    /*inputAttachmentIndex=*/-1),
4155                             Modifiers::kUniform_Flag);
4156         modsPtr = fContext.fModifiersPool->add(modifiers);
4157     }
4158     InterfaceBlockVariable* intfVar = fSynthetics.takeOwnershipOfSymbol(
4159             std::make_unique<InterfaceBlockVariable>(/*pos=*/Position(),
4160                                                      /*modifiersPosition=*/Position(),
4161                                                      modsPtr,
4162                                                      name,
4163                                                      intfStruct,
4164                                                      /*builtin=*/false,
4165                                                      Variable::Storage::kGlobal));
4166     fSPIRVBonusVariables.add(intfVar);
4167     {
4168         AutoAttachPoolToThread attach(fProgram.fPool.get());
4169         fProgram.fSymbols->add(std::make_unique<Field>(Position(), intfVar, /*field=*/0));
4170     }
4171     InterfaceBlock intf(Position(), intfVar, std::make_shared<SymbolTable>(/*builtin=*/false));
4172     this->writeInterfaceBlock(intf, false);
4173 }
4174 
synthesizeTextureAndSampler(const Variable & combinedSampler)4175 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
4176         const Variable& combinedSampler) {
4177     SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
4178     SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
4179 
4180     const Modifiers& modifiers = combinedSampler.modifiers();
4181 
4182     auto data = std::make_unique<SynthesizedTextureSamplerPair>();
4183 
4184     Modifiers texModifiers = modifiers;
4185     texModifiers.fLayout.fBinding = modifiers.fLayout.fTexture;
4186     data->fTextureName = std::string(combinedSampler.name()) + "_texture";
4187     auto texture = std::make_unique<Variable>(/*pos=*/Position(),
4188                                               /*modifierPosition=*/Position(),
4189                                               fContext.fModifiersPool->add(texModifiers),
4190                                               data->fTextureName,
4191                                               &combinedSampler.type().textureType(),
4192                                               /*builtin=*/false,
4193                                               Variable::Storage::kGlobal);
4194 
4195     Modifiers samplerModifiers = modifiers;
4196     samplerModifiers.fLayout.fBinding = modifiers.fLayout.fSampler;
4197     data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
4198     auto sampler = std::make_unique<Variable>(/*pos=*/Position(),
4199                                               /*modifierPosition=*/Position(),
4200                                               fContext.fModifiersPool->add(samplerModifiers),
4201                                               data->fSamplerName,
4202                                               fContext.fTypes.fSampler.get(),
4203                                               /*builtin=*/false,
4204                                               Variable::Storage::kGlobal);
4205 
4206     const Variable* t = texture.get();
4207     const Variable* s = sampler.get();
4208     data->fTexture = std::move(texture);
4209     data->fSampler = std::move(sampler);
4210     fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
4211 
4212     return {t, s};
4213 }
4214 
writeInstructions(const Program & program,OutputStream & out)4215 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
4216     fGLSLExtendedInstructions = this->nextId(nullptr);
4217     StringStream body;
4218     // Assign SpvIds to functions.
4219     const FunctionDeclaration* main = nullptr;
4220     for (const ProgramElement* e : program.elements()) {
4221         if (e->is<FunctionDefinition>()) {
4222             const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
4223             const FunctionDeclaration& funcDecl = funcDef.declaration();
4224             fFunctionMap.set(&funcDecl, this->nextId(nullptr));
4225             if (funcDecl.isMain()) {
4226                 main = &funcDecl;
4227             }
4228         }
4229     }
4230     // Make sure we have a main() function.
4231     if (!main) {
4232         fContext.fErrors->error(Position(), "program does not contain a main() function");
4233         return;
4234     }
4235     // Emit interface blocks.
4236     std::set<SpvId> interfaceVars;
4237     for (const ProgramElement* e : program.elements()) {
4238         if (e->is<InterfaceBlock>()) {
4239             const InterfaceBlock& intf = e->as<InterfaceBlock>();
4240             SpvId id = this->writeInterfaceBlock(intf);
4241 
4242             const Modifiers& modifiers = intf.var()->modifiers();
4243             if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
4244                 modifiers.fLayout.fBuiltin == -1 && !this->isDead(*intf.var())) {
4245                 interfaceVars.insert(id);
4246             }
4247         }
4248     }
4249     // Emit global variable declarations.
4250     for (const ProgramElement* e : program.elements()) {
4251         if (e->is<GlobalVarDeclaration>()) {
4252             if (!this->writeGlobalVarDeclaration(program.fConfig->fKind,
4253                                                  e->as<GlobalVarDeclaration>().varDeclaration())) {
4254                 return;
4255             }
4256         }
4257     }
4258     // Emit top-level uniforms into a dedicated uniform buffer.
4259     if (!fTopLevelUniforms.empty()) {
4260         this->writeUniformBuffer(get_top_level_symbol_table(*main));
4261     }
4262     // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
4263     // main() and stores the result into sk_FragColor.
4264     EntrypointAdapter adapter;
4265     if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
4266         adapter = this->writeEntrypointAdapter(*main);
4267         if (adapter.entrypointDecl) {
4268             fFunctionMap.set(adapter.entrypointDecl.get(), this->nextId(nullptr));
4269             this->writeFunction(*adapter.entrypointDef, body);
4270             main = adapter.entrypointDecl.get();
4271         }
4272     }
4273     // Emit all the functions.
4274     for (const ProgramElement* e : program.elements()) {
4275         if (e->is<FunctionDefinition>()) {
4276             this->writeFunction(e->as<FunctionDefinition>(), body);
4277         }
4278     }
4279     // Add global in/out variables to the list of interface variables.
4280     for (const auto& [var, spvId] : fVariableMap) {
4281         if (var->storage() == Variable::Storage::kGlobal &&
4282             (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
4283             !this->isDead(*var)) {
4284             interfaceVars.insert(spvId);
4285         }
4286     }
4287     this->writeCapabilities(out);
4288     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
4289     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
4290     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().length() + 4) / 4) +
4291                       (int32_t) interfaceVars.size(), out);
4292     if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
4293         this->writeWord(SpvExecutionModelVertex, out);
4294     } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
4295         this->writeWord(SpvExecutionModelFragment, out);
4296     } else {
4297         SK_ABORT("cannot write this kind of program to SPIR-V\n");
4298     }
4299     SpvId entryPoint = fFunctionMap[main];
4300     this->writeWord(entryPoint, out);
4301     this->writeString(main->name(), out);
4302     for (int var : interfaceVars) {
4303         this->writeWord(var, out);
4304     }
4305     if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
4306         this->writeInstruction(SpvOpExecutionMode,
4307                                fFunctionMap[main],
4308                                SpvExecutionModeOriginUpperLeft,
4309                                out);
4310     }
4311     for (const ProgramElement* e : program.elements()) {
4312         if (e->is<Extension>()) {
4313             this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
4314         }
4315     }
4316 
4317     write_stringstream(fNameBuffer, out);
4318     write_stringstream(fDecorationBuffer, out);
4319     write_stringstream(fConstantBuffer, out);
4320     write_stringstream(body, out);
4321 }
4322 
generateCode()4323 bool SPIRVCodeGenerator::generateCode() {
4324     SkASSERT(!fContext.fErrors->errorCount());
4325     this->writeWord(SpvMagicNumber, *fOut);
4326     this->writeWord(SpvVersion, *fOut);
4327     this->writeWord(SKSL_MAGIC, *fOut);
4328     StringStream buffer;
4329     this->writeInstructions(fProgram, buffer);
4330     this->writeWord(fIdCount, *fOut);
4331     this->writeWord(0, *fOut); // reserved, always zero
4332     write_stringstream(buffer, *fOut);
4333     return fContext.fErrors->errorCount() == 0;
4334 }
4335 
4336 }  // namespace SkSL
4337