1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLSPIRVCodeGenerator.h"
9
10 #include "GLSL.std.450.h"
11
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17
18 namespace SkSL {
19
20 #define SPIRV_DEBUG 0
21
22 static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number
23
setupIntrinsics()24 void SPIRVCodeGenerator::setupIntrinsics() {
25 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
26 GLSLstd450 ## x, GLSLstd450 ## x)
27 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
28 GLSLstd450 ## ifFloat, \
29 GLSLstd450 ## ifInt, \
30 GLSLstd450 ## ifUInt, \
31 SpvOpUndef)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34 k ## x ## _SpecialIntrinsic)
35 fIntrinsicMap[String("round")] = ALL_GLSL(Round);
36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven);
37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc);
38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign);
40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor);
41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil);
42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract);
43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians);
44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees);
45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin);
46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos);
47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan);
48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin);
49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos);
50 fIntrinsicMap[String("atan")] = SPECIAL(Atan);
51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh);
52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh);
53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh);
54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh);
55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh);
56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh);
57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow);
58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp);
59 fIntrinsicMap[String("log")] = ALL_GLSL(Log);
60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2);
61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2);
62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt);
63 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt);
64 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant);
65 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
66 fIntrinsicMap[String("mod")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod,
67 SpvOpSMod, SpvOpUMod, SpvOpUndef);
68 fIntrinsicMap[String("min")] = BY_TYPE_GLSL(FMin, SMin, UMin);
69 fIntrinsicMap[String("max")] = BY_TYPE_GLSL(FMax, SMax, UMax);
70 fIntrinsicMap[String("clamp")] = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
71 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
72 SpvOpUndef, SpvOpUndef, SpvOpUndef);
73 fIntrinsicMap[String("mix")] = ALL_GLSL(FMix);
74 fIntrinsicMap[String("step")] = ALL_GLSL(Step);
75 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep);
76 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma);
77 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp);
78 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp);
79
80 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
81 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
82 PACK(Snorm4x8);
83 PACK(Unorm4x8);
84 PACK(Snorm2x16);
85 PACK(Unorm2x16);
86 PACK(Half2x16);
87 PACK(Double2x32);
88 fIntrinsicMap[String("length")] = ALL_GLSL(Length);
89 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance);
90 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross);
91 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize);
92 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
93 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect);
94 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract);
95 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb);
96 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
97 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
98 SpvOpUndef, SpvOpUndef, SpvOpUndef);
99 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
100 SpvOpUndef, SpvOpUndef, SpvOpUndef);
101 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102 SpvOpUndef, SpvOpUndef, SpvOpUndef);
103 fIntrinsicMap[String("texture")] = SPECIAL(Texture);
104 fIntrinsicMap[String("texelFetch")] = SPECIAL(TexelFetch);
105 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
106
107 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
108 SpvOpUndef, SpvOpUndef, SpvOpAny);
109 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
110 SpvOpUndef, SpvOpUndef, SpvOpAll);
111 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind,
112 SpvOpFOrdEqual, SpvOpIEqual,
113 SpvOpIEqual, SpvOpLogicalEqual);
114 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
115 SpvOpFOrdNotEqual, SpvOpINotEqual,
116 SpvOpINotEqual,
117 SpvOpLogicalNotEqual);
118 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
119 SpvOpSLessThan, SpvOpULessThan,
120 SpvOpFOrdLessThan, SpvOpUndef);
121 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
122 SpvOpSLessThanEqual,
123 SpvOpULessThanEqual,
124 SpvOpFOrdLessThanEqual,
125 SpvOpUndef);
126 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
127 SpvOpSGreaterThan,
128 SpvOpUGreaterThan,
129 SpvOpFOrdGreaterThan,
130 SpvOpUndef);
131 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
132 SpvOpSGreaterThanEqual,
133 SpvOpUGreaterThanEqual,
134 SpvOpFOrdGreaterThanEqual,
135 SpvOpUndef);
136 // interpolateAt* not yet supported...
137 }
138
writeWord(int32_t word,OutputStream & out)139 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
140 #if SPIRV_DEBUG
141 out << "(" << word << ") ";
142 #else
143 out.write((const char*) &word, sizeof(word));
144 #endif
145 }
146
is_float(const Context & context,const Type & type)147 static bool is_float(const Context& context, const Type& type) {
148 if (type.kind() == Type::kVector_Kind) {
149 return is_float(context, type.componentType());
150 }
151 return type == *context.fFloat_Type || type == *context.fDouble_Type;
152 }
153
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155 if (type.kind() == Type::kVector_Kind) {
156 return is_signed(context, type.componentType());
157 }
158 return type == *context.fInt_Type;
159 }
160
is_unsigned(const Context & context,const Type & type)161 static bool is_unsigned(const Context& context, const Type& type) {
162 if (type.kind() == Type::kVector_Kind) {
163 return is_unsigned(context, type.componentType());
164 }
165 return type == *context.fUInt_Type;
166 }
167
is_bool(const Context & context,const Type & type)168 static bool is_bool(const Context& context, const Type& type) {
169 if (type.kind() == Type::kVector_Kind) {
170 return is_bool(context, type.componentType());
171 }
172 return type == *context.fBool_Type;
173 }
174
is_out(const Variable & var)175 static bool is_out(const Variable& var) {
176 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177 }
178
179 #if SPIRV_DEBUG
opcode_text(SpvOp_ opCode)180 static String opcode_text(SpvOp_ opCode) {
181 switch (opCode) {
182 case SpvOpNop:
183 return String("Nop");
184 case SpvOpUndef:
185 return String("Undef");
186 case SpvOpSourceContinued:
187 return String("SourceContinued");
188 case SpvOpSource:
189 return String("Source");
190 case SpvOpSourceExtension:
191 return String("SourceExtension");
192 case SpvOpName:
193 return String("Name");
194 case SpvOpMemberName:
195 return String("MemberName");
196 case SpvOpString:
197 return String("String");
198 case SpvOpLine:
199 return String("Line");
200 case SpvOpExtension:
201 return String("Extension");
202 case SpvOpExtInstImport:
203 return String("ExtInstImport");
204 case SpvOpExtInst:
205 return String("ExtInst");
206 case SpvOpMemoryModel:
207 return String("MemoryModel");
208 case SpvOpEntryPoint:
209 return String("EntryPoint");
210 case SpvOpExecutionMode:
211 return String("ExecutionMode");
212 case SpvOpCapability:
213 return String("Capability");
214 case SpvOpTypeVoid:
215 return String("TypeVoid");
216 case SpvOpTypeBool:
217 return String("TypeBool");
218 case SpvOpTypeInt:
219 return String("TypeInt");
220 case SpvOpTypeFloat:
221 return String("TypeFloat");
222 case SpvOpTypeVector:
223 return String("TypeVector");
224 case SpvOpTypeMatrix:
225 return String("TypeMatrix");
226 case SpvOpTypeImage:
227 return String("TypeImage");
228 case SpvOpTypeSampler:
229 return String("TypeSampler");
230 case SpvOpTypeSampledImage:
231 return String("TypeSampledImage");
232 case SpvOpTypeArray:
233 return String("TypeArray");
234 case SpvOpTypeRuntimeArray:
235 return String("TypeRuntimeArray");
236 case SpvOpTypeStruct:
237 return String("TypeStruct");
238 case SpvOpTypeOpaque:
239 return String("TypeOpaque");
240 case SpvOpTypePointer:
241 return String("TypePointer");
242 case SpvOpTypeFunction:
243 return String("TypeFunction");
244 case SpvOpTypeEvent:
245 return String("TypeEvent");
246 case SpvOpTypeDeviceEvent:
247 return String("TypeDeviceEvent");
248 case SpvOpTypeReserveId:
249 return String("TypeReserveId");
250 case SpvOpTypeQueue:
251 return String("TypeQueue");
252 case SpvOpTypePipe:
253 return String("TypePipe");
254 case SpvOpTypeForwardPointer:
255 return String("TypeForwardPointer");
256 case SpvOpConstantTrue:
257 return String("ConstantTrue");
258 case SpvOpConstantFalse:
259 return String("ConstantFalse");
260 case SpvOpConstant:
261 return String("Constant");
262 case SpvOpConstantComposite:
263 return String("ConstantComposite");
264 case SpvOpConstantSampler:
265 return String("ConstantSampler");
266 case SpvOpConstantNull:
267 return String("ConstantNull");
268 case SpvOpSpecConstantTrue:
269 return String("SpecConstantTrue");
270 case SpvOpSpecConstantFalse:
271 return String("SpecConstantFalse");
272 case SpvOpSpecConstant:
273 return String("SpecConstant");
274 case SpvOpSpecConstantComposite:
275 return String("SpecConstantComposite");
276 case SpvOpSpecConstantOp:
277 return String("SpecConstantOp");
278 case SpvOpFunction:
279 return String("Function");
280 case SpvOpFunctionParameter:
281 return String("FunctionParameter");
282 case SpvOpFunctionEnd:
283 return String("FunctionEnd");
284 case SpvOpFunctionCall:
285 return String("FunctionCall");
286 case SpvOpVariable:
287 return String("Variable");
288 case SpvOpImageTexelPointer:
289 return String("ImageTexelPointer");
290 case SpvOpLoad:
291 return String("Load");
292 case SpvOpStore:
293 return String("Store");
294 case SpvOpCopyMemory:
295 return String("CopyMemory");
296 case SpvOpCopyMemorySized:
297 return String("CopyMemorySized");
298 case SpvOpAccessChain:
299 return String("AccessChain");
300 case SpvOpInBoundsAccessChain:
301 return String("InBoundsAccessChain");
302 case SpvOpPtrAccessChain:
303 return String("PtrAccessChain");
304 case SpvOpArrayLength:
305 return String("ArrayLength");
306 case SpvOpGenericPtrMemSemantics:
307 return String("GenericPtrMemSemantics");
308 case SpvOpInBoundsPtrAccessChain:
309 return String("InBoundsPtrAccessChain");
310 case SpvOpDecorate:
311 return String("Decorate");
312 case SpvOpMemberDecorate:
313 return String("MemberDecorate");
314 case SpvOpDecorationGroup:
315 return String("DecorationGroup");
316 case SpvOpGroupDecorate:
317 return String("GroupDecorate");
318 case SpvOpGroupMemberDecorate:
319 return String("GroupMemberDecorate");
320 case SpvOpVectorExtractDynamic:
321 return String("VectorExtractDynamic");
322 case SpvOpVectorInsertDynamic:
323 return String("VectorInsertDynamic");
324 case SpvOpVectorShuffle:
325 return String("VectorShuffle");
326 case SpvOpCompositeConstruct:
327 return String("CompositeConstruct");
328 case SpvOpCompositeExtract:
329 return String("CompositeExtract");
330 case SpvOpCompositeInsert:
331 return String("CompositeInsert");
332 case SpvOpCopyObject:
333 return String("CopyObject");
334 case SpvOpTranspose:
335 return String("Transpose");
336 case SpvOpSampledImage:
337 return String("SampledImage");
338 case SpvOpImageSampleImplicitLod:
339 return String("ImageSampleImplicitLod");
340 case SpvOpImageSampleExplicitLod:
341 return String("ImageSampleExplicitLod");
342 case SpvOpImageSampleDrefImplicitLod:
343 return String("ImageSampleDrefImplicitLod");
344 case SpvOpImageSampleDrefExplicitLod:
345 return String("ImageSampleDrefExplicitLod");
346 case SpvOpImageSampleProjImplicitLod:
347 return String("ImageSampleProjImplicitLod");
348 case SpvOpImageSampleProjExplicitLod:
349 return String("ImageSampleProjExplicitLod");
350 case SpvOpImageSampleProjDrefImplicitLod:
351 return String("ImageSampleProjDrefImplicitLod");
352 case SpvOpImageSampleProjDrefExplicitLod:
353 return String("ImageSampleProjDrefExplicitLod");
354 case SpvOpImageFetch:
355 return String("ImageFetch");
356 case SpvOpImageGather:
357 return String("ImageGather");
358 case SpvOpImageDrefGather:
359 return String("ImageDrefGather");
360 case SpvOpImageRead:
361 return String("ImageRead");
362 case SpvOpImageWrite:
363 return String("ImageWrite");
364 case SpvOpImage:
365 return String("Image");
366 case SpvOpImageQueryFormat:
367 return String("ImageQueryFormat");
368 case SpvOpImageQueryOrder:
369 return String("ImageQueryOrder");
370 case SpvOpImageQuerySizeLod:
371 return String("ImageQuerySizeLod");
372 case SpvOpImageQuerySize:
373 return String("ImageQuerySize");
374 case SpvOpImageQueryLod:
375 return String("ImageQueryLod");
376 case SpvOpImageQueryLevels:
377 return String("ImageQueryLevels");
378 case SpvOpImageQuerySamples:
379 return String("ImageQuerySamples");
380 case SpvOpConvertFToU:
381 return String("ConvertFToU");
382 case SpvOpConvertFToS:
383 return String("ConvertFToS");
384 case SpvOpConvertSToF:
385 return String("ConvertSToF");
386 case SpvOpConvertUToF:
387 return String("ConvertUToF");
388 case SpvOpUConvert:
389 return String("UConvert");
390 case SpvOpSConvert:
391 return String("SConvert");
392 case SpvOpFConvert:
393 return String("FConvert");
394 case SpvOpQuantizeToF16:
395 return String("QuantizeToF16");
396 case SpvOpConvertPtrToU:
397 return String("ConvertPtrToU");
398 case SpvOpSatConvertSToU:
399 return String("SatConvertSToU");
400 case SpvOpSatConvertUToS:
401 return String("SatConvertUToS");
402 case SpvOpConvertUToPtr:
403 return String("ConvertUToPtr");
404 case SpvOpPtrCastToGeneric:
405 return String("PtrCastToGeneric");
406 case SpvOpGenericCastToPtr:
407 return String("GenericCastToPtr");
408 case SpvOpGenericCastToPtrExplicit:
409 return String("GenericCastToPtrExplicit");
410 case SpvOpBitcast:
411 return String("Bitcast");
412 case SpvOpSNegate:
413 return String("SNegate");
414 case SpvOpFNegate:
415 return String("FNegate");
416 case SpvOpIAdd:
417 return String("IAdd");
418 case SpvOpFAdd:
419 return String("FAdd");
420 case SpvOpISub:
421 return String("ISub");
422 case SpvOpFSub:
423 return String("FSub");
424 case SpvOpIMul:
425 return String("IMul");
426 case SpvOpFMul:
427 return String("FMul");
428 case SpvOpUDiv:
429 return String("UDiv");
430 case SpvOpSDiv:
431 return String("SDiv");
432 case SpvOpFDiv:
433 return String("FDiv");
434 case SpvOpUMod:
435 return String("UMod");
436 case SpvOpSRem:
437 return String("SRem");
438 case SpvOpSMod:
439 return String("SMod");
440 case SpvOpFRem:
441 return String("FRem");
442 case SpvOpFMod:
443 return String("FMod");
444 case SpvOpVectorTimesScalar:
445 return String("VectorTimesScalar");
446 case SpvOpMatrixTimesScalar:
447 return String("MatrixTimesScalar");
448 case SpvOpVectorTimesMatrix:
449 return String("VectorTimesMatrix");
450 case SpvOpMatrixTimesVector:
451 return String("MatrixTimesVector");
452 case SpvOpMatrixTimesMatrix:
453 return String("MatrixTimesMatrix");
454 case SpvOpOuterProduct:
455 return String("OuterProduct");
456 case SpvOpDot:
457 return String("Dot");
458 case SpvOpIAddCarry:
459 return String("IAddCarry");
460 case SpvOpISubBorrow:
461 return String("ISubBorrow");
462 case SpvOpUMulExtended:
463 return String("UMulExtended");
464 case SpvOpSMulExtended:
465 return String("SMulExtended");
466 case SpvOpAny:
467 return String("Any");
468 case SpvOpAll:
469 return String("All");
470 case SpvOpIsNan:
471 return String("IsNan");
472 case SpvOpIsInf:
473 return String("IsInf");
474 case SpvOpIsFinite:
475 return String("IsFinite");
476 case SpvOpIsNormal:
477 return String("IsNormal");
478 case SpvOpSignBitSet:
479 return String("SignBitSet");
480 case SpvOpLessOrGreater:
481 return String("LessOrGreater");
482 case SpvOpOrdered:
483 return String("Ordered");
484 case SpvOpUnordered:
485 return String("Unordered");
486 case SpvOpLogicalEqual:
487 return String("LogicalEqual");
488 case SpvOpLogicalNotEqual:
489 return String("LogicalNotEqual");
490 case SpvOpLogicalOr:
491 return String("LogicalOr");
492 case SpvOpLogicalAnd:
493 return String("LogicalAnd");
494 case SpvOpLogicalNot:
495 return String("LogicalNot");
496 case SpvOpSelect:
497 return String("Select");
498 case SpvOpIEqual:
499 return String("IEqual");
500 case SpvOpINotEqual:
501 return String("INotEqual");
502 case SpvOpUGreaterThan:
503 return String("UGreaterThan");
504 case SpvOpSGreaterThan:
505 return String("SGreaterThan");
506 case SpvOpUGreaterThanEqual:
507 return String("UGreaterThanEqual");
508 case SpvOpSGreaterThanEqual:
509 return String("SGreaterThanEqual");
510 case SpvOpULessThan:
511 return String("ULessThan");
512 case SpvOpSLessThan:
513 return String("SLessThan");
514 case SpvOpULessThanEqual:
515 return String("ULessThanEqual");
516 case SpvOpSLessThanEqual:
517 return String("SLessThanEqual");
518 case SpvOpFOrdEqual:
519 return String("FOrdEqual");
520 case SpvOpFUnordEqual:
521 return String("FUnordEqual");
522 case SpvOpFOrdNotEqual:
523 return String("FOrdNotEqual");
524 case SpvOpFUnordNotEqual:
525 return String("FUnordNotEqual");
526 case SpvOpFOrdLessThan:
527 return String("FOrdLessThan");
528 case SpvOpFUnordLessThan:
529 return String("FUnordLessThan");
530 case SpvOpFOrdGreaterThan:
531 return String("FOrdGreaterThan");
532 case SpvOpFUnordGreaterThan:
533 return String("FUnordGreaterThan");
534 case SpvOpFOrdLessThanEqual:
535 return String("FOrdLessThanEqual");
536 case SpvOpFUnordLessThanEqual:
537 return String("FUnordLessThanEqual");
538 case SpvOpFOrdGreaterThanEqual:
539 return String("FOrdGreaterThanEqual");
540 case SpvOpFUnordGreaterThanEqual:
541 return String("FUnordGreaterThanEqual");
542 case SpvOpShiftRightLogical:
543 return String("ShiftRightLogical");
544 case SpvOpShiftRightArithmetic:
545 return String("ShiftRightArithmetic");
546 case SpvOpShiftLeftLogical:
547 return String("ShiftLeftLogical");
548 case SpvOpBitwiseOr:
549 return String("BitwiseOr");
550 case SpvOpBitwiseXor:
551 return String("BitwiseXor");
552 case SpvOpBitwiseAnd:
553 return String("BitwiseAnd");
554 case SpvOpNot:
555 return String("Not");
556 case SpvOpBitFieldInsert:
557 return String("BitFieldInsert");
558 case SpvOpBitFieldSExtract:
559 return String("BitFieldSExtract");
560 case SpvOpBitFieldUExtract:
561 return String("BitFieldUExtract");
562 case SpvOpBitReverse:
563 return String("BitReverse");
564 case SpvOpBitCount:
565 return String("BitCount");
566 case SpvOpDPdx:
567 return String("DPdx");
568 case SpvOpDPdy:
569 return String("DPdy");
570 case SpvOpFwidth:
571 return String("Fwidth");
572 case SpvOpDPdxFine:
573 return String("DPdxFine");
574 case SpvOpDPdyFine:
575 return String("DPdyFine");
576 case SpvOpFwidthFine:
577 return String("FwidthFine");
578 case SpvOpDPdxCoarse:
579 return String("DPdxCoarse");
580 case SpvOpDPdyCoarse:
581 return String("DPdyCoarse");
582 case SpvOpFwidthCoarse:
583 return String("FwidthCoarse");
584 case SpvOpEmitVertex:
585 return String("EmitVertex");
586 case SpvOpEndPrimitive:
587 return String("EndPrimitive");
588 case SpvOpEmitStreamVertex:
589 return String("EmitStreamVertex");
590 case SpvOpEndStreamPrimitive:
591 return String("EndStreamPrimitive");
592 case SpvOpControlBarrier:
593 return String("ControlBarrier");
594 case SpvOpMemoryBarrier:
595 return String("MemoryBarrier");
596 case SpvOpAtomicLoad:
597 return String("AtomicLoad");
598 case SpvOpAtomicStore:
599 return String("AtomicStore");
600 case SpvOpAtomicExchange:
601 return String("AtomicExchange");
602 case SpvOpAtomicCompareExchange:
603 return String("AtomicCompareExchange");
604 case SpvOpAtomicCompareExchangeWeak:
605 return String("AtomicCompareExchangeWeak");
606 case SpvOpAtomicIIncrement:
607 return String("AtomicIIncrement");
608 case SpvOpAtomicIDecrement:
609 return String("AtomicIDecrement");
610 case SpvOpAtomicIAdd:
611 return String("AtomicIAdd");
612 case SpvOpAtomicISub:
613 return String("AtomicISub");
614 case SpvOpAtomicSMin:
615 return String("AtomicSMin");
616 case SpvOpAtomicUMin:
617 return String("AtomicUMin");
618 case SpvOpAtomicSMax:
619 return String("AtomicSMax");
620 case SpvOpAtomicUMax:
621 return String("AtomicUMax");
622 case SpvOpAtomicAnd:
623 return String("AtomicAnd");
624 case SpvOpAtomicOr:
625 return String("AtomicOr");
626 case SpvOpAtomicXor:
627 return String("AtomicXor");
628 case SpvOpPhi:
629 return String("Phi");
630 case SpvOpLoopMerge:
631 return String("LoopMerge");
632 case SpvOpSelectionMerge:
633 return String("SelectionMerge");
634 case SpvOpLabel:
635 return String("Label");
636 case SpvOpBranch:
637 return String("Branch");
638 case SpvOpBranchConditional:
639 return String("BranchConditional");
640 case SpvOpSwitch:
641 return String("Switch");
642 case SpvOpKill:
643 return String("Kill");
644 case SpvOpReturn:
645 return String("Return");
646 case SpvOpReturnValue:
647 return String("ReturnValue");
648 case SpvOpUnreachable:
649 return String("Unreachable");
650 case SpvOpLifetimeStart:
651 return String("LifetimeStart");
652 case SpvOpLifetimeStop:
653 return String("LifetimeStop");
654 case SpvOpGroupAsyncCopy:
655 return String("GroupAsyncCopy");
656 case SpvOpGroupWaitEvents:
657 return String("GroupWaitEvents");
658 case SpvOpGroupAll:
659 return String("GroupAll");
660 case SpvOpGroupAny:
661 return String("GroupAny");
662 case SpvOpGroupBroadcast:
663 return String("GroupBroadcast");
664 case SpvOpGroupIAdd:
665 return String("GroupIAdd");
666 case SpvOpGroupFAdd:
667 return String("GroupFAdd");
668 case SpvOpGroupFMin:
669 return String("GroupFMin");
670 case SpvOpGroupUMin:
671 return String("GroupUMin");
672 case SpvOpGroupSMin:
673 return String("GroupSMin");
674 case SpvOpGroupFMax:
675 return String("GroupFMax");
676 case SpvOpGroupUMax:
677 return String("GroupUMax");
678 case SpvOpGroupSMax:
679 return String("GroupSMax");
680 case SpvOpReadPipe:
681 return String("ReadPipe");
682 case SpvOpWritePipe:
683 return String("WritePipe");
684 case SpvOpReservedReadPipe:
685 return String("ReservedReadPipe");
686 case SpvOpReservedWritePipe:
687 return String("ReservedWritePipe");
688 case SpvOpReserveReadPipePackets:
689 return String("ReserveReadPipePackets");
690 case SpvOpReserveWritePipePackets:
691 return String("ReserveWritePipePackets");
692 case SpvOpCommitReadPipe:
693 return String("CommitReadPipe");
694 case SpvOpCommitWritePipe:
695 return String("CommitWritePipe");
696 case SpvOpIsValidReserveId:
697 return String("IsValidReserveId");
698 case SpvOpGetNumPipePackets:
699 return String("GetNumPipePackets");
700 case SpvOpGetMaxPipePackets:
701 return String("GetMaxPipePackets");
702 case SpvOpGroupReserveReadPipePackets:
703 return String("GroupReserveReadPipePackets");
704 case SpvOpGroupReserveWritePipePackets:
705 return String("GroupReserveWritePipePackets");
706 case SpvOpGroupCommitReadPipe:
707 return String("GroupCommitReadPipe");
708 case SpvOpGroupCommitWritePipe:
709 return String("GroupCommitWritePipe");
710 case SpvOpEnqueueMarker:
711 return String("EnqueueMarker");
712 case SpvOpEnqueueKernel:
713 return String("EnqueueKernel");
714 case SpvOpGetKernelNDrangeSubGroupCount:
715 return String("GetKernelNDrangeSubGroupCount");
716 case SpvOpGetKernelNDrangeMaxSubGroupSize:
717 return String("GetKernelNDrangeMaxSubGroupSize");
718 case SpvOpGetKernelWorkGroupSize:
719 return String("GetKernelWorkGroupSize");
720 case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
721 return String("GetKernelPreferredWorkGroupSizeMultiple");
722 case SpvOpRetainEvent:
723 return String("RetainEvent");
724 case SpvOpReleaseEvent:
725 return String("ReleaseEvent");
726 case SpvOpCreateUserEvent:
727 return String("CreateUserEvent");
728 case SpvOpIsValidEvent:
729 return String("IsValidEvent");
730 case SpvOpSetUserEventStatus:
731 return String("SetUserEventStatus");
732 case SpvOpCaptureEventProfilingInfo:
733 return String("CaptureEventProfilingInfo");
734 case SpvOpGetDefaultQueue:
735 return String("GetDefaultQueue");
736 case SpvOpBuildNDRange:
737 return String("BuildNDRange");
738 case SpvOpImageSparseSampleImplicitLod:
739 return String("ImageSparseSampleImplicitLod");
740 case SpvOpImageSparseSampleExplicitLod:
741 return String("ImageSparseSampleExplicitLod");
742 case SpvOpImageSparseSampleDrefImplicitLod:
743 return String("ImageSparseSampleDrefImplicitLod");
744 case SpvOpImageSparseSampleDrefExplicitLod:
745 return String("ImageSparseSampleDrefExplicitLod");
746 case SpvOpImageSparseSampleProjImplicitLod:
747 return String("ImageSparseSampleProjImplicitLod");
748 case SpvOpImageSparseSampleProjExplicitLod:
749 return String("ImageSparseSampleProjExplicitLod");
750 case SpvOpImageSparseSampleProjDrefImplicitLod:
751 return String("ImageSparseSampleProjDrefImplicitLod");
752 case SpvOpImageSparseSampleProjDrefExplicitLod:
753 return String("ImageSparseSampleProjDrefExplicitLod");
754 case SpvOpImageSparseFetch:
755 return String("ImageSparseFetch");
756 case SpvOpImageSparseGather:
757 return String("ImageSparseGather");
758 case SpvOpImageSparseDrefGather:
759 return String("ImageSparseDrefGather");
760 case SpvOpImageSparseTexelsResident:
761 return String("ImageSparseTexelsResident");
762 case SpvOpNoLine:
763 return String("NoLine");
764 case SpvOpAtomicFlagTestAndSet:
765 return String("AtomicFlagTestAndSet");
766 case SpvOpAtomicFlagClear:
767 return String("AtomicFlagClear");
768 case SpvOpImageSparseRead:
769 return String("ImageSparseRead");
770 default:
771 ABORT("unsupported SPIR-V op");
772 }
773 }
774 #endif
775
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)776 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
777 ASSERT(opCode != SpvOpUndef);
778 switch (opCode) {
779 case SpvOpReturn: // fall through
780 case SpvOpReturnValue: // fall through
781 case SpvOpKill: // fall through
782 case SpvOpBranch: // fall through
783 case SpvOpBranchConditional:
784 ASSERT(fCurrentBlock);
785 fCurrentBlock = 0;
786 break;
787 case SpvOpConstant: // fall through
788 case SpvOpConstantTrue: // fall through
789 case SpvOpConstantFalse: // fall through
790 case SpvOpConstantComposite: // fall through
791 case SpvOpTypeVoid: // fall through
792 case SpvOpTypeInt: // fall through
793 case SpvOpTypeFloat: // fall through
794 case SpvOpTypeBool: // fall through
795 case SpvOpTypeVector: // fall through
796 case SpvOpTypeMatrix: // fall through
797 case SpvOpTypeArray: // fall through
798 case SpvOpTypePointer: // fall through
799 case SpvOpTypeFunction: // fall through
800 case SpvOpTypeRuntimeArray: // fall through
801 case SpvOpTypeStruct: // fall through
802 case SpvOpTypeImage: // fall through
803 case SpvOpTypeSampledImage: // fall through
804 case SpvOpVariable: // fall through
805 case SpvOpFunction: // fall through
806 case SpvOpFunctionParameter: // fall through
807 case SpvOpFunctionEnd: // fall through
808 case SpvOpExecutionMode: // fall through
809 case SpvOpMemoryModel: // fall through
810 case SpvOpCapability: // fall through
811 case SpvOpExtInstImport: // fall through
812 case SpvOpEntryPoint: // fall through
813 case SpvOpSource: // fall through
814 case SpvOpSourceExtension: // fall through
815 case SpvOpName: // fall through
816 case SpvOpMemberName: // fall through
817 case SpvOpDecorate: // fall through
818 case SpvOpMemberDecorate:
819 break;
820 default:
821 ASSERT(fCurrentBlock);
822 }
823 #if SPIRV_DEBUG
824 out << std::endl << opcode_text(opCode) << " ";
825 #else
826 this->writeWord((length << 16) | opCode, out);
827 #endif
828 }
829
writeLabel(SpvId label,OutputStream & out)830 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
831 fCurrentBlock = label;
832 this->writeInstruction(SpvOpLabel, label, out);
833 }
834
writeInstruction(SpvOp_ opCode,OutputStream & out)835 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
836 this->writeOpCode(opCode, 1, out);
837 }
838
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)839 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
840 this->writeOpCode(opCode, 2, out);
841 this->writeWord(word1, out);
842 }
843
writeString(const char * string,OutputStream & out)844 void SPIRVCodeGenerator::writeString(const char* string, OutputStream& out) {
845 size_t length = strlen(string);
846 out.write(string, length);
847 switch (length % 4) {
848 case 1:
849 out.write8(0);
850 // fall through
851 case 2:
852 out.write8(0);
853 // fall through
854 case 3:
855 out.write8(0);
856 break;
857 default:
858 this->writeWord(0, out);
859 }
860 }
861
writeInstruction(SpvOp_ opCode,const char * string,OutputStream & out)862 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, OutputStream& out) {
863 int32_t length = (int32_t) strlen(string);
864 this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
865 this->writeString(string, out);
866 }
867
868
writeInstruction(SpvOp_ opCode,int32_t word1,const char * string,OutputStream & out)869 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
870 OutputStream& out) {
871 int32_t length = (int32_t) strlen(string);
872 this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
873 this->writeWord(word1, out);
874 this->writeString(string, out);
875 }
876
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,const char * string,OutputStream & out)877 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
878 const char* string, OutputStream& out) {
879 int32_t length = (int32_t) strlen(string);
880 this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
881 this->writeWord(word1, out);
882 this->writeWord(word2, out);
883 this->writeString(string, out);
884 }
885
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)886 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
887 OutputStream& out) {
888 this->writeOpCode(opCode, 3, out);
889 this->writeWord(word1, out);
890 this->writeWord(word2, out);
891 }
892
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)893 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
894 int32_t word3, OutputStream& out) {
895 this->writeOpCode(opCode, 4, out);
896 this->writeWord(word1, out);
897 this->writeWord(word2, out);
898 this->writeWord(word3, out);
899 }
900
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)901 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
902 int32_t word3, int32_t word4, OutputStream& out) {
903 this->writeOpCode(opCode, 5, out);
904 this->writeWord(word1, out);
905 this->writeWord(word2, out);
906 this->writeWord(word3, out);
907 this->writeWord(word4, out);
908 }
909
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)910 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
911 int32_t word3, int32_t word4, int32_t word5,
912 OutputStream& out) {
913 this->writeOpCode(opCode, 6, out);
914 this->writeWord(word1, out);
915 this->writeWord(word2, out);
916 this->writeWord(word3, out);
917 this->writeWord(word4, out);
918 this->writeWord(word5, out);
919 }
920
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)921 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
922 int32_t word3, int32_t word4, int32_t word5,
923 int32_t word6, OutputStream& out) {
924 this->writeOpCode(opCode, 7, out);
925 this->writeWord(word1, out);
926 this->writeWord(word2, out);
927 this->writeWord(word3, out);
928 this->writeWord(word4, out);
929 this->writeWord(word5, out);
930 this->writeWord(word6, out);
931 }
932
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)933 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
934 int32_t word3, int32_t word4, int32_t word5,
935 int32_t word6, int32_t word7, OutputStream& out) {
936 this->writeOpCode(opCode, 8, out);
937 this->writeWord(word1, out);
938 this->writeWord(word2, out);
939 this->writeWord(word3, out);
940 this->writeWord(word4, out);
941 this->writeWord(word5, out);
942 this->writeWord(word6, out);
943 this->writeWord(word7, out);
944 }
945
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)946 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
947 int32_t word3, int32_t word4, int32_t word5,
948 int32_t word6, int32_t word7, int32_t word8,
949 OutputStream& out) {
950 this->writeOpCode(opCode, 9, out);
951 this->writeWord(word1, out);
952 this->writeWord(word2, out);
953 this->writeWord(word3, out);
954 this->writeWord(word4, out);
955 this->writeWord(word5, out);
956 this->writeWord(word6, out);
957 this->writeWord(word7, out);
958 this->writeWord(word8, out);
959 }
960
writeCapabilities(OutputStream & out)961 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
962 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
963 if (fCapabilities & bit) {
964 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
965 }
966 }
967 }
968
nextId()969 SpvId SPIRVCodeGenerator::nextId() {
970 return fIdCount++;
971 }
972
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)973 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
974 SpvId resultId) {
975 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
976 // go ahead and write all of the field types, so we don't inadvertently write them while we're
977 // in the middle of writing the struct instruction
978 std::vector<SpvId> types;
979 for (const auto& f : type.fields()) {
980 types.push_back(this->getType(*f.fType, memoryLayout));
981 }
982 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
983 this->writeWord(resultId, fConstantBuffer);
984 for (SpvId id : types) {
985 this->writeWord(id, fConstantBuffer);
986 }
987 size_t offset = 0;
988 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
989 size_t size = memoryLayout.size(*type.fields()[i].fType);
990 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
991 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
992 if (fieldLayout.fOffset >= 0) {
993 if (fieldLayout.fOffset < (int) offset) {
994 fErrors.error(type.fPosition,
995 "offset of field '" + type.fields()[i].fName + "' must be at "
996 "least " + to_string((int) offset));
997 }
998 if (fieldLayout.fOffset % alignment) {
999 fErrors.error(type.fPosition,
1000 "offset of field '" + type.fields()[i].fName + "' must be a multiple"
1001 " of " + to_string((int) alignment));
1002 }
1003 offset = fieldLayout.fOffset;
1004 } else {
1005 size_t mod = offset % alignment;
1006 if (mod) {
1007 offset += alignment - mod;
1008 }
1009 }
1010 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
1011 fNameBuffer);
1012 this->writeLayout(fieldLayout, resultId, i);
1013 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
1014 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1015 (SpvId) offset, fDecorationBuffer);
1016 }
1017 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
1018 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1019 fDecorationBuffer);
1020 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1021 (SpvId) memoryLayout.stride(*type.fields()[i].fType),
1022 fDecorationBuffer);
1023 }
1024 offset += size;
1025 Type::Kind kind = type.fields()[i].fType->kind();
1026 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
1027 offset += alignment - offset % alignment;
1028 }
1029 }
1030 }
1031
getType(const Type & type)1032 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1033 return this->getType(type, fDefaultLayout);
1034 }
1035
getType(const Type & type,const MemoryLayout & layout)1036 SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) {
1037 String key = type.name() + to_string((int) layout.fStd);
1038 auto entry = fTypeMap.find(key);
1039 if (entry == fTypeMap.end()) {
1040 SpvId result = this->nextId();
1041 switch (type.kind()) {
1042 case Type::kScalar_Kind:
1043 if (type == *fContext.fBool_Type) {
1044 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
1045 } else if (type == *fContext.fInt_Type) {
1046 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
1047 } else if (type == *fContext.fUInt_Type) {
1048 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
1049 } else if (type == *fContext.fFloat_Type) {
1050 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
1051 } else if (type == *fContext.fDouble_Type) {
1052 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
1053 } else {
1054 ASSERT(false);
1055 }
1056 break;
1057 case Type::kVector_Kind:
1058 this->writeInstruction(SpvOpTypeVector, result,
1059 this->getType(type.componentType(), layout),
1060 type.columns(), fConstantBuffer);
1061 break;
1062 case Type::kMatrix_Kind:
1063 this->writeInstruction(SpvOpTypeMatrix, result,
1064 this->getType(index_type(fContext, type), layout),
1065 type.columns(), fConstantBuffer);
1066 break;
1067 case Type::kStruct_Kind:
1068 this->writeStruct(type, layout, result);
1069 break;
1070 case Type::kArray_Kind: {
1071 if (type.columns() > 0) {
1072 IntLiteral count(fContext, Position(), type.columns());
1073 this->writeInstruction(SpvOpTypeArray, result,
1074 this->getType(type.componentType(), layout),
1075 this->writeIntLiteral(count), fConstantBuffer);
1076 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1077 (int32_t) layout.stride(type),
1078 fDecorationBuffer);
1079 } else {
1080 this->writeInstruction(SpvOpTypeRuntimeArray, result,
1081 this->getType(type.componentType(), layout),
1082 fConstantBuffer);
1083 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1084 (int32_t) layout.stride(type),
1085 fDecorationBuffer);
1086 }
1087 break;
1088 }
1089 case Type::kSampler_Kind: {
1090 SpvId image = result;
1091 if (SpvDimSubpassData != type.dimensions()) {
1092 image = this->nextId();
1093 }
1094 if (SpvDimBuffer == type.dimensions()) {
1095 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
1096 }
1097 this->writeInstruction(SpvOpTypeImage, image,
1098 this->getType(*fContext.fFloat_Type, layout),
1099 type.dimensions(), type.isDepth(), type.isArrayed(),
1100 type.isMultisampled(), type.isSampled() ? 1 : 2,
1101 SpvImageFormatUnknown, fConstantBuffer);
1102 fImageTypeMap[key] = image;
1103 if (SpvDimSubpassData != type.dimensions()) {
1104 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
1105 }
1106 break;
1107 }
1108 default:
1109 if (type == *fContext.fVoid_Type) {
1110 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
1111 } else {
1112 ABORT("invalid type: %s", type.description().c_str());
1113 }
1114 }
1115 fTypeMap[key] = result;
1116 return result;
1117 }
1118 return entry->second;
1119 }
1120
getImageType(const Type & type)1121 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
1122 ASSERT(type.kind() == Type::kSampler_Kind);
1123 this->getType(type);
1124 String key = type.name() + to_string((int) fDefaultLayout.fStd);
1125 ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
1126 return fImageTypeMap[key];
1127 }
1128
getFunctionType(const FunctionDeclaration & function)1129 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1130 String key = function.fReturnType.description() + "(";
1131 String separator;
1132 for (size_t i = 0; i < function.fParameters.size(); i++) {
1133 key += separator;
1134 separator = ", ";
1135 key += function.fParameters[i]->fType.description();
1136 }
1137 key += ")";
1138 auto entry = fTypeMap.find(key);
1139 if (entry == fTypeMap.end()) {
1140 SpvId result = this->nextId();
1141 int32_t length = 3 + (int32_t) function.fParameters.size();
1142 SpvId returnType = this->getType(function.fReturnType);
1143 std::vector<SpvId> parameterTypes;
1144 for (size_t i = 0; i < function.fParameters.size(); i++) {
1145 // glslang seems to treat all function arguments as pointers whether they need to be or
1146 // not. I was initially puzzled by this until I ran bizarre failures with certain
1147 // patterns of function calls and control constructs, as exemplified by this minimal
1148 // failure case:
1149 //
1150 // void sphere(float x) {
1151 // }
1152 //
1153 // void map() {
1154 // sphere(1.0);
1155 // }
1156 //
1157 // void main() {
1158 // for (int i = 0; i < 1; i++) {
1159 // map();
1160 // }
1161 // }
1162 //
1163 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1164 // crashes. Making it take a float* and storing the argument in a temporary variable,
1165 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
1166 // the spec makes this make sense.
1167 // if (is_out(function->fParameters[i])) {
1168 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
1169 SpvStorageClassFunction));
1170 // } else {
1171 // parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
1172 // }
1173 }
1174 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
1175 this->writeWord(result, fConstantBuffer);
1176 this->writeWord(returnType, fConstantBuffer);
1177 for (SpvId id : parameterTypes) {
1178 this->writeWord(id, fConstantBuffer);
1179 }
1180 fTypeMap[key] = result;
1181 return result;
1182 }
1183 return entry->second;
1184 }
1185
getPointerType(const Type & type,SpvStorageClass_ storageClass)1186 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1187 return this->getPointerType(type, fDefaultLayout, storageClass);
1188 }
1189
getPointerType(const Type & type,const MemoryLayout & layout,SpvStorageClass_ storageClass)1190 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1191 SpvStorageClass_ storageClass) {
1192 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
1193 auto entry = fTypeMap.find(key);
1194 if (entry == fTypeMap.end()) {
1195 SpvId result = this->nextId();
1196 this->writeInstruction(SpvOpTypePointer, result, storageClass,
1197 this->getType(type), fConstantBuffer);
1198 fTypeMap[key] = result;
1199 return result;
1200 }
1201 return entry->second;
1202 }
1203
writeExpression(const Expression & expr,OutputStream & out)1204 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1205 switch (expr.fKind) {
1206 case Expression::kBinary_Kind:
1207 return this->writeBinaryExpression((BinaryExpression&) expr, out);
1208 case Expression::kBoolLiteral_Kind:
1209 return this->writeBoolLiteral((BoolLiteral&) expr);
1210 case Expression::kConstructor_Kind:
1211 return this->writeConstructor((Constructor&) expr, out);
1212 case Expression::kIntLiteral_Kind:
1213 return this->writeIntLiteral((IntLiteral&) expr);
1214 case Expression::kFieldAccess_Kind:
1215 return this->writeFieldAccess(((FieldAccess&) expr), out);
1216 case Expression::kFloatLiteral_Kind:
1217 return this->writeFloatLiteral(((FloatLiteral&) expr));
1218 case Expression::kFunctionCall_Kind:
1219 return this->writeFunctionCall((FunctionCall&) expr, out);
1220 case Expression::kPrefix_Kind:
1221 return this->writePrefixExpression((PrefixExpression&) expr, out);
1222 case Expression::kPostfix_Kind:
1223 return this->writePostfixExpression((PostfixExpression&) expr, out);
1224 case Expression::kSwizzle_Kind:
1225 return this->writeSwizzle((Swizzle&) expr, out);
1226 case Expression::kVariableReference_Kind:
1227 return this->writeVariableReference((VariableReference&) expr, out);
1228 case Expression::kTernary_Kind:
1229 return this->writeTernaryExpression((TernaryExpression&) expr, out);
1230 case Expression::kIndex_Kind:
1231 return this->writeIndexExpression((IndexExpression&) expr, out);
1232 default:
1233 ABORT("unsupported expression: %s", expr.description().c_str());
1234 }
1235 return -1;
1236 }
1237
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1238 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1239 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
1240 ASSERT(intrinsic != fIntrinsicMap.end());
1241 const Type& type = c.fArguments[0]->fType;
1242 int32_t intrinsicId;
1243 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
1244 intrinsicId = std::get<1>(intrinsic->second);
1245 } else if (is_signed(fContext, type)) {
1246 intrinsicId = std::get<2>(intrinsic->second);
1247 } else if (is_unsigned(fContext, type)) {
1248 intrinsicId = std::get<3>(intrinsic->second);
1249 } else if (is_bool(fContext, type)) {
1250 intrinsicId = std::get<4>(intrinsic->second);
1251 } else {
1252 ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
1253 type.description().c_str());
1254 }
1255 switch (std::get<0>(intrinsic->second)) {
1256 case kGLSL_STD_450_IntrinsicKind: {
1257 SpvId result = this->nextId();
1258 std::vector<SpvId> arguments;
1259 for (size_t i = 0; i < c.fArguments.size(); i++) {
1260 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1261 }
1262 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1263 this->writeWord(this->getType(c.fType), out);
1264 this->writeWord(result, out);
1265 this->writeWord(fGLSLExtendedInstructions, out);
1266 this->writeWord(intrinsicId, out);
1267 for (SpvId id : arguments) {
1268 this->writeWord(id, out);
1269 }
1270 return result;
1271 }
1272 case kSPIRV_IntrinsicKind: {
1273 SpvId result = this->nextId();
1274 std::vector<SpvId> arguments;
1275 for (size_t i = 0; i < c.fArguments.size(); i++) {
1276 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1277 }
1278 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1279 this->writeWord(this->getType(c.fType), out);
1280 this->writeWord(result, out);
1281 for (SpvId id : arguments) {
1282 this->writeWord(id, out);
1283 }
1284 return result;
1285 }
1286 case kSpecial_IntrinsicKind:
1287 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1288 default:
1289 ABORT("unsupported intrinsic kind");
1290 }
1291 }
1292
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)1293 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1294 OutputStream& out) {
1295 SpvId result = this->nextId();
1296 switch (kind) {
1297 case kAtan_SpecialIntrinsic: {
1298 std::vector<SpvId> arguments;
1299 for (size_t i = 0; i < c.fArguments.size(); i++) {
1300 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1301 }
1302 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1303 this->writeWord(this->getType(c.fType), out);
1304 this->writeWord(result, out);
1305 this->writeWord(fGLSLExtendedInstructions, out);
1306 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1307 for (SpvId id : arguments) {
1308 this->writeWord(id, out);
1309 }
1310 break;
1311 }
1312 case kSubpassLoad_SpecialIntrinsic: {
1313 SpvId img = this->writeExpression(*c.fArguments[0], out);
1314 std::vector<std::unique_ptr<Expression>> args;
1315 args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1316 args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1317 Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args));
1318 SpvId coords = this->writeConstantVector(ctor);
1319 if (1 == c.fArguments.size()) {
1320 this->writeInstruction(SpvOpImageRead,
1321 this->getType(c.fType),
1322 result,
1323 img,
1324 coords,
1325 out);
1326 } else {
1327 ASSERT(2 == c.fArguments.size());
1328 SpvId sample = this->writeExpression(*c.fArguments[1], out);
1329 this->writeInstruction(SpvOpImageRead,
1330 this->getType(c.fType),
1331 result,
1332 img,
1333 coords,
1334 SpvImageOperandsSampleMask,
1335 sample,
1336 out);
1337 }
1338 break;
1339 }
1340 case kTexelFetch_SpecialIntrinsic: {
1341 ASSERT(c.fArguments.size() == 2);
1342 SpvId image = this->nextId();
1343 this->writeInstruction(SpvOpImage,
1344 this->getImageType(c.fArguments[0]->fType),
1345 image,
1346 this->writeExpression(*c.fArguments[0], out),
1347 out);
1348 this->writeInstruction(SpvOpImageFetch,
1349 this->getType(c.fType),
1350 result,
1351 image,
1352 this->writeExpression(*c.fArguments[1], out),
1353 out);
1354 break;
1355 }
1356 case kTexture_SpecialIntrinsic: {
1357 SpvOp_ op = SpvOpImageSampleImplicitLod;
1358 switch (c.fArguments[0]->fType.dimensions()) {
1359 case SpvDim1D:
1360 if (c.fArguments[1]->fType == *fContext.fVec2_Type) {
1361 op = SpvOpImageSampleProjImplicitLod;
1362 } else {
1363 ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
1364 }
1365 break;
1366 case SpvDim2D:
1367 if (c.fArguments[1]->fType == *fContext.fVec3_Type) {
1368 op = SpvOpImageSampleProjImplicitLod;
1369 } else {
1370 ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type);
1371 }
1372 break;
1373 case SpvDim3D:
1374 if (c.fArguments[1]->fType == *fContext.fVec4_Type) {
1375 op = SpvOpImageSampleProjImplicitLod;
1376 } else {
1377 ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type);
1378 }
1379 break;
1380 case SpvDimCube: // fall through
1381 case SpvDimRect: // fall through
1382 case SpvDimBuffer: // fall through
1383 case SpvDimSubpassData:
1384 break;
1385 }
1386 SpvId type = this->getType(c.fType);
1387 SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1388 SpvId uv = this->writeExpression(*c.fArguments[1], out);
1389 if (c.fArguments.size() == 3) {
1390 this->writeInstruction(op, type, result, sampler, uv,
1391 SpvImageOperandsBiasMask,
1392 this->writeExpression(*c.fArguments[2], out),
1393 out);
1394 } else {
1395 ASSERT(c.fArguments.size() == 2);
1396 this->writeInstruction(op, type, result, sampler, uv,
1397 out);
1398 }
1399 break;
1400 }
1401 }
1402 return result;
1403 }
1404
writeFunctionCall(const FunctionCall & c,OutputStream & out)1405 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1406 const auto& entry = fFunctionMap.find(&c.fFunction);
1407 if (entry == fFunctionMap.end()) {
1408 return this->writeIntrinsicCall(c, out);
1409 }
1410 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1411 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
1412 std::vector<SpvId> arguments;
1413 for (size_t i = 0; i < c.fArguments.size(); i++) {
1414 // id of temporary variable that we will use to hold this argument, or 0 if it is being
1415 // passed directly
1416 SpvId tmpVar;
1417 // if we need a temporary var to store this argument, this is the value to store in the var
1418 SpvId tmpValueId;
1419 if (is_out(*c.fFunction.fParameters[i])) {
1420 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1421 SpvId ptr = lv->getPointer();
1422 if (ptr) {
1423 arguments.push_back(ptr);
1424 continue;
1425 } else {
1426 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1427 // copy it into a temp, call the function, read the value out of the temp, and then
1428 // update the lvalue.
1429 tmpValueId = lv->load(out);
1430 tmpVar = this->nextId();
1431 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
1432 std::move(lv)));
1433 }
1434 } else {
1435 // see getFunctionType for an explanation of why we're always using pointer parameters
1436 tmpValueId = this->writeExpression(*c.fArguments[i], out);
1437 tmpVar = this->nextId();
1438 }
1439 this->writeInstruction(SpvOpVariable,
1440 this->getPointerType(c.fArguments[i]->fType,
1441 SpvStorageClassFunction),
1442 tmpVar,
1443 SpvStorageClassFunction,
1444 fVariableBuffer);
1445 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1446 arguments.push_back(tmpVar);
1447 }
1448 SpvId result = this->nextId();
1449 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1450 this->writeWord(this->getType(c.fType), out);
1451 this->writeWord(result, out);
1452 this->writeWord(entry->second, out);
1453 for (SpvId id : arguments) {
1454 this->writeWord(id, out);
1455 }
1456 // now that the call is complete, we may need to update some lvalues with the new values of out
1457 // arguments
1458 for (const auto& tuple : lvalues) {
1459 SpvId load = this->nextId();
1460 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1461 std::get<2>(tuple)->store(load, out);
1462 }
1463 return result;
1464 }
1465
writeConstantVector(const Constructor & c)1466 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1467 ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1468 SpvId result = this->nextId();
1469 std::vector<SpvId> arguments;
1470 for (size_t i = 0; i < c.fArguments.size(); i++) {
1471 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1472 }
1473 SpvId type = this->getType(c.fType);
1474 if (c.fArguments.size() == 1) {
1475 // with a single argument, a vector will have all of its entries equal to the argument
1476 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1477 this->writeWord(type, fConstantBuffer);
1478 this->writeWord(result, fConstantBuffer);
1479 for (int i = 0; i < c.fType.columns(); i++) {
1480 this->writeWord(arguments[0], fConstantBuffer);
1481 }
1482 } else {
1483 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1484 fConstantBuffer);
1485 this->writeWord(type, fConstantBuffer);
1486 this->writeWord(result, fConstantBuffer);
1487 for (SpvId id : arguments) {
1488 this->writeWord(id, fConstantBuffer);
1489 }
1490 }
1491 return result;
1492 }
1493
writeFloatConstructor(const Constructor & c,OutputStream & out)1494 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1495 ASSERT(c.fType == *fContext.fFloat_Type);
1496 ASSERT(c.fArguments.size() == 1);
1497 ASSERT(c.fArguments[0]->fType.isNumber());
1498 SpvId result = this->nextId();
1499 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1500 if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1501 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1502 out);
1503 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1504 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1505 out);
1506 } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1507 return parameter;
1508 }
1509 return result;
1510 }
1511
writeIntConstructor(const Constructor & c,OutputStream & out)1512 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1513 ASSERT(c.fType == *fContext.fInt_Type);
1514 ASSERT(c.fArguments.size() == 1);
1515 ASSERT(c.fArguments[0]->fType.isNumber());
1516 SpvId result = this->nextId();
1517 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1518 if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1519 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1520 out);
1521 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1522 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1523 out);
1524 } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1525 return parameter;
1526 }
1527 return result;
1528 }
1529
writeUIntConstructor(const Constructor & c,OutputStream & out)1530 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1531 ASSERT(c.fType == *fContext.fUInt_Type);
1532 ASSERT(c.fArguments.size() == 1);
1533 ASSERT(c.fArguments[0]->fType.isNumber());
1534 SpvId result = this->nextId();
1535 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1536 if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1537 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1538 out);
1539 } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1540 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1541 out);
1542 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1543 return parameter;
1544 }
1545 return result;
1546 }
1547
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1548 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1549 OutputStream& out) {
1550 FloatLiteral zero(fContext, Position(), 0);
1551 SpvId zeroId = this->writeFloatLiteral(zero);
1552 std::vector<SpvId> columnIds;
1553 for (int column = 0; column < type.columns(); column++) {
1554 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1555 out);
1556 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1557 out);
1558 SpvId columnId = this->nextId();
1559 this->writeWord(columnId, out);
1560 columnIds.push_back(columnId);
1561 for (int row = 0; row < type.columns(); row++) {
1562 this->writeWord(row == column ? diagonal : zeroId, out);
1563 }
1564 }
1565 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1566 out);
1567 this->writeWord(this->getType(type), out);
1568 this->writeWord(id, out);
1569 for (SpvId id : columnIds) {
1570 this->writeWord(id, out);
1571 }
1572 }
1573
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1574 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1575 const Type& dstType, OutputStream& out) {
1576 ABORT("unimplemented");
1577 }
1578
writeMatrixConstructor(const Constructor & c,OutputStream & out)1579 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1580 ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1581 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1582 // an instruction
1583 std::vector<SpvId> arguments;
1584 for (size_t i = 0; i < c.fArguments.size(); i++) {
1585 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1586 }
1587 SpvId result = this->nextId();
1588 int rows = c.fType.rows();
1589 int columns = c.fType.columns();
1590 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1591 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1592 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1593 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1594 } else {
1595 std::vector<SpvId> columnIds;
1596 int currentCount = 0;
1597 for (size_t i = 0; i < arguments.size(); i++) {
1598 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1599 ASSERT(currentCount == 0);
1600 columnIds.push_back(arguments[i]);
1601 currentCount = 0;
1602 } else {
1603 ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
1604 if (currentCount == 0) {
1605 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
1606 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1607 1)),
1608 out);
1609 SpvId id = this->nextId();
1610 this->writeWord(id, out);
1611 columnIds.push_back(id);
1612 }
1613 this->writeWord(arguments[i], out);
1614 currentCount = (currentCount + 1) % rows;
1615 }
1616 }
1617 ASSERT(columnIds.size() == (size_t) columns);
1618 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1619 this->writeWord(this->getType(c.fType), out);
1620 this->writeWord(result, out);
1621 for (SpvId id : columnIds) {
1622 this->writeWord(id, out);
1623 }
1624 }
1625 return result;
1626 }
1627
writeVectorConstructor(const Constructor & c,OutputStream & out)1628 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1629 ASSERT(c.fType.kind() == Type::kVector_Kind);
1630 if (c.isConstant()) {
1631 return this->writeConstantVector(c);
1632 }
1633 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1634 // an instruction
1635 std::vector<SpvId> arguments;
1636 for (size_t i = 0; i < c.fArguments.size(); i++) {
1637 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1638 }
1639 SpvId result = this->nextId();
1640 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1641 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1642 this->writeWord(this->getType(c.fType), out);
1643 this->writeWord(result, out);
1644 for (int i = 0; i < c.fType.columns(); i++) {
1645 this->writeWord(arguments[0], out);
1646 }
1647 } else {
1648 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1649 this->writeWord(this->getType(c.fType), out);
1650 this->writeWord(result, out);
1651 for (SpvId id : arguments) {
1652 this->writeWord(id, out);
1653 }
1654 }
1655 return result;
1656 }
1657
writeConstructor(const Constructor & c,OutputStream & out)1658 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1659 if (c.fType == *fContext.fFloat_Type) {
1660 return this->writeFloatConstructor(c, out);
1661 } else if (c.fType == *fContext.fInt_Type) {
1662 return this->writeIntConstructor(c, out);
1663 } else if (c.fType == *fContext.fUInt_Type) {
1664 return this->writeUIntConstructor(c, out);
1665 }
1666 switch (c.fType.kind()) {
1667 case Type::kVector_Kind:
1668 return this->writeVectorConstructor(c, out);
1669 case Type::kMatrix_Kind:
1670 return this->writeMatrixConstructor(c, out);
1671 default:
1672 ABORT("unsupported constructor: %s", c.description().c_str());
1673 }
1674 }
1675
get_storage_class(const Modifiers & modifiers)1676 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1677 if (modifiers.fFlags & Modifiers::kIn_Flag) {
1678 ASSERT(!modifiers.fLayout.fPushConstant);
1679 return SpvStorageClassInput;
1680 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1681 ASSERT(!modifiers.fLayout.fPushConstant);
1682 return SpvStorageClassOutput;
1683 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1684 if (modifiers.fLayout.fPushConstant) {
1685 return SpvStorageClassPushConstant;
1686 }
1687 return SpvStorageClassUniform;
1688 } else {
1689 return SpvStorageClassFunction;
1690 }
1691 }
1692
get_storage_class(const Expression & expr)1693 SpvStorageClass_ get_storage_class(const Expression& expr) {
1694 switch (expr.fKind) {
1695 case Expression::kVariableReference_Kind: {
1696 const Variable& var = ((VariableReference&) expr).fVariable;
1697 if (var.fStorage != Variable::kGlobal_Storage) {
1698 return SpvStorageClassFunction;
1699 }
1700 return get_storage_class(var.fModifiers);
1701 }
1702 case Expression::kFieldAccess_Kind:
1703 return get_storage_class(*((FieldAccess&) expr).fBase);
1704 case Expression::kIndex_Kind:
1705 return get_storage_class(*((IndexExpression&) expr).fBase);
1706 default:
1707 return SpvStorageClassFunction;
1708 }
1709 }
1710
getAccessChain(const Expression & expr,OutputStream & out)1711 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1712 std::vector<SpvId> chain;
1713 switch (expr.fKind) {
1714 case Expression::kIndex_Kind: {
1715 IndexExpression& indexExpr = (IndexExpression&) expr;
1716 chain = this->getAccessChain(*indexExpr.fBase, out);
1717 chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1718 break;
1719 }
1720 case Expression::kFieldAccess_Kind: {
1721 FieldAccess& fieldExpr = (FieldAccess&) expr;
1722 chain = this->getAccessChain(*fieldExpr.fBase, out);
1723 IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
1724 chain.push_back(this->writeIntLiteral(index));
1725 break;
1726 }
1727 default:
1728 chain.push_back(this->getLValue(expr, out)->getPointer());
1729 }
1730 return chain;
1731 }
1732
1733 class PointerLValue : public SPIRVCodeGenerator::LValue {
1734 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1735 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1736 : fGen(gen)
1737 , fPointer(pointer)
1738 , fType(type) {}
1739
getPointer()1740 virtual SpvId getPointer() override {
1741 return fPointer;
1742 }
1743
load(OutputStream & out)1744 virtual SpvId load(OutputStream& out) override {
1745 SpvId result = fGen.nextId();
1746 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1747 return result;
1748 }
1749
store(SpvId value,OutputStream & out)1750 virtual void store(SpvId value, OutputStream& out) override {
1751 fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1752 }
1753
1754 private:
1755 SPIRVCodeGenerator& fGen;
1756 const SpvId fPointer;
1757 const SpvId fType;
1758 };
1759
1760 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1761 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1762 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1763 const Type& baseType, const Type& swizzleType)
1764 : fGen(gen)
1765 , fVecPointer(vecPointer)
1766 , fComponents(components)
1767 , fBaseType(baseType)
1768 , fSwizzleType(swizzleType) {}
1769
getPointer()1770 virtual SpvId getPointer() override {
1771 return 0;
1772 }
1773
load(OutputStream & out)1774 virtual SpvId load(OutputStream& out) override {
1775 SpvId base = fGen.nextId();
1776 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1777 SpvId result = fGen.nextId();
1778 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1779 fGen.writeWord(fGen.getType(fSwizzleType), out);
1780 fGen.writeWord(result, out);
1781 fGen.writeWord(base, out);
1782 fGen.writeWord(base, out);
1783 for (int component : fComponents) {
1784 fGen.writeWord(component, out);
1785 }
1786 return result;
1787 }
1788
store(SpvId value,OutputStream & out)1789 virtual void store(SpvId value, OutputStream& out) override {
1790 // use OpVectorShuffle to mix and match the vector components. We effectively create
1791 // a virtual vector out of the concatenation of the left and right vectors, and then
1792 // select components from this virtual vector to make the result vector. For
1793 // instance, given:
1794 // vec3 L = ...;
1795 // vec3 R = ...;
1796 // L.xz = R.xy;
1797 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1798 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1799 // (3, 1, 4).
1800 SpvId base = fGen.nextId();
1801 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1802 SpvId shuffle = fGen.nextId();
1803 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1804 fGen.writeWord(fGen.getType(fBaseType), out);
1805 fGen.writeWord(shuffle, out);
1806 fGen.writeWord(base, out);
1807 fGen.writeWord(value, out);
1808 for (int i = 0; i < fBaseType.columns(); i++) {
1809 // current offset into the virtual vector, defaults to pulling the unmodified
1810 // value from the left side
1811 int offset = i;
1812 // check to see if we are writing this component
1813 for (size_t j = 0; j < fComponents.size(); j++) {
1814 if (fComponents[j] == i) {
1815 // we're writing to this component, so adjust the offset to pull from
1816 // the correct component of the right side instead of preserving the
1817 // value from the left
1818 offset = (int) (j + fBaseType.columns());
1819 break;
1820 }
1821 }
1822 fGen.writeWord(offset, out);
1823 }
1824 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1825 }
1826
1827 private:
1828 SPIRVCodeGenerator& fGen;
1829 const SpvId fVecPointer;
1830 const std::vector<int>& fComponents;
1831 const Type& fBaseType;
1832 const Type& fSwizzleType;
1833 };
1834
getLValue(const Expression & expr,OutputStream & out)1835 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1836 OutputStream& out) {
1837 switch (expr.fKind) {
1838 case Expression::kVariableReference_Kind: {
1839 const Variable& var = ((VariableReference&) expr).fVariable;
1840 auto entry = fVariableMap.find(&var);
1841 ASSERT(entry != fVariableMap.end());
1842 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1843 *this,
1844 entry->second,
1845 this->getType(expr.fType)));
1846 }
1847 case Expression::kIndex_Kind: // fall through
1848 case Expression::kFieldAccess_Kind: {
1849 std::vector<SpvId> chain = this->getAccessChain(expr, out);
1850 SpvId member = this->nextId();
1851 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1852 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1853 this->writeWord(member, out);
1854 for (SpvId idx : chain) {
1855 this->writeWord(idx, out);
1856 }
1857 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1858 *this,
1859 member,
1860 this->getType(expr.fType)));
1861 }
1862
1863 case Expression::kSwizzle_Kind: {
1864 Swizzle& swizzle = (Swizzle&) expr;
1865 size_t count = swizzle.fComponents.size();
1866 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1867 ASSERT(base);
1868 if (count == 1) {
1869 IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
1870 SpvId member = this->nextId();
1871 this->writeInstruction(SpvOpAccessChain,
1872 this->getPointerType(swizzle.fType,
1873 get_storage_class(*swizzle.fBase)),
1874 member,
1875 base,
1876 this->writeIntLiteral(index),
1877 out);
1878 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1879 *this,
1880 member,
1881 this->getType(expr.fType)));
1882 } else {
1883 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1884 *this,
1885 base,
1886 swizzle.fComponents,
1887 swizzle.fBase->fType,
1888 expr.fType));
1889 }
1890 }
1891
1892 default:
1893 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1894 // to the need to store values in temporary variables during function calls (see
1895 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1896 // caught by IRGenerator
1897 SpvId result = this->nextId();
1898 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1899 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1900 fVariableBuffer);
1901 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1902 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1903 *this,
1904 result,
1905 this->getType(expr.fType)));
1906 }
1907 }
1908
writeVariableReference(const VariableReference & ref,OutputStream & out)1909 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1910 SpvId result = this->nextId();
1911 auto entry = fVariableMap.find(&ref.fVariable);
1912 ASSERT(entry != fVariableMap.end());
1913 SpvId var = entry->second;
1914 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1915 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1916 fProgram.fSettings.fFlipY) {
1917 // need to remap to a top-left coordinate system
1918 if (fRTHeightStructId == (SpvId) -1) {
1919 // height variable hasn't been written yet
1920 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1921 ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1922 std::vector<Type::Field> fields;
1923 fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME),
1924 fContext.fFloat_Type.get());
1925 String name("sksl_synthetic_uniforms");
1926 Type intfStruct(Position(), name, fields);
1927 Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false,
1928 Layout::Format::kUnspecified, false, Layout::kUnspecified_Primitive, -1,
1929 -1, "", Layout::kNo_Key);
1930 Variable* intfVar = new Variable(Position(),
1931 Modifiers(layout, Modifiers::kUniform_Flag),
1932 name,
1933 intfStruct,
1934 Variable::kGlobal_Storage);
1935 fSynthetics.takeOwnership(intfVar);
1936 InterfaceBlock intf(Position(), intfVar, name, String(""),
1937 std::vector<std::unique_ptr<Expression>>(), st);
1938 fRTHeightStructId = this->writeInterfaceBlock(intf);
1939 fRTHeightFieldIndex = 0;
1940 }
1941 ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1942 // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1943 SpvId xId = this->nextId();
1944 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1945 result, 0, out);
1946 IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex);
1947 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1948 SpvId heightPtr = this->nextId();
1949 this->writeOpCode(SpvOpAccessChain, 5, out);
1950 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1951 this->writeWord(heightPtr, out);
1952 this->writeWord(fRTHeightStructId, out);
1953 this->writeWord(fieldIndexId, out);
1954 SpvId heightRead = this->nextId();
1955 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1956 heightPtr, out);
1957 SpvId rawYId = this->nextId();
1958 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1959 result, 1, out);
1960 SpvId flippedYId = this->nextId();
1961 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1962 heightRead, rawYId, out);
1963 FloatLiteral zero(fContext, Position(), 0.0);
1964 SpvId zeroId = writeFloatLiteral(zero);
1965 FloatLiteral one(fContext, Position(), 1.0);
1966 SpvId oneId = writeFloatLiteral(one);
1967 SpvId flipped = this->nextId();
1968 this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1969 this->writeWord(this->getType(*fContext.fVec4_Type), out);
1970 this->writeWord(flipped, out);
1971 this->writeWord(xId, out);
1972 this->writeWord(flippedYId, out);
1973 this->writeWord(zeroId, out);
1974 this->writeWord(oneId, out);
1975 return flipped;
1976 }
1977 return result;
1978 }
1979
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1980 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1981 return getLValue(expr, out)->load(out);
1982 }
1983
writeFieldAccess(const FieldAccess & f,OutputStream & out)1984 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1985 return getLValue(f, out)->load(out);
1986 }
1987
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1988 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1989 SpvId base = this->writeExpression(*swizzle.fBase, out);
1990 SpvId result = this->nextId();
1991 size_t count = swizzle.fComponents.size();
1992 if (count == 1) {
1993 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1994 swizzle.fComponents[0], out);
1995 } else {
1996 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1997 this->writeWord(this->getType(swizzle.fType), out);
1998 this->writeWord(result, out);
1999 this->writeWord(base, out);
2000 this->writeWord(base, out);
2001 for (int component : swizzle.fComponents) {
2002 this->writeWord(component, out);
2003 }
2004 }
2005 return result;
2006 }
2007
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2008 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2009 const Type& operandType, SpvId lhs,
2010 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2011 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2012 SpvId result = this->nextId();
2013 if (is_float(fContext, operandType)) {
2014 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
2015 } else if (is_signed(fContext, operandType)) {
2016 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
2017 } else if (is_unsigned(fContext, operandType)) {
2018 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
2019 } else if (operandType == *fContext.fBool_Type) {
2020 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
2021 } else {
2022 ABORT("invalid operandType: %s", operandType.description().c_str());
2023 }
2024 return result;
2025 }
2026
is_assignment(Token::Kind op)2027 bool is_assignment(Token::Kind op) {
2028 switch (op) {
2029 case Token::EQ: // fall through
2030 case Token::PLUSEQ: // fall through
2031 case Token::MINUSEQ: // fall through
2032 case Token::STAREQ: // fall through
2033 case Token::SLASHEQ: // fall through
2034 case Token::PERCENTEQ: // fall through
2035 case Token::SHLEQ: // fall through
2036 case Token::SHREQ: // fall through
2037 case Token::BITWISEOREQ: // fall through
2038 case Token::BITWISEXOREQ: // fall through
2039 case Token::BITWISEANDEQ: // fall through
2040 case Token::LOGICALOREQ: // fall through
2041 case Token::LOGICALXOREQ: // fall through
2042 case Token::LOGICALANDEQ:
2043 return true;
2044 default:
2045 return false;
2046 }
2047 }
2048
foldToBool(SpvId id,const Type & operandType,OutputStream & out)2049 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
2050 if (operandType.kind() == Type::kVector_Kind) {
2051 SpvId result = this->nextId();
2052 this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
2053 return result;
2054 }
2055 return id;
2056 }
2057
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)2058 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2059 SpvOp_ floatOperator, SpvOp_ intOperator,
2060 OutputStream& out) {
2061 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
2062 ASSERT(operandType.kind() == Type::kMatrix_Kind);
2063 SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
2064 operandType.columns(),
2065 1));
2066 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
2067 operandType.columns(),
2068 1));
2069 SpvId boolType = this->getType(*fContext.fBool_Type);
2070 SpvId result = 0;
2071 for (int i = 0; i < operandType.rows(); i++) {
2072 SpvId rowL = this->nextId();
2073 this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
2074 SpvId rowR = this->nextId();
2075 this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
2076 SpvId compare = this->nextId();
2077 this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
2078 SpvId all = this->nextId();
2079 this->writeInstruction(SpvOpAll, boolType, all, compare, out);
2080 if (result != 0) {
2081 SpvId next = this->nextId();
2082 this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
2083 result = next;
2084 }
2085 else {
2086 result = all;
2087 }
2088 }
2089 return result;
2090 }
2091
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2092 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2093 // handle cases where we don't necessarily evaluate both LHS and RHS
2094 switch (b.fOperator) {
2095 case Token::EQ: {
2096 SpvId rhs = this->writeExpression(*b.fRight, out);
2097 this->getLValue(*b.fLeft, out)->store(rhs, out);
2098 return rhs;
2099 }
2100 case Token::LOGICALAND:
2101 return this->writeLogicalAnd(b, out);
2102 case Token::LOGICALOR:
2103 return this->writeLogicalOr(b, out);
2104 default:
2105 break;
2106 }
2107
2108 // "normal" operators
2109 const Type& resultType = b.fType;
2110 std::unique_ptr<LValue> lvalue;
2111 SpvId lhs;
2112 if (is_assignment(b.fOperator)) {
2113 lvalue = this->getLValue(*b.fLeft, out);
2114 lhs = lvalue->load(out);
2115 } else {
2116 lvalue = nullptr;
2117 lhs = this->writeExpression(*b.fLeft, out);
2118 }
2119 SpvId rhs = this->writeExpression(*b.fRight, out);
2120 if (b.fOperator == Token::COMMA) {
2121 return rhs;
2122 }
2123 // component type we are operating on: float, int, uint
2124 const Type* operandType;
2125 // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
2126 // in SPIR-V
2127 if (b.fLeft->fType != b.fRight->fType) {
2128 if (b.fLeft->fType.kind() == Type::kVector_Kind &&
2129 b.fRight->fType.isNumber()) {
2130 // promote number to vector
2131 SpvId vec = this->nextId();
2132 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2133 this->writeWord(this->getType(resultType), out);
2134 this->writeWord(vec, out);
2135 for (int i = 0; i < resultType.columns(); i++) {
2136 this->writeWord(rhs, out);
2137 }
2138 rhs = vec;
2139 operandType = &b.fRight->fType;
2140 } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
2141 b.fLeft->fType.isNumber()) {
2142 // promote number to vector
2143 SpvId vec = this->nextId();
2144 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2145 this->writeWord(this->getType(resultType), out);
2146 this->writeWord(vec, out);
2147 for (int i = 0; i < resultType.columns(); i++) {
2148 this->writeWord(lhs, out);
2149 }
2150 lhs = vec;
2151 ASSERT(!lvalue);
2152 operandType = &b.fLeft->fType;
2153 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2154 SpvOp_ op;
2155 if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2156 op = SpvOpMatrixTimesMatrix;
2157 } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2158 op = SpvOpMatrixTimesVector;
2159 } else {
2160 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2161 op = SpvOpMatrixTimesScalar;
2162 }
2163 SpvId result = this->nextId();
2164 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2165 if (b.fOperator == Token::STAREQ) {
2166 lvalue->store(result, out);
2167 } else {
2168 ASSERT(b.fOperator == Token::STAR);
2169 }
2170 return result;
2171 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2172 SpvId result = this->nextId();
2173 if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2174 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2175 lhs, rhs, out);
2176 } else {
2177 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2178 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2179 lhs, out);
2180 }
2181 if (b.fOperator == Token::STAREQ) {
2182 lvalue->store(result, out);
2183 } else {
2184 ASSERT(b.fOperator == Token::STAR);
2185 }
2186 return result;
2187 } else {
2188 ABORT("unsupported binary expression: %s", b.description().c_str());
2189 }
2190 } else {
2191 operandType = &b.fLeft->fType;
2192 ASSERT(*operandType == b.fRight->fType);
2193 }
2194 switch (b.fOperator) {
2195 case Token::EQEQ: {
2196 if (operandType->kind() == Type::kMatrix_Kind) {
2197 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2198 SpvOpIEqual, out);
2199 }
2200 ASSERT(resultType == *fContext.fBool_Type);
2201 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2202 SpvOpFOrdEqual, SpvOpIEqual,
2203 SpvOpIEqual, SpvOpLogicalEqual, out),
2204 *operandType, out);
2205 }
2206 case Token::NEQ:
2207 if (operandType->kind() == Type::kMatrix_Kind) {
2208 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2209 SpvOpINotEqual, out);
2210 }
2211 ASSERT(resultType == *fContext.fBool_Type);
2212 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2213 SpvOpFOrdNotEqual, SpvOpINotEqual,
2214 SpvOpINotEqual, SpvOpLogicalNotEqual,
2215 out),
2216 *operandType, out);
2217 case Token::GT:
2218 ASSERT(resultType == *fContext.fBool_Type);
2219 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2220 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2221 SpvOpUGreaterThan, SpvOpUndef, out);
2222 case Token::LT:
2223 ASSERT(resultType == *fContext.fBool_Type);
2224 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2225 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2226 case Token::GTEQ:
2227 ASSERT(resultType == *fContext.fBool_Type);
2228 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2229 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2230 SpvOpUGreaterThanEqual, SpvOpUndef, out);
2231 case Token::LTEQ:
2232 ASSERT(resultType == *fContext.fBool_Type);
2233 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2234 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2235 SpvOpULessThanEqual, SpvOpUndef, out);
2236 case Token::PLUS:
2237 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2238 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2239 case Token::MINUS:
2240 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2241 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2242 case Token::STAR:
2243 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2244 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2245 // matrix multiply
2246 SpvId result = this->nextId();
2247 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2248 lhs, rhs, out);
2249 return result;
2250 }
2251 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2252 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2253 case Token::SLASH:
2254 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2255 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2256 case Token::PERCENT:
2257 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2258 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2259 case Token::SHL:
2260 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2261 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2262 SpvOpUndef, out);
2263 case Token::SHR:
2264 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2265 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2266 SpvOpUndef, out);
2267 case Token::BITWISEAND:
2268 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2269 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2270 case Token::BITWISEOR:
2271 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2272 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2273 case Token::BITWISEXOR:
2274 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2275 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2276 case Token::PLUSEQ: {
2277 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2278 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2279 ASSERT(lvalue);
2280 lvalue->store(result, out);
2281 return result;
2282 }
2283 case Token::MINUSEQ: {
2284 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2285 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2286 ASSERT(lvalue);
2287 lvalue->store(result, out);
2288 return result;
2289 }
2290 case Token::STAREQ: {
2291 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2292 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2293 // matrix multiply
2294 SpvId result = this->nextId();
2295 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2296 lhs, rhs, out);
2297 ASSERT(lvalue);
2298 lvalue->store(result, out);
2299 return result;
2300 }
2301 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2302 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2303 ASSERT(lvalue);
2304 lvalue->store(result, out);
2305 return result;
2306 }
2307 case Token::SLASHEQ: {
2308 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2309 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2310 ASSERT(lvalue);
2311 lvalue->store(result, out);
2312 return result;
2313 }
2314 case Token::PERCENTEQ: {
2315 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2316 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2317 ASSERT(lvalue);
2318 lvalue->store(result, out);
2319 return result;
2320 }
2321 case Token::SHLEQ: {
2322 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2323 SpvOpUndef, SpvOpShiftLeftLogical,
2324 SpvOpShiftLeftLogical, SpvOpUndef, out);
2325 ASSERT(lvalue);
2326 lvalue->store(result, out);
2327 return result;
2328 }
2329 case Token::SHREQ: {
2330 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2331 SpvOpUndef, SpvOpShiftRightArithmetic,
2332 SpvOpShiftRightLogical, SpvOpUndef, out);
2333 ASSERT(lvalue);
2334 lvalue->store(result, out);
2335 return result;
2336 }
2337 case Token::BITWISEANDEQ: {
2338 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2339 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2340 SpvOpUndef, out);
2341 ASSERT(lvalue);
2342 lvalue->store(result, out);
2343 return result;
2344 }
2345 case Token::BITWISEOREQ: {
2346 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2347 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2348 SpvOpUndef, out);
2349 ASSERT(lvalue);
2350 lvalue->store(result, out);
2351 return result;
2352 }
2353 case Token::BITWISEXOREQ: {
2354 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2355 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2356 SpvOpUndef, out);
2357 ASSERT(lvalue);
2358 lvalue->store(result, out);
2359 return result;
2360 }
2361 default:
2362 ABORT("unsupported binary expression: %s", b.description().c_str());
2363 }
2364 }
2365
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2366 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2367 ASSERT(a.fOperator == Token::LOGICALAND);
2368 BoolLiteral falseLiteral(fContext, Position(), false);
2369 SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2370 SpvId lhs = this->writeExpression(*a.fLeft, out);
2371 SpvId rhsLabel = this->nextId();
2372 SpvId end = this->nextId();
2373 SpvId lhsBlock = fCurrentBlock;
2374 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2375 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2376 this->writeLabel(rhsLabel, out);
2377 SpvId rhs = this->writeExpression(*a.fRight, out);
2378 SpvId rhsBlock = fCurrentBlock;
2379 this->writeInstruction(SpvOpBranch, end, out);
2380 this->writeLabel(end, out);
2381 SpvId result = this->nextId();
2382 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2383 lhsBlock, rhs, rhsBlock, out);
2384 return result;
2385 }
2386
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2387 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2388 ASSERT(o.fOperator == Token::LOGICALOR);
2389 BoolLiteral trueLiteral(fContext, Position(), true);
2390 SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2391 SpvId lhs = this->writeExpression(*o.fLeft, out);
2392 SpvId rhsLabel = this->nextId();
2393 SpvId end = this->nextId();
2394 SpvId lhsBlock = fCurrentBlock;
2395 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2396 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2397 this->writeLabel(rhsLabel, out);
2398 SpvId rhs = this->writeExpression(*o.fRight, out);
2399 SpvId rhsBlock = fCurrentBlock;
2400 this->writeInstruction(SpvOpBranch, end, out);
2401 this->writeLabel(end, out);
2402 SpvId result = this->nextId();
2403 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2404 lhsBlock, rhs, rhsBlock, out);
2405 return result;
2406 }
2407
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2408 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2409 SpvId test = this->writeExpression(*t.fTest, out);
2410 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2411 // both true and false are constants, can just use OpSelect
2412 SpvId result = this->nextId();
2413 SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2414 SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2415 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2416 out);
2417 return result;
2418 }
2419 // was originally using OpPhi to choose the result, but for some reason that is crashing on
2420 // Adreno. Switched to storing the result in a temp variable as glslang does.
2421 SpvId var = this->nextId();
2422 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2423 var, SpvStorageClassFunction, fVariableBuffer);
2424 SpvId trueLabel = this->nextId();
2425 SpvId falseLabel = this->nextId();
2426 SpvId end = this->nextId();
2427 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2428 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2429 this->writeLabel(trueLabel, out);
2430 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2431 this->writeInstruction(SpvOpBranch, end, out);
2432 this->writeLabel(falseLabel, out);
2433 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2434 this->writeInstruction(SpvOpBranch, end, out);
2435 this->writeLabel(end, out);
2436 SpvId result = this->nextId();
2437 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2438 return result;
2439 }
2440
create_literal_1(const Context & context,const Type & type)2441 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2442 if (type == *context.fInt_Type) {
2443 return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
2444 }
2445 else if (type == *context.fFloat_Type) {
2446 return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
2447 } else {
2448 ABORT("math is unsupported on type '%s'", type.name().c_str());
2449 }
2450 }
2451
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2452 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2453 if (p.fOperator == Token::MINUS) {
2454 SpvId result = this->nextId();
2455 SpvId typeId = this->getType(p.fType);
2456 SpvId expr = this->writeExpression(*p.fOperand, out);
2457 if (is_float(fContext, p.fType)) {
2458 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2459 } else if (is_signed(fContext, p.fType)) {
2460 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2461 } else {
2462 ABORT("unsupported prefix expression %s", p.description().c_str());
2463 };
2464 return result;
2465 }
2466 switch (p.fOperator) {
2467 case Token::PLUS:
2468 return this->writeExpression(*p.fOperand, out);
2469 case Token::PLUSPLUS: {
2470 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2471 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2472 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2473 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2474 out);
2475 lv->store(result, out);
2476 return result;
2477 }
2478 case Token::MINUSMINUS: {
2479 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2480 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2481 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2482 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2483 out);
2484 lv->store(result, out);
2485 return result;
2486 }
2487 case Token::LOGICALNOT: {
2488 ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2489 SpvId result = this->nextId();
2490 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2491 this->writeExpression(*p.fOperand, out), out);
2492 return result;
2493 }
2494 case Token::BITWISENOT: {
2495 SpvId result = this->nextId();
2496 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2497 this->writeExpression(*p.fOperand, out), out);
2498 return result;
2499 }
2500 default:
2501 ABORT("unsupported prefix expression: %s", p.description().c_str());
2502 }
2503 }
2504
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2505 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2506 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2507 SpvId result = lv->load(out);
2508 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2509 switch (p.fOperator) {
2510 case Token::PLUSPLUS: {
2511 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2512 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2513 lv->store(temp, out);
2514 return result;
2515 }
2516 case Token::MINUSMINUS: {
2517 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2518 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2519 lv->store(temp, out);
2520 return result;
2521 }
2522 default:
2523 ABORT("unsupported postfix expression %s", p.description().c_str());
2524 }
2525 }
2526
writeBoolLiteral(const BoolLiteral & b)2527 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2528 if (b.fValue) {
2529 if (fBoolTrue == 0) {
2530 fBoolTrue = this->nextId();
2531 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2532 fConstantBuffer);
2533 }
2534 return fBoolTrue;
2535 } else {
2536 if (fBoolFalse == 0) {
2537 fBoolFalse = this->nextId();
2538 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2539 fConstantBuffer);
2540 }
2541 return fBoolFalse;
2542 }
2543 }
2544
writeIntLiteral(const IntLiteral & i)2545 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2546 if (i.fType == *fContext.fInt_Type) {
2547 auto entry = fIntConstants.find(i.fValue);
2548 if (entry == fIntConstants.end()) {
2549 SpvId result = this->nextId();
2550 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2551 fConstantBuffer);
2552 fIntConstants[i.fValue] = result;
2553 return result;
2554 }
2555 return entry->second;
2556 } else {
2557 ASSERT(i.fType == *fContext.fUInt_Type);
2558 auto entry = fUIntConstants.find(i.fValue);
2559 if (entry == fUIntConstants.end()) {
2560 SpvId result = this->nextId();
2561 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2562 fConstantBuffer);
2563 fUIntConstants[i.fValue] = result;
2564 return result;
2565 }
2566 return entry->second;
2567 }
2568 }
2569
writeFloatLiteral(const FloatLiteral & f)2570 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2571 if (f.fType == *fContext.fFloat_Type) {
2572 float value = (float) f.fValue;
2573 auto entry = fFloatConstants.find(value);
2574 if (entry == fFloatConstants.end()) {
2575 SpvId result = this->nextId();
2576 uint32_t bits;
2577 ASSERT(sizeof(bits) == sizeof(value));
2578 memcpy(&bits, &value, sizeof(bits));
2579 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2580 fConstantBuffer);
2581 fFloatConstants[value] = result;
2582 return result;
2583 }
2584 return entry->second;
2585 } else {
2586 ASSERT(f.fType == *fContext.fDouble_Type);
2587 auto entry = fDoubleConstants.find(f.fValue);
2588 if (entry == fDoubleConstants.end()) {
2589 SpvId result = this->nextId();
2590 uint64_t bits;
2591 ASSERT(sizeof(bits) == sizeof(f.fValue));
2592 memcpy(&bits, &f.fValue, sizeof(bits));
2593 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2594 bits & 0xffffffff, bits >> 32, fConstantBuffer);
2595 fDoubleConstants[f.fValue] = result;
2596 return result;
2597 }
2598 return entry->second;
2599 }
2600 }
2601
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2602 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2603 SpvId result = fFunctionMap[&f];
2604 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2605 SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2606 this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
2607 for (size_t i = 0; i < f.fParameters.size(); i++) {
2608 SpvId id = this->nextId();
2609 fVariableMap[f.fParameters[i]] = id;
2610 SpvId type;
2611 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2612 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2613 }
2614 return result;
2615 }
2616
writeFunction(const FunctionDefinition & f,OutputStream & out)2617 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2618 fVariableBuffer.reset();
2619 SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2620 this->writeLabel(this->nextId(), out);
2621 if (f.fDeclaration.fName == "main") {
2622 write_stringstream(fGlobalInitializersBuffer, out);
2623 }
2624 StringStream bodyBuffer;
2625 this->writeBlock((Block&) *f.fBody, bodyBuffer);
2626 write_stringstream(fVariableBuffer, out);
2627 write_stringstream(bodyBuffer, out);
2628 if (fCurrentBlock) {
2629 this->writeInstruction(SpvOpReturn, out);
2630 }
2631 this->writeInstruction(SpvOpFunctionEnd, out);
2632 return result;
2633 }
2634
writeLayout(const Layout & layout,SpvId target)2635 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2636 if (layout.fLocation >= 0) {
2637 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2638 fDecorationBuffer);
2639 }
2640 if (layout.fBinding >= 0) {
2641 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2642 fDecorationBuffer);
2643 }
2644 if (layout.fIndex >= 0) {
2645 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2646 fDecorationBuffer);
2647 }
2648 if (layout.fSet >= 0) {
2649 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2650 fDecorationBuffer);
2651 }
2652 if (layout.fInputAttachmentIndex >= 0) {
2653 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2654 layout.fInputAttachmentIndex, fDecorationBuffer);
2655 }
2656 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2657 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2658 fDecorationBuffer);
2659 }
2660 }
2661
writeLayout(const Layout & layout,SpvId target,int member)2662 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2663 if (layout.fLocation >= 0) {
2664 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2665 layout.fLocation, fDecorationBuffer);
2666 }
2667 if (layout.fBinding >= 0) {
2668 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2669 layout.fBinding, fDecorationBuffer);
2670 }
2671 if (layout.fIndex >= 0) {
2672 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2673 layout.fIndex, fDecorationBuffer);
2674 }
2675 if (layout.fSet >= 0) {
2676 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2677 layout.fSet, fDecorationBuffer);
2678 }
2679 if (layout.fInputAttachmentIndex >= 0) {
2680 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2681 layout.fInputAttachmentIndex, fDecorationBuffer);
2682 }
2683 if (layout.fBuiltin >= 0) {
2684 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2685 layout.fBuiltin, fDecorationBuffer);
2686 }
2687 }
2688
writeInterfaceBlock(const InterfaceBlock & intf)2689 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2690 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2691 MemoryLayout layout = (intf.fVariable.fModifiers.fLayout.fPushConstant || isBuffer) ?
2692 MemoryLayout(MemoryLayout::k430_Standard) :
2693 fDefaultLayout;
2694 SpvId result = this->nextId();
2695 const Type* type = &intf.fVariable.fType;
2696 if (fProgram.fInputs.fRTHeight) {
2697 ASSERT(fRTHeightStructId == (SpvId) -1);
2698 ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2699 std::vector<Type::Field> fields = type->fields();
2700 fRTHeightStructId = result;
2701 fRTHeightFieldIndex = fields.size();
2702 fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2703 type = new Type(type->fPosition, type->name(), fields);
2704 }
2705 SpvId typeId = this->getType(*type, layout);
2706 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2707 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2708 } else {
2709 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2710 }
2711 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2712 SpvId ptrType = this->nextId();
2713 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2714 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2715 this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2716 fVariableMap[&intf.fVariable] = result;
2717 if (fProgram.fInputs.fRTHeight) {
2718 delete type;
2719 }
2720 return result;
2721 }
2722
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2723 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2724 if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2725 (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2726 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2727 }
2728 }
2729
2730 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2731 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2732 OutputStream& out) {
2733 for (size_t i = 0; i < decl.fVars.size(); i++) {
2734 if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2735 continue;
2736 }
2737 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2738 const Variable* var = varDecl.fVar;
2739 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2740 // in the OpenGL backend.
2741 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2742 Modifiers::kWriteOnly_Flag |
2743 Modifiers::kCoherent_Flag |
2744 Modifiers::kVolatile_Flag |
2745 Modifiers::kRestrict_Flag)));
2746 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2747 continue;
2748 }
2749 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2750 kind != Program::kFragment_Kind) {
2751 continue;
2752 }
2753 if (!var->fReadCount && !var->fWriteCount &&
2754 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2755 Modifiers::kOut_Flag |
2756 Modifiers::kUniform_Flag |
2757 Modifiers::kBuffer_Flag))) {
2758 // variable is dead and not an input / output var (the Vulkan debug layers complain if
2759 // we elide an interface var, even if it's dead)
2760 continue;
2761 }
2762 SpvStorageClass_ storageClass;
2763 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2764 storageClass = SpvStorageClassInput;
2765 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2766 storageClass = SpvStorageClassOutput;
2767 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2768 if (var->fType.kind() == Type::kSampler_Kind) {
2769 storageClass = SpvStorageClassUniformConstant;
2770 } else {
2771 storageClass = SpvStorageClassUniform;
2772 }
2773 } else {
2774 storageClass = SpvStorageClassPrivate;
2775 }
2776 SpvId id = this->nextId();
2777 fVariableMap[var] = id;
2778 SpvId type = this->getPointerType(var->fType, storageClass);
2779 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2780 this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2781 this->writePrecisionModifier(var->fModifiers, id);
2782 if (var->fType.kind() == Type::kMatrix_Kind) {
2783 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
2784 fDecorationBuffer);
2785 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
2786 (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer);
2787 }
2788 if (varDecl.fValue) {
2789 ASSERT(!fCurrentBlock);
2790 fCurrentBlock = -1;
2791 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2792 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2793 fCurrentBlock = 0;
2794 }
2795 this->writeLayout(var->fModifiers.fLayout, id);
2796 }
2797 }
2798
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2799 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2800 for (const auto& stmt : decl.fVars) {
2801 ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2802 VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2803 const Variable* var = varDecl.fVar;
2804 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2805 // in the OpenGL backend.
2806 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2807 Modifiers::kWriteOnly_Flag |
2808 Modifiers::kCoherent_Flag |
2809 Modifiers::kVolatile_Flag |
2810 Modifiers::kRestrict_Flag)));
2811 SpvId id = this->nextId();
2812 fVariableMap[var] = id;
2813 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2814 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2815 this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2816 if (varDecl.fValue) {
2817 SpvId value = this->writeExpression(*varDecl.fValue, out);
2818 this->writeInstruction(SpvOpStore, id, value, out);
2819 }
2820 }
2821 }
2822
writeStatement(const Statement & s,OutputStream & out)2823 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2824 switch (s.fKind) {
2825 case Statement::kNop_Kind:
2826 break;
2827 case Statement::kBlock_Kind:
2828 this->writeBlock((Block&) s, out);
2829 break;
2830 case Statement::kExpression_Kind:
2831 this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2832 break;
2833 case Statement::kReturn_Kind:
2834 this->writeReturnStatement((ReturnStatement&) s, out);
2835 break;
2836 case Statement::kVarDeclarations_Kind:
2837 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2838 break;
2839 case Statement::kIf_Kind:
2840 this->writeIfStatement((IfStatement&) s, out);
2841 break;
2842 case Statement::kFor_Kind:
2843 this->writeForStatement((ForStatement&) s, out);
2844 break;
2845 case Statement::kWhile_Kind:
2846 this->writeWhileStatement((WhileStatement&) s, out);
2847 break;
2848 case Statement::kDo_Kind:
2849 this->writeDoStatement((DoStatement&) s, out);
2850 break;
2851 case Statement::kBreak_Kind:
2852 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2853 break;
2854 case Statement::kContinue_Kind:
2855 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2856 break;
2857 case Statement::kDiscard_Kind:
2858 this->writeInstruction(SpvOpKill, out);
2859 break;
2860 default:
2861 ABORT("unsupported statement: %s", s.description().c_str());
2862 }
2863 }
2864
writeBlock(const Block & b,OutputStream & out)2865 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2866 for (size_t i = 0; i < b.fStatements.size(); i++) {
2867 this->writeStatement(*b.fStatements[i], out);
2868 }
2869 }
2870
writeIfStatement(const IfStatement & stmt,OutputStream & out)2871 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2872 SpvId test = this->writeExpression(*stmt.fTest, out);
2873 SpvId ifTrue = this->nextId();
2874 SpvId ifFalse = this->nextId();
2875 if (stmt.fIfFalse) {
2876 SpvId end = this->nextId();
2877 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2878 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2879 this->writeLabel(ifTrue, out);
2880 this->writeStatement(*stmt.fIfTrue, out);
2881 if (fCurrentBlock) {
2882 this->writeInstruction(SpvOpBranch, end, out);
2883 }
2884 this->writeLabel(ifFalse, out);
2885 this->writeStatement(*stmt.fIfFalse, out);
2886 if (fCurrentBlock) {
2887 this->writeInstruction(SpvOpBranch, end, out);
2888 }
2889 this->writeLabel(end, out);
2890 } else {
2891 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2892 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2893 this->writeLabel(ifTrue, out);
2894 this->writeStatement(*stmt.fIfTrue, out);
2895 if (fCurrentBlock) {
2896 this->writeInstruction(SpvOpBranch, ifFalse, out);
2897 }
2898 this->writeLabel(ifFalse, out);
2899 }
2900 }
2901
writeForStatement(const ForStatement & f,OutputStream & out)2902 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2903 if (f.fInitializer) {
2904 this->writeStatement(*f.fInitializer, out);
2905 }
2906 SpvId header = this->nextId();
2907 SpvId start = this->nextId();
2908 SpvId body = this->nextId();
2909 SpvId next = this->nextId();
2910 fContinueTarget.push(next);
2911 SpvId end = this->nextId();
2912 fBreakTarget.push(end);
2913 this->writeInstruction(SpvOpBranch, header, out);
2914 this->writeLabel(header, out);
2915 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2916 this->writeInstruction(SpvOpBranch, start, out);
2917 this->writeLabel(start, out);
2918 if (f.fTest) {
2919 SpvId test = this->writeExpression(*f.fTest, out);
2920 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2921 }
2922 this->writeLabel(body, out);
2923 this->writeStatement(*f.fStatement, out);
2924 if (fCurrentBlock) {
2925 this->writeInstruction(SpvOpBranch, next, out);
2926 }
2927 this->writeLabel(next, out);
2928 if (f.fNext) {
2929 this->writeExpression(*f.fNext, out);
2930 }
2931 this->writeInstruction(SpvOpBranch, header, out);
2932 this->writeLabel(end, out);
2933 fBreakTarget.pop();
2934 fContinueTarget.pop();
2935 }
2936
writeWhileStatement(const WhileStatement & w,OutputStream & out)2937 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2938 // We believe the while loop code below will work, but Skia doesn't actually use them and
2939 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2940 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2941 // message, simply remove the error call below to see whether our while loop support actually
2942 // works.
2943 fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, "
2944 "see SkSLSPIRVCodeGenerator.cpp for details");
2945
2946 SpvId header = this->nextId();
2947 SpvId start = this->nextId();
2948 SpvId body = this->nextId();
2949 fContinueTarget.push(start);
2950 SpvId end = this->nextId();
2951 fBreakTarget.push(end);
2952 this->writeInstruction(SpvOpBranch, header, out);
2953 this->writeLabel(header, out);
2954 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2955 this->writeInstruction(SpvOpBranch, start, out);
2956 this->writeLabel(start, out);
2957 SpvId test = this->writeExpression(*w.fTest, out);
2958 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2959 this->writeLabel(body, out);
2960 this->writeStatement(*w.fStatement, out);
2961 if (fCurrentBlock) {
2962 this->writeInstruction(SpvOpBranch, start, out);
2963 }
2964 this->writeLabel(end, out);
2965 fBreakTarget.pop();
2966 fContinueTarget.pop();
2967 }
2968
writeDoStatement(const DoStatement & d,OutputStream & out)2969 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2970 // We believe the do loop code below will work, but Skia doesn't actually use them and
2971 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2972 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2973 // message, simply remove the error call below to see whether our do loop support actually
2974 // works.
2975 fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see "
2976 "SkSLSPIRVCodeGenerator.cpp for details");
2977
2978 SpvId header = this->nextId();
2979 SpvId start = this->nextId();
2980 SpvId next = this->nextId();
2981 fContinueTarget.push(next);
2982 SpvId end = this->nextId();
2983 fBreakTarget.push(end);
2984 this->writeInstruction(SpvOpBranch, header, out);
2985 this->writeLabel(header, out);
2986 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2987 this->writeInstruction(SpvOpBranch, start, out);
2988 this->writeLabel(start, out);
2989 this->writeStatement(*d.fStatement, out);
2990 if (fCurrentBlock) {
2991 this->writeInstruction(SpvOpBranch, next, out);
2992 }
2993 this->writeLabel(next, out);
2994 SpvId test = this->writeExpression(*d.fTest, out);
2995 this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2996 this->writeLabel(end, out);
2997 fBreakTarget.pop();
2998 fContinueTarget.pop();
2999 }
3000
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3001 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3002 if (r.fExpression) {
3003 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3004 out);
3005 } else {
3006 this->writeInstruction(SpvOpReturn, out);
3007 }
3008 }
3009
writeInstructions(const Program & program,OutputStream & out)3010 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3011 fGLSLExtendedInstructions = this->nextId();
3012 StringStream body;
3013 std::set<SpvId> interfaceVars;
3014 // assign IDs to functions
3015 for (size_t i = 0; i < program.fElements.size(); i++) {
3016 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
3017 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
3018 fFunctionMap[&f.fDeclaration] = this->nextId();
3019 }
3020 }
3021 for (size_t i = 0; i < program.fElements.size(); i++) {
3022 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
3023 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
3024 SpvId id = this->writeInterfaceBlock(intf);
3025 if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3026 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
3027 interfaceVars.insert(id);
3028 }
3029 }
3030 }
3031 for (size_t i = 0; i < program.fElements.size(); i++) {
3032 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
3033 this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
3034 body);
3035 }
3036 }
3037 for (size_t i = 0; i < program.fElements.size(); i++) {
3038 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
3039 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
3040 }
3041 }
3042 const FunctionDeclaration* main = nullptr;
3043 for (auto entry : fFunctionMap) {
3044 if (entry.first->fName == "main") {
3045 main = entry.first;
3046 }
3047 }
3048 ASSERT(main);
3049 for (auto entry : fVariableMap) {
3050 const Variable* var = entry.first;
3051 if (var->fStorage == Variable::kGlobal_Storage &&
3052 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3053 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3054 interfaceVars.insert(entry.second);
3055 }
3056 }
3057 this->writeCapabilities(out);
3058 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3059 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3060 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
3061 (int32_t) interfaceVars.size(), out);
3062 switch (program.fKind) {
3063 case Program::kVertex_Kind:
3064 this->writeWord(SpvExecutionModelVertex, out);
3065 break;
3066 case Program::kFragment_Kind:
3067 this->writeWord(SpvExecutionModelFragment, out);
3068 break;
3069 case Program::kGeometry_Kind:
3070 this->writeWord(SpvExecutionModelGeometry, out);
3071 break;
3072 default:
3073 ABORT("cannot write this kind of program to SPIR-V\n");
3074 }
3075 this->writeWord(fFunctionMap[main], out);
3076 this->writeString(main->fName.c_str(), out);
3077 for (int var : interfaceVars) {
3078 this->writeWord(var, out);
3079 }
3080 if (program.fKind == Program::kFragment_Kind) {
3081 this->writeInstruction(SpvOpExecutionMode,
3082 fFunctionMap[main],
3083 SpvExecutionModeOriginUpperLeft,
3084 out);
3085 }
3086 for (size_t i = 0; i < program.fElements.size(); i++) {
3087 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
3088 this->writeInstruction(SpvOpSourceExtension,
3089 ((Extension&) *program.fElements[i]).fName.c_str(),
3090 out);
3091 }
3092 }
3093
3094 write_stringstream(fExtraGlobalsBuffer, out);
3095 write_stringstream(fNameBuffer, out);
3096 write_stringstream(fDecorationBuffer, out);
3097 write_stringstream(fConstantBuffer, out);
3098 write_stringstream(fExternalFunctionsBuffer, out);
3099 write_stringstream(body, out);
3100 }
3101
generateCode()3102 bool SPIRVCodeGenerator::generateCode() {
3103 ASSERT(!fErrors.errorCount());
3104 this->writeWord(SpvMagicNumber, *fOut);
3105 this->writeWord(SpvVersion, *fOut);
3106 this->writeWord(SKSL_MAGIC, *fOut);
3107 StringStream buffer;
3108 this->writeInstructions(fProgram, buffer);
3109 this->writeWord(fIdCount, *fOut);
3110 this->writeWord(0, *fOut); // reserved, always zero
3111 write_stringstream(buffer, *fOut);
3112 return 0 == fErrors.errorCount();
3113 }
3114
3115 }
3116