• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLSPIRVCodeGenerator.h"
9 
10 #include "string.h"
11 
12 #include "GLSL.std.450.h"
13 
14 #include "ir/SkSLExpressionStatement.h"
15 #include "ir/SkSLExtension.h"
16 #include "ir/SkSLIndexExpression.h"
17 #include "ir/SkSLVariableReference.h"
18 #include "SkSLCompiler.h"
19 
20 namespace SkSL {
21 
22 #define SPIRV_DEBUG 0
23 
24 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
25 
setupIntrinsics()26 void SPIRVCodeGenerator::setupIntrinsics() {
27 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
28                                     GLSLstd450 ## x, GLSLstd450 ## x)
29 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
30                                                              GLSLstd450 ## ifFloat, \
31                                                              GLSLstd450 ## ifInt, \
32                                                              GLSLstd450 ## ifUInt, \
33                                                              SpvOpUndef)
34 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
35                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
36                                    k ## x ## _SpecialIntrinsic)
37     fIntrinsicMap[SkString("round")]         = ALL_GLSL(Round);
38     fIntrinsicMap[SkString("roundEven")]     = ALL_GLSL(RoundEven);
39     fIntrinsicMap[SkString("trunc")]         = ALL_GLSL(Trunc);
40     fIntrinsicMap[SkString("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
41     fIntrinsicMap[SkString("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
42     fIntrinsicMap[SkString("floor")]         = ALL_GLSL(Floor);
43     fIntrinsicMap[SkString("ceil")]          = ALL_GLSL(Ceil);
44     fIntrinsicMap[SkString("fract")]         = ALL_GLSL(Fract);
45     fIntrinsicMap[SkString("radians")]       = ALL_GLSL(Radians);
46     fIntrinsicMap[SkString("degrees")]       = ALL_GLSL(Degrees);
47     fIntrinsicMap[SkString("sin")]           = ALL_GLSL(Sin);
48     fIntrinsicMap[SkString("cos")]           = ALL_GLSL(Cos);
49     fIntrinsicMap[SkString("tan")]           = ALL_GLSL(Tan);
50     fIntrinsicMap[SkString("asin")]          = ALL_GLSL(Asin);
51     fIntrinsicMap[SkString("acos")]          = ALL_GLSL(Acos);
52     fIntrinsicMap[SkString("atan")]          = SPECIAL(Atan);
53     fIntrinsicMap[SkString("sinh")]          = ALL_GLSL(Sinh);
54     fIntrinsicMap[SkString("cosh")]          = ALL_GLSL(Cosh);
55     fIntrinsicMap[SkString("tanh")]          = ALL_GLSL(Tanh);
56     fIntrinsicMap[SkString("asinh")]         = ALL_GLSL(Asinh);
57     fIntrinsicMap[SkString("acosh")]         = ALL_GLSL(Acosh);
58     fIntrinsicMap[SkString("atanh")]         = ALL_GLSL(Atanh);
59     fIntrinsicMap[SkString("pow")]           = ALL_GLSL(Pow);
60     fIntrinsicMap[SkString("exp")]           = ALL_GLSL(Exp);
61     fIntrinsicMap[SkString("log")]           = ALL_GLSL(Log);
62     fIntrinsicMap[SkString("exp2")]          = ALL_GLSL(Exp2);
63     fIntrinsicMap[SkString("log2")]          = ALL_GLSL(Log2);
64     fIntrinsicMap[SkString("sqrt")]          = ALL_GLSL(Sqrt);
65     fIntrinsicMap[SkString("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66     fIntrinsicMap[SkString("determinant")]   = ALL_GLSL(Determinant);
67     fIntrinsicMap[SkString("matrixInverse")] = ALL_GLSL(MatrixInverse);
68     fIntrinsicMap[SkString("mod")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod,
69                                                                SpvOpSMod, SpvOpUMod, SpvOpUndef);
70     fIntrinsicMap[SkString("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
71     fIntrinsicMap[SkString("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
72     fIntrinsicMap[SkString("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
73     fIntrinsicMap[SkString("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
74                                                                SpvOpUndef, SpvOpUndef, SpvOpUndef);
75     fIntrinsicMap[SkString("mix")]           = ALL_GLSL(FMix);
76     fIntrinsicMap[SkString("step")]          = ALL_GLSL(Step);
77     fIntrinsicMap[SkString("smoothstep")]    = ALL_GLSL(SmoothStep);
78     fIntrinsicMap[SkString("fma")]           = ALL_GLSL(Fma);
79     fIntrinsicMap[SkString("frexp")]         = ALL_GLSL(Frexp);
80     fIntrinsicMap[SkString("ldexp")]         = ALL_GLSL(Ldexp);
81 
82 #define PACK(type) fIntrinsicMap[SkString("pack" #type)] = ALL_GLSL(Pack ## type); \
83                    fIntrinsicMap[SkString("unpack" #type)] = ALL_GLSL(Unpack ## type)
84     PACK(Snorm4x8);
85     PACK(Unorm4x8);
86     PACK(Snorm2x16);
87     PACK(Unorm2x16);
88     PACK(Half2x16);
89     PACK(Double2x32);
90     fIntrinsicMap[SkString("length")]      = ALL_GLSL(Length);
91     fIntrinsicMap[SkString("distance")]    = ALL_GLSL(Distance);
92     fIntrinsicMap[SkString("cross")]       = ALL_GLSL(Cross);
93     fIntrinsicMap[SkString("normalize")]   = ALL_GLSL(Normalize);
94     fIntrinsicMap[SkString("faceForward")] = ALL_GLSL(FaceForward);
95     fIntrinsicMap[SkString("reflect")]     = ALL_GLSL(Reflect);
96     fIntrinsicMap[SkString("refract")]     = ALL_GLSL(Refract);
97     fIntrinsicMap[SkString("findLSB")]     = ALL_GLSL(FindILsb);
98     fIntrinsicMap[SkString("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
99     fIntrinsicMap[SkString("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
100                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
101     fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
103     fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
104                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
105     fIntrinsicMap[SkString("texture")]     = SPECIAL(Texture);
106 
107     fIntrinsicMap[SkString("subpassLoad")] = SPECIAL(SubpassLoad);
108 
109     fIntrinsicMap[SkString("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
110                                                                   SpvOpUndef, SpvOpUndef, SpvOpAny);
111     fIntrinsicMap[SkString("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
112                                                                   SpvOpUndef, SpvOpUndef, SpvOpAll);
113     fIntrinsicMap[SkString("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
114                                                                   SpvOpFOrdEqual, SpvOpIEqual,
115                                                                   SpvOpIEqual, SpvOpLogicalEqual);
116     fIntrinsicMap[SkString("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
117                                                                   SpvOpFOrdNotEqual, SpvOpINotEqual,
118                                                                   SpvOpINotEqual,
119                                                                   SpvOpLogicalNotEqual);
120     fIntrinsicMap[SkString("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
121                                                                   SpvOpSLessThan, SpvOpULessThan,
122                                                                   SpvOpFOrdLessThan, SpvOpUndef);
123     fIntrinsicMap[SkString("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
124                                                                   SpvOpSLessThanEqual,
125                                                                   SpvOpULessThanEqual,
126                                                                   SpvOpFOrdLessThanEqual,
127                                                                   SpvOpUndef);
128     fIntrinsicMap[SkString("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
129                                                                   SpvOpSGreaterThan,
130                                                                   SpvOpUGreaterThan,
131                                                                   SpvOpFOrdGreaterThan,
132                                                                   SpvOpUndef);
133     fIntrinsicMap[SkString("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
134                                                                   SpvOpSGreaterThanEqual,
135                                                                   SpvOpUGreaterThanEqual,
136                                                                   SpvOpFOrdGreaterThanEqual,
137                                                                   SpvOpUndef);
138 
139 // interpolateAt* not yet supported...
140 }
141 
writeWord(int32_t word,SkWStream & out)142 void SPIRVCodeGenerator::writeWord(int32_t word, SkWStream& out) {
143 #if SPIRV_DEBUG
144     out << "(" << word << ") ";
145 #else
146     out.write((const char*) &word, sizeof(word));
147 #endif
148 }
149 
is_float(const Context & context,const Type & type)150 static bool is_float(const Context& context, const Type& type) {
151     if (type.kind() == Type::kVector_Kind) {
152         return is_float(context, type.componentType());
153     }
154     return type == *context.fFloat_Type || type == *context.fDouble_Type;
155 }
156 
is_signed(const Context & context,const Type & type)157 static bool is_signed(const Context& context, const Type& type) {
158     if (type.kind() == Type::kVector_Kind) {
159         return is_signed(context, type.componentType());
160     }
161     return type == *context.fInt_Type;
162 }
163 
is_unsigned(const Context & context,const Type & type)164 static bool is_unsigned(const Context& context, const Type& type) {
165     if (type.kind() == Type::kVector_Kind) {
166         return is_unsigned(context, type.componentType());
167     }
168     return type == *context.fUInt_Type;
169 }
170 
is_bool(const Context & context,const Type & type)171 static bool is_bool(const Context& context, const Type& type) {
172     if (type.kind() == Type::kVector_Kind) {
173         return is_bool(context, type.componentType());
174     }
175     return type == *context.fBool_Type;
176 }
177 
is_out(const Variable & var)178 static bool is_out(const Variable& var) {
179     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
180 }
181 
182 #if SPIRV_DEBUG
opcode_text(SpvOp_ opCode)183 static SkString opcode_text(SpvOp_ opCode) {
184     switch (opCode) {
185         case SpvOpNop:
186             return SkString("Nop");
187         case SpvOpUndef:
188             return SkString("Undef");
189         case SpvOpSourceContinued:
190             return SkString("SourceContinued");
191         case SpvOpSource:
192             return SkString("Source");
193         case SpvOpSourceExtension:
194             return SkString("SourceExtension");
195         case SpvOpName:
196             return SkString("Name");
197         case SpvOpMemberName:
198             return SkString("MemberName");
199         case SpvOpString:
200             return SkString("String");
201         case SpvOpLine:
202             return SkString("Line");
203         case SpvOpExtension:
204             return SkString("Extension");
205         case SpvOpExtInstImport:
206             return SkString("ExtInstImport");
207         case SpvOpExtInst:
208             return SkString("ExtInst");
209         case SpvOpMemoryModel:
210             return SkString("MemoryModel");
211         case SpvOpEntryPoint:
212             return SkString("EntryPoint");
213         case SpvOpExecutionMode:
214             return SkString("ExecutionMode");
215         case SpvOpCapability:
216             return SkString("Capability");
217         case SpvOpTypeVoid:
218             return SkString("TypeVoid");
219         case SpvOpTypeBool:
220             return SkString("TypeBool");
221         case SpvOpTypeInt:
222             return SkString("TypeInt");
223         case SpvOpTypeFloat:
224             return SkString("TypeFloat");
225         case SpvOpTypeVector:
226             return SkString("TypeVector");
227         case SpvOpTypeMatrix:
228             return SkString("TypeMatrix");
229         case SpvOpTypeImage:
230             return SkString("TypeImage");
231         case SpvOpTypeSampler:
232             return SkString("TypeSampler");
233         case SpvOpTypeSampledImage:
234             return SkString("TypeSampledImage");
235         case SpvOpTypeArray:
236             return SkString("TypeArray");
237         case SpvOpTypeRuntimeArray:
238             return SkString("TypeRuntimeArray");
239         case SpvOpTypeStruct:
240             return SkString("TypeStruct");
241         case SpvOpTypeOpaque:
242             return SkString("TypeOpaque");
243         case SpvOpTypePointer:
244             return SkString("TypePointer");
245         case SpvOpTypeFunction:
246             return SkString("TypeFunction");
247         case SpvOpTypeEvent:
248             return SkString("TypeEvent");
249         case SpvOpTypeDeviceEvent:
250             return SkString("TypeDeviceEvent");
251         case SpvOpTypeReserveId:
252             return SkString("TypeReserveId");
253         case SpvOpTypeQueue:
254             return SkString("TypeQueue");
255         case SpvOpTypePipe:
256             return SkString("TypePipe");
257         case SpvOpTypeForwardPointer:
258             return SkString("TypeForwardPointer");
259         case SpvOpConstantTrue:
260             return SkString("ConstantTrue");
261         case SpvOpConstantFalse:
262             return SkString("ConstantFalse");
263         case SpvOpConstant:
264             return SkString("Constant");
265         case SpvOpConstantComposite:
266             return SkString("ConstantComposite");
267         case SpvOpConstantSampler:
268             return SkString("ConstantSampler");
269         case SpvOpConstantNull:
270             return SkString("ConstantNull");
271         case SpvOpSpecConstantTrue:
272             return SkString("SpecConstantTrue");
273         case SpvOpSpecConstantFalse:
274             return SkString("SpecConstantFalse");
275         case SpvOpSpecConstant:
276             return SkString("SpecConstant");
277         case SpvOpSpecConstantComposite:
278             return SkString("SpecConstantComposite");
279         case SpvOpSpecConstantOp:
280             return SkString("SpecConstantOp");
281         case SpvOpFunction:
282             return SkString("Function");
283         case SpvOpFunctionParameter:
284             return SkString("FunctionParameter");
285         case SpvOpFunctionEnd:
286             return SkString("FunctionEnd");
287         case SpvOpFunctionCall:
288             return SkString("FunctionCall");
289         case SpvOpVariable:
290             return SkString("Variable");
291         case SpvOpImageTexelPointer:
292             return SkString("ImageTexelPointer");
293         case SpvOpLoad:
294             return SkString("Load");
295         case SpvOpStore:
296             return SkString("Store");
297         case SpvOpCopyMemory:
298             return SkString("CopyMemory");
299         case SpvOpCopyMemorySized:
300             return SkString("CopyMemorySized");
301         case SpvOpAccessChain:
302             return SkString("AccessChain");
303         case SpvOpInBoundsAccessChain:
304             return SkString("InBoundsAccessChain");
305         case SpvOpPtrAccessChain:
306             return SkString("PtrAccessChain");
307         case SpvOpArrayLength:
308             return SkString("ArrayLength");
309         case SpvOpGenericPtrMemSemantics:
310             return SkString("GenericPtrMemSemantics");
311         case SpvOpInBoundsPtrAccessChain:
312             return SkString("InBoundsPtrAccessChain");
313         case SpvOpDecorate:
314             return SkString("Decorate");
315         case SpvOpMemberDecorate:
316             return SkString("MemberDecorate");
317         case SpvOpDecorationGroup:
318             return SkString("DecorationGroup");
319         case SpvOpGroupDecorate:
320             return SkString("GroupDecorate");
321         case SpvOpGroupMemberDecorate:
322             return SkString("GroupMemberDecorate");
323         case SpvOpVectorExtractDynamic:
324             return SkString("VectorExtractDynamic");
325         case SpvOpVectorInsertDynamic:
326             return SkString("VectorInsertDynamic");
327         case SpvOpVectorShuffle:
328             return SkString("VectorShuffle");
329         case SpvOpCompositeConstruct:
330             return SkString("CompositeConstruct");
331         case SpvOpCompositeExtract:
332             return SkString("CompositeExtract");
333         case SpvOpCompositeInsert:
334             return SkString("CompositeInsert");
335         case SpvOpCopyObject:
336             return SkString("CopyObject");
337         case SpvOpTranspose:
338             return SkString("Transpose");
339         case SpvOpSampledImage:
340             return SkString("SampledImage");
341         case SpvOpImageSampleImplicitLod:
342             return SkString("ImageSampleImplicitLod");
343         case SpvOpImageSampleExplicitLod:
344             return SkString("ImageSampleExplicitLod");
345         case SpvOpImageSampleDrefImplicitLod:
346             return SkString("ImageSampleDrefImplicitLod");
347         case SpvOpImageSampleDrefExplicitLod:
348             return SkString("ImageSampleDrefExplicitLod");
349         case SpvOpImageSampleProjImplicitLod:
350             return SkString("ImageSampleProjImplicitLod");
351         case SpvOpImageSampleProjExplicitLod:
352             return SkString("ImageSampleProjExplicitLod");
353         case SpvOpImageSampleProjDrefImplicitLod:
354             return SkString("ImageSampleProjDrefImplicitLod");
355         case SpvOpImageSampleProjDrefExplicitLod:
356             return SkString("ImageSampleProjDrefExplicitLod");
357         case SpvOpImageFetch:
358             return SkString("ImageFetch");
359         case SpvOpImageGather:
360             return SkString("ImageGather");
361         case SpvOpImageDrefGather:
362             return SkString("ImageDrefGather");
363         case SpvOpImageRead:
364             return SkString("ImageRead");
365         case SpvOpImageWrite:
366             return SkString("ImageWrite");
367         case SpvOpImage:
368             return SkString("Image");
369         case SpvOpImageQueryFormat:
370             return SkString("ImageQueryFormat");
371         case SpvOpImageQueryOrder:
372             return SkString("ImageQueryOrder");
373         case SpvOpImageQuerySizeLod:
374             return SkString("ImageQuerySizeLod");
375         case SpvOpImageQuerySize:
376             return SkString("ImageQuerySize");
377         case SpvOpImageQueryLod:
378             return SkString("ImageQueryLod");
379         case SpvOpImageQueryLevels:
380             return SkString("ImageQueryLevels");
381         case SpvOpImageQuerySamples:
382             return SkString("ImageQuerySamples");
383         case SpvOpConvertFToU:
384             return SkString("ConvertFToU");
385         case SpvOpConvertFToS:
386             return SkString("ConvertFToS");
387         case SpvOpConvertSToF:
388             return SkString("ConvertSToF");
389         case SpvOpConvertUToF:
390             return SkString("ConvertUToF");
391         case SpvOpUConvert:
392             return SkString("UConvert");
393         case SpvOpSConvert:
394             return SkString("SConvert");
395         case SpvOpFConvert:
396             return SkString("FConvert");
397         case SpvOpQuantizeToF16:
398             return SkString("QuantizeToF16");
399         case SpvOpConvertPtrToU:
400             return SkString("ConvertPtrToU");
401         case SpvOpSatConvertSToU:
402             return SkString("SatConvertSToU");
403         case SpvOpSatConvertUToS:
404             return SkString("SatConvertUToS");
405         case SpvOpConvertUToPtr:
406             return SkString("ConvertUToPtr");
407         case SpvOpPtrCastToGeneric:
408             return SkString("PtrCastToGeneric");
409         case SpvOpGenericCastToPtr:
410             return SkString("GenericCastToPtr");
411         case SpvOpGenericCastToPtrExplicit:
412             return SkString("GenericCastToPtrExplicit");
413         case SpvOpBitcast:
414             return SkString("Bitcast");
415         case SpvOpSNegate:
416             return SkString("SNegate");
417         case SpvOpFNegate:
418             return SkString("FNegate");
419         case SpvOpIAdd:
420             return SkString("IAdd");
421         case SpvOpFAdd:
422             return SkString("FAdd");
423         case SpvOpISub:
424             return SkString("ISub");
425         case SpvOpFSub:
426             return SkString("FSub");
427         case SpvOpIMul:
428             return SkString("IMul");
429         case SpvOpFMul:
430             return SkString("FMul");
431         case SpvOpUDiv:
432             return SkString("UDiv");
433         case SpvOpSDiv:
434             return SkString("SDiv");
435         case SpvOpFDiv:
436             return SkString("FDiv");
437         case SpvOpUMod:
438             return SkString("UMod");
439         case SpvOpSRem:
440             return SkString("SRem");
441         case SpvOpSMod:
442             return SkString("SMod");
443         case SpvOpFRem:
444             return SkString("FRem");
445         case SpvOpFMod:
446             return SkString("FMod");
447         case SpvOpVectorTimesScalar:
448             return SkString("VectorTimesScalar");
449         case SpvOpMatrixTimesScalar:
450             return SkString("MatrixTimesScalar");
451         case SpvOpVectorTimesMatrix:
452             return SkString("VectorTimesMatrix");
453         case SpvOpMatrixTimesVector:
454             return SkString("MatrixTimesVector");
455         case SpvOpMatrixTimesMatrix:
456             return SkString("MatrixTimesMatrix");
457         case SpvOpOuterProduct:
458             return SkString("OuterProduct");
459         case SpvOpDot:
460             return SkString("Dot");
461         case SpvOpIAddCarry:
462             return SkString("IAddCarry");
463         case SpvOpISubBorrow:
464             return SkString("ISubBorrow");
465         case SpvOpUMulExtended:
466             return SkString("UMulExtended");
467         case SpvOpSMulExtended:
468             return SkString("SMulExtended");
469         case SpvOpAny:
470             return SkString("Any");
471         case SpvOpAll:
472             return SkString("All");
473         case SpvOpIsNan:
474             return SkString("IsNan");
475         case SpvOpIsInf:
476             return SkString("IsInf");
477         case SpvOpIsFinite:
478             return SkString("IsFinite");
479         case SpvOpIsNormal:
480             return SkString("IsNormal");
481         case SpvOpSignBitSet:
482             return SkString("SignBitSet");
483         case SpvOpLessOrGreater:
484             return SkString("LessOrGreater");
485         case SpvOpOrdered:
486             return SkString("Ordered");
487         case SpvOpUnordered:
488             return SkString("Unordered");
489         case SpvOpLogicalEqual:
490             return SkString("LogicalEqual");
491         case SpvOpLogicalNotEqual:
492             return SkString("LogicalNotEqual");
493         case SpvOpLogicalOr:
494             return SkString("LogicalOr");
495         case SpvOpLogicalAnd:
496             return SkString("LogicalAnd");
497         case SpvOpLogicalNot:
498             return SkString("LogicalNot");
499         case SpvOpSelect:
500             return SkString("Select");
501         case SpvOpIEqual:
502             return SkString("IEqual");
503         case SpvOpINotEqual:
504             return SkString("INotEqual");
505         case SpvOpUGreaterThan:
506             return SkString("UGreaterThan");
507         case SpvOpSGreaterThan:
508             return SkString("SGreaterThan");
509         case SpvOpUGreaterThanEqual:
510             return SkString("UGreaterThanEqual");
511         case SpvOpSGreaterThanEqual:
512             return SkString("SGreaterThanEqual");
513         case SpvOpULessThan:
514             return SkString("ULessThan");
515         case SpvOpSLessThan:
516             return SkString("SLessThan");
517         case SpvOpULessThanEqual:
518             return SkString("ULessThanEqual");
519         case SpvOpSLessThanEqual:
520             return SkString("SLessThanEqual");
521         case SpvOpFOrdEqual:
522             return SkString("FOrdEqual");
523         case SpvOpFUnordEqual:
524             return SkString("FUnordEqual");
525         case SpvOpFOrdNotEqual:
526             return SkString("FOrdNotEqual");
527         case SpvOpFUnordNotEqual:
528             return SkString("FUnordNotEqual");
529         case SpvOpFOrdLessThan:
530             return SkString("FOrdLessThan");
531         case SpvOpFUnordLessThan:
532             return SkString("FUnordLessThan");
533         case SpvOpFOrdGreaterThan:
534             return SkString("FOrdGreaterThan");
535         case SpvOpFUnordGreaterThan:
536             return SkString("FUnordGreaterThan");
537         case SpvOpFOrdLessThanEqual:
538             return SkString("FOrdLessThanEqual");
539         case SpvOpFUnordLessThanEqual:
540             return SkString("FUnordLessThanEqual");
541         case SpvOpFOrdGreaterThanEqual:
542             return SkString("FOrdGreaterThanEqual");
543         case SpvOpFUnordGreaterThanEqual:
544             return SkString("FUnordGreaterThanEqual");
545         case SpvOpShiftRightLogical:
546             return SkString("ShiftRightLogical");
547         case SpvOpShiftRightArithmetic:
548             return SkString("ShiftRightArithmetic");
549         case SpvOpShiftLeftLogical:
550             return SkString("ShiftLeftLogical");
551         case SpvOpBitwiseOr:
552             return SkString("BitwiseOr");
553         case SpvOpBitwiseXor:
554             return SkString("BitwiseXor");
555         case SpvOpBitwiseAnd:
556             return SkString("BitwiseAnd");
557         case SpvOpNot:
558             return SkString("Not");
559         case SpvOpBitFieldInsert:
560             return SkString("BitFieldInsert");
561         case SpvOpBitFieldSExtract:
562             return SkString("BitFieldSExtract");
563         case SpvOpBitFieldUExtract:
564             return SkString("BitFieldUExtract");
565         case SpvOpBitReverse:
566             return SkString("BitReverse");
567         case SpvOpBitCount:
568             return SkString("BitCount");
569         case SpvOpDPdx:
570             return SkString("DPdx");
571         case SpvOpDPdy:
572             return SkString("DPdy");
573         case SpvOpFwidth:
574             return SkString("Fwidth");
575         case SpvOpDPdxFine:
576             return SkString("DPdxFine");
577         case SpvOpDPdyFine:
578             return SkString("DPdyFine");
579         case SpvOpFwidthFine:
580             return SkString("FwidthFine");
581         case SpvOpDPdxCoarse:
582             return SkString("DPdxCoarse");
583         case SpvOpDPdyCoarse:
584             return SkString("DPdyCoarse");
585         case SpvOpFwidthCoarse:
586             return SkString("FwidthCoarse");
587         case SpvOpEmitVertex:
588             return SkString("EmitVertex");
589         case SpvOpEndPrimitive:
590             return SkString("EndPrimitive");
591         case SpvOpEmitStreamVertex:
592             return SkString("EmitStreamVertex");
593         case SpvOpEndStreamPrimitive:
594             return SkString("EndStreamPrimitive");
595         case SpvOpControlBarrier:
596             return SkString("ControlBarrier");
597         case SpvOpMemoryBarrier:
598             return SkString("MemoryBarrier");
599         case SpvOpAtomicLoad:
600             return SkString("AtomicLoad");
601         case SpvOpAtomicStore:
602             return SkString("AtomicStore");
603         case SpvOpAtomicExchange:
604             return SkString("AtomicExchange");
605         case SpvOpAtomicCompareExchange:
606             return SkString("AtomicCompareExchange");
607         case SpvOpAtomicCompareExchangeWeak:
608             return SkString("AtomicCompareExchangeWeak");
609         case SpvOpAtomicIIncrement:
610             return SkString("AtomicIIncrement");
611         case SpvOpAtomicIDecrement:
612             return SkString("AtomicIDecrement");
613         case SpvOpAtomicIAdd:
614             return SkString("AtomicIAdd");
615         case SpvOpAtomicISub:
616             return SkString("AtomicISub");
617         case SpvOpAtomicSMin:
618             return SkString("AtomicSMin");
619         case SpvOpAtomicUMin:
620             return SkString("AtomicUMin");
621         case SpvOpAtomicSMax:
622             return SkString("AtomicSMax");
623         case SpvOpAtomicUMax:
624             return SkString("AtomicUMax");
625         case SpvOpAtomicAnd:
626             return SkString("AtomicAnd");
627         case SpvOpAtomicOr:
628             return SkString("AtomicOr");
629         case SpvOpAtomicXor:
630             return SkString("AtomicXor");
631         case SpvOpPhi:
632             return SkString("Phi");
633         case SpvOpLoopMerge:
634             return SkString("LoopMerge");
635         case SpvOpSelectionMerge:
636             return SkString("SelectionMerge");
637         case SpvOpLabel:
638             return SkString("Label");
639         case SpvOpBranch:
640             return SkString("Branch");
641         case SpvOpBranchConditional:
642             return SkString("BranchConditional");
643         case SpvOpSwitch:
644             return SkString("Switch");
645         case SpvOpKill:
646             return SkString("Kill");
647         case SpvOpReturn:
648             return SkString("Return");
649         case SpvOpReturnValue:
650             return SkString("ReturnValue");
651         case SpvOpUnreachable:
652             return SkString("Unreachable");
653         case SpvOpLifetimeStart:
654             return SkString("LifetimeStart");
655         case SpvOpLifetimeStop:
656             return SkString("LifetimeStop");
657         case SpvOpGroupAsyncCopy:
658             return SkString("GroupAsyncCopy");
659         case SpvOpGroupWaitEvents:
660             return SkString("GroupWaitEvents");
661         case SpvOpGroupAll:
662             return SkString("GroupAll");
663         case SpvOpGroupAny:
664             return SkString("GroupAny");
665         case SpvOpGroupBroadcast:
666             return SkString("GroupBroadcast");
667         case SpvOpGroupIAdd:
668             return SkString("GroupIAdd");
669         case SpvOpGroupFAdd:
670             return SkString("GroupFAdd");
671         case SpvOpGroupFMin:
672             return SkString("GroupFMin");
673         case SpvOpGroupUMin:
674             return SkString("GroupUMin");
675         case SpvOpGroupSMin:
676             return SkString("GroupSMin");
677         case SpvOpGroupFMax:
678             return SkString("GroupFMax");
679         case SpvOpGroupUMax:
680             return SkString("GroupUMax");
681         case SpvOpGroupSMax:
682             return SkString("GroupSMax");
683         case SpvOpReadPipe:
684             return SkString("ReadPipe");
685         case SpvOpWritePipe:
686             return SkString("WritePipe");
687         case SpvOpReservedReadPipe:
688             return SkString("ReservedReadPipe");
689         case SpvOpReservedWritePipe:
690             return SkString("ReservedWritePipe");
691         case SpvOpReserveReadPipePackets:
692             return SkString("ReserveReadPipePackets");
693         case SpvOpReserveWritePipePackets:
694             return SkString("ReserveWritePipePackets");
695         case SpvOpCommitReadPipe:
696             return SkString("CommitReadPipe");
697         case SpvOpCommitWritePipe:
698             return SkString("CommitWritePipe");
699         case SpvOpIsValidReserveId:
700             return SkString("IsValidReserveId");
701         case SpvOpGetNumPipePackets:
702             return SkString("GetNumPipePackets");
703         case SpvOpGetMaxPipePackets:
704             return SkString("GetMaxPipePackets");
705         case SpvOpGroupReserveReadPipePackets:
706             return SkString("GroupReserveReadPipePackets");
707         case SpvOpGroupReserveWritePipePackets:
708             return SkString("GroupReserveWritePipePackets");
709         case SpvOpGroupCommitReadPipe:
710             return SkString("GroupCommitReadPipe");
711         case SpvOpGroupCommitWritePipe:
712             return SkString("GroupCommitWritePipe");
713         case SpvOpEnqueueMarker:
714             return SkString("EnqueueMarker");
715         case SpvOpEnqueueKernel:
716             return SkString("EnqueueKernel");
717         case SpvOpGetKernelNDrangeSubGroupCount:
718             return SkString("GetKernelNDrangeSubGroupCount");
719         case SpvOpGetKernelNDrangeMaxSubGroupSize:
720             return SkString("GetKernelNDrangeMaxSubGroupSize");
721         case SpvOpGetKernelWorkGroupSize:
722             return SkString("GetKernelWorkGroupSize");
723         case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
724             return SkString("GetKernelPreferredWorkGroupSizeMultiple");
725         case SpvOpRetainEvent:
726             return SkString("RetainEvent");
727         case SpvOpReleaseEvent:
728             return SkString("ReleaseEvent");
729         case SpvOpCreateUserEvent:
730             return SkString("CreateUserEvent");
731         case SpvOpIsValidEvent:
732             return SkString("IsValidEvent");
733         case SpvOpSetUserEventStatus:
734             return SkString("SetUserEventStatus");
735         case SpvOpCaptureEventProfilingInfo:
736             return SkString("CaptureEventProfilingInfo");
737         case SpvOpGetDefaultQueue:
738             return SkString("GetDefaultQueue");
739         case SpvOpBuildNDRange:
740             return SkString("BuildNDRange");
741         case SpvOpImageSparseSampleImplicitLod:
742             return SkString("ImageSparseSampleImplicitLod");
743         case SpvOpImageSparseSampleExplicitLod:
744             return SkString("ImageSparseSampleExplicitLod");
745         case SpvOpImageSparseSampleDrefImplicitLod:
746             return SkString("ImageSparseSampleDrefImplicitLod");
747         case SpvOpImageSparseSampleDrefExplicitLod:
748             return SkString("ImageSparseSampleDrefExplicitLod");
749         case SpvOpImageSparseSampleProjImplicitLod:
750             return SkString("ImageSparseSampleProjImplicitLod");
751         case SpvOpImageSparseSampleProjExplicitLod:
752             return SkString("ImageSparseSampleProjExplicitLod");
753         case SpvOpImageSparseSampleProjDrefImplicitLod:
754             return SkString("ImageSparseSampleProjDrefImplicitLod");
755         case SpvOpImageSparseSampleProjDrefExplicitLod:
756             return SkString("ImageSparseSampleProjDrefExplicitLod");
757         case SpvOpImageSparseFetch:
758             return SkString("ImageSparseFetch");
759         case SpvOpImageSparseGather:
760             return SkString("ImageSparseGather");
761         case SpvOpImageSparseDrefGather:
762             return SkString("ImageSparseDrefGather");
763         case SpvOpImageSparseTexelsResident:
764             return SkString("ImageSparseTexelsResident");
765         case SpvOpNoLine:
766             return SkString("NoLine");
767         case SpvOpAtomicFlagTestAndSet:
768             return SkString("AtomicFlagTestAndSet");
769         case SpvOpAtomicFlagClear:
770             return SkString("AtomicFlagClear");
771         case SpvOpImageSparseRead:
772             return SkString("ImageSparseRead");
773         default:
774             ABORT("unsupported SPIR-V op");
775     }
776 }
777 #endif
778 
writeOpCode(SpvOp_ opCode,int length,SkWStream & out)779 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, SkWStream& out) {
780     ASSERT(opCode != SpvOpUndef);
781     switch (opCode) {
782         case SpvOpReturn:      // fall through
783         case SpvOpReturnValue: // fall through
784         case SpvOpKill:        // fall through
785         case SpvOpBranch:      // fall through
786         case SpvOpBranchConditional:
787             ASSERT(fCurrentBlock);
788             fCurrentBlock = 0;
789             break;
790         case SpvOpConstant:          // fall through
791         case SpvOpConstantTrue:      // fall through
792         case SpvOpConstantFalse:     // fall through
793         case SpvOpConstantComposite: // fall through
794         case SpvOpTypeVoid:          // fall through
795         case SpvOpTypeInt:           // fall through
796         case SpvOpTypeFloat:         // fall through
797         case SpvOpTypeBool:          // fall through
798         case SpvOpTypeVector:        // fall through
799         case SpvOpTypeMatrix:        // fall through
800         case SpvOpTypeArray:         // fall through
801         case SpvOpTypePointer:       // fall through
802         case SpvOpTypeFunction:      // fall through
803         case SpvOpTypeRuntimeArray:  // fall through
804         case SpvOpTypeStruct:        // fall through
805         case SpvOpTypeImage:         // fall through
806         case SpvOpTypeSampledImage:  // fall through
807         case SpvOpVariable:          // fall through
808         case SpvOpFunction:          // fall through
809         case SpvOpFunctionParameter: // fall through
810         case SpvOpFunctionEnd:       // fall through
811         case SpvOpExecutionMode:     // fall through
812         case SpvOpMemoryModel:       // fall through
813         case SpvOpCapability:        // fall through
814         case SpvOpExtInstImport:     // fall through
815         case SpvOpEntryPoint:        // fall through
816         case SpvOpSource:            // fall through
817         case SpvOpSourceExtension:   // fall through
818         case SpvOpName:              // fall through
819         case SpvOpMemberName:        // fall through
820         case SpvOpDecorate:          // fall through
821         case SpvOpMemberDecorate:
822             break;
823         default:
824             ASSERT(fCurrentBlock);
825     }
826 #if SPIRV_DEBUG
827     out << std::endl << opcode_text(opCode) << " ";
828 #else
829     this->writeWord((length << 16) | opCode, out);
830 #endif
831 }
832 
writeLabel(SpvId label,SkWStream & out)833 void SPIRVCodeGenerator::writeLabel(SpvId label, SkWStream& out) {
834     fCurrentBlock = label;
835     this->writeInstruction(SpvOpLabel, label, out);
836 }
837 
writeInstruction(SpvOp_ opCode,SkWStream & out)838 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, SkWStream& out) {
839     this->writeOpCode(opCode, 1, out);
840 }
841 
writeInstruction(SpvOp_ opCode,int32_t word1,SkWStream & out)842 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, SkWStream& out) {
843     this->writeOpCode(opCode, 2, out);
844     this->writeWord(word1, out);
845 }
846 
writeString(const char * string,SkWStream & out)847 void SPIRVCodeGenerator::writeString(const char* string, SkWStream& out) {
848     size_t length = strlen(string);
849     out.writeText(string);
850     switch (length % 4) {
851         case 1:
852             out.write8(0);
853             // fall through
854         case 2:
855             out.write8(0);
856             // fall through
857         case 3:
858             out.write8(0);
859             break;
860         default:
861             this->writeWord(0, out);
862     }
863 }
864 
writeInstruction(SpvOp_ opCode,const char * string,SkWStream & out)865 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, SkWStream& out) {
866     int32_t length = (int32_t) strlen(string);
867     this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
868     this->writeString(string, out);
869 }
870 
871 
writeInstruction(SpvOp_ opCode,int32_t word1,const char * string,SkWStream & out)872 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
873                                           SkWStream& out) {
874     int32_t length = (int32_t) strlen(string);
875     this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
876     this->writeWord(word1, out);
877     this->writeString(string, out);
878 }
879 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,const char * string,SkWStream & out)880 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
881                                           const char* string, SkWStream& out) {
882     int32_t length = (int32_t) strlen(string);
883     this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
884     this->writeWord(word1, out);
885     this->writeWord(word2, out);
886     this->writeString(string, out);
887 }
888 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,SkWStream & out)889 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
890                                           SkWStream& out) {
891     this->writeOpCode(opCode, 3, out);
892     this->writeWord(word1, out);
893     this->writeWord(word2, out);
894 }
895 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,SkWStream & out)896 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
897                                           int32_t word3, SkWStream& out) {
898     this->writeOpCode(opCode, 4, out);
899     this->writeWord(word1, out);
900     this->writeWord(word2, out);
901     this->writeWord(word3, out);
902 }
903 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,SkWStream & out)904 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
905                                           int32_t word3, int32_t word4, SkWStream& out) {
906     this->writeOpCode(opCode, 5, out);
907     this->writeWord(word1, out);
908     this->writeWord(word2, out);
909     this->writeWord(word3, out);
910     this->writeWord(word4, out);
911 }
912 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,SkWStream & out)913 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
914                                           int32_t word3, int32_t word4, int32_t word5,
915                                           SkWStream& out) {
916     this->writeOpCode(opCode, 6, out);
917     this->writeWord(word1, out);
918     this->writeWord(word2, out);
919     this->writeWord(word3, out);
920     this->writeWord(word4, out);
921     this->writeWord(word5, out);
922 }
923 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,SkWStream & out)924 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
925                                           int32_t word3, int32_t word4, int32_t word5,
926                                           int32_t word6, SkWStream& out) {
927     this->writeOpCode(opCode, 7, out);
928     this->writeWord(word1, out);
929     this->writeWord(word2, out);
930     this->writeWord(word3, out);
931     this->writeWord(word4, out);
932     this->writeWord(word5, out);
933     this->writeWord(word6, out);
934 }
935 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,SkWStream & out)936 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
937                                           int32_t word3, int32_t word4, int32_t word5,
938                                           int32_t word6, int32_t word7, SkWStream& out) {
939     this->writeOpCode(opCode, 8, out);
940     this->writeWord(word1, out);
941     this->writeWord(word2, out);
942     this->writeWord(word3, out);
943     this->writeWord(word4, out);
944     this->writeWord(word5, out);
945     this->writeWord(word6, out);
946     this->writeWord(word7, out);
947 }
948 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,SkWStream & out)949 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
950                                           int32_t word3, int32_t word4, int32_t word5,
951                                           int32_t word6, int32_t word7, int32_t word8,
952                                           SkWStream& out) {
953     this->writeOpCode(opCode, 9, out);
954     this->writeWord(word1, out);
955     this->writeWord(word2, out);
956     this->writeWord(word3, out);
957     this->writeWord(word4, out);
958     this->writeWord(word5, out);
959     this->writeWord(word6, out);
960     this->writeWord(word7, out);
961     this->writeWord(word8, out);
962 }
963 
writeCapabilities(SkWStream & out)964 void SPIRVCodeGenerator::writeCapabilities(SkWStream& out) {
965     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
966         if (fCapabilities & bit) {
967             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
968         }
969     }
970 }
971 
nextId()972 SpvId SPIRVCodeGenerator::nextId() {
973     return fIdCount++;
974 }
975 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)976 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
977                                      SpvId resultId) {
978     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
979     // go ahead and write all of the field types, so we don't inadvertently write them while we're
980     // in the middle of writing the struct instruction
981     std::vector<SpvId> types;
982     for (const auto& f : type.fields()) {
983         types.push_back(this->getType(*f.fType, memoryLayout));
984     }
985     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
986     this->writeWord(resultId, fConstantBuffer);
987     for (SpvId id : types) {
988         this->writeWord(id, fConstantBuffer);
989     }
990     size_t offset = 0;
991     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
992         size_t size = memoryLayout.size(*type.fields()[i].fType);
993         size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
994         const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
995         if (fieldLayout.fOffset >= 0) {
996             if (fieldLayout.fOffset < (int) offset) {
997                 fErrors.error(type.fPosition,
998                               "offset of field '" + type.fields()[i].fName + "' must be at "
999                               "least " + to_string((int) offset));
1000             }
1001             if (fieldLayout.fOffset % alignment) {
1002                 fErrors.error(type.fPosition,
1003                               "offset of field '" + type.fields()[i].fName + "' must be a multiple"
1004                               " of " + to_string((int) alignment));
1005             }
1006             offset = fieldLayout.fOffset;
1007         } else {
1008             size_t mod = offset % alignment;
1009             if (mod) {
1010                 offset += alignment - mod;
1011             }
1012         }
1013         this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
1014                                fNameBuffer);
1015         this->writeLayout(fieldLayout, resultId, i);
1016         if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
1017             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1018                                    (SpvId) offset, fDecorationBuffer);
1019         }
1020         if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
1021             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1022                                    fDecorationBuffer);
1023             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1024                                    (SpvId) memoryLayout.stride(*type.fields()[i].fType),
1025                                    fDecorationBuffer);
1026         }
1027         offset += size;
1028         Type::Kind kind = type.fields()[i].fType->kind();
1029         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
1030             offset += alignment - offset % alignment;
1031         }
1032     }
1033 }
1034 
getType(const Type & type)1035 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1036     return this->getType(type, fDefaultLayout);
1037 }
1038 
getType(const Type & type,const MemoryLayout & layout)1039 SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) {
1040     SkString key = type.name() + to_string((int) layout.fStd);
1041     auto entry = fTypeMap.find(key);
1042     if (entry == fTypeMap.end()) {
1043         SpvId result = this->nextId();
1044         switch (type.kind()) {
1045             case Type::kScalar_Kind:
1046                 if (type == *fContext.fBool_Type) {
1047                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
1048                 } else if (type == *fContext.fInt_Type) {
1049                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
1050                 } else if (type == *fContext.fUInt_Type) {
1051                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
1052                 } else if (type == *fContext.fFloat_Type) {
1053                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
1054                 } else if (type == *fContext.fDouble_Type) {
1055                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
1056                 } else {
1057                     ASSERT(false);
1058                 }
1059                 break;
1060             case Type::kVector_Kind:
1061                 this->writeInstruction(SpvOpTypeVector, result,
1062                                        this->getType(type.componentType(), layout),
1063                                        type.columns(), fConstantBuffer);
1064                 break;
1065             case Type::kMatrix_Kind:
1066                 this->writeInstruction(SpvOpTypeMatrix, result,
1067                                        this->getType(index_type(fContext, type), layout),
1068                                        type.columns(), fConstantBuffer);
1069                 break;
1070             case Type::kStruct_Kind:
1071                 this->writeStruct(type, layout, result);
1072                 break;
1073             case Type::kArray_Kind: {
1074                 if (type.columns() > 0) {
1075                     IntLiteral count(fContext, Position(), type.columns());
1076                     this->writeInstruction(SpvOpTypeArray, result,
1077                                            this->getType(type.componentType(), layout),
1078                                            this->writeIntLiteral(count), fConstantBuffer);
1079                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1080                                            (int32_t) layout.stride(type),
1081                                            fDecorationBuffer);
1082                 } else {
1083                     ABORT("runtime-sized arrays are not yet supported");
1084                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
1085                                            this->getType(type.componentType(), layout),
1086                                            fConstantBuffer);
1087                 }
1088                 break;
1089             }
1090             case Type::kSampler_Kind: {
1091                 SpvId image = result;
1092                 if (SpvDimSubpassData != type.dimensions()) {
1093                     image = this->nextId();
1094                 }
1095                 this->writeInstruction(SpvOpTypeImage, image,
1096                                        this->getType(*fContext.fFloat_Type, layout),
1097                                        type.dimensions(), type.isDepth(), type.isArrayed(),
1098                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
1099                                        SpvImageFormatUnknown, fConstantBuffer);
1100                 if (SpvDimSubpassData != type.dimensions()) {
1101                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
1102                 }
1103                 break;
1104             }
1105             default:
1106                 if (type == *fContext.fVoid_Type) {
1107                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
1108                 } else {
1109                     ABORT("invalid type: %s", type.description().c_str());
1110                 }
1111         }
1112         fTypeMap[key] = result;
1113         return result;
1114     }
1115     return entry->second;
1116 }
1117 
getFunctionType(const FunctionDeclaration & function)1118 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1119     SkString key = function.fReturnType.description() + "(";
1120     SkString separator;
1121     for (size_t i = 0; i < function.fParameters.size(); i++) {
1122         key += separator;
1123         separator = ", ";
1124         key += function.fParameters[i]->fType.description();
1125     }
1126     key += ")";
1127     auto entry = fTypeMap.find(key);
1128     if (entry == fTypeMap.end()) {
1129         SpvId result = this->nextId();
1130         int32_t length = 3 + (int32_t) function.fParameters.size();
1131         SpvId returnType = this->getType(function.fReturnType);
1132         std::vector<SpvId> parameterTypes;
1133         for (size_t i = 0; i < function.fParameters.size(); i++) {
1134             // glslang seems to treat all function arguments as pointers whether they need to be or
1135             // not. I  was initially puzzled by this until I ran bizarre failures with certain
1136             // patterns of function calls and control constructs, as exemplified by this minimal
1137             // failure case:
1138             //
1139             // void sphere(float x) {
1140             // }
1141             //
1142             // void map() {
1143             //     sphere(1.0);
1144             // }
1145             //
1146             // void main() {
1147             //     for (int i = 0; i < 1; i++) {
1148             //         map();
1149             //     }
1150             // }
1151             //
1152             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1153             // crashes. Making it take a float* and storing the argument in a temporary variable,
1154             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
1155             // the spec makes this make sense.
1156 //            if (is_out(function->fParameters[i])) {
1157                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
1158                                                               SpvStorageClassFunction));
1159 //            } else {
1160 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
1161 //            }
1162         }
1163         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
1164         this->writeWord(result, fConstantBuffer);
1165         this->writeWord(returnType, fConstantBuffer);
1166         for (SpvId id : parameterTypes) {
1167             this->writeWord(id, fConstantBuffer);
1168         }
1169         fTypeMap[key] = result;
1170         return result;
1171     }
1172     return entry->second;
1173 }
1174 
getPointerType(const Type & type,SpvStorageClass_ storageClass)1175 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1176     return this->getPointerType(type, fDefaultLayout, storageClass);
1177 }
1178 
getPointerType(const Type & type,const MemoryLayout & layout,SpvStorageClass_ storageClass)1179 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1180                                          SpvStorageClass_ storageClass) {
1181     SkString key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
1182     auto entry = fTypeMap.find(key);
1183     if (entry == fTypeMap.end()) {
1184         SpvId result = this->nextId();
1185         this->writeInstruction(SpvOpTypePointer, result, storageClass,
1186                                this->getType(type), fConstantBuffer);
1187         fTypeMap[key] = result;
1188         return result;
1189     }
1190     return entry->second;
1191 }
1192 
writeExpression(const Expression & expr,SkWStream & out)1193 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, SkWStream& out) {
1194     switch (expr.fKind) {
1195         case Expression::kBinary_Kind:
1196             return this->writeBinaryExpression((BinaryExpression&) expr, out);
1197         case Expression::kBoolLiteral_Kind:
1198             return this->writeBoolLiteral((BoolLiteral&) expr);
1199         case Expression::kConstructor_Kind:
1200             return this->writeConstructor((Constructor&) expr, out);
1201         case Expression::kIntLiteral_Kind:
1202             return this->writeIntLiteral((IntLiteral&) expr);
1203         case Expression::kFieldAccess_Kind:
1204             return this->writeFieldAccess(((FieldAccess&) expr), out);
1205         case Expression::kFloatLiteral_Kind:
1206             return this->writeFloatLiteral(((FloatLiteral&) expr));
1207         case Expression::kFunctionCall_Kind:
1208             return this->writeFunctionCall((FunctionCall&) expr, out);
1209         case Expression::kPrefix_Kind:
1210             return this->writePrefixExpression((PrefixExpression&) expr, out);
1211         case Expression::kPostfix_Kind:
1212             return this->writePostfixExpression((PostfixExpression&) expr, out);
1213         case Expression::kSwizzle_Kind:
1214             return this->writeSwizzle((Swizzle&) expr, out);
1215         case Expression::kVariableReference_Kind:
1216             return this->writeVariableReference((VariableReference&) expr, out);
1217         case Expression::kTernary_Kind:
1218             return this->writeTernaryExpression((TernaryExpression&) expr, out);
1219         case Expression::kIndex_Kind:
1220             return this->writeIndexExpression((IndexExpression&) expr, out);
1221         default:
1222             ABORT("unsupported expression: %s", expr.description().c_str());
1223     }
1224     return -1;
1225 }
1226 
writeIntrinsicCall(const FunctionCall & c,SkWStream & out)1227 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, SkWStream& out) {
1228     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
1229     ASSERT(intrinsic != fIntrinsicMap.end());
1230     const Type& type = c.fArguments[0]->fType;
1231     int32_t intrinsicId;
1232     if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
1233         intrinsicId = std::get<1>(intrinsic->second);
1234     } else if (is_signed(fContext, type)) {
1235         intrinsicId = std::get<2>(intrinsic->second);
1236     } else if (is_unsigned(fContext, type)) {
1237         intrinsicId = std::get<3>(intrinsic->second);
1238     } else if (is_bool(fContext, type)) {
1239         intrinsicId = std::get<4>(intrinsic->second);
1240     } else {
1241         ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
1242               type.description().c_str());
1243     }
1244     switch (std::get<0>(intrinsic->second)) {
1245         case kGLSL_STD_450_IntrinsicKind: {
1246             SpvId result = this->nextId();
1247             std::vector<SpvId> arguments;
1248             for (size_t i = 0; i < c.fArguments.size(); i++) {
1249                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1250             }
1251             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1252             this->writeWord(this->getType(c.fType), out);
1253             this->writeWord(result, out);
1254             this->writeWord(fGLSLExtendedInstructions, out);
1255             this->writeWord(intrinsicId, out);
1256             for (SpvId id : arguments) {
1257                 this->writeWord(id, out);
1258             }
1259             return result;
1260         }
1261         case kSPIRV_IntrinsicKind: {
1262             SpvId result = this->nextId();
1263             std::vector<SpvId> arguments;
1264             for (size_t i = 0; i < c.fArguments.size(); i++) {
1265                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1266             }
1267             this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1268             this->writeWord(this->getType(c.fType), out);
1269             this->writeWord(result, out);
1270             for (SpvId id : arguments) {
1271                 this->writeWord(id, out);
1272             }
1273             return result;
1274         }
1275         case kSpecial_IntrinsicKind:
1276             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1277         default:
1278             ABORT("unsupported intrinsic kind");
1279     }
1280 }
1281 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SkWStream & out)1282 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1283                                                 SkWStream& out) {
1284     SpvId result = this->nextId();
1285     switch (kind) {
1286         case kAtan_SpecialIntrinsic: {
1287             std::vector<SpvId> arguments;
1288             for (size_t i = 0; i < c.fArguments.size(); i++) {
1289                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1290             }
1291             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1292             this->writeWord(this->getType(c.fType), out);
1293             this->writeWord(result, out);
1294             this->writeWord(fGLSLExtendedInstructions, out);
1295             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1296             for (SpvId id : arguments) {
1297                 this->writeWord(id, out);
1298             }
1299             return result;
1300         }
1301         case kTexture_SpecialIntrinsic: {
1302             SpvOp_ op = SpvOpImageSampleImplicitLod;
1303             switch (c.fArguments[0]->fType.dimensions()) {
1304                 case SpvDim1D:
1305                     if (c.fArguments[1]->fType == *fContext.fVec2_Type) {
1306                         op = SpvOpImageSampleProjImplicitLod;
1307                     } else {
1308                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
1309                     }
1310                     break;
1311                 case SpvDim2D:
1312                     if (c.fArguments[1]->fType == *fContext.fVec3_Type) {
1313                         op = SpvOpImageSampleProjImplicitLod;
1314                     } else {
1315                         ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type);
1316                     }
1317                     break;
1318                 case SpvDim3D:
1319                     if (c.fArguments[1]->fType == *fContext.fVec4_Type) {
1320                         op = SpvOpImageSampleProjImplicitLod;
1321                     } else {
1322                         ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type);
1323                     }
1324                     break;
1325                 case SpvDimCube:   // fall through
1326                 case SpvDimRect:   // fall through
1327                 case SpvDimBuffer: // fall through
1328                 case SpvDimSubpassData:
1329                     break;
1330             }
1331             SpvId type = this->getType(c.fType);
1332             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1333             SpvId uv = this->writeExpression(*c.fArguments[1], out);
1334             if (c.fArguments.size() == 3) {
1335                 this->writeInstruction(op, type, result, sampler, uv,
1336                                        SpvImageOperandsBiasMask,
1337                                        this->writeExpression(*c.fArguments[2], out),
1338                                        out);
1339             } else {
1340                 ASSERT(c.fArguments.size() == 2);
1341                 this->writeInstruction(op, type, result, sampler, uv,
1342                                        out);
1343             }
1344             break;
1345         }
1346         case kSubpassLoad_SpecialIntrinsic: {
1347             SpvId img = this->writeExpression(*c.fArguments[0], out);
1348             std::vector<std::unique_ptr<Expression>> args;
1349             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1350             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1351             Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args));
1352             SpvId coords = this->writeConstantVector(ctor);
1353             if (1 == c.fArguments.size()) {
1354                 this->writeInstruction(SpvOpImageRead,
1355                                        this->getType(c.fType),
1356                                        result,
1357                                        img,
1358                                        coords,
1359                                        out);
1360             } else {
1361                 SkASSERT(2 == c.fArguments.size());
1362                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
1363                 this->writeInstruction(SpvOpImageRead,
1364                                        this->getType(c.fType),
1365                                        result,
1366                                        img,
1367                                        coords,
1368                                        SpvImageOperandsSampleMask,
1369                                        sample,
1370                                        out);
1371             }
1372             break;
1373         }
1374     }
1375     return result;
1376 }
1377 
writeFunctionCall(const FunctionCall & c,SkWStream & out)1378 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, SkWStream& out) {
1379     const auto& entry = fFunctionMap.find(&c.fFunction);
1380     if (entry == fFunctionMap.end()) {
1381         return this->writeIntrinsicCall(c, out);
1382     }
1383     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1384     std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
1385     std::vector<SpvId> arguments;
1386     for (size_t i = 0; i < c.fArguments.size(); i++) {
1387         // id of temporary variable that we will use to hold this argument, or 0 if it is being
1388         // passed directly
1389         SpvId tmpVar;
1390         // if we need a temporary var to store this argument, this is the value to store in the var
1391         SpvId tmpValueId;
1392         if (is_out(*c.fFunction.fParameters[i])) {
1393             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1394             SpvId ptr = lv->getPointer();
1395             if (ptr) {
1396                 arguments.push_back(ptr);
1397                 continue;
1398             } else {
1399                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1400                 // copy it into a temp, call the function, read the value out of the temp, and then
1401                 // update the lvalue.
1402                 tmpValueId = lv->load(out);
1403                 tmpVar = this->nextId();
1404                 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
1405                                   std::move(lv)));
1406             }
1407         } else {
1408             // see getFunctionType for an explanation of why we're always using pointer parameters
1409             tmpValueId = this->writeExpression(*c.fArguments[i], out);
1410             tmpVar = this->nextId();
1411         }
1412         this->writeInstruction(SpvOpVariable,
1413                                this->getPointerType(c.fArguments[i]->fType,
1414                                                     SpvStorageClassFunction),
1415                                tmpVar,
1416                                SpvStorageClassFunction,
1417                                fVariableBuffer);
1418         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1419         arguments.push_back(tmpVar);
1420     }
1421     SpvId result = this->nextId();
1422     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1423     this->writeWord(this->getType(c.fType), out);
1424     this->writeWord(result, out);
1425     this->writeWord(entry->second, out);
1426     for (SpvId id : arguments) {
1427         this->writeWord(id, out);
1428     }
1429     // now that the call is complete, we may need to update some lvalues with the new values of out
1430     // arguments
1431     for (const auto& tuple : lvalues) {
1432         SpvId load = this->nextId();
1433         this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1434         std::get<2>(tuple)->store(load, out);
1435     }
1436     return result;
1437 }
1438 
writeConstantVector(const Constructor & c)1439 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1440     ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1441     SpvId result = this->nextId();
1442     std::vector<SpvId> arguments;
1443     for (size_t i = 0; i < c.fArguments.size(); i++) {
1444         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1445     }
1446     SpvId type = this->getType(c.fType);
1447     if (c.fArguments.size() == 1) {
1448         // with a single argument, a vector will have all of its entries equal to the argument
1449         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1450         this->writeWord(type, fConstantBuffer);
1451         this->writeWord(result, fConstantBuffer);
1452         for (int i = 0; i < c.fType.columns(); i++) {
1453             this->writeWord(arguments[0], fConstantBuffer);
1454         }
1455     } else {
1456         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1457                           fConstantBuffer);
1458         this->writeWord(type, fConstantBuffer);
1459         this->writeWord(result, fConstantBuffer);
1460         for (SpvId id : arguments) {
1461             this->writeWord(id, fConstantBuffer);
1462         }
1463     }
1464     return result;
1465 }
1466 
writeFloatConstructor(const Constructor & c,SkWStream & out)1467 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, SkWStream& out) {
1468     ASSERT(c.fType == *fContext.fFloat_Type);
1469     ASSERT(c.fArguments.size() == 1);
1470     ASSERT(c.fArguments[0]->fType.isNumber());
1471     SpvId result = this->nextId();
1472     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1473     if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1474         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1475                                out);
1476     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1477         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1478                                out);
1479     } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1480         return parameter;
1481     }
1482     return result;
1483 }
1484 
writeIntConstructor(const Constructor & c,SkWStream & out)1485 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, SkWStream& out) {
1486     ASSERT(c.fType == *fContext.fInt_Type);
1487     ASSERT(c.fArguments.size() == 1);
1488     ASSERT(c.fArguments[0]->fType.isNumber());
1489     SpvId result = this->nextId();
1490     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1491     if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1492         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1493                                out);
1494     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1495         this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter,
1496                                out);
1497     } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1498         return parameter;
1499     }
1500     return result;
1501 }
1502 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,SkWStream & out)1503 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1504                                                  SkWStream& out) {
1505     FloatLiteral zero(fContext, Position(), 0);
1506     SpvId zeroId = this->writeFloatLiteral(zero);
1507     std::vector<SpvId> columnIds;
1508     for (int column = 0; column < type.columns(); column++) {
1509         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1510                           out);
1511         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1512                         out);
1513         SpvId columnId = this->nextId();
1514         this->writeWord(columnId, out);
1515         columnIds.push_back(columnId);
1516         for (int row = 0; row < type.columns(); row++) {
1517             this->writeWord(row == column ? diagonal : zeroId, out);
1518         }
1519     }
1520     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1521                       out);
1522     this->writeWord(this->getType(type), out);
1523     this->writeWord(id, out);
1524     for (SpvId id : columnIds) {
1525         this->writeWord(id, out);
1526     }
1527 }
1528 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,SkWStream & out)1529 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1530                                          const Type& dstType, SkWStream& out) {
1531     ABORT("unimplemented");
1532 }
1533 
writeMatrixConstructor(const Constructor & c,SkWStream & out)1534 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, SkWStream& out) {
1535     ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1536     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1537     // an instruction
1538     std::vector<SpvId> arguments;
1539     for (size_t i = 0; i < c.fArguments.size(); i++) {
1540         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1541     }
1542     SpvId result = this->nextId();
1543     int rows = c.fType.rows();
1544     int columns = c.fType.columns();
1545     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1546         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1547     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1548         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1549     } else {
1550         std::vector<SpvId> columnIds;
1551         int currentCount = 0;
1552         for (size_t i = 0; i < arguments.size(); i++) {
1553             if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1554                 ASSERT(currentCount == 0);
1555                 columnIds.push_back(arguments[i]);
1556                 currentCount = 0;
1557             } else {
1558                 ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
1559                 if (currentCount == 0) {
1560                     this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
1561                     this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1562                                                                                      1)),
1563                                     out);
1564                     SpvId id = this->nextId();
1565                     this->writeWord(id, out);
1566                     columnIds.push_back(id);
1567                 }
1568                 this->writeWord(arguments[i], out);
1569                 currentCount = (currentCount + 1) % rows;
1570             }
1571         }
1572         ASSERT(columnIds.size() == (size_t) columns);
1573         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1574         this->writeWord(this->getType(c.fType), out);
1575         this->writeWord(result, out);
1576         for (SpvId id : columnIds) {
1577             this->writeWord(id, out);
1578         }
1579     }
1580     return result;
1581 }
1582 
writeVectorConstructor(const Constructor & c,SkWStream & out)1583 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, SkWStream& out) {
1584     ASSERT(c.fType.kind() == Type::kVector_Kind);
1585     if (c.isConstant()) {
1586         return this->writeConstantVector(c);
1587     }
1588     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1589     // an instruction
1590     std::vector<SpvId> arguments;
1591     for (size_t i = 0; i < c.fArguments.size(); i++) {
1592         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1593     }
1594     SpvId result = this->nextId();
1595     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1596         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1597         this->writeWord(this->getType(c.fType), out);
1598         this->writeWord(result, out);
1599         for (int i = 0; i < c.fType.columns(); i++) {
1600             this->writeWord(arguments[0], out);
1601         }
1602     } else {
1603         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1604         this->writeWord(this->getType(c.fType), out);
1605         this->writeWord(result, out);
1606         for (SpvId id : arguments) {
1607             this->writeWord(id, out);
1608         }
1609     }
1610     return result;
1611 }
1612 
writeConstructor(const Constructor & c,SkWStream & out)1613 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, SkWStream& out) {
1614     if (c.fType == *fContext.fFloat_Type) {
1615         return this->writeFloatConstructor(c, out);
1616     } else if (c.fType == *fContext.fInt_Type) {
1617         return this->writeIntConstructor(c, out);
1618     }
1619     switch (c.fType.kind()) {
1620         case Type::kVector_Kind:
1621             return this->writeVectorConstructor(c, out);
1622         case Type::kMatrix_Kind:
1623             return this->writeMatrixConstructor(c, out);
1624         default:
1625             ABORT("unsupported constructor: %s", c.description().c_str());
1626     }
1627 }
1628 
get_storage_class(const Modifiers & modifiers)1629 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1630     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1631         ASSERT(!modifiers.fLayout.fPushConstant);
1632         return SpvStorageClassInput;
1633     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1634         ASSERT(!modifiers.fLayout.fPushConstant);
1635         return SpvStorageClassOutput;
1636     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1637         if (modifiers.fLayout.fPushConstant) {
1638             return SpvStorageClassPushConstant;
1639         }
1640         return SpvStorageClassUniform;
1641     } else {
1642         return SpvStorageClassFunction;
1643     }
1644 }
1645 
get_storage_class(const Expression & expr)1646 SpvStorageClass_ get_storage_class(const Expression& expr) {
1647     switch (expr.fKind) {
1648         case Expression::kVariableReference_Kind:
1649             return get_storage_class(((VariableReference&) expr).fVariable.fModifiers);
1650         case Expression::kFieldAccess_Kind:
1651             return get_storage_class(*((FieldAccess&) expr).fBase);
1652         case Expression::kIndex_Kind:
1653             return get_storage_class(*((IndexExpression&) expr).fBase);
1654         default:
1655             return SpvStorageClassFunction;
1656     }
1657 }
1658 
getAccessChain(const Expression & expr,SkWStream & out)1659 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, SkWStream& out) {
1660     std::vector<SpvId> chain;
1661     switch (expr.fKind) {
1662         case Expression::kIndex_Kind: {
1663             IndexExpression& indexExpr = (IndexExpression&) expr;
1664             chain = this->getAccessChain(*indexExpr.fBase, out);
1665             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1666             break;
1667         }
1668         case Expression::kFieldAccess_Kind: {
1669             FieldAccess& fieldExpr = (FieldAccess&) expr;
1670             chain = this->getAccessChain(*fieldExpr.fBase, out);
1671             IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
1672             chain.push_back(this->writeIntLiteral(index));
1673             break;
1674         }
1675         default:
1676             chain.push_back(this->getLValue(expr, out)->getPointer());
1677     }
1678     return chain;
1679 }
1680 
1681 class PointerLValue : public SPIRVCodeGenerator::LValue {
1682 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1683     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1684     : fGen(gen)
1685     , fPointer(pointer)
1686     , fType(type) {}
1687 
getPointer()1688     virtual SpvId getPointer() override {
1689         return fPointer;
1690     }
1691 
load(SkWStream & out)1692     virtual SpvId load(SkWStream& out) override {
1693         SpvId result = fGen.nextId();
1694         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1695         return result;
1696     }
1697 
store(SpvId value,SkWStream & out)1698     virtual void store(SpvId value, SkWStream& out) override {
1699         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1700     }
1701 
1702 private:
1703     SPIRVCodeGenerator& fGen;
1704     const SpvId fPointer;
1705     const SpvId fType;
1706 };
1707 
1708 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1709 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1710     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1711                   const Type& baseType, const Type& swizzleType)
1712     : fGen(gen)
1713     , fVecPointer(vecPointer)
1714     , fComponents(components)
1715     , fBaseType(baseType)
1716     , fSwizzleType(swizzleType) {}
1717 
getPointer()1718     virtual SpvId getPointer() override {
1719         return 0;
1720     }
1721 
load(SkWStream & out)1722     virtual SpvId load(SkWStream& out) override {
1723         SpvId base = fGen.nextId();
1724         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1725         SpvId result = fGen.nextId();
1726         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1727         fGen.writeWord(fGen.getType(fSwizzleType), out);
1728         fGen.writeWord(result, out);
1729         fGen.writeWord(base, out);
1730         fGen.writeWord(base, out);
1731         for (int component : fComponents) {
1732             fGen.writeWord(component, out);
1733         }
1734         return result;
1735     }
1736 
store(SpvId value,SkWStream & out)1737     virtual void store(SpvId value, SkWStream& out) override {
1738         // use OpVectorShuffle to mix and match the vector components. We effectively create
1739         // a virtual vector out of the concatenation of the left and right vectors, and then
1740         // select components from this virtual vector to make the result vector. For
1741         // instance, given:
1742         // vec3 L = ...;
1743         // vec3 R = ...;
1744         // L.xz = R.xy;
1745         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1746         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1747         // (3, 1, 4).
1748         SpvId base = fGen.nextId();
1749         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1750         SpvId shuffle = fGen.nextId();
1751         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1752         fGen.writeWord(fGen.getType(fBaseType), out);
1753         fGen.writeWord(shuffle, out);
1754         fGen.writeWord(base, out);
1755         fGen.writeWord(value, out);
1756         for (int i = 0; i < fBaseType.columns(); i++) {
1757             // current offset into the virtual vector, defaults to pulling the unmodified
1758             // value from the left side
1759             int offset = i;
1760             // check to see if we are writing this component
1761             for (size_t j = 0; j < fComponents.size(); j++) {
1762                 if (fComponents[j] == i) {
1763                     // we're writing to this component, so adjust the offset to pull from
1764                     // the correct component of the right side instead of preserving the
1765                     // value from the left
1766                     offset = (int) (j + fBaseType.columns());
1767                     break;
1768                 }
1769             }
1770             fGen.writeWord(offset, out);
1771         }
1772         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1773     }
1774 
1775 private:
1776     SPIRVCodeGenerator& fGen;
1777     const SpvId fVecPointer;
1778     const std::vector<int>& fComponents;
1779     const Type& fBaseType;
1780     const Type& fSwizzleType;
1781 };
1782 
getLValue(const Expression & expr,SkWStream & out)1783 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1784                                                                           SkWStream& out) {
1785     switch (expr.fKind) {
1786         case Expression::kVariableReference_Kind: {
1787             const Variable& var = ((VariableReference&) expr).fVariable;
1788             auto entry = fVariableMap.find(&var);
1789             ASSERT(entry != fVariableMap.end());
1790             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1791                                                                        *this,
1792                                                                        entry->second,
1793                                                                        this->getType(expr.fType)));
1794         }
1795         case Expression::kIndex_Kind: // fall through
1796         case Expression::kFieldAccess_Kind: {
1797             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1798             SpvId member = this->nextId();
1799             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1800             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1801             this->writeWord(member, out);
1802             for (SpvId idx : chain) {
1803                 this->writeWord(idx, out);
1804             }
1805             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1806                                                                        *this,
1807                                                                        member,
1808                                                                        this->getType(expr.fType)));
1809         }
1810 
1811         case Expression::kSwizzle_Kind: {
1812             Swizzle& swizzle = (Swizzle&) expr;
1813             size_t count = swizzle.fComponents.size();
1814             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1815             ASSERT(base);
1816             if (count == 1) {
1817                 IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
1818                 SpvId member = this->nextId();
1819                 this->writeInstruction(SpvOpAccessChain,
1820                                        this->getPointerType(swizzle.fType,
1821                                                             get_storage_class(*swizzle.fBase)),
1822                                        member,
1823                                        base,
1824                                        this->writeIntLiteral(index),
1825                                        out);
1826                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1827                                                                        *this,
1828                                                                        member,
1829                                                                        this->getType(expr.fType)));
1830             } else {
1831                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1832                                                                               *this,
1833                                                                               base,
1834                                                                               swizzle.fComponents,
1835                                                                               swizzle.fBase->fType,
1836                                                                               expr.fType));
1837             }
1838         }
1839 
1840         default:
1841             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1842             // to the need to store values in temporary variables during function calls (see
1843             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1844             // caught by IRGenerator
1845             SpvId result = this->nextId();
1846             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1847             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1848                                    fVariableBuffer);
1849             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1850             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1851                                                                        *this,
1852                                                                        result,
1853                                                                        this->getType(expr.fType)));
1854     }
1855 }
1856 
writeVariableReference(const VariableReference & ref,SkWStream & out)1857 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, SkWStream& out) {
1858     SpvId result = this->nextId();
1859     auto entry = fVariableMap.find(&ref.fVariable);
1860     ASSERT(entry != fVariableMap.end());
1861     SpvId var = entry->second;
1862     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1863     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1864         fProgram.fSettings.fFlipY) {
1865         // need to remap to a top-left coordinate system
1866         if (fRTHeightStructId == (SpvId) -1) {
1867             // height variable hasn't been written yet
1868             std::shared_ptr<SymbolTable> st(new SymbolTable(fErrors));
1869             ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1870             std::vector<Type::Field> fields;
1871             fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME),
1872                                 fContext.fFloat_Type.get());
1873             SkString name("sksl_synthetic_uniforms");
1874             Type intfStruct(Position(), name, fields);
1875             Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false, Layout::Format::kUnspecified,
1876                           false, Layout::kUnspecified_Primitive, -1, -1);
1877             Variable intfVar(Position(), Modifiers(layout, Modifiers::kUniform_Flag), name,
1878                              intfStruct, Variable::kGlobal_Storage);
1879             InterfaceBlock intf(Position(), intfVar, name, SkString(""),
1880                                 std::vector<std::unique_ptr<Expression>>(), st);
1881             fRTHeightStructId = this->writeInterfaceBlock(intf);
1882             fRTHeightFieldIndex = 0;
1883         }
1884         ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1885         // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1886         SpvId xId = this->nextId();
1887         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1888                                result, 0, out);
1889         IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex);
1890         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1891         SpvId heightPtr = this->nextId();
1892         this->writeOpCode(SpvOpAccessChain, 5, out);
1893         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1894         this->writeWord(heightPtr, out);
1895         this->writeWord(fRTHeightStructId, out);
1896         this->writeWord(fieldIndexId, out);
1897         SpvId heightRead = this->nextId();
1898         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1899                                heightPtr, out);
1900         SpvId rawYId = this->nextId();
1901         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1902                                result, 1, out);
1903         SpvId flippedYId = this->nextId();
1904         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1905                                heightRead, rawYId, out);
1906         FloatLiteral zero(fContext, Position(), 0.0);
1907         SpvId zeroId = writeFloatLiteral(zero);
1908         FloatLiteral one(fContext, Position(), 1.0);
1909         SpvId oneId = writeFloatLiteral(one);
1910         SpvId flipped = this->nextId();
1911         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1912         this->writeWord(this->getType(*fContext.fVec4_Type), out);
1913         this->writeWord(flipped, out);
1914         this->writeWord(xId, out);
1915         this->writeWord(flippedYId, out);
1916         this->writeWord(zeroId, out);
1917         this->writeWord(oneId, out);
1918         return flipped;
1919     }
1920     return result;
1921 }
1922 
writeIndexExpression(const IndexExpression & expr,SkWStream & out)1923 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, SkWStream& out) {
1924     return getLValue(expr, out)->load(out);
1925 }
1926 
writeFieldAccess(const FieldAccess & f,SkWStream & out)1927 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, SkWStream& out) {
1928     return getLValue(f, out)->load(out);
1929 }
1930 
writeSwizzle(const Swizzle & swizzle,SkWStream & out)1931 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, SkWStream& out) {
1932     SpvId base = this->writeExpression(*swizzle.fBase, out);
1933     SpvId result = this->nextId();
1934     size_t count = swizzle.fComponents.size();
1935     if (count == 1) {
1936         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1937                                swizzle.fComponents[0], out);
1938     } else {
1939         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1940         this->writeWord(this->getType(swizzle.fType), out);
1941         this->writeWord(result, out);
1942         this->writeWord(base, out);
1943         this->writeWord(base, out);
1944         for (int component : swizzle.fComponents) {
1945             this->writeWord(component, out);
1946         }
1947     }
1948     return result;
1949 }
1950 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,SkWStream & out)1951 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1952                                                const Type& operandType, SpvId lhs,
1953                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1954                                                SpvOp_ ifUInt, SpvOp_ ifBool, SkWStream& out) {
1955     SpvId result = this->nextId();
1956     if (is_float(fContext, operandType)) {
1957         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1958     } else if (is_signed(fContext, operandType)) {
1959         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1960     } else if (is_unsigned(fContext, operandType)) {
1961         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1962     } else if (operandType == *fContext.fBool_Type) {
1963         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1964     } else {
1965         ABORT("invalid operandType: %s", operandType.description().c_str());
1966     }
1967     return result;
1968 }
1969 
is_assignment(Token::Kind op)1970 bool is_assignment(Token::Kind op) {
1971     switch (op) {
1972         case Token::EQ:           // fall through
1973         case Token::PLUSEQ:       // fall through
1974         case Token::MINUSEQ:      // fall through
1975         case Token::STAREQ:       // fall through
1976         case Token::SLASHEQ:      // fall through
1977         case Token::PERCENTEQ:    // fall through
1978         case Token::SHLEQ:        // fall through
1979         case Token::SHREQ:        // fall through
1980         case Token::BITWISEOREQ:  // fall through
1981         case Token::BITWISEXOREQ: // fall through
1982         case Token::BITWISEANDEQ: // fall through
1983         case Token::LOGICALOREQ:  // fall through
1984         case Token::LOGICALXOREQ: // fall through
1985         case Token::LOGICALANDEQ:
1986             return true;
1987         default:
1988             return false;
1989     }
1990 }
1991 
foldToBool(SpvId id,const Type & operandType,SkWStream & out)1992 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SkWStream& out) {
1993     if (operandType.kind() == Type::kVector_Kind) {
1994         SpvId result = this->nextId();
1995         this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
1996         return result;
1997     }
1998     return id;
1999 }
2000 
writeBinaryExpression(const BinaryExpression & b,SkWStream & out)2001 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, SkWStream& out) {
2002     // handle cases where we don't necessarily evaluate both LHS and RHS
2003     switch (b.fOperator) {
2004         case Token::EQ: {
2005             SpvId rhs = this->writeExpression(*b.fRight, out);
2006             this->getLValue(*b.fLeft, out)->store(rhs, out);
2007             return rhs;
2008         }
2009         case Token::LOGICALAND:
2010             return this->writeLogicalAnd(b, out);
2011         case Token::LOGICALOR:
2012             return this->writeLogicalOr(b, out);
2013         default:
2014             break;
2015     }
2016 
2017     // "normal" operators
2018     const Type& resultType = b.fType;
2019     std::unique_ptr<LValue> lvalue;
2020     SpvId lhs;
2021     if (is_assignment(b.fOperator)) {
2022         lvalue = this->getLValue(*b.fLeft, out);
2023         lhs = lvalue->load(out);
2024     } else {
2025         lvalue = nullptr;
2026         lhs = this->writeExpression(*b.fLeft, out);
2027     }
2028     SpvId rhs = this->writeExpression(*b.fRight, out);
2029     // component type we are operating on: float, int, uint
2030     const Type* operandType;
2031     // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
2032     // in SPIR-V
2033     if (b.fLeft->fType != b.fRight->fType) {
2034         if (b.fLeft->fType.kind() == Type::kVector_Kind &&
2035             b.fRight->fType.isNumber()) {
2036             // promote number to vector
2037             SpvId vec = this->nextId();
2038             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2039             this->writeWord(this->getType(resultType), out);
2040             this->writeWord(vec, out);
2041             for (int i = 0; i < resultType.columns(); i++) {
2042                 this->writeWord(rhs, out);
2043             }
2044             rhs = vec;
2045             operandType = &b.fRight->fType;
2046         } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
2047                    b.fLeft->fType.isNumber()) {
2048             // promote number to vector
2049             SpvId vec = this->nextId();
2050             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2051             this->writeWord(this->getType(resultType), out);
2052             this->writeWord(vec, out);
2053             for (int i = 0; i < resultType.columns(); i++) {
2054                 this->writeWord(lhs, out);
2055             }
2056             lhs = vec;
2057             ASSERT(!lvalue);
2058             operandType = &b.fLeft->fType;
2059         } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2060             SpvOp_ op;
2061             if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2062                 op = SpvOpMatrixTimesMatrix;
2063             } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2064                 op = SpvOpMatrixTimesVector;
2065             } else {
2066                 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2067                 op = SpvOpMatrixTimesScalar;
2068             }
2069             SpvId result = this->nextId();
2070             this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2071             if (b.fOperator == Token::STAREQ) {
2072                 lvalue->store(result, out);
2073             } else {
2074                 ASSERT(b.fOperator == Token::STAR);
2075             }
2076             return result;
2077         } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2078             SpvId result = this->nextId();
2079             if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2080                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2081                                        lhs, rhs, out);
2082             } else {
2083                 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2084                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2085                                        lhs, out);
2086             }
2087             if (b.fOperator == Token::STAREQ) {
2088                 lvalue->store(result, out);
2089             } else {
2090                 ASSERT(b.fOperator == Token::STAR);
2091             }
2092             return result;
2093         } else {
2094             ABORT("unsupported binary expression: %s", b.description().c_str());
2095         }
2096     } else {
2097         operandType = &b.fLeft->fType;
2098         ASSERT(*operandType == b.fRight->fType);
2099     }
2100     switch (b.fOperator) {
2101         case Token::EQEQ: {
2102             ASSERT(resultType == *fContext.fBool_Type);
2103             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2104                                                                SpvOpFOrdEqual, SpvOpIEqual,
2105                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2106                                     *operandType, out);
2107         }
2108         case Token::NEQ:
2109             ASSERT(resultType == *fContext.fBool_Type);
2110             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2111                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2112                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2113                                                                out),
2114                                     *operandType, out);
2115         case Token::GT:
2116             ASSERT(resultType == *fContext.fBool_Type);
2117             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2118                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2119                                               SpvOpUGreaterThan, SpvOpUndef, out);
2120         case Token::LT:
2121             ASSERT(resultType == *fContext.fBool_Type);
2122             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2123                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2124         case Token::GTEQ:
2125             ASSERT(resultType == *fContext.fBool_Type);
2126             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2127                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2128                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2129         case Token::LTEQ:
2130             ASSERT(resultType == *fContext.fBool_Type);
2131             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2132                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2133                                               SpvOpULessThanEqual, SpvOpUndef, out);
2134         case Token::PLUS:
2135             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2136                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2137         case Token::MINUS:
2138             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2139                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2140         case Token::STAR:
2141             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2142                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2143                 // matrix multiply
2144                 SpvId result = this->nextId();
2145                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2146                                        lhs, rhs, out);
2147                 return result;
2148             }
2149             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2150                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2151         case Token::SLASH:
2152             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2153                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2154         case Token::PLUSEQ: {
2155             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2156                                                       SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2157             ASSERT(lvalue);
2158             lvalue->store(result, out);
2159             return result;
2160         }
2161         case Token::MINUSEQ: {
2162             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2163                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
2164             ASSERT(lvalue);
2165             lvalue->store(result, out);
2166             return result;
2167         }
2168         case Token::STAREQ: {
2169             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2170                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2171                 // matrix multiply
2172                 SpvId result = this->nextId();
2173                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2174                                        lhs, rhs, out);
2175                 ASSERT(lvalue);
2176                 lvalue->store(result, out);
2177                 return result;
2178             }
2179             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2180                                                       SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2181             ASSERT(lvalue);
2182             lvalue->store(result, out);
2183             return result;
2184         }
2185         case Token::SLASHEQ: {
2186             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2187                                                       SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2188             ASSERT(lvalue);
2189             lvalue->store(result, out);
2190             return result;
2191         }
2192         default:
2193             // FIXME: missing support for some operators (bitwise, &&=, ||=, shift...)
2194             ABORT("unsupported binary expression: %s", b.description().c_str());
2195     }
2196 }
2197 
writeLogicalAnd(const BinaryExpression & a,SkWStream & out)2198 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, SkWStream& out) {
2199     ASSERT(a.fOperator == Token::LOGICALAND);
2200     BoolLiteral falseLiteral(fContext, Position(), false);
2201     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2202     SpvId lhs = this->writeExpression(*a.fLeft, out);
2203     SpvId rhsLabel = this->nextId();
2204     SpvId end = this->nextId();
2205     SpvId lhsBlock = fCurrentBlock;
2206     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2207     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2208     this->writeLabel(rhsLabel, out);
2209     SpvId rhs = this->writeExpression(*a.fRight, out);
2210     SpvId rhsBlock = fCurrentBlock;
2211     this->writeInstruction(SpvOpBranch, end, out);
2212     this->writeLabel(end, out);
2213     SpvId result = this->nextId();
2214     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2215                            lhsBlock, rhs, rhsBlock, out);
2216     return result;
2217 }
2218 
writeLogicalOr(const BinaryExpression & o,SkWStream & out)2219 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, SkWStream& out) {
2220     ASSERT(o.fOperator == Token::LOGICALOR);
2221     BoolLiteral trueLiteral(fContext, Position(), true);
2222     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2223     SpvId lhs = this->writeExpression(*o.fLeft, out);
2224     SpvId rhsLabel = this->nextId();
2225     SpvId end = this->nextId();
2226     SpvId lhsBlock = fCurrentBlock;
2227     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2228     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2229     this->writeLabel(rhsLabel, out);
2230     SpvId rhs = this->writeExpression(*o.fRight, out);
2231     SpvId rhsBlock = fCurrentBlock;
2232     this->writeInstruction(SpvOpBranch, end, out);
2233     this->writeLabel(end, out);
2234     SpvId result = this->nextId();
2235     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2236                            lhsBlock, rhs, rhsBlock, out);
2237     return result;
2238 }
2239 
writeTernaryExpression(const TernaryExpression & t,SkWStream & out)2240 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, SkWStream& out) {
2241     SpvId test = this->writeExpression(*t.fTest, out);
2242     if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2243         // both true and false are constants, can just use OpSelect
2244         SpvId result = this->nextId();
2245         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2246         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2247         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2248                                out);
2249         return result;
2250     }
2251     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2252     // Adreno. Switched to storing the result in a temp variable as glslang does.
2253     SpvId var = this->nextId();
2254     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2255                            var, SpvStorageClassFunction, fVariableBuffer);
2256     SpvId trueLabel = this->nextId();
2257     SpvId falseLabel = this->nextId();
2258     SpvId end = this->nextId();
2259     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2260     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2261     this->writeLabel(trueLabel, out);
2262     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2263     this->writeInstruction(SpvOpBranch, end, out);
2264     this->writeLabel(falseLabel, out);
2265     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2266     this->writeInstruction(SpvOpBranch, end, out);
2267     this->writeLabel(end, out);
2268     SpvId result = this->nextId();
2269     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2270     return result;
2271 }
2272 
create_literal_1(const Context & context,const Type & type)2273 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2274     if (type == *context.fInt_Type) {
2275         return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
2276     }
2277     else if (type == *context.fFloat_Type) {
2278         return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
2279     } else {
2280         ABORT("math is unsupported on type '%s'")
2281     }
2282 }
2283 
writePrefixExpression(const PrefixExpression & p,SkWStream & out)2284 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, SkWStream& out) {
2285     if (p.fOperator == Token::MINUS) {
2286         SpvId result = this->nextId();
2287         SpvId typeId = this->getType(p.fType);
2288         SpvId expr = this->writeExpression(*p.fOperand, out);
2289         if (is_float(fContext, p.fType)) {
2290             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2291         } else if (is_signed(fContext, p.fType)) {
2292             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2293         } else {
2294             ABORT("unsupported prefix expression %s", p.description().c_str());
2295         };
2296         return result;
2297     }
2298     switch (p.fOperator) {
2299         case Token::PLUS:
2300             return this->writeExpression(*p.fOperand, out);
2301         case Token::PLUSPLUS: {
2302             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2303             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2304             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2305                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2306                                                       out);
2307             lv->store(result, out);
2308             return result;
2309         }
2310         case Token::MINUSMINUS: {
2311             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2312             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2313             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2314                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2315                                                       out);
2316             lv->store(result, out);
2317             return result;
2318         }
2319         case Token::LOGICALNOT: {
2320             ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2321             SpvId result = this->nextId();
2322             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2323                                    this->writeExpression(*p.fOperand, out), out);
2324             return result;
2325         }
2326         case Token::BITWISENOT: {
2327             SpvId result = this->nextId();
2328             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2329                                    this->writeExpression(*p.fOperand, out), out);
2330             return result;
2331         }
2332         default:
2333             ABORT("unsupported prefix expression: %s", p.description().c_str());
2334     }
2335 }
2336 
writePostfixExpression(const PostfixExpression & p,SkWStream & out)2337 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, SkWStream& out) {
2338     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2339     SpvId result = lv->load(out);
2340     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2341     switch (p.fOperator) {
2342         case Token::PLUSPLUS: {
2343             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2344                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2345             lv->store(temp, out);
2346             return result;
2347         }
2348         case Token::MINUSMINUS: {
2349             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2350                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2351             lv->store(temp, out);
2352             return result;
2353         }
2354         default:
2355             ABORT("unsupported postfix expression %s", p.description().c_str());
2356     }
2357 }
2358 
writeBoolLiteral(const BoolLiteral & b)2359 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2360     if (b.fValue) {
2361         if (fBoolTrue == 0) {
2362             fBoolTrue = this->nextId();
2363             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2364                                    fConstantBuffer);
2365         }
2366         return fBoolTrue;
2367     } else {
2368         if (fBoolFalse == 0) {
2369             fBoolFalse = this->nextId();
2370             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2371                                    fConstantBuffer);
2372         }
2373         return fBoolFalse;
2374     }
2375 }
2376 
writeIntLiteral(const IntLiteral & i)2377 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2378     if (i.fType == *fContext.fInt_Type) {
2379         auto entry = fIntConstants.find(i.fValue);
2380         if (entry == fIntConstants.end()) {
2381             SpvId result = this->nextId();
2382             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2383                                    fConstantBuffer);
2384             fIntConstants[i.fValue] = result;
2385             return result;
2386         }
2387         return entry->second;
2388     } else {
2389         ASSERT(i.fType == *fContext.fUInt_Type);
2390         auto entry = fUIntConstants.find(i.fValue);
2391         if (entry == fUIntConstants.end()) {
2392             SpvId result = this->nextId();
2393             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2394                                    fConstantBuffer);
2395             fUIntConstants[i.fValue] = result;
2396             return result;
2397         }
2398         return entry->second;
2399     }
2400 }
2401 
writeFloatLiteral(const FloatLiteral & f)2402 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2403     if (f.fType == *fContext.fFloat_Type) {
2404         float value = (float) f.fValue;
2405         auto entry = fFloatConstants.find(value);
2406         if (entry == fFloatConstants.end()) {
2407             SpvId result = this->nextId();
2408             uint32_t bits;
2409             ASSERT(sizeof(bits) == sizeof(value));
2410             memcpy(&bits, &value, sizeof(bits));
2411             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2412                                    fConstantBuffer);
2413             fFloatConstants[value] = result;
2414             return result;
2415         }
2416         return entry->second;
2417     } else {
2418         ASSERT(f.fType == *fContext.fDouble_Type);
2419         auto entry = fDoubleConstants.find(f.fValue);
2420         if (entry == fDoubleConstants.end()) {
2421             SpvId result = this->nextId();
2422             uint64_t bits;
2423             ASSERT(sizeof(bits) == sizeof(f.fValue));
2424             memcpy(&bits, &f.fValue, sizeof(bits));
2425             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2426                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2427             fDoubleConstants[f.fValue] = result;
2428             return result;
2429         }
2430         return entry->second;
2431     }
2432 }
2433 
writeFunctionStart(const FunctionDeclaration & f,SkWStream & out)2434 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, SkWStream& out) {
2435     SpvId result = fFunctionMap[&f];
2436     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2437                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2438     this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
2439     for (size_t i = 0; i < f.fParameters.size(); i++) {
2440         SpvId id = this->nextId();
2441         fVariableMap[f.fParameters[i]] = id;
2442         SpvId type;
2443         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2444         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2445     }
2446     return result;
2447 }
2448 
writeFunction(const FunctionDefinition & f,SkWStream & out)2449 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, SkWStream& out) {
2450     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2451     this->writeLabel(this->nextId(), out);
2452     if (f.fDeclaration.fName == "main") {
2453         write_data(*fGlobalInitializersBuffer.detachAsData(), out);
2454     }
2455     SkDynamicMemoryWStream bodyBuffer;
2456     this->writeBlock(*f.fBody, bodyBuffer);
2457     write_data(*fVariableBuffer.detachAsData(), out);
2458     write_data(*bodyBuffer.detachAsData(), out);
2459     if (fCurrentBlock) {
2460         this->writeInstruction(SpvOpReturn, out);
2461     }
2462     this->writeInstruction(SpvOpFunctionEnd, out);
2463     return result;
2464 }
2465 
writeLayout(const Layout & layout,SpvId target)2466 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2467     if (layout.fLocation >= 0) {
2468         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2469                                fDecorationBuffer);
2470     }
2471     if (layout.fBinding >= 0) {
2472         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2473                                fDecorationBuffer);
2474     }
2475     if (layout.fIndex >= 0) {
2476         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2477                                fDecorationBuffer);
2478     }
2479     if (layout.fSet >= 0) {
2480         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2481                                fDecorationBuffer);
2482     }
2483     if (layout.fInputAttachmentIndex >= 0) {
2484         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2485                                layout.fInputAttachmentIndex, fDecorationBuffer);
2486     }
2487     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2488         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2489                                fDecorationBuffer);
2490     }
2491 }
2492 
writeLayout(const Layout & layout,SpvId target,int member)2493 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2494     if (layout.fLocation >= 0) {
2495         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2496                                layout.fLocation, fDecorationBuffer);
2497     }
2498     if (layout.fBinding >= 0) {
2499         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2500                                layout.fBinding, fDecorationBuffer);
2501     }
2502     if (layout.fIndex >= 0) {
2503         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2504                                layout.fIndex, fDecorationBuffer);
2505     }
2506     if (layout.fSet >= 0) {
2507         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2508                                layout.fSet, fDecorationBuffer);
2509     }
2510     if (layout.fInputAttachmentIndex >= 0) {
2511         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2512                                layout.fInputAttachmentIndex, fDecorationBuffer);
2513     }
2514     if (layout.fBuiltin >= 0) {
2515         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2516                                layout.fBuiltin, fDecorationBuffer);
2517     }
2518 }
2519 
writeInterfaceBlock(const InterfaceBlock & intf)2520 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2521     MemoryLayout layout = intf.fVariable.fModifiers.fLayout.fPushConstant ?
2522                           MemoryLayout(MemoryLayout::k430_Standard) :
2523                           fDefaultLayout;
2524     SpvId result = this->nextId();
2525     const Type* type = &intf.fVariable.fType;
2526     if (fProgram.fInputs.fRTHeight) {
2527         ASSERT(fRTHeightStructId == (SpvId) -1);
2528         ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2529         std::vector<Type::Field> fields = type->fields();
2530         fRTHeightStructId = result;
2531         fRTHeightFieldIndex = fields.size();
2532         fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2533         type = new Type(type->fPosition, type->name(), fields);
2534     }
2535     SpvId typeId = this->getType(*type, layout);
2536     this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2537     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2538     SpvId ptrType = this->nextId();
2539     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2540     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2541     this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2542     fVariableMap[&intf.fVariable] = result;
2543     if (fProgram.fInputs.fRTHeight) {
2544         delete type;
2545     }
2546     return result;
2547 }
2548 
2549 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,SkWStream & out)2550 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2551                                          SkWStream& out) {
2552     for (size_t i = 0; i < decl.fVars.size(); i++) {
2553         const VarDeclaration& varDecl = decl.fVars[i];
2554         const Variable* var = varDecl.fVar;
2555         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2556         // in the OpenGL backend.
2557         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2558                                            Modifiers::kWriteOnly_Flag |
2559                                            Modifiers::kCoherent_Flag |
2560                                            Modifiers::kVolatile_Flag |
2561                                            Modifiers::kRestrict_Flag)));
2562         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2563             continue;
2564         }
2565         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2566             kind != Program::kFragment_Kind) {
2567             continue;
2568         }
2569         if (!var->fReadCount && !var->fWriteCount &&
2570                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2571                                             Modifiers::kOut_Flag |
2572                                             Modifiers::kUniform_Flag))) {
2573             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2574             // we elide an interface var, even if it's dead)
2575             continue;
2576         }
2577         SpvStorageClass_ storageClass;
2578         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2579             storageClass = SpvStorageClassInput;
2580         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2581             storageClass = SpvStorageClassOutput;
2582         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2583             if (var->fType.kind() == Type::kSampler_Kind) {
2584                 storageClass = SpvStorageClassUniformConstant;
2585             } else {
2586                 storageClass = SpvStorageClassUniform;
2587             }
2588         } else {
2589             storageClass = SpvStorageClassPrivate;
2590         }
2591         SpvId id = this->nextId();
2592         fVariableMap[var] = id;
2593         SpvId type = this->getPointerType(var->fType, storageClass);
2594         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2595         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2596         if (var->fType.kind() == Type::kMatrix_Kind) {
2597             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
2598                                    fDecorationBuffer);
2599             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
2600                                    (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer);
2601         }
2602         if (varDecl.fValue) {
2603             ASSERT(!fCurrentBlock);
2604             fCurrentBlock = -1;
2605             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2606             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2607             fCurrentBlock = 0;
2608         }
2609         this->writeLayout(var->fModifiers.fLayout, id);
2610     }
2611 }
2612 
writeVarDeclarations(const VarDeclarations & decl,SkWStream & out)2613 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, SkWStream& out) {
2614     for (const auto& varDecl : decl.fVars) {
2615         const Variable* var = varDecl.fVar;
2616         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2617         // in the OpenGL backend.
2618         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2619                                            Modifiers::kWriteOnly_Flag |
2620                                            Modifiers::kCoherent_Flag |
2621                                            Modifiers::kVolatile_Flag |
2622                                            Modifiers::kRestrict_Flag)));
2623         SpvId id = this->nextId();
2624         fVariableMap[var] = id;
2625         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2626         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2627         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2628         if (varDecl.fValue) {
2629             SpvId value = this->writeExpression(*varDecl.fValue, out);
2630             this->writeInstruction(SpvOpStore, id, value, out);
2631         }
2632     }
2633 }
2634 
writeStatement(const Statement & s,SkWStream & out)2635 void SPIRVCodeGenerator::writeStatement(const Statement& s, SkWStream& out) {
2636     switch (s.fKind) {
2637         case Statement::kBlock_Kind:
2638             this->writeBlock((Block&) s, out);
2639             break;
2640         case Statement::kExpression_Kind:
2641             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2642             break;
2643         case Statement::kReturn_Kind:
2644             this->writeReturnStatement((ReturnStatement&) s, out);
2645             break;
2646         case Statement::kVarDeclarations_Kind:
2647             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2648             break;
2649         case Statement::kIf_Kind:
2650             this->writeIfStatement((IfStatement&) s, out);
2651             break;
2652         case Statement::kFor_Kind:
2653             this->writeForStatement((ForStatement&) s, out);
2654             break;
2655         case Statement::kWhile_Kind:
2656             this->writeWhileStatement((WhileStatement&) s, out);
2657             break;
2658         case Statement::kDo_Kind:
2659             this->writeDoStatement((DoStatement&) s, out);
2660             break;
2661         case Statement::kBreak_Kind:
2662             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2663             break;
2664         case Statement::kContinue_Kind:
2665             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2666             break;
2667         case Statement::kDiscard_Kind:
2668             this->writeInstruction(SpvOpKill, out);
2669             break;
2670         default:
2671             ABORT("unsupported statement: %s", s.description().c_str());
2672     }
2673 }
2674 
writeBlock(const Block & b,SkWStream & out)2675 void SPIRVCodeGenerator::writeBlock(const Block& b, SkWStream& out) {
2676     for (size_t i = 0; i < b.fStatements.size(); i++) {
2677         this->writeStatement(*b.fStatements[i], out);
2678     }
2679 }
2680 
writeIfStatement(const IfStatement & stmt,SkWStream & out)2681 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, SkWStream& out) {
2682     SpvId test = this->writeExpression(*stmt.fTest, out);
2683     SpvId ifTrue = this->nextId();
2684     SpvId ifFalse = this->nextId();
2685     if (stmt.fIfFalse) {
2686         SpvId end = this->nextId();
2687         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2688         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2689         this->writeLabel(ifTrue, out);
2690         this->writeStatement(*stmt.fIfTrue, out);
2691         if (fCurrentBlock) {
2692             this->writeInstruction(SpvOpBranch, end, out);
2693         }
2694         this->writeLabel(ifFalse, out);
2695         this->writeStatement(*stmt.fIfFalse, out);
2696         if (fCurrentBlock) {
2697             this->writeInstruction(SpvOpBranch, end, out);
2698         }
2699         this->writeLabel(end, out);
2700     } else {
2701         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2702         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2703         this->writeLabel(ifTrue, out);
2704         this->writeStatement(*stmt.fIfTrue, out);
2705         if (fCurrentBlock) {
2706             this->writeInstruction(SpvOpBranch, ifFalse, out);
2707         }
2708         this->writeLabel(ifFalse, out);
2709     }
2710 }
2711 
writeForStatement(const ForStatement & f,SkWStream & out)2712 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, SkWStream& out) {
2713     if (f.fInitializer) {
2714         this->writeStatement(*f.fInitializer, out);
2715     }
2716     SpvId header = this->nextId();
2717     SpvId start = this->nextId();
2718     SpvId body = this->nextId();
2719     SpvId next = this->nextId();
2720     fContinueTarget.push(next);
2721     SpvId end = this->nextId();
2722     fBreakTarget.push(end);
2723     this->writeInstruction(SpvOpBranch, header, out);
2724     this->writeLabel(header, out);
2725     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2726     this->writeInstruction(SpvOpBranch, start, out);
2727     this->writeLabel(start, out);
2728     if (f.fTest) {
2729         SpvId test = this->writeExpression(*f.fTest, out);
2730         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2731     }
2732     this->writeLabel(body, out);
2733     this->writeStatement(*f.fStatement, out);
2734     if (fCurrentBlock) {
2735         this->writeInstruction(SpvOpBranch, next, out);
2736     }
2737     this->writeLabel(next, out);
2738     if (f.fNext) {
2739         this->writeExpression(*f.fNext, out);
2740     }
2741     this->writeInstruction(SpvOpBranch, header, out);
2742     this->writeLabel(end, out);
2743     fBreakTarget.pop();
2744     fContinueTarget.pop();
2745 }
2746 
writeWhileStatement(const WhileStatement & w,SkWStream & out)2747 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, SkWStream& out) {
2748     // We believe the while loop code below will work, but Skia doesn't actually use them and
2749     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2750     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2751     // message, simply remove the error call below to see whether our while loop support actually
2752     // works.
2753     fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, "
2754                   "see SkSLSPIRVCodeGenerator.cpp for details");
2755 
2756     SpvId header = this->nextId();
2757     SpvId start = this->nextId();
2758     SpvId body = this->nextId();
2759     fContinueTarget.push(start);
2760     SpvId end = this->nextId();
2761     fBreakTarget.push(end);
2762     this->writeInstruction(SpvOpBranch, header, out);
2763     this->writeLabel(header, out);
2764     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2765     this->writeInstruction(SpvOpBranch, start, out);
2766     this->writeLabel(start, out);
2767     SpvId test = this->writeExpression(*w.fTest, out);
2768     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2769     this->writeLabel(body, out);
2770     this->writeStatement(*w.fStatement, out);
2771     if (fCurrentBlock) {
2772         this->writeInstruction(SpvOpBranch, start, out);
2773     }
2774     this->writeLabel(end, out);
2775     fBreakTarget.pop();
2776     fContinueTarget.pop();
2777 }
2778 
writeDoStatement(const DoStatement & d,SkWStream & out)2779 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, SkWStream& out) {
2780     // We believe the do loop code below will work, but Skia doesn't actually use them and
2781     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2782     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2783     // message, simply remove the error call below to see whether our do loop support actually
2784     // works.
2785     fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see "
2786                   "SkSLSPIRVCodeGenerator.cpp for details");
2787 
2788     SpvId header = this->nextId();
2789     SpvId start = this->nextId();
2790     SpvId next = this->nextId();
2791     fContinueTarget.push(next);
2792     SpvId end = this->nextId();
2793     fBreakTarget.push(end);
2794     this->writeInstruction(SpvOpBranch, header, out);
2795     this->writeLabel(header, out);
2796     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2797     this->writeInstruction(SpvOpBranch, start, out);
2798     this->writeLabel(start, out);
2799     this->writeStatement(*d.fStatement, out);
2800     if (fCurrentBlock) {
2801         this->writeInstruction(SpvOpBranch, next, out);
2802     }
2803     this->writeLabel(next, out);
2804     SpvId test = this->writeExpression(*d.fTest, out);
2805     this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2806     this->writeLabel(end, out);
2807     fBreakTarget.pop();
2808     fContinueTarget.pop();
2809 }
2810 
writeReturnStatement(const ReturnStatement & r,SkWStream & out)2811 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, SkWStream& out) {
2812     if (r.fExpression) {
2813         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2814                                out);
2815     } else {
2816         this->writeInstruction(SpvOpReturn, out);
2817     }
2818 }
2819 
writeInstructions(const Program & program,SkWStream & out)2820 void SPIRVCodeGenerator::writeInstructions(const Program& program, SkWStream& out) {
2821     fGLSLExtendedInstructions = this->nextId();
2822     SkDynamicMemoryWStream body;
2823     std::set<SpvId> interfaceVars;
2824     // assign IDs to functions
2825     for (size_t i = 0; i < program.fElements.size(); i++) {
2826         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2827             FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2828             fFunctionMap[&f.fDeclaration] = this->nextId();
2829         }
2830     }
2831     for (size_t i = 0; i < program.fElements.size(); i++) {
2832         if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2833             InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2834             SpvId id = this->writeInterfaceBlock(intf);
2835             if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2836                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2837                 interfaceVars.insert(id);
2838             }
2839         }
2840     }
2841     for (size_t i = 0; i < program.fElements.size(); i++) {
2842         if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2843             this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2844                                   body);
2845         }
2846     }
2847     for (size_t i = 0; i < program.fElements.size(); i++) {
2848         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2849             this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2850         }
2851     }
2852     const FunctionDeclaration* main = nullptr;
2853     for (auto entry : fFunctionMap) {
2854         if (entry.first->fName == "main") {
2855             main = entry.first;
2856         }
2857     }
2858     ASSERT(main);
2859     for (auto entry : fVariableMap) {
2860         const Variable* var = entry.first;
2861         if (var->fStorage == Variable::kGlobal_Storage &&
2862                 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2863                  (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2864             interfaceVars.insert(entry.second);
2865         }
2866     }
2867     this->writeCapabilities(out);
2868     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2869     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2870     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
2871                       (int32_t) interfaceVars.size(), out);
2872     switch (program.fKind) {
2873         case Program::kVertex_Kind:
2874             this->writeWord(SpvExecutionModelVertex, out);
2875             break;
2876         case Program::kFragment_Kind:
2877             this->writeWord(SpvExecutionModelFragment, out);
2878             break;
2879         case Program::kGeometry_Kind:
2880             this->writeWord(SpvExecutionModelGeometry, out);
2881             break;
2882     }
2883     this->writeWord(fFunctionMap[main], out);
2884     this->writeString(main->fName.c_str(), out);
2885     for (int var : interfaceVars) {
2886         this->writeWord(var, out);
2887     }
2888     if (program.fKind == Program::kFragment_Kind) {
2889         this->writeInstruction(SpvOpExecutionMode,
2890                                fFunctionMap[main],
2891                                SpvExecutionModeOriginUpperLeft,
2892                                out);
2893     }
2894     for (size_t i = 0; i < program.fElements.size(); i++) {
2895         if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2896             this->writeInstruction(SpvOpSourceExtension,
2897                                    ((Extension&) *program.fElements[i]).fName.c_str(),
2898                                    out);
2899         }
2900     }
2901 
2902     write_data(*fExtraGlobalsBuffer.detachAsData(), out);
2903     write_data(*fNameBuffer.detachAsData(), out);
2904     write_data(*fDecorationBuffer.detachAsData(), out);
2905     write_data(*fConstantBuffer.detachAsData(), out);
2906     write_data(*fExternalFunctionsBuffer.detachAsData(), out);
2907     write_data(*body.detachAsData(), out);
2908 }
2909 
generateCode()2910 bool SPIRVCodeGenerator::generateCode() {
2911     ASSERT(!fErrors.errorCount());
2912     this->writeWord(SpvMagicNumber, *fOut);
2913     this->writeWord(SpvVersion, *fOut);
2914     this->writeWord(SKSL_MAGIC, *fOut);
2915     SkDynamicMemoryWStream buffer;
2916     this->writeInstructions(fProgram, buffer);
2917     this->writeWord(fIdCount, *fOut);
2918     this->writeWord(0, *fOut); // reserved, always zero
2919     write_data(*buffer.detachAsData(), *fOut);
2920     return 0 == fErrors.errorCount();
2921 }
2922 
2923 }
2924