• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLSPIRVCodeGenerator.h"
9 
10 #include "GLSL.std.450.h"
11 
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17 
18 namespace SkSL {
19 
20 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
21 
setupIntrinsics()22 void SPIRVCodeGenerator::setupIntrinsics() {
23 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24                                     GLSLstd450 ## x, GLSLstd450 ## x)
25 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26                                                              GLSLstd450 ## ifFloat, \
27                                                              GLSLstd450 ## ifInt, \
28                                                              GLSLstd450 ## ifUInt, \
29                                                              SpvOpUndef)
30 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31                                                            SpvOp ## x)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34                                    k ## x ## _SpecialIntrinsic)
35     fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
36     fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
37     fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
38     fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39     fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
40     fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
41     fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
42     fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
43     fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
44     fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
45     fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
46     fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
47     fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
48     fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
49     fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
50     fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
51     fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
52     fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
53     fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
54     fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
55     fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
56     fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
57     fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
58     fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
59     fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
60     fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
61     fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
62     fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
63     fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
64     fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
65     fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
67     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68     fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
69     fIntrinsicMap[String("min")]           = SPECIAL(Min);
70     fIntrinsicMap[String("max")]           = SPECIAL(Max);
71     fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
72     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
73                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
74     fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
75     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
76     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
77     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
78     fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
79     fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
80 
81 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
82                    fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
83     PACK(Snorm4x8);
84     PACK(Unorm4x8);
85     PACK(Snorm2x16);
86     PACK(Unorm2x16);
87     PACK(Half2x16);
88     PACK(Double2x32);
89     fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
90     fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
91     fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
92     fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
93     fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
94     fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
95     fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
96     fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
97     fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
98     fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
99                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
100     fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
101                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
102     fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
103                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
104     fIntrinsicMap[String("texture")]     = SPECIAL(Texture);
105     fIntrinsicMap[String("texelFetch")]  = SPECIAL(TexelFetch);
106     fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107 
108     fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109                                                                 SpvOpUndef, SpvOpUndef, SpvOpAny);
110     fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111                                                                 SpvOpUndef, SpvOpUndef, SpvOpAll);
112     fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
113                                                                 SpvOpFOrdEqual, SpvOpIEqual,
114                                                                 SpvOpIEqual, SpvOpLogicalEqual);
115     fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
116                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
117                                                                 SpvOpINotEqual,
118                                                                 SpvOpLogicalNotEqual);
119     fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
120                                                                 SpvOpFOrdLessThan, SpvOpSLessThan,
121                                                                 SpvOpULessThan, SpvOpUndef);
122     fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
123                                                                 SpvOpFOrdLessThanEqual,
124                                                                 SpvOpSLessThanEqual,
125                                                                 SpvOpULessThanEqual,
126                                                                 SpvOpUndef);
127     fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                 SpvOpFOrdGreaterThan,
129                                                                 SpvOpSGreaterThan,
130                                                                 SpvOpUGreaterThan,
131                                                                 SpvOpUndef);
132     fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                 SpvOpFOrdGreaterThanEqual,
134                                                                 SpvOpSGreaterThanEqual,
135                                                                 SpvOpUGreaterThanEqual,
136                                                                 SpvOpUndef);
137     fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
138     fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
139 // interpolateAt* not yet supported...
140 }
141 
writeWord(int32_t word,OutputStream & out)142 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143     out.write((const char*) &word, sizeof(word));
144 }
145 
is_float(const Context & context,const Type & type)146 static bool is_float(const Context& context, const Type& type) {
147     if (type.kind() == Type::kVector_Kind) {
148         return is_float(context, type.componentType());
149     }
150     return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151            type == *context.fDouble_Type;
152 }
153 
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155     if (type.kind() == Type::kVector_Kind) {
156         return is_signed(context, type.componentType());
157     }
158     return type == *context.fInt_Type || type == *context.fShort_Type;
159 }
160 
is_unsigned(const Context & context,const Type & type)161 static bool is_unsigned(const Context& context, const Type& type) {
162     if (type.kind() == Type::kVector_Kind) {
163         return is_unsigned(context, type.componentType());
164     }
165     return type == *context.fUInt_Type || type == *context.fUShort_Type;
166 }
167 
is_bool(const Context & context,const Type & type)168 static bool is_bool(const Context& context, const Type& type) {
169     if (type.kind() == Type::kVector_Kind) {
170         return is_bool(context, type.componentType());
171     }
172     return type == *context.fBool_Type;
173 }
174 
is_out(const Variable & var)175 static bool is_out(const Variable& var) {
176     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177 }
178 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)179 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
180     ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
181     ASSERT(opCode != SpvOpUndef);
182     switch (opCode) {
183         case SpvOpReturn:      // fall through
184         case SpvOpReturnValue: // fall through
185         case SpvOpKill:        // fall through
186         case SpvOpBranch:      // fall through
187         case SpvOpBranchConditional:
188             ASSERT(fCurrentBlock);
189             fCurrentBlock = 0;
190             break;
191         case SpvOpConstant:          // fall through
192         case SpvOpConstantTrue:      // fall through
193         case SpvOpConstantFalse:     // fall through
194         case SpvOpConstantComposite: // fall through
195         case SpvOpTypeVoid:          // fall through
196         case SpvOpTypeInt:           // fall through
197         case SpvOpTypeFloat:         // fall through
198         case SpvOpTypeBool:          // fall through
199         case SpvOpTypeVector:        // fall through
200         case SpvOpTypeMatrix:        // fall through
201         case SpvOpTypeArray:         // fall through
202         case SpvOpTypePointer:       // fall through
203         case SpvOpTypeFunction:      // fall through
204         case SpvOpTypeRuntimeArray:  // fall through
205         case SpvOpTypeStruct:        // fall through
206         case SpvOpTypeImage:         // fall through
207         case SpvOpTypeSampledImage:  // fall through
208         case SpvOpVariable:          // fall through
209         case SpvOpFunction:          // fall through
210         case SpvOpFunctionParameter: // fall through
211         case SpvOpFunctionEnd:       // fall through
212         case SpvOpExecutionMode:     // fall through
213         case SpvOpMemoryModel:       // fall through
214         case SpvOpCapability:        // fall through
215         case SpvOpExtInstImport:     // fall through
216         case SpvOpEntryPoint:        // fall through
217         case SpvOpSource:            // fall through
218         case SpvOpSourceExtension:   // fall through
219         case SpvOpName:              // fall through
220         case SpvOpMemberName:        // fall through
221         case SpvOpDecorate:          // fall through
222         case SpvOpMemberDecorate:
223             break;
224         default:
225             ASSERT(fCurrentBlock);
226     }
227     this->writeWord((length << 16) | opCode, out);
228 }
229 
writeLabel(SpvId label,OutputStream & out)230 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
231     fCurrentBlock = label;
232     this->writeInstruction(SpvOpLabel, label, out);
233 }
234 
writeInstruction(SpvOp_ opCode,OutputStream & out)235 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
236     this->writeOpCode(opCode, 1, out);
237 }
238 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)239 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
240     this->writeOpCode(opCode, 2, out);
241     this->writeWord(word1, out);
242 }
243 
writeString(const char * string,size_t length,OutputStream & out)244 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
245     out.write(string, length);
246     switch (length % 4) {
247         case 1:
248             out.write8(0);
249             // fall through
250         case 2:
251             out.write8(0);
252             // fall through
253         case 3:
254             out.write8(0);
255             break;
256         default:
257             this->writeWord(0, out);
258     }
259 }
260 
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)261 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
262     this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
263     this->writeString(string.fChars, string.fLength, out);
264 }
265 
266 
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)267 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
268                                           OutputStream& out) {
269     this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
270     this->writeWord(word1, out);
271     this->writeString(string.fChars, string.fLength, out);
272 }
273 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)274 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
275                                           StringFragment string, OutputStream& out) {
276     this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
277     this->writeWord(word1, out);
278     this->writeWord(word2, out);
279     this->writeString(string.fChars, string.fLength, out);
280 }
281 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)282 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283                                           OutputStream& out) {
284     this->writeOpCode(opCode, 3, out);
285     this->writeWord(word1, out);
286     this->writeWord(word2, out);
287 }
288 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)289 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
290                                           int32_t word3, OutputStream& out) {
291     this->writeOpCode(opCode, 4, out);
292     this->writeWord(word1, out);
293     this->writeWord(word2, out);
294     this->writeWord(word3, out);
295 }
296 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)297 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298                                           int32_t word3, int32_t word4, OutputStream& out) {
299     this->writeOpCode(opCode, 5, out);
300     this->writeWord(word1, out);
301     this->writeWord(word2, out);
302     this->writeWord(word3, out);
303     this->writeWord(word4, out);
304 }
305 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)306 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
307                                           int32_t word3, int32_t word4, int32_t word5,
308                                           OutputStream& out) {
309     this->writeOpCode(opCode, 6, out);
310     this->writeWord(word1, out);
311     this->writeWord(word2, out);
312     this->writeWord(word3, out);
313     this->writeWord(word4, out);
314     this->writeWord(word5, out);
315 }
316 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)317 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
318                                           int32_t word3, int32_t word4, int32_t word5,
319                                           int32_t word6, OutputStream& out) {
320     this->writeOpCode(opCode, 7, out);
321     this->writeWord(word1, out);
322     this->writeWord(word2, out);
323     this->writeWord(word3, out);
324     this->writeWord(word4, out);
325     this->writeWord(word5, out);
326     this->writeWord(word6, out);
327 }
328 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)329 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
330                                           int32_t word3, int32_t word4, int32_t word5,
331                                           int32_t word6, int32_t word7, OutputStream& out) {
332     this->writeOpCode(opCode, 8, out);
333     this->writeWord(word1, out);
334     this->writeWord(word2, out);
335     this->writeWord(word3, out);
336     this->writeWord(word4, out);
337     this->writeWord(word5, out);
338     this->writeWord(word6, out);
339     this->writeWord(word7, out);
340 }
341 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)342 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
343                                           int32_t word3, int32_t word4, int32_t word5,
344                                           int32_t word6, int32_t word7, int32_t word8,
345                                           OutputStream& out) {
346     this->writeOpCode(opCode, 9, out);
347     this->writeWord(word1, out);
348     this->writeWord(word2, out);
349     this->writeWord(word3, out);
350     this->writeWord(word4, out);
351     this->writeWord(word5, out);
352     this->writeWord(word6, out);
353     this->writeWord(word7, out);
354     this->writeWord(word8, out);
355 }
356 
writeCapabilities(OutputStream & out)357 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
358     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
359         if (fCapabilities & bit) {
360             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
361         }
362     }
363     if (fProgram.fKind == Program::kGeometry_Kind) {
364         this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
365     }
366 }
367 
nextId()368 SpvId SPIRVCodeGenerator::nextId() {
369     return fIdCount++;
370 }
371 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)372 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
373                                      SpvId resultId) {
374     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
375     // go ahead and write all of the field types, so we don't inadvertently write them while we're
376     // in the middle of writing the struct instruction
377     std::vector<SpvId> types;
378     for (const auto& f : type.fields()) {
379         types.push_back(this->getType(*f.fType, memoryLayout));
380     }
381     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
382     this->writeWord(resultId, fConstantBuffer);
383     for (SpvId id : types) {
384         this->writeWord(id, fConstantBuffer);
385     }
386     size_t offset = 0;
387     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
388         size_t size = memoryLayout.size(*type.fields()[i].fType);
389         size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
390         const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
391         if (fieldLayout.fOffset >= 0) {
392             if (fieldLayout.fOffset < (int) offset) {
393                 fErrors.error(type.fOffset,
394                               "offset of field '" + type.fields()[i].fName + "' must be at "
395                               "least " + to_string((int) offset));
396             }
397             if (fieldLayout.fOffset % alignment) {
398                 fErrors.error(type.fOffset,
399                               "offset of field '" + type.fields()[i].fName + "' must be a multiple"
400                               " of " + to_string((int) alignment));
401             }
402             offset = fieldLayout.fOffset;
403         } else {
404             size_t mod = offset % alignment;
405             if (mod) {
406                 offset += alignment - mod;
407             }
408         }
409         this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
410         this->writeLayout(fieldLayout, resultId, i);
411         if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
412             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
413                                    (SpvId) offset, fDecorationBuffer);
414         }
415         if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
416             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
417                                    fDecorationBuffer);
418             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
419                                    (SpvId) memoryLayout.stride(*type.fields()[i].fType),
420                                    fDecorationBuffer);
421         }
422         offset += size;
423         Type::Kind kind = type.fields()[i].fType->kind();
424         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
425             offset += alignment - offset % alignment;
426         }
427     }
428 }
429 
getActualType(const Type & type)430 Type SPIRVCodeGenerator::getActualType(const Type& type) {
431     if (type == *fContext.fHalf_Type) {
432         return *fContext.fFloat_Type;
433     }
434     if (type == *fContext.fShort_Type) {
435         return *fContext.fInt_Type;
436     }
437     if (type == *fContext.fUShort_Type) {
438         return *fContext.fUInt_Type;
439     }
440     if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
441         if (type.componentType() == *fContext.fHalf_Type) {
442             return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
443         }
444         if (type.componentType() == *fContext.fShort_Type) {
445             return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
446         }
447         if (type.componentType() == *fContext.fUShort_Type) {
448             return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
449         }
450     }
451     return type;
452 }
453 
getType(const Type & type)454 SpvId SPIRVCodeGenerator::getType(const Type& type) {
455     return this->getType(type, fDefaultLayout);
456 }
457 
getType(const Type & rawType,const MemoryLayout & layout)458 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
459     Type type = this->getActualType(rawType);
460     String key = type.name() + to_string((int) layout.fStd);
461     auto entry = fTypeMap.find(key);
462     if (entry == fTypeMap.end()) {
463         SpvId result = this->nextId();
464         switch (type.kind()) {
465             case Type::kScalar_Kind:
466                 if (type == *fContext.fBool_Type) {
467                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
468                 } else if (type == *fContext.fInt_Type) {
469                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
470                 } else if (type == *fContext.fUInt_Type) {
471                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
472                 } else if (type == *fContext.fFloat_Type) {
473                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
474                 } else if (type == *fContext.fDouble_Type) {
475                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
476                 } else {
477                     ASSERT(false);
478                 }
479                 break;
480             case Type::kVector_Kind:
481                 this->writeInstruction(SpvOpTypeVector, result,
482                                        this->getType(type.componentType(), layout),
483                                        type.columns(), fConstantBuffer);
484                 break;
485             case Type::kMatrix_Kind:
486                 this->writeInstruction(SpvOpTypeMatrix, result,
487                                        this->getType(index_type(fContext, type), layout),
488                                        type.columns(), fConstantBuffer);
489                 break;
490             case Type::kStruct_Kind:
491                 this->writeStruct(type, layout, result);
492                 break;
493             case Type::kArray_Kind: {
494                 if (type.columns() > 0) {
495                     IntLiteral count(fContext, -1, type.columns());
496                     this->writeInstruction(SpvOpTypeArray, result,
497                                            this->getType(type.componentType(), layout),
498                                            this->writeIntLiteral(count), fConstantBuffer);
499                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
500                                            (int32_t) layout.stride(type),
501                                            fDecorationBuffer);
502                 } else {
503                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
504                                            this->getType(type.componentType(), layout),
505                                            fConstantBuffer);
506                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507                                            (int32_t) layout.stride(type),
508                                            fDecorationBuffer);
509                 }
510                 break;
511             }
512             case Type::kSampler_Kind: {
513                 SpvId image = result;
514                 if (SpvDimSubpassData != type.dimensions()) {
515                     image = this->nextId();
516                 }
517                 if (SpvDimBuffer == type.dimensions()) {
518                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
519                 }
520                 this->writeInstruction(SpvOpTypeImage, image,
521                                        this->getType(*fContext.fFloat_Type, layout),
522                                        type.dimensions(), type.isDepth(), type.isArrayed(),
523                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
524                                        SpvImageFormatUnknown, fConstantBuffer);
525                 fImageTypeMap[key] = image;
526                 if (SpvDimSubpassData != type.dimensions()) {
527                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
528                 }
529                 break;
530             }
531             default:
532                 if (type == *fContext.fVoid_Type) {
533                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
534                 } else {
535                     ABORT("invalid type: %s", type.description().c_str());
536                 }
537         }
538         fTypeMap[key] = result;
539         return result;
540     }
541     return entry->second;
542 }
543 
getImageType(const Type & type)544 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
545     ASSERT(type.kind() == Type::kSampler_Kind);
546     this->getType(type);
547     String key = type.name() + to_string((int) fDefaultLayout.fStd);
548     ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
549     return fImageTypeMap[key];
550 }
551 
getFunctionType(const FunctionDeclaration & function)552 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
553     String key = function.fReturnType.description() + "(";
554     String separator;
555     for (size_t i = 0; i < function.fParameters.size(); i++) {
556         key += separator;
557         separator = ", ";
558         key += function.fParameters[i]->fType.description();
559     }
560     key += ")";
561     auto entry = fTypeMap.find(key);
562     if (entry == fTypeMap.end()) {
563         SpvId result = this->nextId();
564         int32_t length = 3 + (int32_t) function.fParameters.size();
565         SpvId returnType = this->getType(function.fReturnType);
566         std::vector<SpvId> parameterTypes;
567         for (size_t i = 0; i < function.fParameters.size(); i++) {
568             // glslang seems to treat all function arguments as pointers whether they need to be or
569             // not. I  was initially puzzled by this until I ran bizarre failures with certain
570             // patterns of function calls and control constructs, as exemplified by this minimal
571             // failure case:
572             //
573             // void sphere(float x) {
574             // }
575             //
576             // void map() {
577             //     sphere(1.0);
578             // }
579             //
580             // void main() {
581             //     for (int i = 0; i < 1; i++) {
582             //         map();
583             //     }
584             // }
585             //
586             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
587             // crashes. Making it take a float* and storing the argument in a temporary variable,
588             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
589             // the spec makes this make sense.
590 //            if (is_out(function->fParameters[i])) {
591                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
592                                                               SpvStorageClassFunction));
593 //            } else {
594 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
595 //            }
596         }
597         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
598         this->writeWord(result, fConstantBuffer);
599         this->writeWord(returnType, fConstantBuffer);
600         for (SpvId id : parameterTypes) {
601             this->writeWord(id, fConstantBuffer);
602         }
603         fTypeMap[key] = result;
604         return result;
605     }
606     return entry->second;
607 }
608 
getPointerType(const Type & type,SpvStorageClass_ storageClass)609 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
610     return this->getPointerType(type, fDefaultLayout, storageClass);
611 }
612 
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)613 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
614                                          SpvStorageClass_ storageClass) {
615     Type type = this->getActualType(rawType);
616     String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
617     auto entry = fTypeMap.find(key);
618     if (entry == fTypeMap.end()) {
619         SpvId result = this->nextId();
620         this->writeInstruction(SpvOpTypePointer, result, storageClass,
621                                this->getType(type), fConstantBuffer);
622         fTypeMap[key] = result;
623         return result;
624     }
625     return entry->second;
626 }
627 
writeExpression(const Expression & expr,OutputStream & out)628 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
629     switch (expr.fKind) {
630         case Expression::kBinary_Kind:
631             return this->writeBinaryExpression((BinaryExpression&) expr, out);
632         case Expression::kBoolLiteral_Kind:
633             return this->writeBoolLiteral((BoolLiteral&) expr);
634         case Expression::kConstructor_Kind:
635             return this->writeConstructor((Constructor&) expr, out);
636         case Expression::kIntLiteral_Kind:
637             return this->writeIntLiteral((IntLiteral&) expr);
638         case Expression::kFieldAccess_Kind:
639             return this->writeFieldAccess(((FieldAccess&) expr), out);
640         case Expression::kFloatLiteral_Kind:
641             return this->writeFloatLiteral(((FloatLiteral&) expr));
642         case Expression::kFunctionCall_Kind:
643             return this->writeFunctionCall((FunctionCall&) expr, out);
644         case Expression::kPrefix_Kind:
645             return this->writePrefixExpression((PrefixExpression&) expr, out);
646         case Expression::kPostfix_Kind:
647             return this->writePostfixExpression((PostfixExpression&) expr, out);
648         case Expression::kSwizzle_Kind:
649             return this->writeSwizzle((Swizzle&) expr, out);
650         case Expression::kVariableReference_Kind:
651             return this->writeVariableReference((VariableReference&) expr, out);
652         case Expression::kTernary_Kind:
653             return this->writeTernaryExpression((TernaryExpression&) expr, out);
654         case Expression::kIndex_Kind:
655             return this->writeIndexExpression((IndexExpression&) expr, out);
656         default:
657             ABORT("unsupported expression: %s", expr.description().c_str());
658     }
659     return -1;
660 }
661 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)662 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
663     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
664     ASSERT(intrinsic != fIntrinsicMap.end());
665     int32_t intrinsicId;
666     if (c.fArguments.size() > 0) {
667         const Type& type = c.fArguments[0]->fType;
668         if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
669             intrinsicId = std::get<1>(intrinsic->second);
670         } else if (is_signed(fContext, type)) {
671             intrinsicId = std::get<2>(intrinsic->second);
672         } else if (is_unsigned(fContext, type)) {
673             intrinsicId = std::get<3>(intrinsic->second);
674         } else if (is_bool(fContext, type)) {
675             intrinsicId = std::get<4>(intrinsic->second);
676         } else {
677             intrinsicId = std::get<1>(intrinsic->second);
678         }
679     } else {
680         intrinsicId = std::get<1>(intrinsic->second);
681     }
682     switch (std::get<0>(intrinsic->second)) {
683         case kGLSL_STD_450_IntrinsicKind: {
684             SpvId result = this->nextId();
685             std::vector<SpvId> arguments;
686             for (size_t i = 0; i < c.fArguments.size(); i++) {
687                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
688             }
689             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
690             this->writeWord(this->getType(c.fType), out);
691             this->writeWord(result, out);
692             this->writeWord(fGLSLExtendedInstructions, out);
693             this->writeWord(intrinsicId, out);
694             for (SpvId id : arguments) {
695                 this->writeWord(id, out);
696             }
697             return result;
698         }
699         case kSPIRV_IntrinsicKind: {
700             SpvId result = this->nextId();
701             std::vector<SpvId> arguments;
702             for (size_t i = 0; i < c.fArguments.size(); i++) {
703                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
704             }
705             if (c.fType != *fContext.fVoid_Type) {
706                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
707                 this->writeWord(this->getType(c.fType), out);
708                 this->writeWord(result, out);
709             } else {
710                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
711             }
712             for (SpvId id : arguments) {
713                 this->writeWord(id, out);
714             }
715             return result;
716         }
717         case kSpecial_IntrinsicKind:
718             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
719         default:
720             ABORT("unsupported intrinsic kind");
721     }
722 }
723 
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)724 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
725                                                const std::vector<std::unique_ptr<Expression>>& args,
726                                                OutputStream& out) {
727     int vectorSize = 0;
728     for (const auto& a : args) {
729         if (a->fType.kind() == Type::kVector_Kind) {
730             if (vectorSize) {
731                 ASSERT(a->fType.columns() == vectorSize);
732             }
733             else {
734                 vectorSize = a->fType.columns();
735             }
736         }
737     }
738     std::vector<SpvId> result;
739     for (const auto& a : args) {
740         SpvId raw = this->writeExpression(*a, out);
741         if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
742             SpvId vector = this->nextId();
743             this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
744             this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
745             this->writeWord(vector, out);
746             for (int i = 0; i < vectorSize; i++) {
747                 this->writeWord(raw, out);
748             }
749             result.push_back(vector);
750         } else {
751             result.push_back(raw);
752         }
753     }
754     return result;
755 }
756 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)757 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
758                                                       SpvId signedInst, SpvId unsignedInst,
759                                                       const std::vector<SpvId>& args,
760                                                       OutputStream& out) {
761     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
762     this->writeWord(this->getType(type), out);
763     this->writeWord(id, out);
764     this->writeWord(fGLSLExtendedInstructions, out);
765 
766     if (is_float(fContext, type)) {
767         this->writeWord(floatInst, out);
768     } else if (is_signed(fContext, type)) {
769         this->writeWord(signedInst, out);
770     } else if (is_unsigned(fContext, type)) {
771         this->writeWord(unsignedInst, out);
772     } else {
773         ASSERT(false);
774     }
775     for (SpvId a : args) {
776         this->writeWord(a, out);
777     }
778 }
779 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)780 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
781                                                 OutputStream& out) {
782     SpvId result = this->nextId();
783     switch (kind) {
784         case kAtan_SpecialIntrinsic: {
785             std::vector<SpvId> arguments;
786             for (size_t i = 0; i < c.fArguments.size(); i++) {
787                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
788             }
789             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
790             this->writeWord(this->getType(c.fType), out);
791             this->writeWord(result, out);
792             this->writeWord(fGLSLExtendedInstructions, out);
793             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
794             for (SpvId id : arguments) {
795                 this->writeWord(id, out);
796             }
797             break;
798         }
799         case kSubpassLoad_SpecialIntrinsic: {
800             SpvId img = this->writeExpression(*c.fArguments[0], out);
801             std::vector<std::unique_ptr<Expression>> args;
802             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
803             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
804             Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
805             SpvId coords = this->writeConstantVector(ctor);
806             if (1 == c.fArguments.size()) {
807                 this->writeInstruction(SpvOpImageRead,
808                                        this->getType(c.fType),
809                                        result,
810                                        img,
811                                        coords,
812                                        out);
813             } else {
814                 ASSERT(2 == c.fArguments.size());
815                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
816                 this->writeInstruction(SpvOpImageRead,
817                                        this->getType(c.fType),
818                                        result,
819                                        img,
820                                        coords,
821                                        SpvImageOperandsSampleMask,
822                                        sample,
823                                        out);
824             }
825             break;
826         }
827         case kTexelFetch_SpecialIntrinsic: {
828             ASSERT(c.fArguments.size() == 2);
829             SpvId image = this->nextId();
830             this->writeInstruction(SpvOpImage,
831                                    this->getImageType(c.fArguments[0]->fType),
832                                    image,
833                                    this->writeExpression(*c.fArguments[0], out),
834                                    out);
835             this->writeInstruction(SpvOpImageFetch,
836                                    this->getType(c.fType),
837                                    result,
838                                    image,
839                                    this->writeExpression(*c.fArguments[1], out),
840                                    out);
841             break;
842         }
843         case kTexture_SpecialIntrinsic: {
844             SpvOp_ op = SpvOpImageSampleImplicitLod;
845             switch (c.fArguments[0]->fType.dimensions()) {
846                 case SpvDim1D:
847                     if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
848                         op = SpvOpImageSampleProjImplicitLod;
849                     } else {
850                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
851                     }
852                     break;
853                 case SpvDim2D:
854                     if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
855                         op = SpvOpImageSampleProjImplicitLod;
856                     } else {
857                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
858                     }
859                     break;
860                 case SpvDim3D:
861                     if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
862                         op = SpvOpImageSampleProjImplicitLod;
863                     } else {
864                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
865                     }
866                     break;
867                 case SpvDimCube:   // fall through
868                 case SpvDimRect:   // fall through
869                 case SpvDimBuffer: // fall through
870                 case SpvDimSubpassData:
871                     break;
872             }
873             SpvId type = this->getType(c.fType);
874             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
875             SpvId uv = this->writeExpression(*c.fArguments[1], out);
876             if (c.fArguments.size() == 3) {
877                 this->writeInstruction(op, type, result, sampler, uv,
878                                        SpvImageOperandsBiasMask,
879                                        this->writeExpression(*c.fArguments[2], out),
880                                        out);
881             } else {
882                 ASSERT(c.fArguments.size() == 2);
883                 this->writeInstruction(op, type, result, sampler, uv,
884                                        out);
885             }
886             break;
887         }
888         case kMod_SpecialIntrinsic: {
889             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
890             ASSERT(args.size() == 2);
891             const Type& operandType = c.fArguments[0]->fType;
892             SpvOp_ op;
893             if (is_float(fContext, operandType)) {
894                 op = SpvOpFMod;
895             } else if (is_signed(fContext, operandType)) {
896                 op = SpvOpSMod;
897             } else if (is_unsigned(fContext, operandType)) {
898                 op = SpvOpUMod;
899             } else {
900                 ASSERT(false);
901                 return 0;
902             }
903             this->writeOpCode(op, 5, out);
904             this->writeWord(this->getType(operandType), out);
905             this->writeWord(result, out);
906             this->writeWord(args[0], out);
907             this->writeWord(args[1], out);
908             break;
909         }
910         case kClamp_SpecialIntrinsic: {
911             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
912             ASSERT(args.size() == 3);
913             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
914                                                GLSLstd450UClamp, args, out);
915             break;
916         }
917         case kMax_SpecialIntrinsic: {
918             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
919             ASSERT(args.size() == 2);
920             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
921                                                GLSLstd450UMax, args, out);
922             break;
923         }
924         case kMin_SpecialIntrinsic: {
925             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
926             ASSERT(args.size() == 2);
927             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
928                                                GLSLstd450UMin, args, out);
929             break;
930         }
931         case kMix_SpecialIntrinsic: {
932             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
933             ASSERT(args.size() == 3);
934             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
935                                                SpvOpUndef, args, out);
936             break;
937         }
938     }
939     return result;
940 }
941 
writeFunctionCall(const FunctionCall & c,OutputStream & out)942 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
943     const auto& entry = fFunctionMap.find(&c.fFunction);
944     if (entry == fFunctionMap.end()) {
945         return this->writeIntrinsicCall(c, out);
946     }
947     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
948     std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
949     std::vector<SpvId> arguments;
950     for (size_t i = 0; i < c.fArguments.size(); i++) {
951         // id of temporary variable that we will use to hold this argument, or 0 if it is being
952         // passed directly
953         SpvId tmpVar;
954         // if we need a temporary var to store this argument, this is the value to store in the var
955         SpvId tmpValueId;
956         if (is_out(*c.fFunction.fParameters[i])) {
957             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
958             SpvId ptr = lv->getPointer();
959             if (ptr) {
960                 arguments.push_back(ptr);
961                 continue;
962             } else {
963                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
964                 // copy it into a temp, call the function, read the value out of the temp, and then
965                 // update the lvalue.
966                 tmpValueId = lv->load(out);
967                 tmpVar = this->nextId();
968                 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
969                                   std::move(lv)));
970             }
971         } else {
972             // see getFunctionType for an explanation of why we're always using pointer parameters
973             tmpValueId = this->writeExpression(*c.fArguments[i], out);
974             tmpVar = this->nextId();
975         }
976         this->writeInstruction(SpvOpVariable,
977                                this->getPointerType(c.fArguments[i]->fType,
978                                                     SpvStorageClassFunction),
979                                tmpVar,
980                                SpvStorageClassFunction,
981                                fVariableBuffer);
982         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
983         arguments.push_back(tmpVar);
984     }
985     SpvId result = this->nextId();
986     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
987     this->writeWord(this->getType(c.fType), out);
988     this->writeWord(result, out);
989     this->writeWord(entry->second, out);
990     for (SpvId id : arguments) {
991         this->writeWord(id, out);
992     }
993     // now that the call is complete, we may need to update some lvalues with the new values of out
994     // arguments
995     for (const auto& tuple : lvalues) {
996         SpvId load = this->nextId();
997         this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
998         std::get<2>(tuple)->store(load, out);
999     }
1000     return result;
1001 }
1002 
writeConstantVector(const Constructor & c)1003 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1004     ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1005     SpvId result = this->nextId();
1006     std::vector<SpvId> arguments;
1007     for (size_t i = 0; i < c.fArguments.size(); i++) {
1008         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1009     }
1010     SpvId type = this->getType(c.fType);
1011     if (c.fArguments.size() == 1) {
1012         // with a single argument, a vector will have all of its entries equal to the argument
1013         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1014         this->writeWord(type, fConstantBuffer);
1015         this->writeWord(result, fConstantBuffer);
1016         for (int i = 0; i < c.fType.columns(); i++) {
1017             this->writeWord(arguments[0], fConstantBuffer);
1018         }
1019     } else {
1020         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1021                           fConstantBuffer);
1022         this->writeWord(type, fConstantBuffer);
1023         this->writeWord(result, fConstantBuffer);
1024         for (SpvId id : arguments) {
1025             this->writeWord(id, fConstantBuffer);
1026         }
1027     }
1028     return result;
1029 }
1030 
writeFloatConstructor(const Constructor & c,OutputStream & out)1031 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1032     ASSERT(c.fType.isFloat());
1033     ASSERT(c.fArguments.size() == 1);
1034     ASSERT(c.fArguments[0]->fType.isNumber());
1035     SpvId result = this->nextId();
1036     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1037     if (c.fArguments[0]->fType.isSigned()) {
1038         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1039                                out);
1040     } else {
1041         ASSERT(c.fArguments[0]->fType.isUnsigned());
1042         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1043                                out);
1044     }
1045     return result;
1046 }
1047 
writeIntConstructor(const Constructor & c,OutputStream & out)1048 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1049     ASSERT(c.fType.isSigned());
1050     ASSERT(c.fArguments.size() == 1);
1051     ASSERT(c.fArguments[0]->fType.isNumber());
1052     SpvId result = this->nextId();
1053     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1054     if (c.fArguments[0]->fType.isFloat()) {
1055         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1056                                out);
1057     }
1058     else {
1059         ASSERT(c.fArguments[0]->fType.isUnsigned());
1060         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1061                                out);
1062     }
1063     return result;
1064 }
1065 
writeUIntConstructor(const Constructor & c,OutputStream & out)1066 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1067     ASSERT(c.fType.isUnsigned());
1068     ASSERT(c.fArguments.size() == 1);
1069     ASSERT(c.fArguments[0]->fType.isNumber());
1070     SpvId result = this->nextId();
1071     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1072     if (c.fArguments[0]->fType.isFloat()) {
1073         this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1074                                out);
1075     } else {
1076         ASSERT(c.fArguments[0]->fType.isSigned());
1077         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1078                                out);
1079     }
1080     return result;
1081 }
1082 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1083 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1084                                                  OutputStream& out) {
1085     FloatLiteral zero(fContext, -1, 0);
1086     SpvId zeroId = this->writeFloatLiteral(zero);
1087     std::vector<SpvId> columnIds;
1088     for (int column = 0; column < type.columns(); column++) {
1089         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1090                           out);
1091         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1092                         out);
1093         SpvId columnId = this->nextId();
1094         this->writeWord(columnId, out);
1095         columnIds.push_back(columnId);
1096         for (int row = 0; row < type.columns(); row++) {
1097             this->writeWord(row == column ? diagonal : zeroId, out);
1098         }
1099     }
1100     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1101                       out);
1102     this->writeWord(this->getType(type), out);
1103     this->writeWord(id, out);
1104     for (SpvId id : columnIds) {
1105         this->writeWord(id, out);
1106     }
1107 }
1108 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1109 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1110                                          const Type& dstType, OutputStream& out) {
1111     ASSERT(srcType.kind() == Type::kMatrix_Kind);
1112     ASSERT(dstType.kind() == Type::kMatrix_Kind);
1113     ASSERT(srcType.componentType() == dstType.componentType());
1114     SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1115                                                                            srcType.rows(),
1116                                                                            1));
1117     SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1118                                                                            dstType.rows(),
1119                                                                            1));
1120     SpvId zeroId;
1121     if (dstType.componentType() == *fContext.fFloat_Type) {
1122         FloatLiteral zero(fContext, -1, 0.0);
1123         zeroId = this->writeFloatLiteral(zero);
1124     } else if (dstType.componentType() == *fContext.fInt_Type) {
1125         IntLiteral zero(fContext, -1, 0);
1126         zeroId = this->writeIntLiteral(zero);
1127     } else {
1128         ABORT("unsupported matrix component type");
1129     }
1130     SpvId zeroColumn = 0;
1131     SpvId columns[4];
1132     for (int i = 0; i < dstType.columns(); i++) {
1133         if (i < srcType.columns()) {
1134             // we're still inside the src matrix, copy the column
1135             SpvId srcColumn = this->nextId();
1136             this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1137             SpvId dstColumn;
1138             if (srcType.rows() == dstType.rows()) {
1139                 // columns are equal size, don't need to do anything
1140                 dstColumn = srcColumn;
1141             }
1142             else if (dstType.rows() > srcType.rows()) {
1143                 // dst column is bigger, need to zero-pad it
1144                 dstColumn = this->nextId();
1145                 int delta = dstType.rows() - srcType.rows();
1146                 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1147                 this->writeWord(dstColumnType, out);
1148                 this->writeWord(dstColumn, out);
1149                 this->writeWord(srcColumn, out);
1150                 for (int i = 0; i < delta; ++i) {
1151                     this->writeWord(zeroId, out);
1152                 }
1153             }
1154             else {
1155                 // dst column is smaller, need to swizzle the src column
1156                 dstColumn = this->nextId();
1157                 int count = dstType.rows();
1158                 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1159                 this->writeWord(dstColumnType, out);
1160                 this->writeWord(dstColumn, out);
1161                 this->writeWord(srcColumn, out);
1162                 this->writeWord(srcColumn, out);
1163                 for (int i = 0; i < count; i++) {
1164                     this->writeWord(i, out);
1165                 }
1166             }
1167             columns[i] = dstColumn;
1168         } else {
1169             // we're past the end of the src matrix, need a vector of zeroes
1170             if (!zeroColumn) {
1171                 zeroColumn = this->nextId();
1172                 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1173                 this->writeWord(dstColumnType, out);
1174                 this->writeWord(zeroColumn, out);
1175                 for (int i = 0; i < dstType.rows(); ++i) {
1176                     this->writeWord(zeroId, out);
1177                 }
1178             }
1179             columns[i] = zeroColumn;
1180         }
1181     }
1182     this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1183     this->writeWord(this->getType(dstType), out);
1184     this->writeWord(id, out);
1185     for (int i = 0; i < dstType.columns(); i++) {
1186         this->writeWord(columns[i], out);
1187     }
1188 }
1189 
writeMatrixConstructor(const Constructor & c,OutputStream & out)1190 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1191     ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1192     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1193     // an instruction
1194     std::vector<SpvId> arguments;
1195     for (size_t i = 0; i < c.fArguments.size(); i++) {
1196         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1197     }
1198     SpvId result = this->nextId();
1199     int rows = c.fType.rows();
1200     int columns = c.fType.columns();
1201     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1202         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1203     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1204         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1205     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1206         ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1207         ASSERT(c.fArguments[0]->fType.columns() == 4);
1208         SpvId componentType = this->getType(c.fType.componentType());
1209         SpvId v[4];
1210         for (int i = 0; i < 4; ++i) {
1211             v[i] = this->nextId();
1212             this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1213         }
1214         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1215         SpvId column1 = this->nextId();
1216         this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1217         SpvId column2 = this->nextId();
1218         this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1219         this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1220                                column2, out);
1221     } else {
1222         std::vector<SpvId> columnIds;
1223         // ids of vectors and scalars we have written to the current column so far
1224         std::vector<SpvId> currentColumn;
1225         // the total number of scalars represented by currentColumn's entries
1226         int currentCount = 0;
1227         for (size_t i = 0; i < arguments.size(); i++) {
1228             if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1229                     c.fArguments[i]->fType.columns() == c.fType.rows()) {
1230                 // this is a complete column by itself
1231                 ASSERT(currentCount == 0);
1232                 columnIds.push_back(arguments[i]);
1233             } else {
1234                 currentColumn.push_back(arguments[i]);
1235                 currentCount += c.fArguments[i]->fType.columns();
1236                 if (currentCount == rows) {
1237                     currentCount = 0;
1238                     this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1239                     this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1240                                                                                      1)),
1241                                     out);
1242                     SpvId columnId = this->nextId();
1243                     this->writeWord(columnId, out);
1244                     columnIds.push_back(columnId);
1245                     for (SpvId id : currentColumn) {
1246                         this->writeWord(id, out);
1247                     }
1248                     currentColumn.clear();
1249                 }
1250                 ASSERT(currentCount < rows);
1251             }
1252         }
1253         ASSERT(columnIds.size() == (size_t) columns);
1254         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1255         this->writeWord(this->getType(c.fType), out);
1256         this->writeWord(result, out);
1257         for (SpvId id : columnIds) {
1258             this->writeWord(id, out);
1259         }
1260     }
1261     return result;
1262 }
1263 
writeVectorConstructor(const Constructor & c,OutputStream & out)1264 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1265     ASSERT(c.fType.kind() == Type::kVector_Kind);
1266     if (c.isConstant()) {
1267         return this->writeConstantVector(c);
1268     }
1269     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1270     // an instruction
1271     std::vector<SpvId> arguments;
1272     for (size_t i = 0; i < c.fArguments.size(); i++) {
1273         if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1274             // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1275             // extract the components and convert them in that case manually. On top of that,
1276             // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1277             // doesn't handle vector arguments at all, so we always extract vector components and
1278             // pass them into OpCreateComposite individually.
1279             SpvId vec = this->writeExpression(*c.fArguments[i], out);
1280             SpvOp_ op = SpvOpUndef;
1281             const Type& src = c.fArguments[i]->fType.componentType();
1282             const Type& dst = c.fType.componentType();
1283             if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1284                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1285                     if (c.fArguments.size() == 1) {
1286                         return vec;
1287                     }
1288                 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1289                     op = SpvOpConvertSToF;
1290                 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1291                     op = SpvOpConvertUToF;
1292                 } else {
1293                     ASSERT(false);
1294                 }
1295             } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) {
1296                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1297                     op = SpvOpConvertFToS;
1298                 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1299                     if (c.fArguments.size() == 1) {
1300                         return vec;
1301                     }
1302                 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1303                     op = SpvOpBitcast;
1304                 } else {
1305                     ASSERT(false);
1306                 }
1307             } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) {
1308                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1309                     op = SpvOpConvertFToS;
1310                 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1311                     op = SpvOpBitcast;
1312                 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1313                     if (c.fArguments.size() == 1) {
1314                         return vec;
1315                     }
1316                 } else {
1317                     ASSERT(false);
1318                 }
1319             }
1320             for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1321                 SpvId swizzle = this->nextId();
1322                 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1323                                        out);
1324                 if (op != SpvOpUndef) {
1325                     SpvId cast = this->nextId();
1326                     this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1327                     arguments.push_back(cast);
1328                 } else {
1329                     arguments.push_back(swizzle);
1330                 }
1331             }
1332         } else {
1333             arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1334         }
1335     }
1336     SpvId result = this->nextId();
1337     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1338         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1339         this->writeWord(this->getType(c.fType), out);
1340         this->writeWord(result, out);
1341         for (int i = 0; i < c.fType.columns(); i++) {
1342             this->writeWord(arguments[0], out);
1343         }
1344     } else {
1345         ASSERT(arguments.size() > 1);
1346         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1347         this->writeWord(this->getType(c.fType), out);
1348         this->writeWord(result, out);
1349         for (SpvId id : arguments) {
1350             this->writeWord(id, out);
1351         }
1352     }
1353     return result;
1354 }
1355 
writeArrayConstructor(const Constructor & c,OutputStream & out)1356 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1357     ASSERT(c.fType.kind() == Type::kArray_Kind);
1358     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1359     // an instruction
1360     std::vector<SpvId> arguments;
1361     for (size_t i = 0; i < c.fArguments.size(); i++) {
1362         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1363     }
1364     SpvId result = this->nextId();
1365     this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1366     this->writeWord(this->getType(c.fType), out);
1367     this->writeWord(result, out);
1368     for (SpvId id : arguments) {
1369         this->writeWord(id, out);
1370     }
1371     return result;
1372 }
1373 
writeConstructor(const Constructor & c,OutputStream & out)1374 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1375     if (c.fArguments.size() == 1 &&
1376         this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1377         return this->writeExpression(*c.fArguments[0], out);
1378     }
1379     if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1380         return this->writeFloatConstructor(c, out);
1381     } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) {
1382         return this->writeIntConstructor(c, out);
1383     } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) {
1384         return this->writeUIntConstructor(c, out);
1385     }
1386     switch (c.fType.kind()) {
1387         case Type::kVector_Kind:
1388             return this->writeVectorConstructor(c, out);
1389         case Type::kMatrix_Kind:
1390             return this->writeMatrixConstructor(c, out);
1391         case Type::kArray_Kind:
1392             return this->writeArrayConstructor(c, out);
1393         default:
1394             ABORT("unsupported constructor: %s", c.description().c_str());
1395     }
1396 }
1397 
get_storage_class(const Modifiers & modifiers)1398 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1399     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1400         ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1401         return SpvStorageClassInput;
1402     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1403         ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1404         return SpvStorageClassOutput;
1405     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1406         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1407             return SpvStorageClassPushConstant;
1408         }
1409         return SpvStorageClassUniform;
1410     } else {
1411         return SpvStorageClassFunction;
1412     }
1413 }
1414 
get_storage_class(const Expression & expr)1415 SpvStorageClass_ get_storage_class(const Expression& expr) {
1416     switch (expr.fKind) {
1417         case Expression::kVariableReference_Kind: {
1418             const Variable& var = ((VariableReference&) expr).fVariable;
1419             if (var.fStorage != Variable::kGlobal_Storage) {
1420                 return SpvStorageClassFunction;
1421             }
1422             SpvStorageClass_ result = get_storage_class(var.fModifiers);
1423             if (result == SpvStorageClassFunction) {
1424                 result = SpvStorageClassPrivate;
1425             }
1426             return result;
1427         }
1428         case Expression::kFieldAccess_Kind:
1429             return get_storage_class(*((FieldAccess&) expr).fBase);
1430         case Expression::kIndex_Kind:
1431             return get_storage_class(*((IndexExpression&) expr).fBase);
1432         default:
1433             return SpvStorageClassFunction;
1434     }
1435 }
1436 
getAccessChain(const Expression & expr,OutputStream & out)1437 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1438     std::vector<SpvId> chain;
1439     switch (expr.fKind) {
1440         case Expression::kIndex_Kind: {
1441             IndexExpression& indexExpr = (IndexExpression&) expr;
1442             chain = this->getAccessChain(*indexExpr.fBase, out);
1443             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1444             break;
1445         }
1446         case Expression::kFieldAccess_Kind: {
1447             FieldAccess& fieldExpr = (FieldAccess&) expr;
1448             chain = this->getAccessChain(*fieldExpr.fBase, out);
1449             IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1450             chain.push_back(this->writeIntLiteral(index));
1451             break;
1452         }
1453         default:
1454             chain.push_back(this->getLValue(expr, out)->getPointer());
1455     }
1456     return chain;
1457 }
1458 
1459 class PointerLValue : public SPIRVCodeGenerator::LValue {
1460 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1461     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1462     : fGen(gen)
1463     , fPointer(pointer)
1464     , fType(type) {}
1465 
getPointer()1466     virtual SpvId getPointer() override {
1467         return fPointer;
1468     }
1469 
load(OutputStream & out)1470     virtual SpvId load(OutputStream& out) override {
1471         SpvId result = fGen.nextId();
1472         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1473         return result;
1474     }
1475 
store(SpvId value,OutputStream & out)1476     virtual void store(SpvId value, OutputStream& out) override {
1477         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1478     }
1479 
1480 private:
1481     SPIRVCodeGenerator& fGen;
1482     const SpvId fPointer;
1483     const SpvId fType;
1484 };
1485 
1486 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1487 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1488     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1489                   const Type& baseType, const Type& swizzleType)
1490     : fGen(gen)
1491     , fVecPointer(vecPointer)
1492     , fComponents(components)
1493     , fBaseType(baseType)
1494     , fSwizzleType(swizzleType) {}
1495 
getPointer()1496     virtual SpvId getPointer() override {
1497         return 0;
1498     }
1499 
load(OutputStream & out)1500     virtual SpvId load(OutputStream& out) override {
1501         SpvId base = fGen.nextId();
1502         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1503         SpvId result = fGen.nextId();
1504         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1505         fGen.writeWord(fGen.getType(fSwizzleType), out);
1506         fGen.writeWord(result, out);
1507         fGen.writeWord(base, out);
1508         fGen.writeWord(base, out);
1509         for (int component : fComponents) {
1510             fGen.writeWord(component, out);
1511         }
1512         return result;
1513     }
1514 
store(SpvId value,OutputStream & out)1515     virtual void store(SpvId value, OutputStream& out) override {
1516         // use OpVectorShuffle to mix and match the vector components. We effectively create
1517         // a virtual vector out of the concatenation of the left and right vectors, and then
1518         // select components from this virtual vector to make the result vector. For
1519         // instance, given:
1520         // float3L = ...;
1521         // float3R = ...;
1522         // L.xz = R.xy;
1523         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1524         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1525         // (3, 1, 4).
1526         SpvId base = fGen.nextId();
1527         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1528         SpvId shuffle = fGen.nextId();
1529         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1530         fGen.writeWord(fGen.getType(fBaseType), out);
1531         fGen.writeWord(shuffle, out);
1532         fGen.writeWord(base, out);
1533         fGen.writeWord(value, out);
1534         for (int i = 0; i < fBaseType.columns(); i++) {
1535             // current offset into the virtual vector, defaults to pulling the unmodified
1536             // value from the left side
1537             int offset = i;
1538             // check to see if we are writing this component
1539             for (size_t j = 0; j < fComponents.size(); j++) {
1540                 if (fComponents[j] == i) {
1541                     // we're writing to this component, so adjust the offset to pull from
1542                     // the correct component of the right side instead of preserving the
1543                     // value from the left
1544                     offset = (int) (j + fBaseType.columns());
1545                     break;
1546                 }
1547             }
1548             fGen.writeWord(offset, out);
1549         }
1550         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1551     }
1552 
1553 private:
1554     SPIRVCodeGenerator& fGen;
1555     const SpvId fVecPointer;
1556     const std::vector<int>& fComponents;
1557     const Type& fBaseType;
1558     const Type& fSwizzleType;
1559 };
1560 
getLValue(const Expression & expr,OutputStream & out)1561 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1562                                                                           OutputStream& out) {
1563     switch (expr.fKind) {
1564         case Expression::kVariableReference_Kind: {
1565             const Variable& var = ((VariableReference&) expr).fVariable;
1566             auto entry = fVariableMap.find(&var);
1567             ASSERT(entry != fVariableMap.end());
1568             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1569                                                                        *this,
1570                                                                        entry->second,
1571                                                                        this->getType(expr.fType)));
1572         }
1573         case Expression::kIndex_Kind: // fall through
1574         case Expression::kFieldAccess_Kind: {
1575             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1576             SpvId member = this->nextId();
1577             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1578             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1579             this->writeWord(member, out);
1580             for (SpvId idx : chain) {
1581                 this->writeWord(idx, out);
1582             }
1583             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1584                                                                        *this,
1585                                                                        member,
1586                                                                        this->getType(expr.fType)));
1587         }
1588         case Expression::kSwizzle_Kind: {
1589             Swizzle& swizzle = (Swizzle&) expr;
1590             size_t count = swizzle.fComponents.size();
1591             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1592             ASSERT(base);
1593             if (count == 1) {
1594                 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1595                 SpvId member = this->nextId();
1596                 this->writeInstruction(SpvOpAccessChain,
1597                                        this->getPointerType(swizzle.fType,
1598                                                             get_storage_class(*swizzle.fBase)),
1599                                        member,
1600                                        base,
1601                                        this->writeIntLiteral(index),
1602                                        out);
1603                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1604                                                                        *this,
1605                                                                        member,
1606                                                                        this->getType(expr.fType)));
1607             } else {
1608                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1609                                                                               *this,
1610                                                                               base,
1611                                                                               swizzle.fComponents,
1612                                                                               swizzle.fBase->fType,
1613                                                                               expr.fType));
1614             }
1615         }
1616         case Expression::kTernary_Kind: {
1617             TernaryExpression& t = (TernaryExpression&) expr;
1618             SpvId test = this->writeExpression(*t.fTest, out);
1619             SpvId end = this->nextId();
1620             SpvId ifTrueLabel = this->nextId();
1621             SpvId ifFalseLabel = this->nextId();
1622             this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1623             this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1624             this->writeLabel(ifTrueLabel, out);
1625             SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1626             ASSERT(ifTrue);
1627             this->writeInstruction(SpvOpBranch, end, out);
1628             ifTrueLabel = fCurrentBlock;
1629             SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1630             ASSERT(ifFalse);
1631             ifFalseLabel = fCurrentBlock;
1632             this->writeInstruction(SpvOpBranch, end, out);
1633             SpvId result = this->nextId();
1634             this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1635                        ifTrueLabel, ifFalse, ifFalseLabel, out);
1636             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1637                                                                        *this,
1638                                                                        result,
1639                                                                        this->getType(expr.fType)));
1640         }
1641         default:
1642             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1643             // to the need to store values in temporary variables during function calls (see
1644             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1645             // caught by IRGenerator
1646             SpvId result = this->nextId();
1647             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1648             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1649                                    fVariableBuffer);
1650             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1651             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1652                                                                        *this,
1653                                                                        result,
1654                                                                        this->getType(expr.fType)));
1655     }
1656 }
1657 
writeVariableReference(const VariableReference & ref,OutputStream & out)1658 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1659     SpvId result = this->nextId();
1660     auto entry = fVariableMap.find(&ref.fVariable);
1661     ASSERT(entry != fVariableMap.end());
1662     SpvId var = entry->second;
1663     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1664     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1665         fProgram.fSettings.fFlipY) {
1666         // need to remap to a top-left coordinate system
1667         if (fRTHeightStructId == (SpvId) -1) {
1668             // height variable hasn't been written yet
1669             std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1670             ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1671             std::vector<Type::Field> fields;
1672             fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1673             StringFragment name("sksl_synthetic_uniforms");
1674             Type intfStruct(-1, name, fields);
1675             Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1676                           Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1677                           StringFragment());
1678             Variable* intfVar = new Variable(-1,
1679                                              Modifiers(layout, Modifiers::kUniform_Flag),
1680                                              name,
1681                                              intfStruct,
1682                                              Variable::kGlobal_Storage);
1683             fSynthetics.takeOwnership(intfVar);
1684             InterfaceBlock intf(-1, intfVar, name, String(""),
1685                                 std::vector<std::unique_ptr<Expression>>(), st);
1686             fRTHeightStructId = this->writeInterfaceBlock(intf);
1687             fRTHeightFieldIndex = 0;
1688         }
1689         ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1690         // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1691         SpvId xId = this->nextId();
1692         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1693                                result, 0, out);
1694         IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1695         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1696         SpvId heightPtr = this->nextId();
1697         this->writeOpCode(SpvOpAccessChain, 5, out);
1698         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1699         this->writeWord(heightPtr, out);
1700         this->writeWord(fRTHeightStructId, out);
1701         this->writeWord(fieldIndexId, out);
1702         SpvId heightRead = this->nextId();
1703         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1704                                heightPtr, out);
1705         SpvId rawYId = this->nextId();
1706         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1707                                result, 1, out);
1708         SpvId flippedYId = this->nextId();
1709         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1710                                heightRead, rawYId, out);
1711         FloatLiteral zero(fContext, -1, 0.0);
1712         SpvId zeroId = writeFloatLiteral(zero);
1713         FloatLiteral one(fContext, -1, 1.0);
1714         SpvId oneId = writeFloatLiteral(one);
1715         SpvId flipped = this->nextId();
1716         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1717         this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1718         this->writeWord(flipped, out);
1719         this->writeWord(xId, out);
1720         this->writeWord(flippedYId, out);
1721         this->writeWord(zeroId, out);
1722         this->writeWord(oneId, out);
1723         return flipped;
1724     }
1725     return result;
1726 }
1727 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1728 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1729     return getLValue(expr, out)->load(out);
1730 }
1731 
writeFieldAccess(const FieldAccess & f,OutputStream & out)1732 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1733     return getLValue(f, out)->load(out);
1734 }
1735 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1736 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1737     SpvId base = this->writeExpression(*swizzle.fBase, out);
1738     SpvId result = this->nextId();
1739     size_t count = swizzle.fComponents.size();
1740     if (count == 1) {
1741         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1742                                swizzle.fComponents[0], out);
1743     } else {
1744         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1745         this->writeWord(this->getType(swizzle.fType), out);
1746         this->writeWord(result, out);
1747         this->writeWord(base, out);
1748         this->writeWord(base, out);
1749         for (int component : swizzle.fComponents) {
1750             this->writeWord(component, out);
1751         }
1752     }
1753     return result;
1754 }
1755 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1756 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1757                                                const Type& operandType, SpvId lhs,
1758                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1759                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1760     SpvId result = this->nextId();
1761     if (is_float(fContext, operandType)) {
1762         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1763     } else if (is_signed(fContext, operandType)) {
1764         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1765     } else if (is_unsigned(fContext, operandType)) {
1766         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1767     } else if (operandType == *fContext.fBool_Type) {
1768         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1769     } else {
1770         ABORT("invalid operandType: %s", operandType.description().c_str());
1771     }
1772     return result;
1773 }
1774 
is_assignment(Token::Kind op)1775 bool is_assignment(Token::Kind op) {
1776     switch (op) {
1777         case Token::EQ:           // fall through
1778         case Token::PLUSEQ:       // fall through
1779         case Token::MINUSEQ:      // fall through
1780         case Token::STAREQ:       // fall through
1781         case Token::SLASHEQ:      // fall through
1782         case Token::PERCENTEQ:    // fall through
1783         case Token::SHLEQ:        // fall through
1784         case Token::SHREQ:        // fall through
1785         case Token::BITWISEOREQ:  // fall through
1786         case Token::BITWISEXOREQ: // fall through
1787         case Token::BITWISEANDEQ: // fall through
1788         case Token::LOGICALOREQ:  // fall through
1789         case Token::LOGICALXOREQ: // fall through
1790         case Token::LOGICALANDEQ:
1791             return true;
1792         default:
1793             return false;
1794     }
1795 }
1796 
foldToBool(SpvId id,const Type & operandType,OutputStream & out)1797 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
1798     if (operandType.kind() == Type::kVector_Kind) {
1799         SpvId result = this->nextId();
1800         this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
1801         return result;
1802     }
1803     return id;
1804 }
1805 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)1806 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1807                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
1808                                                 OutputStream& out) {
1809     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1810     ASSERT(operandType.kind() == Type::kMatrix_Kind);
1811     SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
1812                                                                          operandType.columns(),
1813                                                                          1));
1814     SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1815                                                                     operandType.columns(),
1816                                                                     1));
1817     SpvId boolType = this->getType(*fContext.fBool_Type);
1818     SpvId result = 0;
1819     for (int i = 0; i < operandType.rows(); i++) {
1820         SpvId rowL = this->nextId();
1821         this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
1822         SpvId rowR = this->nextId();
1823         this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
1824         SpvId compare = this->nextId();
1825         this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
1826         SpvId all = this->nextId();
1827         this->writeInstruction(SpvOpAll, boolType, all, compare, out);
1828         if (result != 0) {
1829             SpvId next = this->nextId();
1830             this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
1831             result = next;
1832         }
1833         else {
1834             result = all;
1835         }
1836     }
1837     return result;
1838 }
1839 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)1840 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1841     // handle cases where we don't necessarily evaluate both LHS and RHS
1842     switch (b.fOperator) {
1843         case Token::EQ: {
1844             SpvId rhs = this->writeExpression(*b.fRight, out);
1845             this->getLValue(*b.fLeft, out)->store(rhs, out);
1846             return rhs;
1847         }
1848         case Token::LOGICALAND:
1849             return this->writeLogicalAnd(b, out);
1850         case Token::LOGICALOR:
1851             return this->writeLogicalOr(b, out);
1852         default:
1853             break;
1854     }
1855 
1856     // "normal" operators
1857     const Type& resultType = b.fType;
1858     std::unique_ptr<LValue> lvalue;
1859     SpvId lhs;
1860     if (is_assignment(b.fOperator)) {
1861         lvalue = this->getLValue(*b.fLeft, out);
1862         lhs = lvalue->load(out);
1863     } else {
1864         lvalue = nullptr;
1865         lhs = this->writeExpression(*b.fLeft, out);
1866     }
1867     SpvId rhs = this->writeExpression(*b.fRight, out);
1868     if (b.fOperator == Token::COMMA) {
1869         return rhs;
1870     }
1871     Type tmp("<invalid>");
1872     // component type we are operating on: float, int, uint
1873     const Type* operandType;
1874     // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling
1875     // in SPIR-V
1876     if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1877         if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1878             b.fRight->fType.isNumber()) {
1879             // promote number to vector
1880             SpvId vec = this->nextId();
1881             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1882             this->writeWord(this->getType(resultType), out);
1883             this->writeWord(vec, out);
1884             for (int i = 0; i < resultType.columns(); i++) {
1885                 this->writeWord(rhs, out);
1886             }
1887             rhs = vec;
1888             operandType = &b.fRight->fType;
1889         } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1890                    b.fLeft->fType.isNumber()) {
1891             // promote number to vector
1892             SpvId vec = this->nextId();
1893             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1894             this->writeWord(this->getType(resultType), out);
1895             this->writeWord(vec, out);
1896             for (int i = 0; i < resultType.columns(); i++) {
1897                 this->writeWord(lhs, out);
1898             }
1899             lhs = vec;
1900             ASSERT(!lvalue);
1901             operandType = &b.fLeft->fType;
1902         } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
1903             SpvOp_ op;
1904             if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1905                 op = SpvOpMatrixTimesMatrix;
1906             } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
1907                 op = SpvOpMatrixTimesVector;
1908             } else {
1909                 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
1910                 op = SpvOpMatrixTimesScalar;
1911             }
1912             SpvId result = this->nextId();
1913             this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
1914             if (b.fOperator == Token::STAREQ) {
1915                 lvalue->store(result, out);
1916             } else {
1917                 ASSERT(b.fOperator == Token::STAR);
1918             }
1919             return result;
1920         } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1921             SpvId result = this->nextId();
1922             if (b.fLeft->fType.kind() == Type::kVector_Kind) {
1923                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
1924                                        lhs, rhs, out);
1925             } else {
1926                 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
1927                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
1928                                        lhs, out);
1929             }
1930             if (b.fOperator == Token::STAREQ) {
1931                 lvalue->store(result, out);
1932             } else {
1933                 ASSERT(b.fOperator == Token::STAR);
1934             }
1935             return result;
1936         } else {
1937             ABORT("unsupported binary expression: %s", b.description().c_str());
1938         }
1939     } else {
1940         tmp = this->getActualType(b.fLeft->fType);
1941         operandType = &tmp;
1942         ASSERT(*operandType == this->getActualType(b.fRight->fType));
1943     }
1944     switch (b.fOperator) {
1945         case Token::EQEQ: {
1946             if (operandType->kind() == Type::kMatrix_Kind) {
1947                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
1948                                                    SpvOpIEqual, out);
1949             }
1950             ASSERT(resultType == *fContext.fBool_Type);
1951             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1952                                                                SpvOpFOrdEqual, SpvOpIEqual,
1953                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
1954                                     *operandType, out);
1955         }
1956         case Token::NEQ:
1957             if (operandType->kind() == Type::kMatrix_Kind) {
1958                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
1959                                                    SpvOpINotEqual, out);
1960             }
1961             ASSERT(resultType == *fContext.fBool_Type);
1962             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1963                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
1964                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
1965                                                                out),
1966                                     *operandType, out);
1967         case Token::GT:
1968             ASSERT(resultType == *fContext.fBool_Type);
1969             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1970                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
1971                                               SpvOpUGreaterThan, SpvOpUndef, out);
1972         case Token::LT:
1973             ASSERT(resultType == *fContext.fBool_Type);
1974             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
1975                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
1976         case Token::GTEQ:
1977             ASSERT(resultType == *fContext.fBool_Type);
1978             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1979                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
1980                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
1981         case Token::LTEQ:
1982             ASSERT(resultType == *fContext.fBool_Type);
1983             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1984                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
1985                                               SpvOpULessThanEqual, SpvOpUndef, out);
1986         case Token::PLUS:
1987             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1988                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1989         case Token::MINUS:
1990             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1991                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
1992         case Token::STAR:
1993             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
1994                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
1995                 // matrix multiply
1996                 SpvId result = this->nextId();
1997                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
1998                                        lhs, rhs, out);
1999                 return result;
2000             }
2001             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2002                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2003         case Token::SLASH:
2004             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2005                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2006         case Token::PERCENT:
2007             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2008                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2009         case Token::SHL:
2010             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2011                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2012                                               SpvOpUndef, out);
2013         case Token::SHR:
2014             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2015                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2016                                               SpvOpUndef, out);
2017         case Token::BITWISEAND:
2018             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2019                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2020         case Token::BITWISEOR:
2021             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2022                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2023         case Token::BITWISEXOR:
2024             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2025                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2026         case Token::PLUSEQ: {
2027             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2028                                                       SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2029             ASSERT(lvalue);
2030             lvalue->store(result, out);
2031             return result;
2032         }
2033         case Token::MINUSEQ: {
2034             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2035                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
2036             ASSERT(lvalue);
2037             lvalue->store(result, out);
2038             return result;
2039         }
2040         case Token::STAREQ: {
2041             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2042                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2043                 // matrix multiply
2044                 SpvId result = this->nextId();
2045                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2046                                        lhs, rhs, out);
2047                 ASSERT(lvalue);
2048                 lvalue->store(result, out);
2049                 return result;
2050             }
2051             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2052                                                       SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2053             ASSERT(lvalue);
2054             lvalue->store(result, out);
2055             return result;
2056         }
2057         case Token::SLASHEQ: {
2058             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2059                                                       SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2060             ASSERT(lvalue);
2061             lvalue->store(result, out);
2062             return result;
2063         }
2064         case Token::PERCENTEQ: {
2065             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2066                                                       SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2067             ASSERT(lvalue);
2068             lvalue->store(result, out);
2069             return result;
2070         }
2071         case Token::SHLEQ: {
2072             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2073                                                       SpvOpUndef, SpvOpShiftLeftLogical,
2074                                                       SpvOpShiftLeftLogical, SpvOpUndef, out);
2075             ASSERT(lvalue);
2076             lvalue->store(result, out);
2077             return result;
2078         }
2079         case Token::SHREQ: {
2080             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2081                                                       SpvOpUndef, SpvOpShiftRightArithmetic,
2082                                                       SpvOpShiftRightLogical, SpvOpUndef, out);
2083             ASSERT(lvalue);
2084             lvalue->store(result, out);
2085             return result;
2086         }
2087         case Token::BITWISEANDEQ: {
2088             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2089                                                       SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2090                                                       SpvOpUndef, out);
2091             ASSERT(lvalue);
2092             lvalue->store(result, out);
2093             return result;
2094         }
2095         case Token::BITWISEOREQ: {
2096             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2097                                                       SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2098                                                       SpvOpUndef, out);
2099             ASSERT(lvalue);
2100             lvalue->store(result, out);
2101             return result;
2102         }
2103         case Token::BITWISEXOREQ: {
2104             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2105                                                       SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2106                                                       SpvOpUndef, out);
2107             ASSERT(lvalue);
2108             lvalue->store(result, out);
2109             return result;
2110         }
2111         default:
2112             ABORT("unsupported binary expression: %s", b.description().c_str());
2113     }
2114 }
2115 
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2116 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2117     ASSERT(a.fOperator == Token::LOGICALAND);
2118     BoolLiteral falseLiteral(fContext, -1, false);
2119     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2120     SpvId lhs = this->writeExpression(*a.fLeft, out);
2121     SpvId rhsLabel = this->nextId();
2122     SpvId end = this->nextId();
2123     SpvId lhsBlock = fCurrentBlock;
2124     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2125     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2126     this->writeLabel(rhsLabel, out);
2127     SpvId rhs = this->writeExpression(*a.fRight, out);
2128     SpvId rhsBlock = fCurrentBlock;
2129     this->writeInstruction(SpvOpBranch, end, out);
2130     this->writeLabel(end, out);
2131     SpvId result = this->nextId();
2132     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2133                            lhsBlock, rhs, rhsBlock, out);
2134     return result;
2135 }
2136 
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2137 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2138     ASSERT(o.fOperator == Token::LOGICALOR);
2139     BoolLiteral trueLiteral(fContext, -1, true);
2140     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2141     SpvId lhs = this->writeExpression(*o.fLeft, out);
2142     SpvId rhsLabel = this->nextId();
2143     SpvId end = this->nextId();
2144     SpvId lhsBlock = fCurrentBlock;
2145     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2146     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2147     this->writeLabel(rhsLabel, out);
2148     SpvId rhs = this->writeExpression(*o.fRight, out);
2149     SpvId rhsBlock = fCurrentBlock;
2150     this->writeInstruction(SpvOpBranch, end, out);
2151     this->writeLabel(end, out);
2152     SpvId result = this->nextId();
2153     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2154                            lhsBlock, rhs, rhsBlock, out);
2155     return result;
2156 }
2157 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2158 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2159     SpvId test = this->writeExpression(*t.fTest, out);
2160     if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2161         // both true and false are constants, can just use OpSelect
2162         SpvId result = this->nextId();
2163         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2164         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2165         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2166                                out);
2167         return result;
2168     }
2169     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2170     // Adreno. Switched to storing the result in a temp variable as glslang does.
2171     SpvId var = this->nextId();
2172     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2173                            var, SpvStorageClassFunction, fVariableBuffer);
2174     SpvId trueLabel = this->nextId();
2175     SpvId falseLabel = this->nextId();
2176     SpvId end = this->nextId();
2177     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2178     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2179     this->writeLabel(trueLabel, out);
2180     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2181     this->writeInstruction(SpvOpBranch, end, out);
2182     this->writeLabel(falseLabel, out);
2183     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2184     this->writeInstruction(SpvOpBranch, end, out);
2185     this->writeLabel(end, out);
2186     SpvId result = this->nextId();
2187     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2188     return result;
2189 }
2190 
create_literal_1(const Context & context,const Type & type)2191 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2192     if (type.isInteger()) {
2193         return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type));
2194     }
2195     else if (type.isFloat()) {
2196         return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type));
2197     } else {
2198         ABORT("math is unsupported on type '%s'", type.name().c_str());
2199     }
2200 }
2201 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2202 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2203     if (p.fOperator == Token::MINUS) {
2204         SpvId result = this->nextId();
2205         SpvId typeId = this->getType(p.fType);
2206         SpvId expr = this->writeExpression(*p.fOperand, out);
2207         if (is_float(fContext, p.fType)) {
2208             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2209         } else if (is_signed(fContext, p.fType)) {
2210             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2211         } else {
2212             ABORT("unsupported prefix expression %s", p.description().c_str());
2213         };
2214         return result;
2215     }
2216     switch (p.fOperator) {
2217         case Token::PLUS:
2218             return this->writeExpression(*p.fOperand, out);
2219         case Token::PLUSPLUS: {
2220             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2221             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2222             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2223                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2224                                                       out);
2225             lv->store(result, out);
2226             return result;
2227         }
2228         case Token::MINUSMINUS: {
2229             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2230             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2231             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2232                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2233                                                       out);
2234             lv->store(result, out);
2235             return result;
2236         }
2237         case Token::LOGICALNOT: {
2238             ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2239             SpvId result = this->nextId();
2240             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2241                                    this->writeExpression(*p.fOperand, out), out);
2242             return result;
2243         }
2244         case Token::BITWISENOT: {
2245             SpvId result = this->nextId();
2246             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2247                                    this->writeExpression(*p.fOperand, out), out);
2248             return result;
2249         }
2250         default:
2251             ABORT("unsupported prefix expression: %s", p.description().c_str());
2252     }
2253 }
2254 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2255 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2256     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2257     SpvId result = lv->load(out);
2258     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2259     switch (p.fOperator) {
2260         case Token::PLUSPLUS: {
2261             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2262                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2263             lv->store(temp, out);
2264             return result;
2265         }
2266         case Token::MINUSMINUS: {
2267             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2268                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2269             lv->store(temp, out);
2270             return result;
2271         }
2272         default:
2273             ABORT("unsupported postfix expression %s", p.description().c_str());
2274     }
2275 }
2276 
writeBoolLiteral(const BoolLiteral & b)2277 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2278     if (b.fValue) {
2279         if (fBoolTrue == 0) {
2280             fBoolTrue = this->nextId();
2281             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2282                                    fConstantBuffer);
2283         }
2284         return fBoolTrue;
2285     } else {
2286         if (fBoolFalse == 0) {
2287             fBoolFalse = this->nextId();
2288             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2289                                    fConstantBuffer);
2290         }
2291         return fBoolFalse;
2292     }
2293 }
2294 
writeIntLiteral(const IntLiteral & i)2295 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2296     if (i.fType == *fContext.fInt_Type) {
2297         auto entry = fIntConstants.find(i.fValue);
2298         if (entry == fIntConstants.end()) {
2299             SpvId result = this->nextId();
2300             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2301                                    fConstantBuffer);
2302             fIntConstants[i.fValue] = result;
2303             return result;
2304         }
2305         return entry->second;
2306     } else {
2307         ASSERT(i.fType == *fContext.fUInt_Type);
2308         auto entry = fUIntConstants.find(i.fValue);
2309         if (entry == fUIntConstants.end()) {
2310             SpvId result = this->nextId();
2311             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2312                                    fConstantBuffer);
2313             fUIntConstants[i.fValue] = result;
2314             return result;
2315         }
2316         return entry->second;
2317     }
2318 }
2319 
writeFloatLiteral(const FloatLiteral & f)2320 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2321     if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2322         float value = (float) f.fValue;
2323         auto entry = fFloatConstants.find(value);
2324         if (entry == fFloatConstants.end()) {
2325             SpvId result = this->nextId();
2326             uint32_t bits;
2327             ASSERT(sizeof(bits) == sizeof(value));
2328             memcpy(&bits, &value, sizeof(bits));
2329             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2330                                    fConstantBuffer);
2331             fFloatConstants[value] = result;
2332             return result;
2333         }
2334         return entry->second;
2335     } else {
2336         ASSERT(f.fType == *fContext.fDouble_Type);
2337         auto entry = fDoubleConstants.find(f.fValue);
2338         if (entry == fDoubleConstants.end()) {
2339             SpvId result = this->nextId();
2340             uint64_t bits;
2341             ASSERT(sizeof(bits) == sizeof(f.fValue));
2342             memcpy(&bits, &f.fValue, sizeof(bits));
2343             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2344                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2345             fDoubleConstants[f.fValue] = result;
2346             return result;
2347         }
2348         return entry->second;
2349     }
2350 }
2351 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2352 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2353     SpvId result = fFunctionMap[&f];
2354     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2355                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2356     this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2357     for (size_t i = 0; i < f.fParameters.size(); i++) {
2358         SpvId id = this->nextId();
2359         fVariableMap[f.fParameters[i]] = id;
2360         SpvId type;
2361         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2362         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2363     }
2364     return result;
2365 }
2366 
writeFunction(const FunctionDefinition & f,OutputStream & out)2367 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2368     fVariableBuffer.reset();
2369     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2370     this->writeLabel(this->nextId(), out);
2371     if (f.fDeclaration.fName == "main") {
2372         write_stringstream(fGlobalInitializersBuffer, out);
2373     }
2374     StringStream bodyBuffer;
2375     this->writeBlock((Block&) *f.fBody, bodyBuffer);
2376     write_stringstream(fVariableBuffer, out);
2377     write_stringstream(bodyBuffer, out);
2378     if (fCurrentBlock) {
2379         if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2380             this->writeInstruction(SpvOpReturn, out);
2381         } else {
2382             this->writeInstruction(SpvOpUnreachable, out);
2383         }
2384     }
2385     this->writeInstruction(SpvOpFunctionEnd, out);
2386     return result;
2387 }
2388 
writeLayout(const Layout & layout,SpvId target)2389 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2390     if (layout.fLocation >= 0) {
2391         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2392                                fDecorationBuffer);
2393     }
2394     if (layout.fBinding >= 0) {
2395         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2396                                fDecorationBuffer);
2397     }
2398     if (layout.fIndex >= 0) {
2399         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2400                                fDecorationBuffer);
2401     }
2402     if (layout.fSet >= 0) {
2403         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2404                                fDecorationBuffer);
2405     }
2406     if (layout.fInputAttachmentIndex >= 0) {
2407         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2408                                layout.fInputAttachmentIndex, fDecorationBuffer);
2409         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2410     }
2411     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2412         layout.fBuiltin != SK_IN_BUILTIN) {
2413         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2414                                fDecorationBuffer);
2415     }
2416 }
2417 
writeLayout(const Layout & layout,SpvId target,int member)2418 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2419     if (layout.fLocation >= 0) {
2420         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2421                                layout.fLocation, fDecorationBuffer);
2422     }
2423     if (layout.fBinding >= 0) {
2424         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2425                                layout.fBinding, fDecorationBuffer);
2426     }
2427     if (layout.fIndex >= 0) {
2428         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2429                                layout.fIndex, fDecorationBuffer);
2430     }
2431     if (layout.fSet >= 0) {
2432         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2433                                layout.fSet, fDecorationBuffer);
2434     }
2435     if (layout.fInputAttachmentIndex >= 0) {
2436         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2437                                layout.fInputAttachmentIndex, fDecorationBuffer);
2438     }
2439     if (layout.fBuiltin >= 0) {
2440         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2441                                layout.fBuiltin, fDecorationBuffer);
2442     }
2443 }
2444 
writeInterfaceBlock(const InterfaceBlock & intf)2445 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2446     bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2447     bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2448                                Layout::kPushConstant_Flag));
2449     MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2450                                 MemoryLayout(MemoryLayout::k430_Standard) :
2451                                 fDefaultLayout;
2452     SpvId result = this->nextId();
2453     const Type* type = &intf.fVariable.fType;
2454     if (fProgram.fInputs.fRTHeight) {
2455         ASSERT(fRTHeightStructId == (SpvId) -1);
2456         ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2457         std::vector<Type::Field> fields = type->fields();
2458         fRTHeightStructId = result;
2459         fRTHeightFieldIndex = fields.size();
2460         fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2461         type = new Type(type->fOffset, type->name(), fields);
2462     }
2463     SpvId typeId = this->getType(*type, memoryLayout);
2464     if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2465         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2466     } else {
2467         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2468     }
2469     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2470     SpvId ptrType = this->nextId();
2471     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2472     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2473     Layout layout = intf.fVariable.fModifiers.fLayout;
2474     if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2475         layout.fSet = 0;
2476     }
2477     this->writeLayout(layout, result);
2478     fVariableMap[&intf.fVariable] = result;
2479     if (fProgram.fInputs.fRTHeight) {
2480         delete type;
2481     }
2482     return result;
2483 }
2484 
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2485 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2486     if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2487         (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2488         this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2489     }
2490 }
2491 
2492 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2493 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2494                                          OutputStream& out) {
2495     for (size_t i = 0; i < decl.fVars.size(); i++) {
2496         if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2497             continue;
2498         }
2499         const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2500         const Variable* var = varDecl.fVar;
2501         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2502         // in the OpenGL backend.
2503         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2504                                            Modifiers::kWriteOnly_Flag |
2505                                            Modifiers::kCoherent_Flag |
2506                                            Modifiers::kVolatile_Flag |
2507                                            Modifiers::kRestrict_Flag)));
2508         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2509             continue;
2510         }
2511         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2512             kind != Program::kFragment_Kind) {
2513             ASSERT(!fProgram.fSettings.fFragColorIsInOut);
2514             continue;
2515         }
2516         if (!var->fReadCount && !var->fWriteCount &&
2517                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2518                                             Modifiers::kOut_Flag |
2519                                             Modifiers::kUniform_Flag |
2520                                             Modifiers::kBuffer_Flag))) {
2521             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2522             // we elide an interface var, even if it's dead)
2523             continue;
2524         }
2525         SpvStorageClass_ storageClass;
2526         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2527             storageClass = SpvStorageClassInput;
2528         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2529             storageClass = SpvStorageClassOutput;
2530         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2531             if (var->fType.kind() == Type::kSampler_Kind) {
2532                 storageClass = SpvStorageClassUniformConstant;
2533             } else {
2534                 storageClass = SpvStorageClassUniform;
2535             }
2536         } else {
2537             storageClass = SpvStorageClassPrivate;
2538         }
2539         SpvId id = this->nextId();
2540         fVariableMap[var] = id;
2541         SpvId type = this->getPointerType(var->fType, storageClass);
2542         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2543         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2544         this->writePrecisionModifier(var->fModifiers, id);
2545         if (varDecl.fValue) {
2546             ASSERT(!fCurrentBlock);
2547             fCurrentBlock = -1;
2548             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2549             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2550             fCurrentBlock = 0;
2551         }
2552         this->writeLayout(var->fModifiers.fLayout, id);
2553         if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2554             this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2555         }
2556         if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2557             this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2558                                    fDecorationBuffer);
2559         }
2560     }
2561 }
2562 
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2563 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2564     for (const auto& stmt : decl.fVars) {
2565         ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2566         VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2567         const Variable* var = varDecl.fVar;
2568         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2569         // in the OpenGL backend.
2570         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2571                                            Modifiers::kWriteOnly_Flag |
2572                                            Modifiers::kCoherent_Flag |
2573                                            Modifiers::kVolatile_Flag |
2574                                            Modifiers::kRestrict_Flag)));
2575         SpvId id = this->nextId();
2576         fVariableMap[var] = id;
2577         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2578         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2579         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2580         if (varDecl.fValue) {
2581             SpvId value = this->writeExpression(*varDecl.fValue, out);
2582             this->writeInstruction(SpvOpStore, id, value, out);
2583         }
2584     }
2585 }
2586 
writeStatement(const Statement & s,OutputStream & out)2587 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2588     switch (s.fKind) {
2589         case Statement::kNop_Kind:
2590             break;
2591         case Statement::kBlock_Kind:
2592             this->writeBlock((Block&) s, out);
2593             break;
2594         case Statement::kExpression_Kind:
2595             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2596             break;
2597         case Statement::kReturn_Kind:
2598             this->writeReturnStatement((ReturnStatement&) s, out);
2599             break;
2600         case Statement::kVarDeclarations_Kind:
2601             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2602             break;
2603         case Statement::kIf_Kind:
2604             this->writeIfStatement((IfStatement&) s, out);
2605             break;
2606         case Statement::kFor_Kind:
2607             this->writeForStatement((ForStatement&) s, out);
2608             break;
2609         case Statement::kWhile_Kind:
2610             this->writeWhileStatement((WhileStatement&) s, out);
2611             break;
2612         case Statement::kDo_Kind:
2613             this->writeDoStatement((DoStatement&) s, out);
2614             break;
2615         case Statement::kSwitch_Kind:
2616             this->writeSwitchStatement((SwitchStatement&) s, out);
2617             break;
2618         case Statement::kBreak_Kind:
2619             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2620             break;
2621         case Statement::kContinue_Kind:
2622             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2623             break;
2624         case Statement::kDiscard_Kind:
2625             this->writeInstruction(SpvOpKill, out);
2626             break;
2627         default:
2628             ABORT("unsupported statement: %s", s.description().c_str());
2629     }
2630 }
2631 
writeBlock(const Block & b,OutputStream & out)2632 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2633     for (size_t i = 0; i < b.fStatements.size(); i++) {
2634         this->writeStatement(*b.fStatements[i], out);
2635     }
2636 }
2637 
writeIfStatement(const IfStatement & stmt,OutputStream & out)2638 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2639     SpvId test = this->writeExpression(*stmt.fTest, out);
2640     SpvId ifTrue = this->nextId();
2641     SpvId ifFalse = this->nextId();
2642     if (stmt.fIfFalse) {
2643         SpvId end = this->nextId();
2644         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2645         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2646         this->writeLabel(ifTrue, out);
2647         this->writeStatement(*stmt.fIfTrue, out);
2648         if (fCurrentBlock) {
2649             this->writeInstruction(SpvOpBranch, end, out);
2650         }
2651         this->writeLabel(ifFalse, out);
2652         this->writeStatement(*stmt.fIfFalse, out);
2653         if (fCurrentBlock) {
2654             this->writeInstruction(SpvOpBranch, end, out);
2655         }
2656         this->writeLabel(end, out);
2657     } else {
2658         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2659         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2660         this->writeLabel(ifTrue, out);
2661         this->writeStatement(*stmt.fIfTrue, out);
2662         if (fCurrentBlock) {
2663             this->writeInstruction(SpvOpBranch, ifFalse, out);
2664         }
2665         this->writeLabel(ifFalse, out);
2666     }
2667 }
2668 
writeForStatement(const ForStatement & f,OutputStream & out)2669 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2670     if (f.fInitializer) {
2671         this->writeStatement(*f.fInitializer, out);
2672     }
2673     SpvId header = this->nextId();
2674     SpvId start = this->nextId();
2675     SpvId body = this->nextId();
2676     SpvId next = this->nextId();
2677     fContinueTarget.push(next);
2678     SpvId end = this->nextId();
2679     fBreakTarget.push(end);
2680     this->writeInstruction(SpvOpBranch, header, out);
2681     this->writeLabel(header, out);
2682     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2683     this->writeInstruction(SpvOpBranch, start, out);
2684     this->writeLabel(start, out);
2685     if (f.fTest) {
2686         SpvId test = this->writeExpression(*f.fTest, out);
2687         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2688     }
2689     this->writeLabel(body, out);
2690     this->writeStatement(*f.fStatement, out);
2691     if (fCurrentBlock) {
2692         this->writeInstruction(SpvOpBranch, next, out);
2693     }
2694     this->writeLabel(next, out);
2695     if (f.fNext) {
2696         this->writeExpression(*f.fNext, out);
2697     }
2698     this->writeInstruction(SpvOpBranch, header, out);
2699     this->writeLabel(end, out);
2700     fBreakTarget.pop();
2701     fContinueTarget.pop();
2702 }
2703 
writeWhileStatement(const WhileStatement & w,OutputStream & out)2704 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2705     // We believe the while loop code below will work, but Skia doesn't actually use them and
2706     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2707     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2708     // message, simply remove the error call below to see whether our while loop support actually
2709     // works.
2710     fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2711                   "see SkSLSPIRVCodeGenerator.cpp for details");
2712 
2713     SpvId header = this->nextId();
2714     SpvId start = this->nextId();
2715     SpvId body = this->nextId();
2716     fContinueTarget.push(start);
2717     SpvId end = this->nextId();
2718     fBreakTarget.push(end);
2719     this->writeInstruction(SpvOpBranch, header, out);
2720     this->writeLabel(header, out);
2721     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2722     this->writeInstruction(SpvOpBranch, start, out);
2723     this->writeLabel(start, out);
2724     SpvId test = this->writeExpression(*w.fTest, out);
2725     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2726     this->writeLabel(body, out);
2727     this->writeStatement(*w.fStatement, out);
2728     if (fCurrentBlock) {
2729         this->writeInstruction(SpvOpBranch, start, out);
2730     }
2731     this->writeLabel(end, out);
2732     fBreakTarget.pop();
2733     fContinueTarget.pop();
2734 }
2735 
writeDoStatement(const DoStatement & d,OutputStream & out)2736 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2737     // We believe the do loop code below will work, but Skia doesn't actually use them and
2738     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2739     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2740     // message, simply remove the error call below to see whether our do loop support actually
2741     // works.
2742     fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2743                   "SkSLSPIRVCodeGenerator.cpp for details");
2744 
2745     SpvId header = this->nextId();
2746     SpvId start = this->nextId();
2747     SpvId next = this->nextId();
2748     fContinueTarget.push(next);
2749     SpvId end = this->nextId();
2750     fBreakTarget.push(end);
2751     this->writeInstruction(SpvOpBranch, header, out);
2752     this->writeLabel(header, out);
2753     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2754     this->writeInstruction(SpvOpBranch, start, out);
2755     this->writeLabel(start, out);
2756     this->writeStatement(*d.fStatement, out);
2757     if (fCurrentBlock) {
2758         this->writeInstruction(SpvOpBranch, next, out);
2759     }
2760     this->writeLabel(next, out);
2761     SpvId test = this->writeExpression(*d.fTest, out);
2762     this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2763     this->writeLabel(end, out);
2764     fBreakTarget.pop();
2765     fContinueTarget.pop();
2766 }
2767 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2768 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2769     SpvId value = this->writeExpression(*s.fValue, out);
2770     std::vector<SpvId> labels;
2771     SpvId end = this->nextId();
2772     SpvId defaultLabel = end;
2773     fBreakTarget.push(end);
2774     int size = 3;
2775     for (const auto& c : s.fCases) {
2776         SpvId label = this->nextId();
2777         labels.push_back(label);
2778         if (c->fValue) {
2779             size += 2;
2780         } else {
2781             defaultLabel = label;
2782         }
2783     }
2784     labels.push_back(end);
2785     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2786     this->writeOpCode(SpvOpSwitch, size, out);
2787     this->writeWord(value, out);
2788     this->writeWord(defaultLabel, out);
2789     for (size_t i = 0; i < s.fCases.size(); ++i) {
2790         if (!s.fCases[i]->fValue) {
2791             continue;
2792         }
2793         ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2794         this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2795         this->writeWord(labels[i], out);
2796     }
2797     for (size_t i = 0; i < s.fCases.size(); ++i) {
2798         this->writeLabel(labels[i], out);
2799         for (const auto& stmt : s.fCases[i]->fStatements) {
2800             this->writeStatement(*stmt, out);
2801         }
2802         if (fCurrentBlock) {
2803             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
2804         }
2805     }
2806     this->writeLabel(end, out);
2807     fBreakTarget.pop();
2808 }
2809 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)2810 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
2811     if (r.fExpression) {
2812         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2813                                out);
2814     } else {
2815         this->writeInstruction(SpvOpReturn, out);
2816     }
2817 }
2818 
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)2819 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
2820     ASSERT(fProgram.fKind == Program::kGeometry_Kind);
2821     int invocations = 1;
2822     for (size_t i = 0; i < fProgram.fElements.size(); i++) {
2823         if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) {
2824             const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers;
2825             if (m.fFlags & Modifiers::kIn_Flag) {
2826                 if (m.fLayout.fInvocations != -1) {
2827                     invocations = m.fLayout.fInvocations;
2828                 }
2829                 SpvId input;
2830                 switch (m.fLayout.fPrimitive) {
2831                     case Layout::kPoints_Primitive:
2832                         input = SpvExecutionModeInputPoints;
2833                         break;
2834                     case Layout::kLines_Primitive:
2835                         input = SpvExecutionModeInputLines;
2836                         break;
2837                     case Layout::kLinesAdjacency_Primitive:
2838                         input = SpvExecutionModeInputLinesAdjacency;
2839                         break;
2840                     case Layout::kTriangles_Primitive:
2841                         input = SpvExecutionModeTriangles;
2842                         break;
2843                     case Layout::kTrianglesAdjacency_Primitive:
2844                         input = SpvExecutionModeInputTrianglesAdjacency;
2845                         break;
2846                     default:
2847                         input = 0;
2848                         break;
2849                 }
2850                 if (input) {
2851                     this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
2852                 }
2853             } else if (m.fFlags & Modifiers::kOut_Flag) {
2854                 SpvId output;
2855                 switch (m.fLayout.fPrimitive) {
2856                     case Layout::kPoints_Primitive:
2857                         output = SpvExecutionModeOutputPoints;
2858                         break;
2859                     case Layout::kLineStrip_Primitive:
2860                         output = SpvExecutionModeOutputLineStrip;
2861                         break;
2862                     case Layout::kTriangleStrip_Primitive:
2863                         output = SpvExecutionModeOutputTriangleStrip;
2864                         break;
2865                     default:
2866                         output = 0;
2867                         break;
2868                 }
2869                 if (output) {
2870                     this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
2871                 }
2872                 if (m.fLayout.fMaxVertices != -1) {
2873                     this->writeInstruction(SpvOpExecutionMode, entryPoint,
2874                                            SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
2875                                            out);
2876                 }
2877             }
2878         }
2879     }
2880     this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
2881                            invocations, out);
2882 }
2883 
writeInstructions(const Program & program,OutputStream & out)2884 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
2885     fGLSLExtendedInstructions = this->nextId();
2886     StringStream body;
2887     std::set<SpvId> interfaceVars;
2888     // assign IDs to functions, determine sk_in size
2889     int skInSize = -1;
2890     for (size_t i = 0; i < program.fElements.size(); i++) {
2891         switch (program.fElements[i]->fKind) {
2892             case ProgramElement::kFunction_Kind: {
2893                 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2894                 fFunctionMap[&f.fDeclaration] = this->nextId();
2895                 break;
2896             }
2897             case ProgramElement::kModifiers_Kind: {
2898                 Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers;
2899                 if (m.fFlags & Modifiers::kIn_Flag) {
2900                     switch (m.fLayout.fPrimitive) {
2901                         case Layout::kPoints_Primitive: // break
2902                         case Layout::kLines_Primitive:
2903                             skInSize = 1;
2904                             break;
2905                         case Layout::kLinesAdjacency_Primitive: // break
2906                             skInSize = 2;
2907                             break;
2908                         case Layout::kTriangles_Primitive: // break
2909                         case Layout::kTrianglesAdjacency_Primitive:
2910                             skInSize = 3;
2911                             break;
2912                         default:
2913                             break;
2914                     }
2915                 }
2916                 break;
2917             }
2918             default:
2919                 break;
2920         }
2921     }
2922     for (size_t i = 0; i < program.fElements.size(); i++) {
2923         if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2924             InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2925             if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
2926                 ASSERT(skInSize != -1);
2927                 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
2928             }
2929             SpvId id = this->writeInterfaceBlock(intf);
2930             if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2931                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2932                 interfaceVars.insert(id);
2933             }
2934         }
2935     }
2936     for (size_t i = 0; i < program.fElements.size(); i++) {
2937         if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2938             this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2939                                   body);
2940         }
2941     }
2942     for (size_t i = 0; i < program.fElements.size(); i++) {
2943         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2944             this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2945         }
2946     }
2947     const FunctionDeclaration* main = nullptr;
2948     for (auto entry : fFunctionMap) {
2949         if (entry.first->fName == "main") {
2950             main = entry.first;
2951         }
2952     }
2953     ASSERT(main);
2954     for (auto entry : fVariableMap) {
2955         const Variable* var = entry.first;
2956         if (var->fStorage == Variable::kGlobal_Storage &&
2957                 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2958                  (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2959             interfaceVars.insert(entry.second);
2960         }
2961     }
2962     this->writeCapabilities(out);
2963     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2964     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2965     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
2966                       (int32_t) interfaceVars.size(), out);
2967     switch (program.fKind) {
2968         case Program::kVertex_Kind:
2969             this->writeWord(SpvExecutionModelVertex, out);
2970             break;
2971         case Program::kFragment_Kind:
2972             this->writeWord(SpvExecutionModelFragment, out);
2973             break;
2974         case Program::kGeometry_Kind:
2975             this->writeWord(SpvExecutionModelGeometry, out);
2976             break;
2977         default:
2978             ABORT("cannot write this kind of program to SPIR-V\n");
2979     }
2980     SpvId entryPoint = fFunctionMap[main];
2981     this->writeWord(entryPoint, out);
2982     this->writeString(main->fName.fChars, main->fName.fLength, out);
2983     for (int var : interfaceVars) {
2984         this->writeWord(var, out);
2985     }
2986     if (program.fKind == Program::kGeometry_Kind) {
2987         this->writeGeometryShaderExecutionMode(entryPoint, out);
2988     }
2989     if (program.fKind == Program::kFragment_Kind) {
2990         this->writeInstruction(SpvOpExecutionMode,
2991                                fFunctionMap[main],
2992                                SpvExecutionModeOriginUpperLeft,
2993                                out);
2994     }
2995     for (size_t i = 0; i < program.fElements.size(); i++) {
2996         if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2997             this->writeInstruction(SpvOpSourceExtension,
2998                                    ((Extension&) *program.fElements[i]).fName.c_str(),
2999                                    out);
3000         }
3001     }
3002 
3003     write_stringstream(fExtraGlobalsBuffer, out);
3004     write_stringstream(fNameBuffer, out);
3005     write_stringstream(fDecorationBuffer, out);
3006     write_stringstream(fConstantBuffer, out);
3007     write_stringstream(fExternalFunctionsBuffer, out);
3008     write_stringstream(body, out);
3009 }
3010 
generateCode()3011 bool SPIRVCodeGenerator::generateCode() {
3012     ASSERT(!fErrors.errorCount());
3013     this->writeWord(SpvMagicNumber, *fOut);
3014     this->writeWord(SpvVersion, *fOut);
3015     this->writeWord(SKSL_MAGIC, *fOut);
3016     StringStream buffer;
3017     this->writeInstructions(fProgram, buffer);
3018     this->writeWord(fIdCount, *fOut);
3019     this->writeWord(0, *fOut); // reserved, always zero
3020     write_stringstream(buffer, *fOut);
3021     return 0 == fErrors.errorCount();
3022 }
3023 
3024 }
3025