• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLSPIRVCodeGenerator.h"
9 
10 #include "src/sksl/GLSL.std.450.h"
11 
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/ir/SkSLExpressionStatement.h"
14 #include "src/sksl/ir/SkSLExtension.h"
15 #include "src/sksl/ir/SkSLIndexExpression.h"
16 #include "src/sksl/ir/SkSLVariableReference.h"
17 
18 #ifdef SK_VULKAN
19 #include "src/gpu/vk/GrVkCaps.h"
20 #endif
21 
22 namespace SkSL {
23 
24 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
25 
setupIntrinsics()26 void SPIRVCodeGenerator::setupIntrinsics() {
27 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
28                                     GLSLstd450 ## x, GLSLstd450 ## x)
29 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
30                                                              GLSLstd450 ## ifFloat, \
31                                                              GLSLstd450 ## ifInt, \
32                                                              GLSLstd450 ## ifUInt, \
33                                                              SpvOpUndef)
34 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
35                                                            SpvOp ## x)
36 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
37                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
38                                    k ## x ## _SpecialIntrinsic)
39     fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
40     fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
41     fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
42     fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
43     fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
44     fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
45     fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
46     fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
47     fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
48     fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
49     fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
50     fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
51     fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
52     fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
53     fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
54     fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
55     fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
56     fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
57     fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
58     fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
59     fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
60     fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
61     fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
62     fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
63     fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
64     fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
65     fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
66     fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
67     fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
68     fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
69     fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
70     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
71     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
72     fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
73     fIntrinsicMap[String("min")]           = SPECIAL(Min);
74     fIntrinsicMap[String("max")]           = SPECIAL(Max);
75     fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
76     fIntrinsicMap[String("saturate")]      = SPECIAL(Saturate);
77     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
78                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
79     fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
80     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
81     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
82     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
83     fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
84     fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
85 
86 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
87                    fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
88     PACK(Snorm4x8);
89     PACK(Unorm4x8);
90     PACK(Snorm2x16);
91     PACK(Unorm2x16);
92     PACK(Half2x16);
93     PACK(Double2x32);
94     fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
95     fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
96     fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
97     fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
98     fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
99     fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
100     fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
101     fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
102     fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
103     fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
104                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
105     fIntrinsicMap[String("dFdy")]        = SPECIAL(DFdy);
106     fIntrinsicMap[String("fwidth")]      = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFwidth,
107                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
108     fIntrinsicMap[String("makeSampler2D")] = SPECIAL(SampledImage);
109 
110     fIntrinsicMap[String("sample")]      = SPECIAL(Texture);
111     fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
112 
113     fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
114                                                                 SpvOpUndef, SpvOpUndef, SpvOpAny);
115     fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
116                                                                 SpvOpUndef, SpvOpUndef, SpvOpAll);
117     fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
118                                                                 SpvOpFOrdEqual, SpvOpIEqual,
119                                                                 SpvOpIEqual, SpvOpLogicalEqual);
120     fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
121                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
122                                                                 SpvOpINotEqual,
123                                                                 SpvOpLogicalNotEqual);
124     fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
125                                                                 SpvOpFOrdLessThan, SpvOpSLessThan,
126                                                                 SpvOpULessThan, SpvOpUndef);
127     fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                 SpvOpFOrdLessThanEqual,
129                                                                 SpvOpSLessThanEqual,
130                                                                 SpvOpULessThanEqual,
131                                                                 SpvOpUndef);
132     fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                 SpvOpFOrdGreaterThan,
134                                                                 SpvOpSGreaterThan,
135                                                                 SpvOpUGreaterThan,
136                                                                 SpvOpUndef);
137     fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
138                                                                 SpvOpFOrdGreaterThanEqual,
139                                                                 SpvOpSGreaterThanEqual,
140                                                                 SpvOpUGreaterThanEqual,
141                                                                 SpvOpUndef);
142     fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
143     fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
144 // interpolateAt* not yet supported...
145 }
146 
writeWord(int32_t word,OutputStream & out)147 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
148     out.write((const char*) &word, sizeof(word));
149 }
150 
is_float(const Context & context,const Type & type)151 static bool is_float(const Context& context, const Type& type) {
152     if (type.columns() > 1) {
153         return is_float(context, type.componentType());
154     }
155     return type == *context.fFloat_Type || type == *context.fHalf_Type ||
156            type == *context.fDouble_Type;
157 }
158 
is_signed(const Context & context,const Type & type)159 static bool is_signed(const Context& context, const Type& type) {
160     if (type.kind() == Type::kVector_Kind) {
161         return is_signed(context, type.componentType());
162     }
163     return type == *context.fInt_Type || type == *context.fShort_Type ||
164            type == *context.fByte_Type;
165 }
166 
is_unsigned(const Context & context,const Type & type)167 static bool is_unsigned(const Context& context, const Type& type) {
168     if (type.kind() == Type::kVector_Kind) {
169         return is_unsigned(context, type.componentType());
170     }
171     return type == *context.fUInt_Type || type == *context.fUShort_Type ||
172            type == *context.fUByte_Type;
173 }
174 
is_bool(const Context & context,const Type & type)175 static bool is_bool(const Context& context, const Type& type) {
176     if (type.kind() == Type::kVector_Kind) {
177         return is_bool(context, type.componentType());
178     }
179     return type == *context.fBool_Type;
180 }
181 
is_out(const Variable & var)182 static bool is_out(const Variable& var) {
183     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
184 }
185 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)186 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
187     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
188     SkASSERT(opCode != SpvOpUndef);
189     switch (opCode) {
190         case SpvOpReturn:      // fall through
191         case SpvOpReturnValue: // fall through
192         case SpvOpKill:        // fall through
193         case SpvOpBranch:      // fall through
194         case SpvOpBranchConditional:
195             SkASSERT(fCurrentBlock);
196             fCurrentBlock = 0;
197             break;
198         case SpvOpConstant:          // fall through
199         case SpvOpConstantTrue:      // fall through
200         case SpvOpConstantFalse:     // fall through
201         case SpvOpConstantComposite: // fall through
202         case SpvOpTypeVoid:          // fall through
203         case SpvOpTypeInt:           // fall through
204         case SpvOpTypeFloat:         // fall through
205         case SpvOpTypeBool:          // fall through
206         case SpvOpTypeVector:        // fall through
207         case SpvOpTypeMatrix:        // fall through
208         case SpvOpTypeArray:         // fall through
209         case SpvOpTypePointer:       // fall through
210         case SpvOpTypeFunction:      // fall through
211         case SpvOpTypeRuntimeArray:  // fall through
212         case SpvOpTypeStruct:        // fall through
213         case SpvOpTypeImage:         // fall through
214         case SpvOpTypeSampledImage:  // fall through
215         case SpvOpTypeSampler:       // fall through
216         case SpvOpVariable:          // fall through
217         case SpvOpFunction:          // fall through
218         case SpvOpFunctionParameter: // fall through
219         case SpvOpFunctionEnd:       // fall through
220         case SpvOpExecutionMode:     // fall through
221         case SpvOpMemoryModel:       // fall through
222         case SpvOpCapability:        // fall through
223         case SpvOpExtInstImport:     // fall through
224         case SpvOpEntryPoint:        // fall through
225         case SpvOpSource:            // fall through
226         case SpvOpSourceExtension:   // fall through
227         case SpvOpName:              // fall through
228         case SpvOpMemberName:        // fall through
229         case SpvOpDecorate:          // fall through
230         case SpvOpMemberDecorate:
231             break;
232         default:
233             SkASSERT(fCurrentBlock);
234     }
235     this->writeWord((length << 16) | opCode, out);
236 }
237 
writeLabel(SpvId label,OutputStream & out)238 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
239     fCurrentBlock = label;
240     this->writeInstruction(SpvOpLabel, label, out);
241 }
242 
writeInstruction(SpvOp_ opCode,OutputStream & out)243 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
244     this->writeOpCode(opCode, 1, out);
245 }
246 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)247 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
248     this->writeOpCode(opCode, 2, out);
249     this->writeWord(word1, out);
250 }
251 
writeString(const char * string,size_t length,OutputStream & out)252 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
253     out.write(string, length);
254     switch (length % 4) {
255         case 1:
256             out.write8(0);
257             // fall through
258         case 2:
259             out.write8(0);
260             // fall through
261         case 3:
262             out.write8(0);
263             break;
264         default:
265             this->writeWord(0, out);
266     }
267 }
268 
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)269 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
270     this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
271     this->writeString(string.fChars, string.fLength, out);
272 }
273 
274 
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)275 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
276                                           OutputStream& out) {
277     this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
278     this->writeWord(word1, out);
279     this->writeString(string.fChars, string.fLength, out);
280 }
281 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)282 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283                                           StringFragment string, OutputStream& out) {
284     this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
285     this->writeWord(word1, out);
286     this->writeWord(word2, out);
287     this->writeString(string.fChars, string.fLength, out);
288 }
289 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)290 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
291                                           OutputStream& out) {
292     this->writeOpCode(opCode, 3, out);
293     this->writeWord(word1, out);
294     this->writeWord(word2, out);
295 }
296 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)297 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298                                           int32_t word3, OutputStream& out) {
299     this->writeOpCode(opCode, 4, out);
300     this->writeWord(word1, out);
301     this->writeWord(word2, out);
302     this->writeWord(word3, out);
303 }
304 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)305 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
306                                           int32_t word3, int32_t word4, OutputStream& out) {
307     this->writeOpCode(opCode, 5, out);
308     this->writeWord(word1, out);
309     this->writeWord(word2, out);
310     this->writeWord(word3, out);
311     this->writeWord(word4, out);
312 }
313 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)314 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
315                                           int32_t word3, int32_t word4, int32_t word5,
316                                           OutputStream& out) {
317     this->writeOpCode(opCode, 6, out);
318     this->writeWord(word1, out);
319     this->writeWord(word2, out);
320     this->writeWord(word3, out);
321     this->writeWord(word4, out);
322     this->writeWord(word5, out);
323 }
324 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)325 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
326                                           int32_t word3, int32_t word4, int32_t word5,
327                                           int32_t word6, OutputStream& out) {
328     this->writeOpCode(opCode, 7, out);
329     this->writeWord(word1, out);
330     this->writeWord(word2, out);
331     this->writeWord(word3, out);
332     this->writeWord(word4, out);
333     this->writeWord(word5, out);
334     this->writeWord(word6, out);
335 }
336 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)337 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
338                                           int32_t word3, int32_t word4, int32_t word5,
339                                           int32_t word6, int32_t word7, OutputStream& out) {
340     this->writeOpCode(opCode, 8, out);
341     this->writeWord(word1, out);
342     this->writeWord(word2, out);
343     this->writeWord(word3, out);
344     this->writeWord(word4, out);
345     this->writeWord(word5, out);
346     this->writeWord(word6, out);
347     this->writeWord(word7, out);
348 }
349 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)350 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
351                                           int32_t word3, int32_t word4, int32_t word5,
352                                           int32_t word6, int32_t word7, int32_t word8,
353                                           OutputStream& out) {
354     this->writeOpCode(opCode, 9, out);
355     this->writeWord(word1, out);
356     this->writeWord(word2, out);
357     this->writeWord(word3, out);
358     this->writeWord(word4, out);
359     this->writeWord(word5, out);
360     this->writeWord(word6, out);
361     this->writeWord(word7, out);
362     this->writeWord(word8, out);
363 }
364 
writeCapabilities(OutputStream & out)365 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
366     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
367         if (fCapabilities & bit) {
368             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
369         }
370     }
371     if (fProgram.fKind == Program::kGeometry_Kind) {
372         this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
373     }
374     else {
375         this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
376     }
377 }
378 
nextId()379 SpvId SPIRVCodeGenerator::nextId() {
380     return fIdCount++;
381 }
382 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)383 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
384                                      SpvId resultId) {
385     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
386     // go ahead and write all of the field types, so we don't inadvertently write them while we're
387     // in the middle of writing the struct instruction
388     std::vector<SpvId> types;
389     for (const auto& f : type.fields()) {
390         types.push_back(this->getType(*f.fType, memoryLayout));
391     }
392     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
393     this->writeWord(resultId, fConstantBuffer);
394     for (SpvId id : types) {
395         this->writeWord(id, fConstantBuffer);
396     }
397     size_t offset = 0;
398     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
399         const Type::Field& field = type.fields()[i];
400         size_t size = memoryLayout.size(*field.fType);
401         size_t alignment = memoryLayout.alignment(*field.fType);
402         const Layout& fieldLayout = field.fModifiers.fLayout;
403         if (fieldLayout.fOffset >= 0) {
404             if (fieldLayout.fOffset < (int) offset) {
405                 fErrors.error(type.fOffset,
406                               "offset of field '" + field.fName + "' must be at "
407                               "least " + to_string((int) offset));
408             }
409             if (fieldLayout.fOffset % alignment) {
410                 fErrors.error(type.fOffset,
411                               "offset of field '" + field.fName + "' must be a multiple"
412                               " of " + to_string((int) alignment));
413             }
414             offset = fieldLayout.fOffset;
415         } else {
416             size_t mod = offset % alignment;
417             if (mod) {
418                 offset += alignment - mod;
419             }
420         }
421         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
422         this->writeLayout(fieldLayout, resultId, i);
423         if (field.fModifiers.fLayout.fBuiltin < 0) {
424             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
425                                    (SpvId) offset, fDecorationBuffer);
426         }
427         if (field.fType->kind() == Type::kMatrix_Kind) {
428             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
429                                    fDecorationBuffer);
430             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
431                                    (SpvId) memoryLayout.stride(*field.fType),
432                                    fDecorationBuffer);
433         }
434         if (!field.fType->highPrecision()) {
435             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
436                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
437         }
438         offset += size;
439         Type::Kind kind = field.fType->kind();
440         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
441             offset += alignment - offset % alignment;
442         }
443     }
444 }
445 
getActualType(const Type & type)446 Type SPIRVCodeGenerator::getActualType(const Type& type) {
447     if (type.isFloat()) {
448         return *fContext.fFloat_Type;
449     }
450     if (type.isSigned()) {
451         return *fContext.fInt_Type;
452     }
453     if (type.isUnsigned()) {
454         return *fContext.fUInt_Type;
455     }
456     if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
457         if (type.componentType() == *fContext.fHalf_Type) {
458             return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
459         }
460         if (type.componentType() == *fContext.fShort_Type ||
461             type.componentType() == *fContext.fByte_Type) {
462             return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
463         }
464         if (type.componentType() == *fContext.fUShort_Type ||
465             type.componentType() == *fContext.fUByte_Type) {
466             return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
467         }
468     }
469     return type;
470 }
471 
getType(const Type & type)472 SpvId SPIRVCodeGenerator::getType(const Type& type) {
473     return this->getType(type, fDefaultLayout);
474 }
475 
getType(const Type & rawType,const MemoryLayout & layout)476 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
477     Type type = this->getActualType(rawType);
478     String key = type.name() + to_string((int) layout.fStd);
479     auto entry = fTypeMap.find(key);
480     if (entry == fTypeMap.end()) {
481         SpvId result = this->nextId();
482         switch (type.kind()) {
483             case Type::kScalar_Kind:
484                 if (type == *fContext.fBool_Type) {
485                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
486                 } else if (type == *fContext.fInt_Type || type == *fContext.fShort_Type ||
487                            type == *fContext.fIntLiteral_Type) {
488                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
489                 } else if (type == *fContext.fUInt_Type || type == *fContext.fUShort_Type) {
490                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
491                 } else if (type == *fContext.fFloat_Type || type == *fContext.fHalf_Type ||
492                            type == *fContext.fFloatLiteral_Type) {
493                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
494                 } else if (type == *fContext.fDouble_Type) {
495                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
496                 } else {
497                     SkASSERT(false);
498                 }
499                 break;
500             case Type::kVector_Kind:
501                 this->writeInstruction(SpvOpTypeVector, result,
502                                        this->getType(type.componentType(), layout),
503                                        type.columns(), fConstantBuffer);
504                 break;
505             case Type::kMatrix_Kind:
506                 this->writeInstruction(SpvOpTypeMatrix, result,
507                                        this->getType(index_type(fContext, type), layout),
508                                        type.columns(), fConstantBuffer);
509                 break;
510             case Type::kStruct_Kind:
511                 this->writeStruct(type, layout, result);
512                 break;
513             case Type::kArray_Kind: {
514                 if (type.columns() > 0) {
515                     IntLiteral count(fContext, -1, type.columns());
516                     this->writeInstruction(SpvOpTypeArray, result,
517                                            this->getType(type.componentType(), layout),
518                                            this->writeIntLiteral(count), fConstantBuffer);
519                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
520                                            (int32_t) layout.stride(type),
521                                            fDecorationBuffer);
522                 } else {
523                     SkASSERT(false); // we shouldn't have any runtime-sized arrays right now
524                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
525                                            this->getType(type.componentType(), layout),
526                                            fConstantBuffer);
527                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
528                                            (int32_t) layout.stride(type),
529                                            fDecorationBuffer);
530                 }
531                 break;
532             }
533             case Type::kSampler_Kind: {
534                 SpvId image = result;
535                 if (SpvDimSubpassData != type.dimensions()) {
536                     image = this->getType(type.textureType(), layout);
537                 }
538                 if (SpvDimBuffer == type.dimensions()) {
539                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
540                 }
541                 if (SpvDimSubpassData != type.dimensions()) {
542                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
543                 }
544                 break;
545             }
546             case Type::kSeparateSampler_Kind: {
547                 this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer);
548                 break;
549             }
550             case Type::kTexture_Kind: {
551                 this->writeInstruction(SpvOpTypeImage, result,
552                                        this->getType(*fContext.fFloat_Type, layout),
553                                        type.dimensions(), type.isDepth(), type.isArrayed(),
554                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
555                                        SpvImageFormatUnknown, fConstantBuffer);
556                 fImageTypeMap[key] = result;
557                 break;
558             }
559             default:
560                 if (type == *fContext.fVoid_Type) {
561                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
562                 } else {
563                     ABORT("invalid type: %s", type.description().c_str());
564                 }
565         }
566         fTypeMap[key] = result;
567         return result;
568     }
569     return entry->second;
570 }
571 
getImageType(const Type & type)572 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
573     SkASSERT(type.kind() == Type::kSampler_Kind);
574     this->getType(type);
575     String key = type.name() + to_string((int) fDefaultLayout.fStd);
576     SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
577     return fImageTypeMap[key];
578 }
579 
getFunctionType(const FunctionDeclaration & function)580 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
581     String key = function.fReturnType.description() + "(";
582     String separator;
583     for (size_t i = 0; i < function.fParameters.size(); i++) {
584         key += separator;
585         separator = ", ";
586         key += function.fParameters[i]->fType.description();
587     }
588     key += ")";
589     auto entry = fTypeMap.find(key);
590     if (entry == fTypeMap.end()) {
591         SpvId result = this->nextId();
592         int32_t length = 3 + (int32_t) function.fParameters.size();
593         SpvId returnType = this->getType(function.fReturnType);
594         std::vector<SpvId> parameterTypes;
595         for (size_t i = 0; i < function.fParameters.size(); i++) {
596             // glslang seems to treat all function arguments as pointers whether they need to be or
597             // not. I  was initially puzzled by this until I ran bizarre failures with certain
598             // patterns of function calls and control constructs, as exemplified by this minimal
599             // failure case:
600             //
601             // void sphere(float x) {
602             // }
603             //
604             // void map() {
605             //     sphere(1.0);
606             // }
607             //
608             // void main() {
609             //     for (int i = 0; i < 1; i++) {
610             //         map();
611             //     }
612             // }
613             //
614             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
615             // crashes. Making it take a float* and storing the argument in a temporary variable,
616             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
617             // the spec makes this make sense.
618 //            if (is_out(function->fParameters[i])) {
619                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
620                                                               SpvStorageClassFunction));
621 //            } else {
622 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
623 //            }
624         }
625         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
626         this->writeWord(result, fConstantBuffer);
627         this->writeWord(returnType, fConstantBuffer);
628         for (SpvId id : parameterTypes) {
629             this->writeWord(id, fConstantBuffer);
630         }
631         fTypeMap[key] = result;
632         return result;
633     }
634     return entry->second;
635 }
636 
getPointerType(const Type & type,SpvStorageClass_ storageClass)637 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
638     return this->getPointerType(type, fDefaultLayout, storageClass);
639 }
640 
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)641 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
642                                          SpvStorageClass_ storageClass) {
643     Type type = this->getActualType(rawType);
644     String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
645     auto entry = fTypeMap.find(key);
646     if (entry == fTypeMap.end()) {
647         SpvId result = this->nextId();
648         this->writeInstruction(SpvOpTypePointer, result, storageClass,
649                                this->getType(type), fConstantBuffer);
650         fTypeMap[key] = result;
651         return result;
652     }
653     return entry->second;
654 }
655 
writeExpression(const Expression & expr,OutputStream & out)656 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
657     switch (expr.fKind) {
658         case Expression::kBinary_Kind:
659             return this->writeBinaryExpression((BinaryExpression&) expr, out);
660         case Expression::kBoolLiteral_Kind:
661             return this->writeBoolLiteral((BoolLiteral&) expr);
662         case Expression::kConstructor_Kind:
663             return this->writeConstructor((Constructor&) expr, out);
664         case Expression::kIntLiteral_Kind:
665             return this->writeIntLiteral((IntLiteral&) expr);
666         case Expression::kFieldAccess_Kind:
667             return this->writeFieldAccess(((FieldAccess&) expr), out);
668         case Expression::kFloatLiteral_Kind:
669             return this->writeFloatLiteral(((FloatLiteral&) expr));
670         case Expression::kFunctionCall_Kind:
671             return this->writeFunctionCall((FunctionCall&) expr, out);
672         case Expression::kPrefix_Kind:
673             return this->writePrefixExpression((PrefixExpression&) expr, out);
674         case Expression::kPostfix_Kind:
675             return this->writePostfixExpression((PostfixExpression&) expr, out);
676         case Expression::kSwizzle_Kind:
677             return this->writeSwizzle((Swizzle&) expr, out);
678         case Expression::kVariableReference_Kind:
679             return this->writeVariableReference((VariableReference&) expr, out);
680         case Expression::kTernary_Kind:
681             return this->writeTernaryExpression((TernaryExpression&) expr, out);
682         case Expression::kIndex_Kind:
683             return this->writeIndexExpression((IndexExpression&) expr, out);
684         default:
685             ABORT("unsupported expression: %s", expr.description().c_str());
686     }
687     return -1;
688 }
689 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)690 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
691     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
692     SkASSERT(intrinsic != fIntrinsicMap.end());
693     int32_t intrinsicId;
694     if (c.fArguments.size() > 0) {
695         const Type& type = c.fArguments[0]->fType;
696         if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
697             intrinsicId = std::get<1>(intrinsic->second);
698         } else if (is_signed(fContext, type)) {
699             intrinsicId = std::get<2>(intrinsic->second);
700         } else if (is_unsigned(fContext, type)) {
701             intrinsicId = std::get<3>(intrinsic->second);
702         } else if (is_bool(fContext, type)) {
703             intrinsicId = std::get<4>(intrinsic->second);
704         } else {
705             intrinsicId = std::get<1>(intrinsic->second);
706         }
707     } else {
708         intrinsicId = std::get<1>(intrinsic->second);
709     }
710     switch (std::get<0>(intrinsic->second)) {
711         case kGLSL_STD_450_IntrinsicKind: {
712             SpvId result = this->nextId();
713             std::vector<SpvId> arguments;
714             for (size_t i = 0; i < c.fArguments.size(); i++) {
715                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
716                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
717                 } else {
718                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
719                 }
720             }
721             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
722             this->writeWord(this->getType(c.fType), out);
723             this->writeWord(result, out);
724             this->writeWord(fGLSLExtendedInstructions, out);
725             this->writeWord(intrinsicId, out);
726             for (SpvId id : arguments) {
727                 this->writeWord(id, out);
728             }
729             return result;
730         }
731         case kSPIRV_IntrinsicKind: {
732             SpvId result = this->nextId();
733             std::vector<SpvId> arguments;
734             for (size_t i = 0; i < c.fArguments.size(); i++) {
735                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
736                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
737                 } else {
738                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
739                 }
740             }
741             if (c.fType != *fContext.fVoid_Type) {
742                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
743                 this->writeWord(this->getType(c.fType), out);
744                 this->writeWord(result, out);
745             } else {
746                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
747             }
748             for (SpvId id : arguments) {
749                 this->writeWord(id, out);
750             }
751             return result;
752         }
753         case kSpecial_IntrinsicKind:
754             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
755         default:
756             ABORT("unsupported intrinsic kind");
757     }
758 }
759 
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)760 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
761                                                const std::vector<std::unique_ptr<Expression>>& args,
762                                                OutputStream& out) {
763     int vectorSize = 0;
764     for (const auto& a : args) {
765         if (a->fType.kind() == Type::kVector_Kind) {
766             if (vectorSize) {
767                 SkASSERT(a->fType.columns() == vectorSize);
768             }
769             else {
770                 vectorSize = a->fType.columns();
771             }
772         }
773     }
774     std::vector<SpvId> result;
775     for (const auto& a : args) {
776         SpvId raw = this->writeExpression(*a, out);
777         if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
778             SpvId vector = this->nextId();
779             this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
780             this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
781             this->writeWord(vector, out);
782             for (int i = 0; i < vectorSize; i++) {
783                 this->writeWord(raw, out);
784             }
785             this->writePrecisionModifier(a->fType, vector);
786             result.push_back(vector);
787         } else {
788             result.push_back(raw);
789         }
790     }
791     return result;
792 }
793 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)794 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
795                                                       SpvId signedInst, SpvId unsignedInst,
796                                                       const std::vector<SpvId>& args,
797                                                       OutputStream& out) {
798     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
799     this->writeWord(this->getType(type), out);
800     this->writeWord(id, out);
801     this->writeWord(fGLSLExtendedInstructions, out);
802 
803     if (is_float(fContext, type)) {
804         this->writeWord(floatInst, out);
805     } else if (is_signed(fContext, type)) {
806         this->writeWord(signedInst, out);
807     } else if (is_unsigned(fContext, type)) {
808         this->writeWord(unsignedInst, out);
809     } else {
810         SkASSERT(false);
811     }
812     for (SpvId a : args) {
813         this->writeWord(a, out);
814     }
815 }
816 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)817 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
818                                                 OutputStream& out) {
819     SpvId result = this->nextId();
820     switch (kind) {
821         case kAtan_SpecialIntrinsic: {
822             std::vector<SpvId> arguments;
823             for (size_t i = 0; i < c.fArguments.size(); i++) {
824                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
825             }
826             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
827             this->writeWord(this->getType(c.fType), out);
828             this->writeWord(result, out);
829             this->writeWord(fGLSLExtendedInstructions, out);
830             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
831             for (SpvId id : arguments) {
832                 this->writeWord(id, out);
833             }
834             break;
835         }
836         case kSampledImage_SpecialIntrinsic: {
837             SkASSERT(2 == c.fArguments.size());
838             SpvId img = this->writeExpression(*c.fArguments[0], out);
839             SpvId sampler = this->writeExpression(*c.fArguments[1], out);
840             this->writeInstruction(SpvOpSampledImage,
841                                    this->getType(c.fType),
842                                    result,
843                                    img,
844                                    sampler,
845                                    out);
846             break;
847         }
848         case kSubpassLoad_SpecialIntrinsic: {
849             SpvId img = this->writeExpression(*c.fArguments[0], out);
850             std::vector<std::unique_ptr<Expression>> args;
851             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
852             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
853             Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
854             SpvId coords = this->writeConstantVector(ctor);
855             if (1 == c.fArguments.size()) {
856                 this->writeInstruction(SpvOpImageRead,
857                                        this->getType(c.fType),
858                                        result,
859                                        img,
860                                        coords,
861                                        out);
862             } else {
863                 SkASSERT(2 == c.fArguments.size());
864                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
865                 this->writeInstruction(SpvOpImageRead,
866                                        this->getType(c.fType),
867                                        result,
868                                        img,
869                                        coords,
870                                        SpvImageOperandsSampleMask,
871                                        sample,
872                                        out);
873             }
874             break;
875         }
876         case kTexture_SpecialIntrinsic: {
877             SpvOp_ op = SpvOpImageSampleImplicitLod;
878             switch (c.fArguments[0]->fType.dimensions()) {
879                 case SpvDim1D:
880                     if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
881                         op = SpvOpImageSampleProjImplicitLod;
882                     } else {
883                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
884                     }
885                     break;
886                 case SpvDim2D:
887                     if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
888                         op = SpvOpImageSampleProjImplicitLod;
889                     } else {
890                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
891                     }
892                     break;
893                 case SpvDim3D:
894                     if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
895                         op = SpvOpImageSampleProjImplicitLod;
896                     } else {
897                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
898                     }
899                     break;
900                 case SpvDimCube:   // fall through
901                 case SpvDimRect:   // fall through
902                 case SpvDimBuffer: // fall through
903                 case SpvDimSubpassData:
904                     break;
905             }
906             SpvId type = this->getType(c.fType);
907             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
908             SpvId uv = this->writeExpression(*c.fArguments[1], out);
909             if (c.fArguments.size() == 3) {
910                 this->writeInstruction(op, type, result, sampler, uv,
911                                        SpvImageOperandsBiasMask,
912                                        this->writeExpression(*c.fArguments[2], out),
913                                        out);
914             } else {
915                 SkASSERT(c.fArguments.size() == 2);
916                 if (fProgram.fSettings.fSharpenTextures) {
917                     FloatLiteral lodBias(fContext, -1, -0.5);
918                     this->writeInstruction(op, type, result, sampler, uv,
919                                            SpvImageOperandsBiasMask,
920                                            this->writeFloatLiteral(lodBias),
921                                            out);
922                 } else {
923                     this->writeInstruction(op, type, result, sampler, uv,
924                                            out);
925                 }
926             }
927             break;
928         }
929         case kMod_SpecialIntrinsic: {
930             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
931             SkASSERT(args.size() == 2);
932             const Type& operandType = c.fArguments[0]->fType;
933             SpvOp_ op;
934             if (is_float(fContext, operandType)) {
935                 op = SpvOpFMod;
936             } else if (is_signed(fContext, operandType)) {
937                 op = SpvOpSMod;
938             } else if (is_unsigned(fContext, operandType)) {
939                 op = SpvOpUMod;
940             } else {
941                 SkASSERT(false);
942                 return 0;
943             }
944             this->writeOpCode(op, 5, out);
945             this->writeWord(this->getType(operandType), out);
946             this->writeWord(result, out);
947             this->writeWord(args[0], out);
948             this->writeWord(args[1], out);
949             break;
950         }
951         case kDFdy_SpecialIntrinsic: {
952             SpvId fn = this->writeExpression(*c.fArguments[0], out);
953             this->writeOpCode(SpvOpDPdy, 4, out);
954             this->writeWord(this->getType(c.fType), out);
955             this->writeWord(result, out);
956             this->writeWord(fn, out);
957             if (fProgram.fSettings.fFlipY) {
958                 // Flipping Y also negates the Y derivatives.
959                 SpvId flipped = this->nextId();
960                 this->writeInstruction(SpvOpFNegate, this->getType(c.fType), flipped, result, out);
961                 this->writePrecisionModifier(c.fType, flipped);
962                 return flipped;
963             }
964             break;
965         }
966         case kClamp_SpecialIntrinsic: {
967             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
968             SkASSERT(args.size() == 3);
969             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
970                                                GLSLstd450UClamp, args, out);
971             break;
972         }
973         case kMax_SpecialIntrinsic: {
974             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
975             SkASSERT(args.size() == 2);
976             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
977                                                GLSLstd450UMax, args, out);
978             break;
979         }
980         case kMin_SpecialIntrinsic: {
981             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
982             SkASSERT(args.size() == 2);
983             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
984                                                GLSLstd450UMin, args, out);
985             break;
986         }
987         case kMix_SpecialIntrinsic: {
988             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
989             SkASSERT(args.size() == 3);
990             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
991                                                SpvOpUndef, args, out);
992             break;
993         }
994         case kSaturate_SpecialIntrinsic: {
995             SkASSERT(c.fArguments.size() == 1);
996             std::vector<std::unique_ptr<Expression>> finalArgs;
997             finalArgs.push_back(c.fArguments[0]->clone());
998             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 0));
999             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 1));
1000             std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
1001             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1002                                                GLSLstd450UClamp, spvArgs, out);
1003             break;
1004         }
1005     }
1006     return result;
1007 }
1008 
writeFunctionCall(const FunctionCall & c,OutputStream & out)1009 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1010     const auto& entry = fFunctionMap.find(&c.fFunction);
1011     if (entry == fFunctionMap.end()) {
1012         return this->writeIntrinsicCall(c, out);
1013     }
1014     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1015     std::vector<std::tuple<SpvId, const Type*, std::unique_ptr<LValue>>> lvalues;
1016     std::vector<SpvId> arguments;
1017     for (size_t i = 0; i < c.fArguments.size(); i++) {
1018         // id of temporary variable that we will use to hold this argument, or 0 if it is being
1019         // passed directly
1020         SpvId tmpVar;
1021         // if we need a temporary var to store this argument, this is the value to store in the var
1022         SpvId tmpValueId;
1023         if (is_out(*c.fFunction.fParameters[i])) {
1024             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1025             SpvId ptr = lv->getPointer();
1026             if (ptr) {
1027                 arguments.push_back(ptr);
1028                 continue;
1029             } else {
1030                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1031                 // copy it into a temp, call the function, read the value out of the temp, and then
1032                 // update the lvalue.
1033                 tmpValueId = lv->load(out);
1034                 tmpVar = this->nextId();
1035                 lvalues.push_back(std::make_tuple(tmpVar, &c.fArguments[i]->fType, std::move(lv)));
1036             }
1037         } else {
1038             // see getFunctionType for an explanation of why we're always using pointer parameters
1039             tmpValueId = this->writeExpression(*c.fArguments[i], out);
1040             tmpVar = this->nextId();
1041         }
1042         this->writeInstruction(SpvOpVariable,
1043                                this->getPointerType(c.fArguments[i]->fType,
1044                                                     SpvStorageClassFunction),
1045                                tmpVar,
1046                                SpvStorageClassFunction,
1047                                fVariableBuffer);
1048         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1049         arguments.push_back(tmpVar);
1050     }
1051     SpvId result = this->nextId();
1052     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1053     this->writeWord(this->getType(c.fType), out);
1054     this->writeWord(result, out);
1055     this->writeWord(entry->second, out);
1056     for (SpvId id : arguments) {
1057         this->writeWord(id, out);
1058     }
1059     // now that the call is complete, we may need to update some lvalues with the new values of out
1060     // arguments
1061     for (const auto& tuple : lvalues) {
1062         SpvId load = this->nextId();
1063         this->writeInstruction(SpvOpLoad, getType(*std::get<1>(tuple)), load, std::get<0>(tuple),
1064                                out);
1065         this->writePrecisionModifier(*std::get<1>(tuple), load);
1066         std::get<2>(tuple)->store(load, out);
1067     }
1068     return result;
1069 }
1070 
writeConstantVector(const Constructor & c)1071 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1072     SkASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1073     SpvId result = this->nextId();
1074     std::vector<SpvId> arguments;
1075     for (size_t i = 0; i < c.fArguments.size(); i++) {
1076         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1077     }
1078     SpvId type = this->getType(c.fType);
1079     if (c.fArguments.size() == 1) {
1080         // with a single argument, a vector will have all of its entries equal to the argument
1081         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1082         this->writeWord(type, fConstantBuffer);
1083         this->writeWord(result, fConstantBuffer);
1084         for (int i = 0; i < c.fType.columns(); i++) {
1085             this->writeWord(arguments[0], fConstantBuffer);
1086         }
1087     } else {
1088         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1089                           fConstantBuffer);
1090         this->writeWord(type, fConstantBuffer);
1091         this->writeWord(result, fConstantBuffer);
1092         for (SpvId id : arguments) {
1093             this->writeWord(id, fConstantBuffer);
1094         }
1095     }
1096     return result;
1097 }
1098 
writeFloatConstructor(const Constructor & c,OutputStream & out)1099 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1100     SkASSERT(c.fType.isFloat());
1101     SkASSERT(c.fArguments.size() == 1);
1102     SkASSERT(c.fArguments[0]->fType.isNumber());
1103     SpvId result = this->nextId();
1104     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1105     if (c.fArguments[0]->fType.isSigned()) {
1106         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1107                                out);
1108     } else {
1109         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1110         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1111                                out);
1112     }
1113     return result;
1114 }
1115 
writeIntConstructor(const Constructor & c,OutputStream & out)1116 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1117     SkASSERT(c.fType.isSigned());
1118     SkASSERT(c.fArguments.size() == 1);
1119     SkASSERT(c.fArguments[0]->fType.isNumber());
1120     SpvId result = this->nextId();
1121     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1122     if (c.fArguments[0]->fType.isFloat()) {
1123         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1124                                out);
1125     }
1126     else {
1127         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1128         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1129                                out);
1130     }
1131     return result;
1132 }
1133 
writeUIntConstructor(const Constructor & c,OutputStream & out)1134 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1135     SkASSERT(c.fType.isUnsigned());
1136     SkASSERT(c.fArguments.size() == 1);
1137     SkASSERT(c.fArguments[0]->fType.isNumber());
1138     SpvId result = this->nextId();
1139     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1140     if (c.fArguments[0]->fType.isFloat()) {
1141         this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1142                                out);
1143     } else {
1144         SkASSERT(c.fArguments[0]->fType.isSigned());
1145         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1146                                out);
1147     }
1148     return result;
1149 }
1150 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1151 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1152                                                  OutputStream& out) {
1153     FloatLiteral zero(fContext, -1, 0);
1154     SpvId zeroId = this->writeFloatLiteral(zero);
1155     std::vector<SpvId> columnIds;
1156     for (int column = 0; column < type.columns(); column++) {
1157         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1158                           out);
1159         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1160                         out);
1161         SpvId columnId = this->nextId();
1162         this->writeWord(columnId, out);
1163         columnIds.push_back(columnId);
1164         for (int row = 0; row < type.columns(); row++) {
1165             this->writeWord(row == column ? diagonal : zeroId, out);
1166         }
1167         this->writePrecisionModifier(type, columnId);
1168     }
1169     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1170                       out);
1171     this->writeWord(this->getType(type), out);
1172     this->writeWord(id, out);
1173     for (SpvId id : columnIds) {
1174         this->writeWord(id, out);
1175     }
1176     this->writePrecisionModifier(type, id);
1177 }
1178 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1179 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1180                                          const Type& dstType, OutputStream& out) {
1181     SkASSERT(srcType.kind() == Type::kMatrix_Kind);
1182     SkASSERT(dstType.kind() == Type::kMatrix_Kind);
1183     SkASSERT(srcType.componentType() == dstType.componentType());
1184     SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1185                                                                            srcType.rows(),
1186                                                                            1));
1187     SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1188                                                                            dstType.rows(),
1189                                                                            1));
1190     SpvId zeroId;
1191     if (dstType.componentType() == *fContext.fFloat_Type) {
1192         FloatLiteral zero(fContext, -1, 0.0);
1193         zeroId = this->writeFloatLiteral(zero);
1194     } else if (dstType.componentType() == *fContext.fInt_Type) {
1195         IntLiteral zero(fContext, -1, 0);
1196         zeroId = this->writeIntLiteral(zero);
1197     } else {
1198         ABORT("unsupported matrix component type");
1199     }
1200     SpvId zeroColumn = 0;
1201     SpvId columns[4];
1202     for (int i = 0; i < dstType.columns(); i++) {
1203         if (i < srcType.columns()) {
1204             // we're still inside the src matrix, copy the column
1205             SpvId srcColumn = this->nextId();
1206             this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1207             this->writePrecisionModifier(dstType, srcColumn);
1208             SpvId dstColumn;
1209             if (srcType.rows() == dstType.rows()) {
1210                 // columns are equal size, don't need to do anything
1211                 dstColumn = srcColumn;
1212             }
1213             else if (dstType.rows() > srcType.rows()) {
1214                 // dst column is bigger, need to zero-pad it
1215                 dstColumn = this->nextId();
1216                 int delta = dstType.rows() - srcType.rows();
1217                 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1218                 this->writeWord(dstColumnType, out);
1219                 this->writeWord(dstColumn, out);
1220                 this->writeWord(srcColumn, out);
1221                 for (int i = 0; i < delta; ++i) {
1222                     this->writeWord(zeroId, out);
1223                 }
1224                 this->writePrecisionModifier(dstType, dstColumn);
1225             }
1226             else {
1227                 // dst column is smaller, need to swizzle the src column
1228                 dstColumn = this->nextId();
1229                 int count = dstType.rows();
1230                 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1231                 this->writeWord(dstColumnType, out);
1232                 this->writeWord(dstColumn, out);
1233                 this->writeWord(srcColumn, out);
1234                 this->writeWord(srcColumn, out);
1235                 for (int i = 0; i < count; i++) {
1236                     this->writeWord(i, out);
1237                 }
1238                 this->writePrecisionModifier(dstType, dstColumn);
1239             }
1240             columns[i] = dstColumn;
1241         } else {
1242             // we're past the end of the src matrix, need a vector of zeroes
1243             if (!zeroColumn) {
1244                 zeroColumn = this->nextId();
1245                 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1246                 this->writeWord(dstColumnType, out);
1247                 this->writeWord(zeroColumn, out);
1248                 for (int i = 0; i < dstType.rows(); ++i) {
1249                     this->writeWord(zeroId, out);
1250                 }
1251                 this->writePrecisionModifier(dstType, zeroColumn);
1252             }
1253             columns[i] = zeroColumn;
1254         }
1255     }
1256     this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1257     this->writeWord(this->getType(dstType), out);
1258     this->writeWord(id, out);
1259     for (int i = 0; i < dstType.columns(); i++) {
1260         this->writeWord(columns[i], out);
1261     }
1262     this->writePrecisionModifier(dstType, id);
1263 }
1264 
addColumnEntry(SpvId columnType,Precision precision,std::vector<SpvId> * currentColumn,std::vector<SpvId> * columnIds,int * currentCount,int rows,SpvId entry,OutputStream & out)1265 void SPIRVCodeGenerator::addColumnEntry(SpvId columnType, Precision precision,
1266                                         std::vector<SpvId>* currentColumn,
1267                                         std::vector<SpvId>* columnIds,
1268                                         int* currentCount, int rows, SpvId entry,
1269                                         OutputStream& out) {
1270     SkASSERT(*currentCount < rows);
1271     ++(*currentCount);
1272     currentColumn->push_back(entry);
1273     if (*currentCount == rows) {
1274         *currentCount = 0;
1275         this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn->size(), out);
1276         this->writeWord(columnType, out);
1277         SpvId columnId = this->nextId();
1278         this->writeWord(columnId, out);
1279         columnIds->push_back(columnId);
1280         for (SpvId id : *currentColumn) {
1281             this->writeWord(id, out);
1282         }
1283         currentColumn->clear();
1284         this->writePrecisionModifier(precision, columnId);
1285     }
1286 }
1287 
writeMatrixConstructor(const Constructor & c,OutputStream & out)1288 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1289     SkASSERT(c.fType.kind() == Type::kMatrix_Kind);
1290     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1291     // an instruction
1292     std::vector<SpvId> arguments;
1293     for (size_t i = 0; i < c.fArguments.size(); i++) {
1294         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1295     }
1296     SpvId result = this->nextId();
1297     int rows = c.fType.rows();
1298     int columns = c.fType.columns();
1299     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1300         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1301     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1302         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1303     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1304         SkASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1305         SkASSERT(c.fArguments[0]->fType.columns() == 4);
1306         SpvId componentType = this->getType(c.fType.componentType());
1307         SpvId v[4];
1308         for (int i = 0; i < 4; ++i) {
1309             v[i] = this->nextId();
1310             this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1311         }
1312         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1313         SpvId column1 = this->nextId();
1314         this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1315         SpvId column2 = this->nextId();
1316         this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1317         this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1318                                column2, out);
1319     } else {
1320         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, rows, 1));
1321         std::vector<SpvId> columnIds;
1322         // ids of vectors and scalars we have written to the current column so far
1323         std::vector<SpvId> currentColumn;
1324         // the total number of scalars represented by currentColumn's entries
1325         int currentCount = 0;
1326         Precision precision = c.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
1327         for (size_t i = 0; i < arguments.size(); i++) {
1328             if (currentCount == 0 && c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1329                     c.fArguments[i]->fType.columns() == c.fType.rows()) {
1330                 // this is a complete column by itself
1331                 columnIds.push_back(arguments[i]);
1332             } else {
1333                 if (c.fArguments[i]->fType.columns() == 1) {
1334                     this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1335                                          &currentCount, rows, arguments[i], out);
1336                 } else {
1337                     SpvId componentType = this->getType(c.fArguments[i]->fType.componentType());
1338                     for (int j = 0; j < c.fArguments[i]->fType.columns(); ++j) {
1339                         SpvId swizzle = this->nextId();
1340                         this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1341                                                arguments[i], j, out);
1342                         this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1343                                              &currentCount, rows, swizzle, out);
1344                     }
1345                 }
1346             }
1347         }
1348         SkASSERT(columnIds.size() == (size_t) columns);
1349         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1350         this->writeWord(this->getType(c.fType), out);
1351         this->writeWord(result, out);
1352         for (SpvId id : columnIds) {
1353             this->writeWord(id, out);
1354         }
1355     }
1356     this->writePrecisionModifier(c.fType, result);
1357     return result;
1358 }
1359 
writeVectorConstructor(const Constructor & c,OutputStream & out)1360 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1361     SkASSERT(c.fType.kind() == Type::kVector_Kind);
1362     if (c.isConstant()) {
1363         return this->writeConstantVector(c);
1364     }
1365     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1366     // an instruction
1367     std::vector<SpvId> arguments;
1368     for (size_t i = 0; i < c.fArguments.size(); i++) {
1369         if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1370             // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1371             // extract the components and convert them in that case manually. On top of that,
1372             // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1373             // doesn't handle vector arguments at all, so we always extract vector components and
1374             // pass them into OpCreateComposite individually.
1375             SpvId vec = this->writeExpression(*c.fArguments[i], out);
1376             SpvOp_ op = SpvOpUndef;
1377             const Type& src = c.fArguments[i]->fType.componentType();
1378             const Type& dst = c.fType.componentType();
1379             if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1380                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1381                     if (c.fArguments.size() == 1) {
1382                         return vec;
1383                     }
1384                 } else if (src == *fContext.fInt_Type ||
1385                            src == *fContext.fShort_Type ||
1386                            src == *fContext.fByte_Type) {
1387                     op = SpvOpConvertSToF;
1388                 } else if (src == *fContext.fUInt_Type ||
1389                            src == *fContext.fUShort_Type ||
1390                            src == *fContext.fUByte_Type) {
1391                     op = SpvOpConvertUToF;
1392                 } else {
1393                     SkASSERT(false);
1394                 }
1395             } else if (dst == *fContext.fInt_Type ||
1396                        dst == *fContext.fShort_Type ||
1397                        dst == *fContext.fByte_Type) {
1398                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1399                     op = SpvOpConvertFToS;
1400                 } else if (src == *fContext.fInt_Type ||
1401                            src == *fContext.fShort_Type ||
1402                            src == *fContext.fByte_Type) {
1403                     if (c.fArguments.size() == 1) {
1404                         return vec;
1405                     }
1406                 } else if (src == *fContext.fUInt_Type ||
1407                            src == *fContext.fUShort_Type ||
1408                            src == *fContext.fUByte_Type) {
1409                     op = SpvOpBitcast;
1410                 } else {
1411                     SkASSERT(false);
1412                 }
1413             } else if (dst == *fContext.fUInt_Type ||
1414                        dst == *fContext.fUShort_Type ||
1415                        dst == *fContext.fUByte_Type) {
1416                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1417                     op = SpvOpConvertFToS;
1418                 } else if (src == *fContext.fInt_Type ||
1419                            src == *fContext.fShort_Type ||
1420                            src == *fContext.fByte_Type) {
1421                     op = SpvOpBitcast;
1422                 } else if (src == *fContext.fUInt_Type ||
1423                            src == *fContext.fUShort_Type ||
1424                            src == *fContext.fUByte_Type) {
1425                     if (c.fArguments.size() == 1) {
1426                         return vec;
1427                     }
1428                 } else {
1429                     SkASSERT(false);
1430                 }
1431             }
1432             for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1433                 SpvId swizzle = this->nextId();
1434                 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1435                                        out);
1436                 if (op != SpvOpUndef) {
1437                     SpvId cast = this->nextId();
1438                     this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1439                     arguments.push_back(cast);
1440                 } else {
1441                     arguments.push_back(swizzle);
1442                 }
1443             }
1444         } else {
1445             arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1446         }
1447     }
1448     SpvId result = this->nextId();
1449     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1450         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1451         this->writeWord(this->getType(c.fType), out);
1452         this->writeWord(result, out);
1453         for (int i = 0; i < c.fType.columns(); i++) {
1454             this->writeWord(arguments[0], out);
1455         }
1456     } else {
1457         SkASSERT(arguments.size() > 1);
1458         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1459         this->writeWord(this->getType(c.fType), out);
1460         this->writeWord(result, out);
1461         for (SpvId id : arguments) {
1462             this->writeWord(id, out);
1463         }
1464     }
1465     return result;
1466 }
1467 
writeArrayConstructor(const Constructor & c,OutputStream & out)1468 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1469     SkASSERT(c.fType.kind() == Type::kArray_Kind);
1470     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1471     // an instruction
1472     std::vector<SpvId> arguments;
1473     for (size_t i = 0; i < c.fArguments.size(); i++) {
1474         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1475     }
1476     SpvId result = this->nextId();
1477     this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1478     this->writeWord(this->getType(c.fType), out);
1479     this->writeWord(result, out);
1480     for (SpvId id : arguments) {
1481         this->writeWord(id, out);
1482     }
1483     return result;
1484 }
1485 
writeConstructor(const Constructor & c,OutputStream & out)1486 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1487     if (c.fArguments.size() == 1 &&
1488         this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1489         return this->writeExpression(*c.fArguments[0], out);
1490     }
1491     if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1492         return this->writeFloatConstructor(c, out);
1493     } else if (c.fType == *fContext.fInt_Type ||
1494                c.fType == *fContext.fShort_Type ||
1495                c.fType == *fContext.fByte_Type) {
1496         return this->writeIntConstructor(c, out);
1497     } else if (c.fType == *fContext.fUInt_Type ||
1498                c.fType == *fContext.fUShort_Type ||
1499                c.fType == *fContext.fUByte_Type) {
1500         return this->writeUIntConstructor(c, out);
1501     }
1502     switch (c.fType.kind()) {
1503         case Type::kVector_Kind:
1504             return this->writeVectorConstructor(c, out);
1505         case Type::kMatrix_Kind:
1506             return this->writeMatrixConstructor(c, out);
1507         case Type::kArray_Kind:
1508             return this->writeArrayConstructor(c, out);
1509         default:
1510             ABORT("unsupported constructor: %s", c.description().c_str());
1511     }
1512 }
1513 
get_storage_class(const Modifiers & modifiers)1514 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1515     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1516         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1517         return SpvStorageClassInput;
1518     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1519         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1520         return SpvStorageClassOutput;
1521     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1522         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1523             return SpvStorageClassPushConstant;
1524         }
1525         return SpvStorageClassUniform;
1526     } else {
1527         return SpvStorageClassFunction;
1528     }
1529 }
1530 
get_storage_class(const Expression & expr)1531 SpvStorageClass_ get_storage_class(const Expression& expr) {
1532     switch (expr.fKind) {
1533         case Expression::kVariableReference_Kind: {
1534             const Variable& var = ((VariableReference&) expr).fVariable;
1535             if (var.fStorage != Variable::kGlobal_Storage) {
1536                 return SpvStorageClassFunction;
1537             }
1538             SpvStorageClass_ result = get_storage_class(var.fModifiers);
1539             if (result == SpvStorageClassFunction) {
1540                 result = SpvStorageClassPrivate;
1541             }
1542             return result;
1543         }
1544         case Expression::kFieldAccess_Kind:
1545             return get_storage_class(*((FieldAccess&) expr).fBase);
1546         case Expression::kIndex_Kind:
1547             return get_storage_class(*((IndexExpression&) expr).fBase);
1548         default:
1549             return SpvStorageClassFunction;
1550     }
1551 }
1552 
getAccessChain(const Expression & expr,OutputStream & out)1553 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1554     std::vector<SpvId> chain;
1555     switch (expr.fKind) {
1556         case Expression::kIndex_Kind: {
1557             IndexExpression& indexExpr = (IndexExpression&) expr;
1558             chain = this->getAccessChain(*indexExpr.fBase, out);
1559             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1560             break;
1561         }
1562         case Expression::kFieldAccess_Kind: {
1563             FieldAccess& fieldExpr = (FieldAccess&) expr;
1564             chain = this->getAccessChain(*fieldExpr.fBase, out);
1565             IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1566             chain.push_back(this->writeIntLiteral(index));
1567             break;
1568         }
1569         default: {
1570             SpvId id = this->getLValue(expr, out)->getPointer();
1571             SkASSERT(id != 0);
1572             chain.push_back(id);
1573         }
1574     }
1575     return chain;
1576 }
1577 
1578 class PointerLValue : public SPIRVCodeGenerator::LValue {
1579 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type,SPIRVCodeGenerator::Precision precision)1580     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type,
1581                   SPIRVCodeGenerator::Precision precision)
1582     : fGen(gen)
1583     , fPointer(pointer)
1584     , fType(type)
1585     , fPrecision(precision) {}
1586 
getPointer()1587     virtual SpvId getPointer() override {
1588         return fPointer;
1589     }
1590 
load(OutputStream & out)1591     virtual SpvId load(OutputStream& out) override {
1592         SpvId result = fGen.nextId();
1593         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1594         fGen.writePrecisionModifier(fPrecision, result);
1595         return result;
1596     }
1597 
store(SpvId value,OutputStream & out)1598     virtual void store(SpvId value, OutputStream& out) override {
1599         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1600     }
1601 
1602 private:
1603     SPIRVCodeGenerator& fGen;
1604     const SpvId fPointer;
1605     const SpvId fType;
1606     const SPIRVCodeGenerator::Precision fPrecision;
1607 };
1608 
1609 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1610 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType,SPIRVCodeGenerator::Precision precision)1611     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1612                   const Type& baseType, const Type& swizzleType,
1613                   SPIRVCodeGenerator::Precision precision)
1614     : fGen(gen)
1615     , fVecPointer(vecPointer)
1616     , fComponents(components)
1617     , fBaseType(baseType)
1618     , fSwizzleType(swizzleType)
1619     , fPrecision(precision) {}
1620 
getPointer()1621     virtual SpvId getPointer() override {
1622         return 0;
1623     }
1624 
load(OutputStream & out)1625     virtual SpvId load(OutputStream& out) override {
1626         SpvId base = fGen.nextId();
1627         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1628         fGen.writePrecisionModifier(fPrecision, base);
1629         SpvId result = fGen.nextId();
1630         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1631         fGen.writeWord(fGen.getType(fSwizzleType), out);
1632         fGen.writeWord(result, out);
1633         fGen.writeWord(base, out);
1634         fGen.writeWord(base, out);
1635         for (int component : fComponents) {
1636             fGen.writeWord(component, out);
1637         }
1638         fGen.writePrecisionModifier(fPrecision, result);
1639         return result;
1640     }
1641 
store(SpvId value,OutputStream & out)1642     virtual void store(SpvId value, OutputStream& out) override {
1643         // use OpVectorShuffle to mix and match the vector components. We effectively create
1644         // a virtual vector out of the concatenation of the left and right vectors, and then
1645         // select components from this virtual vector to make the result vector. For
1646         // instance, given:
1647         // float3L = ...;
1648         // float3R = ...;
1649         // L.xz = R.xy;
1650         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1651         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1652         // (3, 1, 4).
1653         SpvId base = fGen.nextId();
1654         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1655         SpvId shuffle = fGen.nextId();
1656         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1657         fGen.writeWord(fGen.getType(fBaseType), out);
1658         fGen.writeWord(shuffle, out);
1659         fGen.writeWord(base, out);
1660         fGen.writeWord(value, out);
1661         for (int i = 0; i < fBaseType.columns(); i++) {
1662             // current offset into the virtual vector, defaults to pulling the unmodified
1663             // value from the left side
1664             int offset = i;
1665             // check to see if we are writing this component
1666             for (size_t j = 0; j < fComponents.size(); j++) {
1667                 if (fComponents[j] == i) {
1668                     // we're writing to this component, so adjust the offset to pull from
1669                     // the correct component of the right side instead of preserving the
1670                     // value from the left
1671                     offset = (int) (j + fBaseType.columns());
1672                     break;
1673                 }
1674             }
1675             fGen.writeWord(offset, out);
1676         }
1677         fGen.writePrecisionModifier(fPrecision, shuffle);
1678         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1679     }
1680 
1681 private:
1682     SPIRVCodeGenerator& fGen;
1683     const SpvId fVecPointer;
1684     const std::vector<int>& fComponents;
1685     const Type& fBaseType;
1686     const Type& fSwizzleType;
1687     const SPIRVCodeGenerator::Precision fPrecision;
1688 };
1689 
getLValue(const Expression & expr,OutputStream & out)1690 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1691                                                                           OutputStream& out) {
1692     Precision precision = expr.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
1693     switch (expr.fKind) {
1694         case Expression::kVariableReference_Kind: {
1695             SpvId type;
1696             const Variable& var = ((VariableReference&) expr).fVariable;
1697             if (var.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
1698                 type = this->getType(Type("sk_in", Type::kArray_Kind, var.fType.componentType(),
1699                                           fSkInCount));
1700             } else {
1701                 type = this->getType(expr.fType);
1702             }
1703             auto entry = fVariableMap.find(&var);
1704             SkASSERT(entry != fVariableMap.end());
1705             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(*this,
1706                                                                                  entry->second,
1707                                                                                  type,
1708                                                                                  precision));
1709         }
1710         case Expression::kIndex_Kind: // fall through
1711         case Expression::kFieldAccess_Kind: {
1712             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1713             SpvId member = this->nextId();
1714             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1715             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1716             this->writeWord(member, out);
1717             for (SpvId idx : chain) {
1718                 this->writeWord(idx, out);
1719             }
1720             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1721                                                                         *this,
1722                                                                         member,
1723                                                                         this->getType(expr.fType),
1724                                                                         precision));
1725         }
1726         case Expression::kSwizzle_Kind: {
1727             Swizzle& swizzle = (Swizzle&) expr;
1728             size_t count = swizzle.fComponents.size();
1729             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1730             SkASSERT(base);
1731             if (count == 1) {
1732                 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1733                 SpvId member = this->nextId();
1734                 this->writeInstruction(SpvOpAccessChain,
1735                                        this->getPointerType(swizzle.fType,
1736                                                             get_storage_class(*swizzle.fBase)),
1737                                        member,
1738                                        base,
1739                                        this->writeIntLiteral(index),
1740                                        out);
1741                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1742                                                                        *this,
1743                                                                        member,
1744                                                                        this->getType(expr.fType),
1745                                                                        precision));
1746             } else {
1747                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1748                                                                               *this,
1749                                                                               base,
1750                                                                               swizzle.fComponents,
1751                                                                               swizzle.fBase->fType,
1752                                                                               expr.fType,
1753                                                                               precision));
1754             }
1755         }
1756         case Expression::kTernary_Kind: {
1757             TernaryExpression& t = (TernaryExpression&) expr;
1758             SpvId test = this->writeExpression(*t.fTest, out);
1759             SpvId end = this->nextId();
1760             SpvId ifTrueLabel = this->nextId();
1761             SpvId ifFalseLabel = this->nextId();
1762             this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1763             this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1764             this->writeLabel(ifTrueLabel, out);
1765             SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1766             SkASSERT(ifTrue);
1767             this->writeInstruction(SpvOpBranch, end, out);
1768             ifTrueLabel = fCurrentBlock;
1769             SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1770             SkASSERT(ifFalse);
1771             ifFalseLabel = fCurrentBlock;
1772             this->writeInstruction(SpvOpBranch, end, out);
1773             SpvId result = this->nextId();
1774             this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1775                        ifTrueLabel, ifFalse, ifFalseLabel, out);
1776             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1777                                                                        *this,
1778                                                                        result,
1779                                                                        this->getType(expr.fType),
1780                                                                        precision));
1781         }
1782         default:
1783             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1784             // to the need to store values in temporary variables during function calls (see
1785             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1786             // caught by IRGenerator
1787             SpvId result = this->nextId();
1788             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1789             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1790                                    fVariableBuffer);
1791             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1792             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1793                                                                        *this,
1794                                                                        result,
1795                                                                        this->getType(expr.fType),
1796                                                                        precision));
1797     }
1798 }
1799 
writeVariableReference(const VariableReference & ref,OutputStream & out)1800 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1801     SpvId result = this->nextId();
1802     auto entry = fVariableMap.find(&ref.fVariable);
1803     SkASSERT(entry != fVariableMap.end());
1804     SpvId var = entry->second;
1805     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1806     this->writePrecisionModifier(ref.fVariable.fType, result);
1807     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1808         fProgram.fSettings.fFlipY) {
1809         // need to remap to a top-left coordinate system
1810         if (fRTHeightStructId == (SpvId) -1) {
1811             // height variable hasn't been written yet
1812             std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1813             SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
1814             std::vector<Type::Field> fields;
1815             SkASSERT(fProgram.fSettings.fRTHeightOffset >= 0);
1816             fields.emplace_back(Modifiers(Layout(0, -1, fProgram.fSettings.fRTHeightOffset, -1,
1817                                                  -1, -1, -1, -1, Layout::Format::kUnspecified,
1818                                                  Layout::kUnspecified_Primitive, -1, -1, "",
1819                                                  Layout::kNo_Key, Layout::CType::kDefault), 0),
1820                                 SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1821             StringFragment name("sksl_synthetic_uniforms");
1822             Type intfStruct(-1, name, fields);
1823             int binding;
1824             int set;
1825 #ifdef SK_VULKAN
1826             const GrVkCaps* vkCaps = fProgram.fSettings.fVkCaps;
1827             SkASSERT(vkCaps);
1828             binding = vkCaps->getFragmentUniformBinding();
1829             set = vkCaps->getFragmentUniformSet();
1830 #else
1831             binding = 0;
1832             set = 0;
1833 #endif
1834             Layout layout(0, -1, -1, binding, -1, set, -1, -1, Layout::Format::kUnspecified,
1835                           Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1836                           Layout::CType::kDefault);
1837             Variable* intfVar = (Variable*) fSynthetics.takeOwnership(std::unique_ptr<Symbol>(
1838                                            new Variable(-1,
1839                                                         Modifiers(layout, Modifiers::kUniform_Flag),
1840                                                         name,
1841                                                         intfStruct,
1842                                                         Variable::kGlobal_Storage)));
1843             InterfaceBlock intf(-1, intfVar, name, String(""),
1844                                 std::vector<std::unique_ptr<Expression>>(), st);
1845             fRTHeightStructId = this->writeInterfaceBlock(intf);
1846             fRTHeightFieldIndex = 0;
1847         }
1848         SkASSERT(fRTHeightFieldIndex != (SpvId) -1);
1849         // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, gl_FragCoord.w)
1850         SpvId xId = this->nextId();
1851         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1852                                result, 0, out);
1853         IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1854         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1855         SpvId heightPtr = this->nextId();
1856         this->writeOpCode(SpvOpAccessChain, 5, out);
1857         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1858         this->writeWord(heightPtr, out);
1859         this->writeWord(fRTHeightStructId, out);
1860         this->writeWord(fieldIndexId, out);
1861         SpvId heightRead = this->nextId();
1862         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1863                                heightPtr, out);
1864         SpvId rawYId = this->nextId();
1865         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1866                                result, 1, out);
1867         SpvId flippedYId = this->nextId();
1868         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1869                                heightRead, rawYId, out);
1870         FloatLiteral zero(fContext, -1, 0.0);
1871         SpvId zeroId = writeFloatLiteral(zero);
1872         FloatLiteral one(fContext, -1, 1.0);
1873         SpvId wId = this->nextId();
1874         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), wId,
1875                                result, 3, out);
1876         SpvId flipped = this->nextId();
1877         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1878         this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1879         this->writeWord(flipped, out);
1880         this->writeWord(xId, out);
1881         this->writeWord(flippedYId, out);
1882         this->writeWord(zeroId, out);
1883         this->writeWord(wId, out);
1884         return flipped;
1885     }
1886     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
1887         !fProgram.fSettings.fFlipY) {
1888         // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
1889         // the default convention of "counter-clockwise face is front".
1890         SpvId inverse = this->nextId();
1891         this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fBool_Type), inverse,
1892                                result, out);
1893         return inverse;
1894     }
1895     return result;
1896 }
1897 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1898 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1899     if (expr.fBase->fType.kind() == Type::Kind::kVector_Kind) {
1900         SpvId base = this->writeExpression(*expr.fBase, out);
1901         SpvId index = this->writeExpression(*expr.fIndex, out);
1902         SpvId result = this->nextId();
1903         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.fType), result, base,
1904                                index, out);
1905         return result;
1906     }
1907     return getLValue(expr, out)->load(out);
1908 }
1909 
writeFieldAccess(const FieldAccess & f,OutputStream & out)1910 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1911     return getLValue(f, out)->load(out);
1912 }
1913 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1914 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1915     SpvId base = this->writeExpression(*swizzle.fBase, out);
1916     SpvId result = this->nextId();
1917     size_t count = swizzle.fComponents.size();
1918     if (count == 1) {
1919         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1920                                swizzle.fComponents[0], out);
1921     } else {
1922         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1923         this->writeWord(this->getType(swizzle.fType), out);
1924         this->writeWord(result, out);
1925         this->writeWord(base, out);
1926         SpvId other;
1927         int last = swizzle.fComponents.back();
1928         if (last < 0) {
1929             if (!fConstantZeroOneVector) {
1930                 FloatLiteral zero(fContext, -1, 0);
1931                 SpvId zeroId = this->writeFloatLiteral(zero);
1932                 FloatLiteral one(fContext, -1, 1);
1933                 SpvId oneId = this->writeFloatLiteral(one);
1934                 SpvId type = this->getType(*fContext.fFloat2_Type);
1935                 fConstantZeroOneVector = this->nextId();
1936                 this->writeOpCode(SpvOpConstantComposite, 5, fConstantBuffer);
1937                 this->writeWord(type, fConstantBuffer);
1938                 this->writeWord(fConstantZeroOneVector, fConstantBuffer);
1939                 this->writeWord(zeroId, fConstantBuffer);
1940                 this->writeWord(oneId, fConstantBuffer);
1941             }
1942             other = fConstantZeroOneVector;
1943         } else {
1944             other = base;
1945         }
1946         this->writeWord(other, out);
1947         for (int component : swizzle.fComponents) {
1948             if (component == SKSL_SWIZZLE_0) {
1949                 this->writeWord(swizzle.fBase->fType.columns(), out);
1950             } else if (component == SKSL_SWIZZLE_1) {
1951                 this->writeWord(swizzle.fBase->fType.columns() + 1, out);
1952             } else {
1953                 this->writeWord(component, out);
1954             }
1955         }
1956     }
1957     return result;
1958 }
1959 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1960 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1961                                                const Type& operandType, SpvId lhs,
1962                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1963                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1964     SpvId result = this->nextId();
1965     if (is_float(fContext, operandType)) {
1966         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1967     } else if (is_signed(fContext, operandType)) {
1968         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1969     } else if (is_unsigned(fContext, operandType)) {
1970         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1971     } else if (operandType == *fContext.fBool_Type) {
1972         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1973         return result; // skip RelaxedPrecision check
1974     } else {
1975         ABORT("invalid operandType: %s", operandType.description().c_str());
1976     }
1977     if (getActualType(resultType) == operandType && !resultType.highPrecision()) {
1978         this->writeInstruction(SpvOpDecorate, result, SpvDecorationRelaxedPrecision,
1979                                fDecorationBuffer);
1980     }
1981     return result;
1982 }
1983 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)1984 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
1985                                      OutputStream& out) {
1986     if (operandType.kind() == Type::kVector_Kind) {
1987         SpvId result = this->nextId();
1988         this->writeInstruction(op, this->getType(*fContext.fBool_Type), result, id, out);
1989         return result;
1990     }
1991     return id;
1992 }
1993 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)1994 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1995                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
1996                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
1997                                                 OutputStream& out) {
1998     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1999     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
2000     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2001                                                                             operandType.rows(),
2002                                                                             1));
2003     SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
2004                                                                     operandType.rows(),
2005                                                                     1));
2006     SpvId boolType = this->getType(*fContext.fBool_Type);
2007     SpvId result = 0;
2008     for (int i = 0; i < operandType.columns(); i++) {
2009         SpvId columnL = this->nextId();
2010         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2011         SpvId columnR = this->nextId();
2012         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2013         SpvId compare = this->nextId();
2014         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2015         SpvId merge = this->nextId();
2016         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2017         if (result != 0) {
2018             SpvId next = this->nextId();
2019             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2020             result = next;
2021         }
2022         else {
2023             result = merge;
2024         }
2025     }
2026     return result;
2027 }
2028 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)2029 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2030                                                          SpvId rhs, SpvOp_ floatOperator,
2031                                                          SpvOp_ intOperator,
2032                                                          OutputStream& out) {
2033     SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
2034     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
2035     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2036                                                                             operandType.rows(),
2037                                                                             1));
2038     SpvId columns[4];
2039     for (int i = 0; i < operandType.columns(); i++) {
2040         SpvId columnL = this->nextId();
2041         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2042         SpvId columnR = this->nextId();
2043         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2044         columns[i] = this->nextId();
2045         this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
2046     }
2047     SpvId result = this->nextId();
2048     this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
2049     this->writeWord(this->getType(operandType), out);
2050     this->writeWord(result, out);
2051     for (int i = 0; i < operandType.columns(); i++) {
2052         this->writeWord(columns[i], out);
2053     }
2054     return result;
2055 }
2056 
create_literal_1(const Context & context,const Type & type)2057 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2058     if (type.isInteger()) {
2059         return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
2060     }
2061     else if (type.isFloat()) {
2062         return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
2063     } else {
2064         ABORT("math is unsupported on type '%s'", type.name().c_str());
2065     }
2066 }
2067 
writeBinaryExpression(const Type & leftType,SpvId lhs,Token::Kind op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2068 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
2069                                                 const Type& rightType, SpvId rhs,
2070                                                 const Type& resultType, OutputStream& out) {
2071     Type tmp("<invalid>");
2072     // overall type we are operating on: float2, int, uint4...
2073     const Type* operandType;
2074     // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2075     // handling in SPIR-V
2076     if (this->getActualType(leftType) != this->getActualType(rightType)) {
2077         if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
2078             if (op == Token::SLASH) {
2079                 SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
2080                 SpvId inverse = this->nextId();
2081                 this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
2082                 rhs = inverse;
2083                 op = Token::STAR;
2084             }
2085             if (op == Token::STAR) {
2086                 SpvId result = this->nextId();
2087                 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2088                                        result, lhs, rhs, out);
2089                 return result;
2090             }
2091             // promote number to vector
2092             SpvId vec = this->nextId();
2093             const Type& vecType = leftType;
2094             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2095             this->writeWord(this->getType(vecType), out);
2096             this->writeWord(vec, out);
2097             for (int i = 0; i < vecType.columns(); i++) {
2098                 this->writeWord(rhs, out);
2099             }
2100             rhs = vec;
2101             operandType = &leftType;
2102         } else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
2103             if (op == Token::STAR) {
2104                 SpvId result = this->nextId();
2105                 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2106                                        result, rhs, lhs, out);
2107                 return result;
2108             }
2109             // promote number to vector
2110             SpvId vec = this->nextId();
2111             const Type& vecType = rightType;
2112             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2113             this->writeWord(this->getType(vecType), out);
2114             this->writeWord(vec, out);
2115             for (int i = 0; i < vecType.columns(); i++) {
2116                 this->writeWord(lhs, out);
2117             }
2118             lhs = vec;
2119             operandType = &rightType;
2120         } else if (leftType.kind() == Type::kMatrix_Kind) {
2121             SpvOp_ spvop;
2122             if (rightType.kind() == Type::kMatrix_Kind) {
2123                 spvop = SpvOpMatrixTimesMatrix;
2124             } else if (rightType.kind() == Type::kVector_Kind) {
2125                 spvop = SpvOpMatrixTimesVector;
2126             } else {
2127                 SkASSERT(rightType.kind() == Type::kScalar_Kind);
2128                 spvop = SpvOpMatrixTimesScalar;
2129             }
2130             SpvId result = this->nextId();
2131             this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2132             return result;
2133         } else if (rightType.kind() == Type::kMatrix_Kind) {
2134             SpvId result = this->nextId();
2135             if (leftType.kind() == Type::kVector_Kind) {
2136                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
2137                                        lhs, rhs, out);
2138             } else {
2139                 SkASSERT(leftType.kind() == Type::kScalar_Kind);
2140                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
2141                                        rhs, lhs, out);
2142             }
2143             return result;
2144         } else {
2145             SkASSERT(false);
2146             return -1;
2147         }
2148     } else {
2149         tmp = this->getActualType(leftType);
2150         operandType = &tmp;
2151         SkASSERT(*operandType == this->getActualType(rightType));
2152     }
2153     switch (op) {
2154         case Token::EQEQ: {
2155             if (operandType->kind() == Type::kMatrix_Kind) {
2156                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2157                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2158             }
2159             SkASSERT(resultType == *fContext.fBool_Type);
2160             const Type* tmpType;
2161             if (operandType->kind() == Type::kVector_Kind) {
2162                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2163                                                            operandType->columns(),
2164                                                            operandType->rows());
2165             } else {
2166                 tmpType = &resultType;
2167             }
2168             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2169                                                                SpvOpFOrdEqual, SpvOpIEqual,
2170                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2171                                     *operandType, SpvOpAll, out);
2172         }
2173         case Token::NEQ:
2174             if (operandType->kind() == Type::kMatrix_Kind) {
2175                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2176                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2177             }
2178             SkASSERT(resultType == *fContext.fBool_Type);
2179             const Type* tmpType;
2180             if (operandType->kind() == Type::kVector_Kind) {
2181                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2182                                                            operandType->columns(),
2183                                                            operandType->rows());
2184             } else {
2185                 tmpType = &resultType;
2186             }
2187             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2188                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2189                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2190                                                                out),
2191                                     *operandType, SpvOpAny, out);
2192         case Token::GT:
2193             SkASSERT(resultType == *fContext.fBool_Type);
2194             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2195                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2196                                               SpvOpUGreaterThan, SpvOpUndef, out);
2197         case Token::LT:
2198             SkASSERT(resultType == *fContext.fBool_Type);
2199             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2200                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2201         case Token::GTEQ:
2202             SkASSERT(resultType == *fContext.fBool_Type);
2203             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2204                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2205                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2206         case Token::LTEQ:
2207             SkASSERT(resultType == *fContext.fBool_Type);
2208             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2209                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2210                                               SpvOpULessThanEqual, SpvOpUndef, out);
2211         case Token::PLUS:
2212             if (leftType.kind() == Type::kMatrix_Kind &&
2213                 rightType.kind() == Type::kMatrix_Kind) {
2214                 SkASSERT(leftType == rightType);
2215                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2216                                                             SpvOpFAdd, SpvOpIAdd, out);
2217             }
2218             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2219                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2220         case Token::MINUS:
2221             if (leftType.kind() == Type::kMatrix_Kind &&
2222                 rightType.kind() == Type::kMatrix_Kind) {
2223                 SkASSERT(leftType == rightType);
2224                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2225                                                             SpvOpFSub, SpvOpISub, out);
2226             }
2227             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2228                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2229         case Token::STAR:
2230             if (leftType.kind() == Type::kMatrix_Kind &&
2231                 rightType.kind() == Type::kMatrix_Kind) {
2232                 // matrix multiply
2233                 SpvId result = this->nextId();
2234                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2235                                        lhs, rhs, out);
2236                 return result;
2237             }
2238             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2239                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2240         case Token::SLASH:
2241             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2242                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2243         case Token::PERCENT:
2244             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2245                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2246         case Token::SHL:
2247             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2248                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2249                                               SpvOpUndef, out);
2250         case Token::SHR:
2251             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2252                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2253                                               SpvOpUndef, out);
2254         case Token::BITWISEAND:
2255             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2256                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2257         case Token::BITWISEOR:
2258             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2259                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2260         case Token::BITWISEXOR:
2261             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2262                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2263         case Token::COMMA:
2264             return rhs;
2265         default:
2266             SkASSERT(false);
2267             return -1;
2268     }
2269 }
2270 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2271 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2272     // handle cases where we don't necessarily evaluate both LHS and RHS
2273     switch (b.fOperator) {
2274         case Token::EQ: {
2275             SpvId rhs = this->writeExpression(*b.fRight, out);
2276             this->getLValue(*b.fLeft, out)->store(rhs, out);
2277             return rhs;
2278         }
2279         case Token::LOGICALAND:
2280             return this->writeLogicalAnd(b, out);
2281         case Token::LOGICALOR:
2282             return this->writeLogicalOr(b, out);
2283         default:
2284             break;
2285     }
2286 
2287     std::unique_ptr<LValue> lvalue;
2288     SpvId lhs;
2289     if (is_assignment(b.fOperator)) {
2290         lvalue = this->getLValue(*b.fLeft, out);
2291         lhs = lvalue->load(out);
2292     } else {
2293         lvalue = nullptr;
2294         lhs = this->writeExpression(*b.fLeft, out);
2295     }
2296     SpvId rhs = this->writeExpression(*b.fRight, out);
2297     SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
2298                                                b.fRight->fType, rhs, b.fType, out);
2299     if (lvalue) {
2300         lvalue->store(result, out);
2301     }
2302     return result;
2303 }
2304 
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2305 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2306     SkASSERT(a.fOperator == Token::LOGICALAND);
2307     BoolLiteral falseLiteral(fContext, -1, false);
2308     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2309     SpvId lhs = this->writeExpression(*a.fLeft, out);
2310     SpvId rhsLabel = this->nextId();
2311     SpvId end = this->nextId();
2312     SpvId lhsBlock = fCurrentBlock;
2313     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2314     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2315     this->writeLabel(rhsLabel, out);
2316     SpvId rhs = this->writeExpression(*a.fRight, out);
2317     SpvId rhsBlock = fCurrentBlock;
2318     this->writeInstruction(SpvOpBranch, end, out);
2319     this->writeLabel(end, out);
2320     SpvId result = this->nextId();
2321     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2322                            lhsBlock, rhs, rhsBlock, out);
2323     return result;
2324 }
2325 
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2326 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2327     SkASSERT(o.fOperator == Token::LOGICALOR);
2328     BoolLiteral trueLiteral(fContext, -1, true);
2329     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2330     SpvId lhs = this->writeExpression(*o.fLeft, out);
2331     SpvId rhsLabel = this->nextId();
2332     SpvId end = this->nextId();
2333     SpvId lhsBlock = fCurrentBlock;
2334     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2335     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2336     this->writeLabel(rhsLabel, out);
2337     SpvId rhs = this->writeExpression(*o.fRight, out);
2338     SpvId rhsBlock = fCurrentBlock;
2339     this->writeInstruction(SpvOpBranch, end, out);
2340     this->writeLabel(end, out);
2341     SpvId result = this->nextId();
2342     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2343                            lhsBlock, rhs, rhsBlock, out);
2344     return result;
2345 }
2346 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2347 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2348     SpvId test = this->writeExpression(*t.fTest, out);
2349     if (t.fIfTrue->fType.columns() == 1 && t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2350         // both true and false are constants, can just use OpSelect
2351         SpvId result = this->nextId();
2352         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2353         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2354         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2355                                out);
2356         return result;
2357     }
2358     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2359     // Adreno. Switched to storing the result in a temp variable as glslang does.
2360     SpvId var = this->nextId();
2361     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2362                            var, SpvStorageClassFunction, fVariableBuffer);
2363     SpvId trueLabel = this->nextId();
2364     SpvId falseLabel = this->nextId();
2365     SpvId end = this->nextId();
2366     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2367     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2368     this->writeLabel(trueLabel, out);
2369     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2370     this->writeInstruction(SpvOpBranch, end, out);
2371     this->writeLabel(falseLabel, out);
2372     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2373     this->writeInstruction(SpvOpBranch, end, out);
2374     this->writeLabel(end, out);
2375     SpvId result = this->nextId();
2376     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2377     this->writePrecisionModifier(t.fType, result);
2378     return result;
2379 }
2380 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2381 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2382     if (p.fOperator == Token::MINUS) {
2383         SpvId result = this->nextId();
2384         SpvId typeId = this->getType(p.fType);
2385         SpvId expr = this->writeExpression(*p.fOperand, out);
2386         if (is_float(fContext, p.fType)) {
2387             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2388         } else if (is_signed(fContext, p.fType)) {
2389             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2390         } else {
2391             ABORT("unsupported prefix expression %s", p.description().c_str());
2392         }
2393         this->writePrecisionModifier(p.fType, result);
2394         return result;
2395     }
2396     switch (p.fOperator) {
2397         case Token::PLUS:
2398             return this->writeExpression(*p.fOperand, out);
2399         case Token::PLUSPLUS: {
2400             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2401             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2402             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2403                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2404                                                       out);
2405             lv->store(result, out);
2406             return result;
2407         }
2408         case Token::MINUSMINUS: {
2409             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2410             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2411             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2412                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2413                                                       out);
2414             lv->store(result, out);
2415             return result;
2416         }
2417         case Token::LOGICALNOT: {
2418             SkASSERT(p.fOperand->fType == *fContext.fBool_Type);
2419             SpvId result = this->nextId();
2420             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2421                                    this->writeExpression(*p.fOperand, out), out);
2422             return result;
2423         }
2424         case Token::BITWISENOT: {
2425             SpvId result = this->nextId();
2426             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2427                                    this->writeExpression(*p.fOperand, out), out);
2428             return result;
2429         }
2430         default:
2431             ABORT("unsupported prefix expression: %s", p.description().c_str());
2432     }
2433 }
2434 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2435 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2436     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2437     SpvId result = lv->load(out);
2438     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2439     switch (p.fOperator) {
2440         case Token::PLUSPLUS: {
2441             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2442                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2443             lv->store(temp, out);
2444             return result;
2445         }
2446         case Token::MINUSMINUS: {
2447             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2448                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2449             lv->store(temp, out);
2450             return result;
2451         }
2452         default:
2453             ABORT("unsupported postfix expression %s", p.description().c_str());
2454     }
2455 }
2456 
writeBoolLiteral(const BoolLiteral & b)2457 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2458     if (b.fValue) {
2459         if (fBoolTrue == 0) {
2460             fBoolTrue = this->nextId();
2461             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2462                                    fConstantBuffer);
2463         }
2464         return fBoolTrue;
2465     } else {
2466         if (fBoolFalse == 0) {
2467             fBoolFalse = this->nextId();
2468             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2469                                    fConstantBuffer);
2470         }
2471         return fBoolFalse;
2472     }
2473 }
2474 
writeIntLiteral(const IntLiteral & i)2475 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2476     ConstantType type;
2477     if (i.fType == *fContext.fInt_Type) {
2478         type = ConstantType::kInt;
2479     } else if (i.fType == *fContext.fUInt_Type) {
2480         type = ConstantType::kUInt;
2481     } else if (i.fType == *fContext.fShort_Type) {
2482         type = ConstantType::kShort;
2483     } else if (i.fType == *fContext.fUShort_Type) {
2484         type = ConstantType::kUShort;
2485     }
2486     std::pair<ConstantValue, ConstantType> key(i.fValue, type);
2487     auto entry = fNumberConstants.find(key);
2488     if (entry == fNumberConstants.end()) {
2489         SpvId result = this->nextId();
2490         this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2491                                fConstantBuffer);
2492         fNumberConstants[key] = result;
2493         return result;
2494     }
2495     return entry->second;
2496 }
2497 
writeFloatLiteral(const FloatLiteral & f)2498 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2499     if (f.fType != *fContext.fDouble_Type) {
2500         ConstantType type;
2501         if (f.fType == *fContext.fHalf_Type) {
2502             type = ConstantType::kHalf;
2503         } else {
2504             type = ConstantType::kFloat;
2505         }
2506         float value = (float) f.fValue;
2507         std::pair<ConstantValue, ConstantType> key(f.fValue, type);
2508         auto entry = fNumberConstants.find(key);
2509         if (entry == fNumberConstants.end()) {
2510             SpvId result = this->nextId();
2511             uint32_t bits;
2512             SkASSERT(sizeof(bits) == sizeof(value));
2513             memcpy(&bits, &value, sizeof(bits));
2514             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2515                                    fConstantBuffer);
2516             fNumberConstants[key] = result;
2517             return result;
2518         }
2519         return entry->second;
2520     } else {
2521         std::pair<ConstantValue, ConstantType> key(f.fValue, ConstantType::kDouble);
2522         auto entry = fNumberConstants.find(key);
2523         if (entry == fNumberConstants.end()) {
2524             SpvId result = this->nextId();
2525             uint64_t bits;
2526             SkASSERT(sizeof(bits) == sizeof(f.fValue));
2527             memcpy(&bits, &f.fValue, sizeof(bits));
2528             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2529                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2530             fNumberConstants[key] = result;
2531             return result;
2532         }
2533         return entry->second;
2534     }
2535 }
2536 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2537 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2538     SpvId result = fFunctionMap[&f];
2539     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2540                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2541     this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2542     for (size_t i = 0; i < f.fParameters.size(); i++) {
2543         SpvId id = this->nextId();
2544         fVariableMap[f.fParameters[i]] = id;
2545         SpvId type;
2546         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2547         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2548     }
2549     return result;
2550 }
2551 
writeFunction(const FunctionDefinition & f,OutputStream & out)2552 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2553     fVariableBuffer.reset();
2554     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2555     this->writeLabel(this->nextId(), out);
2556     StringStream bodyBuffer;
2557     this->writeBlock((Block&) *f.fBody, bodyBuffer);
2558     write_stringstream(fVariableBuffer, out);
2559     if (f.fDeclaration.fName == "main") {
2560         write_stringstream(fGlobalInitializersBuffer, out);
2561     }
2562     write_stringstream(bodyBuffer, out);
2563     if (fCurrentBlock) {
2564         if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2565             this->writeInstruction(SpvOpReturn, out);
2566         } else {
2567             this->writeInstruction(SpvOpUnreachable, out);
2568         }
2569     }
2570     this->writeInstruction(SpvOpFunctionEnd, out);
2571     return result;
2572 }
2573 
writeLayout(const Layout & layout,SpvId target)2574 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2575     if (layout.fLocation >= 0) {
2576         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2577                                fDecorationBuffer);
2578     }
2579     if (layout.fBinding >= 0) {
2580         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2581                                fDecorationBuffer);
2582     }
2583     if (layout.fIndex >= 0) {
2584         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2585                                fDecorationBuffer);
2586     }
2587     if (layout.fSet >= 0) {
2588         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2589                                fDecorationBuffer);
2590     }
2591     if (layout.fInputAttachmentIndex >= 0) {
2592         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2593                                layout.fInputAttachmentIndex, fDecorationBuffer);
2594         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2595     }
2596     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2597         layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
2598         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2599                                fDecorationBuffer);
2600     }
2601 }
2602 
writeLayout(const Layout & layout,SpvId target,int member)2603 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2604     if (layout.fLocation >= 0) {
2605         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2606                                layout.fLocation, fDecorationBuffer);
2607     }
2608     if (layout.fBinding >= 0) {
2609         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2610                                layout.fBinding, fDecorationBuffer);
2611     }
2612     if (layout.fIndex >= 0) {
2613         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2614                                layout.fIndex, fDecorationBuffer);
2615     }
2616     if (layout.fSet >= 0) {
2617         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2618                                layout.fSet, fDecorationBuffer);
2619     }
2620     if (layout.fInputAttachmentIndex >= 0) {
2621         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2622                                layout.fInputAttachmentIndex, fDecorationBuffer);
2623     }
2624     if (layout.fBuiltin >= 0) {
2625         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2626                                layout.fBuiltin, fDecorationBuffer);
2627     }
2628 }
2629 
update_sk_in_count(const Modifiers & m,int * outSkInCount)2630 static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
2631     switch (m.fLayout.fPrimitive) {
2632         case Layout::kPoints_Primitive:
2633             *outSkInCount = 1;
2634             break;
2635         case Layout::kLines_Primitive:
2636             *outSkInCount = 2;
2637             break;
2638         case Layout::kLinesAdjacency_Primitive:
2639             *outSkInCount = 4;
2640             break;
2641         case Layout::kTriangles_Primitive:
2642             *outSkInCount = 3;
2643             break;
2644         case Layout::kTrianglesAdjacency_Primitive:
2645             *outSkInCount = 6;
2646             break;
2647         default:
2648             return;
2649     }
2650 }
2651 
writeInterfaceBlock(const InterfaceBlock & intf)2652 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2653     bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2654     bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2655                                Layout::kPushConstant_Flag));
2656     MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2657                                 MemoryLayout(MemoryLayout::k430_Standard) :
2658                                 fDefaultLayout;
2659     SpvId result = this->nextId();
2660     const Type* type = &intf.fVariable.fType;
2661     if (fProgram.fInputs.fRTHeight) {
2662         SkASSERT(fRTHeightStructId == (SpvId) -1);
2663         SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
2664         std::vector<Type::Field> fields = type->fields();
2665         fRTHeightStructId = result;
2666         fRTHeightFieldIndex = fields.size();
2667         fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2668         type = new Type(type->fOffset, type->name(), fields);
2669     }
2670     SpvId typeId;
2671     if (intf.fVariable.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2672         for (const auto& e : fProgram) {
2673             if (e.fKind == ProgramElement::kModifiers_Kind) {
2674                 const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
2675                 update_sk_in_count(m, &fSkInCount);
2676             }
2677         }
2678         typeId = this->getType(Type("sk_in", Type::kArray_Kind, intf.fVariable.fType.componentType(),
2679                                   fSkInCount), memoryLayout);
2680     } else {
2681         typeId = this->getType(*type, memoryLayout);
2682     }
2683     if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2684         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2685     } else if (intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
2686         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2687     }
2688     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2689     SpvId ptrType = this->nextId();
2690     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2691     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2692     Layout layout = intf.fVariable.fModifiers.fLayout;
2693     if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2694         layout.fSet = 0;
2695     }
2696     this->writeLayout(layout, result);
2697     fVariableMap[&intf.fVariable] = result;
2698     if (fProgram.fInputs.fRTHeight) {
2699         delete type;
2700     }
2701     return result;
2702 }
2703 
writePrecisionModifier(const Type & type,SpvId id)2704 void SPIRVCodeGenerator::writePrecisionModifier(const Type& type, SpvId id) {
2705     this->writePrecisionModifier(type.highPrecision() ? Precision::kHigh : Precision::kLow, id);
2706 }
2707 
writePrecisionModifier(Precision precision,SpvId id)2708 void SPIRVCodeGenerator::writePrecisionModifier(Precision precision, SpvId id) {
2709     if (precision == Precision::kLow) {
2710         this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2711     }
2712 }
2713 
2714 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2715 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2716                                          OutputStream& out) {
2717     for (size_t i = 0; i < decl.fVars.size(); i++) {
2718         if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2719             continue;
2720         }
2721         const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2722         const Variable* var = varDecl.fVar;
2723         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2724         // in the OpenGL backend.
2725         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2726                                            Modifiers::kWriteOnly_Flag |
2727                                            Modifiers::kCoherent_Flag |
2728                                            Modifiers::kVolatile_Flag |
2729                                            Modifiers::kRestrict_Flag)));
2730         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2731             continue;
2732         }
2733         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2734             kind != Program::kFragment_Kind) {
2735             SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
2736             continue;
2737         }
2738         if (!var->fReadCount && !var->fWriteCount &&
2739                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2740                                             Modifiers::kOut_Flag |
2741                                             Modifiers::kUniform_Flag |
2742                                             Modifiers::kBuffer_Flag))) {
2743             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2744             // we elide an interface var, even if it's dead)
2745             continue;
2746         }
2747         SpvStorageClass_ storageClass;
2748         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2749             storageClass = SpvStorageClassInput;
2750         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2751             storageClass = SpvStorageClassOutput;
2752         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2753             if (var->fType.kind() == Type::kSampler_Kind ||
2754                 var->fType.kind() == Type::kSeparateSampler_Kind ||
2755                 var->fType.kind() == Type::kTexture_Kind) {
2756                 storageClass = SpvStorageClassUniformConstant;
2757             } else {
2758                 storageClass = SpvStorageClassUniform;
2759             }
2760         } else {
2761             storageClass = SpvStorageClassPrivate;
2762         }
2763         SpvId id = this->nextId();
2764         fVariableMap[var] = id;
2765         SpvId type;
2766         if (var->fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2767             type = this->getPointerType(Type("sk_in", Type::kArray_Kind,
2768                                              var->fType.componentType(), fSkInCount),
2769                                         storageClass);
2770         } else {
2771             type = this->getPointerType(var->fType, storageClass);
2772         }
2773         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2774         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2775         this->writePrecisionModifier(var->fType, id);
2776         if (varDecl.fValue) {
2777             SkASSERT(!fCurrentBlock);
2778             fCurrentBlock = -1;
2779             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2780             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2781             fCurrentBlock = 0;
2782         }
2783         this->writeLayout(var->fModifiers.fLayout, id);
2784         if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2785             this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2786         }
2787         if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2788             this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2789                                    fDecorationBuffer);
2790         }
2791     }
2792 }
2793 
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2794 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2795     for (const auto& stmt : decl.fVars) {
2796         SkASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2797         VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2798         const Variable* var = varDecl.fVar;
2799         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2800         // in the OpenGL backend.
2801         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2802                                            Modifiers::kWriteOnly_Flag |
2803                                            Modifiers::kCoherent_Flag |
2804                                            Modifiers::kVolatile_Flag |
2805                                            Modifiers::kRestrict_Flag)));
2806         SpvId id = this->nextId();
2807         fVariableMap[var] = id;
2808         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2809         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2810         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2811         if (varDecl.fValue) {
2812             SpvId value = this->writeExpression(*varDecl.fValue, out);
2813             this->writeInstruction(SpvOpStore, id, value, out);
2814         }
2815     }
2816 }
2817 
writeStatement(const Statement & s,OutputStream & out)2818 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2819     switch (s.fKind) {
2820         case Statement::kNop_Kind:
2821             break;
2822         case Statement::kBlock_Kind:
2823             this->writeBlock((Block&) s, out);
2824             break;
2825         case Statement::kExpression_Kind:
2826             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2827             break;
2828         case Statement::kReturn_Kind:
2829             this->writeReturnStatement((ReturnStatement&) s, out);
2830             break;
2831         case Statement::kVarDeclarations_Kind:
2832             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2833             break;
2834         case Statement::kIf_Kind:
2835             this->writeIfStatement((IfStatement&) s, out);
2836             break;
2837         case Statement::kFor_Kind:
2838             this->writeForStatement((ForStatement&) s, out);
2839             break;
2840         case Statement::kWhile_Kind:
2841             this->writeWhileStatement((WhileStatement&) s, out);
2842             break;
2843         case Statement::kDo_Kind:
2844             this->writeDoStatement((DoStatement&) s, out);
2845             break;
2846         case Statement::kSwitch_Kind:
2847             this->writeSwitchStatement((SwitchStatement&) s, out);
2848             break;
2849         case Statement::kBreak_Kind:
2850             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2851             break;
2852         case Statement::kContinue_Kind:
2853             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2854             break;
2855         case Statement::kDiscard_Kind:
2856             this->writeInstruction(SpvOpKill, out);
2857             break;
2858         default:
2859             ABORT("unsupported statement: %s", s.description().c_str());
2860     }
2861 }
2862 
writeBlock(const Block & b,OutputStream & out)2863 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2864     for (size_t i = 0; i < b.fStatements.size(); i++) {
2865         this->writeStatement(*b.fStatements[i], out);
2866     }
2867 }
2868 
writeIfStatement(const IfStatement & stmt,OutputStream & out)2869 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2870     SpvId test = this->writeExpression(*stmt.fTest, out);
2871     SpvId ifTrue = this->nextId();
2872     SpvId ifFalse = this->nextId();
2873     if (stmt.fIfFalse) {
2874         SpvId end = this->nextId();
2875         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2876         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2877         this->writeLabel(ifTrue, out);
2878         this->writeStatement(*stmt.fIfTrue, out);
2879         if (fCurrentBlock) {
2880             this->writeInstruction(SpvOpBranch, end, out);
2881         }
2882         this->writeLabel(ifFalse, out);
2883         this->writeStatement(*stmt.fIfFalse, out);
2884         if (fCurrentBlock) {
2885             this->writeInstruction(SpvOpBranch, end, out);
2886         }
2887         this->writeLabel(end, out);
2888     } else {
2889         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2890         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2891         this->writeLabel(ifTrue, out);
2892         this->writeStatement(*stmt.fIfTrue, out);
2893         if (fCurrentBlock) {
2894             this->writeInstruction(SpvOpBranch, ifFalse, out);
2895         }
2896         this->writeLabel(ifFalse, out);
2897     }
2898 }
2899 
writeForStatement(const ForStatement & f,OutputStream & out)2900 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2901     if (f.fInitializer) {
2902         this->writeStatement(*f.fInitializer, out);
2903     }
2904     SpvId header = this->nextId();
2905     SpvId start = this->nextId();
2906     SpvId body = this->nextId();
2907     SpvId next = this->nextId();
2908     fContinueTarget.push(next);
2909     SpvId end = this->nextId();
2910     fBreakTarget.push(end);
2911     this->writeInstruction(SpvOpBranch, header, out);
2912     this->writeLabel(header, out);
2913     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2914     this->writeInstruction(SpvOpBranch, start, out);
2915     this->writeLabel(start, out);
2916     if (f.fTest) {
2917         SpvId test = this->writeExpression(*f.fTest, out);
2918         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2919     }
2920     this->writeLabel(body, out);
2921     this->writeStatement(*f.fStatement, out);
2922     if (fCurrentBlock) {
2923         this->writeInstruction(SpvOpBranch, next, out);
2924     }
2925     this->writeLabel(next, out);
2926     if (f.fNext) {
2927         this->writeExpression(*f.fNext, out);
2928     }
2929     this->writeInstruction(SpvOpBranch, header, out);
2930     this->writeLabel(end, out);
2931     fBreakTarget.pop();
2932     fContinueTarget.pop();
2933 }
2934 
writeWhileStatement(const WhileStatement & w,OutputStream & out)2935 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2936     SpvId header = this->nextId();
2937     SpvId start = this->nextId();
2938     SpvId body = this->nextId();
2939     SpvId continueTarget = this->nextId();
2940     fContinueTarget.push(continueTarget);
2941     SpvId end = this->nextId();
2942     fBreakTarget.push(end);
2943     this->writeInstruction(SpvOpBranch, header, out);
2944     this->writeLabel(header, out);
2945     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
2946     this->writeInstruction(SpvOpBranch, start, out);
2947     this->writeLabel(start, out);
2948     SpvId test = this->writeExpression(*w.fTest, out);
2949     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2950     this->writeLabel(body, out);
2951     this->writeStatement(*w.fStatement, out);
2952     if (fCurrentBlock) {
2953         this->writeInstruction(SpvOpBranch, continueTarget, out);
2954     }
2955     this->writeLabel(continueTarget, out);
2956     this->writeInstruction(SpvOpBranch, header, out);
2957     this->writeLabel(end, out);
2958     fBreakTarget.pop();
2959     fContinueTarget.pop();
2960 }
2961 
writeDoStatement(const DoStatement & d,OutputStream & out)2962 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2963     // We believe the do loop code below will work, but Skia doesn't actually use them and
2964     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2965     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2966     // message, simply remove the error call below to see whether our do loop support actually
2967     // works.
2968     fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2969                   "SkSLSPIRVCodeGenerator.cpp for details");
2970 
2971     SpvId header = this->nextId();
2972     SpvId start = this->nextId();
2973     SpvId next = this->nextId();
2974     SpvId continueTarget = this->nextId();
2975     fContinueTarget.push(continueTarget);
2976     SpvId end = this->nextId();
2977     fBreakTarget.push(end);
2978     this->writeInstruction(SpvOpBranch, header, out);
2979     this->writeLabel(header, out);
2980     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
2981     this->writeInstruction(SpvOpBranch, start, out);
2982     this->writeLabel(start, out);
2983     this->writeStatement(*d.fStatement, out);
2984     if (fCurrentBlock) {
2985         this->writeInstruction(SpvOpBranch, next, out);
2986     }
2987     this->writeLabel(next, out);
2988     SpvId test = this->writeExpression(*d.fTest, out);
2989     this->writeInstruction(SpvOpBranchConditional, test, continueTarget, end, out);
2990     this->writeLabel(continueTarget, out);
2991     this->writeInstruction(SpvOpBranch, header, out);
2992     this->writeLabel(end, out);
2993     fBreakTarget.pop();
2994     fContinueTarget.pop();
2995 }
2996 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2997 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2998     SpvId value = this->writeExpression(*s.fValue, out);
2999     std::vector<SpvId> labels;
3000     SpvId end = this->nextId();
3001     SpvId defaultLabel = end;
3002     fBreakTarget.push(end);
3003     int size = 3;
3004     for (const auto& c : s.fCases) {
3005         SpvId label = this->nextId();
3006         labels.push_back(label);
3007         if (c->fValue) {
3008             size += 2;
3009         } else {
3010             defaultLabel = label;
3011         }
3012     }
3013     labels.push_back(end);
3014     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3015     this->writeOpCode(SpvOpSwitch, size, out);
3016     this->writeWord(value, out);
3017     this->writeWord(defaultLabel, out);
3018     for (size_t i = 0; i < s.fCases.size(); ++i) {
3019         if (!s.fCases[i]->fValue) {
3020             continue;
3021         }
3022         SkASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
3023         this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
3024         this->writeWord(labels[i], out);
3025     }
3026     for (size_t i = 0; i < s.fCases.size(); ++i) {
3027         this->writeLabel(labels[i], out);
3028         for (const auto& stmt : s.fCases[i]->fStatements) {
3029             this->writeStatement(*stmt, out);
3030         }
3031         if (fCurrentBlock) {
3032             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3033         }
3034     }
3035     this->writeLabel(end, out);
3036     fBreakTarget.pop();
3037 }
3038 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3039 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3040     if (r.fExpression) {
3041         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3042                                out);
3043     } else {
3044         this->writeInstruction(SpvOpReturn, out);
3045     }
3046 }
3047 
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)3048 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
3049     SkASSERT(fProgram.fKind == Program::kGeometry_Kind);
3050     int invocations = 1;
3051     for (const auto& e : fProgram) {
3052         if (e.fKind == ProgramElement::kModifiers_Kind) {
3053             const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3054             if (m.fFlags & Modifiers::kIn_Flag) {
3055                 if (m.fLayout.fInvocations != -1) {
3056                     invocations = m.fLayout.fInvocations;
3057                 }
3058                 SpvId input;
3059                 switch (m.fLayout.fPrimitive) {
3060                     case Layout::kPoints_Primitive:
3061                         input = SpvExecutionModeInputPoints;
3062                         break;
3063                     case Layout::kLines_Primitive:
3064                         input = SpvExecutionModeInputLines;
3065                         break;
3066                     case Layout::kLinesAdjacency_Primitive:
3067                         input = SpvExecutionModeInputLinesAdjacency;
3068                         break;
3069                     case Layout::kTriangles_Primitive:
3070                         input = SpvExecutionModeTriangles;
3071                         break;
3072                     case Layout::kTrianglesAdjacency_Primitive:
3073                         input = SpvExecutionModeInputTrianglesAdjacency;
3074                         break;
3075                     default:
3076                         input = 0;
3077                         break;
3078                 }
3079                 update_sk_in_count(m, &fSkInCount);
3080                 if (input) {
3081                     this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
3082                 }
3083             } else if (m.fFlags & Modifiers::kOut_Flag) {
3084                 SpvId output;
3085                 switch (m.fLayout.fPrimitive) {
3086                     case Layout::kPoints_Primitive:
3087                         output = SpvExecutionModeOutputPoints;
3088                         break;
3089                     case Layout::kLineStrip_Primitive:
3090                         output = SpvExecutionModeOutputLineStrip;
3091                         break;
3092                     case Layout::kTriangleStrip_Primitive:
3093                         output = SpvExecutionModeOutputTriangleStrip;
3094                         break;
3095                     default:
3096                         output = 0;
3097                         break;
3098                 }
3099                 if (output) {
3100                     this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
3101                 }
3102                 if (m.fLayout.fMaxVertices != -1) {
3103                     this->writeInstruction(SpvOpExecutionMode, entryPoint,
3104                                            SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
3105                                            out);
3106                 }
3107             }
3108         }
3109     }
3110     this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
3111                            invocations, out);
3112 }
3113 
writeInstructions(const Program & program,OutputStream & out)3114 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3115     fGLSLExtendedInstructions = this->nextId();
3116     StringStream body;
3117     std::set<SpvId> interfaceVars;
3118     // assign IDs to functions, determine sk_in size
3119     int skInSize = -1;
3120     for (const auto& e : program) {
3121         switch (e.fKind) {
3122             case ProgramElement::kFunction_Kind: {
3123                 FunctionDefinition& f = (FunctionDefinition&) e;
3124                 fFunctionMap[&f.fDeclaration] = this->nextId();
3125                 break;
3126             }
3127             case ProgramElement::kModifiers_Kind: {
3128                 Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3129                 if (m.fFlags & Modifiers::kIn_Flag) {
3130                     switch (m.fLayout.fPrimitive) {
3131                         case Layout::kPoints_Primitive: // break
3132                         case Layout::kLines_Primitive:
3133                             skInSize = 1;
3134                             break;
3135                         case Layout::kLinesAdjacency_Primitive: // break
3136                             skInSize = 2;
3137                             break;
3138                         case Layout::kTriangles_Primitive: // break
3139                         case Layout::kTrianglesAdjacency_Primitive:
3140                             skInSize = 3;
3141                             break;
3142                         default:
3143                             break;
3144                     }
3145                 }
3146                 break;
3147             }
3148             default:
3149                 break;
3150         }
3151     }
3152     for (const auto& e : program) {
3153         if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
3154             InterfaceBlock& intf = (InterfaceBlock&) e;
3155             if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
3156                 SkASSERT(skInSize != -1);
3157                 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
3158             }
3159             SpvId id = this->writeInterfaceBlock(intf);
3160             if (((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3161                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) &&
3162                 intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
3163                 interfaceVars.insert(id);
3164             }
3165         }
3166     }
3167     for (const auto& e : program) {
3168         if (e.fKind == ProgramElement::kVar_Kind) {
3169             this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
3170         }
3171     }
3172     for (const auto& e : program) {
3173         if (e.fKind == ProgramElement::kFunction_Kind) {
3174             this->writeFunction(((FunctionDefinition&) e), body);
3175         }
3176     }
3177     const FunctionDeclaration* main = nullptr;
3178     for (auto entry : fFunctionMap) {
3179         if (entry.first->fName == "main") {
3180             main = entry.first;
3181         }
3182     }
3183     if (!main) {
3184         fErrors.error(0, "program does not contain a main() function");
3185         return;
3186     }
3187     for (auto entry : fVariableMap) {
3188         const Variable* var = entry.first;
3189         if (var->fStorage == Variable::kGlobal_Storage &&
3190             ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3191              (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3192             interfaceVars.insert(entry.second);
3193         }
3194     }
3195     this->writeCapabilities(out);
3196     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3197     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3198     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
3199                       (int32_t) interfaceVars.size(), out);
3200     switch (program.fKind) {
3201         case Program::kVertex_Kind:
3202             this->writeWord(SpvExecutionModelVertex, out);
3203             break;
3204         case Program::kFragment_Kind:
3205             this->writeWord(SpvExecutionModelFragment, out);
3206             break;
3207         case Program::kGeometry_Kind:
3208             this->writeWord(SpvExecutionModelGeometry, out);
3209             break;
3210         default:
3211             ABORT("cannot write this kind of program to SPIR-V\n");
3212     }
3213     SpvId entryPoint = fFunctionMap[main];
3214     this->writeWord(entryPoint, out);
3215     this->writeString(main->fName.fChars, main->fName.fLength, out);
3216     for (int var : interfaceVars) {
3217         this->writeWord(var, out);
3218     }
3219     if (program.fKind == Program::kGeometry_Kind) {
3220         this->writeGeometryShaderExecutionMode(entryPoint, out);
3221     }
3222     if (program.fKind == Program::kFragment_Kind) {
3223         this->writeInstruction(SpvOpExecutionMode,
3224                                fFunctionMap[main],
3225                                SpvExecutionModeOriginUpperLeft,
3226                                out);
3227     }
3228     for (const auto& e : program) {
3229         if (e.fKind == ProgramElement::kExtension_Kind) {
3230             this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
3231         }
3232     }
3233 
3234     write_stringstream(fExtraGlobalsBuffer, out);
3235     write_stringstream(fNameBuffer, out);
3236     write_stringstream(fDecorationBuffer, out);
3237     write_stringstream(fConstantBuffer, out);
3238     write_stringstream(fExternalFunctionsBuffer, out);
3239     write_stringstream(body, out);
3240 }
3241 
generateCode()3242 bool SPIRVCodeGenerator::generateCode() {
3243     SkASSERT(!fErrors.errorCount());
3244     this->writeWord(SpvMagicNumber, *fOut);
3245     this->writeWord(SpvVersion, *fOut);
3246     this->writeWord(SKSL_MAGIC, *fOut);
3247     StringStream buffer;
3248     this->writeInstructions(fProgram, buffer);
3249     this->writeWord(fIdCount, *fOut);
3250     this->writeWord(0, *fOut); // reserved, always zero
3251     write_stringstream(buffer, *fOut);
3252     return 0 == fErrors.errorCount();
3253 }
3254 
3255 }
3256