• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9 
10 #include "src/sksl/GLSL.std.450.h"
11 
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/SkSLOperators.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLExpressionStatement.h"
16 #include "src/sksl/ir/SkSLExtension.h"
17 #include "src/sksl/ir/SkSLIndexExpression.h"
18 #include "src/sksl/ir/SkSLVariableReference.h"
19 
20 #ifdef SK_VULKAN
21 #include "src/gpu/vk/GrVkCaps.h"
22 #endif
23 
24 #define kLast_Capability SpvCapabilityMultiViewport
25 
26 namespace SkSL {
27 
28 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
29 
setupIntrinsics()30 void SPIRVCodeGenerator::setupIntrinsics() {
31 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
32                                     GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x)
33 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, \
34                                                              GLSLstd450 ## ifFloat,             \
35                                                              GLSLstd450 ## ifInt,               \
36                                                              GLSLstd450 ## ifUInt,              \
37                                                              SpvOpUndef)
38 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicOpcodeKind, \
39                                      SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x)
40 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
41                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
42                                    k ## x ## _SpecialIntrinsic)
43     fIntrinsicMap[k_round_IntrinsicKind]         = ALL_GLSL(Round);
44     fIntrinsicMap[k_roundEven_IntrinsicKind]     = ALL_GLSL(RoundEven);
45     fIntrinsicMap[k_trunc_IntrinsicKind]         = ALL_GLSL(Trunc);
46     fIntrinsicMap[k_abs_IntrinsicKind]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
47     fIntrinsicMap[k_sign_IntrinsicKind]          = BY_TYPE_GLSL(FSign, SSign, SSign);
48     fIntrinsicMap[k_floor_IntrinsicKind]         = ALL_GLSL(Floor);
49     fIntrinsicMap[k_ceil_IntrinsicKind]          = ALL_GLSL(Ceil);
50     fIntrinsicMap[k_fract_IntrinsicKind]         = ALL_GLSL(Fract);
51     fIntrinsicMap[k_radians_IntrinsicKind]       = ALL_GLSL(Radians);
52     fIntrinsicMap[k_degrees_IntrinsicKind]       = ALL_GLSL(Degrees);
53     fIntrinsicMap[k_sin_IntrinsicKind]           = ALL_GLSL(Sin);
54     fIntrinsicMap[k_cos_IntrinsicKind]           = ALL_GLSL(Cos);
55     fIntrinsicMap[k_tan_IntrinsicKind]           = ALL_GLSL(Tan);
56     fIntrinsicMap[k_asin_IntrinsicKind]          = ALL_GLSL(Asin);
57     fIntrinsicMap[k_acos_IntrinsicKind]          = ALL_GLSL(Acos);
58     fIntrinsicMap[k_atan_IntrinsicKind]          = SPECIAL(Atan);
59     fIntrinsicMap[k_sinh_IntrinsicKind]          = ALL_GLSL(Sinh);
60     fIntrinsicMap[k_cosh_IntrinsicKind]          = ALL_GLSL(Cosh);
61     fIntrinsicMap[k_tanh_IntrinsicKind]          = ALL_GLSL(Tanh);
62     fIntrinsicMap[k_asinh_IntrinsicKind]         = ALL_GLSL(Asinh);
63     fIntrinsicMap[k_acosh_IntrinsicKind]         = ALL_GLSL(Acosh);
64     fIntrinsicMap[k_atanh_IntrinsicKind]         = ALL_GLSL(Atanh);
65     fIntrinsicMap[k_pow_IntrinsicKind]           = ALL_GLSL(Pow);
66     fIntrinsicMap[k_exp_IntrinsicKind]           = ALL_GLSL(Exp);
67     fIntrinsicMap[k_log_IntrinsicKind]           = ALL_GLSL(Log);
68     fIntrinsicMap[k_exp2_IntrinsicKind]          = ALL_GLSL(Exp2);
69     fIntrinsicMap[k_log2_IntrinsicKind]          = ALL_GLSL(Log2);
70     fIntrinsicMap[k_sqrt_IntrinsicKind]          = ALL_GLSL(Sqrt);
71     fIntrinsicMap[k_inverse_IntrinsicKind]       = ALL_GLSL(MatrixInverse);
72     fIntrinsicMap[k_outerProduct_IntrinsicKind]  = ALL_SPIRV(OuterProduct);
73     fIntrinsicMap[k_transpose_IntrinsicKind]     = ALL_SPIRV(Transpose);
74     fIntrinsicMap[k_isinf_IntrinsicKind]         = ALL_SPIRV(IsInf);
75     fIntrinsicMap[k_isnan_IntrinsicKind]         = ALL_SPIRV(IsNan);
76     fIntrinsicMap[k_inversesqrt_IntrinsicKind]   = ALL_GLSL(InverseSqrt);
77     fIntrinsicMap[k_determinant_IntrinsicKind]   = ALL_GLSL(Determinant);
78     fIntrinsicMap[k_matrixCompMult_IntrinsicKind] = SPECIAL(MatrixCompMult);
79     fIntrinsicMap[k_matrixInverse_IntrinsicKind] = ALL_GLSL(MatrixInverse);
80     fIntrinsicMap[k_mod_IntrinsicKind]           = SPECIAL(Mod);
81     fIntrinsicMap[k_modf_IntrinsicKind]          = ALL_GLSL(Modf);
82     fIntrinsicMap[k_min_IntrinsicKind]           = SPECIAL(Min);
83     fIntrinsicMap[k_max_IntrinsicKind]           = SPECIAL(Max);
84     fIntrinsicMap[k_clamp_IntrinsicKind]         = SPECIAL(Clamp);
85     fIntrinsicMap[k_saturate_IntrinsicKind]      = SPECIAL(Saturate);
86     fIntrinsicMap[k_dot_IntrinsicKind]           = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
87                                                       SpvOpDot, SpvOpUndef, SpvOpUndef, SpvOpUndef);
88     fIntrinsicMap[k_mix_IntrinsicKind]           = SPECIAL(Mix);
89     fIntrinsicMap[k_step_IntrinsicKind]          = SPECIAL(Step);
90     fIntrinsicMap[k_smoothstep_IntrinsicKind]    = SPECIAL(SmoothStep);
91     fIntrinsicMap[k_fma_IntrinsicKind]           = ALL_GLSL(Fma);
92     fIntrinsicMap[k_frexp_IntrinsicKind]         = ALL_GLSL(Frexp);
93     fIntrinsicMap[k_ldexp_IntrinsicKind]         = ALL_GLSL(Ldexp);
94 
95 #define PACK(type) fIntrinsicMap[k_pack##type##_IntrinsicKind] = ALL_GLSL(Pack##type); \
96                    fIntrinsicMap[k_unpack##type##_IntrinsicKind] = ALL_GLSL(Unpack##type)
97     PACK(Snorm4x8);
98     PACK(Unorm4x8);
99     PACK(Snorm2x16);
100     PACK(Unorm2x16);
101     PACK(Half2x16);
102     PACK(Double2x32);
103 #undef PACK
104     fIntrinsicMap[k_length_IntrinsicKind]      = ALL_GLSL(Length);
105     fIntrinsicMap[k_distance_IntrinsicKind]    = ALL_GLSL(Distance);
106     fIntrinsicMap[k_cross_IntrinsicKind]       = ALL_GLSL(Cross);
107     fIntrinsicMap[k_normalize_IntrinsicKind]   = ALL_GLSL(Normalize);
108     fIntrinsicMap[k_faceforward_IntrinsicKind] = ALL_GLSL(FaceForward);
109     fIntrinsicMap[k_reflect_IntrinsicKind]     = ALL_GLSL(Reflect);
110     fIntrinsicMap[k_refract_IntrinsicKind]     = ALL_GLSL(Refract);
111     fIntrinsicMap[k_bitCount_IntrinsicKind]    = ALL_SPIRV(BitCount);
112     fIntrinsicMap[k_findLSB_IntrinsicKind]     = ALL_GLSL(FindILsb);
113     fIntrinsicMap[k_findMSB_IntrinsicKind]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
114     fIntrinsicMap[k_dFdx_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
115                                                                  SpvOpDPdx, SpvOpUndef,
116                                                                  SpvOpUndef, SpvOpUndef);
117     fIntrinsicMap[k_dFdy_IntrinsicKind]        = SPECIAL(DFdy);
118     fIntrinsicMap[k_fwidth_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
119                                                                  SpvOpFwidth, SpvOpUndef,
120                                                                  SpvOpUndef, SpvOpUndef);
121     fIntrinsicMap[k_makeSampler2D_IntrinsicKind] = SPECIAL(SampledImage);
122 
123     fIntrinsicMap[k_sample_IntrinsicKind]      = SPECIAL(Texture);
124     fIntrinsicMap[k_subpassLoad_IntrinsicKind] = SPECIAL(SubpassLoad);
125 
126     fIntrinsicMap[k_floatBitsToInt_IntrinsicKind]  = ALL_SPIRV(Bitcast);
127     fIntrinsicMap[k_floatBitsToUint_IntrinsicKind] = ALL_SPIRV(Bitcast);
128     fIntrinsicMap[k_intBitsToFloat_IntrinsicKind]  = ALL_SPIRV(Bitcast);
129     fIntrinsicMap[k_uintBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast);
130 
131     fIntrinsicMap[k_any_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
132                                                                 SpvOpUndef, SpvOpUndef,
133                                                                 SpvOpUndef, SpvOpAny);
134     fIntrinsicMap[k_all_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
135                                                                 SpvOpUndef, SpvOpUndef,
136                                                                 SpvOpUndef, SpvOpAll);
137     fIntrinsicMap[k_not_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
138                                                                 SpvOpUndef, SpvOpUndef, SpvOpUndef,
139                                                                 SpvOpLogicalNot);
140     fIntrinsicMap[k_equal_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
141                                                                 SpvOpFOrdEqual, SpvOpIEqual,
142                                                                 SpvOpIEqual, SpvOpLogicalEqual);
143     fIntrinsicMap[k_notEqual_IntrinsicKind]   = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
144                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
145                                                                 SpvOpINotEqual,
146                                                                 SpvOpLogicalNotEqual);
147     fIntrinsicMap[k_lessThan_IntrinsicKind]         = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
148                                                                       SpvOpFOrdLessThan,
149                                                                       SpvOpSLessThan,
150                                                                       SpvOpULessThan,
151                                                                       SpvOpUndef);
152     fIntrinsicMap[k_lessThanEqual_IntrinsicKind]    = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
153                                                                       SpvOpFOrdLessThanEqual,
154                                                                       SpvOpSLessThanEqual,
155                                                                       SpvOpULessThanEqual,
156                                                                       SpvOpUndef);
157     fIntrinsicMap[k_greaterThan_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
158                                                                       SpvOpFOrdGreaterThan,
159                                                                       SpvOpSGreaterThan,
160                                                                       SpvOpUGreaterThan,
161                                                                       SpvOpUndef);
162     fIntrinsicMap[k_greaterThanEqual_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
163                                                                       SpvOpFOrdGreaterThanEqual,
164                                                                       SpvOpSGreaterThanEqual,
165                                                                       SpvOpUGreaterThanEqual,
166                                                                       SpvOpUndef);
167     fIntrinsicMap[k_EmitVertex_IntrinsicKind]       = ALL_SPIRV(EmitVertex);
168     fIntrinsicMap[k_EndPrimitive_IntrinsicKind]     = ALL_SPIRV(EndPrimitive);
169 // interpolateAt* not yet supported...
170 }
171 
writeWord(int32_t word,OutputStream & out)172 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
173     out.write((const char*) &word, sizeof(word));
174 }
175 
is_float(const Context & context,const Type & type)176 static bool is_float(const Context& context, const Type& type) {
177     return (type.isScalar() || type.isVector() || type.isMatrix()) &&
178            type.componentType().isFloat();
179 }
180 
is_signed(const Context & context,const Type & type)181 static bool is_signed(const Context& context, const Type& type) {
182     return type.isEnum() ||
183            ((type.isScalar() || type.isVector()) && type.componentType().isSigned());
184 }
185 
is_unsigned(const Context & context,const Type & type)186 static bool is_unsigned(const Context& context, const Type& type) {
187     return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
188 }
189 
is_bool(const Context & context,const Type & type)190 static bool is_bool(const Context& context, const Type& type) {
191     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
192 }
193 
is_out(const Variable & var)194 static bool is_out(const Variable& var) {
195     return (var.modifiers().fFlags & Modifiers::kOut_Flag) != 0;
196 }
197 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)198 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
199     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
200     SkASSERT(opCode != SpvOpUndef);
201     switch (opCode) {
202         case SpvOpReturn:      // fall through
203         case SpvOpReturnValue: // fall through
204         case SpvOpKill:        // fall through
205         case SpvOpSwitch:      // fall through
206         case SpvOpBranch:      // fall through
207         case SpvOpBranchConditional:
208             SkASSERT(fCurrentBlock);
209             fCurrentBlock = 0;
210             break;
211         case SpvOpConstant:          // fall through
212         case SpvOpConstantTrue:      // fall through
213         case SpvOpConstantFalse:     // fall through
214         case SpvOpConstantComposite: // fall through
215         case SpvOpTypeVoid:          // fall through
216         case SpvOpTypeInt:           // fall through
217         case SpvOpTypeFloat:         // fall through
218         case SpvOpTypeBool:          // fall through
219         case SpvOpTypeVector:        // fall through
220         case SpvOpTypeMatrix:        // fall through
221         case SpvOpTypeArray:         // fall through
222         case SpvOpTypePointer:       // fall through
223         case SpvOpTypeFunction:      // fall through
224         case SpvOpTypeRuntimeArray:  // fall through
225         case SpvOpTypeStruct:        // fall through
226         case SpvOpTypeImage:         // fall through
227         case SpvOpTypeSampledImage:  // fall through
228         case SpvOpTypeSampler:       // fall through
229         case SpvOpVariable:          // fall through
230         case SpvOpFunction:          // fall through
231         case SpvOpFunctionParameter: // fall through
232         case SpvOpFunctionEnd:       // fall through
233         case SpvOpExecutionMode:     // fall through
234         case SpvOpMemoryModel:       // fall through
235         case SpvOpCapability:        // fall through
236         case SpvOpExtInstImport:     // fall through
237         case SpvOpEntryPoint:        // fall through
238         case SpvOpSource:            // fall through
239         case SpvOpSourceExtension:   // fall through
240         case SpvOpName:              // fall through
241         case SpvOpMemberName:        // fall through
242         case SpvOpDecorate:          // fall through
243         case SpvOpMemberDecorate:
244             break;
245         default:
246             // We may find ourselves with dead code--instructions that don't have an associated
247             // block. This should be a rare event, but if it happens, synthesize a label; this is
248             // necessary to satisfy the validator.
249             if (fCurrentBlock == 0) {
250                 this->writeLabel(this->nextId(nullptr), out);
251             }
252             break;
253     }
254     this->writeWord((length << 16) | opCode, out);
255 }
256 
writeLabel(SpvId label,OutputStream & out)257 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
258     SkASSERT(!fCurrentBlock);
259     fCurrentBlock = label;
260     this->writeInstruction(SpvOpLabel, label, out);
261 }
262 
writeInstruction(SpvOp_ opCode,OutputStream & out)263 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
264     this->writeOpCode(opCode, 1, out);
265 }
266 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)267 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
268     this->writeOpCode(opCode, 2, out);
269     this->writeWord(word1, out);
270 }
271 
writeString(const char * string,size_t length,OutputStream & out)272 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
273     out.write(string, length);
274     switch (length % 4) {
275         case 1:
276             out.write8(0);
277             [[fallthrough]];
278         case 2:
279             out.write8(0);
280             [[fallthrough]];
281         case 3:
282             out.write8(0);
283             break;
284         default:
285             this->writeWord(0, out);
286     }
287 }
288 
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)289 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
290     this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
291     this->writeString(string.fChars, string.fLength, out);
292 }
293 
294 
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)295 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
296                                           OutputStream& out) {
297     this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
298     this->writeWord(word1, out);
299     this->writeString(string.fChars, string.fLength, out);
300 }
301 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)302 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
303                                           StringFragment string, OutputStream& out) {
304     this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
305     this->writeWord(word1, out);
306     this->writeWord(word2, out);
307     this->writeString(string.fChars, string.fLength, out);
308 }
309 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)310 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
311                                           OutputStream& out) {
312     this->writeOpCode(opCode, 3, out);
313     this->writeWord(word1, out);
314     this->writeWord(word2, out);
315 }
316 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)317 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
318                                           int32_t word3, OutputStream& out) {
319     this->writeOpCode(opCode, 4, out);
320     this->writeWord(word1, out);
321     this->writeWord(word2, out);
322     this->writeWord(word3, out);
323 }
324 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)325 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
326                                           int32_t word3, int32_t word4, OutputStream& out) {
327     this->writeOpCode(opCode, 5, out);
328     this->writeWord(word1, out);
329     this->writeWord(word2, out);
330     this->writeWord(word3, out);
331     this->writeWord(word4, out);
332 }
333 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)334 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
335                                           int32_t word3, int32_t word4, int32_t word5,
336                                           OutputStream& out) {
337     this->writeOpCode(opCode, 6, out);
338     this->writeWord(word1, out);
339     this->writeWord(word2, out);
340     this->writeWord(word3, out);
341     this->writeWord(word4, out);
342     this->writeWord(word5, out);
343 }
344 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)345 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
346                                           int32_t word3, int32_t word4, int32_t word5,
347                                           int32_t word6, OutputStream& out) {
348     this->writeOpCode(opCode, 7, out);
349     this->writeWord(word1, out);
350     this->writeWord(word2, out);
351     this->writeWord(word3, out);
352     this->writeWord(word4, out);
353     this->writeWord(word5, out);
354     this->writeWord(word6, out);
355 }
356 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)357 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
358                                           int32_t word3, int32_t word4, int32_t word5,
359                                           int32_t word6, int32_t word7, OutputStream& out) {
360     this->writeOpCode(opCode, 8, out);
361     this->writeWord(word1, out);
362     this->writeWord(word2, out);
363     this->writeWord(word3, out);
364     this->writeWord(word4, out);
365     this->writeWord(word5, out);
366     this->writeWord(word6, out);
367     this->writeWord(word7, out);
368 }
369 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)370 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
371                                           int32_t word3, int32_t word4, int32_t word5,
372                                           int32_t word6, int32_t word7, int32_t word8,
373                                           OutputStream& out) {
374     this->writeOpCode(opCode, 9, out);
375     this->writeWord(word1, out);
376     this->writeWord(word2, out);
377     this->writeWord(word3, out);
378     this->writeWord(word4, out);
379     this->writeWord(word5, out);
380     this->writeWord(word6, out);
381     this->writeWord(word7, out);
382     this->writeWord(word8, out);
383 }
384 
writeCapabilities(OutputStream & out)385 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
386     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
387         if (fCapabilities & bit) {
388             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
389         }
390     }
391     if (fProgram.fConfig->fKind == ProgramKind::kGeometry) {
392         this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
393     }
394     else {
395         this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
396     }
397 }
398 
nextId(const Type * type)399 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
400     return this->nextId(type && type->hasPrecision() && !type->highPrecision()
401                 ? Precision::kRelaxed
402                 : Precision::kDefault);
403 }
404 
nextId(Precision precision)405 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
406     if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
407         this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
408                                fDecorationBuffer);
409     }
410     return fIdCount++;
411 }
412 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)413 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
414                                      SpvId resultId) {
415     this->writeInstruction(SpvOpName, resultId, String(type.name()).c_str(), fNameBuffer);
416     // go ahead and write all of the field types, so we don't inadvertently write them while we're
417     // in the middle of writing the struct instruction
418     std::vector<SpvId> types;
419     for (const auto& f : type.fields()) {
420         types.push_back(this->getType(*f.fType, memoryLayout));
421     }
422     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
423     this->writeWord(resultId, fConstantBuffer);
424     for (SpvId id : types) {
425         this->writeWord(id, fConstantBuffer);
426     }
427     size_t offset = 0;
428     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
429         const Type::Field& field = type.fields()[i];
430         if (!MemoryLayout::LayoutIsSupported(*field.fType)) {
431             fErrors.error(type.fOffset, "type '" + field.fType->name() + "' is not permitted here");
432             return;
433         }
434         size_t size = memoryLayout.size(*field.fType);
435         size_t alignment = memoryLayout.alignment(*field.fType);
436         const Layout& fieldLayout = field.fModifiers.fLayout;
437         if (fieldLayout.fOffset >= 0) {
438             if (fieldLayout.fOffset < (int) offset) {
439                 fErrors.error(type.fOffset,
440                               "offset of field '" + field.fName + "' must be at "
441                               "least " + to_string((int) offset));
442             }
443             if (fieldLayout.fOffset % alignment) {
444                 fErrors.error(type.fOffset,
445                               "offset of field '" + field.fName + "' must be a multiple"
446                               " of " + to_string((int) alignment));
447             }
448             offset = fieldLayout.fOffset;
449         } else {
450             size_t mod = offset % alignment;
451             if (mod) {
452                 offset += alignment - mod;
453             }
454         }
455         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
456         this->writeLayout(fieldLayout, resultId, i);
457         if (field.fModifiers.fLayout.fBuiltin < 0) {
458             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
459                                    (SpvId) offset, fDecorationBuffer);
460         }
461         if (field.fType->isMatrix()) {
462             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
463                                    fDecorationBuffer);
464             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
465                                    (SpvId) memoryLayout.stride(*field.fType),
466                                    fDecorationBuffer);
467         }
468         if (!field.fType->highPrecision()) {
469             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
470                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
471         }
472         offset += size;
473         if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
474             offset += alignment - offset % alignment;
475         }
476     }
477 }
478 
getActualType(const Type & type)479 const Type& SPIRVCodeGenerator::getActualType(const Type& type) {
480     if (type.isFloat()) {
481         return *fContext.fTypes.fFloat;
482     }
483     if (type.isSigned() || type.isEnum()) {
484         return *fContext.fTypes.fInt;
485     }
486     if (type.isUnsigned()) {
487         return *fContext.fTypes.fUInt;
488     }
489     if (type.isMatrix() || type.isVector()) {
490         if (type.componentType() == *fContext.fTypes.fHalf) {
491             return fContext.fTypes.fFloat->toCompound(fContext, type.columns(), type.rows());
492         }
493         if (type.componentType() == *fContext.fTypes.fShort) {
494             return fContext.fTypes.fInt->toCompound(fContext, type.columns(), type.rows());
495         }
496         if (type.componentType() == *fContext.fTypes.fUShort) {
497             return fContext.fTypes.fUInt->toCompound(fContext, type.columns(), type.rows());
498         }
499     }
500     return type;
501 }
502 
getType(const Type & type)503 SpvId SPIRVCodeGenerator::getType(const Type& type) {
504     return this->getType(type, fDefaultLayout);
505 }
506 
getType(const Type & rawType,const MemoryLayout & layout)507 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
508     const Type& type = this->getActualType(rawType);
509     String key = type.name();
510     if (type.isStruct() || type.isArray()) {
511         key += to_string((int)layout.fStd);
512 #ifdef SK_DEBUG
513         SkASSERT(layout.fStd == MemoryLayout::Standard::k140_Standard ||
514                  layout.fStd == MemoryLayout::Standard::k430_Standard);
515         MemoryLayout::Standard otherStd = layout.fStd == MemoryLayout::Standard::k140_Standard
516                                                   ? MemoryLayout::Standard::k430_Standard
517                                                   : MemoryLayout::Standard::k140_Standard;
518         String otherKey = type.name() + to_string((int)otherStd);
519         SkASSERT(fTypeMap.find(otherKey) == fTypeMap.end());
520 #endif
521     }
522     auto entry = fTypeMap.find(key);
523     if (entry == fTypeMap.end()) {
524         SpvId result = this->nextId(nullptr);
525         switch (type.typeKind()) {
526             case Type::TypeKind::kScalar:
527                 if (type.isBoolean()) {
528                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
529                 } else if (type == *fContext.fTypes.fInt || type == *fContext.fTypes.fShort ||
530                            type == *fContext.fTypes.fIntLiteral) {
531                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
532                 } else if (type == *fContext.fTypes.fUInt || type == *fContext.fTypes.fUShort) {
533                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
534                 } else if (type == *fContext.fTypes.fFloat || type == *fContext.fTypes.fHalf ||
535                            type == *fContext.fTypes.fFloatLiteral) {
536                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
537                 } else {
538                     SkASSERT(false);
539                 }
540                 break;
541             case Type::TypeKind::kEnum:
542                 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
543                 break;
544             case Type::TypeKind::kVector:
545                 this->writeInstruction(SpvOpTypeVector, result,
546                                        this->getType(type.componentType(), layout),
547                                        type.columns(), fConstantBuffer);
548                 break;
549             case Type::TypeKind::kMatrix:
550                 this->writeInstruction(
551                         SpvOpTypeMatrix,
552                         result,
553                         this->getType(IndexExpression::IndexType(fContext, type), layout),
554                         type.columns(),
555                         fConstantBuffer);
556                 break;
557             case Type::TypeKind::kStruct:
558                 this->writeStruct(type, layout, result);
559                 break;
560             case Type::TypeKind::kArray: {
561                 if (!MemoryLayout::LayoutIsSupported(type)) {
562                     fErrors.error(type.fOffset, "type '" + type.name() + "' is not permitted here");
563                     return this->nextId(nullptr);
564                 }
565                 if (type.columns() > 0) {
566                     SpvId typeId = this->getType(type.componentType(), layout);
567                     IntLiteral countLiteral(/*offset=*/-1, type.columns(),
568                                             fContext.fTypes.fInt.get());
569                     SpvId countId = this->writeIntLiteral(countLiteral);
570                     this->writeInstruction(SpvOpTypeArray, result, typeId, countId,
571                                            fConstantBuffer);
572                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
573                                            (int32_t) layout.stride(type),
574                                            fDecorationBuffer);
575                 } else {
576                     // We shouldn't have any runtime-sized arrays right now
577                     fErrors.error(type.fOffset, "runtime-sized arrays are not supported in SPIR-V");
578                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
579                                            this->getType(type.componentType(), layout),
580                                            fConstantBuffer);
581                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
582                                            (int32_t) layout.stride(type),
583                                            fDecorationBuffer);
584                 }
585                 break;
586             }
587             case Type::TypeKind::kSampler: {
588                 SpvId image = result;
589                 if (SpvDimSubpassData != type.dimensions()) {
590                     image = this->getType(type.textureType(), layout);
591                 }
592                 if (SpvDimBuffer == type.dimensions()) {
593                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
594                 }
595                 if (SpvDimSubpassData != type.dimensions()) {
596                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
597                 }
598                 break;
599             }
600             case Type::TypeKind::kSeparateSampler: {
601                 this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer);
602                 break;
603             }
604             case Type::TypeKind::kTexture: {
605                 this->writeInstruction(SpvOpTypeImage, result,
606                                        this->getType(*fContext.fTypes.fFloat, layout),
607                                        type.dimensions(), type.isDepth(), type.isArrayedTexture(),
608                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
609                                        SpvImageFormatUnknown, fConstantBuffer);
610                 fImageTypeMap[key] = result;
611                 break;
612             }
613             default:
614                 if (type.isVoid()) {
615                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
616                 } else {
617                     SkDEBUGFAILF("invalid type: %s", type.description().c_str());
618                 }
619         }
620         fTypeMap[key] = result;
621         return result;
622     }
623     return entry->second;
624 }
625 
getImageType(const Type & type)626 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
627     SkASSERT(type.typeKind() == Type::TypeKind::kSampler);
628     this->getType(type);
629     String key = type.name() + to_string((int) fDefaultLayout.fStd);
630     SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
631     return fImageTypeMap[key];
632 }
633 
getFunctionType(const FunctionDeclaration & function)634 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
635     String key = to_string(this->getType(function.returnType())) + "(";
636     String separator;
637     const std::vector<const Variable*>& parameters = function.parameters();
638     for (size_t i = 0; i < parameters.size(); i++) {
639         key += separator;
640         separator = ", ";
641         key += to_string(this->getType(parameters[i]->type()));
642     }
643     key += ")";
644     auto entry = fTypeMap.find(key);
645     if (entry == fTypeMap.end()) {
646         SpvId result = this->nextId(nullptr);
647         int32_t length = 3 + (int32_t) parameters.size();
648         SpvId returnType = this->getType(function.returnType());
649         std::vector<SpvId> parameterTypes;
650         for (size_t i = 0; i < parameters.size(); i++) {
651             // glslang seems to treat all function arguments as pointers whether they need to be or
652             // not. I  was initially puzzled by this until I ran bizarre failures with certain
653             // patterns of function calls and control constructs, as exemplified by this minimal
654             // failure case:
655             //
656             // void sphere(float x) {
657             // }
658             //
659             // void map() {
660             //     sphere(1.0);
661             // }
662             //
663             // void main() {
664             //     for (int i = 0; i < 1; i++) {
665             //         map();
666             //     }
667             // }
668             //
669             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
670             // crashes. Making it take a float* and storing the argument in a temporary variable,
671             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
672             // the spec makes this make sense.
673 //            if (is_out(function->fParameters[i])) {
674                 parameterTypes.push_back(this->getPointerType(parameters[i]->type(),
675                                                               SpvStorageClassFunction));
676 //            } else {
677 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
678 //            }
679         }
680         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
681         this->writeWord(result, fConstantBuffer);
682         this->writeWord(returnType, fConstantBuffer);
683         for (SpvId id : parameterTypes) {
684             this->writeWord(id, fConstantBuffer);
685         }
686         fTypeMap[key] = result;
687         return result;
688     }
689     return entry->second;
690 }
691 
getPointerType(const Type & type,SpvStorageClass_ storageClass)692 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
693     return this->getPointerType(type, fDefaultLayout, storageClass);
694 }
695 
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)696 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
697                                          SpvStorageClass_ storageClass) {
698     const Type& type = this->getActualType(rawType);
699     String key = type.displayName() + "*" + to_string(layout.fStd) + to_string(storageClass);
700     auto entry = fTypeMap.find(key);
701     if (entry == fTypeMap.end()) {
702         SpvId result = this->nextId(nullptr);
703         this->writeInstruction(SpvOpTypePointer, result, storageClass,
704                                this->getType(type), fConstantBuffer);
705         fTypeMap[key] = result;
706         return result;
707     }
708     return entry->second;
709 }
710 
writeExpression(const Expression & expr,OutputStream & out)711 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
712     switch (expr.kind()) {
713         case Expression::Kind::kBinary:
714             return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
715         case Expression::Kind::kBoolLiteral:
716             return this->writeBoolLiteral(expr.as<BoolLiteral>());
717         case Expression::Kind::kConstructorArray:
718         case Expression::Kind::kConstructorStruct:
719             return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
720         case Expression::Kind::kConstructorDiagonalMatrix:
721             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
722         case Expression::Kind::kConstructorMatrixResize:
723             return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
724         case Expression::Kind::kConstructorScalarCast:
725             return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
726         case Expression::Kind::kConstructorSplat:
727             return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
728         case Expression::Kind::kConstructorCompound:
729             return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
730         case Expression::Kind::kConstructorCompoundCast:
731             return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
732         case Expression::Kind::kIntLiteral:
733             return this->writeIntLiteral(expr.as<IntLiteral>());
734         case Expression::Kind::kFieldAccess:
735             return this->writeFieldAccess(expr.as<FieldAccess>(), out);
736         case Expression::Kind::kFloatLiteral:
737             return this->writeFloatLiteral(expr.as<FloatLiteral>());
738         case Expression::Kind::kFunctionCall:
739             return this->writeFunctionCall(expr.as<FunctionCall>(), out);
740         case Expression::Kind::kPrefix:
741             return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
742         case Expression::Kind::kPostfix:
743             return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
744         case Expression::Kind::kSwizzle:
745             return this->writeSwizzle(expr.as<Swizzle>(), out);
746         case Expression::Kind::kVariableReference:
747             return this->writeVariableReference(expr.as<VariableReference>(), out);
748         case Expression::Kind::kTernary:
749             return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
750         case Expression::Kind::kIndex:
751             return this->writeIndexExpression(expr.as<IndexExpression>(), out);
752         default:
753             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
754             break;
755     }
756     return -1;
757 }
758 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)759 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
760     const FunctionDeclaration& function = c.function();
761     auto intrinsic = fIntrinsicMap.find(function.intrinsicKind());
762     if (intrinsic == fIntrinsicMap.end()) {
763         fErrors.error(c.fOffset, "unsupported intrinsic '" + function.description() + "'");
764         return -1;
765     }
766     int32_t intrinsicId;
767     const ExpressionArray& arguments = c.arguments();
768     if (arguments.size() > 0) {
769         const Type& type = arguments[0]->type();
770         if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicOpcodeKind ||
771             is_float(fContext, type)) {
772             intrinsicId = std::get<1>(intrinsic->second);
773         } else if (is_signed(fContext, type)) {
774             intrinsicId = std::get<2>(intrinsic->second);
775         } else if (is_unsigned(fContext, type)) {
776             intrinsicId = std::get<3>(intrinsic->second);
777         } else if (is_bool(fContext, type)) {
778             intrinsicId = std::get<4>(intrinsic->second);
779         } else {
780             intrinsicId = std::get<1>(intrinsic->second);
781         }
782     } else {
783         intrinsicId = std::get<1>(intrinsic->second);
784     }
785     switch (std::get<0>(intrinsic->second)) {
786         case kGLSL_STD_450_IntrinsicOpcodeKind: {
787             SpvId result = this->nextId(&c.type());
788             std::vector<SpvId> argumentIds;
789             for (size_t i = 0; i < arguments.size(); i++) {
790                 if (function.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
791                     // TODO(skia:11052): swizzled lvalues won't work with getPointer()
792                     argumentIds.push_back(this->getLValue(*arguments[i], out)->getPointer());
793                 } else {
794                     argumentIds.push_back(this->writeExpression(*arguments[i], out));
795                 }
796             }
797             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
798             this->writeWord(this->getType(c.type()), out);
799             this->writeWord(result, out);
800             this->writeWord(fGLSLExtendedInstructions, out);
801             this->writeWord(intrinsicId, out);
802             for (SpvId id : argumentIds) {
803                 this->writeWord(id, out);
804             }
805             return result;
806         }
807         case kSPIRV_IntrinsicOpcodeKind: {
808             // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
809             if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
810                 intrinsicId = SpvOpFMul;
811             }
812             SpvId result = this->nextId(&c.type());
813             std::vector<SpvId> argumentIds;
814             for (size_t i = 0; i < arguments.size(); i++) {
815                 if (function.parameters()[i]->modifiers().fFlags & Modifiers::kOut_Flag) {
816                     // TODO(skia:11052): swizzled lvalues won't work with getPointer()
817                     argumentIds.push_back(this->getLValue(*arguments[i], out)->getPointer());
818                 } else {
819                     argumentIds.push_back(this->writeExpression(*arguments[i], out));
820                 }
821             }
822             if (!c.type().isVoid()) {
823                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
824                 this->writeWord(this->getType(c.type()), out);
825                 this->writeWord(result, out);
826             } else {
827                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
828             }
829             for (SpvId id : argumentIds) {
830                 this->writeWord(id, out);
831             }
832             return result;
833         }
834         case kSpecial_IntrinsicOpcodeKind:
835             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
836         default:
837             fErrors.error(c.fOffset, "unsupported intrinsic '" + function.description() + "'");
838             return -1;
839     }
840 }
841 
vectorize(const ExpressionArray & args,OutputStream & out)842 std::vector<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
843     int vectorSize = 0;
844     for (const auto& a : args) {
845         if (a->type().isVector()) {
846             if (vectorSize) {
847                 SkASSERT(a->type().columns() == vectorSize);
848             }
849             else {
850                 vectorSize = a->type().columns();
851             }
852         }
853     }
854     std::vector<SpvId> result;
855     result.reserve(args.size());
856     for (const auto& arg : args) {
857         const Type& argType = arg->type();
858         SpvId raw = this->writeExpression(*arg, out);
859         if (vectorSize && argType.isScalar()) {
860             SpvId vector = this->nextId(&arg->type());
861             this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
862             this->writeWord(this->getType(argType.toCompound(fContext, vectorSize, 1)), out);
863             this->writeWord(vector, out);
864             for (int i = 0; i < vectorSize; i++) {
865                 this->writeWord(raw, out);
866             }
867             result.push_back(vector);
868         } else {
869             result.push_back(raw);
870         }
871     }
872     return result;
873 }
874 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)875 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
876                                                       SpvId signedInst, SpvId unsignedInst,
877                                                       const std::vector<SpvId>& args,
878                                                       OutputStream& out) {
879     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
880     this->writeWord(this->getType(type), out);
881     this->writeWord(id, out);
882     this->writeWord(fGLSLExtendedInstructions, out);
883 
884     if (is_float(fContext, type)) {
885         this->writeWord(floatInst, out);
886     } else if (is_signed(fContext, type)) {
887         this->writeWord(signedInst, out);
888     } else if (is_unsigned(fContext, type)) {
889         this->writeWord(unsignedInst, out);
890     } else {
891         SkASSERT(false);
892     }
893     for (SpvId a : args) {
894         this->writeWord(a, out);
895     }
896 }
897 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)898 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
899                                                 OutputStream& out) {
900     const ExpressionArray& arguments = c.arguments();
901     const Type& callType = c.type();
902     SpvId result = this->nextId(nullptr);
903     switch (kind) {
904         case kAtan_SpecialIntrinsic: {
905             std::vector<SpvId> argumentIds;
906             for (const std::unique_ptr<Expression>& arg : arguments) {
907                 argumentIds.push_back(this->writeExpression(*arg, out));
908             }
909             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
910             this->writeWord(this->getType(callType), out);
911             this->writeWord(result, out);
912             this->writeWord(fGLSLExtendedInstructions, out);
913             this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
914             for (SpvId id : argumentIds) {
915                 this->writeWord(id, out);
916             }
917             break;
918         }
919         case kSampledImage_SpecialIntrinsic: {
920             SkASSERT(arguments.size() == 2);
921             SpvId img = this->writeExpression(*arguments[0], out);
922             SpvId sampler = this->writeExpression(*arguments[1], out);
923             this->writeInstruction(SpvOpSampledImage,
924                                    this->getType(callType),
925                                    result,
926                                    img,
927                                    sampler,
928                                    out);
929             break;
930         }
931         case kSubpassLoad_SpecialIntrinsic: {
932             SpvId img = this->writeExpression(*arguments[0], out);
933             ExpressionArray args;
934             args.reserve_back(2);
935             args.push_back(IntLiteral::Make(fContext, /*offset=*/-1, /*value=*/0));
936             args.push_back(IntLiteral::Make(fContext, /*offset=*/-1, /*value=*/0));
937             ConstructorCompound ctor(/*offset=*/-1, *fContext.fTypes.fInt2, std::move(args));
938             SpvId coords = this->writeConstantVector(ctor);
939             if (arguments.size() == 1) {
940                 this->writeInstruction(SpvOpImageRead,
941                                        this->getType(callType),
942                                        result,
943                                        img,
944                                        coords,
945                                        out);
946             } else {
947                 SkASSERT(arguments.size() == 2);
948                 SpvId sample = this->writeExpression(*arguments[1], out);
949                 this->writeInstruction(SpvOpImageRead,
950                                        this->getType(callType),
951                                        result,
952                                        img,
953                                        coords,
954                                        SpvImageOperandsSampleMask,
955                                        sample,
956                                        out);
957             }
958             break;
959         }
960         case kTexture_SpecialIntrinsic: {
961             SpvOp_ op = SpvOpImageSampleImplicitLod;
962             const Type& arg1Type = arguments[1]->type();
963             switch (arguments[0]->type().dimensions()) {
964                 case SpvDim1D:
965                     if (arg1Type == *fContext.fTypes.fFloat2) {
966                         op = SpvOpImageSampleProjImplicitLod;
967                     } else {
968                         SkASSERT(arg1Type == *fContext.fTypes.fFloat);
969                     }
970                     break;
971                 case SpvDim2D:
972                     if (arg1Type == *fContext.fTypes.fFloat3) {
973                         op = SpvOpImageSampleProjImplicitLod;
974                     } else {
975                         SkASSERT(arg1Type == *fContext.fTypes.fFloat2);
976                     }
977                     break;
978                 case SpvDim3D:
979                     if (arg1Type == *fContext.fTypes.fFloat4) {
980                         op = SpvOpImageSampleProjImplicitLod;
981                     } else {
982                         SkASSERT(arg1Type == *fContext.fTypes.fFloat3);
983                     }
984                     break;
985                 case SpvDimCube:   // fall through
986                 case SpvDimRect:   // fall through
987                 case SpvDimBuffer: // fall through
988                 case SpvDimSubpassData:
989                     break;
990             }
991             SpvId type = this->getType(callType);
992             SpvId sampler = this->writeExpression(*arguments[0], out);
993             SpvId uv = this->writeExpression(*arguments[1], out);
994             if (arguments.size() == 3) {
995                 this->writeInstruction(op, type, result, sampler, uv,
996                                        SpvImageOperandsBiasMask,
997                                        this->writeExpression(*arguments[2], out),
998                                        out);
999             } else {
1000                 SkASSERT(arguments.size() == 2);
1001                 if (fProgram.fConfig->fSettings.fSharpenTextures) {
1002                     FloatLiteral lodBias(/*offset=*/-1, /*value=*/-0.5,
1003                                          fContext.fTypes.fFloat.get());
1004                     this->writeInstruction(op, type, result, sampler, uv,
1005                                            SpvImageOperandsBiasMask,
1006                                            this->writeFloatLiteral(lodBias),
1007                                            out);
1008                 } else {
1009                     this->writeInstruction(op, type, result, sampler, uv,
1010                                            out);
1011                 }
1012             }
1013             break;
1014         }
1015         case kMod_SpecialIntrinsic: {
1016             std::vector<SpvId> args = this->vectorize(arguments, out);
1017             SkASSERT(args.size() == 2);
1018             const Type& operandType = arguments[0]->type();
1019             SpvOp_ op;
1020             if (is_float(fContext, operandType)) {
1021                 op = SpvOpFMod;
1022             } else if (is_signed(fContext, operandType)) {
1023                 op = SpvOpSMod;
1024             } else if (is_unsigned(fContext, operandType)) {
1025                 op = SpvOpUMod;
1026             } else {
1027                 SkASSERT(false);
1028                 return 0;
1029             }
1030             this->writeOpCode(op, 5, out);
1031             this->writeWord(this->getType(operandType), out);
1032             this->writeWord(result, out);
1033             this->writeWord(args[0], out);
1034             this->writeWord(args[1], out);
1035             break;
1036         }
1037         case kDFdy_SpecialIntrinsic: {
1038             SpvId fn = this->writeExpression(*arguments[0], out);
1039             this->writeOpCode(SpvOpDPdy, 4, out);
1040             this->writeWord(this->getType(callType), out);
1041             this->writeWord(result, out);
1042             this->writeWord(fn, out);
1043             if (fProgram.fConfig->fSettings.fFlipY) {
1044                 // Flipping Y also negates the Y derivatives.
1045                 SpvId flipped = this->nextId(&callType);
1046                 this->writeInstruction(SpvOpFNegate, this->getType(callType), flipped, result,
1047                                        out);
1048                 result = flipped;
1049             }
1050             break;
1051         }
1052         case kClamp_SpecialIntrinsic: {
1053             std::vector<SpvId> args = this->vectorize(arguments, out);
1054             SkASSERT(args.size() == 3);
1055             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1056                                                GLSLstd450UClamp, args, out);
1057             break;
1058         }
1059         case kMax_SpecialIntrinsic: {
1060             std::vector<SpvId> args = this->vectorize(arguments, out);
1061             SkASSERT(args.size() == 2);
1062             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
1063                                                GLSLstd450UMax, args, out);
1064             break;
1065         }
1066         case kMin_SpecialIntrinsic: {
1067             std::vector<SpvId> args = this->vectorize(arguments, out);
1068             SkASSERT(args.size() == 2);
1069             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
1070                                                GLSLstd450UMin, args, out);
1071             break;
1072         }
1073         case kMix_SpecialIntrinsic: {
1074             std::vector<SpvId> args = this->vectorize(arguments, out);
1075             SkASSERT(args.size() == 3);
1076             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
1077                                                SpvOpUndef, args, out);
1078             break;
1079         }
1080         case kSaturate_SpecialIntrinsic: {
1081             SkASSERT(arguments.size() == 1);
1082             ExpressionArray finalArgs;
1083             finalArgs.reserve_back(3);
1084             finalArgs.push_back(arguments[0]->clone());
1085             finalArgs.push_back(FloatLiteral::Make(fContext, /*offset=*/-1, /*value=*/0));
1086             finalArgs.push_back(FloatLiteral::Make(fContext, /*offset=*/-1, /*value=*/1));
1087             std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
1088             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1089                                                GLSLstd450UClamp, spvArgs, out);
1090             break;
1091         }
1092         case kSmoothStep_SpecialIntrinsic: {
1093             std::vector<SpvId> args = this->vectorize(arguments, out);
1094             SkASSERT(args.size() == 3);
1095             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
1096                                                SpvOpUndef, args, out);
1097             break;
1098         }
1099         case kStep_SpecialIntrinsic: {
1100             std::vector<SpvId> args = this->vectorize(arguments, out);
1101             SkASSERT(args.size() == 2);
1102             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
1103                                                SpvOpUndef, args, out);
1104             break;
1105         }
1106         case kMatrixCompMult_SpecialIntrinsic: {
1107             SkASSERT(arguments.size() == 2);
1108             SpvId lhs = this->writeExpression(*arguments[0], out);
1109             SpvId rhs = this->writeExpression(*arguments[1], out);
1110             result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, SpvOpUndef,
1111                                                           out);
1112             break;
1113         }
1114     }
1115     return result;
1116 }
1117 
1118 namespace {
1119 struct TempVar {
1120     SpvId spvId;
1121     const Type* type;
1122     std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
1123 };
1124 }
1125 
writeFunctionCall(const FunctionCall & c,OutputStream & out)1126 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1127     const FunctionDeclaration& function = c.function();
1128     if (function.isIntrinsic() && !function.definition()) {
1129         return this->writeIntrinsicCall(c, out);
1130     }
1131     const ExpressionArray& arguments = c.arguments();
1132     const auto& entry = fFunctionMap.find(&function);
1133     if (entry == fFunctionMap.end()) {
1134         fErrors.error(c.fOffset, "function '" + function.description() + "' is not defined");
1135         return -1;
1136     }
1137     // Temp variables are used to write back out-parameters after the function call is complete.
1138     std::vector<TempVar> tempVars;
1139     std::vector<SpvId> argumentIds;
1140     for (size_t i = 0; i < arguments.size(); i++) {
1141         // id of temporary variable that we will use to hold this argument, or 0 if it is being
1142         // passed directly
1143         SpvId tmpVar;
1144         // if we need a temporary var to store this argument, this is the value to store in the var
1145         SpvId tmpValueId;
1146         if (is_out(*function.parameters()[i])) {
1147             std::unique_ptr<LValue> lv = this->getLValue(*arguments[i], out);
1148             SpvId ptr = lv->getPointer();
1149             if (ptr != (SpvId) -1 && lv->isMemoryObjectPointer()) {
1150                 argumentIds.push_back(ptr);
1151                 continue;
1152             } else {
1153                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1154                 // copy it into a temp, call the function, read the value out of the temp, and then
1155                 // update the lvalue.
1156                 tmpValueId = lv->load(out);
1157                 tmpVar = this->nextId(nullptr);
1158                 tempVars.push_back(TempVar{tmpVar, &arguments[i]->type(), std::move(lv)});
1159             }
1160         } else {
1161             // See getFunctionType for an explanation of why we're always using pointer parameters.
1162             tmpValueId = this->writeExpression(*arguments[i], out);
1163             tmpVar = this->nextId(nullptr);
1164         }
1165         this->writeInstruction(SpvOpVariable,
1166                                this->getPointerType(arguments[i]->type(), SpvStorageClassFunction),
1167                                tmpVar,
1168                                SpvStorageClassFunction,
1169                                fVariableBuffer);
1170         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1171         argumentIds.push_back(tmpVar);
1172     }
1173     SpvId result = this->nextId(nullptr);
1174     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) arguments.size(), out);
1175     this->writeWord(this->getType(c.type()), out);
1176     this->writeWord(result, out);
1177     this->writeWord(entry->second, out);
1178     for (SpvId id : argumentIds) {
1179         this->writeWord(id, out);
1180     }
1181     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
1182     for (const TempVar& tempVar : tempVars) {
1183         SpvId load = this->nextId(tempVar.type);
1184         this->writeInstruction(SpvOpLoad, getType(*tempVar.type), load, tempVar.spvId, out);
1185         tempVar.lvalue->store(load, out);
1186     }
1187     return result;
1188 }
1189 
writeConstantVector(const AnyConstructor & c)1190 SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) {
1191     const Type& type = c.type();
1192     SkASSERT(type.isVector() && c.isCompileTimeConstant());
1193 
1194     // Get each of the constructor components as SPIR-V constants.
1195     SPIRVVectorConstant key{this->getType(type),
1196                             /*fValueId=*/{SpvId(-1), SpvId(-1), SpvId(-1), SpvId(-1)}};
1197 
1198     for (int n = 0; n < type.columns(); n++) {
1199         const Expression* expr = c.getConstantSubexpression(n);
1200         if (!expr) {
1201             SkDEBUGFAILF("writeConstantVector: %s not actually constant", c.description().c_str());
1202             return (SpvId)-1;
1203         }
1204         key.fValueId[n] = this->writeExpression(*expr, fConstantBuffer);
1205     }
1206 
1207     // Check to see if we've already synthesized this vector constant.
1208     auto [iter, newlyCreated] = fVectorConstants.insert({key, (SpvId)-1});
1209     if (newlyCreated) {
1210         // Emit an OpConstantComposite instruction for this constant.
1211         SpvId result = this->nextId(&type);
1212         this->writeOpCode(SpvOpConstantComposite, 3 + type.columns(), fConstantBuffer);
1213         this->writeWord(key.fTypeId, fConstantBuffer);
1214         this->writeWord(result, fConstantBuffer);
1215         for (int i = 0; i < type.columns(); i++) {
1216             this->writeWord(key.fValueId[i], fConstantBuffer);
1217         }
1218         iter->second = result;
1219     }
1220     return iter->second;
1221 }
1222 
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)1223 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
1224                                            const Type& inputType,
1225                                            const Type& outputType,
1226                                            OutputStream& out) {
1227     if (outputType.isFloat()) {
1228         return this->castScalarToFloat(inputExprId, inputType, outputType, out);
1229     }
1230     if (outputType.isSigned()) {
1231         return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
1232     }
1233     if (outputType.isUnsigned()) {
1234         return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
1235     }
1236     if (outputType.isBoolean()) {
1237         return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
1238     }
1239 
1240     fErrors.error(-1, "unsupported cast: " + inputType.description() +
1241                       " to " + outputType.description());
1242     return inputExprId;
1243 }
1244 
writeFloatConstructor(const AnyConstructor & c,OutputStream & out)1245 SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) {
1246     SkASSERT(c.argumentSpan().size() == 1);
1247     SkASSERT(c.type().isFloat());
1248     const Expression& ctorExpr = *c.argumentSpan().front();
1249     SpvId expressionId = this->writeExpression(ctorExpr, out);
1250     return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out);
1251 }
1252 
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1253 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
1254                                             const Type& outputType, OutputStream& out) {
1255     // Casting a float to float is a no-op.
1256     if (inputType.isFloat()) {
1257         return inputId;
1258     }
1259 
1260     // Given the input type, generate the appropriate instruction to cast to float.
1261     SpvId result = this->nextId(&outputType);
1262     if (inputType.isBoolean()) {
1263         // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
1264         FloatLiteral one(/*offset=*/-1, /*value=*/1, fContext.fTypes.fFloat.get());
1265         SpvId        oneID = this->writeFloatLiteral(one);
1266         FloatLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
1267         SpvId        zeroID = this->writeFloatLiteral(zero);
1268         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1269                                inputId, oneID, zeroID, out);
1270     } else if (inputType.isSigned()) {
1271         this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
1272     } else if (inputType.isUnsigned()) {
1273         this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
1274     } else {
1275         SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
1276         return (SpvId)-1;
1277     }
1278     return result;
1279 }
1280 
writeIntConstructor(const AnyConstructor & c,OutputStream & out)1281 SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) {
1282     SkASSERT(c.argumentSpan().size() == 1);
1283     SkASSERT(c.type().isSigned());
1284     const Expression& ctorExpr = *c.argumentSpan().front();
1285     SpvId expressionId = this->writeExpression(ctorExpr, out);
1286     return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out);
1287 }
1288 
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1289 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
1290                                                 const Type& outputType, OutputStream& out) {
1291     // Casting a signed int to signed int is a no-op.
1292     if (inputType.isSigned()) {
1293         return inputId;
1294     }
1295 
1296     // Given the input type, generate the appropriate instruction to cast to signed int.
1297     SpvId result = this->nextId(&outputType);
1298     if (inputType.isBoolean()) {
1299         // Use OpSelect to convert the boolean argument to a literal 1 or 0.
1300         IntLiteral one(/*offset=*/-1, /*value=*/1, fContext.fTypes.fInt.get());
1301         SpvId      oneID = this->writeIntLiteral(one);
1302         IntLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fInt.get());
1303         SpvId      zeroID = this->writeIntLiteral(zero);
1304         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1305                                inputId, oneID, zeroID, out);
1306     } else if (inputType.isFloat()) {
1307         this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
1308     } else if (inputType.isUnsigned()) {
1309         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1310     } else {
1311         SkDEBUGFAILF("unsupported type for signed int typecast: %s",
1312                      inputType.description().c_str());
1313         return (SpvId)-1;
1314     }
1315     return result;
1316 }
1317 
writeUIntConstructor(const AnyConstructor & c,OutputStream & out)1318 SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) {
1319     SkASSERT(c.argumentSpan().size() == 1);
1320     SkASSERT(c.type().isUnsigned());
1321     const Expression& ctorExpr = *c.argumentSpan().front();
1322     SpvId expressionId = this->writeExpression(ctorExpr, out);
1323     return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out);
1324 }
1325 
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1326 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
1327                                                   const Type& outputType, OutputStream& out) {
1328     // Casting an unsigned int to unsigned int is a no-op.
1329     if (inputType.isUnsigned()) {
1330         return inputId;
1331     }
1332 
1333     // Given the input type, generate the appropriate instruction to cast to unsigned int.
1334     SpvId result = this->nextId(&outputType);
1335     if (inputType.isBoolean()) {
1336         // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
1337         IntLiteral one(/*offset=*/-1, /*value=*/1, fContext.fTypes.fUInt.get());
1338         SpvId      oneID = this->writeIntLiteral(one);
1339         IntLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fUInt.get());
1340         SpvId      zeroID = this->writeIntLiteral(zero);
1341         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1342                                inputId, oneID, zeroID, out);
1343     } else if (inputType.isFloat()) {
1344         this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
1345     } else if (inputType.isSigned()) {
1346         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1347     } else {
1348         SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
1349                      inputType.description().c_str());
1350         return (SpvId)-1;
1351     }
1352     return result;
1353 }
1354 
writeBooleanConstructor(const AnyConstructor & c,OutputStream & out)1355 SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) {
1356     SkASSERT(c.argumentSpan().size() == 1);
1357     SkASSERT(c.type().isBoolean());
1358     const Expression& ctorExpr = *c.argumentSpan().front();
1359     SpvId expressionId = this->writeExpression(ctorExpr, out);
1360     return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out);
1361 }
1362 
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1363 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
1364                                               const Type& outputType, OutputStream& out) {
1365     // Casting a bool to bool is a no-op.
1366     if (inputType.isBoolean()) {
1367         return inputId;
1368     }
1369 
1370     // Given the input type, generate the appropriate instruction to cast to bool.
1371     SpvId result = this->nextId(nullptr);
1372     if (inputType.isSigned()) {
1373         // Synthesize a boolean result by comparing the input against a signed zero literal.
1374         IntLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fInt.get());
1375         SpvId      zeroID = this->writeIntLiteral(zero);
1376         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1377                                inputId, zeroID, out);
1378     } else if (inputType.isUnsigned()) {
1379         // Synthesize a boolean result by comparing the input against an unsigned zero literal.
1380         IntLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fUInt.get());
1381         SpvId      zeroID = this->writeIntLiteral(zero);
1382         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1383                                inputId, zeroID, out);
1384     } else if (inputType.isFloat()) {
1385         // Synthesize a boolean result by comparing the input against a floating-point zero literal.
1386         FloatLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
1387         SpvId        zeroID = this->writeFloatLiteral(zero);
1388         this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
1389                                inputId, zeroID, out);
1390     } else {
1391         SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
1392         return (SpvId)-1;
1393     }
1394     return result;
1395 }
1396 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1397 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1398                                                  OutputStream& out) {
1399     FloatLiteral zero(/*offset=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
1400     SpvId zeroId = this->writeFloatLiteral(zero);
1401     std::vector<SpvId> columnIds;
1402     columnIds.reserve(type.columns());
1403     for (int column = 0; column < type.columns(); column++) {
1404         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1405                           out);
1406         this->writeWord(this->getType(type.componentType().toCompound(
1407                                 fContext, /*columns=*/type.rows(), /*rows=*/1)),
1408                         out);
1409         SpvId columnId = this->nextId(&type);
1410         this->writeWord(columnId, out);
1411         columnIds.push_back(columnId);
1412         for (int row = 0; row < type.rows(); row++) {
1413             this->writeWord(row == column ? diagonal : zeroId, out);
1414         }
1415     }
1416     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1417                       out);
1418     this->writeWord(this->getType(type), out);
1419     this->writeWord(id, out);
1420     for (SpvId columnId : columnIds) {
1421         this->writeWord(columnId, out);
1422     }
1423 }
1424 
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1425 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
1426                                           OutputStream& out) {
1427     SkASSERT(srcType.isMatrix());
1428     SkASSERT(dstType.isMatrix());
1429     SkASSERT(srcType.componentType() == dstType.componentType());
1430     SpvId id = this->nextId(&dstType);
1431     SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1432                                                                            srcType.rows(),
1433                                                                            1));
1434     SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1435                                                                            dstType.rows(),
1436                                                                            1));
1437     SkASSERT(dstType.componentType().isFloat());
1438     FloatLiteral zero(/*offset=*/-1, /*value=*/0.0, &dstType.componentType());
1439     const SpvId zeroId = this->writeFloatLiteral(zero);
1440     FloatLiteral one(/*offset=*/-1, /*value=*/1.0, &dstType.componentType());
1441     const SpvId oneId = this->writeFloatLiteral(one);
1442 
1443     SpvId columns[4];
1444     for (int i = 0; i < dstType.columns(); i++) {
1445         if (i < srcType.columns()) {
1446             // we're still inside the src matrix, copy the column
1447             SpvId srcColumn = this->nextId(&dstType);
1448             this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1449             SpvId dstColumn;
1450             if (srcType.rows() == dstType.rows()) {
1451                 // columns are equal size, don't need to do anything
1452                 dstColumn = srcColumn;
1453             }
1454             else if (dstType.rows() > srcType.rows()) {
1455                 // dst column is bigger, need to zero-pad it
1456                 dstColumn = this->nextId(&dstType);
1457                 int delta = dstType.rows() - srcType.rows();
1458                 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1459                 this->writeWord(dstColumnType, out);
1460                 this->writeWord(dstColumn, out);
1461                 this->writeWord(srcColumn, out);
1462                 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
1463                     this->writeWord((i == j) ? oneId : zeroId, out);
1464                 }
1465             }
1466             else {
1467                 // dst column is smaller, need to swizzle the src column
1468                 dstColumn = this->nextId(&dstType);
1469                 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
1470                 this->writeWord(dstColumnType, out);
1471                 this->writeWord(dstColumn, out);
1472                 this->writeWord(srcColumn, out);
1473                 this->writeWord(srcColumn, out);
1474                 for (int j = 0; j < dstType.rows(); j++) {
1475                     this->writeWord(j, out);
1476                 }
1477             }
1478             columns[i] = dstColumn;
1479         } else {
1480             // we're past the end of the src matrix, need to synthesize an identity-matrix column
1481             SpvId identityColumn = this->nextId(&dstType);
1482             this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1483             this->writeWord(dstColumnType, out);
1484             this->writeWord(identityColumn, out);
1485             for (int j = 0; j < dstType.rows(); ++j) {
1486                 this->writeWord((i == j) ? oneId : zeroId, out);
1487             }
1488             columns[i] = identityColumn;
1489         }
1490     }
1491     this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1492     this->writeWord(this->getType(dstType), out);
1493     this->writeWord(id, out);
1494     for (int i = 0; i < dstType.columns(); i++) {
1495         this->writeWord(columns[i], out);
1496     }
1497     return id;
1498 }
1499 
addColumnEntry(SpvId columnType,Precision precision,std::vector<SpvId> * currentColumn,std::vector<SpvId> * columnIds,int * currentCount,int rows,SpvId entry,OutputStream & out)1500 void SPIRVCodeGenerator::addColumnEntry(SpvId columnType, Precision precision,
1501                                         std::vector<SpvId>* currentColumn,
1502                                         std::vector<SpvId>* columnIds,
1503                                         int* currentCount, int rows, SpvId entry,
1504                                         OutputStream& out) {
1505     SkASSERT(*currentCount < rows);
1506     ++(*currentCount);
1507     currentColumn->push_back(entry);
1508     if (*currentCount == rows) {
1509         *currentCount = 0;
1510         this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn->size(), out);
1511         this->writeWord(columnType, out);
1512         SpvId columnId = this->nextId(precision);
1513         this->writeWord(columnId, out);
1514         columnIds->push_back(columnId);
1515         for (SpvId id : *currentColumn) {
1516             this->writeWord(id, out);
1517         }
1518         currentColumn->clear();
1519     }
1520 }
1521 
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)1522 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
1523     const Type& type = c.type();
1524     SkASSERT(type.isMatrix());
1525     SkASSERT(!c.arguments().empty());
1526     const Type& arg0Type = c.arguments()[0]->type();
1527     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1528     // an instruction
1529     std::vector<SpvId> arguments;
1530     arguments.reserve(c.arguments().size());
1531     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1532         arguments.push_back(this->writeExpression(*arg, out));
1533     }
1534     SpvId result = this->nextId(&type);
1535     int rows = type.rows();
1536     int columns = type.columns();
1537     if (arguments.size() == 1 && arg0Type.isVector()) {
1538         // Special-case handling of float4 -> mat2x2.
1539         SkASSERT(type.rows() == 2 && type.columns() == 2);
1540         SkASSERT(arg0Type.columns() == 4);
1541         SpvId componentType = this->getType(type.componentType());
1542         SpvId v[4];
1543         for (int i = 0; i < 4; ++i) {
1544             v[i] = this->nextId(&type);
1545             this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i,
1546                                    out);
1547         }
1548         SpvId columnType = this->getType(type.componentType().toCompound(fContext, 2, 1));
1549         SpvId column1 = this->nextId(&type);
1550         this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1551         SpvId column2 = this->nextId(&type);
1552         this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1553         this->writeInstruction(SpvOpCompositeConstruct, this->getType(type), result, column1,
1554                                column2, out);
1555     } else {
1556         SpvId columnType = this->getType(type.componentType().toCompound(fContext, rows, 1));
1557         std::vector<SpvId> columnIds;
1558         // ids of vectors and scalars we have written to the current column so far
1559         std::vector<SpvId> currentColumn;
1560         // the total number of scalars represented by currentColumn's entries
1561         int currentCount = 0;
1562         Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
1563         for (size_t i = 0; i < arguments.size(); i++) {
1564             const Type& argType = c.arguments()[i]->type();
1565             if (currentCount == 0 && argType.isVector() &&
1566                 argType.columns() == type.rows()) {
1567                 // this is a complete column by itself
1568                 columnIds.push_back(arguments[i]);
1569             } else {
1570                 if (argType.columns() == 1) {
1571                     this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1572                                          &currentCount, rows, arguments[i], out);
1573                 } else {
1574                     SpvId componentType = this->getType(argType.componentType());
1575                     for (int j = 0; j < argType.columns(); ++j) {
1576                         SpvId swizzle = this->nextId(&argType);
1577                         this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1578                                                arguments[i], j, out);
1579                         this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1580                                              &currentCount, rows, swizzle, out);
1581                     }
1582                 }
1583             }
1584         }
1585         SkASSERT(columnIds.size() == (size_t) columns);
1586         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1587         this->writeWord(this->getType(type), out);
1588         this->writeWord(result, out);
1589         for (SpvId id : columnIds) {
1590             this->writeWord(id, out);
1591         }
1592     }
1593     return result;
1594 }
1595 
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)1596 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1597                                                    OutputStream& out) {
1598     return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
1599                                : this->writeVectorConstructor(c, out);
1600 }
1601 
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)1602 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
1603     const Type& type = c.type();
1604     const Type& componentType = type.componentType();
1605     SkASSERT(type.isVector());
1606 
1607     if (c.isCompileTimeConstant()) {
1608         return this->writeConstantVector(c);
1609     }
1610 
1611     std::vector<SpvId> arguments;
1612     for (size_t i = 0; i < c.arguments().size(); i++) {
1613         const Type& argType = c.arguments()[i]->type();
1614         SkASSERT(componentType == argType.componentType());
1615 
1616         if (argType.isVector()) {
1617             // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
1618             // vector arguments at all, so we always extract each vector component and pass them
1619             // into OpCompositeConstruct individually.
1620             SpvId vec = this->writeExpression(*c.arguments()[i], out);
1621             for (int j = 0; j < argType.columns(); j++) {
1622                 SpvId componentId = this->nextId(&componentType);
1623                 this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType),
1624                                        componentId, vec, j, out);
1625                 arguments.push_back(componentId);
1626             }
1627         } else {
1628             arguments.push_back(this->writeExpression(*c.arguments()[i], out));
1629         }
1630     }
1631 
1632     return this->writeComposite(arguments, type, out);
1633 }
1634 
writeComposite(const std::vector<SpvId> & arguments,const Type & type,OutputStream & out)1635 SpvId SPIRVCodeGenerator::writeComposite(const std::vector<SpvId>& arguments,
1636                                          const Type& type,
1637                                          OutputStream& out) {
1638     SkASSERT(arguments.size() == (type.isStruct() ? type.fields().size() : (size_t)type.columns()));
1639 
1640     SpvId result = this->nextId(&type);
1641     this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1642     this->writeWord(this->getType(type), out);
1643     this->writeWord(result, out);
1644     for (SpvId id : arguments) {
1645         this->writeWord(id, out);
1646     }
1647     return result;
1648 }
1649 
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)1650 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
1651     // Use writeConstantVector to deduplicate constant splats.
1652     if (c.isCompileTimeConstant()) {
1653         return this->writeConstantVector(c);
1654     }
1655 
1656     // Write the splat argument.
1657     SpvId argument = this->writeExpression(*c.argument(), out);
1658 
1659     // Generate a OpCompositeConstruct which repeats the argument N times.
1660     std::vector<SpvId> arguments(/*count*/ c.type().columns(), /*value*/ argument);
1661     return this->writeComposite(arguments, c.type(), out);
1662 }
1663 
1664 
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)1665 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
1666     SkASSERT(c.type().isArray() || c.type().isStruct());
1667     auto ctorArgs = c.argumentSpan();
1668 
1669     std::vector<SpvId> arguments;
1670     arguments.reserve(ctorArgs.size());
1671     for (const std::unique_ptr<Expression>& arg : ctorArgs) {
1672         arguments.push_back(this->writeExpression(*arg, out));
1673     }
1674 
1675     return this->writeComposite(arguments, c.type(), out);
1676 }
1677 
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)1678 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
1679                                                      OutputStream& out) {
1680     const Type& type = c.type();
1681     if (this->getActualType(type) == this->getActualType(c.argument()->type())) {
1682         return this->writeExpression(*c.argument(), out);
1683     }
1684 
1685     const Expression& ctorExpr = *c.argument();
1686     SpvId expressionId = this->writeExpression(ctorExpr, out);
1687     return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
1688 }
1689 
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)1690 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
1691                                                        OutputStream& out) {
1692     const Type& ctorType = c.type();
1693     const Type& argType = c.argument()->type();
1694     SkASSERT(ctorType.isVector() || ctorType.isMatrix());
1695 
1696     // Write the composite that we are casting. If the actual type matches, we are done.
1697     SpvId compositeId = this->writeExpression(*c.argument(), out);
1698     if (this->getActualType(ctorType) == this->getActualType(argType)) {
1699         return compositeId;
1700     }
1701 
1702     // writeMatrixCopy can cast matrices to a different type.
1703     if (ctorType.isMatrix()) {
1704         return this->writeMatrixCopy(compositeId, argType, ctorType, out);
1705     }
1706 
1707     // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
1708     // components and convert each one manually.
1709     const Type& srcType = argType.componentType();
1710     const Type& dstType = ctorType.componentType();
1711 
1712     std::vector<SpvId> arguments;
1713     arguments.reserve(argType.columns());
1714     for (int index = 0; index < argType.columns(); ++index) {
1715         SpvId componentId = this->nextId(&srcType);
1716         this->writeInstruction(SpvOpCompositeExtract, this->getType(srcType), componentId,
1717                                compositeId, index, out);
1718         arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
1719     }
1720 
1721     return this->writeComposite(arguments, ctorType, out);
1722 }
1723 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)1724 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
1725                                                          OutputStream& out) {
1726     const Type& type = c.type();
1727     SkASSERT(type.isMatrix());
1728     SkASSERT(c.argument()->type().isScalar());
1729 
1730     // Write out the scalar argument.
1731     SpvId argument = this->writeExpression(*c.argument(), out);
1732 
1733     // Build the diagonal matrix.
1734     SpvId result = this->nextId(&type);
1735     this->writeUniformScaleMatrix(result, argument, type, out);
1736     return result;
1737 }
1738 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)1739 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1740                                                        OutputStream& out) {
1741     // Write the input matrix.
1742     SpvId argument = this->writeExpression(*c.argument(), out);
1743 
1744     // Use matrix-copy to resize the input matrix to its new size.
1745     return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
1746 }
1747 
get_storage_class(const Variable & var,SpvStorageClass_ fallbackStorageClass)1748 static SpvStorageClass_ get_storage_class(const Variable& var,
1749                                           SpvStorageClass_ fallbackStorageClass) {
1750     const Modifiers& modifiers = var.modifiers();
1751     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1752         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1753         return SpvStorageClassInput;
1754     }
1755     if (modifiers.fFlags & Modifiers::kOut_Flag) {
1756         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1757         return SpvStorageClassOutput;
1758     }
1759     if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1760         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1761             return SpvStorageClassPushConstant;
1762         }
1763         if (var.type().typeKind() == Type::TypeKind::kSampler ||
1764             var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
1765             var.type().typeKind() == Type::TypeKind::kTexture) {
1766             return SpvStorageClassUniformConstant;
1767         }
1768         return SpvStorageClassUniform;
1769     }
1770     return fallbackStorageClass;
1771 }
1772 
get_storage_class(const Expression & expr)1773 static SpvStorageClass_ get_storage_class(const Expression& expr) {
1774     switch (expr.kind()) {
1775         case Expression::Kind::kVariableReference: {
1776             const Variable& var = *expr.as<VariableReference>().variable();
1777             if (var.storage() != Variable::Storage::kGlobal) {
1778                 return SpvStorageClassFunction;
1779             }
1780             return get_storage_class(var, SpvStorageClassPrivate);
1781         }
1782         case Expression::Kind::kFieldAccess:
1783             return get_storage_class(*expr.as<FieldAccess>().base());
1784         case Expression::Kind::kIndex:
1785             return get_storage_class(*expr.as<IndexExpression>().base());
1786         default:
1787             return SpvStorageClassFunction;
1788     }
1789 }
1790 
getAccessChain(const Expression & expr,OutputStream & out)1791 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1792     std::vector<SpvId> chain;
1793     switch (expr.kind()) {
1794         case Expression::Kind::kIndex: {
1795             const IndexExpression& indexExpr = expr.as<IndexExpression>();
1796             chain = this->getAccessChain(*indexExpr.base(), out);
1797             chain.push_back(this->writeExpression(*indexExpr.index(), out));
1798             break;
1799         }
1800         case Expression::Kind::kFieldAccess: {
1801             const FieldAccess& fieldExpr = expr.as<FieldAccess>();
1802             chain = this->getAccessChain(*fieldExpr.base(), out);
1803             IntLiteral index(/*offset=*/-1, fieldExpr.fieldIndex(), fContext.fTypes.fInt.get());
1804             chain.push_back(this->writeIntLiteral(index));
1805             break;
1806         }
1807         default: {
1808             SpvId id = this->getLValue(expr, out)->getPointer();
1809             SkASSERT(id != (SpvId) -1);
1810             chain.push_back(id);
1811         }
1812     }
1813     return chain;
1814 }
1815 
1816 class PointerLValue : public SPIRVCodeGenerator::LValue {
1817 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision)1818     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
1819                   SPIRVCodeGenerator::Precision precision)
1820     : fGen(gen)
1821     , fPointer(pointer)
1822     , fIsMemoryObject(isMemoryObject)
1823     , fType(type)
1824     , fPrecision(precision) {}
1825 
getPointer()1826     SpvId getPointer() override {
1827         return fPointer;
1828     }
1829 
isMemoryObjectPointer() const1830     bool isMemoryObjectPointer() const override {
1831         return fIsMemoryObject;
1832     }
1833 
load(OutputStream & out)1834     SpvId load(OutputStream& out) override {
1835         SpvId result = fGen.nextId(fPrecision);
1836         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1837         return result;
1838     }
1839 
store(SpvId value,OutputStream & out)1840     void store(SpvId value, OutputStream& out) override {
1841         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1842     }
1843 
1844 private:
1845     SPIRVCodeGenerator& fGen;
1846     const SpvId fPointer;
1847     const bool fIsMemoryObject;
1848     const SpvId fType;
1849     const SPIRVCodeGenerator::Precision fPrecision;
1850 };
1851 
1852 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1853 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType)1854     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
1855                   const Type& baseType, const Type& swizzleType)
1856     : fGen(gen)
1857     , fVecPointer(vecPointer)
1858     , fComponents(components)
1859     , fBaseType(&baseType)
1860     , fSwizzleType(&swizzleType) {}
1861 
applySwizzle(const ComponentArray & components,const Type & newType)1862     bool applySwizzle(const ComponentArray& components, const Type& newType) override {
1863         ComponentArray updatedSwizzle;
1864         for (int8_t component : components) {
1865             if (component < 0 || component >= fComponents.count()) {
1866                 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
1867                 return false;
1868             }
1869             updatedSwizzle.push_back(fComponents[component]);
1870         }
1871         fComponents = updatedSwizzle;
1872         fSwizzleType = &newType;
1873         return true;
1874     }
1875 
load(OutputStream & out)1876     SpvId load(OutputStream& out) override {
1877         SpvId base = fGen.nextId(fBaseType);
1878         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1879         SpvId result = fGen.nextId(fBaseType);
1880         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1881         fGen.writeWord(fGen.getType(*fSwizzleType), out);
1882         fGen.writeWord(result, out);
1883         fGen.writeWord(base, out);
1884         fGen.writeWord(base, out);
1885         for (int component : fComponents) {
1886             fGen.writeWord(component, out);
1887         }
1888         return result;
1889     }
1890 
store(SpvId value,OutputStream & out)1891     void store(SpvId value, OutputStream& out) override {
1892         // use OpVectorShuffle to mix and match the vector components. We effectively create
1893         // a virtual vector out of the concatenation of the left and right vectors, and then
1894         // select components from this virtual vector to make the result vector. For
1895         // instance, given:
1896         // float3L = ...;
1897         // float3R = ...;
1898         // L.xz = R.xy;
1899         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1900         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1901         // (3, 1, 4).
1902         SpvId base = fGen.nextId(fBaseType);
1903         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1904         SpvId shuffle = fGen.nextId(fBaseType);
1905         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
1906         fGen.writeWord(fGen.getType(*fBaseType), out);
1907         fGen.writeWord(shuffle, out);
1908         fGen.writeWord(base, out);
1909         fGen.writeWord(value, out);
1910         for (int i = 0; i < fBaseType->columns(); i++) {
1911             // current offset into the virtual vector, defaults to pulling the unmodified
1912             // value from the left side
1913             int offset = i;
1914             // check to see if we are writing this component
1915             for (size_t j = 0; j < fComponents.size(); j++) {
1916                 if (fComponents[j] == i) {
1917                     // we're writing to this component, so adjust the offset to pull from
1918                     // the correct component of the right side instead of preserving the
1919                     // value from the left
1920                     offset = (int) (j + fBaseType->columns());
1921                     break;
1922                 }
1923             }
1924             fGen.writeWord(offset, out);
1925         }
1926         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1927     }
1928 
1929 private:
1930     SPIRVCodeGenerator& fGen;
1931     const SpvId fVecPointer;
1932     ComponentArray fComponents;
1933     const Type* fBaseType;
1934     const Type* fSwizzleType;
1935 };
1936 
findUniformFieldIndex(const Variable & var) const1937 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
1938     auto iter = fTopLevelUniformMap.find(&var);
1939     return (iter != fTopLevelUniformMap.end()) ? iter->second : -1;
1940 }
1941 
getLValue(const Expression & expr,OutputStream & out)1942 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1943                                                                           OutputStream& out) {
1944     const Type& type = expr.type();
1945     Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
1946     switch (expr.kind()) {
1947         case Expression::Kind::kVariableReference: {
1948             const Variable& var = *expr.as<VariableReference>().variable();
1949             int uniformIdx = this->findUniformFieldIndex(var);
1950             if (uniformIdx >= 0) {
1951                 IntLiteral uniformIdxLiteral{/*offset=*/-1, uniformIdx, fContext.fTypes.fInt.get()};
1952                 SpvId memberId = this->nextId(nullptr);
1953                 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
1954                 SpvId uniformIdxId = this->writeIntLiteral(uniformIdxLiteral);
1955                 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
1956                                        uniformIdxId, out);
1957                 return std::make_unique<PointerLValue>(*this, memberId,
1958                                                        /*isMemoryObjectPointer=*/true,
1959                                                        this->getType(type), precision);
1960             }
1961             SpvId typeId;
1962             if (var.modifiers().fLayout.fBuiltin == SK_IN_BUILTIN) {
1963                 typeId = this->getType(*Type::MakeArrayType("sk_in", var.type().componentType(),
1964                                                             fSkInCount));
1965             } else {
1966                 typeId = this->getType(type, this->memoryLayoutForVariable(var));
1967             }
1968             auto entry = fVariableMap.find(&var);
1969             SkASSERT(entry != fVariableMap.end());
1970             return std::make_unique<PointerLValue>(*this, entry->second,
1971                                                    /*isMemoryObjectPointer=*/true,
1972                                                    typeId, precision);
1973         }
1974         case Expression::Kind::kIndex: // fall through
1975         case Expression::Kind::kFieldAccess: {
1976             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1977             SpvId member = this->nextId(nullptr);
1978             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1979             this->writeWord(this->getPointerType(type, get_storage_class(expr)), out);
1980             this->writeWord(member, out);
1981             for (SpvId idx : chain) {
1982                 this->writeWord(idx, out);
1983             }
1984             return std::make_unique<PointerLValue>(*this, member, /*isMemoryObjectPointer=*/false,
1985                                                    this->getType(type), precision);
1986         }
1987         case Expression::Kind::kSwizzle: {
1988             const Swizzle& swizzle = expr.as<Swizzle>();
1989             std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
1990             if (lvalue->applySwizzle(swizzle.components(), type)) {
1991                 return lvalue;
1992             }
1993             SpvId base = lvalue->getPointer();
1994             if (base == (SpvId) -1) {
1995                 fErrors.error(swizzle.fOffset, "unable to retrieve lvalue from swizzle");
1996             }
1997             if (swizzle.components().size() == 1) {
1998                 SpvId member = this->nextId(nullptr);
1999                 SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base()));
2000                 IntLiteral index(/*offset=*/-1, swizzle.components()[0],
2001                                  fContext.fTypes.fInt.get());
2002                 SpvId indexId = this->writeIntLiteral(index);
2003                 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
2004                 return std::make_unique<PointerLValue>(*this,
2005                                                        member,
2006                                                        /*isMemoryObjectPointer=*/false,
2007                                                        this->getType(type),
2008                                                        precision);
2009             } else {
2010                 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
2011                                                        swizzle.base()->type(), type);
2012             }
2013         }
2014         default: {
2015             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
2016             // to the need to store values in temporary variables during function calls (see
2017             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
2018             // caught by IRGenerator
2019             SpvId result = this->nextId(nullptr);
2020             SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
2021             this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
2022                                    fVariableBuffer);
2023             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
2024             return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
2025                                                    this->getType(type), precision);
2026         }
2027     }
2028 }
2029 
writeVariableReference(const VariableReference & ref,OutputStream & out)2030 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
2031     SpvId result = this->getLValue(ref, out)->load(out);
2032 
2033     // Handle the "flipY" setting when reading sk_FragCoord.
2034     const Variable* variable = ref.variable();
2035     if (variable->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
2036         fProgram.fConfig->fSettings.fFlipY) {
2037         // The x component never changes, so just grab it
2038         SpvId xId = this->nextId(Precision::kDefault);
2039         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fTypes.fFloat), xId,
2040                                result, 0, out);
2041 
2042         // Calculate the y component which may need to be flipped
2043         SpvId rawYId = this->nextId(nullptr);
2044         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fTypes.fFloat),
2045                                rawYId, result, 1, out);
2046         SpvId flippedYId = 0;
2047         if (fProgram.fConfig->fSettings.fFlipY) {
2048             // need to remap to a top-left coordinate system
2049             if (fRTHeightStructId == (SpvId)-1) {
2050                 // height variable hasn't been written yet
2051                 SkASSERT(fRTHeightFieldIndex == (SpvId)-1);
2052                 std::vector<Type::Field> fields;
2053                 if (fProgram.fConfig->fSettings.fRTHeightOffset < 0) {
2054                     fErrors.error(ref.fOffset, "RTHeightOffset is negative");
2055                 }
2056                 fields.emplace_back(
2057                         Modifiers(Layout(/*flags=*/0, /*location=*/-1,
2058                                          fProgram.fConfig->fSettings.fRTHeightOffset,
2059                                          /*binding=*/-1, /*index=*/-1, /*set=*/-1, /*builtin=*/-1,
2060                                          /*inputAttachmentIndex=*/-1,
2061                                          Layout::kUnspecified_Primitive, /*maxVertices=*/1,
2062                                          /*invocations=*/-1, /*when=*/"", Layout::CType::kDefault),
2063                                   /*flags=*/0),
2064                         SKSL_RTHEIGHT_NAME, fContext.fTypes.fFloat.get());
2065                 StringFragment name("sksl_synthetic_uniforms");
2066                 std::unique_ptr<Type> intfStruct = Type::MakeStructType(/*offset=*/-1, name,
2067                                                                         fields);
2068                 int binding = fProgram.fConfig->fSettings.fRTHeightBinding;
2069                 if (binding == -1) {
2070                     fErrors.error(ref.fOffset, "layout(binding=...) is required in SPIR-V");
2071                 }
2072                 int set = fProgram.fConfig->fSettings.fRTHeightSet;
2073                 if (set == -1) {
2074                     fErrors.error(ref.fOffset, "layout(set=...) is required in SPIR-V");
2075                 }
2076                 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
2077                 int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0;
2078                 Modifiers modifiers(
2079                         Layout(flags, /*location=*/-1, /*offset=*/-1, binding, /*index=*/-1,
2080                                set, /*builtin=*/-1, /*inputAttachmentIndex=*/-1,
2081                                Layout::kUnspecified_Primitive,
2082                                /*maxVertices=*/-1, /*invocations=*/-1, /*when=*/"",
2083                                Layout::CType::kDefault),
2084                         Modifiers::kUniform_Flag);
2085                 const Variable* intfVar = fSynthetics.takeOwnershipOfSymbol(
2086                         std::make_unique<Variable>(/*offset=*/-1,
2087                                                    fProgram.fModifiers->add(modifiers),
2088                                                    name,
2089                                                    intfStruct.get(),
2090                                                    /*builtin=*/false,
2091                                                    Variable::Storage::kGlobal));
2092                 InterfaceBlock intf(/*offset=*/-1, intfVar, name,
2093                                     /*instanceName=*/"", /*arraySize=*/0,
2094                                     std::make_shared<SymbolTable>(&fErrors, /*builtin=*/false));
2095 
2096                 fRTHeightStructId = this->writeInterfaceBlock(intf, false);
2097                 fRTHeightFieldIndex = 0;
2098                 fRTHeightStorageClass = usePushConstants ? SpvStorageClassPushConstant
2099                                                          : SpvStorageClassUniform;
2100             }
2101             SkASSERT(fRTHeightFieldIndex != (SpvId)-1);
2102 
2103             IntLiteral fieldIndex(/*offset=*/-1, fRTHeightFieldIndex, fContext.fTypes.fInt.get());
2104             SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
2105             SpvId heightPtr = this->nextId(nullptr);
2106             this->writeOpCode(SpvOpAccessChain, 5, out);
2107             this->writeWord(this->getPointerType(*fContext.fTypes.fFloat, fRTHeightStorageClass),
2108                             out);
2109             this->writeWord(heightPtr, out);
2110             this->writeWord(fRTHeightStructId, out);
2111             this->writeWord(fieldIndexId, out);
2112             SpvId heightRead = this->nextId(nullptr);
2113             this->writeInstruction(SpvOpLoad, this->getType(*fContext.fTypes.fFloat), heightRead,
2114                                    heightPtr, out);
2115 
2116             flippedYId = this->nextId(nullptr);
2117             this->writeInstruction(SpvOpFSub, this->getType(*fContext.fTypes.fFloat), flippedYId,
2118                                    heightRead, rawYId, out);
2119         }
2120 
2121         // The z component will always be zero so we just get an id to the 0 literal
2122         FloatLiteral zero(/*offset=*/-1, /*value=*/0.0, fContext.fTypes.fFloat.get());
2123         SpvId zeroId = writeFloatLiteral(zero);
2124 
2125         // Calculate the w component
2126         SpvId rawWId = this->nextId(nullptr);
2127         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fTypes.fFloat),
2128                                rawWId, result, 3, out);
2129 
2130         // Fill in the new fragcoord with the components from above
2131         SpvId adjusted = this->nextId(nullptr);
2132         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
2133         this->writeWord(this->getType(*fContext.fTypes.fFloat4), out);
2134         this->writeWord(adjusted, out);
2135         this->writeWord(xId, out);
2136         if (fProgram.fConfig->fSettings.fFlipY) {
2137             this->writeWord(flippedYId, out);
2138         } else {
2139             this->writeWord(rawYId, out);
2140         }
2141         this->writeWord(zeroId, out);
2142         this->writeWord(rawWId, out);
2143 
2144         return adjusted;
2145     }
2146 
2147     // Handle the "flipY" setting when reading sk_Clockwise.
2148     if (variable->modifiers().fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
2149         !fProgram.fConfig->fSettings.fFlipY) {
2150         // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
2151         // the default convention of "counter-clockwise face is front".
2152         SpvId inverse = this->nextId(nullptr);
2153         this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fTypes.fBool), inverse,
2154                                result, out);
2155         return inverse;
2156     }
2157 
2158     return result;
2159 }
2160 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)2161 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
2162     if (expr.base()->type().isVector()) {
2163         SpvId base = this->writeExpression(*expr.base(), out);
2164         SpvId index = this->writeExpression(*expr.index(), out);
2165         SpvId result = this->nextId(nullptr);
2166         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
2167                                index, out);
2168         return result;
2169     }
2170     return getLValue(expr, out)->load(out);
2171 }
2172 
writeFieldAccess(const FieldAccess & f,OutputStream & out)2173 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
2174     return getLValue(f, out)->load(out);
2175 }
2176 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)2177 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
2178     SpvId base = this->writeExpression(*swizzle.base(), out);
2179     SpvId result = this->nextId(&swizzle.type());
2180     size_t count = swizzle.components().size();
2181     if (count == 1) {
2182         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.type()), result, base,
2183                                swizzle.components()[0], out);
2184     } else {
2185         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
2186         this->writeWord(this->getType(swizzle.type()), out);
2187         this->writeWord(result, out);
2188         this->writeWord(base, out);
2189         this->writeWord(base, out);
2190         for (int component : swizzle.components()) {
2191             this->writeWord(component, out);
2192         }
2193     }
2194     return result;
2195 }
2196 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2197 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2198                                                const Type& operandType, SpvId lhs,
2199                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2200                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2201     SpvId result = this->nextId(&resultType);
2202     if (is_float(fContext, operandType)) {
2203         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
2204     } else if (is_signed(fContext, operandType)) {
2205         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
2206     } else if (is_unsigned(fContext, operandType)) {
2207         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
2208     } else if (is_bool(fContext, operandType)) {
2209         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
2210     } else {
2211         fErrors.error(operandType.fOffset,
2212                       "unsupported operand for binary expression: " + operandType.description());
2213     }
2214     return result;
2215 }
2216 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)2217 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
2218                                      OutputStream& out) {
2219     if (operandType.isVector()) {
2220         SpvId result = this->nextId(nullptr);
2221         this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
2222         return result;
2223     }
2224     return id;
2225 }
2226 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)2227 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2228                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
2229                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
2230                                                 OutputStream& out) {
2231     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
2232     SkASSERT(operandType.isMatrix());
2233     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2234                                                                             operandType.rows(),
2235                                                                             1));
2236     SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
2237                                                                     operandType.rows(),
2238                                                                     1));
2239     SpvId boolType = this->getType(*fContext.fTypes.fBool);
2240     SpvId result = 0;
2241     for (int i = 0; i < operandType.columns(); i++) {
2242         SpvId columnL = this->nextId(&operandType);
2243         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2244         SpvId columnR = this->nextId(&operandType);
2245         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2246         SpvId compare = this->nextId(&operandType);
2247         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2248         SpvId merge = this->nextId(nullptr);
2249         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2250         if (result != 0) {
2251             SpvId next = this->nextId(nullptr);
2252             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2253             result = next;
2254         }
2255         else {
2256             result = merge;
2257         }
2258     }
2259     return result;
2260 }
2261 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)2262 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2263                                                          SpvId rhs, SpvOp_ floatOperator,
2264                                                          SpvOp_ intOperator,
2265                                                          OutputStream& out) {
2266     SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
2267     SkASSERT(operandType.isMatrix());
2268     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2269                                                                             operandType.rows(),
2270                                                                             1));
2271     SpvId columns[4];
2272     for (int i = 0; i < operandType.columns(); i++) {
2273         SpvId columnL = this->nextId(&operandType);
2274         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2275         SpvId columnR = this->nextId(&operandType);
2276         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2277         columns[i] = this->nextId(&operandType);
2278         this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
2279     }
2280     SpvId result = this->nextId(&operandType);
2281     this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
2282     this->writeWord(this->getType(operandType), out);
2283     this->writeWord(result, out);
2284     for (int i = 0; i < operandType.columns(); i++) {
2285         this->writeWord(columns[i], out);
2286     }
2287     return result;
2288 }
2289 
create_literal_1(const Context & context,const Type & type)2290 static std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2291     if (type.isInteger()) {
2292         return IntLiteral::Make(/*offset=*/-1, /*value=*/1, &type);
2293     }
2294     else if (type.isFloat()) {
2295         return FloatLiteral::Make(/*offset=*/-1, /*value=*/1.0, &type);
2296     } else {
2297         SK_ABORT("math is unsupported on type '%s'", String(type.name()).c_str());
2298     }
2299 }
2300 
writeReciprocal(const Type & type,SpvId value,OutputStream & out)2301 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
2302     SkASSERT(type.isFloat());
2303     SpvId one = this->writeFloatLiteral({/*offset=*/-1, /*value=*/1, &type});
2304     SpvId reciprocal = this->nextId(&type);
2305     this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
2306     return reciprocal;
2307 }
2308 
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2309 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
2310                                                 const Type& rightType, SpvId rhs,
2311                                                 const Type& resultType, OutputStream& out) {
2312     // The comma operator ignores the type of the left-hand side entirely.
2313     if (op.kind() == Token::Kind::TK_COMMA) {
2314         return rhs;
2315     }
2316     // overall type we are operating on: float2, int, uint4...
2317     const Type* operandType;
2318     // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2319     // handling in SPIR-V
2320     if (this->getActualType(leftType) != this->getActualType(rightType)) {
2321         if (leftType.isVector() && rightType.isNumber()) {
2322             if (resultType.componentType().isFloat()) {
2323                 switch (op.kind()) {
2324                     case Token::Kind::TK_SLASH: {
2325                         rhs = this->writeReciprocal(rightType, rhs, out);
2326                         [[fallthrough]];
2327                     }
2328                     case Token::Kind::TK_STAR: {
2329                         SpvId result = this->nextId(&resultType);
2330                         this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2331                                                result, lhs, rhs, out);
2332                         return result;
2333                     }
2334                     default:
2335                         break;
2336                 }
2337             }
2338             // promote number to vector
2339             const Type& vecType = leftType;
2340             SpvId vec = this->nextId(&vecType);
2341             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2342             this->writeWord(this->getType(vecType), out);
2343             this->writeWord(vec, out);
2344             for (int i = 0; i < vecType.columns(); i++) {
2345                 this->writeWord(rhs, out);
2346             }
2347             rhs = vec;
2348             operandType = &leftType;
2349         } else if (rightType.isVector() && leftType.isNumber()) {
2350             if (resultType.componentType().isFloat()) {
2351                 if (op.kind() == Token::Kind::TK_STAR) {
2352                     SpvId result = this->nextId(&resultType);
2353                     this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2354                                            result, rhs, lhs, out);
2355                     return result;
2356                 }
2357             }
2358             // promote number to vector
2359             const Type& vecType = rightType;
2360             SpvId vec = this->nextId(&vecType);
2361             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2362             this->writeWord(this->getType(vecType), out);
2363             this->writeWord(vec, out);
2364             for (int i = 0; i < vecType.columns(); i++) {
2365                 this->writeWord(lhs, out);
2366             }
2367             lhs = vec;
2368             operandType = &rightType;
2369         } else if (leftType.isMatrix()) {
2370             SpvOp_ spvop;
2371             if (rightType.isMatrix()) {
2372                 spvop = SpvOpMatrixTimesMatrix;
2373             } else if (rightType.isVector()) {
2374                 spvop = SpvOpMatrixTimesVector;
2375             } else {
2376                 SkASSERT(rightType.isScalar());
2377                 spvop = SpvOpMatrixTimesScalar;
2378             }
2379             SpvId result = this->nextId(&resultType);
2380             this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2381             return result;
2382         } else if (rightType.isMatrix()) {
2383             SpvId result = this->nextId(&resultType);
2384             if (leftType.isVector()) {
2385                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
2386                                        lhs, rhs, out);
2387             } else {
2388                 SkASSERT(leftType.isScalar());
2389                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
2390                                        rhs, lhs, out);
2391             }
2392             return result;
2393         } else {
2394             fErrors.error(leftType.fOffset, "unsupported mixed-type expression");
2395             return -1;
2396         }
2397     } else {
2398         operandType = &this->getActualType(leftType);
2399         SkASSERT(*operandType == this->getActualType(rightType));
2400     }
2401     switch (op.kind()) {
2402         case Token::Kind::TK_EQEQ: {
2403             if (operandType->isMatrix()) {
2404                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2405                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2406             }
2407             if (operandType->isStruct()) {
2408                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2409             }
2410             if (operandType->isArray()) {
2411                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2412             }
2413             SkASSERT(resultType.isBoolean());
2414             const Type* tmpType;
2415             if (operandType->isVector()) {
2416                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2417                                                              operandType->columns(),
2418                                                              operandType->rows());
2419             } else {
2420                 tmpType = &resultType;
2421             }
2422             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2423                                                                SpvOpFOrdEqual, SpvOpIEqual,
2424                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2425                                     *operandType, SpvOpAll, out);
2426         }
2427         case Token::Kind::TK_NEQ:
2428             if (operandType->isMatrix()) {
2429                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2430                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2431             }
2432             if (operandType->isStruct()) {
2433                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2434             }
2435             if (operandType->isArray()) {
2436                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2437             }
2438             [[fallthrough]];
2439         case Token::Kind::TK_LOGICALXOR:
2440             SkASSERT(resultType.isBoolean());
2441             const Type* tmpType;
2442             if (operandType->isVector()) {
2443                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2444                                                              operandType->columns(),
2445                                                              operandType->rows());
2446             } else {
2447                 tmpType = &resultType;
2448             }
2449             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2450                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2451                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2452                                                                out),
2453                                     *operandType, SpvOpAny, out);
2454         case Token::Kind::TK_GT:
2455             SkASSERT(resultType.isBoolean());
2456             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2457                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2458                                               SpvOpUGreaterThan, SpvOpUndef, out);
2459         case Token::Kind::TK_LT:
2460             SkASSERT(resultType.isBoolean());
2461             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2462                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2463         case Token::Kind::TK_GTEQ:
2464             SkASSERT(resultType.isBoolean());
2465             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2466                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2467                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2468         case Token::Kind::TK_LTEQ:
2469             SkASSERT(resultType.isBoolean());
2470             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2471                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2472                                               SpvOpULessThanEqual, SpvOpUndef, out);
2473         case Token::Kind::TK_PLUS:
2474             if (leftType.isMatrix() && rightType.isMatrix()) {
2475                 SkASSERT(leftType == rightType);
2476                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2477                                                             SpvOpFAdd, SpvOpIAdd, out);
2478             }
2479             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2480                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2481         case Token::Kind::TK_MINUS:
2482             if (leftType.isMatrix() && rightType.isMatrix()) {
2483                 SkASSERT(leftType == rightType);
2484                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2485                                                             SpvOpFSub, SpvOpISub, out);
2486             }
2487             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2488                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2489         case Token::Kind::TK_STAR:
2490             if (leftType.isMatrix() && rightType.isMatrix()) {
2491                 // matrix multiply
2492                 SpvId result = this->nextId(&resultType);
2493                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2494                                        lhs, rhs, out);
2495                 return result;
2496             }
2497             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2498                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2499         case Token::Kind::TK_SLASH:
2500             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2501                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2502         case Token::Kind::TK_PERCENT:
2503             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2504                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2505         case Token::Kind::TK_SHL:
2506             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2507                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2508                                               SpvOpUndef, out);
2509         case Token::Kind::TK_SHR:
2510             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2511                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2512                                               SpvOpUndef, out);
2513         case Token::Kind::TK_BITWISEAND:
2514             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2515                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2516         case Token::Kind::TK_BITWISEOR:
2517             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2518                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2519         case Token::Kind::TK_BITWISEXOR:
2520             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2521                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2522         default:
2523             fErrors.error(0, "unsupported token");
2524             return -1;
2525     }
2526 }
2527 
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)2528 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
2529                                                SpvId rhs, OutputStream& out) {
2530     // The inputs must be arrays, and the op must be == or !=.
2531     SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2532     SkASSERT(arrayType.isArray());
2533     const Type& componentType = arrayType.componentType();
2534     const SpvId componentTypeId = this->getType(componentType);
2535     const int arraySize = arrayType.columns();
2536     SkASSERT(arraySize > 0);
2537 
2538     // Synthesize equality checks for each item in the array.
2539     const Type& boolType = *fContext.fTypes.fBool;
2540     SpvId allComparisons = (SpvId)-1;
2541     for (int index = 0; index < arraySize; ++index) {
2542         // Get the left and right item in the array.
2543         SpvId itemL = this->nextId(&componentType);
2544         this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemL, lhs, index, out);
2545         SpvId itemR = this->nextId(&componentType);
2546         this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemR, rhs, index, out);
2547         // Use `writeBinaryExpression` with the requested == or != operator on these items.
2548         SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
2549                                                        componentType, itemR, boolType, out);
2550         // Merge this comparison result with all the other comparisons we've done.
2551         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2552     }
2553     return allComparisons;
2554 }
2555 
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)2556 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
2557                                                 SpvId rhs, OutputStream& out) {
2558     // The inputs must be structs containing fields, and the op must be == or !=.
2559     SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2560     SkASSERT(structType.isStruct());
2561     const std::vector<Type::Field>& fields = structType.fields();
2562     SkASSERT(!fields.empty());
2563 
2564     // Synthesize equality checks for each field in the struct.
2565     const Type& boolType = *fContext.fTypes.fBool;
2566     SpvId allComparisons = (SpvId)-1;
2567     for (int index = 0; index < (int)fields.size(); ++index) {
2568         // Get the left and right versions of this field.
2569         const Type& fieldType = *fields[index].fType;
2570         const SpvId fieldTypeId = this->getType(fieldType);
2571 
2572         SpvId fieldL = this->nextId(&fieldType);
2573         this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldL, lhs, index, out);
2574         SpvId fieldR = this->nextId(&fieldType);
2575         this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldR, rhs, index, out);
2576         // Use `writeBinaryExpression` with the requested == or != operator on these fields.
2577         SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
2578                                                        boolType, out);
2579         // Merge this comparison result with all the other comparisons we've done.
2580         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2581     }
2582     return allComparisons;
2583 }
2584 
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)2585 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
2586                                            OutputStream& out) {
2587     // If this is the first entry, we don't need to merge comparison results with anything.
2588     if (allComparisons == (SpvId)-1) {
2589         return comparison;
2590     }
2591     // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
2592     const Type& boolType = *fContext.fTypes.fBool;
2593     SpvId boolTypeId = this->getType(boolType);
2594     SpvId logicalOp = this->nextId(&boolType);
2595     switch (op.kind()) {
2596         case Token::Kind::TK_EQEQ:
2597             this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
2598                                    comparison, allComparisons, out);
2599             break;
2600         case Token::Kind::TK_NEQ:
2601             this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
2602                                    comparison, allComparisons, out);
2603             break;
2604         default:
2605             SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
2606             return (SpvId)-1;
2607     }
2608     return logicalOp;
2609 }
2610 
division_by_literal_value(Operator op,const Expression & right)2611 static float division_by_literal_value(Operator op, const Expression& right) {
2612     // If this is a division by a literal value, returns that literal value. Otherwise, returns 0.
2613     if (op.kind() == Token::Kind::TK_SLASH && right.is<FloatLiteral>()) {
2614         float rhsValue = right.as<FloatLiteral>().value();
2615         if (std::isfinite(rhsValue)) {
2616             return rhsValue;
2617         }
2618     }
2619     return 0.0f;
2620 }
2621 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2622 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2623     const Expression* left = b.left().get();
2624     const Expression* right = b.right().get();
2625     Operator op = b.getOperator();
2626 
2627     switch (op.kind()) {
2628         case Token::Kind::TK_EQ: {
2629             // Handles assignment.
2630             SpvId rhs = this->writeExpression(*right, out);
2631             this->getLValue(*left, out)->store(rhs, out);
2632             return rhs;
2633         }
2634         case Token::Kind::TK_LOGICALAND:
2635             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2636             return this->writeLogicalAnd(*b.left(), *b.right(), out);
2637 
2638         case Token::Kind::TK_LOGICALOR:
2639             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2640             return this->writeLogicalOr(*b.left(), *b.right(), out);
2641 
2642         default:
2643             break;
2644     }
2645 
2646     std::unique_ptr<LValue> lvalue;
2647     SpvId lhs;
2648     if (op.isAssignment()) {
2649         lvalue = this->getLValue(*left, out);
2650         lhs = lvalue->load(out);
2651     } else {
2652         lvalue = nullptr;
2653         lhs = this->writeExpression(*left, out);
2654     }
2655 
2656     SpvId rhs;
2657     float rhsValue = division_by_literal_value(op, *right);
2658     if (rhsValue != 0.0f) {
2659         // Rewrite floating-point division by a literal into multiplication by the reciprocal.
2660         // This converts `expr / 2` into `expr * 0.5`
2661         // This improves codegen, especially for certain types of divides (e.g. vector/scalar).
2662         op = Operator(Token::Kind::TK_STAR);
2663         FloatLiteral reciprocal{right->fOffset, 1.0f / rhsValue, &right->type()};
2664         rhs = this->writeExpression(reciprocal, out);
2665     } else {
2666         // Write the right-hand side expression normally.
2667         rhs = this->writeExpression(*right, out);
2668     }
2669 
2670     SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
2671                                                right->type(), rhs, b.type(), out);
2672     if (lvalue) {
2673         lvalue->store(result, out);
2674     }
2675     return result;
2676 }
2677 
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)2678 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
2679                                           OutputStream& out) {
2680     BoolLiteral falseLiteral(/*offset=*/-1, /*value=*/false, fContext.fTypes.fBool.get());
2681     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2682     SpvId lhs = this->writeExpression(left, out);
2683     SpvId rhsLabel = this->nextId(nullptr);
2684     SpvId end = this->nextId(nullptr);
2685     SpvId lhsBlock = fCurrentBlock;
2686     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2687     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2688     this->writeLabel(rhsLabel, out);
2689     SpvId rhs = this->writeExpression(right, out);
2690     SpvId rhsBlock = fCurrentBlock;
2691     this->writeInstruction(SpvOpBranch, end, out);
2692     this->writeLabel(end, out);
2693     SpvId result = this->nextId(nullptr);
2694     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
2695                            lhsBlock, rhs, rhsBlock, out);
2696     return result;
2697 }
2698 
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)2699 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
2700                                          OutputStream& out) {
2701     BoolLiteral trueLiteral(/*offset=*/-1, /*value=*/true, fContext.fTypes.fBool.get());
2702     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2703     SpvId lhs = this->writeExpression(left, out);
2704     SpvId rhsLabel = this->nextId(nullptr);
2705     SpvId end = this->nextId(nullptr);
2706     SpvId lhsBlock = fCurrentBlock;
2707     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2708     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2709     this->writeLabel(rhsLabel, out);
2710     SpvId rhs = this->writeExpression(right, out);
2711     SpvId rhsBlock = fCurrentBlock;
2712     this->writeInstruction(SpvOpBranch, end, out);
2713     this->writeLabel(end, out);
2714     SpvId result = this->nextId(nullptr);
2715     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
2716                            lhsBlock, rhs, rhsBlock, out);
2717     return result;
2718 }
2719 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2720 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2721     const Type& type = t.type();
2722     SpvId test = this->writeExpression(*t.test(), out);
2723     if (t.ifTrue()->type().columns() == 1 &&
2724         t.ifTrue()->isCompileTimeConstant() &&
2725         t.ifFalse()->isCompileTimeConstant()) {
2726         // both true and false are constants, can just use OpSelect
2727         SpvId result = this->nextId(nullptr);
2728         SpvId trueId = this->writeExpression(*t.ifTrue(), out);
2729         SpvId falseId = this->writeExpression(*t.ifFalse(), out);
2730         this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
2731                                out);
2732         return result;
2733     }
2734     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2735     // Adreno. Switched to storing the result in a temp variable as glslang does.
2736     SpvId var = this->nextId(nullptr);
2737     this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
2738                            var, SpvStorageClassFunction, fVariableBuffer);
2739     SpvId trueLabel = this->nextId(nullptr);
2740     SpvId falseLabel = this->nextId(nullptr);
2741     SpvId end = this->nextId(nullptr);
2742     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2743     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2744     this->writeLabel(trueLabel, out);
2745     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out);
2746     this->writeInstruction(SpvOpBranch, end, out);
2747     this->writeLabel(falseLabel, out);
2748     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out);
2749     this->writeInstruction(SpvOpBranch, end, out);
2750     this->writeLabel(end, out);
2751     SpvId result = this->nextId(&type);
2752     this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
2753     return result;
2754 }
2755 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2756 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2757     const Type& type = p.type();
2758     if (p.getOperator().kind() == Token::Kind::TK_MINUS) {
2759         SpvId result = this->nextId(&type);
2760         SpvId typeId = this->getType(type);
2761         SpvId expr = this->writeExpression(*p.operand(), out);
2762         if (is_float(fContext, type)) {
2763             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2764         } else if (is_signed(fContext, type)) {
2765             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2766         } else {
2767             SkDEBUGFAILF("unsupported prefix expression %s", p.description().c_str());
2768         }
2769         return result;
2770     }
2771     switch (p.getOperator().kind()) {
2772         case Token::Kind::TK_PLUS:
2773             return this->writeExpression(*p.operand(), out);
2774         case Token::Kind::TK_PLUSPLUS: {
2775             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2776             SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
2777             SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
2778                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2779                                                       out);
2780             lv->store(result, out);
2781             return result;
2782         }
2783         case Token::Kind::TK_MINUSMINUS: {
2784             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2785             SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
2786             SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
2787                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
2788             lv->store(result, out);
2789             return result;
2790         }
2791         case Token::Kind::TK_LOGICALNOT: {
2792             SkASSERT(p.operand()->type().isBoolean());
2793             SpvId result = this->nextId(nullptr);
2794             this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
2795                                    this->writeExpression(*p.operand(), out), out);
2796             return result;
2797         }
2798         case Token::Kind::TK_BITWISENOT: {
2799             SpvId result = this->nextId(nullptr);
2800             this->writeInstruction(SpvOpNot, this->getType(type), result,
2801                                    this->writeExpression(*p.operand(), out), out);
2802             return result;
2803         }
2804         default:
2805             SkDEBUGFAILF("unsupported prefix expression: %s", p.description().c_str());
2806             return -1;
2807     }
2808 }
2809 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2810 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2811     const Type& type = p.type();
2812     std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2813     SpvId result = lv->load(out);
2814     SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
2815     switch (p.getOperator().kind()) {
2816         case Token::Kind::TK_PLUSPLUS: {
2817             SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,
2818                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2819             lv->store(temp, out);
2820             return result;
2821         }
2822         case Token::Kind::TK_MINUSMINUS: {
2823             SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub,
2824                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2825             lv->store(temp, out);
2826             return result;
2827         }
2828         default:
2829             SkDEBUGFAILF("unsupported postfix expression %s", p.description().c_str());
2830             return -1;
2831     }
2832 }
2833 
writeBoolLiteral(const BoolLiteral & b)2834 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2835     if (b.value()) {
2836         if (fBoolTrue == 0) {
2837             fBoolTrue = this->nextId(nullptr);
2838             this->writeInstruction(SpvOpConstantTrue, this->getType(b.type()), fBoolTrue,
2839                                    fConstantBuffer);
2840         }
2841         return fBoolTrue;
2842     } else {
2843         if (fBoolFalse == 0) {
2844             fBoolFalse = this->nextId(nullptr);
2845             this->writeInstruction(SpvOpConstantFalse, this->getType(b.type()), fBoolFalse,
2846                                    fConstantBuffer);
2847         }
2848         return fBoolFalse;
2849     }
2850 }
2851 
writeIntLiteral(const IntLiteral & i)2852 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2853     SPIRVNumberConstant key{i.value(), i.type().numberKind()};
2854     auto [iter, newlyCreated] = fNumberConstants.insert({key, (SpvId)-1});
2855     if (newlyCreated) {
2856         SpvId result = this->nextId(nullptr);
2857         this->writeInstruction(SpvOpConstant, this->getType(i.type()), result, (SpvId) i.value(),
2858                                fConstantBuffer);
2859         iter->second = result;
2860     }
2861     return iter->second;
2862 }
2863 
writeFloatLiteral(const FloatLiteral & f)2864 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2865     // Convert the float literal into its bit-representation.
2866     float value = f.value();
2867     uint32_t valueBits;
2868     static_assert(sizeof(valueBits) == sizeof(value));
2869     memcpy(&valueBits, &value, sizeof(value));
2870 
2871     SPIRVNumberConstant key{valueBits, f.type().numberKind()};
2872     auto [iter, newlyCreated] = fNumberConstants.insert({key, (SpvId)-1});
2873     if (newlyCreated) {
2874         SpvId result = this->nextId(nullptr);
2875         this->writeInstruction(SpvOpConstant, this->getType(f.type()), result, (SpvId) valueBits,
2876                                fConstantBuffer);
2877         iter->second = result;
2878     }
2879     return iter->second;
2880 }
2881 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2882 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2883     SpvId result = fFunctionMap[&f];
2884     SpvId returnTypeId = this->getType(f.returnType());
2885     SpvId functionTypeId = this->getFunctionType(f);
2886     this->writeInstruction(SpvOpFunction, returnTypeId, result,
2887                            SpvFunctionControlMaskNone, functionTypeId, out);
2888     String mangledName = f.mangledName();
2889     this->writeInstruction(SpvOpName,
2890                            result,
2891                            StringFragment(mangledName.c_str(), mangledName.size()),
2892                            fNameBuffer);
2893     for (const Variable* parameter : f.parameters()) {
2894         SpvId id = this->nextId(nullptr);
2895         fVariableMap[parameter] = id;
2896         SpvId type = this->getPointerType(parameter->type(), SpvStorageClassFunction);
2897         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2898     }
2899     return result;
2900 }
2901 
writeFunction(const FunctionDefinition & f,OutputStream & out)2902 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2903     fVariableBuffer.reset();
2904     SpvId result = this->writeFunctionStart(f.declaration(), out);
2905     fCurrentBlock = 0;
2906     this->writeLabel(this->nextId(nullptr), out);
2907     StringStream bodyBuffer;
2908     this->writeBlock(f.body()->as<Block>(), bodyBuffer);
2909     write_stringstream(fVariableBuffer, out);
2910     if (f.declaration().isMain()) {
2911         write_stringstream(fGlobalInitializersBuffer, out);
2912     }
2913     write_stringstream(bodyBuffer, out);
2914     if (fCurrentBlock) {
2915         if (f.declaration().returnType().isVoid()) {
2916             this->writeInstruction(SpvOpReturn, out);
2917         } else {
2918             this->writeInstruction(SpvOpUnreachable, out);
2919         }
2920     }
2921     this->writeInstruction(SpvOpFunctionEnd, out);
2922     return result;
2923 }
2924 
writeLayout(const Layout & layout,SpvId target)2925 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2926     if (layout.fLocation >= 0) {
2927         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2928                                fDecorationBuffer);
2929     }
2930     if (layout.fBinding >= 0) {
2931         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2932                                fDecorationBuffer);
2933     }
2934     if (layout.fIndex >= 0) {
2935         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2936                                fDecorationBuffer);
2937     }
2938     if (layout.fSet >= 0) {
2939         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2940                                fDecorationBuffer);
2941     }
2942     if (layout.fInputAttachmentIndex >= 0) {
2943         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2944                                layout.fInputAttachmentIndex, fDecorationBuffer);
2945         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2946     }
2947     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2948         layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
2949         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2950                                fDecorationBuffer);
2951     }
2952 }
2953 
writeLayout(const Layout & layout,SpvId target,int member)2954 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2955     if (layout.fLocation >= 0) {
2956         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2957                                layout.fLocation, fDecorationBuffer);
2958     }
2959     if (layout.fBinding >= 0) {
2960         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2961                                layout.fBinding, fDecorationBuffer);
2962     }
2963     if (layout.fIndex >= 0) {
2964         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2965                                layout.fIndex, fDecorationBuffer);
2966     }
2967     if (layout.fSet >= 0) {
2968         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2969                                layout.fSet, fDecorationBuffer);
2970     }
2971     if (layout.fInputAttachmentIndex >= 0) {
2972         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2973                                layout.fInputAttachmentIndex, fDecorationBuffer);
2974     }
2975     if (layout.fBuiltin >= 0) {
2976         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2977                                layout.fBuiltin, fDecorationBuffer);
2978     }
2979 }
2980 
memoryLayoutForVariable(const Variable & v) const2981 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
2982     bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
2983     return pushConstant ? MemoryLayout(MemoryLayout::k430_Standard) : fDefaultLayout;
2984 }
2985 
update_sk_in_count(const Modifiers & m,int * outSkInCount)2986 static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
2987     switch (m.fLayout.fPrimitive) {
2988         case Layout::kPoints_Primitive:
2989             *outSkInCount = 1;
2990             break;
2991         case Layout::kLines_Primitive:
2992             *outSkInCount = 2;
2993             break;
2994         case Layout::kLinesAdjacency_Primitive:
2995             *outSkInCount = 4;
2996             break;
2997         case Layout::kTriangles_Primitive:
2998             *outSkInCount = 3;
2999             break;
3000         case Layout::kTrianglesAdjacency_Primitive:
3001             *outSkInCount = 6;
3002             break;
3003         default:
3004             return;
3005     }
3006 }
3007 
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTHeight)3008 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTHeight) {
3009     MemoryLayout memoryLayout = this->memoryLayoutForVariable(intf.variable());
3010     SpvId result = this->nextId(nullptr);
3011     std::unique_ptr<Type> rtHeightStructType;
3012     const Type* type = &intf.variable().type();
3013     if (!MemoryLayout::LayoutIsSupported(*type)) {
3014         fErrors.error(type->fOffset, "type '" + type->name() + "' is not permitted here");
3015         return this->nextId(nullptr);
3016     }
3017     SpvStorageClass_ storageClass = get_storage_class(intf.variable(), SpvStorageClassFunction);
3018     if (fProgram.fInputs.fRTHeight && appendRTHeight) {
3019         SkASSERT(fRTHeightStructId == (SpvId) -1);
3020         SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
3021         std::vector<Type::Field> fields = type->fields();
3022         fRTHeightStructId = result;
3023         fRTHeightFieldIndex = fields.size();
3024         fRTHeightStorageClass = storageClass;
3025         fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME),
3026                             fContext.fTypes.fFloat.get());
3027         rtHeightStructType = Type::MakeStructType(type->fOffset, type->name(), std::move(fields));
3028         type = rtHeightStructType.get();
3029     }
3030     SpvId typeId;
3031     const Modifiers& intfModifiers = intf.variable().modifiers();
3032     if (intfModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
3033         for (const ProgramElement* e : fProgram.elements()) {
3034             if (e->is<ModifiersDeclaration>()) {
3035                 const Modifiers& m = e->as<ModifiersDeclaration>().modifiers();
3036                 update_sk_in_count(m, &fSkInCount);
3037             }
3038         }
3039         typeId = this->getType(
3040                 *Type::MakeArrayType("sk_in", intf.variable().type().componentType(), fSkInCount),
3041                 memoryLayout);
3042     } else {
3043         typeId = this->getType(*type, memoryLayout);
3044     }
3045     if (intfModifiers.fLayout.fBuiltin == -1) {
3046         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
3047     }
3048     SpvId ptrType = this->nextId(nullptr);
3049     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
3050     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
3051     Layout layout = intfModifiers.fLayout;
3052     if (intfModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
3053         layout.fSet = 0;
3054     }
3055     this->writeLayout(layout, result);
3056     fVariableMap[&intf.variable()] = result;
3057     return result;
3058 }
3059 
is_dead(const Variable & var,const ProgramUsage * usage)3060 static bool is_dead(const Variable& var, const ProgramUsage* usage) {
3061     ProgramUsage::VariableCounts counts = usage->get(var);
3062     if (counts.fRead || counts.fWrite) {
3063         return false;
3064     }
3065     // It's not entirely clear what the rules are for eliding interface variables. Generally, it
3066     // causes problems to elide them, even when they're dead.
3067     return !(var.modifiers().fFlags &
3068              (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag));
3069 }
3070 
writeGlobalVar(ProgramKind kind,const VarDeclaration & varDecl)3071 void SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind, const VarDeclaration& varDecl) {
3072     const Variable& var = varDecl.var();
3073     // 9999 is a sentinel value used in our built-in modules that causes us to ignore these
3074     // declarations, beyond adding them to the symbol table.
3075     constexpr int kBuiltinIgnore = 9999;
3076     if (var.modifiers().fLayout.fBuiltin == kBuiltinIgnore) {
3077         return;
3078     }
3079     if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
3080         kind != ProgramKind::kFragment) {
3081         SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
3082         return;
3083     }
3084     if (is_dead(var, fProgram.fUsage.get())) {
3085         return;
3086     }
3087     SpvStorageClass_ storageClass = get_storage_class(var, SpvStorageClassPrivate);
3088     if (storageClass == SpvStorageClassUniform) {
3089         // Top-level uniforms are emitted in writeUniformBuffer.
3090         fTopLevelUniforms.push_back(&varDecl);
3091         return;
3092     }
3093     const Type& type = var.type();
3094     Layout layout = var.modifiers().fLayout;
3095     if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
3096         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3097     }
3098     SpvId id = this->nextId(&type);
3099     fVariableMap[&var] = id;
3100     SpvId typeId;
3101     if (var.modifiers().fLayout.fBuiltin == SK_IN_BUILTIN) {
3102         typeId = this->getPointerType(
3103                 *Type::MakeArrayType("sk_in", type.componentType(), fSkInCount),
3104                 storageClass);
3105     } else {
3106         typeId = this->getPointerType(type, storageClass);
3107     }
3108     this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
3109     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3110     if (varDecl.value()) {
3111         SkASSERT(!fCurrentBlock);
3112         fCurrentBlock = -1;
3113         SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
3114         this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
3115         fCurrentBlock = 0;
3116     }
3117     this->writeLayout(layout, id);
3118     if (var.modifiers().fFlags & Modifiers::kFlat_Flag) {
3119         this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
3120     }
3121     if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) {
3122         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
3123                                 fDecorationBuffer);
3124     }
3125 }
3126 
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)3127 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
3128     const Variable& var = varDecl.var();
3129     SpvId id = this->nextId(&var.type());
3130     fVariableMap[&var] = id;
3131     SpvId type = this->getPointerType(var.type(), SpvStorageClassFunction);
3132     this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
3133     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3134     if (varDecl.value()) {
3135         SpvId value = this->writeExpression(*varDecl.value(), out);
3136         this->writeInstruction(SpvOpStore, id, value, out);
3137     }
3138 }
3139 
writeStatement(const Statement & s,OutputStream & out)3140 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
3141     switch (s.kind()) {
3142         case Statement::Kind::kInlineMarker:
3143         case Statement::Kind::kNop:
3144             break;
3145         case Statement::Kind::kBlock:
3146             this->writeBlock(s.as<Block>(), out);
3147             break;
3148         case Statement::Kind::kExpression:
3149             this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
3150             break;
3151         case Statement::Kind::kReturn:
3152             this->writeReturnStatement(s.as<ReturnStatement>(), out);
3153             break;
3154         case Statement::Kind::kVarDeclaration:
3155             this->writeVarDeclaration(s.as<VarDeclaration>(), out);
3156             break;
3157         case Statement::Kind::kIf:
3158             this->writeIfStatement(s.as<IfStatement>(), out);
3159             break;
3160         case Statement::Kind::kFor:
3161             this->writeForStatement(s.as<ForStatement>(), out);
3162             break;
3163         case Statement::Kind::kDo:
3164             this->writeDoStatement(s.as<DoStatement>(), out);
3165             break;
3166         case Statement::Kind::kSwitch:
3167             this->writeSwitchStatement(s.as<SwitchStatement>(), out);
3168             break;
3169         case Statement::Kind::kBreak:
3170             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
3171             break;
3172         case Statement::Kind::kContinue:
3173             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
3174             break;
3175         case Statement::Kind::kDiscard:
3176             this->writeInstruction(SpvOpKill, out);
3177             break;
3178         default:
3179             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
3180             break;
3181     }
3182 }
3183 
writeBlock(const Block & b,OutputStream & out)3184 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
3185     for (const std::unique_ptr<Statement>& stmt : b.children()) {
3186         this->writeStatement(*stmt, out);
3187     }
3188 }
3189 
writeIfStatement(const IfStatement & stmt,OutputStream & out)3190 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
3191     SpvId test = this->writeExpression(*stmt.test(), out);
3192     SpvId ifTrue = this->nextId(nullptr);
3193     SpvId ifFalse = this->nextId(nullptr);
3194     if (stmt.ifFalse()) {
3195         SpvId end = this->nextId(nullptr);
3196         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3197         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3198         this->writeLabel(ifTrue, out);
3199         this->writeStatement(*stmt.ifTrue(), out);
3200         if (fCurrentBlock) {
3201             this->writeInstruction(SpvOpBranch, end, out);
3202         }
3203         this->writeLabel(ifFalse, out);
3204         this->writeStatement(*stmt.ifFalse(), out);
3205         if (fCurrentBlock) {
3206             this->writeInstruction(SpvOpBranch, end, out);
3207         }
3208         this->writeLabel(end, out);
3209     } else {
3210         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
3211         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3212         this->writeLabel(ifTrue, out);
3213         this->writeStatement(*stmt.ifTrue(), out);
3214         if (fCurrentBlock) {
3215             this->writeInstruction(SpvOpBranch, ifFalse, out);
3216         }
3217         this->writeLabel(ifFalse, out);
3218     }
3219 }
3220 
writeForStatement(const ForStatement & f,OutputStream & out)3221 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
3222     if (f.initializer()) {
3223         this->writeStatement(*f.initializer(), out);
3224     }
3225     SpvId header = this->nextId(nullptr);
3226     SpvId start = this->nextId(nullptr);
3227     SpvId body = this->nextId(nullptr);
3228     SpvId next = this->nextId(nullptr);
3229     fContinueTarget.push(next);
3230     SpvId end = this->nextId(nullptr);
3231     fBreakTarget.push(end);
3232     this->writeInstruction(SpvOpBranch, header, out);
3233     this->writeLabel(header, out);
3234     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
3235     this->writeInstruction(SpvOpBranch, start, out);
3236     this->writeLabel(start, out);
3237     if (f.test()) {
3238         SpvId test = this->writeExpression(*f.test(), out);
3239         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
3240     } else {
3241         this->writeInstruction(SpvOpBranch, body, out);
3242     }
3243     this->writeLabel(body, out);
3244     this->writeStatement(*f.statement(), out);
3245     if (fCurrentBlock) {
3246         this->writeInstruction(SpvOpBranch, next, out);
3247     }
3248     this->writeLabel(next, out);
3249     if (f.next()) {
3250         this->writeExpression(*f.next(), out);
3251     }
3252     this->writeInstruction(SpvOpBranch, header, out);
3253     this->writeLabel(end, out);
3254     fBreakTarget.pop();
3255     fContinueTarget.pop();
3256 }
3257 
writeDoStatement(const DoStatement & d,OutputStream & out)3258 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
3259     SpvId header = this->nextId(nullptr);
3260     SpvId start = this->nextId(nullptr);
3261     SpvId next = this->nextId(nullptr);
3262     SpvId continueTarget = this->nextId(nullptr);
3263     fContinueTarget.push(continueTarget);
3264     SpvId end = this->nextId(nullptr);
3265     fBreakTarget.push(end);
3266     this->writeInstruction(SpvOpBranch, header, out);
3267     this->writeLabel(header, out);
3268     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
3269     this->writeInstruction(SpvOpBranch, start, out);
3270     this->writeLabel(start, out);
3271     this->writeStatement(*d.statement(), out);
3272     if (fCurrentBlock) {
3273         this->writeInstruction(SpvOpBranch, next, out);
3274     }
3275     this->writeLabel(next, out);
3276     this->writeInstruction(SpvOpBranch, continueTarget, out);
3277     this->writeLabel(continueTarget, out);
3278     SpvId test = this->writeExpression(*d.test(), out);
3279     this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
3280     this->writeLabel(end, out);
3281     fBreakTarget.pop();
3282     fContinueTarget.pop();
3283 }
3284 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)3285 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3286     SpvId value = this->writeExpression(*s.value(), out);
3287     std::vector<SpvId> labels;
3288     SpvId end = this->nextId(nullptr);
3289     SpvId defaultLabel = end;
3290     fBreakTarget.push(end);
3291     int size = 3;
3292     auto& cases = s.cases();
3293     for (const std::unique_ptr<Statement>& stmt : cases) {
3294         const SwitchCase& c = stmt->as<SwitchCase>();
3295         SpvId label = this->nextId(nullptr);
3296         labels.push_back(label);
3297         if (c.value()) {
3298             size += 2;
3299         } else {
3300             defaultLabel = label;
3301         }
3302     }
3303     labels.push_back(end);
3304     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3305     this->writeOpCode(SpvOpSwitch, size, out);
3306     this->writeWord(value, out);
3307     this->writeWord(defaultLabel, out);
3308     for (size_t i = 0; i < cases.size(); ++i) {
3309         const SwitchCase& c = cases[i]->as<SwitchCase>();
3310         if (!c.value()) {
3311             continue;
3312         }
3313         this->writeWord(c.value()->as<IntLiteral>().value(), out);
3314         this->writeWord(labels[i], out);
3315     }
3316     for (size_t i = 0; i < cases.size(); ++i) {
3317         const SwitchCase& c = cases[i]->as<SwitchCase>();
3318         this->writeLabel(labels[i], out);
3319         this->writeStatement(*c.statement(), out);
3320         if (fCurrentBlock) {
3321             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3322         }
3323     }
3324     this->writeLabel(end, out);
3325     fBreakTarget.pop();
3326 }
3327 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3328 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3329     if (r.expression()) {
3330         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
3331                                out);
3332     } else {
3333         this->writeInstruction(SpvOpReturn, out);
3334     }
3335 }
3336 
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)3337 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
3338     SkASSERT(fProgram.fConfig->fKind == ProgramKind::kGeometry);
3339     int invocations = 1;
3340     for (const ProgramElement* e : fProgram.elements()) {
3341         if (e->is<ModifiersDeclaration>()) {
3342             const Modifiers& m = e->as<ModifiersDeclaration>().modifiers();
3343             if (m.fFlags & Modifiers::kIn_Flag) {
3344                 if (m.fLayout.fInvocations != -1) {
3345                     invocations = m.fLayout.fInvocations;
3346                 }
3347                 SpvId input;
3348                 switch (m.fLayout.fPrimitive) {
3349                     case Layout::kPoints_Primitive:
3350                         input = SpvExecutionModeInputPoints;
3351                         break;
3352                     case Layout::kLines_Primitive:
3353                         input = SpvExecutionModeInputLines;
3354                         break;
3355                     case Layout::kLinesAdjacency_Primitive:
3356                         input = SpvExecutionModeInputLinesAdjacency;
3357                         break;
3358                     case Layout::kTriangles_Primitive:
3359                         input = SpvExecutionModeTriangles;
3360                         break;
3361                     case Layout::kTrianglesAdjacency_Primitive:
3362                         input = SpvExecutionModeInputTrianglesAdjacency;
3363                         break;
3364                     default:
3365                         input = 0;
3366                         break;
3367                 }
3368                 update_sk_in_count(m, &fSkInCount);
3369                 if (input) {
3370                     this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
3371                 }
3372             } else if (m.fFlags & Modifiers::kOut_Flag) {
3373                 SpvId output;
3374                 switch (m.fLayout.fPrimitive) {
3375                     case Layout::kPoints_Primitive:
3376                         output = SpvExecutionModeOutputPoints;
3377                         break;
3378                     case Layout::kLineStrip_Primitive:
3379                         output = SpvExecutionModeOutputLineStrip;
3380                         break;
3381                     case Layout::kTriangleStrip_Primitive:
3382                         output = SpvExecutionModeOutputTriangleStrip;
3383                         break;
3384                     default:
3385                         output = 0;
3386                         break;
3387                 }
3388                 if (output) {
3389                     this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
3390                 }
3391                 if (m.fLayout.fMaxVertices != -1) {
3392                     this->writeInstruction(SpvOpExecutionMode, entryPoint,
3393                                            SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
3394                                            out);
3395                 }
3396             }
3397         }
3398     }
3399     this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
3400                            invocations, out);
3401 }
3402 
3403 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)3404 static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
3405     return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
3406 }
3407 
writeEntrypointAdapter(const FunctionDeclaration & main)3408 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
3409         const FunctionDeclaration& main) {
3410     // Our goal is to synthesize a tiny helper function which looks like this:
3411     //     void _entrypoint() { sk_FragColor = main(); }
3412 
3413     // Fish a symbol table out of main().
3414     std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main);
3415 
3416     // Get `sk_FragColor` as a writable reference.
3417     const Symbol* skFragColorSymbol = (*symbolTable)["sk_FragColor"];
3418     SkASSERT(skFragColorSymbol);
3419     const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
3420     auto skFragColorRef = std::make_unique<VariableReference>(/*offset=*/-1, &skFragColorVar,
3421                                                               VariableReference::RefKind::kWrite);
3422     // Synthesize a call to the `main()` function.
3423     if (main.returnType() != skFragColorRef->type()) {
3424         fErrors.error(main.fOffset, "SPIR-V does not support returning '" +
3425                                     main.returnType().description() + "' from main()");
3426         return {};
3427     }
3428     ExpressionArray args;
3429     if (main.parameters().size() == 1) {
3430         if (main.parameters()[0]->type() != *fContext.fTypes.fFloat2) {
3431             fErrors.error(main.fOffset,
3432                           "SPIR-V does not support parameter of type '" +
3433                                   main.parameters()[0]->type().description() + "' to main()");
3434             return {};
3435         }
3436         auto zero = std::make_unique<FloatLiteral>(
3437                 /*offset=*/-1, 0.0f, fContext.fTypes.fFloatLiteral.get());
3438         auto zeros = std::make_unique<ConstructorSplat>(
3439                 /*offset=*/-1, *fContext.fTypes.fFloat2, std::move(zero));
3440         args.push_back(std::move(zeros));
3441     }
3442     auto callMainFn = std::make_unique<FunctionCall>(/*offset=*/-1, &main.returnType(), &main,
3443                                                      std::move(args));
3444 
3445     // Synthesize `skFragColor = main()` as a BinaryExpression.
3446     auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
3447             /*offset=*/-1,
3448             std::move(skFragColorRef),
3449             Token::Kind::TK_EQ,
3450             std::move(callMainFn),
3451             &main.returnType()));
3452 
3453     // Function bodies are always wrapped in a Block.
3454     StatementArray entrypointStmts;
3455     entrypointStmts.push_back(std::move(assignmentStmt));
3456     auto entrypointBlock = Block::Make(/*offset=*/-1, std::move(entrypointStmts),
3457                                        symbolTable, /*isScope=*/true);
3458     // Declare an entrypoint function.
3459     EntrypointAdapter adapter;
3460     adapter.fLayout = {};
3461     adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kHasSideEffects_Flag};
3462     adapter.entrypointDecl =
3463             std::make_unique<FunctionDeclaration>(/*offset=*/-1,
3464                                                   &adapter.fModifiers,
3465                                                   "_entrypoint",
3466                                                   /*parameters=*/std::vector<const Variable*>{},
3467                                                   /*returnType=*/fContext.fTypes.fVoid.get(),
3468                                                   /*builtin=*/false);
3469     // Define it.
3470     adapter.entrypointDef =
3471             std::make_unique<FunctionDefinition>(/*offset=*/-1, adapter.entrypointDecl.get(),
3472                                                  /*builtin=*/false,
3473                                                  /*body=*/std::move(entrypointBlock));
3474 
3475     adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
3476     return adapter;
3477 }
3478 
writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable)3479 void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) {
3480     SkASSERT(!fTopLevelUniforms.empty());
3481     static constexpr char kUniformBufferName[] = "_UniformBuffer";
3482 
3483     // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
3484     // a lookup table of variables to UniformBuffer field indices.
3485     std::vector<Type::Field> fields;
3486     fields.reserve(fTopLevelUniforms.size());
3487     fTopLevelUniformMap.reserve(fTopLevelUniforms.size());
3488     for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
3489         const Variable* var = &topLevelUniform->var();
3490         fTopLevelUniformMap[var] = (int)fields.size();
3491         fields.emplace_back(var->modifiers(), var->name(), &var->type());
3492     }
3493     fUniformBuffer.fStruct = Type::MakeStructType(/*offset=*/-1, kUniformBufferName,
3494                                                  std::move(fields));
3495 
3496     // Create a global variable to contain this struct.
3497     Layout layout;
3498     layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
3499     layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
3500     Modifiers modifiers{layout, Modifiers::kUniform_Flag};
3501 
3502     fUniformBuffer.fInnerVariable = std::make_unique<Variable>(
3503             /*offset=*/-1, fProgram.fModifiers->add(modifiers), kUniformBufferName,
3504             fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal);
3505 
3506     // Create an interface block object for this global variable.
3507     fUniformBuffer.fInterfaceBlock = std::make_unique<InterfaceBlock>(
3508             /*offset=*/-1, fUniformBuffer.fInnerVariable.get(), kUniformBufferName,
3509             kUniformBufferName, /*arraySize=*/0, topLevelSymbolTable);
3510 
3511     // Generate an interface block and hold onto its ID.
3512     fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
3513 }
3514 
writeInstructions(const Program & program,OutputStream & out)3515 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3516     fGLSLExtendedInstructions = this->nextId(nullptr);
3517     StringStream body;
3518     // Assign SpvIds to functions.
3519     const FunctionDeclaration* main = nullptr;
3520     for (const ProgramElement* e : program.elements()) {
3521         if (e->is<FunctionDefinition>()) {
3522             const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
3523             const FunctionDeclaration& funcDecl = funcDef.declaration();
3524             fFunctionMap[&funcDecl] = this->nextId(nullptr);
3525             if (funcDecl.isMain()) {
3526                 main = &funcDecl;
3527             }
3528         }
3529     }
3530     // Make sure we have a main() function.
3531     if (!main) {
3532         fErrors.error(/*offset=*/0, "program does not contain a main() function");
3533         return;
3534     }
3535     // Emit interface blocks.
3536     std::set<SpvId> interfaceVars;
3537     for (const ProgramElement* e : program.elements()) {
3538         if (e->is<InterfaceBlock>()) {
3539             const InterfaceBlock& intf = e->as<InterfaceBlock>();
3540             SpvId id = this->writeInterfaceBlock(intf);
3541 
3542             const Modifiers& modifiers = intf.variable().modifiers();
3543             if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3544                 modifiers.fLayout.fBuiltin == -1 &&
3545                 !is_dead(intf.variable(), fProgram.fUsage.get())) {
3546                 interfaceVars.insert(id);
3547             }
3548         }
3549     }
3550     // Emit global variable declarations.
3551     for (const ProgramElement* e : program.elements()) {
3552         if (e->is<GlobalVarDeclaration>()) {
3553             this->writeGlobalVar(program.fConfig->fKind,
3554                                  e->as<GlobalVarDeclaration>().declaration()->as<VarDeclaration>());
3555         }
3556     }
3557     // Emit top-level uniforms into a dedicated uniform buffer.
3558     if (!fTopLevelUniforms.empty()) {
3559         this->writeUniformBuffer(get_top_level_symbol_table(*main));
3560     }
3561     // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
3562     // main() and stores the result into sk_FragColor.
3563     EntrypointAdapter adapter;
3564     if (main->returnType() == *fContext.fTypes.fHalf4) {
3565         adapter = this->writeEntrypointAdapter(*main);
3566         if (adapter.entrypointDecl) {
3567             fFunctionMap[adapter.entrypointDecl.get()] = this->nextId(nullptr);
3568             this->writeFunction(*adapter.entrypointDef, body);
3569             main = adapter.entrypointDecl.get();
3570         }
3571     }
3572     // Emit all the functions.
3573     for (const ProgramElement* e : program.elements()) {
3574         if (e->is<FunctionDefinition>()) {
3575             this->writeFunction(e->as<FunctionDefinition>(), body);
3576         }
3577     }
3578     // Add global in/out variables to the list of interface variables.
3579     for (auto entry : fVariableMap) {
3580         const Variable* var = entry.first;
3581         if (var->storage() == Variable::Storage::kGlobal &&
3582             (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3583             !is_dead(*var, fProgram.fUsage.get())) {
3584             interfaceVars.insert(entry.second);
3585         }
3586     }
3587     this->writeCapabilities(out);
3588     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3589     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3590     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().fLength + 4) / 4) +
3591                       (int32_t) interfaceVars.size(), out);
3592     switch (program.fConfig->fKind) {
3593         case ProgramKind::kVertex:
3594             this->writeWord(SpvExecutionModelVertex, out);
3595             break;
3596         case ProgramKind::kFragment:
3597             this->writeWord(SpvExecutionModelFragment, out);
3598             break;
3599         case ProgramKind::kGeometry:
3600             this->writeWord(SpvExecutionModelGeometry, out);
3601             break;
3602         default:
3603             SK_ABORT("cannot write this kind of program to SPIR-V\n");
3604     }
3605     SpvId entryPoint = fFunctionMap[main];
3606     this->writeWord(entryPoint, out);
3607     this->writeString(main->name().fChars, main->name().fLength, out);
3608     for (int var : interfaceVars) {
3609         this->writeWord(var, out);
3610     }
3611     if (program.fConfig->fKind == ProgramKind::kGeometry) {
3612         this->writeGeometryShaderExecutionMode(entryPoint, out);
3613     }
3614     if (program.fConfig->fKind == ProgramKind::kFragment) {
3615         this->writeInstruction(SpvOpExecutionMode,
3616                                fFunctionMap[main],
3617                                SpvExecutionModeOriginUpperLeft,
3618                                out);
3619     }
3620     for (const ProgramElement* e : program.elements()) {
3621         if (e->is<Extension>()) {
3622             this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name().c_str(), out);
3623         }
3624     }
3625 
3626     write_stringstream(fExtraGlobalsBuffer, out);
3627     write_stringstream(fNameBuffer, out);
3628     write_stringstream(fDecorationBuffer, out);
3629     write_stringstream(fConstantBuffer, out);
3630     write_stringstream(fExternalFunctionsBuffer, out);
3631     write_stringstream(body, out);
3632 }
3633 
generateCode()3634 bool SPIRVCodeGenerator::generateCode() {
3635     SkASSERT(!fErrors.errorCount());
3636     this->writeWord(SpvMagicNumber, *fOut);
3637     this->writeWord(SpvVersion, *fOut);
3638     this->writeWord(SKSL_MAGIC, *fOut);
3639     StringStream buffer;
3640     this->writeInstructions(fProgram, buffer);
3641     this->writeWord(fIdCount, *fOut);
3642     this->writeWord(0, *fOut); // reserved, always zero
3643     write_stringstream(buffer, *fOut);
3644     return 0 == fErrors.errorCount();
3645 }
3646 
3647 }  // namespace SkSL
3648