• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLSPIRVCodeGenerator.h"
9 
10 #include "GLSL.std.450.h"
11 
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17 
18 namespace SkSL {
19 
20 #define SPIRV_DEBUG 0
21 
22 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
23 
setupIntrinsics()24 void SPIRVCodeGenerator::setupIntrinsics() {
25 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
26                                     GLSLstd450 ## x, GLSLstd450 ## x)
27 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
28                                                              GLSLstd450 ## ifFloat, \
29                                                              GLSLstd450 ## ifInt, \
30                                                              GLSLstd450 ## ifUInt, \
31                                                              SpvOpUndef)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34                                    k ## x ## _SpecialIntrinsic)
35     fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
36     fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
37     fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
38     fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39     fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
40     fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
41     fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
42     fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
43     fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
44     fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
45     fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
46     fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
47     fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
48     fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
49     fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
50     fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
51     fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
52     fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
53     fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
54     fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
55     fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
56     fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
57     fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
58     fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
59     fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
60     fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
61     fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
62     fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
63     fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
64     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
65     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
66     fIntrinsicMap[String("mod")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod,
67                                                              SpvOpSMod, SpvOpUMod, SpvOpUndef);
68     fIntrinsicMap[String("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
69     fIntrinsicMap[String("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
70     fIntrinsicMap[String("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
71     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
72                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
73     fIntrinsicMap[String("mix")]           = ALL_GLSL(FMix);
74     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
75     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
76     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
77     fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
78     fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
79 
80 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
81                    fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
82     PACK(Snorm4x8);
83     PACK(Unorm4x8);
84     PACK(Snorm2x16);
85     PACK(Unorm2x16);
86     PACK(Half2x16);
87     PACK(Double2x32);
88     fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
89     fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
90     fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
91     fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
92     fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
93     fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
94     fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
95     fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
96     fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
97     fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
98                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
99     fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
100                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
101     fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
103     fIntrinsicMap[String("texture")]     = SPECIAL(Texture);
104     fIntrinsicMap[String("texelFetch")]  = SPECIAL(TexelFetch);
105     fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
106 
107     fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
108                                                                 SpvOpUndef, SpvOpUndef, SpvOpAny);
109     fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
110                                                                 SpvOpUndef, SpvOpUndef, SpvOpAll);
111     fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
112                                                                 SpvOpFOrdEqual, SpvOpIEqual,
113                                                                 SpvOpIEqual, SpvOpLogicalEqual);
114     fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
115                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
116                                                                 SpvOpINotEqual,
117                                                                 SpvOpLogicalNotEqual);
118     fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
119                                                                 SpvOpSLessThan, SpvOpULessThan,
120                                                                 SpvOpFOrdLessThan, SpvOpUndef);
121     fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
122                                                                 SpvOpSLessThanEqual,
123                                                                 SpvOpULessThanEqual,
124                                                                 SpvOpFOrdLessThanEqual,
125                                                                 SpvOpUndef);
126     fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
127                                                                 SpvOpSGreaterThan,
128                                                                 SpvOpUGreaterThan,
129                                                                 SpvOpFOrdGreaterThan,
130                                                                 SpvOpUndef);
131     fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
132                                                                 SpvOpSGreaterThanEqual,
133                                                                 SpvOpUGreaterThanEqual,
134                                                                 SpvOpFOrdGreaterThanEqual,
135                                                                 SpvOpUndef);
136 // interpolateAt* not yet supported...
137 }
138 
writeWord(int32_t word,OutputStream & out)139 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
140 #if SPIRV_DEBUG
141     out << "(" << word << ") ";
142 #else
143     out.write((const char*) &word, sizeof(word));
144 #endif
145 }
146 
is_float(const Context & context,const Type & type)147 static bool is_float(const Context& context, const Type& type) {
148     if (type.kind() == Type::kVector_Kind) {
149         return is_float(context, type.componentType());
150     }
151     return type == *context.fFloat_Type || type == *context.fDouble_Type;
152 }
153 
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155     if (type.kind() == Type::kVector_Kind) {
156         return is_signed(context, type.componentType());
157     }
158     return type == *context.fInt_Type;
159 }
160 
is_unsigned(const Context & context,const Type & type)161 static bool is_unsigned(const Context& context, const Type& type) {
162     if (type.kind() == Type::kVector_Kind) {
163         return is_unsigned(context, type.componentType());
164     }
165     return type == *context.fUInt_Type;
166 }
167 
is_bool(const Context & context,const Type & type)168 static bool is_bool(const Context& context, const Type& type) {
169     if (type.kind() == Type::kVector_Kind) {
170         return is_bool(context, type.componentType());
171     }
172     return type == *context.fBool_Type;
173 }
174 
is_out(const Variable & var)175 static bool is_out(const Variable& var) {
176     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177 }
178 
179 #if SPIRV_DEBUG
opcode_text(SpvOp_ opCode)180 static String opcode_text(SpvOp_ opCode) {
181     switch (opCode) {
182         case SpvOpNop:
183             return String("Nop");
184         case SpvOpUndef:
185             return String("Undef");
186         case SpvOpSourceContinued:
187             return String("SourceContinued");
188         case SpvOpSource:
189             return String("Source");
190         case SpvOpSourceExtension:
191             return String("SourceExtension");
192         case SpvOpName:
193             return String("Name");
194         case SpvOpMemberName:
195             return String("MemberName");
196         case SpvOpString:
197             return String("String");
198         case SpvOpLine:
199             return String("Line");
200         case SpvOpExtension:
201             return String("Extension");
202         case SpvOpExtInstImport:
203             return String("ExtInstImport");
204         case SpvOpExtInst:
205             return String("ExtInst");
206         case SpvOpMemoryModel:
207             return String("MemoryModel");
208         case SpvOpEntryPoint:
209             return String("EntryPoint");
210         case SpvOpExecutionMode:
211             return String("ExecutionMode");
212         case SpvOpCapability:
213             return String("Capability");
214         case SpvOpTypeVoid:
215             return String("TypeVoid");
216         case SpvOpTypeBool:
217             return String("TypeBool");
218         case SpvOpTypeInt:
219             return String("TypeInt");
220         case SpvOpTypeFloat:
221             return String("TypeFloat");
222         case SpvOpTypeVector:
223             return String("TypeVector");
224         case SpvOpTypeMatrix:
225             return String("TypeMatrix");
226         case SpvOpTypeImage:
227             return String("TypeImage");
228         case SpvOpTypeSampler:
229             return String("TypeSampler");
230         case SpvOpTypeSampledImage:
231             return String("TypeSampledImage");
232         case SpvOpTypeArray:
233             return String("TypeArray");
234         case SpvOpTypeRuntimeArray:
235             return String("TypeRuntimeArray");
236         case SpvOpTypeStruct:
237             return String("TypeStruct");
238         case SpvOpTypeOpaque:
239             return String("TypeOpaque");
240         case SpvOpTypePointer:
241             return String("TypePointer");
242         case SpvOpTypeFunction:
243             return String("TypeFunction");
244         case SpvOpTypeEvent:
245             return String("TypeEvent");
246         case SpvOpTypeDeviceEvent:
247             return String("TypeDeviceEvent");
248         case SpvOpTypeReserveId:
249             return String("TypeReserveId");
250         case SpvOpTypeQueue:
251             return String("TypeQueue");
252         case SpvOpTypePipe:
253             return String("TypePipe");
254         case SpvOpTypeForwardPointer:
255             return String("TypeForwardPointer");
256         case SpvOpConstantTrue:
257             return String("ConstantTrue");
258         case SpvOpConstantFalse:
259             return String("ConstantFalse");
260         case SpvOpConstant:
261             return String("Constant");
262         case SpvOpConstantComposite:
263             return String("ConstantComposite");
264         case SpvOpConstantSampler:
265             return String("ConstantSampler");
266         case SpvOpConstantNull:
267             return String("ConstantNull");
268         case SpvOpSpecConstantTrue:
269             return String("SpecConstantTrue");
270         case SpvOpSpecConstantFalse:
271             return String("SpecConstantFalse");
272         case SpvOpSpecConstant:
273             return String("SpecConstant");
274         case SpvOpSpecConstantComposite:
275             return String("SpecConstantComposite");
276         case SpvOpSpecConstantOp:
277             return String("SpecConstantOp");
278         case SpvOpFunction:
279             return String("Function");
280         case SpvOpFunctionParameter:
281             return String("FunctionParameter");
282         case SpvOpFunctionEnd:
283             return String("FunctionEnd");
284         case SpvOpFunctionCall:
285             return String("FunctionCall");
286         case SpvOpVariable:
287             return String("Variable");
288         case SpvOpImageTexelPointer:
289             return String("ImageTexelPointer");
290         case SpvOpLoad:
291             return String("Load");
292         case SpvOpStore:
293             return String("Store");
294         case SpvOpCopyMemory:
295             return String("CopyMemory");
296         case SpvOpCopyMemorySized:
297             return String("CopyMemorySized");
298         case SpvOpAccessChain:
299             return String("AccessChain");
300         case SpvOpInBoundsAccessChain:
301             return String("InBoundsAccessChain");
302         case SpvOpPtrAccessChain:
303             return String("PtrAccessChain");
304         case SpvOpArrayLength:
305             return String("ArrayLength");
306         case SpvOpGenericPtrMemSemantics:
307             return String("GenericPtrMemSemantics");
308         case SpvOpInBoundsPtrAccessChain:
309             return String("InBoundsPtrAccessChain");
310         case SpvOpDecorate:
311             return String("Decorate");
312         case SpvOpMemberDecorate:
313             return String("MemberDecorate");
314         case SpvOpDecorationGroup:
315             return String("DecorationGroup");
316         case SpvOpGroupDecorate:
317             return String("GroupDecorate");
318         case SpvOpGroupMemberDecorate:
319             return String("GroupMemberDecorate");
320         case SpvOpVectorExtractDynamic:
321             return String("VectorExtractDynamic");
322         case SpvOpVectorInsertDynamic:
323             return String("VectorInsertDynamic");
324         case SpvOpVectorShuffle:
325             return String("VectorShuffle");
326         case SpvOpCompositeConstruct:
327             return String("CompositeConstruct");
328         case SpvOpCompositeExtract:
329             return String("CompositeExtract");
330         case SpvOpCompositeInsert:
331             return String("CompositeInsert");
332         case SpvOpCopyObject:
333             return String("CopyObject");
334         case SpvOpTranspose:
335             return String("Transpose");
336         case SpvOpSampledImage:
337             return String("SampledImage");
338         case SpvOpImageSampleImplicitLod:
339             return String("ImageSampleImplicitLod");
340         case SpvOpImageSampleExplicitLod:
341             return String("ImageSampleExplicitLod");
342         case SpvOpImageSampleDrefImplicitLod:
343             return String("ImageSampleDrefImplicitLod");
344         case SpvOpImageSampleDrefExplicitLod:
345             return String("ImageSampleDrefExplicitLod");
346         case SpvOpImageSampleProjImplicitLod:
347             return String("ImageSampleProjImplicitLod");
348         case SpvOpImageSampleProjExplicitLod:
349             return String("ImageSampleProjExplicitLod");
350         case SpvOpImageSampleProjDrefImplicitLod:
351             return String("ImageSampleProjDrefImplicitLod");
352         case SpvOpImageSampleProjDrefExplicitLod:
353             return String("ImageSampleProjDrefExplicitLod");
354         case SpvOpImageFetch:
355             return String("ImageFetch");
356         case SpvOpImageGather:
357             return String("ImageGather");
358         case SpvOpImageDrefGather:
359             return String("ImageDrefGather");
360         case SpvOpImageRead:
361             return String("ImageRead");
362         case SpvOpImageWrite:
363             return String("ImageWrite");
364         case SpvOpImage:
365             return String("Image");
366         case SpvOpImageQueryFormat:
367             return String("ImageQueryFormat");
368         case SpvOpImageQueryOrder:
369             return String("ImageQueryOrder");
370         case SpvOpImageQuerySizeLod:
371             return String("ImageQuerySizeLod");
372         case SpvOpImageQuerySize:
373             return String("ImageQuerySize");
374         case SpvOpImageQueryLod:
375             return String("ImageQueryLod");
376         case SpvOpImageQueryLevels:
377             return String("ImageQueryLevels");
378         case SpvOpImageQuerySamples:
379             return String("ImageQuerySamples");
380         case SpvOpConvertFToU:
381             return String("ConvertFToU");
382         case SpvOpConvertFToS:
383             return String("ConvertFToS");
384         case SpvOpConvertSToF:
385             return String("ConvertSToF");
386         case SpvOpConvertUToF:
387             return String("ConvertUToF");
388         case SpvOpUConvert:
389             return String("UConvert");
390         case SpvOpSConvert:
391             return String("SConvert");
392         case SpvOpFConvert:
393             return String("FConvert");
394         case SpvOpQuantizeToF16:
395             return String("QuantizeToF16");
396         case SpvOpConvertPtrToU:
397             return String("ConvertPtrToU");
398         case SpvOpSatConvertSToU:
399             return String("SatConvertSToU");
400         case SpvOpSatConvertUToS:
401             return String("SatConvertUToS");
402         case SpvOpConvertUToPtr:
403             return String("ConvertUToPtr");
404         case SpvOpPtrCastToGeneric:
405             return String("PtrCastToGeneric");
406         case SpvOpGenericCastToPtr:
407             return String("GenericCastToPtr");
408         case SpvOpGenericCastToPtrExplicit:
409             return String("GenericCastToPtrExplicit");
410         case SpvOpBitcast:
411             return String("Bitcast");
412         case SpvOpSNegate:
413             return String("SNegate");
414         case SpvOpFNegate:
415             return String("FNegate");
416         case SpvOpIAdd:
417             return String("IAdd");
418         case SpvOpFAdd:
419             return String("FAdd");
420         case SpvOpISub:
421             return String("ISub");
422         case SpvOpFSub:
423             return String("FSub");
424         case SpvOpIMul:
425             return String("IMul");
426         case SpvOpFMul:
427             return String("FMul");
428         case SpvOpUDiv:
429             return String("UDiv");
430         case SpvOpSDiv:
431             return String("SDiv");
432         case SpvOpFDiv:
433             return String("FDiv");
434         case SpvOpUMod:
435             return String("UMod");
436         case SpvOpSRem:
437             return String("SRem");
438         case SpvOpSMod:
439             return String("SMod");
440         case SpvOpFRem:
441             return String("FRem");
442         case SpvOpFMod:
443             return String("FMod");
444         case SpvOpVectorTimesScalar:
445             return String("VectorTimesScalar");
446         case SpvOpMatrixTimesScalar:
447             return String("MatrixTimesScalar");
448         case SpvOpVectorTimesMatrix:
449             return String("VectorTimesMatrix");
450         case SpvOpMatrixTimesVector:
451             return String("MatrixTimesVector");
452         case SpvOpMatrixTimesMatrix:
453             return String("MatrixTimesMatrix");
454         case SpvOpOuterProduct:
455             return String("OuterProduct");
456         case SpvOpDot:
457             return String("Dot");
458         case SpvOpIAddCarry:
459             return String("IAddCarry");
460         case SpvOpISubBorrow:
461             return String("ISubBorrow");
462         case SpvOpUMulExtended:
463             return String("UMulExtended");
464         case SpvOpSMulExtended:
465             return String("SMulExtended");
466         case SpvOpAny:
467             return String("Any");
468         case SpvOpAll:
469             return String("All");
470         case SpvOpIsNan:
471             return String("IsNan");
472         case SpvOpIsInf:
473             return String("IsInf");
474         case SpvOpIsFinite:
475             return String("IsFinite");
476         case SpvOpIsNormal:
477             return String("IsNormal");
478         case SpvOpSignBitSet:
479             return String("SignBitSet");
480         case SpvOpLessOrGreater:
481             return String("LessOrGreater");
482         case SpvOpOrdered:
483             return String("Ordered");
484         case SpvOpUnordered:
485             return String("Unordered");
486         case SpvOpLogicalEqual:
487             return String("LogicalEqual");
488         case SpvOpLogicalNotEqual:
489             return String("LogicalNotEqual");
490         case SpvOpLogicalOr:
491             return String("LogicalOr");
492         case SpvOpLogicalAnd:
493             return String("LogicalAnd");
494         case SpvOpLogicalNot:
495             return String("LogicalNot");
496         case SpvOpSelect:
497             return String("Select");
498         case SpvOpIEqual:
499             return String("IEqual");
500         case SpvOpINotEqual:
501             return String("INotEqual");
502         case SpvOpUGreaterThan:
503             return String("UGreaterThan");
504         case SpvOpSGreaterThan:
505             return String("SGreaterThan");
506         case SpvOpUGreaterThanEqual:
507             return String("UGreaterThanEqual");
508         case SpvOpSGreaterThanEqual:
509             return String("SGreaterThanEqual");
510         case SpvOpULessThan:
511             return String("ULessThan");
512         case SpvOpSLessThan:
513             return String("SLessThan");
514         case SpvOpULessThanEqual:
515             return String("ULessThanEqual");
516         case SpvOpSLessThanEqual:
517             return String("SLessThanEqual");
518         case SpvOpFOrdEqual:
519             return String("FOrdEqual");
520         case SpvOpFUnordEqual:
521             return String("FUnordEqual");
522         case SpvOpFOrdNotEqual:
523             return String("FOrdNotEqual");
524         case SpvOpFUnordNotEqual:
525             return String("FUnordNotEqual");
526         case SpvOpFOrdLessThan:
527             return String("FOrdLessThan");
528         case SpvOpFUnordLessThan:
529             return String("FUnordLessThan");
530         case SpvOpFOrdGreaterThan:
531             return String("FOrdGreaterThan");
532         case SpvOpFUnordGreaterThan:
533             return String("FUnordGreaterThan");
534         case SpvOpFOrdLessThanEqual:
535             return String("FOrdLessThanEqual");
536         case SpvOpFUnordLessThanEqual:
537             return String("FUnordLessThanEqual");
538         case SpvOpFOrdGreaterThanEqual:
539             return String("FOrdGreaterThanEqual");
540         case SpvOpFUnordGreaterThanEqual:
541             return String("FUnordGreaterThanEqual");
542         case SpvOpShiftRightLogical:
543             return String("ShiftRightLogical");
544         case SpvOpShiftRightArithmetic:
545             return String("ShiftRightArithmetic");
546         case SpvOpShiftLeftLogical:
547             return String("ShiftLeftLogical");
548         case SpvOpBitwiseOr:
549             return String("BitwiseOr");
550         case SpvOpBitwiseXor:
551             return String("BitwiseXor");
552         case SpvOpBitwiseAnd:
553             return String("BitwiseAnd");
554         case SpvOpNot:
555             return String("Not");
556         case SpvOpBitFieldInsert:
557             return String("BitFieldInsert");
558         case SpvOpBitFieldSExtract:
559             return String("BitFieldSExtract");
560         case SpvOpBitFieldUExtract:
561             return String("BitFieldUExtract");
562         case SpvOpBitReverse:
563             return String("BitReverse");
564         case SpvOpBitCount:
565             return String("BitCount");
566         case SpvOpDPdx:
567             return String("DPdx");
568         case SpvOpDPdy:
569             return String("DPdy");
570         case SpvOpFwidth:
571             return String("Fwidth");
572         case SpvOpDPdxFine:
573             return String("DPdxFine");
574         case SpvOpDPdyFine:
575             return String("DPdyFine");
576         case SpvOpFwidthFine:
577             return String("FwidthFine");
578         case SpvOpDPdxCoarse:
579             return String("DPdxCoarse");
580         case SpvOpDPdyCoarse:
581             return String("DPdyCoarse");
582         case SpvOpFwidthCoarse:
583             return String("FwidthCoarse");
584         case SpvOpEmitVertex:
585             return String("EmitVertex");
586         case SpvOpEndPrimitive:
587             return String("EndPrimitive");
588         case SpvOpEmitStreamVertex:
589             return String("EmitStreamVertex");
590         case SpvOpEndStreamPrimitive:
591             return String("EndStreamPrimitive");
592         case SpvOpControlBarrier:
593             return String("ControlBarrier");
594         case SpvOpMemoryBarrier:
595             return String("MemoryBarrier");
596         case SpvOpAtomicLoad:
597             return String("AtomicLoad");
598         case SpvOpAtomicStore:
599             return String("AtomicStore");
600         case SpvOpAtomicExchange:
601             return String("AtomicExchange");
602         case SpvOpAtomicCompareExchange:
603             return String("AtomicCompareExchange");
604         case SpvOpAtomicCompareExchangeWeak:
605             return String("AtomicCompareExchangeWeak");
606         case SpvOpAtomicIIncrement:
607             return String("AtomicIIncrement");
608         case SpvOpAtomicIDecrement:
609             return String("AtomicIDecrement");
610         case SpvOpAtomicIAdd:
611             return String("AtomicIAdd");
612         case SpvOpAtomicISub:
613             return String("AtomicISub");
614         case SpvOpAtomicSMin:
615             return String("AtomicSMin");
616         case SpvOpAtomicUMin:
617             return String("AtomicUMin");
618         case SpvOpAtomicSMax:
619             return String("AtomicSMax");
620         case SpvOpAtomicUMax:
621             return String("AtomicUMax");
622         case SpvOpAtomicAnd:
623             return String("AtomicAnd");
624         case SpvOpAtomicOr:
625             return String("AtomicOr");
626         case SpvOpAtomicXor:
627             return String("AtomicXor");
628         case SpvOpPhi:
629             return String("Phi");
630         case SpvOpLoopMerge:
631             return String("LoopMerge");
632         case SpvOpSelectionMerge:
633             return String("SelectionMerge");
634         case SpvOpLabel:
635             return String("Label");
636         case SpvOpBranch:
637             return String("Branch");
638         case SpvOpBranchConditional:
639             return String("BranchConditional");
640         case SpvOpSwitch:
641             return String("Switch");
642         case SpvOpKill:
643             return String("Kill");
644         case SpvOpReturn:
645             return String("Return");
646         case SpvOpReturnValue:
647             return String("ReturnValue");
648         case SpvOpUnreachable:
649             return String("Unreachable");
650         case SpvOpLifetimeStart:
651             return String("LifetimeStart");
652         case SpvOpLifetimeStop:
653             return String("LifetimeStop");
654         case SpvOpGroupAsyncCopy:
655             return String("GroupAsyncCopy");
656         case SpvOpGroupWaitEvents:
657             return String("GroupWaitEvents");
658         case SpvOpGroupAll:
659             return String("GroupAll");
660         case SpvOpGroupAny:
661             return String("GroupAny");
662         case SpvOpGroupBroadcast:
663             return String("GroupBroadcast");
664         case SpvOpGroupIAdd:
665             return String("GroupIAdd");
666         case SpvOpGroupFAdd:
667             return String("GroupFAdd");
668         case SpvOpGroupFMin:
669             return String("GroupFMin");
670         case SpvOpGroupUMin:
671             return String("GroupUMin");
672         case SpvOpGroupSMin:
673             return String("GroupSMin");
674         case SpvOpGroupFMax:
675             return String("GroupFMax");
676         case SpvOpGroupUMax:
677             return String("GroupUMax");
678         case SpvOpGroupSMax:
679             return String("GroupSMax");
680         case SpvOpReadPipe:
681             return String("ReadPipe");
682         case SpvOpWritePipe:
683             return String("WritePipe");
684         case SpvOpReservedReadPipe:
685             return String("ReservedReadPipe");
686         case SpvOpReservedWritePipe:
687             return String("ReservedWritePipe");
688         case SpvOpReserveReadPipePackets:
689             return String("ReserveReadPipePackets");
690         case SpvOpReserveWritePipePackets:
691             return String("ReserveWritePipePackets");
692         case SpvOpCommitReadPipe:
693             return String("CommitReadPipe");
694         case SpvOpCommitWritePipe:
695             return String("CommitWritePipe");
696         case SpvOpIsValidReserveId:
697             return String("IsValidReserveId");
698         case SpvOpGetNumPipePackets:
699             return String("GetNumPipePackets");
700         case SpvOpGetMaxPipePackets:
701             return String("GetMaxPipePackets");
702         case SpvOpGroupReserveReadPipePackets:
703             return String("GroupReserveReadPipePackets");
704         case SpvOpGroupReserveWritePipePackets:
705             return String("GroupReserveWritePipePackets");
706         case SpvOpGroupCommitReadPipe:
707             return String("GroupCommitReadPipe");
708         case SpvOpGroupCommitWritePipe:
709             return String("GroupCommitWritePipe");
710         case SpvOpEnqueueMarker:
711             return String("EnqueueMarker");
712         case SpvOpEnqueueKernel:
713             return String("EnqueueKernel");
714         case SpvOpGetKernelNDrangeSubGroupCount:
715             return String("GetKernelNDrangeSubGroupCount");
716         case SpvOpGetKernelNDrangeMaxSubGroupSize:
717             return String("GetKernelNDrangeMaxSubGroupSize");
718         case SpvOpGetKernelWorkGroupSize:
719             return String("GetKernelWorkGroupSize");
720         case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
721             return String("GetKernelPreferredWorkGroupSizeMultiple");
722         case SpvOpRetainEvent:
723             return String("RetainEvent");
724         case SpvOpReleaseEvent:
725             return String("ReleaseEvent");
726         case SpvOpCreateUserEvent:
727             return String("CreateUserEvent");
728         case SpvOpIsValidEvent:
729             return String("IsValidEvent");
730         case SpvOpSetUserEventStatus:
731             return String("SetUserEventStatus");
732         case SpvOpCaptureEventProfilingInfo:
733             return String("CaptureEventProfilingInfo");
734         case SpvOpGetDefaultQueue:
735             return String("GetDefaultQueue");
736         case SpvOpBuildNDRange:
737             return String("BuildNDRange");
738         case SpvOpImageSparseSampleImplicitLod:
739             return String("ImageSparseSampleImplicitLod");
740         case SpvOpImageSparseSampleExplicitLod:
741             return String("ImageSparseSampleExplicitLod");
742         case SpvOpImageSparseSampleDrefImplicitLod:
743             return String("ImageSparseSampleDrefImplicitLod");
744         case SpvOpImageSparseSampleDrefExplicitLod:
745             return String("ImageSparseSampleDrefExplicitLod");
746         case SpvOpImageSparseSampleProjImplicitLod:
747             return String("ImageSparseSampleProjImplicitLod");
748         case SpvOpImageSparseSampleProjExplicitLod:
749             return String("ImageSparseSampleProjExplicitLod");
750         case SpvOpImageSparseSampleProjDrefImplicitLod:
751             return String("ImageSparseSampleProjDrefImplicitLod");
752         case SpvOpImageSparseSampleProjDrefExplicitLod:
753             return String("ImageSparseSampleProjDrefExplicitLod");
754         case SpvOpImageSparseFetch:
755             return String("ImageSparseFetch");
756         case SpvOpImageSparseGather:
757             return String("ImageSparseGather");
758         case SpvOpImageSparseDrefGather:
759             return String("ImageSparseDrefGather");
760         case SpvOpImageSparseTexelsResident:
761             return String("ImageSparseTexelsResident");
762         case SpvOpNoLine:
763             return String("NoLine");
764         case SpvOpAtomicFlagTestAndSet:
765             return String("AtomicFlagTestAndSet");
766         case SpvOpAtomicFlagClear:
767             return String("AtomicFlagClear");
768         case SpvOpImageSparseRead:
769             return String("ImageSparseRead");
770         default:
771             ABORT("unsupported SPIR-V op");
772     }
773 }
774 #endif
775 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)776 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
777     ASSERT(opCode != SpvOpUndef);
778     switch (opCode) {
779         case SpvOpReturn:      // fall through
780         case SpvOpReturnValue: // fall through
781         case SpvOpKill:        // fall through
782         case SpvOpBranch:      // fall through
783         case SpvOpBranchConditional:
784             ASSERT(fCurrentBlock);
785             fCurrentBlock = 0;
786             break;
787         case SpvOpConstant:          // fall through
788         case SpvOpConstantTrue:      // fall through
789         case SpvOpConstantFalse:     // fall through
790         case SpvOpConstantComposite: // fall through
791         case SpvOpTypeVoid:          // fall through
792         case SpvOpTypeInt:           // fall through
793         case SpvOpTypeFloat:         // fall through
794         case SpvOpTypeBool:          // fall through
795         case SpvOpTypeVector:        // fall through
796         case SpvOpTypeMatrix:        // fall through
797         case SpvOpTypeArray:         // fall through
798         case SpvOpTypePointer:       // fall through
799         case SpvOpTypeFunction:      // fall through
800         case SpvOpTypeRuntimeArray:  // fall through
801         case SpvOpTypeStruct:        // fall through
802         case SpvOpTypeImage:         // fall through
803         case SpvOpTypeSampledImage:  // fall through
804         case SpvOpVariable:          // fall through
805         case SpvOpFunction:          // fall through
806         case SpvOpFunctionParameter: // fall through
807         case SpvOpFunctionEnd:       // fall through
808         case SpvOpExecutionMode:     // fall through
809         case SpvOpMemoryModel:       // fall through
810         case SpvOpCapability:        // fall through
811         case SpvOpExtInstImport:     // fall through
812         case SpvOpEntryPoint:        // fall through
813         case SpvOpSource:            // fall through
814         case SpvOpSourceExtension:   // fall through
815         case SpvOpName:              // fall through
816         case SpvOpMemberName:        // fall through
817         case SpvOpDecorate:          // fall through
818         case SpvOpMemberDecorate:
819             break;
820         default:
821             ASSERT(fCurrentBlock);
822     }
823 #if SPIRV_DEBUG
824     out << std::endl << opcode_text(opCode) << " ";
825 #else
826     this->writeWord((length << 16) | opCode, out);
827 #endif
828 }
829 
writeLabel(SpvId label,OutputStream & out)830 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
831     fCurrentBlock = label;
832     this->writeInstruction(SpvOpLabel, label, out);
833 }
834 
writeInstruction(SpvOp_ opCode,OutputStream & out)835 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
836     this->writeOpCode(opCode, 1, out);
837 }
838 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)839 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
840     this->writeOpCode(opCode, 2, out);
841     this->writeWord(word1, out);
842 }
843 
writeString(const char * string,OutputStream & out)844 void SPIRVCodeGenerator::writeString(const char* string, OutputStream& out) {
845     size_t length = strlen(string);
846     out.write(string, length);
847     switch (length % 4) {
848         case 1:
849             out.write8(0);
850             // fall through
851         case 2:
852             out.write8(0);
853             // fall through
854         case 3:
855             out.write8(0);
856             break;
857         default:
858             this->writeWord(0, out);
859     }
860 }
861 
writeInstruction(SpvOp_ opCode,const char * string,OutputStream & out)862 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, OutputStream& out) {
863     int32_t length = (int32_t) strlen(string);
864     this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
865     this->writeString(string, out);
866 }
867 
868 
writeInstruction(SpvOp_ opCode,int32_t word1,const char * string,OutputStream & out)869 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
870                                           OutputStream& out) {
871     int32_t length = (int32_t) strlen(string);
872     this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
873     this->writeWord(word1, out);
874     this->writeString(string, out);
875 }
876 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,const char * string,OutputStream & out)877 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
878                                           const char* string, OutputStream& out) {
879     int32_t length = (int32_t) strlen(string);
880     this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
881     this->writeWord(word1, out);
882     this->writeWord(word2, out);
883     this->writeString(string, out);
884 }
885 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)886 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
887                                           OutputStream& out) {
888     this->writeOpCode(opCode, 3, out);
889     this->writeWord(word1, out);
890     this->writeWord(word2, out);
891 }
892 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)893 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
894                                           int32_t word3, OutputStream& out) {
895     this->writeOpCode(opCode, 4, out);
896     this->writeWord(word1, out);
897     this->writeWord(word2, out);
898     this->writeWord(word3, out);
899 }
900 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)901 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
902                                           int32_t word3, int32_t word4, OutputStream& out) {
903     this->writeOpCode(opCode, 5, out);
904     this->writeWord(word1, out);
905     this->writeWord(word2, out);
906     this->writeWord(word3, out);
907     this->writeWord(word4, out);
908 }
909 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)910 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
911                                           int32_t word3, int32_t word4, int32_t word5,
912                                           OutputStream& out) {
913     this->writeOpCode(opCode, 6, out);
914     this->writeWord(word1, out);
915     this->writeWord(word2, out);
916     this->writeWord(word3, out);
917     this->writeWord(word4, out);
918     this->writeWord(word5, out);
919 }
920 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)921 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
922                                           int32_t word3, int32_t word4, int32_t word5,
923                                           int32_t word6, OutputStream& out) {
924     this->writeOpCode(opCode, 7, out);
925     this->writeWord(word1, out);
926     this->writeWord(word2, out);
927     this->writeWord(word3, out);
928     this->writeWord(word4, out);
929     this->writeWord(word5, out);
930     this->writeWord(word6, out);
931 }
932 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)933 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
934                                           int32_t word3, int32_t word4, int32_t word5,
935                                           int32_t word6, int32_t word7, OutputStream& out) {
936     this->writeOpCode(opCode, 8, out);
937     this->writeWord(word1, out);
938     this->writeWord(word2, out);
939     this->writeWord(word3, out);
940     this->writeWord(word4, out);
941     this->writeWord(word5, out);
942     this->writeWord(word6, out);
943     this->writeWord(word7, out);
944 }
945 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)946 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
947                                           int32_t word3, int32_t word4, int32_t word5,
948                                           int32_t word6, int32_t word7, int32_t word8,
949                                           OutputStream& out) {
950     this->writeOpCode(opCode, 9, out);
951     this->writeWord(word1, out);
952     this->writeWord(word2, out);
953     this->writeWord(word3, out);
954     this->writeWord(word4, out);
955     this->writeWord(word5, out);
956     this->writeWord(word6, out);
957     this->writeWord(word7, out);
958     this->writeWord(word8, out);
959 }
960 
writeCapabilities(OutputStream & out)961 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
962     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
963         if (fCapabilities & bit) {
964             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
965         }
966     }
967 }
968 
nextId()969 SpvId SPIRVCodeGenerator::nextId() {
970     return fIdCount++;
971 }
972 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)973 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
974                                      SpvId resultId) {
975     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
976     // go ahead and write all of the field types, so we don't inadvertently write them while we're
977     // in the middle of writing the struct instruction
978     std::vector<SpvId> types;
979     for (const auto& f : type.fields()) {
980         types.push_back(this->getType(*f.fType, memoryLayout));
981     }
982     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
983     this->writeWord(resultId, fConstantBuffer);
984     for (SpvId id : types) {
985         this->writeWord(id, fConstantBuffer);
986     }
987     size_t offset = 0;
988     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
989         size_t size = memoryLayout.size(*type.fields()[i].fType);
990         size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
991         const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
992         if (fieldLayout.fOffset >= 0) {
993             if (fieldLayout.fOffset < (int) offset) {
994                 fErrors.error(type.fPosition,
995                               "offset of field '" + type.fields()[i].fName + "' must be at "
996                               "least " + to_string((int) offset));
997             }
998             if (fieldLayout.fOffset % alignment) {
999                 fErrors.error(type.fPosition,
1000                               "offset of field '" + type.fields()[i].fName + "' must be a multiple"
1001                               " of " + to_string((int) alignment));
1002             }
1003             offset = fieldLayout.fOffset;
1004         } else {
1005             size_t mod = offset % alignment;
1006             if (mod) {
1007                 offset += alignment - mod;
1008             }
1009         }
1010         this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
1011                                fNameBuffer);
1012         this->writeLayout(fieldLayout, resultId, i);
1013         if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
1014             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1015                                    (SpvId) offset, fDecorationBuffer);
1016         }
1017         if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
1018             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1019                                    fDecorationBuffer);
1020             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1021                                    (SpvId) memoryLayout.stride(*type.fields()[i].fType),
1022                                    fDecorationBuffer);
1023         }
1024         offset += size;
1025         Type::Kind kind = type.fields()[i].fType->kind();
1026         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
1027             offset += alignment - offset % alignment;
1028         }
1029     }
1030 }
1031 
getType(const Type & type)1032 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1033     return this->getType(type, fDefaultLayout);
1034 }
1035 
getType(const Type & type,const MemoryLayout & layout)1036 SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) {
1037     String key = type.name() + to_string((int) layout.fStd);
1038     auto entry = fTypeMap.find(key);
1039     if (entry == fTypeMap.end()) {
1040         SpvId result = this->nextId();
1041         switch (type.kind()) {
1042             case Type::kScalar_Kind:
1043                 if (type == *fContext.fBool_Type) {
1044                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
1045                 } else if (type == *fContext.fInt_Type) {
1046                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
1047                 } else if (type == *fContext.fUInt_Type) {
1048                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
1049                 } else if (type == *fContext.fFloat_Type) {
1050                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
1051                 } else if (type == *fContext.fDouble_Type) {
1052                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
1053                 } else {
1054                     ASSERT(false);
1055                 }
1056                 break;
1057             case Type::kVector_Kind:
1058                 this->writeInstruction(SpvOpTypeVector, result,
1059                                        this->getType(type.componentType(), layout),
1060                                        type.columns(), fConstantBuffer);
1061                 break;
1062             case Type::kMatrix_Kind:
1063                 this->writeInstruction(SpvOpTypeMatrix, result,
1064                                        this->getType(index_type(fContext, type), layout),
1065                                        type.columns(), fConstantBuffer);
1066                 break;
1067             case Type::kStruct_Kind:
1068                 this->writeStruct(type, layout, result);
1069                 break;
1070             case Type::kArray_Kind: {
1071                 if (type.columns() > 0) {
1072                     IntLiteral count(fContext, Position(), type.columns());
1073                     this->writeInstruction(SpvOpTypeArray, result,
1074                                            this->getType(type.componentType(), layout),
1075                                            this->writeIntLiteral(count), fConstantBuffer);
1076                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1077                                            (int32_t) layout.stride(type),
1078                                            fDecorationBuffer);
1079                 } else {
1080                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
1081                                            this->getType(type.componentType(), layout),
1082                                            fConstantBuffer);
1083                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1084                                            (int32_t) layout.stride(type),
1085                                            fDecorationBuffer);
1086                 }
1087                 break;
1088             }
1089             case Type::kSampler_Kind: {
1090                 SpvId image = result;
1091                 if (SpvDimSubpassData != type.dimensions()) {
1092                     image = this->nextId();
1093                 }
1094                 if (SpvDimBuffer == type.dimensions()) {
1095                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
1096                 }
1097                 this->writeInstruction(SpvOpTypeImage, image,
1098                                        this->getType(*fContext.fFloat_Type, layout),
1099                                        type.dimensions(), type.isDepth(), type.isArrayed(),
1100                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
1101                                        SpvImageFormatUnknown, fConstantBuffer);
1102                 fImageTypeMap[key] = image;
1103                 if (SpvDimSubpassData != type.dimensions()) {
1104                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
1105                 }
1106                 break;
1107             }
1108             default:
1109                 if (type == *fContext.fVoid_Type) {
1110                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
1111                 } else {
1112                     ABORT("invalid type: %s", type.description().c_str());
1113                 }
1114         }
1115         fTypeMap[key] = result;
1116         return result;
1117     }
1118     return entry->second;
1119 }
1120 
getImageType(const Type & type)1121 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
1122     ASSERT(type.kind() == Type::kSampler_Kind);
1123     this->getType(type);
1124     String key = type.name() + to_string((int) fDefaultLayout.fStd);
1125     ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
1126     return fImageTypeMap[key];
1127 }
1128 
getFunctionType(const FunctionDeclaration & function)1129 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1130     String key = function.fReturnType.description() + "(";
1131     String separator;
1132     for (size_t i = 0; i < function.fParameters.size(); i++) {
1133         key += separator;
1134         separator = ", ";
1135         key += function.fParameters[i]->fType.description();
1136     }
1137     key += ")";
1138     auto entry = fTypeMap.find(key);
1139     if (entry == fTypeMap.end()) {
1140         SpvId result = this->nextId();
1141         int32_t length = 3 + (int32_t) function.fParameters.size();
1142         SpvId returnType = this->getType(function.fReturnType);
1143         std::vector<SpvId> parameterTypes;
1144         for (size_t i = 0; i < function.fParameters.size(); i++) {
1145             // glslang seems to treat all function arguments as pointers whether they need to be or
1146             // not. I  was initially puzzled by this until I ran bizarre failures with certain
1147             // patterns of function calls and control constructs, as exemplified by this minimal
1148             // failure case:
1149             //
1150             // void sphere(float x) {
1151             // }
1152             //
1153             // void map() {
1154             //     sphere(1.0);
1155             // }
1156             //
1157             // void main() {
1158             //     for (int i = 0; i < 1; i++) {
1159             //         map();
1160             //     }
1161             // }
1162             //
1163             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1164             // crashes. Making it take a float* and storing the argument in a temporary variable,
1165             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
1166             // the spec makes this make sense.
1167 //            if (is_out(function->fParameters[i])) {
1168                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
1169                                                               SpvStorageClassFunction));
1170 //            } else {
1171 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
1172 //            }
1173         }
1174         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
1175         this->writeWord(result, fConstantBuffer);
1176         this->writeWord(returnType, fConstantBuffer);
1177         for (SpvId id : parameterTypes) {
1178             this->writeWord(id, fConstantBuffer);
1179         }
1180         fTypeMap[key] = result;
1181         return result;
1182     }
1183     return entry->second;
1184 }
1185 
getPointerType(const Type & type,SpvStorageClass_ storageClass)1186 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1187     return this->getPointerType(type, fDefaultLayout, storageClass);
1188 }
1189 
getPointerType(const Type & type,const MemoryLayout & layout,SpvStorageClass_ storageClass)1190 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1191                                          SpvStorageClass_ storageClass) {
1192     String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
1193     auto entry = fTypeMap.find(key);
1194     if (entry == fTypeMap.end()) {
1195         SpvId result = this->nextId();
1196         this->writeInstruction(SpvOpTypePointer, result, storageClass,
1197                                this->getType(type), fConstantBuffer);
1198         fTypeMap[key] = result;
1199         return result;
1200     }
1201     return entry->second;
1202 }
1203 
writeExpression(const Expression & expr,OutputStream & out)1204 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1205     switch (expr.fKind) {
1206         case Expression::kBinary_Kind:
1207             return this->writeBinaryExpression((BinaryExpression&) expr, out);
1208         case Expression::kBoolLiteral_Kind:
1209             return this->writeBoolLiteral((BoolLiteral&) expr);
1210         case Expression::kConstructor_Kind:
1211             return this->writeConstructor((Constructor&) expr, out);
1212         case Expression::kIntLiteral_Kind:
1213             return this->writeIntLiteral((IntLiteral&) expr);
1214         case Expression::kFieldAccess_Kind:
1215             return this->writeFieldAccess(((FieldAccess&) expr), out);
1216         case Expression::kFloatLiteral_Kind:
1217             return this->writeFloatLiteral(((FloatLiteral&) expr));
1218         case Expression::kFunctionCall_Kind:
1219             return this->writeFunctionCall((FunctionCall&) expr, out);
1220         case Expression::kPrefix_Kind:
1221             return this->writePrefixExpression((PrefixExpression&) expr, out);
1222         case Expression::kPostfix_Kind:
1223             return this->writePostfixExpression((PostfixExpression&) expr, out);
1224         case Expression::kSwizzle_Kind:
1225             return this->writeSwizzle((Swizzle&) expr, out);
1226         case Expression::kVariableReference_Kind:
1227             return this->writeVariableReference((VariableReference&) expr, out);
1228         case Expression::kTernary_Kind:
1229             return this->writeTernaryExpression((TernaryExpression&) expr, out);
1230         case Expression::kIndex_Kind:
1231             return this->writeIndexExpression((IndexExpression&) expr, out);
1232         default:
1233             ABORT("unsupported expression: %s", expr.description().c_str());
1234     }
1235     return -1;
1236 }
1237 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1238 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1239     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
1240     ASSERT(intrinsic != fIntrinsicMap.end());
1241     const Type& type = c.fArguments[0]->fType;
1242     int32_t intrinsicId;
1243     if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
1244         intrinsicId = std::get<1>(intrinsic->second);
1245     } else if (is_signed(fContext, type)) {
1246         intrinsicId = std::get<2>(intrinsic->second);
1247     } else if (is_unsigned(fContext, type)) {
1248         intrinsicId = std::get<3>(intrinsic->second);
1249     } else if (is_bool(fContext, type)) {
1250         intrinsicId = std::get<4>(intrinsic->second);
1251     } else {
1252         ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
1253               type.description().c_str());
1254     }
1255     switch (std::get<0>(intrinsic->second)) {
1256         case kGLSL_STD_450_IntrinsicKind: {
1257             SpvId result = this->nextId();
1258             std::vector<SpvId> arguments;
1259             for (size_t i = 0; i < c.fArguments.size(); i++) {
1260                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1261             }
1262             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1263             this->writeWord(this->getType(c.fType), out);
1264             this->writeWord(result, out);
1265             this->writeWord(fGLSLExtendedInstructions, out);
1266             this->writeWord(intrinsicId, out);
1267             for (SpvId id : arguments) {
1268                 this->writeWord(id, out);
1269             }
1270             return result;
1271         }
1272         case kSPIRV_IntrinsicKind: {
1273             SpvId result = this->nextId();
1274             std::vector<SpvId> arguments;
1275             for (size_t i = 0; i < c.fArguments.size(); i++) {
1276                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1277             }
1278             this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1279             this->writeWord(this->getType(c.fType), out);
1280             this->writeWord(result, out);
1281             for (SpvId id : arguments) {
1282                 this->writeWord(id, out);
1283             }
1284             return result;
1285         }
1286         case kSpecial_IntrinsicKind:
1287             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1288         default:
1289             ABORT("unsupported intrinsic kind");
1290     }
1291 }
1292 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)1293 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1294                                                 OutputStream& out) {
1295     SpvId result = this->nextId();
1296     switch (kind) {
1297         case kAtan_SpecialIntrinsic: {
1298             std::vector<SpvId> arguments;
1299             for (size_t i = 0; i < c.fArguments.size(); i++) {
1300                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1301             }
1302             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1303             this->writeWord(this->getType(c.fType), out);
1304             this->writeWord(result, out);
1305             this->writeWord(fGLSLExtendedInstructions, out);
1306             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1307             for (SpvId id : arguments) {
1308                 this->writeWord(id, out);
1309             }
1310             break;
1311         }
1312         case kSubpassLoad_SpecialIntrinsic: {
1313             SpvId img = this->writeExpression(*c.fArguments[0], out);
1314             std::vector<std::unique_ptr<Expression>> args;
1315             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1316             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1317             Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args));
1318             SpvId coords = this->writeConstantVector(ctor);
1319             if (1 == c.fArguments.size()) {
1320                 this->writeInstruction(SpvOpImageRead,
1321                                        this->getType(c.fType),
1322                                        result,
1323                                        img,
1324                                        coords,
1325                                        out);
1326             } else {
1327                 ASSERT(2 == c.fArguments.size());
1328                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
1329                 this->writeInstruction(SpvOpImageRead,
1330                                        this->getType(c.fType),
1331                                        result,
1332                                        img,
1333                                        coords,
1334                                        SpvImageOperandsSampleMask,
1335                                        sample,
1336                                        out);
1337             }
1338             break;
1339         }
1340         case kTexelFetch_SpecialIntrinsic: {
1341             ASSERT(c.fArguments.size() == 2);
1342             SpvId image = this->nextId();
1343             this->writeInstruction(SpvOpImage,
1344                                    this->getImageType(c.fArguments[0]->fType),
1345                                    image,
1346                                    this->writeExpression(*c.fArguments[0], out),
1347                                    out);
1348             this->writeInstruction(SpvOpImageFetch,
1349                                    this->getType(c.fType),
1350                                    result,
1351                                    image,
1352                                    this->writeExpression(*c.fArguments[1], out),
1353                                    out);
1354             break;
1355         }
1356         case kTexture_SpecialIntrinsic: {
1357             SpvOp_ op = SpvOpImageSampleImplicitLod;
1358             switch (c.fArguments[0]->fType.dimensions()) {
1359                 case SpvDim1D:
1360                     if (c.fArguments[1]->fType == *fContext.fVec2_Type) {
1361                         op = SpvOpImageSampleProjImplicitLod;
1362                     } else {
1363                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
1364                     }
1365                     break;
1366                 case SpvDim2D:
1367                     if (c.fArguments[1]->fType == *fContext.fVec3_Type) {
1368                         op = SpvOpImageSampleProjImplicitLod;
1369                     } else {
1370                         ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type);
1371                     }
1372                     break;
1373                 case SpvDim3D:
1374                     if (c.fArguments[1]->fType == *fContext.fVec4_Type) {
1375                         op = SpvOpImageSampleProjImplicitLod;
1376                     } else {
1377                         ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type);
1378                     }
1379                     break;
1380                 case SpvDimCube:   // fall through
1381                 case SpvDimRect:   // fall through
1382                 case SpvDimBuffer: // fall through
1383                 case SpvDimSubpassData:
1384                     break;
1385             }
1386             SpvId type = this->getType(c.fType);
1387             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1388             SpvId uv = this->writeExpression(*c.fArguments[1], out);
1389             if (c.fArguments.size() == 3) {
1390                 this->writeInstruction(op, type, result, sampler, uv,
1391                                        SpvImageOperandsBiasMask,
1392                                        this->writeExpression(*c.fArguments[2], out),
1393                                        out);
1394             } else {
1395                 ASSERT(c.fArguments.size() == 2);
1396                 this->writeInstruction(op, type, result, sampler, uv,
1397                                        out);
1398             }
1399             break;
1400         }
1401     }
1402     return result;
1403 }
1404 
writeFunctionCall(const FunctionCall & c,OutputStream & out)1405 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1406     const auto& entry = fFunctionMap.find(&c.fFunction);
1407     if (entry == fFunctionMap.end()) {
1408         return this->writeIntrinsicCall(c, out);
1409     }
1410     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1411     std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
1412     std::vector<SpvId> arguments;
1413     for (size_t i = 0; i < c.fArguments.size(); i++) {
1414         // id of temporary variable that we will use to hold this argument, or 0 if it is being
1415         // passed directly
1416         SpvId tmpVar;
1417         // if we need a temporary var to store this argument, this is the value to store in the var
1418         SpvId tmpValueId;
1419         if (is_out(*c.fFunction.fParameters[i])) {
1420             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1421             SpvId ptr = lv->getPointer();
1422             if (ptr) {
1423                 arguments.push_back(ptr);
1424                 continue;
1425             } else {
1426                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1427                 // copy it into a temp, call the function, read the value out of the temp, and then
1428                 // update the lvalue.
1429                 tmpValueId = lv->load(out);
1430                 tmpVar = this->nextId();
1431                 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
1432                                   std::move(lv)));
1433             }
1434         } else {
1435             // see getFunctionType for an explanation of why we're always using pointer parameters
1436             tmpValueId = this->writeExpression(*c.fArguments[i], out);
1437             tmpVar = this->nextId();
1438         }
1439         this->writeInstruction(SpvOpVariable,
1440                                this->getPointerType(c.fArguments[i]->fType,
1441                                                     SpvStorageClassFunction),
1442                                tmpVar,
1443                                SpvStorageClassFunction,
1444                                fVariableBuffer);
1445         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1446         arguments.push_back(tmpVar);
1447     }
1448     SpvId result = this->nextId();
1449     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1450     this->writeWord(this->getType(c.fType), out);
1451     this->writeWord(result, out);
1452     this->writeWord(entry->second, out);
1453     for (SpvId id : arguments) {
1454         this->writeWord(id, out);
1455     }
1456     // now that the call is complete, we may need to update some lvalues with the new values of out
1457     // arguments
1458     for (const auto& tuple : lvalues) {
1459         SpvId load = this->nextId();
1460         this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1461         std::get<2>(tuple)->store(load, out);
1462     }
1463     return result;
1464 }
1465 
writeConstantVector(const Constructor & c)1466 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1467     ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1468     SpvId result = this->nextId();
1469     std::vector<SpvId> arguments;
1470     for (size_t i = 0; i < c.fArguments.size(); i++) {
1471         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1472     }
1473     SpvId type = this->getType(c.fType);
1474     if (c.fArguments.size() == 1) {
1475         // with a single argument, a vector will have all of its entries equal to the argument
1476         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1477         this->writeWord(type, fConstantBuffer);
1478         this->writeWord(result, fConstantBuffer);
1479         for (int i = 0; i < c.fType.columns(); i++) {
1480             this->writeWord(arguments[0], fConstantBuffer);
1481         }
1482     } else {
1483         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1484                           fConstantBuffer);
1485         this->writeWord(type, fConstantBuffer);
1486         this->writeWord(result, fConstantBuffer);
1487         for (SpvId id : arguments) {
1488             this->writeWord(id, fConstantBuffer);
1489         }
1490     }
1491     return result;
1492 }
1493 
writeFloatConstructor(const Constructor & c,OutputStream & out)1494 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1495     ASSERT(c.fType == *fContext.fFloat_Type);
1496     ASSERT(c.fArguments.size() == 1);
1497     ASSERT(c.fArguments[0]->fType.isNumber());
1498     SpvId result = this->nextId();
1499     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1500     if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1501         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1502                                out);
1503     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1504         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1505                                out);
1506     } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1507         return parameter;
1508     }
1509     return result;
1510 }
1511 
writeIntConstructor(const Constructor & c,OutputStream & out)1512 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1513     ASSERT(c.fType == *fContext.fInt_Type);
1514     ASSERT(c.fArguments.size() == 1);
1515     ASSERT(c.fArguments[0]->fType.isNumber());
1516     SpvId result = this->nextId();
1517     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1518     if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1519         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1520                                out);
1521     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1522         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1523                                out);
1524     } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1525         return parameter;
1526     }
1527     return result;
1528 }
1529 
writeUIntConstructor(const Constructor & c,OutputStream & out)1530 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1531     ASSERT(c.fType == *fContext.fUInt_Type);
1532     ASSERT(c.fArguments.size() == 1);
1533     ASSERT(c.fArguments[0]->fType.isNumber());
1534     SpvId result = this->nextId();
1535     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1536     if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1537         this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1538                                out);
1539     } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1540         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1541                                out);
1542     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1543         return parameter;
1544     }
1545     return result;
1546 }
1547 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1548 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1549                                                  OutputStream& out) {
1550     FloatLiteral zero(fContext, Position(), 0);
1551     SpvId zeroId = this->writeFloatLiteral(zero);
1552     std::vector<SpvId> columnIds;
1553     for (int column = 0; column < type.columns(); column++) {
1554         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1555                           out);
1556         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1557                         out);
1558         SpvId columnId = this->nextId();
1559         this->writeWord(columnId, out);
1560         columnIds.push_back(columnId);
1561         for (int row = 0; row < type.columns(); row++) {
1562             this->writeWord(row == column ? diagonal : zeroId, out);
1563         }
1564     }
1565     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1566                       out);
1567     this->writeWord(this->getType(type), out);
1568     this->writeWord(id, out);
1569     for (SpvId id : columnIds) {
1570         this->writeWord(id, out);
1571     }
1572 }
1573 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1574 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1575                                          const Type& dstType, OutputStream& out) {
1576     ABORT("unimplemented");
1577 }
1578 
writeMatrixConstructor(const Constructor & c,OutputStream & out)1579 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1580     ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1581     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1582     // an instruction
1583     std::vector<SpvId> arguments;
1584     for (size_t i = 0; i < c.fArguments.size(); i++) {
1585         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1586     }
1587     SpvId result = this->nextId();
1588     int rows = c.fType.rows();
1589     int columns = c.fType.columns();
1590     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1591         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1592     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1593         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1594     } else {
1595         std::vector<SpvId> columnIds;
1596         int currentCount = 0;
1597         for (size_t i = 0; i < arguments.size(); i++) {
1598             if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1599                 ASSERT(currentCount == 0);
1600                 columnIds.push_back(arguments[i]);
1601                 currentCount = 0;
1602             } else {
1603                 ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
1604                 if (currentCount == 0) {
1605                     this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
1606                     this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1607                                                                                      1)),
1608                                     out);
1609                     SpvId id = this->nextId();
1610                     this->writeWord(id, out);
1611                     columnIds.push_back(id);
1612                 }
1613                 this->writeWord(arguments[i], out);
1614                 currentCount = (currentCount + 1) % rows;
1615             }
1616         }
1617         ASSERT(columnIds.size() == (size_t) columns);
1618         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1619         this->writeWord(this->getType(c.fType), out);
1620         this->writeWord(result, out);
1621         for (SpvId id : columnIds) {
1622             this->writeWord(id, out);
1623         }
1624     }
1625     return result;
1626 }
1627 
writeVectorConstructor(const Constructor & c,OutputStream & out)1628 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1629     ASSERT(c.fType.kind() == Type::kVector_Kind);
1630     if (c.isConstant()) {
1631         return this->writeConstantVector(c);
1632     }
1633     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1634     // an instruction
1635     std::vector<SpvId> arguments;
1636     for (size_t i = 0; i < c.fArguments.size(); i++) {
1637         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1638     }
1639     SpvId result = this->nextId();
1640     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1641         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1642         this->writeWord(this->getType(c.fType), out);
1643         this->writeWord(result, out);
1644         for (int i = 0; i < c.fType.columns(); i++) {
1645             this->writeWord(arguments[0], out);
1646         }
1647     } else {
1648         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1649         this->writeWord(this->getType(c.fType), out);
1650         this->writeWord(result, out);
1651         for (SpvId id : arguments) {
1652             this->writeWord(id, out);
1653         }
1654     }
1655     return result;
1656 }
1657 
writeConstructor(const Constructor & c,OutputStream & out)1658 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1659     if (c.fType == *fContext.fFloat_Type) {
1660         return this->writeFloatConstructor(c, out);
1661     } else if (c.fType == *fContext.fInt_Type) {
1662         return this->writeIntConstructor(c, out);
1663     } else if (c.fType == *fContext.fUInt_Type) {
1664         return this->writeUIntConstructor(c, out);
1665     }
1666     switch (c.fType.kind()) {
1667         case Type::kVector_Kind:
1668             return this->writeVectorConstructor(c, out);
1669         case Type::kMatrix_Kind:
1670             return this->writeMatrixConstructor(c, out);
1671         default:
1672             ABORT("unsupported constructor: %s", c.description().c_str());
1673     }
1674 }
1675 
get_storage_class(const Modifiers & modifiers)1676 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1677     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1678         ASSERT(!modifiers.fLayout.fPushConstant);
1679         return SpvStorageClassInput;
1680     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1681         ASSERT(!modifiers.fLayout.fPushConstant);
1682         return SpvStorageClassOutput;
1683     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1684         if (modifiers.fLayout.fPushConstant) {
1685             return SpvStorageClassPushConstant;
1686         }
1687         return SpvStorageClassUniform;
1688     } else {
1689         return SpvStorageClassFunction;
1690     }
1691 }
1692 
get_storage_class(const Expression & expr)1693 SpvStorageClass_ get_storage_class(const Expression& expr) {
1694     switch (expr.fKind) {
1695         case Expression::kVariableReference_Kind: {
1696             const Variable& var = ((VariableReference&) expr).fVariable;
1697             if (var.fStorage != Variable::kGlobal_Storage) {
1698                 return SpvStorageClassFunction;
1699             }
1700             return get_storage_class(var.fModifiers);
1701         }
1702         case Expression::kFieldAccess_Kind:
1703             return get_storage_class(*((FieldAccess&) expr).fBase);
1704         case Expression::kIndex_Kind:
1705             return get_storage_class(*((IndexExpression&) expr).fBase);
1706         default:
1707             return SpvStorageClassFunction;
1708     }
1709 }
1710 
getAccessChain(const Expression & expr,OutputStream & out)1711 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1712     std::vector<SpvId> chain;
1713     switch (expr.fKind) {
1714         case Expression::kIndex_Kind: {
1715             IndexExpression& indexExpr = (IndexExpression&) expr;
1716             chain = this->getAccessChain(*indexExpr.fBase, out);
1717             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1718             break;
1719         }
1720         case Expression::kFieldAccess_Kind: {
1721             FieldAccess& fieldExpr = (FieldAccess&) expr;
1722             chain = this->getAccessChain(*fieldExpr.fBase, out);
1723             IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
1724             chain.push_back(this->writeIntLiteral(index));
1725             break;
1726         }
1727         default:
1728             chain.push_back(this->getLValue(expr, out)->getPointer());
1729     }
1730     return chain;
1731 }
1732 
1733 class PointerLValue : public SPIRVCodeGenerator::LValue {
1734 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1735     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1736     : fGen(gen)
1737     , fPointer(pointer)
1738     , fType(type) {}
1739 
getPointer()1740     virtual SpvId getPointer() override {
1741         return fPointer;
1742     }
1743 
load(OutputStream & out)1744     virtual SpvId load(OutputStream& out) override {
1745         SpvId result = fGen.nextId();
1746         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1747         return result;
1748     }
1749 
store(SpvId value,OutputStream & out)1750     virtual void store(SpvId value, OutputStream& out) override {
1751         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1752     }
1753 
1754 private:
1755     SPIRVCodeGenerator& fGen;
1756     const SpvId fPointer;
1757     const SpvId fType;
1758 };
1759 
1760 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1761 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1762     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1763                   const Type& baseType, const Type& swizzleType)
1764     : fGen(gen)
1765     , fVecPointer(vecPointer)
1766     , fComponents(components)
1767     , fBaseType(baseType)
1768     , fSwizzleType(swizzleType) {}
1769 
getPointer()1770     virtual SpvId getPointer() override {
1771         return 0;
1772     }
1773 
load(OutputStream & out)1774     virtual SpvId load(OutputStream& out) override {
1775         SpvId base = fGen.nextId();
1776         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1777         SpvId result = fGen.nextId();
1778         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1779         fGen.writeWord(fGen.getType(fSwizzleType), out);
1780         fGen.writeWord(result, out);
1781         fGen.writeWord(base, out);
1782         fGen.writeWord(base, out);
1783         for (int component : fComponents) {
1784             fGen.writeWord(component, out);
1785         }
1786         return result;
1787     }
1788 
store(SpvId value,OutputStream & out)1789     virtual void store(SpvId value, OutputStream& out) override {
1790         // use OpVectorShuffle to mix and match the vector components. We effectively create
1791         // a virtual vector out of the concatenation of the left and right vectors, and then
1792         // select components from this virtual vector to make the result vector. For
1793         // instance, given:
1794         // vec3 L = ...;
1795         // vec3 R = ...;
1796         // L.xz = R.xy;
1797         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1798         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1799         // (3, 1, 4).
1800         SpvId base = fGen.nextId();
1801         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1802         SpvId shuffle = fGen.nextId();
1803         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1804         fGen.writeWord(fGen.getType(fBaseType), out);
1805         fGen.writeWord(shuffle, out);
1806         fGen.writeWord(base, out);
1807         fGen.writeWord(value, out);
1808         for (int i = 0; i < fBaseType.columns(); i++) {
1809             // current offset into the virtual vector, defaults to pulling the unmodified
1810             // value from the left side
1811             int offset = i;
1812             // check to see if we are writing this component
1813             for (size_t j = 0; j < fComponents.size(); j++) {
1814                 if (fComponents[j] == i) {
1815                     // we're writing to this component, so adjust the offset to pull from
1816                     // the correct component of the right side instead of preserving the
1817                     // value from the left
1818                     offset = (int) (j + fBaseType.columns());
1819                     break;
1820                 }
1821             }
1822             fGen.writeWord(offset, out);
1823         }
1824         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1825     }
1826 
1827 private:
1828     SPIRVCodeGenerator& fGen;
1829     const SpvId fVecPointer;
1830     const std::vector<int>& fComponents;
1831     const Type& fBaseType;
1832     const Type& fSwizzleType;
1833 };
1834 
getLValue(const Expression & expr,OutputStream & out)1835 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1836                                                                           OutputStream& out) {
1837     switch (expr.fKind) {
1838         case Expression::kVariableReference_Kind: {
1839             const Variable& var = ((VariableReference&) expr).fVariable;
1840             auto entry = fVariableMap.find(&var);
1841             ASSERT(entry != fVariableMap.end());
1842             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1843                                                                        *this,
1844                                                                        entry->second,
1845                                                                        this->getType(expr.fType)));
1846         }
1847         case Expression::kIndex_Kind: // fall through
1848         case Expression::kFieldAccess_Kind: {
1849             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1850             SpvId member = this->nextId();
1851             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1852             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1853             this->writeWord(member, out);
1854             for (SpvId idx : chain) {
1855                 this->writeWord(idx, out);
1856             }
1857             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1858                                                                        *this,
1859                                                                        member,
1860                                                                        this->getType(expr.fType)));
1861         }
1862 
1863         case Expression::kSwizzle_Kind: {
1864             Swizzle& swizzle = (Swizzle&) expr;
1865             size_t count = swizzle.fComponents.size();
1866             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1867             ASSERT(base);
1868             if (count == 1) {
1869                 IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
1870                 SpvId member = this->nextId();
1871                 this->writeInstruction(SpvOpAccessChain,
1872                                        this->getPointerType(swizzle.fType,
1873                                                             get_storage_class(*swizzle.fBase)),
1874                                        member,
1875                                        base,
1876                                        this->writeIntLiteral(index),
1877                                        out);
1878                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1879                                                                        *this,
1880                                                                        member,
1881                                                                        this->getType(expr.fType)));
1882             } else {
1883                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1884                                                                               *this,
1885                                                                               base,
1886                                                                               swizzle.fComponents,
1887                                                                               swizzle.fBase->fType,
1888                                                                               expr.fType));
1889             }
1890         }
1891 
1892         default:
1893             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1894             // to the need to store values in temporary variables during function calls (see
1895             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1896             // caught by IRGenerator
1897             SpvId result = this->nextId();
1898             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1899             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1900                                    fVariableBuffer);
1901             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1902             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1903                                                                        *this,
1904                                                                        result,
1905                                                                        this->getType(expr.fType)));
1906     }
1907 }
1908 
writeVariableReference(const VariableReference & ref,OutputStream & out)1909 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1910     SpvId result = this->nextId();
1911     auto entry = fVariableMap.find(&ref.fVariable);
1912     ASSERT(entry != fVariableMap.end());
1913     SpvId var = entry->second;
1914     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1915     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1916         fProgram.fSettings.fFlipY) {
1917         // need to remap to a top-left coordinate system
1918         if (fRTHeightStructId == (SpvId) -1) {
1919             // height variable hasn't been written yet
1920             std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1921             ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1922             std::vector<Type::Field> fields;
1923             fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME),
1924                                 fContext.fFloat_Type.get());
1925             String name("sksl_synthetic_uniforms");
1926             Type intfStruct(Position(), name, fields);
1927             Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false,
1928                           Layout::Format::kUnspecified, false, Layout::kUnspecified_Primitive, -1,
1929                           -1, "", Layout::kNo_Key);
1930             Variable* intfVar = new Variable(Position(),
1931                                              Modifiers(layout, Modifiers::kUniform_Flag),
1932                                              name,
1933                                              intfStruct,
1934                                              Variable::kGlobal_Storage);
1935             fSynthetics.takeOwnership(intfVar);
1936             InterfaceBlock intf(Position(), intfVar, name, String(""),
1937                                 std::vector<std::unique_ptr<Expression>>(), st);
1938             fRTHeightStructId = this->writeInterfaceBlock(intf);
1939             fRTHeightFieldIndex = 0;
1940         }
1941         ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1942         // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1943         SpvId xId = this->nextId();
1944         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1945                                result, 0, out);
1946         IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex);
1947         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1948         SpvId heightPtr = this->nextId();
1949         this->writeOpCode(SpvOpAccessChain, 5, out);
1950         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1951         this->writeWord(heightPtr, out);
1952         this->writeWord(fRTHeightStructId, out);
1953         this->writeWord(fieldIndexId, out);
1954         SpvId heightRead = this->nextId();
1955         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1956                                heightPtr, out);
1957         SpvId rawYId = this->nextId();
1958         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1959                                result, 1, out);
1960         SpvId flippedYId = this->nextId();
1961         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1962                                heightRead, rawYId, out);
1963         FloatLiteral zero(fContext, Position(), 0.0);
1964         SpvId zeroId = writeFloatLiteral(zero);
1965         FloatLiteral one(fContext, Position(), 1.0);
1966         SpvId oneId = writeFloatLiteral(one);
1967         SpvId flipped = this->nextId();
1968         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1969         this->writeWord(this->getType(*fContext.fVec4_Type), out);
1970         this->writeWord(flipped, out);
1971         this->writeWord(xId, out);
1972         this->writeWord(flippedYId, out);
1973         this->writeWord(zeroId, out);
1974         this->writeWord(oneId, out);
1975         return flipped;
1976     }
1977     return result;
1978 }
1979 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1980 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1981     return getLValue(expr, out)->load(out);
1982 }
1983 
writeFieldAccess(const FieldAccess & f,OutputStream & out)1984 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1985     return getLValue(f, out)->load(out);
1986 }
1987 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1988 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1989     SpvId base = this->writeExpression(*swizzle.fBase, out);
1990     SpvId result = this->nextId();
1991     size_t count = swizzle.fComponents.size();
1992     if (count == 1) {
1993         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1994                                swizzle.fComponents[0], out);
1995     } else {
1996         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1997         this->writeWord(this->getType(swizzle.fType), out);
1998         this->writeWord(result, out);
1999         this->writeWord(base, out);
2000         this->writeWord(base, out);
2001         for (int component : swizzle.fComponents) {
2002             this->writeWord(component, out);
2003         }
2004     }
2005     return result;
2006 }
2007 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2008 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2009                                                const Type& operandType, SpvId lhs,
2010                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2011                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2012     SpvId result = this->nextId();
2013     if (is_float(fContext, operandType)) {
2014         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
2015     } else if (is_signed(fContext, operandType)) {
2016         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
2017     } else if (is_unsigned(fContext, operandType)) {
2018         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
2019     } else if (operandType == *fContext.fBool_Type) {
2020         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
2021     } else {
2022         ABORT("invalid operandType: %s", operandType.description().c_str());
2023     }
2024     return result;
2025 }
2026 
is_assignment(Token::Kind op)2027 bool is_assignment(Token::Kind op) {
2028     switch (op) {
2029         case Token::EQ:           // fall through
2030         case Token::PLUSEQ:       // fall through
2031         case Token::MINUSEQ:      // fall through
2032         case Token::STAREQ:       // fall through
2033         case Token::SLASHEQ:      // fall through
2034         case Token::PERCENTEQ:    // fall through
2035         case Token::SHLEQ:        // fall through
2036         case Token::SHREQ:        // fall through
2037         case Token::BITWISEOREQ:  // fall through
2038         case Token::BITWISEXOREQ: // fall through
2039         case Token::BITWISEANDEQ: // fall through
2040         case Token::LOGICALOREQ:  // fall through
2041         case Token::LOGICALXOREQ: // fall through
2042         case Token::LOGICALANDEQ:
2043             return true;
2044         default:
2045             return false;
2046     }
2047 }
2048 
foldToBool(SpvId id,const Type & operandType,OutputStream & out)2049 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
2050     if (operandType.kind() == Type::kVector_Kind) {
2051         SpvId result = this->nextId();
2052         this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
2053         return result;
2054     }
2055     return id;
2056 }
2057 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)2058 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2059                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
2060                                                 OutputStream& out) {
2061     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
2062     ASSERT(operandType.kind() == Type::kMatrix_Kind);
2063     SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
2064                                                                          operandType.columns(),
2065                                                                          1));
2066     SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
2067                                                                     operandType.columns(),
2068                                                                     1));
2069     SpvId boolType = this->getType(*fContext.fBool_Type);
2070     SpvId result = 0;
2071     for (int i = 0; i < operandType.rows(); i++) {
2072         SpvId rowL = this->nextId();
2073         this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
2074         SpvId rowR = this->nextId();
2075         this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
2076         SpvId compare = this->nextId();
2077         this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
2078         SpvId all = this->nextId();
2079         this->writeInstruction(SpvOpAll, boolType, all, compare, out);
2080         if (result != 0) {
2081             SpvId next = this->nextId();
2082             this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
2083             result = next;
2084         }
2085         else {
2086             result = all;
2087         }
2088     }
2089     return result;
2090 }
2091 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2092 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2093     // handle cases where we don't necessarily evaluate both LHS and RHS
2094     switch (b.fOperator) {
2095         case Token::EQ: {
2096             SpvId rhs = this->writeExpression(*b.fRight, out);
2097             this->getLValue(*b.fLeft, out)->store(rhs, out);
2098             return rhs;
2099         }
2100         case Token::LOGICALAND:
2101             return this->writeLogicalAnd(b, out);
2102         case Token::LOGICALOR:
2103             return this->writeLogicalOr(b, out);
2104         default:
2105             break;
2106     }
2107 
2108     // "normal" operators
2109     const Type& resultType = b.fType;
2110     std::unique_ptr<LValue> lvalue;
2111     SpvId lhs;
2112     if (is_assignment(b.fOperator)) {
2113         lvalue = this->getLValue(*b.fLeft, out);
2114         lhs = lvalue->load(out);
2115     } else {
2116         lvalue = nullptr;
2117         lhs = this->writeExpression(*b.fLeft, out);
2118     }
2119     SpvId rhs = this->writeExpression(*b.fRight, out);
2120     if (b.fOperator == Token::COMMA) {
2121         return rhs;
2122     }
2123     // component type we are operating on: float, int, uint
2124     const Type* operandType;
2125     // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
2126     // in SPIR-V
2127     if (b.fLeft->fType != b.fRight->fType) {
2128         if (b.fLeft->fType.kind() == Type::kVector_Kind &&
2129             b.fRight->fType.isNumber()) {
2130             // promote number to vector
2131             SpvId vec = this->nextId();
2132             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2133             this->writeWord(this->getType(resultType), out);
2134             this->writeWord(vec, out);
2135             for (int i = 0; i < resultType.columns(); i++) {
2136                 this->writeWord(rhs, out);
2137             }
2138             rhs = vec;
2139             operandType = &b.fRight->fType;
2140         } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
2141                    b.fLeft->fType.isNumber()) {
2142             // promote number to vector
2143             SpvId vec = this->nextId();
2144             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2145             this->writeWord(this->getType(resultType), out);
2146             this->writeWord(vec, out);
2147             for (int i = 0; i < resultType.columns(); i++) {
2148                 this->writeWord(lhs, out);
2149             }
2150             lhs = vec;
2151             ASSERT(!lvalue);
2152             operandType = &b.fLeft->fType;
2153         } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2154             SpvOp_ op;
2155             if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2156                 op = SpvOpMatrixTimesMatrix;
2157             } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2158                 op = SpvOpMatrixTimesVector;
2159             } else {
2160                 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2161                 op = SpvOpMatrixTimesScalar;
2162             }
2163             SpvId result = this->nextId();
2164             this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2165             if (b.fOperator == Token::STAREQ) {
2166                 lvalue->store(result, out);
2167             } else {
2168                 ASSERT(b.fOperator == Token::STAR);
2169             }
2170             return result;
2171         } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2172             SpvId result = this->nextId();
2173             if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2174                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2175                                        lhs, rhs, out);
2176             } else {
2177                 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2178                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2179                                        lhs, out);
2180             }
2181             if (b.fOperator == Token::STAREQ) {
2182                 lvalue->store(result, out);
2183             } else {
2184                 ASSERT(b.fOperator == Token::STAR);
2185             }
2186             return result;
2187         } else {
2188             ABORT("unsupported binary expression: %s", b.description().c_str());
2189         }
2190     } else {
2191         operandType = &b.fLeft->fType;
2192         ASSERT(*operandType == b.fRight->fType);
2193     }
2194     switch (b.fOperator) {
2195         case Token::EQEQ: {
2196             if (operandType->kind() == Type::kMatrix_Kind) {
2197                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2198                                                    SpvOpIEqual, out);
2199             }
2200             ASSERT(resultType == *fContext.fBool_Type);
2201             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2202                                                                SpvOpFOrdEqual, SpvOpIEqual,
2203                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2204                                     *operandType, out);
2205         }
2206         case Token::NEQ:
2207             if (operandType->kind() == Type::kMatrix_Kind) {
2208                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2209                                                    SpvOpINotEqual, out);
2210             }
2211             ASSERT(resultType == *fContext.fBool_Type);
2212             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2213                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2214                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2215                                                                out),
2216                                     *operandType, out);
2217         case Token::GT:
2218             ASSERT(resultType == *fContext.fBool_Type);
2219             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2220                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2221                                               SpvOpUGreaterThan, SpvOpUndef, out);
2222         case Token::LT:
2223             ASSERT(resultType == *fContext.fBool_Type);
2224             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2225                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2226         case Token::GTEQ:
2227             ASSERT(resultType == *fContext.fBool_Type);
2228             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2229                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2230                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2231         case Token::LTEQ:
2232             ASSERT(resultType == *fContext.fBool_Type);
2233             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2234                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2235                                               SpvOpULessThanEqual, SpvOpUndef, out);
2236         case Token::PLUS:
2237             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2238                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2239         case Token::MINUS:
2240             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2241                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2242         case Token::STAR:
2243             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2244                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2245                 // matrix multiply
2246                 SpvId result = this->nextId();
2247                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2248                                        lhs, rhs, out);
2249                 return result;
2250             }
2251             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2252                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2253         case Token::SLASH:
2254             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2255                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2256         case Token::PERCENT:
2257             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2258                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2259         case Token::SHL:
2260             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2261                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2262                                               SpvOpUndef, out);
2263         case Token::SHR:
2264             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2265                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2266                                               SpvOpUndef, out);
2267         case Token::BITWISEAND:
2268             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2269                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2270         case Token::BITWISEOR:
2271             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2272                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2273         case Token::BITWISEXOR:
2274             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2275                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2276         case Token::PLUSEQ: {
2277             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2278                                                       SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2279             ASSERT(lvalue);
2280             lvalue->store(result, out);
2281             return result;
2282         }
2283         case Token::MINUSEQ: {
2284             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2285                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
2286             ASSERT(lvalue);
2287             lvalue->store(result, out);
2288             return result;
2289         }
2290         case Token::STAREQ: {
2291             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2292                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2293                 // matrix multiply
2294                 SpvId result = this->nextId();
2295                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2296                                        lhs, rhs, out);
2297                 ASSERT(lvalue);
2298                 lvalue->store(result, out);
2299                 return result;
2300             }
2301             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2302                                                       SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2303             ASSERT(lvalue);
2304             lvalue->store(result, out);
2305             return result;
2306         }
2307         case Token::SLASHEQ: {
2308             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2309                                                       SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2310             ASSERT(lvalue);
2311             lvalue->store(result, out);
2312             return result;
2313         }
2314         case Token::PERCENTEQ: {
2315             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2316                                                       SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2317             ASSERT(lvalue);
2318             lvalue->store(result, out);
2319             return result;
2320         }
2321         case Token::SHLEQ: {
2322             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2323                                                       SpvOpUndef, SpvOpShiftLeftLogical,
2324                                                       SpvOpShiftLeftLogical, SpvOpUndef, out);
2325             ASSERT(lvalue);
2326             lvalue->store(result, out);
2327             return result;
2328         }
2329         case Token::SHREQ: {
2330             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2331                                                       SpvOpUndef, SpvOpShiftRightArithmetic,
2332                                                       SpvOpShiftRightLogical, SpvOpUndef, out);
2333             ASSERT(lvalue);
2334             lvalue->store(result, out);
2335             return result;
2336         }
2337         case Token::BITWISEANDEQ: {
2338             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2339                                                       SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2340                                                       SpvOpUndef, out);
2341             ASSERT(lvalue);
2342             lvalue->store(result, out);
2343             return result;
2344         }
2345         case Token::BITWISEOREQ: {
2346             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2347                                                       SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2348                                                       SpvOpUndef, out);
2349             ASSERT(lvalue);
2350             lvalue->store(result, out);
2351             return result;
2352         }
2353         case Token::BITWISEXOREQ: {
2354             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2355                                                       SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2356                                                       SpvOpUndef, out);
2357             ASSERT(lvalue);
2358             lvalue->store(result, out);
2359             return result;
2360         }
2361         default:
2362             ABORT("unsupported binary expression: %s", b.description().c_str());
2363     }
2364 }
2365 
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2366 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2367     ASSERT(a.fOperator == Token::LOGICALAND);
2368     BoolLiteral falseLiteral(fContext, Position(), false);
2369     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2370     SpvId lhs = this->writeExpression(*a.fLeft, out);
2371     SpvId rhsLabel = this->nextId();
2372     SpvId end = this->nextId();
2373     SpvId lhsBlock = fCurrentBlock;
2374     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2375     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2376     this->writeLabel(rhsLabel, out);
2377     SpvId rhs = this->writeExpression(*a.fRight, out);
2378     SpvId rhsBlock = fCurrentBlock;
2379     this->writeInstruction(SpvOpBranch, end, out);
2380     this->writeLabel(end, out);
2381     SpvId result = this->nextId();
2382     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2383                            lhsBlock, rhs, rhsBlock, out);
2384     return result;
2385 }
2386 
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2387 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2388     ASSERT(o.fOperator == Token::LOGICALOR);
2389     BoolLiteral trueLiteral(fContext, Position(), true);
2390     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2391     SpvId lhs = this->writeExpression(*o.fLeft, out);
2392     SpvId rhsLabel = this->nextId();
2393     SpvId end = this->nextId();
2394     SpvId lhsBlock = fCurrentBlock;
2395     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2396     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2397     this->writeLabel(rhsLabel, out);
2398     SpvId rhs = this->writeExpression(*o.fRight, out);
2399     SpvId rhsBlock = fCurrentBlock;
2400     this->writeInstruction(SpvOpBranch, end, out);
2401     this->writeLabel(end, out);
2402     SpvId result = this->nextId();
2403     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2404                            lhsBlock, rhs, rhsBlock, out);
2405     return result;
2406 }
2407 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2408 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2409     SpvId test = this->writeExpression(*t.fTest, out);
2410     if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2411         // both true and false are constants, can just use OpSelect
2412         SpvId result = this->nextId();
2413         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2414         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2415         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2416                                out);
2417         return result;
2418     }
2419     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2420     // Adreno. Switched to storing the result in a temp variable as glslang does.
2421     SpvId var = this->nextId();
2422     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2423                            var, SpvStorageClassFunction, fVariableBuffer);
2424     SpvId trueLabel = this->nextId();
2425     SpvId falseLabel = this->nextId();
2426     SpvId end = this->nextId();
2427     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2428     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2429     this->writeLabel(trueLabel, out);
2430     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2431     this->writeInstruction(SpvOpBranch, end, out);
2432     this->writeLabel(falseLabel, out);
2433     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2434     this->writeInstruction(SpvOpBranch, end, out);
2435     this->writeLabel(end, out);
2436     SpvId result = this->nextId();
2437     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2438     return result;
2439 }
2440 
create_literal_1(const Context & context,const Type & type)2441 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2442     if (type == *context.fInt_Type) {
2443         return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
2444     }
2445     else if (type == *context.fFloat_Type) {
2446         return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
2447     } else {
2448         ABORT("math is unsupported on type '%s'", type.name().c_str());
2449     }
2450 }
2451 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2452 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2453     if (p.fOperator == Token::MINUS) {
2454         SpvId result = this->nextId();
2455         SpvId typeId = this->getType(p.fType);
2456         SpvId expr = this->writeExpression(*p.fOperand, out);
2457         if (is_float(fContext, p.fType)) {
2458             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2459         } else if (is_signed(fContext, p.fType)) {
2460             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2461         } else {
2462             ABORT("unsupported prefix expression %s", p.description().c_str());
2463         };
2464         return result;
2465     }
2466     switch (p.fOperator) {
2467         case Token::PLUS:
2468             return this->writeExpression(*p.fOperand, out);
2469         case Token::PLUSPLUS: {
2470             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2471             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2472             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2473                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2474                                                       out);
2475             lv->store(result, out);
2476             return result;
2477         }
2478         case Token::MINUSMINUS: {
2479             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2480             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2481             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2482                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2483                                                       out);
2484             lv->store(result, out);
2485             return result;
2486         }
2487         case Token::LOGICALNOT: {
2488             ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2489             SpvId result = this->nextId();
2490             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2491                                    this->writeExpression(*p.fOperand, out), out);
2492             return result;
2493         }
2494         case Token::BITWISENOT: {
2495             SpvId result = this->nextId();
2496             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2497                                    this->writeExpression(*p.fOperand, out), out);
2498             return result;
2499         }
2500         default:
2501             ABORT("unsupported prefix expression: %s", p.description().c_str());
2502     }
2503 }
2504 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2505 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2506     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2507     SpvId result = lv->load(out);
2508     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2509     switch (p.fOperator) {
2510         case Token::PLUSPLUS: {
2511             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2512                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2513             lv->store(temp, out);
2514             return result;
2515         }
2516         case Token::MINUSMINUS: {
2517             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2518                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2519             lv->store(temp, out);
2520             return result;
2521         }
2522         default:
2523             ABORT("unsupported postfix expression %s", p.description().c_str());
2524     }
2525 }
2526 
writeBoolLiteral(const BoolLiteral & b)2527 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2528     if (b.fValue) {
2529         if (fBoolTrue == 0) {
2530             fBoolTrue = this->nextId();
2531             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2532                                    fConstantBuffer);
2533         }
2534         return fBoolTrue;
2535     } else {
2536         if (fBoolFalse == 0) {
2537             fBoolFalse = this->nextId();
2538             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2539                                    fConstantBuffer);
2540         }
2541         return fBoolFalse;
2542     }
2543 }
2544 
writeIntLiteral(const IntLiteral & i)2545 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2546     if (i.fType == *fContext.fInt_Type) {
2547         auto entry = fIntConstants.find(i.fValue);
2548         if (entry == fIntConstants.end()) {
2549             SpvId result = this->nextId();
2550             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2551                                    fConstantBuffer);
2552             fIntConstants[i.fValue] = result;
2553             return result;
2554         }
2555         return entry->second;
2556     } else {
2557         ASSERT(i.fType == *fContext.fUInt_Type);
2558         auto entry = fUIntConstants.find(i.fValue);
2559         if (entry == fUIntConstants.end()) {
2560             SpvId result = this->nextId();
2561             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2562                                    fConstantBuffer);
2563             fUIntConstants[i.fValue] = result;
2564             return result;
2565         }
2566         return entry->second;
2567     }
2568 }
2569 
writeFloatLiteral(const FloatLiteral & f)2570 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2571     if (f.fType == *fContext.fFloat_Type) {
2572         float value = (float) f.fValue;
2573         auto entry = fFloatConstants.find(value);
2574         if (entry == fFloatConstants.end()) {
2575             SpvId result = this->nextId();
2576             uint32_t bits;
2577             ASSERT(sizeof(bits) == sizeof(value));
2578             memcpy(&bits, &value, sizeof(bits));
2579             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2580                                    fConstantBuffer);
2581             fFloatConstants[value] = result;
2582             return result;
2583         }
2584         return entry->second;
2585     } else {
2586         ASSERT(f.fType == *fContext.fDouble_Type);
2587         auto entry = fDoubleConstants.find(f.fValue);
2588         if (entry == fDoubleConstants.end()) {
2589             SpvId result = this->nextId();
2590             uint64_t bits;
2591             ASSERT(sizeof(bits) == sizeof(f.fValue));
2592             memcpy(&bits, &f.fValue, sizeof(bits));
2593             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2594                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2595             fDoubleConstants[f.fValue] = result;
2596             return result;
2597         }
2598         return entry->second;
2599     }
2600 }
2601 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2602 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2603     SpvId result = fFunctionMap[&f];
2604     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2605                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2606     this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
2607     for (size_t i = 0; i < f.fParameters.size(); i++) {
2608         SpvId id = this->nextId();
2609         fVariableMap[f.fParameters[i]] = id;
2610         SpvId type;
2611         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2612         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2613     }
2614     return result;
2615 }
2616 
writeFunction(const FunctionDefinition & f,OutputStream & out)2617 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2618     fVariableBuffer.reset();
2619     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2620     this->writeLabel(this->nextId(), out);
2621     if (f.fDeclaration.fName == "main") {
2622         write_stringstream(fGlobalInitializersBuffer, out);
2623     }
2624     StringStream bodyBuffer;
2625     this->writeBlock((Block&) *f.fBody, bodyBuffer);
2626     write_stringstream(fVariableBuffer, out);
2627     write_stringstream(bodyBuffer, out);
2628     if (fCurrentBlock) {
2629         this->writeInstruction(SpvOpReturn, out);
2630     }
2631     this->writeInstruction(SpvOpFunctionEnd, out);
2632     return result;
2633 }
2634 
writeLayout(const Layout & layout,SpvId target)2635 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2636     if (layout.fLocation >= 0) {
2637         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2638                                fDecorationBuffer);
2639     }
2640     if (layout.fBinding >= 0) {
2641         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2642                                fDecorationBuffer);
2643     }
2644     if (layout.fIndex >= 0) {
2645         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2646                                fDecorationBuffer);
2647     }
2648     if (layout.fSet >= 0) {
2649         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2650                                fDecorationBuffer);
2651     }
2652     if (layout.fInputAttachmentIndex >= 0) {
2653         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2654                                layout.fInputAttachmentIndex, fDecorationBuffer);
2655     }
2656     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2657         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2658                                fDecorationBuffer);
2659     }
2660 }
2661 
writeLayout(const Layout & layout,SpvId target,int member)2662 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2663     if (layout.fLocation >= 0) {
2664         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2665                                layout.fLocation, fDecorationBuffer);
2666     }
2667     if (layout.fBinding >= 0) {
2668         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2669                                layout.fBinding, fDecorationBuffer);
2670     }
2671     if (layout.fIndex >= 0) {
2672         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2673                                layout.fIndex, fDecorationBuffer);
2674     }
2675     if (layout.fSet >= 0) {
2676         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2677                                layout.fSet, fDecorationBuffer);
2678     }
2679     if (layout.fInputAttachmentIndex >= 0) {
2680         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2681                                layout.fInputAttachmentIndex, fDecorationBuffer);
2682     }
2683     if (layout.fBuiltin >= 0) {
2684         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2685                                layout.fBuiltin, fDecorationBuffer);
2686     }
2687 }
2688 
writeInterfaceBlock(const InterfaceBlock & intf)2689 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2690     bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2691     MemoryLayout layout = (intf.fVariable.fModifiers.fLayout.fPushConstant || isBuffer) ?
2692                           MemoryLayout(MemoryLayout::k430_Standard) :
2693                           fDefaultLayout;
2694     SpvId result = this->nextId();
2695     const Type* type = &intf.fVariable.fType;
2696     if (fProgram.fInputs.fRTHeight) {
2697         ASSERT(fRTHeightStructId == (SpvId) -1);
2698         ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2699         std::vector<Type::Field> fields = type->fields();
2700         fRTHeightStructId = result;
2701         fRTHeightFieldIndex = fields.size();
2702         fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2703         type = new Type(type->fPosition, type->name(), fields);
2704     }
2705     SpvId typeId = this->getType(*type, layout);
2706     if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2707         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2708     } else {
2709         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2710     }
2711     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2712     SpvId ptrType = this->nextId();
2713     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2714     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2715     this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2716     fVariableMap[&intf.fVariable] = result;
2717     if (fProgram.fInputs.fRTHeight) {
2718         delete type;
2719     }
2720     return result;
2721 }
2722 
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2723 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2724     if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2725         (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2726         this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2727     }
2728 }
2729 
2730 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2731 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2732                                          OutputStream& out) {
2733     for (size_t i = 0; i < decl.fVars.size(); i++) {
2734         if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2735             continue;
2736         }
2737         const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2738         const Variable* var = varDecl.fVar;
2739         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2740         // in the OpenGL backend.
2741         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2742                                            Modifiers::kWriteOnly_Flag |
2743                                            Modifiers::kCoherent_Flag |
2744                                            Modifiers::kVolatile_Flag |
2745                                            Modifiers::kRestrict_Flag)));
2746         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2747             continue;
2748         }
2749         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2750             kind != Program::kFragment_Kind) {
2751             continue;
2752         }
2753         if (!var->fReadCount && !var->fWriteCount &&
2754                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2755                                             Modifiers::kOut_Flag |
2756                                             Modifiers::kUniform_Flag |
2757                                             Modifiers::kBuffer_Flag))) {
2758             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2759             // we elide an interface var, even if it's dead)
2760             continue;
2761         }
2762         SpvStorageClass_ storageClass;
2763         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2764             storageClass = SpvStorageClassInput;
2765         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2766             storageClass = SpvStorageClassOutput;
2767         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2768             if (var->fType.kind() == Type::kSampler_Kind) {
2769                 storageClass = SpvStorageClassUniformConstant;
2770             } else {
2771                 storageClass = SpvStorageClassUniform;
2772             }
2773         } else {
2774             storageClass = SpvStorageClassPrivate;
2775         }
2776         SpvId id = this->nextId();
2777         fVariableMap[var] = id;
2778         SpvId type = this->getPointerType(var->fType, storageClass);
2779         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2780         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2781         this->writePrecisionModifier(var->fModifiers, id);
2782         if (var->fType.kind() == Type::kMatrix_Kind) {
2783             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
2784                                    fDecorationBuffer);
2785             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
2786                                    (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer);
2787         }
2788         if (varDecl.fValue) {
2789             ASSERT(!fCurrentBlock);
2790             fCurrentBlock = -1;
2791             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2792             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2793             fCurrentBlock = 0;
2794         }
2795         this->writeLayout(var->fModifiers.fLayout, id);
2796     }
2797 }
2798 
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2799 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2800     for (const auto& stmt : decl.fVars) {
2801         ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2802         VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2803         const Variable* var = varDecl.fVar;
2804         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2805         // in the OpenGL backend.
2806         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2807                                            Modifiers::kWriteOnly_Flag |
2808                                            Modifiers::kCoherent_Flag |
2809                                            Modifiers::kVolatile_Flag |
2810                                            Modifiers::kRestrict_Flag)));
2811         SpvId id = this->nextId();
2812         fVariableMap[var] = id;
2813         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2814         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2815         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2816         if (varDecl.fValue) {
2817             SpvId value = this->writeExpression(*varDecl.fValue, out);
2818             this->writeInstruction(SpvOpStore, id, value, out);
2819         }
2820     }
2821 }
2822 
writeStatement(const Statement & s,OutputStream & out)2823 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2824     switch (s.fKind) {
2825         case Statement::kNop_Kind:
2826             break;
2827         case Statement::kBlock_Kind:
2828             this->writeBlock((Block&) s, out);
2829             break;
2830         case Statement::kExpression_Kind:
2831             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2832             break;
2833         case Statement::kReturn_Kind:
2834             this->writeReturnStatement((ReturnStatement&) s, out);
2835             break;
2836         case Statement::kVarDeclarations_Kind:
2837             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2838             break;
2839         case Statement::kIf_Kind:
2840             this->writeIfStatement((IfStatement&) s, out);
2841             break;
2842         case Statement::kFor_Kind:
2843             this->writeForStatement((ForStatement&) s, out);
2844             break;
2845         case Statement::kWhile_Kind:
2846             this->writeWhileStatement((WhileStatement&) s, out);
2847             break;
2848         case Statement::kDo_Kind:
2849             this->writeDoStatement((DoStatement&) s, out);
2850             break;
2851         case Statement::kBreak_Kind:
2852             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2853             break;
2854         case Statement::kContinue_Kind:
2855             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2856             break;
2857         case Statement::kDiscard_Kind:
2858             this->writeInstruction(SpvOpKill, out);
2859             break;
2860         default:
2861             ABORT("unsupported statement: %s", s.description().c_str());
2862     }
2863 }
2864 
writeBlock(const Block & b,OutputStream & out)2865 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2866     for (size_t i = 0; i < b.fStatements.size(); i++) {
2867         this->writeStatement(*b.fStatements[i], out);
2868     }
2869 }
2870 
writeIfStatement(const IfStatement & stmt,OutputStream & out)2871 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2872     SpvId test = this->writeExpression(*stmt.fTest, out);
2873     SpvId ifTrue = this->nextId();
2874     SpvId ifFalse = this->nextId();
2875     if (stmt.fIfFalse) {
2876         SpvId end = this->nextId();
2877         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2878         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2879         this->writeLabel(ifTrue, out);
2880         this->writeStatement(*stmt.fIfTrue, out);
2881         if (fCurrentBlock) {
2882             this->writeInstruction(SpvOpBranch, end, out);
2883         }
2884         this->writeLabel(ifFalse, out);
2885         this->writeStatement(*stmt.fIfFalse, out);
2886         if (fCurrentBlock) {
2887             this->writeInstruction(SpvOpBranch, end, out);
2888         }
2889         this->writeLabel(end, out);
2890     } else {
2891         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2892         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2893         this->writeLabel(ifTrue, out);
2894         this->writeStatement(*stmt.fIfTrue, out);
2895         if (fCurrentBlock) {
2896             this->writeInstruction(SpvOpBranch, ifFalse, out);
2897         }
2898         this->writeLabel(ifFalse, out);
2899     }
2900 }
2901 
writeForStatement(const ForStatement & f,OutputStream & out)2902 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2903     if (f.fInitializer) {
2904         this->writeStatement(*f.fInitializer, out);
2905     }
2906     SpvId header = this->nextId();
2907     SpvId start = this->nextId();
2908     SpvId body = this->nextId();
2909     SpvId next = this->nextId();
2910     fContinueTarget.push(next);
2911     SpvId end = this->nextId();
2912     fBreakTarget.push(end);
2913     this->writeInstruction(SpvOpBranch, header, out);
2914     this->writeLabel(header, out);
2915     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2916     this->writeInstruction(SpvOpBranch, start, out);
2917     this->writeLabel(start, out);
2918     if (f.fTest) {
2919         SpvId test = this->writeExpression(*f.fTest, out);
2920         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2921     }
2922     this->writeLabel(body, out);
2923     this->writeStatement(*f.fStatement, out);
2924     if (fCurrentBlock) {
2925         this->writeInstruction(SpvOpBranch, next, out);
2926     }
2927     this->writeLabel(next, out);
2928     if (f.fNext) {
2929         this->writeExpression(*f.fNext, out);
2930     }
2931     this->writeInstruction(SpvOpBranch, header, out);
2932     this->writeLabel(end, out);
2933     fBreakTarget.pop();
2934     fContinueTarget.pop();
2935 }
2936 
writeWhileStatement(const WhileStatement & w,OutputStream & out)2937 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2938     // We believe the while loop code below will work, but Skia doesn't actually use them and
2939     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2940     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2941     // message, simply remove the error call below to see whether our while loop support actually
2942     // works.
2943     fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, "
2944                   "see SkSLSPIRVCodeGenerator.cpp for details");
2945 
2946     SpvId header = this->nextId();
2947     SpvId start = this->nextId();
2948     SpvId body = this->nextId();
2949     fContinueTarget.push(start);
2950     SpvId end = this->nextId();
2951     fBreakTarget.push(end);
2952     this->writeInstruction(SpvOpBranch, header, out);
2953     this->writeLabel(header, out);
2954     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2955     this->writeInstruction(SpvOpBranch, start, out);
2956     this->writeLabel(start, out);
2957     SpvId test = this->writeExpression(*w.fTest, out);
2958     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2959     this->writeLabel(body, out);
2960     this->writeStatement(*w.fStatement, out);
2961     if (fCurrentBlock) {
2962         this->writeInstruction(SpvOpBranch, start, out);
2963     }
2964     this->writeLabel(end, out);
2965     fBreakTarget.pop();
2966     fContinueTarget.pop();
2967 }
2968 
writeDoStatement(const DoStatement & d,OutputStream & out)2969 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2970     // We believe the do loop code below will work, but Skia doesn't actually use them and
2971     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2972     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2973     // message, simply remove the error call below to see whether our do loop support actually
2974     // works.
2975     fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see "
2976                   "SkSLSPIRVCodeGenerator.cpp for details");
2977 
2978     SpvId header = this->nextId();
2979     SpvId start = this->nextId();
2980     SpvId next = this->nextId();
2981     fContinueTarget.push(next);
2982     SpvId end = this->nextId();
2983     fBreakTarget.push(end);
2984     this->writeInstruction(SpvOpBranch, header, out);
2985     this->writeLabel(header, out);
2986     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2987     this->writeInstruction(SpvOpBranch, start, out);
2988     this->writeLabel(start, out);
2989     this->writeStatement(*d.fStatement, out);
2990     if (fCurrentBlock) {
2991         this->writeInstruction(SpvOpBranch, next, out);
2992     }
2993     this->writeLabel(next, out);
2994     SpvId test = this->writeExpression(*d.fTest, out);
2995     this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2996     this->writeLabel(end, out);
2997     fBreakTarget.pop();
2998     fContinueTarget.pop();
2999 }
3000 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3001 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3002     if (r.fExpression) {
3003         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3004                                out);
3005     } else {
3006         this->writeInstruction(SpvOpReturn, out);
3007     }
3008 }
3009 
writeInstructions(const Program & program,OutputStream & out)3010 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3011     fGLSLExtendedInstructions = this->nextId();
3012     StringStream body;
3013     std::set<SpvId> interfaceVars;
3014     // assign IDs to functions
3015     for (size_t i = 0; i < program.fElements.size(); i++) {
3016         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
3017             FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
3018             fFunctionMap[&f.fDeclaration] = this->nextId();
3019         }
3020     }
3021     for (size_t i = 0; i < program.fElements.size(); i++) {
3022         if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
3023             InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
3024             SpvId id = this->writeInterfaceBlock(intf);
3025             if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3026                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
3027                 interfaceVars.insert(id);
3028             }
3029         }
3030     }
3031     for (size_t i = 0; i < program.fElements.size(); i++) {
3032         if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
3033             this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
3034                                   body);
3035         }
3036     }
3037     for (size_t i = 0; i < program.fElements.size(); i++) {
3038         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
3039             this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
3040         }
3041     }
3042     const FunctionDeclaration* main = nullptr;
3043     for (auto entry : fFunctionMap) {
3044         if (entry.first->fName == "main") {
3045             main = entry.first;
3046         }
3047     }
3048     ASSERT(main);
3049     for (auto entry : fVariableMap) {
3050         const Variable* var = entry.first;
3051         if (var->fStorage == Variable::kGlobal_Storage &&
3052                 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3053                  (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3054             interfaceVars.insert(entry.second);
3055         }
3056     }
3057     this->writeCapabilities(out);
3058     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3059     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3060     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
3061                       (int32_t) interfaceVars.size(), out);
3062     switch (program.fKind) {
3063         case Program::kVertex_Kind:
3064             this->writeWord(SpvExecutionModelVertex, out);
3065             break;
3066         case Program::kFragment_Kind:
3067             this->writeWord(SpvExecutionModelFragment, out);
3068             break;
3069         case Program::kGeometry_Kind:
3070             this->writeWord(SpvExecutionModelGeometry, out);
3071             break;
3072         default:
3073             ABORT("cannot write this kind of program to SPIR-V\n");
3074     }
3075     this->writeWord(fFunctionMap[main], out);
3076     this->writeString(main->fName.c_str(), out);
3077     for (int var : interfaceVars) {
3078         this->writeWord(var, out);
3079     }
3080     if (program.fKind == Program::kFragment_Kind) {
3081         this->writeInstruction(SpvOpExecutionMode,
3082                                fFunctionMap[main],
3083                                SpvExecutionModeOriginUpperLeft,
3084                                out);
3085     }
3086     for (size_t i = 0; i < program.fElements.size(); i++) {
3087         if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
3088             this->writeInstruction(SpvOpSourceExtension,
3089                                    ((Extension&) *program.fElements[i]).fName.c_str(),
3090                                    out);
3091         }
3092     }
3093 
3094     write_stringstream(fExtraGlobalsBuffer, out);
3095     write_stringstream(fNameBuffer, out);
3096     write_stringstream(fDecorationBuffer, out);
3097     write_stringstream(fConstantBuffer, out);
3098     write_stringstream(fExternalFunctionsBuffer, out);
3099     write_stringstream(body, out);
3100 }
3101 
generateCode()3102 bool SPIRVCodeGenerator::generateCode() {
3103     ASSERT(!fErrors.errorCount());
3104     this->writeWord(SpvMagicNumber, *fOut);
3105     this->writeWord(SpvVersion, *fOut);
3106     this->writeWord(SKSL_MAGIC, *fOut);
3107     StringStream buffer;
3108     this->writeInstructions(fProgram, buffer);
3109     this->writeWord(fIdCount, *fOut);
3110     this->writeWord(0, *fOut); // reserved, always zero
3111     write_stringstream(buffer, *fOut);
3112     return 0 == fErrors.errorCount();
3113 }
3114 
3115 }
3116