• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLSPIRVCodeGenerator.h"
9 
10 #include "GLSL.std.450.h"
11 
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17 
18 namespace SkSL {
19 
20 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
21 
setupIntrinsics()22 void SPIRVCodeGenerator::setupIntrinsics() {
23 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24                                     GLSLstd450 ## x, GLSLstd450 ## x)
25 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26                                                              GLSLstd450 ## ifFloat, \
27                                                              GLSLstd450 ## ifInt, \
28                                                              GLSLstd450 ## ifUInt, \
29                                                              SpvOpUndef)
30 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31                                                            SpvOp ## x)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34                                    k ## x ## _SpecialIntrinsic)
35     fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
36     fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
37     fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
38     fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39     fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
40     fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
41     fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
42     fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
43     fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
44     fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
45     fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
46     fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
47     fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
48     fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
49     fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
50     fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
51     fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
52     fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
53     fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
54     fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
55     fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
56     fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
57     fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
58     fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
59     fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
60     fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
61     fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
62     fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
63     fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
64     fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
65     fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
67     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68     fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
69     fIntrinsicMap[String("min")]           = SPECIAL(Min);
70     fIntrinsicMap[String("max")]           = SPECIAL(Max);
71     fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
72     fIntrinsicMap[String("saturate")]      = SPECIAL(Saturate);
73     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
74                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
75     fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
76     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
77     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
78     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
79     fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
80     fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
81 
82 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
83                    fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
84     PACK(Snorm4x8);
85     PACK(Unorm4x8);
86     PACK(Snorm2x16);
87     PACK(Unorm2x16);
88     PACK(Half2x16);
89     PACK(Double2x32);
90     fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
91     fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
92     fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
93     fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
94     fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
95     fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
96     fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
97     fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
98     fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
99     fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
100                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
101     fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
103     fIntrinsicMap[String("fwidth")]      = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFwidth,
104                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
105     fIntrinsicMap[String("texture")]     = SPECIAL(Texture);
106     fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107 
108     fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109                                                                 SpvOpUndef, SpvOpUndef, SpvOpAny);
110     fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111                                                                 SpvOpUndef, SpvOpUndef, SpvOpAll);
112     fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
113                                                                 SpvOpFOrdEqual, SpvOpIEqual,
114                                                                 SpvOpIEqual, SpvOpLogicalEqual);
115     fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
116                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
117                                                                 SpvOpINotEqual,
118                                                                 SpvOpLogicalNotEqual);
119     fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
120                                                                 SpvOpFOrdLessThan, SpvOpSLessThan,
121                                                                 SpvOpULessThan, SpvOpUndef);
122     fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
123                                                                 SpvOpFOrdLessThanEqual,
124                                                                 SpvOpSLessThanEqual,
125                                                                 SpvOpULessThanEqual,
126                                                                 SpvOpUndef);
127     fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                 SpvOpFOrdGreaterThan,
129                                                                 SpvOpSGreaterThan,
130                                                                 SpvOpUGreaterThan,
131                                                                 SpvOpUndef);
132     fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                 SpvOpFOrdGreaterThanEqual,
134                                                                 SpvOpSGreaterThanEqual,
135                                                                 SpvOpUGreaterThanEqual,
136                                                                 SpvOpUndef);
137     fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
138     fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
139 // interpolateAt* not yet supported...
140 }
141 
writeWord(int32_t word,OutputStream & out)142 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143     out.write((const char*) &word, sizeof(word));
144 }
145 
is_float(const Context & context,const Type & type)146 static bool is_float(const Context& context, const Type& type) {
147     if (type.columns() > 1) {
148         return is_float(context, type.componentType());
149     }
150     return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151            type == *context.fDouble_Type;
152 }
153 
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155     if (type.kind() == Type::kVector_Kind) {
156         return is_signed(context, type.componentType());
157     }
158     return type == *context.fInt_Type || type == *context.fShort_Type ||
159            type == *context.fByte_Type;
160 }
161 
is_unsigned(const Context & context,const Type & type)162 static bool is_unsigned(const Context& context, const Type& type) {
163     if (type.kind() == Type::kVector_Kind) {
164         return is_unsigned(context, type.componentType());
165     }
166     return type == *context.fUInt_Type || type == *context.fUShort_Type ||
167            type == *context.fUByte_Type;
168 }
169 
is_bool(const Context & context,const Type & type)170 static bool is_bool(const Context& context, const Type& type) {
171     if (type.kind() == Type::kVector_Kind) {
172         return is_bool(context, type.componentType());
173     }
174     return type == *context.fBool_Type;
175 }
176 
is_out(const Variable & var)177 static bool is_out(const Variable& var) {
178     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
179 }
180 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)181 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
182     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
183     SkASSERT(opCode != SpvOpUndef);
184     switch (opCode) {
185         case SpvOpReturn:      // fall through
186         case SpvOpReturnValue: // fall through
187         case SpvOpKill:        // fall through
188         case SpvOpBranch:      // fall through
189         case SpvOpBranchConditional:
190             SkASSERT(fCurrentBlock);
191             fCurrentBlock = 0;
192             break;
193         case SpvOpConstant:          // fall through
194         case SpvOpConstantTrue:      // fall through
195         case SpvOpConstantFalse:     // fall through
196         case SpvOpConstantComposite: // fall through
197         case SpvOpTypeVoid:          // fall through
198         case SpvOpTypeInt:           // fall through
199         case SpvOpTypeFloat:         // fall through
200         case SpvOpTypeBool:          // fall through
201         case SpvOpTypeVector:        // fall through
202         case SpvOpTypeMatrix:        // fall through
203         case SpvOpTypeArray:         // fall through
204         case SpvOpTypePointer:       // fall through
205         case SpvOpTypeFunction:      // fall through
206         case SpvOpTypeRuntimeArray:  // fall through
207         case SpvOpTypeStruct:        // fall through
208         case SpvOpTypeImage:         // fall through
209         case SpvOpTypeSampledImage:  // fall through
210         case SpvOpVariable:          // fall through
211         case SpvOpFunction:          // fall through
212         case SpvOpFunctionParameter: // fall through
213         case SpvOpFunctionEnd:       // fall through
214         case SpvOpExecutionMode:     // fall through
215         case SpvOpMemoryModel:       // fall through
216         case SpvOpCapability:        // fall through
217         case SpvOpExtInstImport:     // fall through
218         case SpvOpEntryPoint:        // fall through
219         case SpvOpSource:            // fall through
220         case SpvOpSourceExtension:   // fall through
221         case SpvOpName:              // fall through
222         case SpvOpMemberName:        // fall through
223         case SpvOpDecorate:          // fall through
224         case SpvOpMemberDecorate:
225             break;
226         default:
227             SkASSERT(fCurrentBlock);
228     }
229     this->writeWord((length << 16) | opCode, out);
230 }
231 
writeLabel(SpvId label,OutputStream & out)232 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
233     fCurrentBlock = label;
234     this->writeInstruction(SpvOpLabel, label, out);
235 }
236 
writeInstruction(SpvOp_ opCode,OutputStream & out)237 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
238     this->writeOpCode(opCode, 1, out);
239 }
240 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)241 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
242     this->writeOpCode(opCode, 2, out);
243     this->writeWord(word1, out);
244 }
245 
writeString(const char * string,size_t length,OutputStream & out)246 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
247     out.write(string, length);
248     switch (length % 4) {
249         case 1:
250             out.write8(0);
251             // fall through
252         case 2:
253             out.write8(0);
254             // fall through
255         case 3:
256             out.write8(0);
257             break;
258         default:
259             this->writeWord(0, out);
260     }
261 }
262 
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)263 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
264     this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
265     this->writeString(string.fChars, string.fLength, out);
266 }
267 
268 
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)269 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
270                                           OutputStream& out) {
271     this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
272     this->writeWord(word1, out);
273     this->writeString(string.fChars, string.fLength, out);
274 }
275 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)276 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
277                                           StringFragment string, OutputStream& out) {
278     this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
279     this->writeWord(word1, out);
280     this->writeWord(word2, out);
281     this->writeString(string.fChars, string.fLength, out);
282 }
283 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)284 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
285                                           OutputStream& out) {
286     this->writeOpCode(opCode, 3, out);
287     this->writeWord(word1, out);
288     this->writeWord(word2, out);
289 }
290 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)291 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
292                                           int32_t word3, OutputStream& out) {
293     this->writeOpCode(opCode, 4, out);
294     this->writeWord(word1, out);
295     this->writeWord(word2, out);
296     this->writeWord(word3, out);
297 }
298 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)299 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
300                                           int32_t word3, int32_t word4, OutputStream& out) {
301     this->writeOpCode(opCode, 5, out);
302     this->writeWord(word1, out);
303     this->writeWord(word2, out);
304     this->writeWord(word3, out);
305     this->writeWord(word4, out);
306 }
307 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)308 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
309                                           int32_t word3, int32_t word4, int32_t word5,
310                                           OutputStream& out) {
311     this->writeOpCode(opCode, 6, out);
312     this->writeWord(word1, out);
313     this->writeWord(word2, out);
314     this->writeWord(word3, out);
315     this->writeWord(word4, out);
316     this->writeWord(word5, out);
317 }
318 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)319 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
320                                           int32_t word3, int32_t word4, int32_t word5,
321                                           int32_t word6, OutputStream& out) {
322     this->writeOpCode(opCode, 7, out);
323     this->writeWord(word1, out);
324     this->writeWord(word2, out);
325     this->writeWord(word3, out);
326     this->writeWord(word4, out);
327     this->writeWord(word5, out);
328     this->writeWord(word6, out);
329 }
330 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)331 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
332                                           int32_t word3, int32_t word4, int32_t word5,
333                                           int32_t word6, int32_t word7, OutputStream& out) {
334     this->writeOpCode(opCode, 8, out);
335     this->writeWord(word1, out);
336     this->writeWord(word2, out);
337     this->writeWord(word3, out);
338     this->writeWord(word4, out);
339     this->writeWord(word5, out);
340     this->writeWord(word6, out);
341     this->writeWord(word7, out);
342 }
343 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)344 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
345                                           int32_t word3, int32_t word4, int32_t word5,
346                                           int32_t word6, int32_t word7, int32_t word8,
347                                           OutputStream& out) {
348     this->writeOpCode(opCode, 9, out);
349     this->writeWord(word1, out);
350     this->writeWord(word2, out);
351     this->writeWord(word3, out);
352     this->writeWord(word4, out);
353     this->writeWord(word5, out);
354     this->writeWord(word6, out);
355     this->writeWord(word7, out);
356     this->writeWord(word8, out);
357 }
358 
writeCapabilities(OutputStream & out)359 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
360     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
361         if (fCapabilities & bit) {
362             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
363         }
364     }
365     if (fProgram.fKind == Program::kGeometry_Kind) {
366         this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
367     }
368     else {
369         this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
370     }
371 }
372 
nextId()373 SpvId SPIRVCodeGenerator::nextId() {
374     return fIdCount++;
375 }
376 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)377 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
378                                      SpvId resultId) {
379     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
380     // go ahead and write all of the field types, so we don't inadvertently write them while we're
381     // in the middle of writing the struct instruction
382     std::vector<SpvId> types;
383     for (const auto& f : type.fields()) {
384         types.push_back(this->getType(*f.fType, memoryLayout));
385     }
386     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
387     this->writeWord(resultId, fConstantBuffer);
388     for (SpvId id : types) {
389         this->writeWord(id, fConstantBuffer);
390     }
391     size_t offset = 0;
392     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
393         size_t size = memoryLayout.size(*type.fields()[i].fType);
394         size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
395         const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
396         if (fieldLayout.fOffset >= 0) {
397             if (fieldLayout.fOffset < (int) offset) {
398                 fErrors.error(type.fOffset,
399                               "offset of field '" + type.fields()[i].fName + "' must be at "
400                               "least " + to_string((int) offset));
401             }
402             if (fieldLayout.fOffset % alignment) {
403                 fErrors.error(type.fOffset,
404                               "offset of field '" + type.fields()[i].fName + "' must be a multiple"
405                               " of " + to_string((int) alignment));
406             }
407             offset = fieldLayout.fOffset;
408         } else {
409             size_t mod = offset % alignment;
410             if (mod) {
411                 offset += alignment - mod;
412             }
413         }
414         this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
415         this->writeLayout(fieldLayout, resultId, i);
416         if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
417             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
418                                    (SpvId) offset, fDecorationBuffer);
419         }
420         if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
421             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
422                                    fDecorationBuffer);
423             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
424                                    (SpvId) memoryLayout.stride(*type.fields()[i].fType),
425                                    fDecorationBuffer);
426         }
427         offset += size;
428         Type::Kind kind = type.fields()[i].fType->kind();
429         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
430             offset += alignment - offset % alignment;
431         }
432     }
433 }
434 
getActualType(const Type & type)435 Type SPIRVCodeGenerator::getActualType(const Type& type) {
436     if (type == *fContext.fHalf_Type) {
437         return *fContext.fFloat_Type;
438     }
439     if (type == *fContext.fShort_Type || type == *fContext.fByte_Type) {
440         return *fContext.fInt_Type;
441     }
442     if (type == *fContext.fUShort_Type || type == *fContext.fUByte_Type) {
443         return *fContext.fUInt_Type;
444     }
445     if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
446         if (type.componentType() == *fContext.fHalf_Type) {
447             return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
448         }
449         if (type.componentType() == *fContext.fShort_Type ||
450             type.componentType() == *fContext.fByte_Type) {
451             return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
452         }
453         if (type.componentType() == *fContext.fUShort_Type ||
454             type.componentType() == *fContext.fUByte_Type) {
455             return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
456         }
457     }
458     return type;
459 }
460 
getType(const Type & type)461 SpvId SPIRVCodeGenerator::getType(const Type& type) {
462     return this->getType(type, fDefaultLayout);
463 }
464 
getType(const Type & rawType,const MemoryLayout & layout)465 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
466     Type type = this->getActualType(rawType);
467     String key = type.name() + to_string((int) layout.fStd);
468     auto entry = fTypeMap.find(key);
469     if (entry == fTypeMap.end()) {
470         SpvId result = this->nextId();
471         switch (type.kind()) {
472             case Type::kScalar_Kind:
473                 if (type == *fContext.fBool_Type) {
474                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
475                 } else if (type == *fContext.fInt_Type) {
476                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
477                 } else if (type == *fContext.fUInt_Type) {
478                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
479                 } else if (type == *fContext.fFloat_Type) {
480                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
481                 } else if (type == *fContext.fDouble_Type) {
482                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
483                 } else {
484                     SkASSERT(false);
485                 }
486                 break;
487             case Type::kVector_Kind:
488                 this->writeInstruction(SpvOpTypeVector, result,
489                                        this->getType(type.componentType(), layout),
490                                        type.columns(), fConstantBuffer);
491                 break;
492             case Type::kMatrix_Kind:
493                 this->writeInstruction(SpvOpTypeMatrix, result,
494                                        this->getType(index_type(fContext, type), layout),
495                                        type.columns(), fConstantBuffer);
496                 break;
497             case Type::kStruct_Kind:
498                 this->writeStruct(type, layout, result);
499                 break;
500             case Type::kArray_Kind: {
501                 if (type.columns() > 0) {
502                     IntLiteral count(fContext, -1, type.columns());
503                     this->writeInstruction(SpvOpTypeArray, result,
504                                            this->getType(type.componentType(), layout),
505                                            this->writeIntLiteral(count), fConstantBuffer);
506                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507                                            (int32_t) layout.stride(type),
508                                            fDecorationBuffer);
509                 } else {
510                     SkASSERT(false); // we shouldn't have any runtime-sized arrays right now
511                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
512                                            this->getType(type.componentType(), layout),
513                                            fConstantBuffer);
514                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
515                                            (int32_t) layout.stride(type),
516                                            fDecorationBuffer);
517                 }
518                 break;
519             }
520             case Type::kSampler_Kind: {
521                 SpvId image = result;
522                 if (SpvDimSubpassData != type.dimensions()) {
523                     image = this->nextId();
524                 }
525                 if (SpvDimBuffer == type.dimensions()) {
526                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
527                 }
528                 this->writeInstruction(SpvOpTypeImage, image,
529                                        this->getType(*fContext.fFloat_Type, layout),
530                                        type.dimensions(), type.isDepth(), type.isArrayed(),
531                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
532                                        SpvImageFormatUnknown, fConstantBuffer);
533                 fImageTypeMap[key] = image;
534                 if (SpvDimSubpassData != type.dimensions()) {
535                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
536                 }
537                 break;
538             }
539             default:
540                 if (type == *fContext.fVoid_Type) {
541                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
542                 } else {
543                     ABORT("invalid type: %s", type.description().c_str());
544                 }
545         }
546         fTypeMap[key] = result;
547         return result;
548     }
549     return entry->second;
550 }
551 
getImageType(const Type & type)552 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
553     SkASSERT(type.kind() == Type::kSampler_Kind);
554     this->getType(type);
555     String key = type.name() + to_string((int) fDefaultLayout.fStd);
556     SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
557     return fImageTypeMap[key];
558 }
559 
getFunctionType(const FunctionDeclaration & function)560 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
561     String key = function.fReturnType.description() + "(";
562     String separator;
563     for (size_t i = 0; i < function.fParameters.size(); i++) {
564         key += separator;
565         separator = ", ";
566         key += function.fParameters[i]->fType.description();
567     }
568     key += ")";
569     auto entry = fTypeMap.find(key);
570     if (entry == fTypeMap.end()) {
571         SpvId result = this->nextId();
572         int32_t length = 3 + (int32_t) function.fParameters.size();
573         SpvId returnType = this->getType(function.fReturnType);
574         std::vector<SpvId> parameterTypes;
575         for (size_t i = 0; i < function.fParameters.size(); i++) {
576             // glslang seems to treat all function arguments as pointers whether they need to be or
577             // not. I  was initially puzzled by this until I ran bizarre failures with certain
578             // patterns of function calls and control constructs, as exemplified by this minimal
579             // failure case:
580             //
581             // void sphere(float x) {
582             // }
583             //
584             // void map() {
585             //     sphere(1.0);
586             // }
587             //
588             // void main() {
589             //     for (int i = 0; i < 1; i++) {
590             //         map();
591             //     }
592             // }
593             //
594             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
595             // crashes. Making it take a float* and storing the argument in a temporary variable,
596             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
597             // the spec makes this make sense.
598 //            if (is_out(function->fParameters[i])) {
599                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
600                                                               SpvStorageClassFunction));
601 //            } else {
602 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
603 //            }
604         }
605         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
606         this->writeWord(result, fConstantBuffer);
607         this->writeWord(returnType, fConstantBuffer);
608         for (SpvId id : parameterTypes) {
609             this->writeWord(id, fConstantBuffer);
610         }
611         fTypeMap[key] = result;
612         return result;
613     }
614     return entry->second;
615 }
616 
getPointerType(const Type & type,SpvStorageClass_ storageClass)617 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
618     return this->getPointerType(type, fDefaultLayout, storageClass);
619 }
620 
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)621 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
622                                          SpvStorageClass_ storageClass) {
623     Type type = this->getActualType(rawType);
624     String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
625     auto entry = fTypeMap.find(key);
626     if (entry == fTypeMap.end()) {
627         SpvId result = this->nextId();
628         this->writeInstruction(SpvOpTypePointer, result, storageClass,
629                                this->getType(type), fConstantBuffer);
630         fTypeMap[key] = result;
631         return result;
632     }
633     return entry->second;
634 }
635 
writeExpression(const Expression & expr,OutputStream & out)636 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
637     switch (expr.fKind) {
638         case Expression::kBinary_Kind:
639             return this->writeBinaryExpression((BinaryExpression&) expr, out);
640         case Expression::kBoolLiteral_Kind:
641             return this->writeBoolLiteral((BoolLiteral&) expr);
642         case Expression::kConstructor_Kind:
643             return this->writeConstructor((Constructor&) expr, out);
644         case Expression::kIntLiteral_Kind:
645             return this->writeIntLiteral((IntLiteral&) expr);
646         case Expression::kFieldAccess_Kind:
647             return this->writeFieldAccess(((FieldAccess&) expr), out);
648         case Expression::kFloatLiteral_Kind:
649             return this->writeFloatLiteral(((FloatLiteral&) expr));
650         case Expression::kFunctionCall_Kind:
651             return this->writeFunctionCall((FunctionCall&) expr, out);
652         case Expression::kPrefix_Kind:
653             return this->writePrefixExpression((PrefixExpression&) expr, out);
654         case Expression::kPostfix_Kind:
655             return this->writePostfixExpression((PostfixExpression&) expr, out);
656         case Expression::kSwizzle_Kind:
657             return this->writeSwizzle((Swizzle&) expr, out);
658         case Expression::kVariableReference_Kind:
659             return this->writeVariableReference((VariableReference&) expr, out);
660         case Expression::kTernary_Kind:
661             return this->writeTernaryExpression((TernaryExpression&) expr, out);
662         case Expression::kIndex_Kind:
663             return this->writeIndexExpression((IndexExpression&) expr, out);
664         default:
665             ABORT("unsupported expression: %s", expr.description().c_str());
666     }
667     return -1;
668 }
669 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)670 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
671     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
672     SkASSERT(intrinsic != fIntrinsicMap.end());
673     int32_t intrinsicId;
674     if (c.fArguments.size() > 0) {
675         const Type& type = c.fArguments[0]->fType;
676         if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
677             intrinsicId = std::get<1>(intrinsic->second);
678         } else if (is_signed(fContext, type)) {
679             intrinsicId = std::get<2>(intrinsic->second);
680         } else if (is_unsigned(fContext, type)) {
681             intrinsicId = std::get<3>(intrinsic->second);
682         } else if (is_bool(fContext, type)) {
683             intrinsicId = std::get<4>(intrinsic->second);
684         } else {
685             intrinsicId = std::get<1>(intrinsic->second);
686         }
687     } else {
688         intrinsicId = std::get<1>(intrinsic->second);
689     }
690     switch (std::get<0>(intrinsic->second)) {
691         case kGLSL_STD_450_IntrinsicKind: {
692             SpvId result = this->nextId();
693             std::vector<SpvId> arguments;
694             for (size_t i = 0; i < c.fArguments.size(); i++) {
695                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
696                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
697                 } else {
698                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
699                 }
700             }
701             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
702             this->writeWord(this->getType(c.fType), out);
703             this->writeWord(result, out);
704             this->writeWord(fGLSLExtendedInstructions, out);
705             this->writeWord(intrinsicId, out);
706             for (SpvId id : arguments) {
707                 this->writeWord(id, out);
708             }
709             return result;
710         }
711         case kSPIRV_IntrinsicKind: {
712             SpvId result = this->nextId();
713             std::vector<SpvId> arguments;
714             for (size_t i = 0; i < c.fArguments.size(); i++) {
715                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
716                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
717                 } else {
718                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
719                 }
720             }
721             if (c.fType != *fContext.fVoid_Type) {
722                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
723                 this->writeWord(this->getType(c.fType), out);
724                 this->writeWord(result, out);
725             } else {
726                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
727             }
728             for (SpvId id : arguments) {
729                 this->writeWord(id, out);
730             }
731             return result;
732         }
733         case kSpecial_IntrinsicKind:
734             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
735         default:
736             ABORT("unsupported intrinsic kind");
737     }
738 }
739 
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)740 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
741                                                const std::vector<std::unique_ptr<Expression>>& args,
742                                                OutputStream& out) {
743     int vectorSize = 0;
744     for (const auto& a : args) {
745         if (a->fType.kind() == Type::kVector_Kind) {
746             if (vectorSize) {
747                 SkASSERT(a->fType.columns() == vectorSize);
748             }
749             else {
750                 vectorSize = a->fType.columns();
751             }
752         }
753     }
754     std::vector<SpvId> result;
755     for (const auto& a : args) {
756         SpvId raw = this->writeExpression(*a, out);
757         if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
758             SpvId vector = this->nextId();
759             this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
760             this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
761             this->writeWord(vector, out);
762             for (int i = 0; i < vectorSize; i++) {
763                 this->writeWord(raw, out);
764             }
765             result.push_back(vector);
766         } else {
767             result.push_back(raw);
768         }
769     }
770     return result;
771 }
772 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)773 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
774                                                       SpvId signedInst, SpvId unsignedInst,
775                                                       const std::vector<SpvId>& args,
776                                                       OutputStream& out) {
777     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
778     this->writeWord(this->getType(type), out);
779     this->writeWord(id, out);
780     this->writeWord(fGLSLExtendedInstructions, out);
781 
782     if (is_float(fContext, type)) {
783         this->writeWord(floatInst, out);
784     } else if (is_signed(fContext, type)) {
785         this->writeWord(signedInst, out);
786     } else if (is_unsigned(fContext, type)) {
787         this->writeWord(unsignedInst, out);
788     } else {
789         SkASSERT(false);
790     }
791     for (SpvId a : args) {
792         this->writeWord(a, out);
793     }
794 }
795 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)796 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
797                                                 OutputStream& out) {
798     SpvId result = this->nextId();
799     switch (kind) {
800         case kAtan_SpecialIntrinsic: {
801             std::vector<SpvId> arguments;
802             for (size_t i = 0; i < c.fArguments.size(); i++) {
803                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
804             }
805             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
806             this->writeWord(this->getType(c.fType), out);
807             this->writeWord(result, out);
808             this->writeWord(fGLSLExtendedInstructions, out);
809             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
810             for (SpvId id : arguments) {
811                 this->writeWord(id, out);
812             }
813             break;
814         }
815         case kSubpassLoad_SpecialIntrinsic: {
816             SpvId img = this->writeExpression(*c.fArguments[0], out);
817             std::vector<std::unique_ptr<Expression>> args;
818             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
819             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
820             Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
821             SpvId coords = this->writeConstantVector(ctor);
822             if (1 == c.fArguments.size()) {
823                 this->writeInstruction(SpvOpImageRead,
824                                        this->getType(c.fType),
825                                        result,
826                                        img,
827                                        coords,
828                                        out);
829             } else {
830                 SkASSERT(2 == c.fArguments.size());
831                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
832                 this->writeInstruction(SpvOpImageRead,
833                                        this->getType(c.fType),
834                                        result,
835                                        img,
836                                        coords,
837                                        SpvImageOperandsSampleMask,
838                                        sample,
839                                        out);
840             }
841             break;
842         }
843         case kTexture_SpecialIntrinsic: {
844             SpvOp_ op = SpvOpImageSampleImplicitLod;
845             switch (c.fArguments[0]->fType.dimensions()) {
846                 case SpvDim1D:
847                     if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
848                         op = SpvOpImageSampleProjImplicitLod;
849                     } else {
850                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
851                     }
852                     break;
853                 case SpvDim2D:
854                     if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
855                         op = SpvOpImageSampleProjImplicitLod;
856                     } else {
857                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
858                     }
859                     break;
860                 case SpvDim3D:
861                     if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
862                         op = SpvOpImageSampleProjImplicitLod;
863                     } else {
864                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
865                     }
866                     break;
867                 case SpvDimCube:   // fall through
868                 case SpvDimRect:   // fall through
869                 case SpvDimBuffer: // fall through
870                 case SpvDimSubpassData:
871                     break;
872             }
873             SpvId type = this->getType(c.fType);
874             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
875             SpvId uv = this->writeExpression(*c.fArguments[1], out);
876             if (c.fArguments.size() == 3) {
877                 this->writeInstruction(op, type, result, sampler, uv,
878                                        SpvImageOperandsBiasMask,
879                                        this->writeExpression(*c.fArguments[2], out),
880                                        out);
881             } else {
882                 SkASSERT(c.fArguments.size() == 2);
883                 if (fProgram.fSettings.fSharpenTextures) {
884                     FloatLiteral lodBias(fContext, -1, -0.5);
885                     this->writeInstruction(op, type, result, sampler, uv,
886                                            SpvImageOperandsBiasMask,
887                                            this->writeFloatLiteral(lodBias),
888                                            out);
889                 } else {
890                     this->writeInstruction(op, type, result, sampler, uv,
891                                            out);
892                 }
893             }
894             break;
895         }
896         case kMod_SpecialIntrinsic: {
897             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
898             SkASSERT(args.size() == 2);
899             const Type& operandType = c.fArguments[0]->fType;
900             SpvOp_ op;
901             if (is_float(fContext, operandType)) {
902                 op = SpvOpFMod;
903             } else if (is_signed(fContext, operandType)) {
904                 op = SpvOpSMod;
905             } else if (is_unsigned(fContext, operandType)) {
906                 op = SpvOpUMod;
907             } else {
908                 SkASSERT(false);
909                 return 0;
910             }
911             this->writeOpCode(op, 5, out);
912             this->writeWord(this->getType(operandType), out);
913             this->writeWord(result, out);
914             this->writeWord(args[0], out);
915             this->writeWord(args[1], out);
916             break;
917         }
918         case kClamp_SpecialIntrinsic: {
919             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
920             SkASSERT(args.size() == 3);
921             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
922                                                GLSLstd450UClamp, args, out);
923             break;
924         }
925         case kMax_SpecialIntrinsic: {
926             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
927             SkASSERT(args.size() == 2);
928             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
929                                                GLSLstd450UMax, args, out);
930             break;
931         }
932         case kMin_SpecialIntrinsic: {
933             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
934             SkASSERT(args.size() == 2);
935             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
936                                                GLSLstd450UMin, args, out);
937             break;
938         }
939         case kMix_SpecialIntrinsic: {
940             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
941             SkASSERT(args.size() == 3);
942             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
943                                                SpvOpUndef, args, out);
944             break;
945         }
946         case kSaturate_SpecialIntrinsic: {
947             SkASSERT(c.fArguments.size() == 1);
948             std::vector<std::unique_ptr<Expression>> finalArgs;
949             finalArgs.push_back(c.fArguments[0]->clone());
950             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 0));
951             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 1));
952             std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
953             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
954                                                GLSLstd450UClamp, spvArgs, out);
955             break;
956         }
957     }
958     return result;
959 }
960 
writeFunctionCall(const FunctionCall & c,OutputStream & out)961 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
962     const auto& entry = fFunctionMap.find(&c.fFunction);
963     if (entry == fFunctionMap.end()) {
964         return this->writeIntrinsicCall(c, out);
965     }
966     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
967     std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
968     std::vector<SpvId> arguments;
969     for (size_t i = 0; i < c.fArguments.size(); i++) {
970         // id of temporary variable that we will use to hold this argument, or 0 if it is being
971         // passed directly
972         SpvId tmpVar;
973         // if we need a temporary var to store this argument, this is the value to store in the var
974         SpvId tmpValueId;
975         if (is_out(*c.fFunction.fParameters[i])) {
976             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
977             SpvId ptr = lv->getPointer();
978             if (ptr) {
979                 arguments.push_back(ptr);
980                 continue;
981             } else {
982                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
983                 // copy it into a temp, call the function, read the value out of the temp, and then
984                 // update the lvalue.
985                 tmpValueId = lv->load(out);
986                 tmpVar = this->nextId();
987                 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
988                                   std::move(lv)));
989             }
990         } else {
991             // see getFunctionType for an explanation of why we're always using pointer parameters
992             tmpValueId = this->writeExpression(*c.fArguments[i], out);
993             tmpVar = this->nextId();
994         }
995         this->writeInstruction(SpvOpVariable,
996                                this->getPointerType(c.fArguments[i]->fType,
997                                                     SpvStorageClassFunction),
998                                tmpVar,
999                                SpvStorageClassFunction,
1000                                fVariableBuffer);
1001         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1002         arguments.push_back(tmpVar);
1003     }
1004     SpvId result = this->nextId();
1005     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1006     this->writeWord(this->getType(c.fType), out);
1007     this->writeWord(result, out);
1008     this->writeWord(entry->second, out);
1009     for (SpvId id : arguments) {
1010         this->writeWord(id, out);
1011     }
1012     // now that the call is complete, we may need to update some lvalues with the new values of out
1013     // arguments
1014     for (const auto& tuple : lvalues) {
1015         SpvId load = this->nextId();
1016         this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1017         std::get<2>(tuple)->store(load, out);
1018     }
1019     return result;
1020 }
1021 
writeConstantVector(const Constructor & c)1022 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1023     SkASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1024     SpvId result = this->nextId();
1025     std::vector<SpvId> arguments;
1026     for (size_t i = 0; i < c.fArguments.size(); i++) {
1027         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1028     }
1029     SpvId type = this->getType(c.fType);
1030     if (c.fArguments.size() == 1) {
1031         // with a single argument, a vector will have all of its entries equal to the argument
1032         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1033         this->writeWord(type, fConstantBuffer);
1034         this->writeWord(result, fConstantBuffer);
1035         for (int i = 0; i < c.fType.columns(); i++) {
1036             this->writeWord(arguments[0], fConstantBuffer);
1037         }
1038     } else {
1039         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1040                           fConstantBuffer);
1041         this->writeWord(type, fConstantBuffer);
1042         this->writeWord(result, fConstantBuffer);
1043         for (SpvId id : arguments) {
1044             this->writeWord(id, fConstantBuffer);
1045         }
1046     }
1047     return result;
1048 }
1049 
writeFloatConstructor(const Constructor & c,OutputStream & out)1050 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1051     SkASSERT(c.fType.isFloat());
1052     SkASSERT(c.fArguments.size() == 1);
1053     SkASSERT(c.fArguments[0]->fType.isNumber());
1054     SpvId result = this->nextId();
1055     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1056     if (c.fArguments[0]->fType.isSigned()) {
1057         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1058                                out);
1059     } else {
1060         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1061         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1062                                out);
1063     }
1064     return result;
1065 }
1066 
writeIntConstructor(const Constructor & c,OutputStream & out)1067 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1068     SkASSERT(c.fType.isSigned());
1069     SkASSERT(c.fArguments.size() == 1);
1070     SkASSERT(c.fArguments[0]->fType.isNumber());
1071     SpvId result = this->nextId();
1072     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1073     if (c.fArguments[0]->fType.isFloat()) {
1074         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1075                                out);
1076     }
1077     else {
1078         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1079         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1080                                out);
1081     }
1082     return result;
1083 }
1084 
writeUIntConstructor(const Constructor & c,OutputStream & out)1085 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1086     SkASSERT(c.fType.isUnsigned());
1087     SkASSERT(c.fArguments.size() == 1);
1088     SkASSERT(c.fArguments[0]->fType.isNumber());
1089     SpvId result = this->nextId();
1090     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1091     if (c.fArguments[0]->fType.isFloat()) {
1092         this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1093                                out);
1094     } else {
1095         SkASSERT(c.fArguments[0]->fType.isSigned());
1096         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1097                                out);
1098     }
1099     return result;
1100 }
1101 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1102 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1103                                                  OutputStream& out) {
1104     FloatLiteral zero(fContext, -1, 0);
1105     SpvId zeroId = this->writeFloatLiteral(zero);
1106     std::vector<SpvId> columnIds;
1107     for (int column = 0; column < type.columns(); column++) {
1108         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1109                           out);
1110         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1111                         out);
1112         SpvId columnId = this->nextId();
1113         this->writeWord(columnId, out);
1114         columnIds.push_back(columnId);
1115         for (int row = 0; row < type.columns(); row++) {
1116             this->writeWord(row == column ? diagonal : zeroId, out);
1117         }
1118     }
1119     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1120                       out);
1121     this->writeWord(this->getType(type), out);
1122     this->writeWord(id, out);
1123     for (SpvId id : columnIds) {
1124         this->writeWord(id, out);
1125     }
1126 }
1127 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1128 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1129                                          const Type& dstType, OutputStream& out) {
1130     SkASSERT(srcType.kind() == Type::kMatrix_Kind);
1131     SkASSERT(dstType.kind() == Type::kMatrix_Kind);
1132     SkASSERT(srcType.componentType() == dstType.componentType());
1133     SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1134                                                                            srcType.rows(),
1135                                                                            1));
1136     SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1137                                                                            dstType.rows(),
1138                                                                            1));
1139     SpvId zeroId;
1140     if (dstType.componentType() == *fContext.fFloat_Type) {
1141         FloatLiteral zero(fContext, -1, 0.0);
1142         zeroId = this->writeFloatLiteral(zero);
1143     } else if (dstType.componentType() == *fContext.fInt_Type) {
1144         IntLiteral zero(fContext, -1, 0);
1145         zeroId = this->writeIntLiteral(zero);
1146     } else {
1147         ABORT("unsupported matrix component type");
1148     }
1149     SpvId zeroColumn = 0;
1150     SpvId columns[4];
1151     for (int i = 0; i < dstType.columns(); i++) {
1152         if (i < srcType.columns()) {
1153             // we're still inside the src matrix, copy the column
1154             SpvId srcColumn = this->nextId();
1155             this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1156             SpvId dstColumn;
1157             if (srcType.rows() == dstType.rows()) {
1158                 // columns are equal size, don't need to do anything
1159                 dstColumn = srcColumn;
1160             }
1161             else if (dstType.rows() > srcType.rows()) {
1162                 // dst column is bigger, need to zero-pad it
1163                 dstColumn = this->nextId();
1164                 int delta = dstType.rows() - srcType.rows();
1165                 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1166                 this->writeWord(dstColumnType, out);
1167                 this->writeWord(dstColumn, out);
1168                 this->writeWord(srcColumn, out);
1169                 for (int i = 0; i < delta; ++i) {
1170                     this->writeWord(zeroId, out);
1171                 }
1172             }
1173             else {
1174                 // dst column is smaller, need to swizzle the src column
1175                 dstColumn = this->nextId();
1176                 int count = dstType.rows();
1177                 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1178                 this->writeWord(dstColumnType, out);
1179                 this->writeWord(dstColumn, out);
1180                 this->writeWord(srcColumn, out);
1181                 this->writeWord(srcColumn, out);
1182                 for (int i = 0; i < count; i++) {
1183                     this->writeWord(i, out);
1184                 }
1185             }
1186             columns[i] = dstColumn;
1187         } else {
1188             // we're past the end of the src matrix, need a vector of zeroes
1189             if (!zeroColumn) {
1190                 zeroColumn = this->nextId();
1191                 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1192                 this->writeWord(dstColumnType, out);
1193                 this->writeWord(zeroColumn, out);
1194                 for (int i = 0; i < dstType.rows(); ++i) {
1195                     this->writeWord(zeroId, out);
1196                 }
1197             }
1198             columns[i] = zeroColumn;
1199         }
1200     }
1201     this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1202     this->writeWord(this->getType(dstType), out);
1203     this->writeWord(id, out);
1204     for (int i = 0; i < dstType.columns(); i++) {
1205         this->writeWord(columns[i], out);
1206     }
1207 }
1208 
writeMatrixConstructor(const Constructor & c,OutputStream & out)1209 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1210     SkASSERT(c.fType.kind() == Type::kMatrix_Kind);
1211     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1212     // an instruction
1213     std::vector<SpvId> arguments;
1214     for (size_t i = 0; i < c.fArguments.size(); i++) {
1215         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1216     }
1217     SpvId result = this->nextId();
1218     int rows = c.fType.rows();
1219     int columns = c.fType.columns();
1220     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1221         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1222     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1223         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1224     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1225         SkASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1226         SkASSERT(c.fArguments[0]->fType.columns() == 4);
1227         SpvId componentType = this->getType(c.fType.componentType());
1228         SpvId v[4];
1229         for (int i = 0; i < 4; ++i) {
1230             v[i] = this->nextId();
1231             this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1232         }
1233         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1234         SpvId column1 = this->nextId();
1235         this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1236         SpvId column2 = this->nextId();
1237         this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1238         this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1239                                column2, out);
1240     } else {
1241         std::vector<SpvId> columnIds;
1242         // ids of vectors and scalars we have written to the current column so far
1243         std::vector<SpvId> currentColumn;
1244         // the total number of scalars represented by currentColumn's entries
1245         int currentCount = 0;
1246         for (size_t i = 0; i < arguments.size(); i++) {
1247             if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1248                     c.fArguments[i]->fType.columns() == c.fType.rows()) {
1249                 // this is a complete column by itself
1250                 SkASSERT(currentCount == 0);
1251                 columnIds.push_back(arguments[i]);
1252             } else {
1253                 if (c.fArguments[i]->fType.columns() == 1) {
1254                     currentColumn.push_back(arguments[i]);
1255                 } else {
1256                     SpvId componentType = this->getType(c.fArguments[i]->fType.componentType());
1257                     for (int j = 0; j < c.fArguments[i]->fType.columns(); ++j) {
1258                         SpvId swizzle = this->nextId();
1259                         this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1260                                                arguments[i], j, out);
1261                         currentColumn.push_back(swizzle);
1262                     }
1263                 }
1264                 currentCount += c.fArguments[i]->fType.columns();
1265                 if (currentCount == rows) {
1266                     currentCount = 0;
1267                     this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1268                     this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1269                                                                                      1)),
1270                                     out);
1271                     SpvId columnId = this->nextId();
1272                     this->writeWord(columnId, out);
1273                     columnIds.push_back(columnId);
1274                     for (SpvId id : currentColumn) {
1275                         this->writeWord(id, out);
1276                     }
1277                     currentColumn.clear();
1278                 }
1279                 SkASSERT(currentCount < rows);
1280             }
1281         }
1282         SkASSERT(columnIds.size() == (size_t) columns);
1283         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1284         this->writeWord(this->getType(c.fType), out);
1285         this->writeWord(result, out);
1286         for (SpvId id : columnIds) {
1287             this->writeWord(id, out);
1288         }
1289     }
1290     return result;
1291 }
1292 
writeVectorConstructor(const Constructor & c,OutputStream & out)1293 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1294     SkASSERT(c.fType.kind() == Type::kVector_Kind);
1295     if (c.isConstant()) {
1296         return this->writeConstantVector(c);
1297     }
1298     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1299     // an instruction
1300     std::vector<SpvId> arguments;
1301     for (size_t i = 0; i < c.fArguments.size(); i++) {
1302         if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1303             // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1304             // extract the components and convert them in that case manually. On top of that,
1305             // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1306             // doesn't handle vector arguments at all, so we always extract vector components and
1307             // pass them into OpCreateComposite individually.
1308             SpvId vec = this->writeExpression(*c.fArguments[i], out);
1309             SpvOp_ op = SpvOpUndef;
1310             const Type& src = c.fArguments[i]->fType.componentType();
1311             const Type& dst = c.fType.componentType();
1312             if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1313                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1314                     if (c.fArguments.size() == 1) {
1315                         return vec;
1316                     }
1317                 } else if (src == *fContext.fInt_Type ||
1318                            src == *fContext.fShort_Type ||
1319                            src == *fContext.fByte_Type) {
1320                     op = SpvOpConvertSToF;
1321                 } else if (src == *fContext.fUInt_Type ||
1322                            src == *fContext.fUShort_Type ||
1323                            src == *fContext.fUByte_Type) {
1324                     op = SpvOpConvertUToF;
1325                 } else {
1326                     SkASSERT(false);
1327                 }
1328             } else if (dst == *fContext.fInt_Type ||
1329                        dst == *fContext.fShort_Type ||
1330                        dst == *fContext.fByte_Type) {
1331                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1332                     op = SpvOpConvertFToS;
1333                 } else if (src == *fContext.fInt_Type ||
1334                            src == *fContext.fShort_Type ||
1335                            src == *fContext.fByte_Type) {
1336                     if (c.fArguments.size() == 1) {
1337                         return vec;
1338                     }
1339                 } else if (src == *fContext.fUInt_Type ||
1340                            src == *fContext.fUShort_Type ||
1341                            src == *fContext.fUByte_Type) {
1342                     op = SpvOpBitcast;
1343                 } else {
1344                     SkASSERT(false);
1345                 }
1346             } else if (dst == *fContext.fUInt_Type ||
1347                        dst == *fContext.fUShort_Type ||
1348                        dst == *fContext.fUByte_Type) {
1349                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1350                     op = SpvOpConvertFToS;
1351                 } else if (src == *fContext.fInt_Type ||
1352                            src == *fContext.fShort_Type ||
1353                            src == *fContext.fByte_Type) {
1354                     op = SpvOpBitcast;
1355                 } else if (src == *fContext.fUInt_Type ||
1356                            src == *fContext.fUShort_Type ||
1357                            src == *fContext.fUByte_Type) {
1358                     if (c.fArguments.size() == 1) {
1359                         return vec;
1360                     }
1361                 } else {
1362                     SkASSERT(false);
1363                 }
1364             }
1365             for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1366                 SpvId swizzle = this->nextId();
1367                 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1368                                        out);
1369                 if (op != SpvOpUndef) {
1370                     SpvId cast = this->nextId();
1371                     this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1372                     arguments.push_back(cast);
1373                 } else {
1374                     arguments.push_back(swizzle);
1375                 }
1376             }
1377         } else {
1378             arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1379         }
1380     }
1381     SpvId result = this->nextId();
1382     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1383         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1384         this->writeWord(this->getType(c.fType), out);
1385         this->writeWord(result, out);
1386         for (int i = 0; i < c.fType.columns(); i++) {
1387             this->writeWord(arguments[0], out);
1388         }
1389     } else {
1390         SkASSERT(arguments.size() > 1);
1391         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1392         this->writeWord(this->getType(c.fType), out);
1393         this->writeWord(result, out);
1394         for (SpvId id : arguments) {
1395             this->writeWord(id, out);
1396         }
1397     }
1398     return result;
1399 }
1400 
writeArrayConstructor(const Constructor & c,OutputStream & out)1401 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1402     SkASSERT(c.fType.kind() == Type::kArray_Kind);
1403     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1404     // an instruction
1405     std::vector<SpvId> arguments;
1406     for (size_t i = 0; i < c.fArguments.size(); i++) {
1407         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1408     }
1409     SpvId result = this->nextId();
1410     this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1411     this->writeWord(this->getType(c.fType), out);
1412     this->writeWord(result, out);
1413     for (SpvId id : arguments) {
1414         this->writeWord(id, out);
1415     }
1416     return result;
1417 }
1418 
writeConstructor(const Constructor & c,OutputStream & out)1419 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1420     if (c.fArguments.size() == 1 &&
1421         this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1422         return this->writeExpression(*c.fArguments[0], out);
1423     }
1424     if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1425         return this->writeFloatConstructor(c, out);
1426     } else if (c.fType == *fContext.fInt_Type ||
1427                c.fType == *fContext.fShort_Type ||
1428                c.fType == *fContext.fByte_Type) {
1429         return this->writeIntConstructor(c, out);
1430     } else if (c.fType == *fContext.fUInt_Type ||
1431                c.fType == *fContext.fUShort_Type ||
1432                c.fType == *fContext.fUByte_Type) {
1433         return this->writeUIntConstructor(c, out);
1434     }
1435     switch (c.fType.kind()) {
1436         case Type::kVector_Kind:
1437             return this->writeVectorConstructor(c, out);
1438         case Type::kMatrix_Kind:
1439             return this->writeMatrixConstructor(c, out);
1440         case Type::kArray_Kind:
1441             return this->writeArrayConstructor(c, out);
1442         default:
1443             ABORT("unsupported constructor: %s", c.description().c_str());
1444     }
1445 }
1446 
get_storage_class(const Modifiers & modifiers)1447 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1448     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1449         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1450         return SpvStorageClassInput;
1451     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1452         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1453         return SpvStorageClassOutput;
1454     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1455         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1456             return SpvStorageClassPushConstant;
1457         }
1458         return SpvStorageClassUniform;
1459     } else {
1460         return SpvStorageClassFunction;
1461     }
1462 }
1463 
get_storage_class(const Expression & expr)1464 SpvStorageClass_ get_storage_class(const Expression& expr) {
1465     switch (expr.fKind) {
1466         case Expression::kVariableReference_Kind: {
1467             const Variable& var = ((VariableReference&) expr).fVariable;
1468             if (var.fStorage != Variable::kGlobal_Storage) {
1469                 return SpvStorageClassFunction;
1470             }
1471             SpvStorageClass_ result = get_storage_class(var.fModifiers);
1472             if (result == SpvStorageClassFunction) {
1473                 result = SpvStorageClassPrivate;
1474             }
1475             return result;
1476         }
1477         case Expression::kFieldAccess_Kind:
1478             return get_storage_class(*((FieldAccess&) expr).fBase);
1479         case Expression::kIndex_Kind:
1480             return get_storage_class(*((IndexExpression&) expr).fBase);
1481         default:
1482             return SpvStorageClassFunction;
1483     }
1484 }
1485 
getAccessChain(const Expression & expr,OutputStream & out)1486 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1487     std::vector<SpvId> chain;
1488     switch (expr.fKind) {
1489         case Expression::kIndex_Kind: {
1490             IndexExpression& indexExpr = (IndexExpression&) expr;
1491             chain = this->getAccessChain(*indexExpr.fBase, out);
1492             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1493             break;
1494         }
1495         case Expression::kFieldAccess_Kind: {
1496             FieldAccess& fieldExpr = (FieldAccess&) expr;
1497             chain = this->getAccessChain(*fieldExpr.fBase, out);
1498             IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1499             chain.push_back(this->writeIntLiteral(index));
1500             break;
1501         }
1502         default: {
1503             SpvId id = this->getLValue(expr, out)->getPointer();
1504             SkASSERT(id != 0);
1505             chain.push_back(id);
1506         }
1507     }
1508     return chain;
1509 }
1510 
1511 class PointerLValue : public SPIRVCodeGenerator::LValue {
1512 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1513     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1514     : fGen(gen)
1515     , fPointer(pointer)
1516     , fType(type) {}
1517 
getPointer()1518     virtual SpvId getPointer() override {
1519         return fPointer;
1520     }
1521 
load(OutputStream & out)1522     virtual SpvId load(OutputStream& out) override {
1523         SpvId result = fGen.nextId();
1524         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1525         return result;
1526     }
1527 
store(SpvId value,OutputStream & out)1528     virtual void store(SpvId value, OutputStream& out) override {
1529         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1530     }
1531 
1532 private:
1533     SPIRVCodeGenerator& fGen;
1534     const SpvId fPointer;
1535     const SpvId fType;
1536 };
1537 
1538 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1539 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1540     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1541                   const Type& baseType, const Type& swizzleType)
1542     : fGen(gen)
1543     , fVecPointer(vecPointer)
1544     , fComponents(components)
1545     , fBaseType(baseType)
1546     , fSwizzleType(swizzleType) {}
1547 
getPointer()1548     virtual SpvId getPointer() override {
1549         return 0;
1550     }
1551 
load(OutputStream & out)1552     virtual SpvId load(OutputStream& out) override {
1553         SpvId base = fGen.nextId();
1554         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1555         SpvId result = fGen.nextId();
1556         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1557         fGen.writeWord(fGen.getType(fSwizzleType), out);
1558         fGen.writeWord(result, out);
1559         fGen.writeWord(base, out);
1560         fGen.writeWord(base, out);
1561         for (int component : fComponents) {
1562             fGen.writeWord(component, out);
1563         }
1564         return result;
1565     }
1566 
store(SpvId value,OutputStream & out)1567     virtual void store(SpvId value, OutputStream& out) override {
1568         // use OpVectorShuffle to mix and match the vector components. We effectively create
1569         // a virtual vector out of the concatenation of the left and right vectors, and then
1570         // select components from this virtual vector to make the result vector. For
1571         // instance, given:
1572         // float3L = ...;
1573         // float3R = ...;
1574         // L.xz = R.xy;
1575         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1576         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1577         // (3, 1, 4).
1578         SpvId base = fGen.nextId();
1579         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1580         SpvId shuffle = fGen.nextId();
1581         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1582         fGen.writeWord(fGen.getType(fBaseType), out);
1583         fGen.writeWord(shuffle, out);
1584         fGen.writeWord(base, out);
1585         fGen.writeWord(value, out);
1586         for (int i = 0; i < fBaseType.columns(); i++) {
1587             // current offset into the virtual vector, defaults to pulling the unmodified
1588             // value from the left side
1589             int offset = i;
1590             // check to see if we are writing this component
1591             for (size_t j = 0; j < fComponents.size(); j++) {
1592                 if (fComponents[j] == i) {
1593                     // we're writing to this component, so adjust the offset to pull from
1594                     // the correct component of the right side instead of preserving the
1595                     // value from the left
1596                     offset = (int) (j + fBaseType.columns());
1597                     break;
1598                 }
1599             }
1600             fGen.writeWord(offset, out);
1601         }
1602         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1603     }
1604 
1605 private:
1606     SPIRVCodeGenerator& fGen;
1607     const SpvId fVecPointer;
1608     const std::vector<int>& fComponents;
1609     const Type& fBaseType;
1610     const Type& fSwizzleType;
1611 };
1612 
getLValue(const Expression & expr,OutputStream & out)1613 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1614                                                                           OutputStream& out) {
1615     switch (expr.fKind) {
1616         case Expression::kVariableReference_Kind: {
1617             SpvId type;
1618             const Variable& var = ((VariableReference&) expr).fVariable;
1619             if (var.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
1620                 type = this->getType(Type("sk_in", Type::kArray_Kind, var.fType.componentType(),
1621                                           fSkInCount));
1622             } else {
1623                 type = this->getType(expr.fType);
1624             }
1625             auto entry = fVariableMap.find(&var);
1626             SkASSERT(entry != fVariableMap.end());
1627             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(*this,
1628                                                                                  entry->second,
1629                                                                                  type));
1630         }
1631         case Expression::kIndex_Kind: // fall through
1632         case Expression::kFieldAccess_Kind: {
1633             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1634             SpvId member = this->nextId();
1635             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1636             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1637             this->writeWord(member, out);
1638             for (SpvId idx : chain) {
1639                 this->writeWord(idx, out);
1640             }
1641             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1642                                                                         *this,
1643                                                                         member,
1644                                                                         this->getType(expr.fType)));
1645         }
1646         case Expression::kSwizzle_Kind: {
1647             Swizzle& swizzle = (Swizzle&) expr;
1648             size_t count = swizzle.fComponents.size();
1649             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1650             SkASSERT(base);
1651             if (count == 1) {
1652                 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1653                 SpvId member = this->nextId();
1654                 this->writeInstruction(SpvOpAccessChain,
1655                                        this->getPointerType(swizzle.fType,
1656                                                             get_storage_class(*swizzle.fBase)),
1657                                        member,
1658                                        base,
1659                                        this->writeIntLiteral(index),
1660                                        out);
1661                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1662                                                                        *this,
1663                                                                        member,
1664                                                                        this->getType(expr.fType)));
1665             } else {
1666                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1667                                                                               *this,
1668                                                                               base,
1669                                                                               swizzle.fComponents,
1670                                                                               swizzle.fBase->fType,
1671                                                                               expr.fType));
1672             }
1673         }
1674         case Expression::kTernary_Kind: {
1675             TernaryExpression& t = (TernaryExpression&) expr;
1676             SpvId test = this->writeExpression(*t.fTest, out);
1677             SpvId end = this->nextId();
1678             SpvId ifTrueLabel = this->nextId();
1679             SpvId ifFalseLabel = this->nextId();
1680             this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1681             this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1682             this->writeLabel(ifTrueLabel, out);
1683             SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1684             SkASSERT(ifTrue);
1685             this->writeInstruction(SpvOpBranch, end, out);
1686             ifTrueLabel = fCurrentBlock;
1687             SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1688             SkASSERT(ifFalse);
1689             ifFalseLabel = fCurrentBlock;
1690             this->writeInstruction(SpvOpBranch, end, out);
1691             SpvId result = this->nextId();
1692             this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1693                        ifTrueLabel, ifFalse, ifFalseLabel, out);
1694             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1695                                                                        *this,
1696                                                                        result,
1697                                                                        this->getType(expr.fType)));
1698         }
1699         default:
1700             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1701             // to the need to store values in temporary variables during function calls (see
1702             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1703             // caught by IRGenerator
1704             SpvId result = this->nextId();
1705             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1706             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1707                                    fVariableBuffer);
1708             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1709             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1710                                                                        *this,
1711                                                                        result,
1712                                                                        this->getType(expr.fType)));
1713     }
1714 }
1715 
writeVariableReference(const VariableReference & ref,OutputStream & out)1716 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1717     SpvId result = this->nextId();
1718     auto entry = fVariableMap.find(&ref.fVariable);
1719     SkASSERT(entry != fVariableMap.end());
1720     SpvId var = entry->second;
1721     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1722     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1723         fProgram.fSettings.fFlipY) {
1724         // need to remap to a top-left coordinate system
1725         if (fRTHeightStructId == (SpvId) -1) {
1726             // height variable hasn't been written yet
1727             std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1728             SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
1729             std::vector<Type::Field> fields;
1730             fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1731             StringFragment name("sksl_synthetic_uniforms");
1732             Type intfStruct(-1, name, fields);
1733             Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1734                           Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1735                           Layout::CType::kDefault);
1736             Variable* intfVar = new Variable(-1,
1737                                              Modifiers(layout, Modifiers::kUniform_Flag),
1738                                              name,
1739                                              intfStruct,
1740                                              Variable::kGlobal_Storage);
1741             fSynthetics.takeOwnership(intfVar);
1742             InterfaceBlock intf(-1, intfVar, name, String(""),
1743                                 std::vector<std::unique_ptr<Expression>>(), st);
1744             fRTHeightStructId = this->writeInterfaceBlock(intf);
1745             fRTHeightFieldIndex = 0;
1746         }
1747         SkASSERT(fRTHeightFieldIndex != (SpvId) -1);
1748         // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, gl_FragCoord.w)
1749         SpvId xId = this->nextId();
1750         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1751                                result, 0, out);
1752         IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1753         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1754         SpvId heightPtr = this->nextId();
1755         this->writeOpCode(SpvOpAccessChain, 5, out);
1756         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1757         this->writeWord(heightPtr, out);
1758         this->writeWord(fRTHeightStructId, out);
1759         this->writeWord(fieldIndexId, out);
1760         SpvId heightRead = this->nextId();
1761         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1762                                heightPtr, out);
1763         SpvId rawYId = this->nextId();
1764         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1765                                result, 1, out);
1766         SpvId flippedYId = this->nextId();
1767         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1768                                heightRead, rawYId, out);
1769         FloatLiteral zero(fContext, -1, 0.0);
1770         SpvId zeroId = writeFloatLiteral(zero);
1771         FloatLiteral one(fContext, -1, 1.0);
1772         SpvId wId = this->nextId();
1773         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), wId,
1774                                result, 3, out);
1775         SpvId flipped = this->nextId();
1776         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1777         this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1778         this->writeWord(flipped, out);
1779         this->writeWord(xId, out);
1780         this->writeWord(flippedYId, out);
1781         this->writeWord(zeroId, out);
1782         this->writeWord(wId, out);
1783         return flipped;
1784     }
1785     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
1786         !fProgram.fSettings.fFlipY) {
1787         // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
1788         // the default convention of "counter-clockwise face is front".
1789         SpvId inverse = this->nextId();
1790         this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fBool_Type), inverse,
1791                                result, out);
1792         return inverse;
1793     }
1794     return result;
1795 }
1796 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1797 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1798     if (expr.fBase->fType.kind() == Type::Kind::kVector_Kind) {
1799         SpvId base = this->writeExpression(*expr.fBase, out);
1800         SpvId index = this->writeExpression(*expr.fIndex, out);
1801         SpvId result = this->nextId();
1802         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.fType), result, base,
1803                                index, out);
1804         return result;
1805     }
1806     return getLValue(expr, out)->load(out);
1807 }
1808 
writeFieldAccess(const FieldAccess & f,OutputStream & out)1809 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1810     return getLValue(f, out)->load(out);
1811 }
1812 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1813 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1814     SpvId base = this->writeExpression(*swizzle.fBase, out);
1815     SpvId result = this->nextId();
1816     size_t count = swizzle.fComponents.size();
1817     if (count == 1) {
1818         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1819                                swizzle.fComponents[0], out);
1820     } else {
1821         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1822         this->writeWord(this->getType(swizzle.fType), out);
1823         this->writeWord(result, out);
1824         this->writeWord(base, out);
1825         this->writeWord(base, out);
1826         for (int component : swizzle.fComponents) {
1827             this->writeWord(component, out);
1828         }
1829     }
1830     return result;
1831 }
1832 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1833 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1834                                                const Type& operandType, SpvId lhs,
1835                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1836                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1837     SpvId result = this->nextId();
1838     if (is_float(fContext, operandType)) {
1839         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1840     } else if (is_signed(fContext, operandType)) {
1841         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1842     } else if (is_unsigned(fContext, operandType)) {
1843         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1844     } else if (operandType == *fContext.fBool_Type) {
1845         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1846     } else {
1847         ABORT("invalid operandType: %s", operandType.description().c_str());
1848     }
1849     return result;
1850 }
1851 
is_assignment(Token::Kind op)1852 bool is_assignment(Token::Kind op) {
1853     switch (op) {
1854         case Token::EQ:           // fall through
1855         case Token::PLUSEQ:       // fall through
1856         case Token::MINUSEQ:      // fall through
1857         case Token::STAREQ:       // fall through
1858         case Token::SLASHEQ:      // fall through
1859         case Token::PERCENTEQ:    // fall through
1860         case Token::SHLEQ:        // fall through
1861         case Token::SHREQ:        // fall through
1862         case Token::BITWISEOREQ:  // fall through
1863         case Token::BITWISEXOREQ: // fall through
1864         case Token::BITWISEANDEQ: // fall through
1865         case Token::LOGICALOREQ:  // fall through
1866         case Token::LOGICALXOREQ: // fall through
1867         case Token::LOGICALANDEQ:
1868             return true;
1869         default:
1870             return false;
1871     }
1872 }
1873 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)1874 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
1875                                      OutputStream& out) {
1876     if (operandType.kind() == Type::kVector_Kind) {
1877         SpvId result = this->nextId();
1878         this->writeInstruction(op, this->getType(*fContext.fBool_Type), result, id, out);
1879         return result;
1880     }
1881     return id;
1882 }
1883 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)1884 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1885                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
1886                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
1887                                                 OutputStream& out) {
1888     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1889     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
1890     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
1891                                                                             operandType.rows(),
1892                                                                             1));
1893     SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1894                                                                     operandType.rows(),
1895                                                                     1));
1896     SpvId boolType = this->getType(*fContext.fBool_Type);
1897     SpvId result = 0;
1898     for (int i = 0; i < operandType.columns(); i++) {
1899         SpvId columnL = this->nextId();
1900         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
1901         SpvId columnR = this->nextId();
1902         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
1903         SpvId compare = this->nextId();
1904         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
1905         SpvId merge = this->nextId();
1906         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
1907         if (result != 0) {
1908             SpvId next = this->nextId();
1909             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
1910             result = next;
1911         }
1912         else {
1913             result = merge;
1914         }
1915     }
1916     return result;
1917 }
1918 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)1919 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
1920                                                          SpvId rhs, SpvOp_ floatOperator,
1921                                                          SpvOp_ intOperator,
1922                                                          OutputStream& out) {
1923     SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
1924     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
1925     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
1926                                                                             operandType.rows(),
1927                                                                             1));
1928     SpvId columns[4];
1929     for (int i = 0; i < operandType.columns(); i++) {
1930         SpvId columnL = this->nextId();
1931         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
1932         SpvId columnR = this->nextId();
1933         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
1934         columns[i] = this->nextId();
1935         this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
1936     }
1937     SpvId result = this->nextId();
1938     this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
1939     this->writeWord(this->getType(operandType), out);
1940     this->writeWord(result, out);
1941     for (int i = 0; i < operandType.columns(); i++) {
1942         this->writeWord(columns[i], out);
1943     }
1944     return result;
1945 }
1946 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)1947 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1948     // handle cases where we don't necessarily evaluate both LHS and RHS
1949     switch (b.fOperator) {
1950         case Token::EQ: {
1951             SpvId rhs = this->writeExpression(*b.fRight, out);
1952             this->getLValue(*b.fLeft, out)->store(rhs, out);
1953             return rhs;
1954         }
1955         case Token::LOGICALAND:
1956             return this->writeLogicalAnd(b, out);
1957         case Token::LOGICALOR:
1958             return this->writeLogicalOr(b, out);
1959         default:
1960             break;
1961     }
1962 
1963     // "normal" operators
1964     const Type& resultType = b.fType;
1965     std::unique_ptr<LValue> lvalue;
1966     SpvId lhs;
1967     if (is_assignment(b.fOperator)) {
1968         lvalue = this->getLValue(*b.fLeft, out);
1969         lhs = lvalue->load(out);
1970     } else {
1971         lvalue = nullptr;
1972         lhs = this->writeExpression(*b.fLeft, out);
1973     }
1974     SpvId rhs = this->writeExpression(*b.fRight, out);
1975     if (b.fOperator == Token::COMMA) {
1976         return rhs;
1977     }
1978     Type tmp("<invalid>");
1979     // overall type we are operating on: float2, int, uint4...
1980     const Type* operandType;
1981     // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
1982     // handling in SPIR-V
1983     if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1984         if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1985             b.fRight->fType.isNumber()) {
1986             // promote number to vector
1987             SpvId vec = this->nextId();
1988             const Type& vecType = b.fLeft->fType;
1989             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
1990             this->writeWord(this->getType(vecType), out);
1991             this->writeWord(vec, out);
1992             for (int i = 0; i < vecType.columns(); i++) {
1993                 this->writeWord(rhs, out);
1994             }
1995             rhs = vec;
1996             operandType = &b.fLeft->fType;
1997         } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1998                    b.fLeft->fType.isNumber()) {
1999             // promote number to vector
2000             SpvId vec = this->nextId();
2001             const Type& vecType = b.fRight->fType;
2002             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2003             this->writeWord(this->getType(vecType), out);
2004             this->writeWord(vec, out);
2005             for (int i = 0; i < vecType.columns(); i++) {
2006                 this->writeWord(lhs, out);
2007             }
2008             lhs = vec;
2009             SkASSERT(!lvalue);
2010             operandType = &b.fRight->fType;
2011         } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2012             SpvOp_ op;
2013             if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2014                 op = SpvOpMatrixTimesMatrix;
2015             } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2016                 op = SpvOpMatrixTimesVector;
2017             } else {
2018                 SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2019                 op = SpvOpMatrixTimesScalar;
2020             }
2021             SpvId result = this->nextId();
2022             this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2023             if (b.fOperator == Token::STAREQ) {
2024                 lvalue->store(result, out);
2025             } else {
2026                 SkASSERT(b.fOperator == Token::STAR);
2027             }
2028             return result;
2029         } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2030             SpvId result = this->nextId();
2031             if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2032                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2033                                        lhs, rhs, out);
2034             } else {
2035                 SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2036                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2037                                        lhs, out);
2038             }
2039             if (b.fOperator == Token::STAREQ) {
2040                 lvalue->store(result, out);
2041             } else {
2042                 SkASSERT(b.fOperator == Token::STAR);
2043             }
2044             return result;
2045         } else {
2046             ABORT("unsupported binary expression: %s", b.description().c_str());
2047         }
2048     } else {
2049         tmp = this->getActualType(b.fLeft->fType);
2050         operandType = &tmp;
2051         SkASSERT(*operandType == this->getActualType(b.fRight->fType));
2052     }
2053     switch (b.fOperator) {
2054         case Token::EQEQ: {
2055             if (operandType->kind() == Type::kMatrix_Kind) {
2056                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2057                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2058             }
2059             SkASSERT(resultType == *fContext.fBool_Type);
2060             const Type* tmpType;
2061             if (operandType->kind() == Type::kVector_Kind) {
2062                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2063                                                            operandType->columns(),
2064                                                            operandType->rows());
2065             } else {
2066                 tmpType = &resultType;
2067             }
2068             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2069                                                                SpvOpFOrdEqual, SpvOpIEqual,
2070                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2071                                     *operandType, SpvOpAll, out);
2072         }
2073         case Token::NEQ:
2074             if (operandType->kind() == Type::kMatrix_Kind) {
2075                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2076                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2077             }
2078             SkASSERT(resultType == *fContext.fBool_Type);
2079             const Type* tmpType;
2080             if (operandType->kind() == Type::kVector_Kind) {
2081                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2082                                                            operandType->columns(),
2083                                                            operandType->rows());
2084             } else {
2085                 tmpType = &resultType;
2086             }
2087             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2088                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2089                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2090                                                                out),
2091                                     *operandType, SpvOpAny, out);
2092         case Token::GT:
2093             SkASSERT(resultType == *fContext.fBool_Type);
2094             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2095                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2096                                               SpvOpUGreaterThan, SpvOpUndef, out);
2097         case Token::LT:
2098             SkASSERT(resultType == *fContext.fBool_Type);
2099             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2100                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2101         case Token::GTEQ:
2102             SkASSERT(resultType == *fContext.fBool_Type);
2103             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2104                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2105                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2106         case Token::LTEQ:
2107             SkASSERT(resultType == *fContext.fBool_Type);
2108             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2109                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2110                                               SpvOpULessThanEqual, SpvOpUndef, out);
2111         case Token::PLUS:
2112             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2113                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2114                 SkASSERT(b.fLeft->fType == b.fRight->fType);
2115                 return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2116                                                             SpvOpFAdd, SpvOpIAdd, out);
2117             }
2118             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2119                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2120         case Token::MINUS:
2121             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2122                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2123                 SkASSERT(b.fLeft->fType == b.fRight->fType);
2124                 return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2125                                                             SpvOpFSub, SpvOpISub, out);
2126             }
2127             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2128                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2129         case Token::STAR:
2130             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2131                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2132                 // matrix multiply
2133                 SpvId result = this->nextId();
2134                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2135                                        lhs, rhs, out);
2136                 return result;
2137             }
2138             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2139                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2140         case Token::SLASH:
2141             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2142                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2143         case Token::PERCENT:
2144             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2145                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2146         case Token::SHL:
2147             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2148                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2149                                               SpvOpUndef, out);
2150         case Token::SHR:
2151             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2152                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2153                                               SpvOpUndef, out);
2154         case Token::BITWISEAND:
2155             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2156                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2157         case Token::BITWISEOR:
2158             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2159                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2160         case Token::BITWISEXOR:
2161             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2162                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2163         case Token::PLUSEQ: {
2164             SpvId result;
2165             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2166                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2167                 SkASSERT(b.fLeft->fType == b.fRight->fType);
2168                 result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2169                                                             SpvOpFAdd, SpvOpIAdd, out);
2170             }
2171             else {
2172                 result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2173                                                       SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2174             }
2175             SkASSERT(lvalue);
2176             lvalue->store(result, out);
2177             return result;
2178         }
2179         case Token::MINUSEQ: {
2180             SpvId result;
2181             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2182                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2183                 SkASSERT(b.fLeft->fType == b.fRight->fType);
2184                 result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2185                                                             SpvOpFSub, SpvOpISub, out);
2186             }
2187             else {
2188                 result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2189                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
2190             }
2191             SkASSERT(lvalue);
2192             lvalue->store(result, out);
2193             return result;
2194         }
2195         case Token::STAREQ: {
2196             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2197                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2198                 // matrix multiply
2199                 SpvId result = this->nextId();
2200                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2201                                        lhs, rhs, out);
2202                 SkASSERT(lvalue);
2203                 lvalue->store(result, out);
2204                 return result;
2205             }
2206             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2207                                                       SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2208             SkASSERT(lvalue);
2209             lvalue->store(result, out);
2210             return result;
2211         }
2212         case Token::SLASHEQ: {
2213             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2214                                                       SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2215             SkASSERT(lvalue);
2216             lvalue->store(result, out);
2217             return result;
2218         }
2219         case Token::PERCENTEQ: {
2220             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2221                                                       SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2222             SkASSERT(lvalue);
2223             lvalue->store(result, out);
2224             return result;
2225         }
2226         case Token::SHLEQ: {
2227             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2228                                                       SpvOpUndef, SpvOpShiftLeftLogical,
2229                                                       SpvOpShiftLeftLogical, SpvOpUndef, out);
2230             SkASSERT(lvalue);
2231             lvalue->store(result, out);
2232             return result;
2233         }
2234         case Token::SHREQ: {
2235             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2236                                                       SpvOpUndef, SpvOpShiftRightArithmetic,
2237                                                       SpvOpShiftRightLogical, SpvOpUndef, out);
2238             SkASSERT(lvalue);
2239             lvalue->store(result, out);
2240             return result;
2241         }
2242         case Token::BITWISEANDEQ: {
2243             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2244                                                       SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2245                                                       SpvOpUndef, out);
2246             SkASSERT(lvalue);
2247             lvalue->store(result, out);
2248             return result;
2249         }
2250         case Token::BITWISEOREQ: {
2251             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2252                                                       SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2253                                                       SpvOpUndef, out);
2254             SkASSERT(lvalue);
2255             lvalue->store(result, out);
2256             return result;
2257         }
2258         case Token::BITWISEXOREQ: {
2259             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2260                                                       SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2261                                                       SpvOpUndef, out);
2262             SkASSERT(lvalue);
2263             lvalue->store(result, out);
2264             return result;
2265         }
2266         default:
2267             ABORT("unsupported binary expression: %s", b.description().c_str());
2268     }
2269 }
2270 
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2271 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2272     SkASSERT(a.fOperator == Token::LOGICALAND);
2273     BoolLiteral falseLiteral(fContext, -1, false);
2274     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2275     SpvId lhs = this->writeExpression(*a.fLeft, out);
2276     SpvId rhsLabel = this->nextId();
2277     SpvId end = this->nextId();
2278     SpvId lhsBlock = fCurrentBlock;
2279     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2280     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2281     this->writeLabel(rhsLabel, out);
2282     SpvId rhs = this->writeExpression(*a.fRight, out);
2283     SpvId rhsBlock = fCurrentBlock;
2284     this->writeInstruction(SpvOpBranch, end, out);
2285     this->writeLabel(end, out);
2286     SpvId result = this->nextId();
2287     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2288                            lhsBlock, rhs, rhsBlock, out);
2289     return result;
2290 }
2291 
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2292 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2293     SkASSERT(o.fOperator == Token::LOGICALOR);
2294     BoolLiteral trueLiteral(fContext, -1, true);
2295     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2296     SpvId lhs = this->writeExpression(*o.fLeft, out);
2297     SpvId rhsLabel = this->nextId();
2298     SpvId end = this->nextId();
2299     SpvId lhsBlock = fCurrentBlock;
2300     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2301     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2302     this->writeLabel(rhsLabel, out);
2303     SpvId rhs = this->writeExpression(*o.fRight, out);
2304     SpvId rhsBlock = fCurrentBlock;
2305     this->writeInstruction(SpvOpBranch, end, out);
2306     this->writeLabel(end, out);
2307     SpvId result = this->nextId();
2308     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2309                            lhsBlock, rhs, rhsBlock, out);
2310     return result;
2311 }
2312 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2313 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2314     SpvId test = this->writeExpression(*t.fTest, out);
2315     if (t.fIfTrue->fType.columns() == 1 && t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2316         // both true and false are constants, can just use OpSelect
2317         SpvId result = this->nextId();
2318         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2319         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2320         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2321                                out);
2322         return result;
2323     }
2324     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2325     // Adreno. Switched to storing the result in a temp variable as glslang does.
2326     SpvId var = this->nextId();
2327     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2328                            var, SpvStorageClassFunction, fVariableBuffer);
2329     SpvId trueLabel = this->nextId();
2330     SpvId falseLabel = this->nextId();
2331     SpvId end = this->nextId();
2332     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2333     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2334     this->writeLabel(trueLabel, out);
2335     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2336     this->writeInstruction(SpvOpBranch, end, out);
2337     this->writeLabel(falseLabel, out);
2338     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2339     this->writeInstruction(SpvOpBranch, end, out);
2340     this->writeLabel(end, out);
2341     SpvId result = this->nextId();
2342     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2343     return result;
2344 }
2345 
create_literal_1(const Context & context,const Type & type)2346 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2347     if (type.isInteger()) {
2348         return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
2349     }
2350     else if (type.isFloat()) {
2351         return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
2352     } else {
2353         ABORT("math is unsupported on type '%s'", type.name().c_str());
2354     }
2355 }
2356 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2357 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2358     if (p.fOperator == Token::MINUS) {
2359         SpvId result = this->nextId();
2360         SpvId typeId = this->getType(p.fType);
2361         SpvId expr = this->writeExpression(*p.fOperand, out);
2362         if (is_float(fContext, p.fType)) {
2363             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2364         } else if (is_signed(fContext, p.fType)) {
2365             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2366         } else {
2367             ABORT("unsupported prefix expression %s", p.description().c_str());
2368         }
2369         return result;
2370     }
2371     switch (p.fOperator) {
2372         case Token::PLUS:
2373             return this->writeExpression(*p.fOperand, out);
2374         case Token::PLUSPLUS: {
2375             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2376             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2377             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2378                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2379                                                       out);
2380             lv->store(result, out);
2381             return result;
2382         }
2383         case Token::MINUSMINUS: {
2384             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2385             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2386             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2387                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2388                                                       out);
2389             lv->store(result, out);
2390             return result;
2391         }
2392         case Token::LOGICALNOT: {
2393             SkASSERT(p.fOperand->fType == *fContext.fBool_Type);
2394             SpvId result = this->nextId();
2395             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2396                                    this->writeExpression(*p.fOperand, out), out);
2397             return result;
2398         }
2399         case Token::BITWISENOT: {
2400             SpvId result = this->nextId();
2401             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2402                                    this->writeExpression(*p.fOperand, out), out);
2403             return result;
2404         }
2405         default:
2406             ABORT("unsupported prefix expression: %s", p.description().c_str());
2407     }
2408 }
2409 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2410 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2411     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2412     SpvId result = lv->load(out);
2413     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2414     switch (p.fOperator) {
2415         case Token::PLUSPLUS: {
2416             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2417                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2418             lv->store(temp, out);
2419             return result;
2420         }
2421         case Token::MINUSMINUS: {
2422             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2423                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2424             lv->store(temp, out);
2425             return result;
2426         }
2427         default:
2428             ABORT("unsupported postfix expression %s", p.description().c_str());
2429     }
2430 }
2431 
writeBoolLiteral(const BoolLiteral & b)2432 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2433     if (b.fValue) {
2434         if (fBoolTrue == 0) {
2435             fBoolTrue = this->nextId();
2436             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2437                                    fConstantBuffer);
2438         }
2439         return fBoolTrue;
2440     } else {
2441         if (fBoolFalse == 0) {
2442             fBoolFalse = this->nextId();
2443             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2444                                    fConstantBuffer);
2445         }
2446         return fBoolFalse;
2447     }
2448 }
2449 
writeIntLiteral(const IntLiteral & i)2450 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2451     if (i.fType == *fContext.fInt_Type) {
2452         auto entry = fIntConstants.find(i.fValue);
2453         if (entry == fIntConstants.end()) {
2454             SpvId result = this->nextId();
2455             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2456                                    fConstantBuffer);
2457             fIntConstants[i.fValue] = result;
2458             return result;
2459         }
2460         return entry->second;
2461     } else {
2462         SkASSERT(i.fType == *fContext.fUInt_Type);
2463         auto entry = fUIntConstants.find(i.fValue);
2464         if (entry == fUIntConstants.end()) {
2465             SpvId result = this->nextId();
2466             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2467                                    fConstantBuffer);
2468             fUIntConstants[i.fValue] = result;
2469             return result;
2470         }
2471         return entry->second;
2472     }
2473 }
2474 
writeFloatLiteral(const FloatLiteral & f)2475 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2476     if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2477         float value = (float) f.fValue;
2478         auto entry = fFloatConstants.find(value);
2479         if (entry == fFloatConstants.end()) {
2480             SpvId result = this->nextId();
2481             uint32_t bits;
2482             SkASSERT(sizeof(bits) == sizeof(value));
2483             memcpy(&bits, &value, sizeof(bits));
2484             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2485                                    fConstantBuffer);
2486             fFloatConstants[value] = result;
2487             return result;
2488         }
2489         return entry->second;
2490     } else {
2491         SkASSERT(f.fType == *fContext.fDouble_Type);
2492         auto entry = fDoubleConstants.find(f.fValue);
2493         if (entry == fDoubleConstants.end()) {
2494             SpvId result = this->nextId();
2495             uint64_t bits;
2496             SkASSERT(sizeof(bits) == sizeof(f.fValue));
2497             memcpy(&bits, &f.fValue, sizeof(bits));
2498             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2499                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2500             fDoubleConstants[f.fValue] = result;
2501             return result;
2502         }
2503         return entry->second;
2504     }
2505 }
2506 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2507 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2508     SpvId result = fFunctionMap[&f];
2509     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2510                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2511     this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2512     for (size_t i = 0; i < f.fParameters.size(); i++) {
2513         SpvId id = this->nextId();
2514         fVariableMap[f.fParameters[i]] = id;
2515         SpvId type;
2516         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2517         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2518     }
2519     return result;
2520 }
2521 
writeFunction(const FunctionDefinition & f,OutputStream & out)2522 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2523     fVariableBuffer.reset();
2524     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2525     this->writeLabel(this->nextId(), out);
2526     StringStream bodyBuffer;
2527     this->writeBlock((Block&) *f.fBody, bodyBuffer);
2528     write_stringstream(fVariableBuffer, out);
2529     if (f.fDeclaration.fName == "main") {
2530         write_stringstream(fGlobalInitializersBuffer, out);
2531     }
2532     write_stringstream(bodyBuffer, out);
2533     if (fCurrentBlock) {
2534         if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2535             this->writeInstruction(SpvOpReturn, out);
2536         } else {
2537             this->writeInstruction(SpvOpUnreachable, out);
2538         }
2539     }
2540     this->writeInstruction(SpvOpFunctionEnd, out);
2541     return result;
2542 }
2543 
writeLayout(const Layout & layout,SpvId target)2544 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2545     if (layout.fLocation >= 0) {
2546         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2547                                fDecorationBuffer);
2548     }
2549     if (layout.fBinding >= 0) {
2550         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2551                                fDecorationBuffer);
2552     }
2553     if (layout.fIndex >= 0) {
2554         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2555                                fDecorationBuffer);
2556     }
2557     if (layout.fSet >= 0) {
2558         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2559                                fDecorationBuffer);
2560     }
2561     if (layout.fInputAttachmentIndex >= 0) {
2562         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2563                                layout.fInputAttachmentIndex, fDecorationBuffer);
2564         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2565     }
2566     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2567         layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
2568         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2569                                fDecorationBuffer);
2570     }
2571 }
2572 
writeLayout(const Layout & layout,SpvId target,int member)2573 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2574     if (layout.fLocation >= 0) {
2575         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2576                                layout.fLocation, fDecorationBuffer);
2577     }
2578     if (layout.fBinding >= 0) {
2579         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2580                                layout.fBinding, fDecorationBuffer);
2581     }
2582     if (layout.fIndex >= 0) {
2583         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2584                                layout.fIndex, fDecorationBuffer);
2585     }
2586     if (layout.fSet >= 0) {
2587         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2588                                layout.fSet, fDecorationBuffer);
2589     }
2590     if (layout.fInputAttachmentIndex >= 0) {
2591         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2592                                layout.fInputAttachmentIndex, fDecorationBuffer);
2593     }
2594     if (layout.fBuiltin >= 0) {
2595         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2596                                layout.fBuiltin, fDecorationBuffer);
2597     }
2598 }
2599 
update_sk_in_count(const Modifiers & m,int * outSkInCount)2600 static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
2601     switch (m.fLayout.fPrimitive) {
2602         case Layout::kPoints_Primitive:
2603             *outSkInCount = 1;
2604             break;
2605         case Layout::kLines_Primitive:
2606             *outSkInCount = 2;
2607             break;
2608         case Layout::kLinesAdjacency_Primitive:
2609             *outSkInCount = 4;
2610             break;
2611         case Layout::kTriangles_Primitive:
2612             *outSkInCount = 3;
2613             break;
2614         case Layout::kTrianglesAdjacency_Primitive:
2615             *outSkInCount = 6;
2616             break;
2617         default:
2618             return;
2619     }
2620 }
2621 
writeInterfaceBlock(const InterfaceBlock & intf)2622 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2623     bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2624     bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2625                                Layout::kPushConstant_Flag));
2626     MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2627                                 MemoryLayout(MemoryLayout::k430_Standard) :
2628                                 fDefaultLayout;
2629     SpvId result = this->nextId();
2630     const Type* type = &intf.fVariable.fType;
2631     if (fProgram.fInputs.fRTHeight) {
2632         SkASSERT(fRTHeightStructId == (SpvId) -1);
2633         SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
2634         std::vector<Type::Field> fields = type->fields();
2635         fRTHeightStructId = result;
2636         fRTHeightFieldIndex = fields.size();
2637         fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2638         type = new Type(type->fOffset, type->name(), fields);
2639     }
2640     SpvId typeId;
2641     if (intf.fVariable.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2642         for (const auto& e : fProgram) {
2643             if (e.fKind == ProgramElement::kModifiers_Kind) {
2644                 const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
2645                 update_sk_in_count(m, &fSkInCount);
2646             }
2647         }
2648         typeId = this->getType(Type("sk_in", Type::kArray_Kind, intf.fVariable.fType.componentType(),
2649                                   fSkInCount), memoryLayout);
2650     } else {
2651         typeId = this->getType(*type, memoryLayout);
2652     }
2653     if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2654         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2655     } else if (intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
2656         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2657     }
2658     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2659     SpvId ptrType = this->nextId();
2660     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2661     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2662     Layout layout = intf.fVariable.fModifiers.fLayout;
2663     if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2664         layout.fSet = 0;
2665     }
2666     this->writeLayout(layout, result);
2667     fVariableMap[&intf.fVariable] = result;
2668     if (fProgram.fInputs.fRTHeight) {
2669         delete type;
2670     }
2671     return result;
2672 }
2673 
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2674 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2675     if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2676         (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2677         this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2678     }
2679 }
2680 
2681 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2682 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2683                                          OutputStream& out) {
2684     for (size_t i = 0; i < decl.fVars.size(); i++) {
2685         if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2686             continue;
2687         }
2688         const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2689         const Variable* var = varDecl.fVar;
2690         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2691         // in the OpenGL backend.
2692         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2693                                            Modifiers::kWriteOnly_Flag |
2694                                            Modifiers::kCoherent_Flag |
2695                                            Modifiers::kVolatile_Flag |
2696                                            Modifiers::kRestrict_Flag)));
2697         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2698             continue;
2699         }
2700         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2701             kind != Program::kFragment_Kind) {
2702             SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
2703             continue;
2704         }
2705         if (!var->fReadCount && !var->fWriteCount &&
2706                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2707                                             Modifiers::kOut_Flag |
2708                                             Modifiers::kUniform_Flag |
2709                                             Modifiers::kBuffer_Flag))) {
2710             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2711             // we elide an interface var, even if it's dead)
2712             continue;
2713         }
2714         SpvStorageClass_ storageClass;
2715         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2716             storageClass = SpvStorageClassInput;
2717         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2718             storageClass = SpvStorageClassOutput;
2719         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2720             if (var->fType.kind() == Type::kSampler_Kind) {
2721                 storageClass = SpvStorageClassUniformConstant;
2722             } else {
2723                 storageClass = SpvStorageClassUniform;
2724             }
2725         } else {
2726             storageClass = SpvStorageClassPrivate;
2727         }
2728         SpvId id = this->nextId();
2729         fVariableMap[var] = id;
2730         SpvId type;
2731         if (var->fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2732             type = this->getPointerType(Type("sk_in", Type::kArray_Kind,
2733                                              var->fType.componentType(), fSkInCount),
2734                                         storageClass);
2735         } else {
2736             type = this->getPointerType(var->fType, storageClass);
2737         }
2738         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2739         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2740         this->writePrecisionModifier(var->fModifiers, id);
2741         if (varDecl.fValue) {
2742             SkASSERT(!fCurrentBlock);
2743             fCurrentBlock = -1;
2744             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2745             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2746             fCurrentBlock = 0;
2747         }
2748         this->writeLayout(var->fModifiers.fLayout, id);
2749         if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2750             this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2751         }
2752         if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2753             this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2754                                    fDecorationBuffer);
2755         }
2756     }
2757 }
2758 
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2759 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2760     for (const auto& stmt : decl.fVars) {
2761         SkASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2762         VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2763         const Variable* var = varDecl.fVar;
2764         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2765         // in the OpenGL backend.
2766         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2767                                            Modifiers::kWriteOnly_Flag |
2768                                            Modifiers::kCoherent_Flag |
2769                                            Modifiers::kVolatile_Flag |
2770                                            Modifiers::kRestrict_Flag)));
2771         SpvId id = this->nextId();
2772         fVariableMap[var] = id;
2773         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2774         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2775         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2776         if (varDecl.fValue) {
2777             SpvId value = this->writeExpression(*varDecl.fValue, out);
2778             this->writeInstruction(SpvOpStore, id, value, out);
2779         }
2780     }
2781 }
2782 
writeStatement(const Statement & s,OutputStream & out)2783 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2784     switch (s.fKind) {
2785         case Statement::kNop_Kind:
2786             break;
2787         case Statement::kBlock_Kind:
2788             this->writeBlock((Block&) s, out);
2789             break;
2790         case Statement::kExpression_Kind:
2791             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2792             break;
2793         case Statement::kReturn_Kind:
2794             this->writeReturnStatement((ReturnStatement&) s, out);
2795             break;
2796         case Statement::kVarDeclarations_Kind:
2797             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2798             break;
2799         case Statement::kIf_Kind:
2800             this->writeIfStatement((IfStatement&) s, out);
2801             break;
2802         case Statement::kFor_Kind:
2803             this->writeForStatement((ForStatement&) s, out);
2804             break;
2805         case Statement::kWhile_Kind:
2806             this->writeWhileStatement((WhileStatement&) s, out);
2807             break;
2808         case Statement::kDo_Kind:
2809             this->writeDoStatement((DoStatement&) s, out);
2810             break;
2811         case Statement::kSwitch_Kind:
2812             this->writeSwitchStatement((SwitchStatement&) s, out);
2813             break;
2814         case Statement::kBreak_Kind:
2815             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2816             break;
2817         case Statement::kContinue_Kind:
2818             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2819             break;
2820         case Statement::kDiscard_Kind:
2821             this->writeInstruction(SpvOpKill, out);
2822             break;
2823         default:
2824             ABORT("unsupported statement: %s", s.description().c_str());
2825     }
2826 }
2827 
writeBlock(const Block & b,OutputStream & out)2828 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2829     for (size_t i = 0; i < b.fStatements.size(); i++) {
2830         this->writeStatement(*b.fStatements[i], out);
2831     }
2832 }
2833 
writeIfStatement(const IfStatement & stmt,OutputStream & out)2834 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2835     SpvId test = this->writeExpression(*stmt.fTest, out);
2836     SpvId ifTrue = this->nextId();
2837     SpvId ifFalse = this->nextId();
2838     if (stmt.fIfFalse) {
2839         SpvId end = this->nextId();
2840         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2841         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2842         this->writeLabel(ifTrue, out);
2843         this->writeStatement(*stmt.fIfTrue, out);
2844         if (fCurrentBlock) {
2845             this->writeInstruction(SpvOpBranch, end, out);
2846         }
2847         this->writeLabel(ifFalse, out);
2848         this->writeStatement(*stmt.fIfFalse, out);
2849         if (fCurrentBlock) {
2850             this->writeInstruction(SpvOpBranch, end, out);
2851         }
2852         this->writeLabel(end, out);
2853     } else {
2854         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2855         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2856         this->writeLabel(ifTrue, out);
2857         this->writeStatement(*stmt.fIfTrue, out);
2858         if (fCurrentBlock) {
2859             this->writeInstruction(SpvOpBranch, ifFalse, out);
2860         }
2861         this->writeLabel(ifFalse, out);
2862     }
2863 }
2864 
writeForStatement(const ForStatement & f,OutputStream & out)2865 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2866     if (f.fInitializer) {
2867         this->writeStatement(*f.fInitializer, out);
2868     }
2869     SpvId header = this->nextId();
2870     SpvId start = this->nextId();
2871     SpvId body = this->nextId();
2872     SpvId next = this->nextId();
2873     fContinueTarget.push(next);
2874     SpvId end = this->nextId();
2875     fBreakTarget.push(end);
2876     this->writeInstruction(SpvOpBranch, header, out);
2877     this->writeLabel(header, out);
2878     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2879     this->writeInstruction(SpvOpBranch, start, out);
2880     this->writeLabel(start, out);
2881     if (f.fTest) {
2882         SpvId test = this->writeExpression(*f.fTest, out);
2883         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2884     }
2885     this->writeLabel(body, out);
2886     this->writeStatement(*f.fStatement, out);
2887     if (fCurrentBlock) {
2888         this->writeInstruction(SpvOpBranch, next, out);
2889     }
2890     this->writeLabel(next, out);
2891     if (f.fNext) {
2892         this->writeExpression(*f.fNext, out);
2893     }
2894     this->writeInstruction(SpvOpBranch, header, out);
2895     this->writeLabel(end, out);
2896     fBreakTarget.pop();
2897     fContinueTarget.pop();
2898 }
2899 
writeWhileStatement(const WhileStatement & w,OutputStream & out)2900 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2901     // We believe the while loop code below will work, but Skia doesn't actually use them and
2902     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2903     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2904     // message, simply remove the error call below to see whether our while loop support actually
2905     // works.
2906     fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2907                   "see SkSLSPIRVCodeGenerator.cpp for details");
2908 
2909     SpvId header = this->nextId();
2910     SpvId start = this->nextId();
2911     SpvId body = this->nextId();
2912     fContinueTarget.push(start);
2913     SpvId end = this->nextId();
2914     fBreakTarget.push(end);
2915     this->writeInstruction(SpvOpBranch, header, out);
2916     this->writeLabel(header, out);
2917     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2918     this->writeInstruction(SpvOpBranch, start, out);
2919     this->writeLabel(start, out);
2920     SpvId test = this->writeExpression(*w.fTest, out);
2921     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2922     this->writeLabel(body, out);
2923     this->writeStatement(*w.fStatement, out);
2924     if (fCurrentBlock) {
2925         this->writeInstruction(SpvOpBranch, start, out);
2926     }
2927     this->writeLabel(end, out);
2928     fBreakTarget.pop();
2929     fContinueTarget.pop();
2930 }
2931 
writeDoStatement(const DoStatement & d,OutputStream & out)2932 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2933     // We believe the do loop code below will work, but Skia doesn't actually use them and
2934     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2935     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2936     // message, simply remove the error call below to see whether our do loop support actually
2937     // works.
2938     fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2939                   "SkSLSPIRVCodeGenerator.cpp for details");
2940 
2941     SpvId header = this->nextId();
2942     SpvId start = this->nextId();
2943     SpvId next = this->nextId();
2944     fContinueTarget.push(next);
2945     SpvId end = this->nextId();
2946     fBreakTarget.push(end);
2947     this->writeInstruction(SpvOpBranch, header, out);
2948     this->writeLabel(header, out);
2949     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2950     this->writeInstruction(SpvOpBranch, start, out);
2951     this->writeLabel(start, out);
2952     this->writeStatement(*d.fStatement, out);
2953     if (fCurrentBlock) {
2954         this->writeInstruction(SpvOpBranch, next, out);
2955     }
2956     this->writeLabel(next, out);
2957     SpvId test = this->writeExpression(*d.fTest, out);
2958     this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2959     this->writeLabel(end, out);
2960     fBreakTarget.pop();
2961     fContinueTarget.pop();
2962 }
2963 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2964 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2965     SpvId value = this->writeExpression(*s.fValue, out);
2966     std::vector<SpvId> labels;
2967     SpvId end = this->nextId();
2968     SpvId defaultLabel = end;
2969     fBreakTarget.push(end);
2970     int size = 3;
2971     for (const auto& c : s.fCases) {
2972         SpvId label = this->nextId();
2973         labels.push_back(label);
2974         if (c->fValue) {
2975             size += 2;
2976         } else {
2977             defaultLabel = label;
2978         }
2979     }
2980     labels.push_back(end);
2981     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2982     this->writeOpCode(SpvOpSwitch, size, out);
2983     this->writeWord(value, out);
2984     this->writeWord(defaultLabel, out);
2985     for (size_t i = 0; i < s.fCases.size(); ++i) {
2986         if (!s.fCases[i]->fValue) {
2987             continue;
2988         }
2989         SkASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2990         this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2991         this->writeWord(labels[i], out);
2992     }
2993     for (size_t i = 0; i < s.fCases.size(); ++i) {
2994         this->writeLabel(labels[i], out);
2995         for (const auto& stmt : s.fCases[i]->fStatements) {
2996             this->writeStatement(*stmt, out);
2997         }
2998         if (fCurrentBlock) {
2999             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3000         }
3001     }
3002     this->writeLabel(end, out);
3003     fBreakTarget.pop();
3004 }
3005 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3006 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3007     if (r.fExpression) {
3008         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3009                                out);
3010     } else {
3011         this->writeInstruction(SpvOpReturn, out);
3012     }
3013 }
3014 
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)3015 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
3016     SkASSERT(fProgram.fKind == Program::kGeometry_Kind);
3017     int invocations = 1;
3018     for (const auto& e : fProgram) {
3019         if (e.fKind == ProgramElement::kModifiers_Kind) {
3020             const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3021             if (m.fFlags & Modifiers::kIn_Flag) {
3022                 if (m.fLayout.fInvocations != -1) {
3023                     invocations = m.fLayout.fInvocations;
3024                 }
3025                 SpvId input;
3026                 switch (m.fLayout.fPrimitive) {
3027                     case Layout::kPoints_Primitive:
3028                         input = SpvExecutionModeInputPoints;
3029                         break;
3030                     case Layout::kLines_Primitive:
3031                         input = SpvExecutionModeInputLines;
3032                         break;
3033                     case Layout::kLinesAdjacency_Primitive:
3034                         input = SpvExecutionModeInputLinesAdjacency;
3035                         break;
3036                     case Layout::kTriangles_Primitive:
3037                         input = SpvExecutionModeTriangles;
3038                         break;
3039                     case Layout::kTrianglesAdjacency_Primitive:
3040                         input = SpvExecutionModeInputTrianglesAdjacency;
3041                         break;
3042                     default:
3043                         input = 0;
3044                         break;
3045                 }
3046                 update_sk_in_count(m, &fSkInCount);
3047                 if (input) {
3048                     this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
3049                 }
3050             } else if (m.fFlags & Modifiers::kOut_Flag) {
3051                 SpvId output;
3052                 switch (m.fLayout.fPrimitive) {
3053                     case Layout::kPoints_Primitive:
3054                         output = SpvExecutionModeOutputPoints;
3055                         break;
3056                     case Layout::kLineStrip_Primitive:
3057                         output = SpvExecutionModeOutputLineStrip;
3058                         break;
3059                     case Layout::kTriangleStrip_Primitive:
3060                         output = SpvExecutionModeOutputTriangleStrip;
3061                         break;
3062                     default:
3063                         output = 0;
3064                         break;
3065                 }
3066                 if (output) {
3067                     this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
3068                 }
3069                 if (m.fLayout.fMaxVertices != -1) {
3070                     this->writeInstruction(SpvOpExecutionMode, entryPoint,
3071                                            SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
3072                                            out);
3073                 }
3074             }
3075         }
3076     }
3077     this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
3078                            invocations, out);
3079 }
3080 
writeInstructions(const Program & program,OutputStream & out)3081 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3082     fGLSLExtendedInstructions = this->nextId();
3083     StringStream body;
3084     std::set<SpvId> interfaceVars;
3085     // assign IDs to functions, determine sk_in size
3086     int skInSize = -1;
3087     for (const auto& e : program) {
3088         switch (e.fKind) {
3089             case ProgramElement::kFunction_Kind: {
3090                 FunctionDefinition& f = (FunctionDefinition&) e;
3091                 fFunctionMap[&f.fDeclaration] = this->nextId();
3092                 break;
3093             }
3094             case ProgramElement::kModifiers_Kind: {
3095                 Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3096                 if (m.fFlags & Modifiers::kIn_Flag) {
3097                     switch (m.fLayout.fPrimitive) {
3098                         case Layout::kPoints_Primitive: // break
3099                         case Layout::kLines_Primitive:
3100                             skInSize = 1;
3101                             break;
3102                         case Layout::kLinesAdjacency_Primitive: // break
3103                             skInSize = 2;
3104                             break;
3105                         case Layout::kTriangles_Primitive: // break
3106                         case Layout::kTrianglesAdjacency_Primitive:
3107                             skInSize = 3;
3108                             break;
3109                         default:
3110                             break;
3111                     }
3112                 }
3113                 break;
3114             }
3115             default:
3116                 break;
3117         }
3118     }
3119     for (const auto& e : program) {
3120         if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
3121             InterfaceBlock& intf = (InterfaceBlock&) e;
3122             if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
3123                 SkASSERT(skInSize != -1);
3124                 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
3125             }
3126             SpvId id = this->writeInterfaceBlock(intf);
3127             if (((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3128                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) &&
3129                 intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
3130                 interfaceVars.insert(id);
3131             }
3132         }
3133     }
3134     for (const auto& e : program) {
3135         if (e.fKind == ProgramElement::kVar_Kind) {
3136             this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
3137         }
3138     }
3139     for (const auto& e : program) {
3140         if (e.fKind == ProgramElement::kFunction_Kind) {
3141             this->writeFunction(((FunctionDefinition&) e), body);
3142         }
3143     }
3144     const FunctionDeclaration* main = nullptr;
3145     for (auto entry : fFunctionMap) {
3146         if (entry.first->fName == "main") {
3147             main = entry.first;
3148         }
3149     }
3150     SkASSERT(main);
3151     for (auto entry : fVariableMap) {
3152         const Variable* var = entry.first;
3153         if (var->fStorage == Variable::kGlobal_Storage &&
3154             ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3155              (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3156             interfaceVars.insert(entry.second);
3157         }
3158     }
3159     this->writeCapabilities(out);
3160     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3161     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3162     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
3163                       (int32_t) interfaceVars.size(), out);
3164     switch (program.fKind) {
3165         case Program::kVertex_Kind:
3166             this->writeWord(SpvExecutionModelVertex, out);
3167             break;
3168         case Program::kFragment_Kind:
3169             this->writeWord(SpvExecutionModelFragment, out);
3170             break;
3171         case Program::kGeometry_Kind:
3172             this->writeWord(SpvExecutionModelGeometry, out);
3173             break;
3174         default:
3175             ABORT("cannot write this kind of program to SPIR-V\n");
3176     }
3177     SpvId entryPoint = fFunctionMap[main];
3178     this->writeWord(entryPoint, out);
3179     this->writeString(main->fName.fChars, main->fName.fLength, out);
3180     for (int var : interfaceVars) {
3181         this->writeWord(var, out);
3182     }
3183     if (program.fKind == Program::kGeometry_Kind) {
3184         this->writeGeometryShaderExecutionMode(entryPoint, out);
3185     }
3186     if (program.fKind == Program::kFragment_Kind) {
3187         this->writeInstruction(SpvOpExecutionMode,
3188                                fFunctionMap[main],
3189                                SpvExecutionModeOriginUpperLeft,
3190                                out);
3191     }
3192     for (const auto& e : program) {
3193         if (e.fKind == ProgramElement::kExtension_Kind) {
3194             this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
3195         }
3196     }
3197 
3198     write_stringstream(fExtraGlobalsBuffer, out);
3199     write_stringstream(fNameBuffer, out);
3200     write_stringstream(fDecorationBuffer, out);
3201     write_stringstream(fConstantBuffer, out);
3202     write_stringstream(fExternalFunctionsBuffer, out);
3203     write_stringstream(body, out);
3204 }
3205 
generateCode()3206 bool SPIRVCodeGenerator::generateCode() {
3207     SkASSERT(!fErrors.errorCount());
3208     this->writeWord(SpvMagicNumber, *fOut);
3209     this->writeWord(SpvVersion, *fOut);
3210     this->writeWord(SKSL_MAGIC, *fOut);
3211     StringStream buffer;
3212     this->writeInstructions(fProgram, buffer);
3213     this->writeWord(fIdCount, *fOut);
3214     this->writeWord(0, *fOut); // reserved, always zero
3215     write_stringstream(buffer, *fOut);
3216     return 0 == fErrors.errorCount();
3217 }
3218 
3219 }
3220