1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLSPIRVCodeGenerator.h"
9
10 #include "GLSL.std.450.h"
11
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17
18 namespace SkSL {
19
20 static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number
21
setupIntrinsics()22 void SPIRVCodeGenerator::setupIntrinsics() {
23 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24 GLSLstd450 ## x, GLSLstd450 ## x)
25 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26 GLSLstd450 ## ifFloat, \
27 GLSLstd450 ## ifInt, \
28 GLSLstd450 ## ifUInt, \
29 SpvOpUndef)
30 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31 SpvOp ## x)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34 k ## x ## _SpecialIntrinsic)
35 fIntrinsicMap[String("round")] = ALL_GLSL(Round);
36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven);
37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc);
38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign);
40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor);
41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil);
42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract);
43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians);
44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees);
45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin);
46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos);
47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan);
48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin);
49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos);
50 fIntrinsicMap[String("atan")] = SPECIAL(Atan);
51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh);
52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh);
53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh);
54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh);
55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh);
56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh);
57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow);
58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp);
59 fIntrinsicMap[String("log")] = ALL_GLSL(Log);
60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2);
61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2);
62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt);
63 fIntrinsicMap[String("inverse")] = ALL_GLSL(MatrixInverse);
64 fIntrinsicMap[String("transpose")] = ALL_SPIRV(Transpose);
65 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt);
66 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant);
67 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68 fIntrinsicMap[String("mod")] = SPECIAL(Mod);
69 fIntrinsicMap[String("min")] = SPECIAL(Min);
70 fIntrinsicMap[String("max")] = SPECIAL(Max);
71 fIntrinsicMap[String("clamp")] = SPECIAL(Clamp);
72 fIntrinsicMap[String("saturate")] = SPECIAL(Saturate);
73 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
74 SpvOpUndef, SpvOpUndef, SpvOpUndef);
75 fIntrinsicMap[String("mix")] = SPECIAL(Mix);
76 fIntrinsicMap[String("step")] = ALL_GLSL(Step);
77 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep);
78 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma);
79 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp);
80 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp);
81
82 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
83 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
84 PACK(Snorm4x8);
85 PACK(Unorm4x8);
86 PACK(Snorm2x16);
87 PACK(Unorm2x16);
88 PACK(Half2x16);
89 PACK(Double2x32);
90 fIntrinsicMap[String("length")] = ALL_GLSL(Length);
91 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance);
92 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross);
93 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize);
94 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
95 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect);
96 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract);
97 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb);
98 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
99 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
100 SpvOpUndef, SpvOpUndef, SpvOpUndef);
101 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102 SpvOpUndef, SpvOpUndef, SpvOpUndef);
103 fIntrinsicMap[String("fwidth")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFwidth,
104 SpvOpUndef, SpvOpUndef, SpvOpUndef);
105 fIntrinsicMap[String("texture")] = SPECIAL(Texture);
106 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107
108 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109 SpvOpUndef, SpvOpUndef, SpvOpAny);
110 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111 SpvOpUndef, SpvOpUndef, SpvOpAll);
112 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind,
113 SpvOpFOrdEqual, SpvOpIEqual,
114 SpvOpIEqual, SpvOpLogicalEqual);
115 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
116 SpvOpFOrdNotEqual, SpvOpINotEqual,
117 SpvOpINotEqual,
118 SpvOpLogicalNotEqual);
119 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
120 SpvOpFOrdLessThan, SpvOpSLessThan,
121 SpvOpULessThan, SpvOpUndef);
122 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
123 SpvOpFOrdLessThanEqual,
124 SpvOpSLessThanEqual,
125 SpvOpULessThanEqual,
126 SpvOpUndef);
127 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
128 SpvOpFOrdGreaterThan,
129 SpvOpSGreaterThan,
130 SpvOpUGreaterThan,
131 SpvOpUndef);
132 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133 SpvOpFOrdGreaterThanEqual,
134 SpvOpSGreaterThanEqual,
135 SpvOpUGreaterThanEqual,
136 SpvOpUndef);
137 fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex);
138 fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive);
139 // interpolateAt* not yet supported...
140 }
141
writeWord(int32_t word,OutputStream & out)142 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143 out.write((const char*) &word, sizeof(word));
144 }
145
is_float(const Context & context,const Type & type)146 static bool is_float(const Context& context, const Type& type) {
147 if (type.columns() > 1) {
148 return is_float(context, type.componentType());
149 }
150 return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151 type == *context.fDouble_Type;
152 }
153
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155 if (type.kind() == Type::kVector_Kind) {
156 return is_signed(context, type.componentType());
157 }
158 return type == *context.fInt_Type || type == *context.fShort_Type ||
159 type == *context.fByte_Type;
160 }
161
is_unsigned(const Context & context,const Type & type)162 static bool is_unsigned(const Context& context, const Type& type) {
163 if (type.kind() == Type::kVector_Kind) {
164 return is_unsigned(context, type.componentType());
165 }
166 return type == *context.fUInt_Type || type == *context.fUShort_Type ||
167 type == *context.fUByte_Type;
168 }
169
is_bool(const Context & context,const Type & type)170 static bool is_bool(const Context& context, const Type& type) {
171 if (type.kind() == Type::kVector_Kind) {
172 return is_bool(context, type.componentType());
173 }
174 return type == *context.fBool_Type;
175 }
176
is_out(const Variable & var)177 static bool is_out(const Variable& var) {
178 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
179 }
180
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)181 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
182 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
183 SkASSERT(opCode != SpvOpUndef);
184 switch (opCode) {
185 case SpvOpReturn: // fall through
186 case SpvOpReturnValue: // fall through
187 case SpvOpKill: // fall through
188 case SpvOpBranch: // fall through
189 case SpvOpBranchConditional:
190 SkASSERT(fCurrentBlock);
191 fCurrentBlock = 0;
192 break;
193 case SpvOpConstant: // fall through
194 case SpvOpConstantTrue: // fall through
195 case SpvOpConstantFalse: // fall through
196 case SpvOpConstantComposite: // fall through
197 case SpvOpTypeVoid: // fall through
198 case SpvOpTypeInt: // fall through
199 case SpvOpTypeFloat: // fall through
200 case SpvOpTypeBool: // fall through
201 case SpvOpTypeVector: // fall through
202 case SpvOpTypeMatrix: // fall through
203 case SpvOpTypeArray: // fall through
204 case SpvOpTypePointer: // fall through
205 case SpvOpTypeFunction: // fall through
206 case SpvOpTypeRuntimeArray: // fall through
207 case SpvOpTypeStruct: // fall through
208 case SpvOpTypeImage: // fall through
209 case SpvOpTypeSampledImage: // fall through
210 case SpvOpVariable: // fall through
211 case SpvOpFunction: // fall through
212 case SpvOpFunctionParameter: // fall through
213 case SpvOpFunctionEnd: // fall through
214 case SpvOpExecutionMode: // fall through
215 case SpvOpMemoryModel: // fall through
216 case SpvOpCapability: // fall through
217 case SpvOpExtInstImport: // fall through
218 case SpvOpEntryPoint: // fall through
219 case SpvOpSource: // fall through
220 case SpvOpSourceExtension: // fall through
221 case SpvOpName: // fall through
222 case SpvOpMemberName: // fall through
223 case SpvOpDecorate: // fall through
224 case SpvOpMemberDecorate:
225 break;
226 default:
227 SkASSERT(fCurrentBlock);
228 }
229 this->writeWord((length << 16) | opCode, out);
230 }
231
writeLabel(SpvId label,OutputStream & out)232 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
233 fCurrentBlock = label;
234 this->writeInstruction(SpvOpLabel, label, out);
235 }
236
writeInstruction(SpvOp_ opCode,OutputStream & out)237 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
238 this->writeOpCode(opCode, 1, out);
239 }
240
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)241 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
242 this->writeOpCode(opCode, 2, out);
243 this->writeWord(word1, out);
244 }
245
writeString(const char * string,size_t length,OutputStream & out)246 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
247 out.write(string, length);
248 switch (length % 4) {
249 case 1:
250 out.write8(0);
251 // fall through
252 case 2:
253 out.write8(0);
254 // fall through
255 case 3:
256 out.write8(0);
257 break;
258 default:
259 this->writeWord(0, out);
260 }
261 }
262
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)263 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
264 this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
265 this->writeString(string.fChars, string.fLength, out);
266 }
267
268
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)269 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
270 OutputStream& out) {
271 this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
272 this->writeWord(word1, out);
273 this->writeString(string.fChars, string.fLength, out);
274 }
275
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)276 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
277 StringFragment string, OutputStream& out) {
278 this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
279 this->writeWord(word1, out);
280 this->writeWord(word2, out);
281 this->writeString(string.fChars, string.fLength, out);
282 }
283
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)284 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
285 OutputStream& out) {
286 this->writeOpCode(opCode, 3, out);
287 this->writeWord(word1, out);
288 this->writeWord(word2, out);
289 }
290
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)291 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
292 int32_t word3, OutputStream& out) {
293 this->writeOpCode(opCode, 4, out);
294 this->writeWord(word1, out);
295 this->writeWord(word2, out);
296 this->writeWord(word3, out);
297 }
298
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)299 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
300 int32_t word3, int32_t word4, OutputStream& out) {
301 this->writeOpCode(opCode, 5, out);
302 this->writeWord(word1, out);
303 this->writeWord(word2, out);
304 this->writeWord(word3, out);
305 this->writeWord(word4, out);
306 }
307
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)308 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
309 int32_t word3, int32_t word4, int32_t word5,
310 OutputStream& out) {
311 this->writeOpCode(opCode, 6, out);
312 this->writeWord(word1, out);
313 this->writeWord(word2, out);
314 this->writeWord(word3, out);
315 this->writeWord(word4, out);
316 this->writeWord(word5, out);
317 }
318
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)319 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
320 int32_t word3, int32_t word4, int32_t word5,
321 int32_t word6, OutputStream& out) {
322 this->writeOpCode(opCode, 7, out);
323 this->writeWord(word1, out);
324 this->writeWord(word2, out);
325 this->writeWord(word3, out);
326 this->writeWord(word4, out);
327 this->writeWord(word5, out);
328 this->writeWord(word6, out);
329 }
330
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)331 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
332 int32_t word3, int32_t word4, int32_t word5,
333 int32_t word6, int32_t word7, OutputStream& out) {
334 this->writeOpCode(opCode, 8, out);
335 this->writeWord(word1, out);
336 this->writeWord(word2, out);
337 this->writeWord(word3, out);
338 this->writeWord(word4, out);
339 this->writeWord(word5, out);
340 this->writeWord(word6, out);
341 this->writeWord(word7, out);
342 }
343
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)344 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
345 int32_t word3, int32_t word4, int32_t word5,
346 int32_t word6, int32_t word7, int32_t word8,
347 OutputStream& out) {
348 this->writeOpCode(opCode, 9, out);
349 this->writeWord(word1, out);
350 this->writeWord(word2, out);
351 this->writeWord(word3, out);
352 this->writeWord(word4, out);
353 this->writeWord(word5, out);
354 this->writeWord(word6, out);
355 this->writeWord(word7, out);
356 this->writeWord(word8, out);
357 }
358
writeCapabilities(OutputStream & out)359 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
360 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
361 if (fCapabilities & bit) {
362 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
363 }
364 }
365 if (fProgram.fKind == Program::kGeometry_Kind) {
366 this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
367 }
368 else {
369 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
370 }
371 }
372
nextId()373 SpvId SPIRVCodeGenerator::nextId() {
374 return fIdCount++;
375 }
376
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)377 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
378 SpvId resultId) {
379 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
380 // go ahead and write all of the field types, so we don't inadvertently write them while we're
381 // in the middle of writing the struct instruction
382 std::vector<SpvId> types;
383 for (const auto& f : type.fields()) {
384 types.push_back(this->getType(*f.fType, memoryLayout));
385 }
386 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
387 this->writeWord(resultId, fConstantBuffer);
388 for (SpvId id : types) {
389 this->writeWord(id, fConstantBuffer);
390 }
391 size_t offset = 0;
392 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
393 size_t size = memoryLayout.size(*type.fields()[i].fType);
394 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
395 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
396 if (fieldLayout.fOffset >= 0) {
397 if (fieldLayout.fOffset < (int) offset) {
398 fErrors.error(type.fOffset,
399 "offset of field '" + type.fields()[i].fName + "' must be at "
400 "least " + to_string((int) offset));
401 }
402 if (fieldLayout.fOffset % alignment) {
403 fErrors.error(type.fOffset,
404 "offset of field '" + type.fields()[i].fName + "' must be a multiple"
405 " of " + to_string((int) alignment));
406 }
407 offset = fieldLayout.fOffset;
408 } else {
409 size_t mod = offset % alignment;
410 if (mod) {
411 offset += alignment - mod;
412 }
413 }
414 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
415 this->writeLayout(fieldLayout, resultId, i);
416 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
417 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
418 (SpvId) offset, fDecorationBuffer);
419 }
420 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
421 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
422 fDecorationBuffer);
423 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
424 (SpvId) memoryLayout.stride(*type.fields()[i].fType),
425 fDecorationBuffer);
426 }
427 offset += size;
428 Type::Kind kind = type.fields()[i].fType->kind();
429 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
430 offset += alignment - offset % alignment;
431 }
432 }
433 }
434
getActualType(const Type & type)435 Type SPIRVCodeGenerator::getActualType(const Type& type) {
436 if (type == *fContext.fHalf_Type) {
437 return *fContext.fFloat_Type;
438 }
439 if (type == *fContext.fShort_Type || type == *fContext.fByte_Type) {
440 return *fContext.fInt_Type;
441 }
442 if (type == *fContext.fUShort_Type || type == *fContext.fUByte_Type) {
443 return *fContext.fUInt_Type;
444 }
445 if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
446 if (type.componentType() == *fContext.fHalf_Type) {
447 return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
448 }
449 if (type.componentType() == *fContext.fShort_Type ||
450 type.componentType() == *fContext.fByte_Type) {
451 return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
452 }
453 if (type.componentType() == *fContext.fUShort_Type ||
454 type.componentType() == *fContext.fUByte_Type) {
455 return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
456 }
457 }
458 return type;
459 }
460
getType(const Type & type)461 SpvId SPIRVCodeGenerator::getType(const Type& type) {
462 return this->getType(type, fDefaultLayout);
463 }
464
getType(const Type & rawType,const MemoryLayout & layout)465 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
466 Type type = this->getActualType(rawType);
467 String key = type.name() + to_string((int) layout.fStd);
468 auto entry = fTypeMap.find(key);
469 if (entry == fTypeMap.end()) {
470 SpvId result = this->nextId();
471 switch (type.kind()) {
472 case Type::kScalar_Kind:
473 if (type == *fContext.fBool_Type) {
474 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
475 } else if (type == *fContext.fInt_Type) {
476 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
477 } else if (type == *fContext.fUInt_Type) {
478 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
479 } else if (type == *fContext.fFloat_Type) {
480 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
481 } else if (type == *fContext.fDouble_Type) {
482 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
483 } else {
484 SkASSERT(false);
485 }
486 break;
487 case Type::kVector_Kind:
488 this->writeInstruction(SpvOpTypeVector, result,
489 this->getType(type.componentType(), layout),
490 type.columns(), fConstantBuffer);
491 break;
492 case Type::kMatrix_Kind:
493 this->writeInstruction(SpvOpTypeMatrix, result,
494 this->getType(index_type(fContext, type), layout),
495 type.columns(), fConstantBuffer);
496 break;
497 case Type::kStruct_Kind:
498 this->writeStruct(type, layout, result);
499 break;
500 case Type::kArray_Kind: {
501 if (type.columns() > 0) {
502 IntLiteral count(fContext, -1, type.columns());
503 this->writeInstruction(SpvOpTypeArray, result,
504 this->getType(type.componentType(), layout),
505 this->writeIntLiteral(count), fConstantBuffer);
506 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507 (int32_t) layout.stride(type),
508 fDecorationBuffer);
509 } else {
510 SkASSERT(false); // we shouldn't have any runtime-sized arrays right now
511 this->writeInstruction(SpvOpTypeRuntimeArray, result,
512 this->getType(type.componentType(), layout),
513 fConstantBuffer);
514 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
515 (int32_t) layout.stride(type),
516 fDecorationBuffer);
517 }
518 break;
519 }
520 case Type::kSampler_Kind: {
521 SpvId image = result;
522 if (SpvDimSubpassData != type.dimensions()) {
523 image = this->nextId();
524 }
525 if (SpvDimBuffer == type.dimensions()) {
526 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
527 }
528 this->writeInstruction(SpvOpTypeImage, image,
529 this->getType(*fContext.fFloat_Type, layout),
530 type.dimensions(), type.isDepth(), type.isArrayed(),
531 type.isMultisampled(), type.isSampled() ? 1 : 2,
532 SpvImageFormatUnknown, fConstantBuffer);
533 fImageTypeMap[key] = image;
534 if (SpvDimSubpassData != type.dimensions()) {
535 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
536 }
537 break;
538 }
539 default:
540 if (type == *fContext.fVoid_Type) {
541 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
542 } else {
543 ABORT("invalid type: %s", type.description().c_str());
544 }
545 }
546 fTypeMap[key] = result;
547 return result;
548 }
549 return entry->second;
550 }
551
getImageType(const Type & type)552 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
553 SkASSERT(type.kind() == Type::kSampler_Kind);
554 this->getType(type);
555 String key = type.name() + to_string((int) fDefaultLayout.fStd);
556 SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
557 return fImageTypeMap[key];
558 }
559
getFunctionType(const FunctionDeclaration & function)560 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
561 String key = function.fReturnType.description() + "(";
562 String separator;
563 for (size_t i = 0; i < function.fParameters.size(); i++) {
564 key += separator;
565 separator = ", ";
566 key += function.fParameters[i]->fType.description();
567 }
568 key += ")";
569 auto entry = fTypeMap.find(key);
570 if (entry == fTypeMap.end()) {
571 SpvId result = this->nextId();
572 int32_t length = 3 + (int32_t) function.fParameters.size();
573 SpvId returnType = this->getType(function.fReturnType);
574 std::vector<SpvId> parameterTypes;
575 for (size_t i = 0; i < function.fParameters.size(); i++) {
576 // glslang seems to treat all function arguments as pointers whether they need to be or
577 // not. I was initially puzzled by this until I ran bizarre failures with certain
578 // patterns of function calls and control constructs, as exemplified by this minimal
579 // failure case:
580 //
581 // void sphere(float x) {
582 // }
583 //
584 // void map() {
585 // sphere(1.0);
586 // }
587 //
588 // void main() {
589 // for (int i = 0; i < 1; i++) {
590 // map();
591 // }
592 // }
593 //
594 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
595 // crashes. Making it take a float* and storing the argument in a temporary variable,
596 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
597 // the spec makes this make sense.
598 // if (is_out(function->fParameters[i])) {
599 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
600 SpvStorageClassFunction));
601 // } else {
602 // parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
603 // }
604 }
605 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
606 this->writeWord(result, fConstantBuffer);
607 this->writeWord(returnType, fConstantBuffer);
608 for (SpvId id : parameterTypes) {
609 this->writeWord(id, fConstantBuffer);
610 }
611 fTypeMap[key] = result;
612 return result;
613 }
614 return entry->second;
615 }
616
getPointerType(const Type & type,SpvStorageClass_ storageClass)617 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
618 return this->getPointerType(type, fDefaultLayout, storageClass);
619 }
620
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)621 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
622 SpvStorageClass_ storageClass) {
623 Type type = this->getActualType(rawType);
624 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
625 auto entry = fTypeMap.find(key);
626 if (entry == fTypeMap.end()) {
627 SpvId result = this->nextId();
628 this->writeInstruction(SpvOpTypePointer, result, storageClass,
629 this->getType(type), fConstantBuffer);
630 fTypeMap[key] = result;
631 return result;
632 }
633 return entry->second;
634 }
635
writeExpression(const Expression & expr,OutputStream & out)636 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
637 switch (expr.fKind) {
638 case Expression::kBinary_Kind:
639 return this->writeBinaryExpression((BinaryExpression&) expr, out);
640 case Expression::kBoolLiteral_Kind:
641 return this->writeBoolLiteral((BoolLiteral&) expr);
642 case Expression::kConstructor_Kind:
643 return this->writeConstructor((Constructor&) expr, out);
644 case Expression::kIntLiteral_Kind:
645 return this->writeIntLiteral((IntLiteral&) expr);
646 case Expression::kFieldAccess_Kind:
647 return this->writeFieldAccess(((FieldAccess&) expr), out);
648 case Expression::kFloatLiteral_Kind:
649 return this->writeFloatLiteral(((FloatLiteral&) expr));
650 case Expression::kFunctionCall_Kind:
651 return this->writeFunctionCall((FunctionCall&) expr, out);
652 case Expression::kPrefix_Kind:
653 return this->writePrefixExpression((PrefixExpression&) expr, out);
654 case Expression::kPostfix_Kind:
655 return this->writePostfixExpression((PostfixExpression&) expr, out);
656 case Expression::kSwizzle_Kind:
657 return this->writeSwizzle((Swizzle&) expr, out);
658 case Expression::kVariableReference_Kind:
659 return this->writeVariableReference((VariableReference&) expr, out);
660 case Expression::kTernary_Kind:
661 return this->writeTernaryExpression((TernaryExpression&) expr, out);
662 case Expression::kIndex_Kind:
663 return this->writeIndexExpression((IndexExpression&) expr, out);
664 default:
665 ABORT("unsupported expression: %s", expr.description().c_str());
666 }
667 return -1;
668 }
669
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)670 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
671 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
672 SkASSERT(intrinsic != fIntrinsicMap.end());
673 int32_t intrinsicId;
674 if (c.fArguments.size() > 0) {
675 const Type& type = c.fArguments[0]->fType;
676 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
677 intrinsicId = std::get<1>(intrinsic->second);
678 } else if (is_signed(fContext, type)) {
679 intrinsicId = std::get<2>(intrinsic->second);
680 } else if (is_unsigned(fContext, type)) {
681 intrinsicId = std::get<3>(intrinsic->second);
682 } else if (is_bool(fContext, type)) {
683 intrinsicId = std::get<4>(intrinsic->second);
684 } else {
685 intrinsicId = std::get<1>(intrinsic->second);
686 }
687 } else {
688 intrinsicId = std::get<1>(intrinsic->second);
689 }
690 switch (std::get<0>(intrinsic->second)) {
691 case kGLSL_STD_450_IntrinsicKind: {
692 SpvId result = this->nextId();
693 std::vector<SpvId> arguments;
694 for (size_t i = 0; i < c.fArguments.size(); i++) {
695 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
696 arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
697 } else {
698 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
699 }
700 }
701 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
702 this->writeWord(this->getType(c.fType), out);
703 this->writeWord(result, out);
704 this->writeWord(fGLSLExtendedInstructions, out);
705 this->writeWord(intrinsicId, out);
706 for (SpvId id : arguments) {
707 this->writeWord(id, out);
708 }
709 return result;
710 }
711 case kSPIRV_IntrinsicKind: {
712 SpvId result = this->nextId();
713 std::vector<SpvId> arguments;
714 for (size_t i = 0; i < c.fArguments.size(); i++) {
715 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
716 arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
717 } else {
718 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
719 }
720 }
721 if (c.fType != *fContext.fVoid_Type) {
722 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
723 this->writeWord(this->getType(c.fType), out);
724 this->writeWord(result, out);
725 } else {
726 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
727 }
728 for (SpvId id : arguments) {
729 this->writeWord(id, out);
730 }
731 return result;
732 }
733 case kSpecial_IntrinsicKind:
734 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
735 default:
736 ABORT("unsupported intrinsic kind");
737 }
738 }
739
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)740 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
741 const std::vector<std::unique_ptr<Expression>>& args,
742 OutputStream& out) {
743 int vectorSize = 0;
744 for (const auto& a : args) {
745 if (a->fType.kind() == Type::kVector_Kind) {
746 if (vectorSize) {
747 SkASSERT(a->fType.columns() == vectorSize);
748 }
749 else {
750 vectorSize = a->fType.columns();
751 }
752 }
753 }
754 std::vector<SpvId> result;
755 for (const auto& a : args) {
756 SpvId raw = this->writeExpression(*a, out);
757 if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
758 SpvId vector = this->nextId();
759 this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
760 this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
761 this->writeWord(vector, out);
762 for (int i = 0; i < vectorSize; i++) {
763 this->writeWord(raw, out);
764 }
765 result.push_back(vector);
766 } else {
767 result.push_back(raw);
768 }
769 }
770 return result;
771 }
772
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)773 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
774 SpvId signedInst, SpvId unsignedInst,
775 const std::vector<SpvId>& args,
776 OutputStream& out) {
777 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
778 this->writeWord(this->getType(type), out);
779 this->writeWord(id, out);
780 this->writeWord(fGLSLExtendedInstructions, out);
781
782 if (is_float(fContext, type)) {
783 this->writeWord(floatInst, out);
784 } else if (is_signed(fContext, type)) {
785 this->writeWord(signedInst, out);
786 } else if (is_unsigned(fContext, type)) {
787 this->writeWord(unsignedInst, out);
788 } else {
789 SkASSERT(false);
790 }
791 for (SpvId a : args) {
792 this->writeWord(a, out);
793 }
794 }
795
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)796 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
797 OutputStream& out) {
798 SpvId result = this->nextId();
799 switch (kind) {
800 case kAtan_SpecialIntrinsic: {
801 std::vector<SpvId> arguments;
802 for (size_t i = 0; i < c.fArguments.size(); i++) {
803 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
804 }
805 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
806 this->writeWord(this->getType(c.fType), out);
807 this->writeWord(result, out);
808 this->writeWord(fGLSLExtendedInstructions, out);
809 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
810 for (SpvId id : arguments) {
811 this->writeWord(id, out);
812 }
813 break;
814 }
815 case kSubpassLoad_SpecialIntrinsic: {
816 SpvId img = this->writeExpression(*c.fArguments[0], out);
817 std::vector<std::unique_ptr<Expression>> args;
818 args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
819 args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
820 Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
821 SpvId coords = this->writeConstantVector(ctor);
822 if (1 == c.fArguments.size()) {
823 this->writeInstruction(SpvOpImageRead,
824 this->getType(c.fType),
825 result,
826 img,
827 coords,
828 out);
829 } else {
830 SkASSERT(2 == c.fArguments.size());
831 SpvId sample = this->writeExpression(*c.fArguments[1], out);
832 this->writeInstruction(SpvOpImageRead,
833 this->getType(c.fType),
834 result,
835 img,
836 coords,
837 SpvImageOperandsSampleMask,
838 sample,
839 out);
840 }
841 break;
842 }
843 case kTexture_SpecialIntrinsic: {
844 SpvOp_ op = SpvOpImageSampleImplicitLod;
845 switch (c.fArguments[0]->fType.dimensions()) {
846 case SpvDim1D:
847 if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
848 op = SpvOpImageSampleProjImplicitLod;
849 } else {
850 SkASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
851 }
852 break;
853 case SpvDim2D:
854 if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
855 op = SpvOpImageSampleProjImplicitLod;
856 } else {
857 SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
858 }
859 break;
860 case SpvDim3D:
861 if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
862 op = SpvOpImageSampleProjImplicitLod;
863 } else {
864 SkASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
865 }
866 break;
867 case SpvDimCube: // fall through
868 case SpvDimRect: // fall through
869 case SpvDimBuffer: // fall through
870 case SpvDimSubpassData:
871 break;
872 }
873 SpvId type = this->getType(c.fType);
874 SpvId sampler = this->writeExpression(*c.fArguments[0], out);
875 SpvId uv = this->writeExpression(*c.fArguments[1], out);
876 if (c.fArguments.size() == 3) {
877 this->writeInstruction(op, type, result, sampler, uv,
878 SpvImageOperandsBiasMask,
879 this->writeExpression(*c.fArguments[2], out),
880 out);
881 } else {
882 SkASSERT(c.fArguments.size() == 2);
883 if (fProgram.fSettings.fSharpenTextures) {
884 FloatLiteral lodBias(fContext, -1, -0.5);
885 this->writeInstruction(op, type, result, sampler, uv,
886 SpvImageOperandsBiasMask,
887 this->writeFloatLiteral(lodBias),
888 out);
889 } else {
890 this->writeInstruction(op, type, result, sampler, uv,
891 out);
892 }
893 }
894 break;
895 }
896 case kMod_SpecialIntrinsic: {
897 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
898 SkASSERT(args.size() == 2);
899 const Type& operandType = c.fArguments[0]->fType;
900 SpvOp_ op;
901 if (is_float(fContext, operandType)) {
902 op = SpvOpFMod;
903 } else if (is_signed(fContext, operandType)) {
904 op = SpvOpSMod;
905 } else if (is_unsigned(fContext, operandType)) {
906 op = SpvOpUMod;
907 } else {
908 SkASSERT(false);
909 return 0;
910 }
911 this->writeOpCode(op, 5, out);
912 this->writeWord(this->getType(operandType), out);
913 this->writeWord(result, out);
914 this->writeWord(args[0], out);
915 this->writeWord(args[1], out);
916 break;
917 }
918 case kClamp_SpecialIntrinsic: {
919 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
920 SkASSERT(args.size() == 3);
921 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
922 GLSLstd450UClamp, args, out);
923 break;
924 }
925 case kMax_SpecialIntrinsic: {
926 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
927 SkASSERT(args.size() == 2);
928 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
929 GLSLstd450UMax, args, out);
930 break;
931 }
932 case kMin_SpecialIntrinsic: {
933 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
934 SkASSERT(args.size() == 2);
935 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
936 GLSLstd450UMin, args, out);
937 break;
938 }
939 case kMix_SpecialIntrinsic: {
940 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
941 SkASSERT(args.size() == 3);
942 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
943 SpvOpUndef, args, out);
944 break;
945 }
946 case kSaturate_SpecialIntrinsic: {
947 SkASSERT(c.fArguments.size() == 1);
948 std::vector<std::unique_ptr<Expression>> finalArgs;
949 finalArgs.push_back(c.fArguments[0]->clone());
950 finalArgs.emplace_back(new FloatLiteral(fContext, -1, 0));
951 finalArgs.emplace_back(new FloatLiteral(fContext, -1, 1));
952 std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
953 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
954 GLSLstd450UClamp, spvArgs, out);
955 break;
956 }
957 }
958 return result;
959 }
960
writeFunctionCall(const FunctionCall & c,OutputStream & out)961 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
962 const auto& entry = fFunctionMap.find(&c.fFunction);
963 if (entry == fFunctionMap.end()) {
964 return this->writeIntrinsicCall(c, out);
965 }
966 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
967 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
968 std::vector<SpvId> arguments;
969 for (size_t i = 0; i < c.fArguments.size(); i++) {
970 // id of temporary variable that we will use to hold this argument, or 0 if it is being
971 // passed directly
972 SpvId tmpVar;
973 // if we need a temporary var to store this argument, this is the value to store in the var
974 SpvId tmpValueId;
975 if (is_out(*c.fFunction.fParameters[i])) {
976 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
977 SpvId ptr = lv->getPointer();
978 if (ptr) {
979 arguments.push_back(ptr);
980 continue;
981 } else {
982 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
983 // copy it into a temp, call the function, read the value out of the temp, and then
984 // update the lvalue.
985 tmpValueId = lv->load(out);
986 tmpVar = this->nextId();
987 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
988 std::move(lv)));
989 }
990 } else {
991 // see getFunctionType for an explanation of why we're always using pointer parameters
992 tmpValueId = this->writeExpression(*c.fArguments[i], out);
993 tmpVar = this->nextId();
994 }
995 this->writeInstruction(SpvOpVariable,
996 this->getPointerType(c.fArguments[i]->fType,
997 SpvStorageClassFunction),
998 tmpVar,
999 SpvStorageClassFunction,
1000 fVariableBuffer);
1001 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1002 arguments.push_back(tmpVar);
1003 }
1004 SpvId result = this->nextId();
1005 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1006 this->writeWord(this->getType(c.fType), out);
1007 this->writeWord(result, out);
1008 this->writeWord(entry->second, out);
1009 for (SpvId id : arguments) {
1010 this->writeWord(id, out);
1011 }
1012 // now that the call is complete, we may need to update some lvalues with the new values of out
1013 // arguments
1014 for (const auto& tuple : lvalues) {
1015 SpvId load = this->nextId();
1016 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1017 std::get<2>(tuple)->store(load, out);
1018 }
1019 return result;
1020 }
1021
writeConstantVector(const Constructor & c)1022 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1023 SkASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1024 SpvId result = this->nextId();
1025 std::vector<SpvId> arguments;
1026 for (size_t i = 0; i < c.fArguments.size(); i++) {
1027 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1028 }
1029 SpvId type = this->getType(c.fType);
1030 if (c.fArguments.size() == 1) {
1031 // with a single argument, a vector will have all of its entries equal to the argument
1032 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1033 this->writeWord(type, fConstantBuffer);
1034 this->writeWord(result, fConstantBuffer);
1035 for (int i = 0; i < c.fType.columns(); i++) {
1036 this->writeWord(arguments[0], fConstantBuffer);
1037 }
1038 } else {
1039 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1040 fConstantBuffer);
1041 this->writeWord(type, fConstantBuffer);
1042 this->writeWord(result, fConstantBuffer);
1043 for (SpvId id : arguments) {
1044 this->writeWord(id, fConstantBuffer);
1045 }
1046 }
1047 return result;
1048 }
1049
writeFloatConstructor(const Constructor & c,OutputStream & out)1050 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1051 SkASSERT(c.fType.isFloat());
1052 SkASSERT(c.fArguments.size() == 1);
1053 SkASSERT(c.fArguments[0]->fType.isNumber());
1054 SpvId result = this->nextId();
1055 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1056 if (c.fArguments[0]->fType.isSigned()) {
1057 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1058 out);
1059 } else {
1060 SkASSERT(c.fArguments[0]->fType.isUnsigned());
1061 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1062 out);
1063 }
1064 return result;
1065 }
1066
writeIntConstructor(const Constructor & c,OutputStream & out)1067 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1068 SkASSERT(c.fType.isSigned());
1069 SkASSERT(c.fArguments.size() == 1);
1070 SkASSERT(c.fArguments[0]->fType.isNumber());
1071 SpvId result = this->nextId();
1072 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1073 if (c.fArguments[0]->fType.isFloat()) {
1074 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1075 out);
1076 }
1077 else {
1078 SkASSERT(c.fArguments[0]->fType.isUnsigned());
1079 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1080 out);
1081 }
1082 return result;
1083 }
1084
writeUIntConstructor(const Constructor & c,OutputStream & out)1085 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1086 SkASSERT(c.fType.isUnsigned());
1087 SkASSERT(c.fArguments.size() == 1);
1088 SkASSERT(c.fArguments[0]->fType.isNumber());
1089 SpvId result = this->nextId();
1090 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1091 if (c.fArguments[0]->fType.isFloat()) {
1092 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1093 out);
1094 } else {
1095 SkASSERT(c.fArguments[0]->fType.isSigned());
1096 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1097 out);
1098 }
1099 return result;
1100 }
1101
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1102 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1103 OutputStream& out) {
1104 FloatLiteral zero(fContext, -1, 0);
1105 SpvId zeroId = this->writeFloatLiteral(zero);
1106 std::vector<SpvId> columnIds;
1107 for (int column = 0; column < type.columns(); column++) {
1108 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1109 out);
1110 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1111 out);
1112 SpvId columnId = this->nextId();
1113 this->writeWord(columnId, out);
1114 columnIds.push_back(columnId);
1115 for (int row = 0; row < type.columns(); row++) {
1116 this->writeWord(row == column ? diagonal : zeroId, out);
1117 }
1118 }
1119 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1120 out);
1121 this->writeWord(this->getType(type), out);
1122 this->writeWord(id, out);
1123 for (SpvId id : columnIds) {
1124 this->writeWord(id, out);
1125 }
1126 }
1127
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1128 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1129 const Type& dstType, OutputStream& out) {
1130 SkASSERT(srcType.kind() == Type::kMatrix_Kind);
1131 SkASSERT(dstType.kind() == Type::kMatrix_Kind);
1132 SkASSERT(srcType.componentType() == dstType.componentType());
1133 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1134 srcType.rows(),
1135 1));
1136 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1137 dstType.rows(),
1138 1));
1139 SpvId zeroId;
1140 if (dstType.componentType() == *fContext.fFloat_Type) {
1141 FloatLiteral zero(fContext, -1, 0.0);
1142 zeroId = this->writeFloatLiteral(zero);
1143 } else if (dstType.componentType() == *fContext.fInt_Type) {
1144 IntLiteral zero(fContext, -1, 0);
1145 zeroId = this->writeIntLiteral(zero);
1146 } else {
1147 ABORT("unsupported matrix component type");
1148 }
1149 SpvId zeroColumn = 0;
1150 SpvId columns[4];
1151 for (int i = 0; i < dstType.columns(); i++) {
1152 if (i < srcType.columns()) {
1153 // we're still inside the src matrix, copy the column
1154 SpvId srcColumn = this->nextId();
1155 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1156 SpvId dstColumn;
1157 if (srcType.rows() == dstType.rows()) {
1158 // columns are equal size, don't need to do anything
1159 dstColumn = srcColumn;
1160 }
1161 else if (dstType.rows() > srcType.rows()) {
1162 // dst column is bigger, need to zero-pad it
1163 dstColumn = this->nextId();
1164 int delta = dstType.rows() - srcType.rows();
1165 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1166 this->writeWord(dstColumnType, out);
1167 this->writeWord(dstColumn, out);
1168 this->writeWord(srcColumn, out);
1169 for (int i = 0; i < delta; ++i) {
1170 this->writeWord(zeroId, out);
1171 }
1172 }
1173 else {
1174 // dst column is smaller, need to swizzle the src column
1175 dstColumn = this->nextId();
1176 int count = dstType.rows();
1177 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1178 this->writeWord(dstColumnType, out);
1179 this->writeWord(dstColumn, out);
1180 this->writeWord(srcColumn, out);
1181 this->writeWord(srcColumn, out);
1182 for (int i = 0; i < count; i++) {
1183 this->writeWord(i, out);
1184 }
1185 }
1186 columns[i] = dstColumn;
1187 } else {
1188 // we're past the end of the src matrix, need a vector of zeroes
1189 if (!zeroColumn) {
1190 zeroColumn = this->nextId();
1191 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1192 this->writeWord(dstColumnType, out);
1193 this->writeWord(zeroColumn, out);
1194 for (int i = 0; i < dstType.rows(); ++i) {
1195 this->writeWord(zeroId, out);
1196 }
1197 }
1198 columns[i] = zeroColumn;
1199 }
1200 }
1201 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1202 this->writeWord(this->getType(dstType), out);
1203 this->writeWord(id, out);
1204 for (int i = 0; i < dstType.columns(); i++) {
1205 this->writeWord(columns[i], out);
1206 }
1207 }
1208
writeMatrixConstructor(const Constructor & c,OutputStream & out)1209 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1210 SkASSERT(c.fType.kind() == Type::kMatrix_Kind);
1211 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1212 // an instruction
1213 std::vector<SpvId> arguments;
1214 for (size_t i = 0; i < c.fArguments.size(); i++) {
1215 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1216 }
1217 SpvId result = this->nextId();
1218 int rows = c.fType.rows();
1219 int columns = c.fType.columns();
1220 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1221 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1222 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1223 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1224 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1225 SkASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1226 SkASSERT(c.fArguments[0]->fType.columns() == 4);
1227 SpvId componentType = this->getType(c.fType.componentType());
1228 SpvId v[4];
1229 for (int i = 0; i < 4; ++i) {
1230 v[i] = this->nextId();
1231 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1232 }
1233 SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1234 SpvId column1 = this->nextId();
1235 this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1236 SpvId column2 = this->nextId();
1237 this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1238 this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1239 column2, out);
1240 } else {
1241 std::vector<SpvId> columnIds;
1242 // ids of vectors and scalars we have written to the current column so far
1243 std::vector<SpvId> currentColumn;
1244 // the total number of scalars represented by currentColumn's entries
1245 int currentCount = 0;
1246 for (size_t i = 0; i < arguments.size(); i++) {
1247 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1248 c.fArguments[i]->fType.columns() == c.fType.rows()) {
1249 // this is a complete column by itself
1250 SkASSERT(currentCount == 0);
1251 columnIds.push_back(arguments[i]);
1252 } else {
1253 if (c.fArguments[i]->fType.columns() == 1) {
1254 currentColumn.push_back(arguments[i]);
1255 } else {
1256 SpvId componentType = this->getType(c.fArguments[i]->fType.componentType());
1257 for (int j = 0; j < c.fArguments[i]->fType.columns(); ++j) {
1258 SpvId swizzle = this->nextId();
1259 this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1260 arguments[i], j, out);
1261 currentColumn.push_back(swizzle);
1262 }
1263 }
1264 currentCount += c.fArguments[i]->fType.columns();
1265 if (currentCount == rows) {
1266 currentCount = 0;
1267 this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1268 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1269 1)),
1270 out);
1271 SpvId columnId = this->nextId();
1272 this->writeWord(columnId, out);
1273 columnIds.push_back(columnId);
1274 for (SpvId id : currentColumn) {
1275 this->writeWord(id, out);
1276 }
1277 currentColumn.clear();
1278 }
1279 SkASSERT(currentCount < rows);
1280 }
1281 }
1282 SkASSERT(columnIds.size() == (size_t) columns);
1283 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1284 this->writeWord(this->getType(c.fType), out);
1285 this->writeWord(result, out);
1286 for (SpvId id : columnIds) {
1287 this->writeWord(id, out);
1288 }
1289 }
1290 return result;
1291 }
1292
writeVectorConstructor(const Constructor & c,OutputStream & out)1293 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1294 SkASSERT(c.fType.kind() == Type::kVector_Kind);
1295 if (c.isConstant()) {
1296 return this->writeConstantVector(c);
1297 }
1298 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1299 // an instruction
1300 std::vector<SpvId> arguments;
1301 for (size_t i = 0; i < c.fArguments.size(); i++) {
1302 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1303 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1304 // extract the components and convert them in that case manually. On top of that,
1305 // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1306 // doesn't handle vector arguments at all, so we always extract vector components and
1307 // pass them into OpCreateComposite individually.
1308 SpvId vec = this->writeExpression(*c.fArguments[i], out);
1309 SpvOp_ op = SpvOpUndef;
1310 const Type& src = c.fArguments[i]->fType.componentType();
1311 const Type& dst = c.fType.componentType();
1312 if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1313 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1314 if (c.fArguments.size() == 1) {
1315 return vec;
1316 }
1317 } else if (src == *fContext.fInt_Type ||
1318 src == *fContext.fShort_Type ||
1319 src == *fContext.fByte_Type) {
1320 op = SpvOpConvertSToF;
1321 } else if (src == *fContext.fUInt_Type ||
1322 src == *fContext.fUShort_Type ||
1323 src == *fContext.fUByte_Type) {
1324 op = SpvOpConvertUToF;
1325 } else {
1326 SkASSERT(false);
1327 }
1328 } else if (dst == *fContext.fInt_Type ||
1329 dst == *fContext.fShort_Type ||
1330 dst == *fContext.fByte_Type) {
1331 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1332 op = SpvOpConvertFToS;
1333 } else if (src == *fContext.fInt_Type ||
1334 src == *fContext.fShort_Type ||
1335 src == *fContext.fByte_Type) {
1336 if (c.fArguments.size() == 1) {
1337 return vec;
1338 }
1339 } else if (src == *fContext.fUInt_Type ||
1340 src == *fContext.fUShort_Type ||
1341 src == *fContext.fUByte_Type) {
1342 op = SpvOpBitcast;
1343 } else {
1344 SkASSERT(false);
1345 }
1346 } else if (dst == *fContext.fUInt_Type ||
1347 dst == *fContext.fUShort_Type ||
1348 dst == *fContext.fUByte_Type) {
1349 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1350 op = SpvOpConvertFToS;
1351 } else if (src == *fContext.fInt_Type ||
1352 src == *fContext.fShort_Type ||
1353 src == *fContext.fByte_Type) {
1354 op = SpvOpBitcast;
1355 } else if (src == *fContext.fUInt_Type ||
1356 src == *fContext.fUShort_Type ||
1357 src == *fContext.fUByte_Type) {
1358 if (c.fArguments.size() == 1) {
1359 return vec;
1360 }
1361 } else {
1362 SkASSERT(false);
1363 }
1364 }
1365 for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1366 SpvId swizzle = this->nextId();
1367 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1368 out);
1369 if (op != SpvOpUndef) {
1370 SpvId cast = this->nextId();
1371 this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1372 arguments.push_back(cast);
1373 } else {
1374 arguments.push_back(swizzle);
1375 }
1376 }
1377 } else {
1378 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1379 }
1380 }
1381 SpvId result = this->nextId();
1382 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1383 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1384 this->writeWord(this->getType(c.fType), out);
1385 this->writeWord(result, out);
1386 for (int i = 0; i < c.fType.columns(); i++) {
1387 this->writeWord(arguments[0], out);
1388 }
1389 } else {
1390 SkASSERT(arguments.size() > 1);
1391 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1392 this->writeWord(this->getType(c.fType), out);
1393 this->writeWord(result, out);
1394 for (SpvId id : arguments) {
1395 this->writeWord(id, out);
1396 }
1397 }
1398 return result;
1399 }
1400
writeArrayConstructor(const Constructor & c,OutputStream & out)1401 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1402 SkASSERT(c.fType.kind() == Type::kArray_Kind);
1403 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1404 // an instruction
1405 std::vector<SpvId> arguments;
1406 for (size_t i = 0; i < c.fArguments.size(); i++) {
1407 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1408 }
1409 SpvId result = this->nextId();
1410 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1411 this->writeWord(this->getType(c.fType), out);
1412 this->writeWord(result, out);
1413 for (SpvId id : arguments) {
1414 this->writeWord(id, out);
1415 }
1416 return result;
1417 }
1418
writeConstructor(const Constructor & c,OutputStream & out)1419 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1420 if (c.fArguments.size() == 1 &&
1421 this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1422 return this->writeExpression(*c.fArguments[0], out);
1423 }
1424 if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1425 return this->writeFloatConstructor(c, out);
1426 } else if (c.fType == *fContext.fInt_Type ||
1427 c.fType == *fContext.fShort_Type ||
1428 c.fType == *fContext.fByte_Type) {
1429 return this->writeIntConstructor(c, out);
1430 } else if (c.fType == *fContext.fUInt_Type ||
1431 c.fType == *fContext.fUShort_Type ||
1432 c.fType == *fContext.fUByte_Type) {
1433 return this->writeUIntConstructor(c, out);
1434 }
1435 switch (c.fType.kind()) {
1436 case Type::kVector_Kind:
1437 return this->writeVectorConstructor(c, out);
1438 case Type::kMatrix_Kind:
1439 return this->writeMatrixConstructor(c, out);
1440 case Type::kArray_Kind:
1441 return this->writeArrayConstructor(c, out);
1442 default:
1443 ABORT("unsupported constructor: %s", c.description().c_str());
1444 }
1445 }
1446
get_storage_class(const Modifiers & modifiers)1447 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1448 if (modifiers.fFlags & Modifiers::kIn_Flag) {
1449 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1450 return SpvStorageClassInput;
1451 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1452 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1453 return SpvStorageClassOutput;
1454 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1455 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1456 return SpvStorageClassPushConstant;
1457 }
1458 return SpvStorageClassUniform;
1459 } else {
1460 return SpvStorageClassFunction;
1461 }
1462 }
1463
get_storage_class(const Expression & expr)1464 SpvStorageClass_ get_storage_class(const Expression& expr) {
1465 switch (expr.fKind) {
1466 case Expression::kVariableReference_Kind: {
1467 const Variable& var = ((VariableReference&) expr).fVariable;
1468 if (var.fStorage != Variable::kGlobal_Storage) {
1469 return SpvStorageClassFunction;
1470 }
1471 SpvStorageClass_ result = get_storage_class(var.fModifiers);
1472 if (result == SpvStorageClassFunction) {
1473 result = SpvStorageClassPrivate;
1474 }
1475 return result;
1476 }
1477 case Expression::kFieldAccess_Kind:
1478 return get_storage_class(*((FieldAccess&) expr).fBase);
1479 case Expression::kIndex_Kind:
1480 return get_storage_class(*((IndexExpression&) expr).fBase);
1481 default:
1482 return SpvStorageClassFunction;
1483 }
1484 }
1485
getAccessChain(const Expression & expr,OutputStream & out)1486 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1487 std::vector<SpvId> chain;
1488 switch (expr.fKind) {
1489 case Expression::kIndex_Kind: {
1490 IndexExpression& indexExpr = (IndexExpression&) expr;
1491 chain = this->getAccessChain(*indexExpr.fBase, out);
1492 chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1493 break;
1494 }
1495 case Expression::kFieldAccess_Kind: {
1496 FieldAccess& fieldExpr = (FieldAccess&) expr;
1497 chain = this->getAccessChain(*fieldExpr.fBase, out);
1498 IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1499 chain.push_back(this->writeIntLiteral(index));
1500 break;
1501 }
1502 default: {
1503 SpvId id = this->getLValue(expr, out)->getPointer();
1504 SkASSERT(id != 0);
1505 chain.push_back(id);
1506 }
1507 }
1508 return chain;
1509 }
1510
1511 class PointerLValue : public SPIRVCodeGenerator::LValue {
1512 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1513 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1514 : fGen(gen)
1515 , fPointer(pointer)
1516 , fType(type) {}
1517
getPointer()1518 virtual SpvId getPointer() override {
1519 return fPointer;
1520 }
1521
load(OutputStream & out)1522 virtual SpvId load(OutputStream& out) override {
1523 SpvId result = fGen.nextId();
1524 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1525 return result;
1526 }
1527
store(SpvId value,OutputStream & out)1528 virtual void store(SpvId value, OutputStream& out) override {
1529 fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1530 }
1531
1532 private:
1533 SPIRVCodeGenerator& fGen;
1534 const SpvId fPointer;
1535 const SpvId fType;
1536 };
1537
1538 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1539 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1540 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1541 const Type& baseType, const Type& swizzleType)
1542 : fGen(gen)
1543 , fVecPointer(vecPointer)
1544 , fComponents(components)
1545 , fBaseType(baseType)
1546 , fSwizzleType(swizzleType) {}
1547
getPointer()1548 virtual SpvId getPointer() override {
1549 return 0;
1550 }
1551
load(OutputStream & out)1552 virtual SpvId load(OutputStream& out) override {
1553 SpvId base = fGen.nextId();
1554 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1555 SpvId result = fGen.nextId();
1556 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1557 fGen.writeWord(fGen.getType(fSwizzleType), out);
1558 fGen.writeWord(result, out);
1559 fGen.writeWord(base, out);
1560 fGen.writeWord(base, out);
1561 for (int component : fComponents) {
1562 fGen.writeWord(component, out);
1563 }
1564 return result;
1565 }
1566
store(SpvId value,OutputStream & out)1567 virtual void store(SpvId value, OutputStream& out) override {
1568 // use OpVectorShuffle to mix and match the vector components. We effectively create
1569 // a virtual vector out of the concatenation of the left and right vectors, and then
1570 // select components from this virtual vector to make the result vector. For
1571 // instance, given:
1572 // float3L = ...;
1573 // float3R = ...;
1574 // L.xz = R.xy;
1575 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1576 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1577 // (3, 1, 4).
1578 SpvId base = fGen.nextId();
1579 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1580 SpvId shuffle = fGen.nextId();
1581 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1582 fGen.writeWord(fGen.getType(fBaseType), out);
1583 fGen.writeWord(shuffle, out);
1584 fGen.writeWord(base, out);
1585 fGen.writeWord(value, out);
1586 for (int i = 0; i < fBaseType.columns(); i++) {
1587 // current offset into the virtual vector, defaults to pulling the unmodified
1588 // value from the left side
1589 int offset = i;
1590 // check to see if we are writing this component
1591 for (size_t j = 0; j < fComponents.size(); j++) {
1592 if (fComponents[j] == i) {
1593 // we're writing to this component, so adjust the offset to pull from
1594 // the correct component of the right side instead of preserving the
1595 // value from the left
1596 offset = (int) (j + fBaseType.columns());
1597 break;
1598 }
1599 }
1600 fGen.writeWord(offset, out);
1601 }
1602 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1603 }
1604
1605 private:
1606 SPIRVCodeGenerator& fGen;
1607 const SpvId fVecPointer;
1608 const std::vector<int>& fComponents;
1609 const Type& fBaseType;
1610 const Type& fSwizzleType;
1611 };
1612
getLValue(const Expression & expr,OutputStream & out)1613 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1614 OutputStream& out) {
1615 switch (expr.fKind) {
1616 case Expression::kVariableReference_Kind: {
1617 SpvId type;
1618 const Variable& var = ((VariableReference&) expr).fVariable;
1619 if (var.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
1620 type = this->getType(Type("sk_in", Type::kArray_Kind, var.fType.componentType(),
1621 fSkInCount));
1622 } else {
1623 type = this->getType(expr.fType);
1624 }
1625 auto entry = fVariableMap.find(&var);
1626 SkASSERT(entry != fVariableMap.end());
1627 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(*this,
1628 entry->second,
1629 type));
1630 }
1631 case Expression::kIndex_Kind: // fall through
1632 case Expression::kFieldAccess_Kind: {
1633 std::vector<SpvId> chain = this->getAccessChain(expr, out);
1634 SpvId member = this->nextId();
1635 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1636 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1637 this->writeWord(member, out);
1638 for (SpvId idx : chain) {
1639 this->writeWord(idx, out);
1640 }
1641 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1642 *this,
1643 member,
1644 this->getType(expr.fType)));
1645 }
1646 case Expression::kSwizzle_Kind: {
1647 Swizzle& swizzle = (Swizzle&) expr;
1648 size_t count = swizzle.fComponents.size();
1649 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1650 SkASSERT(base);
1651 if (count == 1) {
1652 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1653 SpvId member = this->nextId();
1654 this->writeInstruction(SpvOpAccessChain,
1655 this->getPointerType(swizzle.fType,
1656 get_storage_class(*swizzle.fBase)),
1657 member,
1658 base,
1659 this->writeIntLiteral(index),
1660 out);
1661 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1662 *this,
1663 member,
1664 this->getType(expr.fType)));
1665 } else {
1666 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1667 *this,
1668 base,
1669 swizzle.fComponents,
1670 swizzle.fBase->fType,
1671 expr.fType));
1672 }
1673 }
1674 case Expression::kTernary_Kind: {
1675 TernaryExpression& t = (TernaryExpression&) expr;
1676 SpvId test = this->writeExpression(*t.fTest, out);
1677 SpvId end = this->nextId();
1678 SpvId ifTrueLabel = this->nextId();
1679 SpvId ifFalseLabel = this->nextId();
1680 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1681 this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1682 this->writeLabel(ifTrueLabel, out);
1683 SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1684 SkASSERT(ifTrue);
1685 this->writeInstruction(SpvOpBranch, end, out);
1686 ifTrueLabel = fCurrentBlock;
1687 SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1688 SkASSERT(ifFalse);
1689 ifFalseLabel = fCurrentBlock;
1690 this->writeInstruction(SpvOpBranch, end, out);
1691 SpvId result = this->nextId();
1692 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1693 ifTrueLabel, ifFalse, ifFalseLabel, out);
1694 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1695 *this,
1696 result,
1697 this->getType(expr.fType)));
1698 }
1699 default:
1700 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1701 // to the need to store values in temporary variables during function calls (see
1702 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1703 // caught by IRGenerator
1704 SpvId result = this->nextId();
1705 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1706 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1707 fVariableBuffer);
1708 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1709 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1710 *this,
1711 result,
1712 this->getType(expr.fType)));
1713 }
1714 }
1715
writeVariableReference(const VariableReference & ref,OutputStream & out)1716 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1717 SpvId result = this->nextId();
1718 auto entry = fVariableMap.find(&ref.fVariable);
1719 SkASSERT(entry != fVariableMap.end());
1720 SpvId var = entry->second;
1721 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1722 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1723 fProgram.fSettings.fFlipY) {
1724 // need to remap to a top-left coordinate system
1725 if (fRTHeightStructId == (SpvId) -1) {
1726 // height variable hasn't been written yet
1727 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1728 SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
1729 std::vector<Type::Field> fields;
1730 fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1731 StringFragment name("sksl_synthetic_uniforms");
1732 Type intfStruct(-1, name, fields);
1733 Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1734 Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1735 Layout::CType::kDefault);
1736 Variable* intfVar = new Variable(-1,
1737 Modifiers(layout, Modifiers::kUniform_Flag),
1738 name,
1739 intfStruct,
1740 Variable::kGlobal_Storage);
1741 fSynthetics.takeOwnership(intfVar);
1742 InterfaceBlock intf(-1, intfVar, name, String(""),
1743 std::vector<std::unique_ptr<Expression>>(), st);
1744 fRTHeightStructId = this->writeInterfaceBlock(intf);
1745 fRTHeightFieldIndex = 0;
1746 }
1747 SkASSERT(fRTHeightFieldIndex != (SpvId) -1);
1748 // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, gl_FragCoord.w)
1749 SpvId xId = this->nextId();
1750 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1751 result, 0, out);
1752 IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1753 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1754 SpvId heightPtr = this->nextId();
1755 this->writeOpCode(SpvOpAccessChain, 5, out);
1756 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1757 this->writeWord(heightPtr, out);
1758 this->writeWord(fRTHeightStructId, out);
1759 this->writeWord(fieldIndexId, out);
1760 SpvId heightRead = this->nextId();
1761 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1762 heightPtr, out);
1763 SpvId rawYId = this->nextId();
1764 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1765 result, 1, out);
1766 SpvId flippedYId = this->nextId();
1767 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1768 heightRead, rawYId, out);
1769 FloatLiteral zero(fContext, -1, 0.0);
1770 SpvId zeroId = writeFloatLiteral(zero);
1771 FloatLiteral one(fContext, -1, 1.0);
1772 SpvId wId = this->nextId();
1773 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), wId,
1774 result, 3, out);
1775 SpvId flipped = this->nextId();
1776 this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1777 this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1778 this->writeWord(flipped, out);
1779 this->writeWord(xId, out);
1780 this->writeWord(flippedYId, out);
1781 this->writeWord(zeroId, out);
1782 this->writeWord(wId, out);
1783 return flipped;
1784 }
1785 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
1786 !fProgram.fSettings.fFlipY) {
1787 // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
1788 // the default convention of "counter-clockwise face is front".
1789 SpvId inverse = this->nextId();
1790 this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fBool_Type), inverse,
1791 result, out);
1792 return inverse;
1793 }
1794 return result;
1795 }
1796
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1797 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1798 if (expr.fBase->fType.kind() == Type::Kind::kVector_Kind) {
1799 SpvId base = this->writeExpression(*expr.fBase, out);
1800 SpvId index = this->writeExpression(*expr.fIndex, out);
1801 SpvId result = this->nextId();
1802 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.fType), result, base,
1803 index, out);
1804 return result;
1805 }
1806 return getLValue(expr, out)->load(out);
1807 }
1808
writeFieldAccess(const FieldAccess & f,OutputStream & out)1809 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1810 return getLValue(f, out)->load(out);
1811 }
1812
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1813 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1814 SpvId base = this->writeExpression(*swizzle.fBase, out);
1815 SpvId result = this->nextId();
1816 size_t count = swizzle.fComponents.size();
1817 if (count == 1) {
1818 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1819 swizzle.fComponents[0], out);
1820 } else {
1821 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1822 this->writeWord(this->getType(swizzle.fType), out);
1823 this->writeWord(result, out);
1824 this->writeWord(base, out);
1825 this->writeWord(base, out);
1826 for (int component : swizzle.fComponents) {
1827 this->writeWord(component, out);
1828 }
1829 }
1830 return result;
1831 }
1832
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1833 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1834 const Type& operandType, SpvId lhs,
1835 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1836 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1837 SpvId result = this->nextId();
1838 if (is_float(fContext, operandType)) {
1839 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1840 } else if (is_signed(fContext, operandType)) {
1841 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1842 } else if (is_unsigned(fContext, operandType)) {
1843 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1844 } else if (operandType == *fContext.fBool_Type) {
1845 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1846 } else {
1847 ABORT("invalid operandType: %s", operandType.description().c_str());
1848 }
1849 return result;
1850 }
1851
is_assignment(Token::Kind op)1852 bool is_assignment(Token::Kind op) {
1853 switch (op) {
1854 case Token::EQ: // fall through
1855 case Token::PLUSEQ: // fall through
1856 case Token::MINUSEQ: // fall through
1857 case Token::STAREQ: // fall through
1858 case Token::SLASHEQ: // fall through
1859 case Token::PERCENTEQ: // fall through
1860 case Token::SHLEQ: // fall through
1861 case Token::SHREQ: // fall through
1862 case Token::BITWISEOREQ: // fall through
1863 case Token::BITWISEXOREQ: // fall through
1864 case Token::BITWISEANDEQ: // fall through
1865 case Token::LOGICALOREQ: // fall through
1866 case Token::LOGICALXOREQ: // fall through
1867 case Token::LOGICALANDEQ:
1868 return true;
1869 default:
1870 return false;
1871 }
1872 }
1873
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)1874 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
1875 OutputStream& out) {
1876 if (operandType.kind() == Type::kVector_Kind) {
1877 SpvId result = this->nextId();
1878 this->writeInstruction(op, this->getType(*fContext.fBool_Type), result, id, out);
1879 return result;
1880 }
1881 return id;
1882 }
1883
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)1884 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1885 SpvOp_ floatOperator, SpvOp_ intOperator,
1886 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
1887 OutputStream& out) {
1888 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1889 SkASSERT(operandType.kind() == Type::kMatrix_Kind);
1890 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
1891 operandType.rows(),
1892 1));
1893 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1894 operandType.rows(),
1895 1));
1896 SpvId boolType = this->getType(*fContext.fBool_Type);
1897 SpvId result = 0;
1898 for (int i = 0; i < operandType.columns(); i++) {
1899 SpvId columnL = this->nextId();
1900 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
1901 SpvId columnR = this->nextId();
1902 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
1903 SpvId compare = this->nextId();
1904 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
1905 SpvId merge = this->nextId();
1906 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
1907 if (result != 0) {
1908 SpvId next = this->nextId();
1909 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
1910 result = next;
1911 }
1912 else {
1913 result = merge;
1914 }
1915 }
1916 return result;
1917 }
1918
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)1919 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
1920 SpvId rhs, SpvOp_ floatOperator,
1921 SpvOp_ intOperator,
1922 OutputStream& out) {
1923 SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
1924 SkASSERT(operandType.kind() == Type::kMatrix_Kind);
1925 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
1926 operandType.rows(),
1927 1));
1928 SpvId columns[4];
1929 for (int i = 0; i < operandType.columns(); i++) {
1930 SpvId columnL = this->nextId();
1931 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
1932 SpvId columnR = this->nextId();
1933 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
1934 columns[i] = this->nextId();
1935 this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
1936 }
1937 SpvId result = this->nextId();
1938 this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
1939 this->writeWord(this->getType(operandType), out);
1940 this->writeWord(result, out);
1941 for (int i = 0; i < operandType.columns(); i++) {
1942 this->writeWord(columns[i], out);
1943 }
1944 return result;
1945 }
1946
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)1947 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1948 // handle cases where we don't necessarily evaluate both LHS and RHS
1949 switch (b.fOperator) {
1950 case Token::EQ: {
1951 SpvId rhs = this->writeExpression(*b.fRight, out);
1952 this->getLValue(*b.fLeft, out)->store(rhs, out);
1953 return rhs;
1954 }
1955 case Token::LOGICALAND:
1956 return this->writeLogicalAnd(b, out);
1957 case Token::LOGICALOR:
1958 return this->writeLogicalOr(b, out);
1959 default:
1960 break;
1961 }
1962
1963 // "normal" operators
1964 const Type& resultType = b.fType;
1965 std::unique_ptr<LValue> lvalue;
1966 SpvId lhs;
1967 if (is_assignment(b.fOperator)) {
1968 lvalue = this->getLValue(*b.fLeft, out);
1969 lhs = lvalue->load(out);
1970 } else {
1971 lvalue = nullptr;
1972 lhs = this->writeExpression(*b.fLeft, out);
1973 }
1974 SpvId rhs = this->writeExpression(*b.fRight, out);
1975 if (b.fOperator == Token::COMMA) {
1976 return rhs;
1977 }
1978 Type tmp("<invalid>");
1979 // overall type we are operating on: float2, int, uint4...
1980 const Type* operandType;
1981 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
1982 // handling in SPIR-V
1983 if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1984 if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1985 b.fRight->fType.isNumber()) {
1986 // promote number to vector
1987 SpvId vec = this->nextId();
1988 const Type& vecType = b.fLeft->fType;
1989 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
1990 this->writeWord(this->getType(vecType), out);
1991 this->writeWord(vec, out);
1992 for (int i = 0; i < vecType.columns(); i++) {
1993 this->writeWord(rhs, out);
1994 }
1995 rhs = vec;
1996 operandType = &b.fLeft->fType;
1997 } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1998 b.fLeft->fType.isNumber()) {
1999 // promote number to vector
2000 SpvId vec = this->nextId();
2001 const Type& vecType = b.fRight->fType;
2002 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2003 this->writeWord(this->getType(vecType), out);
2004 this->writeWord(vec, out);
2005 for (int i = 0; i < vecType.columns(); i++) {
2006 this->writeWord(lhs, out);
2007 }
2008 lhs = vec;
2009 SkASSERT(!lvalue);
2010 operandType = &b.fRight->fType;
2011 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2012 SpvOp_ op;
2013 if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2014 op = SpvOpMatrixTimesMatrix;
2015 } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2016 op = SpvOpMatrixTimesVector;
2017 } else {
2018 SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2019 op = SpvOpMatrixTimesScalar;
2020 }
2021 SpvId result = this->nextId();
2022 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2023 if (b.fOperator == Token::STAREQ) {
2024 lvalue->store(result, out);
2025 } else {
2026 SkASSERT(b.fOperator == Token::STAR);
2027 }
2028 return result;
2029 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2030 SpvId result = this->nextId();
2031 if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2032 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2033 lhs, rhs, out);
2034 } else {
2035 SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2036 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2037 lhs, out);
2038 }
2039 if (b.fOperator == Token::STAREQ) {
2040 lvalue->store(result, out);
2041 } else {
2042 SkASSERT(b.fOperator == Token::STAR);
2043 }
2044 return result;
2045 } else {
2046 ABORT("unsupported binary expression: %s", b.description().c_str());
2047 }
2048 } else {
2049 tmp = this->getActualType(b.fLeft->fType);
2050 operandType = &tmp;
2051 SkASSERT(*operandType == this->getActualType(b.fRight->fType));
2052 }
2053 switch (b.fOperator) {
2054 case Token::EQEQ: {
2055 if (operandType->kind() == Type::kMatrix_Kind) {
2056 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2057 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2058 }
2059 SkASSERT(resultType == *fContext.fBool_Type);
2060 const Type* tmpType;
2061 if (operandType->kind() == Type::kVector_Kind) {
2062 tmpType = &fContext.fBool_Type->toCompound(fContext,
2063 operandType->columns(),
2064 operandType->rows());
2065 } else {
2066 tmpType = &resultType;
2067 }
2068 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2069 SpvOpFOrdEqual, SpvOpIEqual,
2070 SpvOpIEqual, SpvOpLogicalEqual, out),
2071 *operandType, SpvOpAll, out);
2072 }
2073 case Token::NEQ:
2074 if (operandType->kind() == Type::kMatrix_Kind) {
2075 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2076 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2077 }
2078 SkASSERT(resultType == *fContext.fBool_Type);
2079 const Type* tmpType;
2080 if (operandType->kind() == Type::kVector_Kind) {
2081 tmpType = &fContext.fBool_Type->toCompound(fContext,
2082 operandType->columns(),
2083 operandType->rows());
2084 } else {
2085 tmpType = &resultType;
2086 }
2087 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2088 SpvOpFOrdNotEqual, SpvOpINotEqual,
2089 SpvOpINotEqual, SpvOpLogicalNotEqual,
2090 out),
2091 *operandType, SpvOpAny, out);
2092 case Token::GT:
2093 SkASSERT(resultType == *fContext.fBool_Type);
2094 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2095 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2096 SpvOpUGreaterThan, SpvOpUndef, out);
2097 case Token::LT:
2098 SkASSERT(resultType == *fContext.fBool_Type);
2099 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2100 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2101 case Token::GTEQ:
2102 SkASSERT(resultType == *fContext.fBool_Type);
2103 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2104 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2105 SpvOpUGreaterThanEqual, SpvOpUndef, out);
2106 case Token::LTEQ:
2107 SkASSERT(resultType == *fContext.fBool_Type);
2108 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2109 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2110 SpvOpULessThanEqual, SpvOpUndef, out);
2111 case Token::PLUS:
2112 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2113 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2114 SkASSERT(b.fLeft->fType == b.fRight->fType);
2115 return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2116 SpvOpFAdd, SpvOpIAdd, out);
2117 }
2118 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2119 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2120 case Token::MINUS:
2121 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2122 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2123 SkASSERT(b.fLeft->fType == b.fRight->fType);
2124 return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2125 SpvOpFSub, SpvOpISub, out);
2126 }
2127 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2128 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2129 case Token::STAR:
2130 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2131 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2132 // matrix multiply
2133 SpvId result = this->nextId();
2134 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2135 lhs, rhs, out);
2136 return result;
2137 }
2138 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2139 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2140 case Token::SLASH:
2141 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2142 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2143 case Token::PERCENT:
2144 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2145 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2146 case Token::SHL:
2147 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2148 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2149 SpvOpUndef, out);
2150 case Token::SHR:
2151 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2152 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2153 SpvOpUndef, out);
2154 case Token::BITWISEAND:
2155 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2156 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2157 case Token::BITWISEOR:
2158 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2159 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2160 case Token::BITWISEXOR:
2161 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2162 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2163 case Token::PLUSEQ: {
2164 SpvId result;
2165 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2166 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2167 SkASSERT(b.fLeft->fType == b.fRight->fType);
2168 result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2169 SpvOpFAdd, SpvOpIAdd, out);
2170 }
2171 else {
2172 result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2173 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2174 }
2175 SkASSERT(lvalue);
2176 lvalue->store(result, out);
2177 return result;
2178 }
2179 case Token::MINUSEQ: {
2180 SpvId result;
2181 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2182 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2183 SkASSERT(b.fLeft->fType == b.fRight->fType);
2184 result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
2185 SpvOpFSub, SpvOpISub, out);
2186 }
2187 else {
2188 result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2189 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2190 }
2191 SkASSERT(lvalue);
2192 lvalue->store(result, out);
2193 return result;
2194 }
2195 case Token::STAREQ: {
2196 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2197 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2198 // matrix multiply
2199 SpvId result = this->nextId();
2200 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2201 lhs, rhs, out);
2202 SkASSERT(lvalue);
2203 lvalue->store(result, out);
2204 return result;
2205 }
2206 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2207 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2208 SkASSERT(lvalue);
2209 lvalue->store(result, out);
2210 return result;
2211 }
2212 case Token::SLASHEQ: {
2213 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2214 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2215 SkASSERT(lvalue);
2216 lvalue->store(result, out);
2217 return result;
2218 }
2219 case Token::PERCENTEQ: {
2220 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2221 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2222 SkASSERT(lvalue);
2223 lvalue->store(result, out);
2224 return result;
2225 }
2226 case Token::SHLEQ: {
2227 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2228 SpvOpUndef, SpvOpShiftLeftLogical,
2229 SpvOpShiftLeftLogical, SpvOpUndef, out);
2230 SkASSERT(lvalue);
2231 lvalue->store(result, out);
2232 return result;
2233 }
2234 case Token::SHREQ: {
2235 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2236 SpvOpUndef, SpvOpShiftRightArithmetic,
2237 SpvOpShiftRightLogical, SpvOpUndef, out);
2238 SkASSERT(lvalue);
2239 lvalue->store(result, out);
2240 return result;
2241 }
2242 case Token::BITWISEANDEQ: {
2243 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2244 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2245 SpvOpUndef, out);
2246 SkASSERT(lvalue);
2247 lvalue->store(result, out);
2248 return result;
2249 }
2250 case Token::BITWISEOREQ: {
2251 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2252 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2253 SpvOpUndef, out);
2254 SkASSERT(lvalue);
2255 lvalue->store(result, out);
2256 return result;
2257 }
2258 case Token::BITWISEXOREQ: {
2259 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2260 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2261 SpvOpUndef, out);
2262 SkASSERT(lvalue);
2263 lvalue->store(result, out);
2264 return result;
2265 }
2266 default:
2267 ABORT("unsupported binary expression: %s", b.description().c_str());
2268 }
2269 }
2270
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2271 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2272 SkASSERT(a.fOperator == Token::LOGICALAND);
2273 BoolLiteral falseLiteral(fContext, -1, false);
2274 SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2275 SpvId lhs = this->writeExpression(*a.fLeft, out);
2276 SpvId rhsLabel = this->nextId();
2277 SpvId end = this->nextId();
2278 SpvId lhsBlock = fCurrentBlock;
2279 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2280 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2281 this->writeLabel(rhsLabel, out);
2282 SpvId rhs = this->writeExpression(*a.fRight, out);
2283 SpvId rhsBlock = fCurrentBlock;
2284 this->writeInstruction(SpvOpBranch, end, out);
2285 this->writeLabel(end, out);
2286 SpvId result = this->nextId();
2287 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2288 lhsBlock, rhs, rhsBlock, out);
2289 return result;
2290 }
2291
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2292 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2293 SkASSERT(o.fOperator == Token::LOGICALOR);
2294 BoolLiteral trueLiteral(fContext, -1, true);
2295 SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2296 SpvId lhs = this->writeExpression(*o.fLeft, out);
2297 SpvId rhsLabel = this->nextId();
2298 SpvId end = this->nextId();
2299 SpvId lhsBlock = fCurrentBlock;
2300 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2301 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2302 this->writeLabel(rhsLabel, out);
2303 SpvId rhs = this->writeExpression(*o.fRight, out);
2304 SpvId rhsBlock = fCurrentBlock;
2305 this->writeInstruction(SpvOpBranch, end, out);
2306 this->writeLabel(end, out);
2307 SpvId result = this->nextId();
2308 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2309 lhsBlock, rhs, rhsBlock, out);
2310 return result;
2311 }
2312
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2313 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2314 SpvId test = this->writeExpression(*t.fTest, out);
2315 if (t.fIfTrue->fType.columns() == 1 && t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2316 // both true and false are constants, can just use OpSelect
2317 SpvId result = this->nextId();
2318 SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2319 SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2320 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2321 out);
2322 return result;
2323 }
2324 // was originally using OpPhi to choose the result, but for some reason that is crashing on
2325 // Adreno. Switched to storing the result in a temp variable as glslang does.
2326 SpvId var = this->nextId();
2327 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2328 var, SpvStorageClassFunction, fVariableBuffer);
2329 SpvId trueLabel = this->nextId();
2330 SpvId falseLabel = this->nextId();
2331 SpvId end = this->nextId();
2332 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2333 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2334 this->writeLabel(trueLabel, out);
2335 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2336 this->writeInstruction(SpvOpBranch, end, out);
2337 this->writeLabel(falseLabel, out);
2338 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2339 this->writeInstruction(SpvOpBranch, end, out);
2340 this->writeLabel(end, out);
2341 SpvId result = this->nextId();
2342 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2343 return result;
2344 }
2345
create_literal_1(const Context & context,const Type & type)2346 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2347 if (type.isInteger()) {
2348 return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
2349 }
2350 else if (type.isFloat()) {
2351 return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
2352 } else {
2353 ABORT("math is unsupported on type '%s'", type.name().c_str());
2354 }
2355 }
2356
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2357 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2358 if (p.fOperator == Token::MINUS) {
2359 SpvId result = this->nextId();
2360 SpvId typeId = this->getType(p.fType);
2361 SpvId expr = this->writeExpression(*p.fOperand, out);
2362 if (is_float(fContext, p.fType)) {
2363 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2364 } else if (is_signed(fContext, p.fType)) {
2365 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2366 } else {
2367 ABORT("unsupported prefix expression %s", p.description().c_str());
2368 }
2369 return result;
2370 }
2371 switch (p.fOperator) {
2372 case Token::PLUS:
2373 return this->writeExpression(*p.fOperand, out);
2374 case Token::PLUSPLUS: {
2375 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2376 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2377 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2378 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2379 out);
2380 lv->store(result, out);
2381 return result;
2382 }
2383 case Token::MINUSMINUS: {
2384 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2385 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2386 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2387 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2388 out);
2389 lv->store(result, out);
2390 return result;
2391 }
2392 case Token::LOGICALNOT: {
2393 SkASSERT(p.fOperand->fType == *fContext.fBool_Type);
2394 SpvId result = this->nextId();
2395 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2396 this->writeExpression(*p.fOperand, out), out);
2397 return result;
2398 }
2399 case Token::BITWISENOT: {
2400 SpvId result = this->nextId();
2401 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2402 this->writeExpression(*p.fOperand, out), out);
2403 return result;
2404 }
2405 default:
2406 ABORT("unsupported prefix expression: %s", p.description().c_str());
2407 }
2408 }
2409
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2410 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2411 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2412 SpvId result = lv->load(out);
2413 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2414 switch (p.fOperator) {
2415 case Token::PLUSPLUS: {
2416 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2417 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2418 lv->store(temp, out);
2419 return result;
2420 }
2421 case Token::MINUSMINUS: {
2422 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2423 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2424 lv->store(temp, out);
2425 return result;
2426 }
2427 default:
2428 ABORT("unsupported postfix expression %s", p.description().c_str());
2429 }
2430 }
2431
writeBoolLiteral(const BoolLiteral & b)2432 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2433 if (b.fValue) {
2434 if (fBoolTrue == 0) {
2435 fBoolTrue = this->nextId();
2436 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2437 fConstantBuffer);
2438 }
2439 return fBoolTrue;
2440 } else {
2441 if (fBoolFalse == 0) {
2442 fBoolFalse = this->nextId();
2443 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2444 fConstantBuffer);
2445 }
2446 return fBoolFalse;
2447 }
2448 }
2449
writeIntLiteral(const IntLiteral & i)2450 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2451 if (i.fType == *fContext.fInt_Type) {
2452 auto entry = fIntConstants.find(i.fValue);
2453 if (entry == fIntConstants.end()) {
2454 SpvId result = this->nextId();
2455 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2456 fConstantBuffer);
2457 fIntConstants[i.fValue] = result;
2458 return result;
2459 }
2460 return entry->second;
2461 } else {
2462 SkASSERT(i.fType == *fContext.fUInt_Type);
2463 auto entry = fUIntConstants.find(i.fValue);
2464 if (entry == fUIntConstants.end()) {
2465 SpvId result = this->nextId();
2466 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2467 fConstantBuffer);
2468 fUIntConstants[i.fValue] = result;
2469 return result;
2470 }
2471 return entry->second;
2472 }
2473 }
2474
writeFloatLiteral(const FloatLiteral & f)2475 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2476 if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2477 float value = (float) f.fValue;
2478 auto entry = fFloatConstants.find(value);
2479 if (entry == fFloatConstants.end()) {
2480 SpvId result = this->nextId();
2481 uint32_t bits;
2482 SkASSERT(sizeof(bits) == sizeof(value));
2483 memcpy(&bits, &value, sizeof(bits));
2484 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2485 fConstantBuffer);
2486 fFloatConstants[value] = result;
2487 return result;
2488 }
2489 return entry->second;
2490 } else {
2491 SkASSERT(f.fType == *fContext.fDouble_Type);
2492 auto entry = fDoubleConstants.find(f.fValue);
2493 if (entry == fDoubleConstants.end()) {
2494 SpvId result = this->nextId();
2495 uint64_t bits;
2496 SkASSERT(sizeof(bits) == sizeof(f.fValue));
2497 memcpy(&bits, &f.fValue, sizeof(bits));
2498 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2499 bits & 0xffffffff, bits >> 32, fConstantBuffer);
2500 fDoubleConstants[f.fValue] = result;
2501 return result;
2502 }
2503 return entry->second;
2504 }
2505 }
2506
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2507 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2508 SpvId result = fFunctionMap[&f];
2509 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2510 SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2511 this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2512 for (size_t i = 0; i < f.fParameters.size(); i++) {
2513 SpvId id = this->nextId();
2514 fVariableMap[f.fParameters[i]] = id;
2515 SpvId type;
2516 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2517 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2518 }
2519 return result;
2520 }
2521
writeFunction(const FunctionDefinition & f,OutputStream & out)2522 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2523 fVariableBuffer.reset();
2524 SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2525 this->writeLabel(this->nextId(), out);
2526 StringStream bodyBuffer;
2527 this->writeBlock((Block&) *f.fBody, bodyBuffer);
2528 write_stringstream(fVariableBuffer, out);
2529 if (f.fDeclaration.fName == "main") {
2530 write_stringstream(fGlobalInitializersBuffer, out);
2531 }
2532 write_stringstream(bodyBuffer, out);
2533 if (fCurrentBlock) {
2534 if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2535 this->writeInstruction(SpvOpReturn, out);
2536 } else {
2537 this->writeInstruction(SpvOpUnreachable, out);
2538 }
2539 }
2540 this->writeInstruction(SpvOpFunctionEnd, out);
2541 return result;
2542 }
2543
writeLayout(const Layout & layout,SpvId target)2544 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2545 if (layout.fLocation >= 0) {
2546 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2547 fDecorationBuffer);
2548 }
2549 if (layout.fBinding >= 0) {
2550 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2551 fDecorationBuffer);
2552 }
2553 if (layout.fIndex >= 0) {
2554 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2555 fDecorationBuffer);
2556 }
2557 if (layout.fSet >= 0) {
2558 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2559 fDecorationBuffer);
2560 }
2561 if (layout.fInputAttachmentIndex >= 0) {
2562 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2563 layout.fInputAttachmentIndex, fDecorationBuffer);
2564 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2565 }
2566 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2567 layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
2568 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2569 fDecorationBuffer);
2570 }
2571 }
2572
writeLayout(const Layout & layout,SpvId target,int member)2573 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2574 if (layout.fLocation >= 0) {
2575 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2576 layout.fLocation, fDecorationBuffer);
2577 }
2578 if (layout.fBinding >= 0) {
2579 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2580 layout.fBinding, fDecorationBuffer);
2581 }
2582 if (layout.fIndex >= 0) {
2583 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2584 layout.fIndex, fDecorationBuffer);
2585 }
2586 if (layout.fSet >= 0) {
2587 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2588 layout.fSet, fDecorationBuffer);
2589 }
2590 if (layout.fInputAttachmentIndex >= 0) {
2591 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2592 layout.fInputAttachmentIndex, fDecorationBuffer);
2593 }
2594 if (layout.fBuiltin >= 0) {
2595 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2596 layout.fBuiltin, fDecorationBuffer);
2597 }
2598 }
2599
update_sk_in_count(const Modifiers & m,int * outSkInCount)2600 static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
2601 switch (m.fLayout.fPrimitive) {
2602 case Layout::kPoints_Primitive:
2603 *outSkInCount = 1;
2604 break;
2605 case Layout::kLines_Primitive:
2606 *outSkInCount = 2;
2607 break;
2608 case Layout::kLinesAdjacency_Primitive:
2609 *outSkInCount = 4;
2610 break;
2611 case Layout::kTriangles_Primitive:
2612 *outSkInCount = 3;
2613 break;
2614 case Layout::kTrianglesAdjacency_Primitive:
2615 *outSkInCount = 6;
2616 break;
2617 default:
2618 return;
2619 }
2620 }
2621
writeInterfaceBlock(const InterfaceBlock & intf)2622 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2623 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2624 bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2625 Layout::kPushConstant_Flag));
2626 MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2627 MemoryLayout(MemoryLayout::k430_Standard) :
2628 fDefaultLayout;
2629 SpvId result = this->nextId();
2630 const Type* type = &intf.fVariable.fType;
2631 if (fProgram.fInputs.fRTHeight) {
2632 SkASSERT(fRTHeightStructId == (SpvId) -1);
2633 SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
2634 std::vector<Type::Field> fields = type->fields();
2635 fRTHeightStructId = result;
2636 fRTHeightFieldIndex = fields.size();
2637 fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2638 type = new Type(type->fOffset, type->name(), fields);
2639 }
2640 SpvId typeId;
2641 if (intf.fVariable.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2642 for (const auto& e : fProgram) {
2643 if (e.fKind == ProgramElement::kModifiers_Kind) {
2644 const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
2645 update_sk_in_count(m, &fSkInCount);
2646 }
2647 }
2648 typeId = this->getType(Type("sk_in", Type::kArray_Kind, intf.fVariable.fType.componentType(),
2649 fSkInCount), memoryLayout);
2650 } else {
2651 typeId = this->getType(*type, memoryLayout);
2652 }
2653 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2654 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2655 } else if (intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
2656 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2657 }
2658 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2659 SpvId ptrType = this->nextId();
2660 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2661 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2662 Layout layout = intf.fVariable.fModifiers.fLayout;
2663 if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2664 layout.fSet = 0;
2665 }
2666 this->writeLayout(layout, result);
2667 fVariableMap[&intf.fVariable] = result;
2668 if (fProgram.fInputs.fRTHeight) {
2669 delete type;
2670 }
2671 return result;
2672 }
2673
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2674 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2675 if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2676 (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2677 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2678 }
2679 }
2680
2681 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2682 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2683 OutputStream& out) {
2684 for (size_t i = 0; i < decl.fVars.size(); i++) {
2685 if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2686 continue;
2687 }
2688 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2689 const Variable* var = varDecl.fVar;
2690 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2691 // in the OpenGL backend.
2692 SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2693 Modifiers::kWriteOnly_Flag |
2694 Modifiers::kCoherent_Flag |
2695 Modifiers::kVolatile_Flag |
2696 Modifiers::kRestrict_Flag)));
2697 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2698 continue;
2699 }
2700 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2701 kind != Program::kFragment_Kind) {
2702 SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
2703 continue;
2704 }
2705 if (!var->fReadCount && !var->fWriteCount &&
2706 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2707 Modifiers::kOut_Flag |
2708 Modifiers::kUniform_Flag |
2709 Modifiers::kBuffer_Flag))) {
2710 // variable is dead and not an input / output var (the Vulkan debug layers complain if
2711 // we elide an interface var, even if it's dead)
2712 continue;
2713 }
2714 SpvStorageClass_ storageClass;
2715 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2716 storageClass = SpvStorageClassInput;
2717 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2718 storageClass = SpvStorageClassOutput;
2719 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2720 if (var->fType.kind() == Type::kSampler_Kind) {
2721 storageClass = SpvStorageClassUniformConstant;
2722 } else {
2723 storageClass = SpvStorageClassUniform;
2724 }
2725 } else {
2726 storageClass = SpvStorageClassPrivate;
2727 }
2728 SpvId id = this->nextId();
2729 fVariableMap[var] = id;
2730 SpvId type;
2731 if (var->fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2732 type = this->getPointerType(Type("sk_in", Type::kArray_Kind,
2733 var->fType.componentType(), fSkInCount),
2734 storageClass);
2735 } else {
2736 type = this->getPointerType(var->fType, storageClass);
2737 }
2738 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2739 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2740 this->writePrecisionModifier(var->fModifiers, id);
2741 if (varDecl.fValue) {
2742 SkASSERT(!fCurrentBlock);
2743 fCurrentBlock = -1;
2744 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2745 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2746 fCurrentBlock = 0;
2747 }
2748 this->writeLayout(var->fModifiers.fLayout, id);
2749 if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2750 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2751 }
2752 if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2753 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2754 fDecorationBuffer);
2755 }
2756 }
2757 }
2758
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2759 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2760 for (const auto& stmt : decl.fVars) {
2761 SkASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2762 VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2763 const Variable* var = varDecl.fVar;
2764 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2765 // in the OpenGL backend.
2766 SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2767 Modifiers::kWriteOnly_Flag |
2768 Modifiers::kCoherent_Flag |
2769 Modifiers::kVolatile_Flag |
2770 Modifiers::kRestrict_Flag)));
2771 SpvId id = this->nextId();
2772 fVariableMap[var] = id;
2773 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2774 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2775 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2776 if (varDecl.fValue) {
2777 SpvId value = this->writeExpression(*varDecl.fValue, out);
2778 this->writeInstruction(SpvOpStore, id, value, out);
2779 }
2780 }
2781 }
2782
writeStatement(const Statement & s,OutputStream & out)2783 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2784 switch (s.fKind) {
2785 case Statement::kNop_Kind:
2786 break;
2787 case Statement::kBlock_Kind:
2788 this->writeBlock((Block&) s, out);
2789 break;
2790 case Statement::kExpression_Kind:
2791 this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2792 break;
2793 case Statement::kReturn_Kind:
2794 this->writeReturnStatement((ReturnStatement&) s, out);
2795 break;
2796 case Statement::kVarDeclarations_Kind:
2797 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2798 break;
2799 case Statement::kIf_Kind:
2800 this->writeIfStatement((IfStatement&) s, out);
2801 break;
2802 case Statement::kFor_Kind:
2803 this->writeForStatement((ForStatement&) s, out);
2804 break;
2805 case Statement::kWhile_Kind:
2806 this->writeWhileStatement((WhileStatement&) s, out);
2807 break;
2808 case Statement::kDo_Kind:
2809 this->writeDoStatement((DoStatement&) s, out);
2810 break;
2811 case Statement::kSwitch_Kind:
2812 this->writeSwitchStatement((SwitchStatement&) s, out);
2813 break;
2814 case Statement::kBreak_Kind:
2815 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2816 break;
2817 case Statement::kContinue_Kind:
2818 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2819 break;
2820 case Statement::kDiscard_Kind:
2821 this->writeInstruction(SpvOpKill, out);
2822 break;
2823 default:
2824 ABORT("unsupported statement: %s", s.description().c_str());
2825 }
2826 }
2827
writeBlock(const Block & b,OutputStream & out)2828 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2829 for (size_t i = 0; i < b.fStatements.size(); i++) {
2830 this->writeStatement(*b.fStatements[i], out);
2831 }
2832 }
2833
writeIfStatement(const IfStatement & stmt,OutputStream & out)2834 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2835 SpvId test = this->writeExpression(*stmt.fTest, out);
2836 SpvId ifTrue = this->nextId();
2837 SpvId ifFalse = this->nextId();
2838 if (stmt.fIfFalse) {
2839 SpvId end = this->nextId();
2840 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2841 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2842 this->writeLabel(ifTrue, out);
2843 this->writeStatement(*stmt.fIfTrue, out);
2844 if (fCurrentBlock) {
2845 this->writeInstruction(SpvOpBranch, end, out);
2846 }
2847 this->writeLabel(ifFalse, out);
2848 this->writeStatement(*stmt.fIfFalse, out);
2849 if (fCurrentBlock) {
2850 this->writeInstruction(SpvOpBranch, end, out);
2851 }
2852 this->writeLabel(end, out);
2853 } else {
2854 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2855 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2856 this->writeLabel(ifTrue, out);
2857 this->writeStatement(*stmt.fIfTrue, out);
2858 if (fCurrentBlock) {
2859 this->writeInstruction(SpvOpBranch, ifFalse, out);
2860 }
2861 this->writeLabel(ifFalse, out);
2862 }
2863 }
2864
writeForStatement(const ForStatement & f,OutputStream & out)2865 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2866 if (f.fInitializer) {
2867 this->writeStatement(*f.fInitializer, out);
2868 }
2869 SpvId header = this->nextId();
2870 SpvId start = this->nextId();
2871 SpvId body = this->nextId();
2872 SpvId next = this->nextId();
2873 fContinueTarget.push(next);
2874 SpvId end = this->nextId();
2875 fBreakTarget.push(end);
2876 this->writeInstruction(SpvOpBranch, header, out);
2877 this->writeLabel(header, out);
2878 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2879 this->writeInstruction(SpvOpBranch, start, out);
2880 this->writeLabel(start, out);
2881 if (f.fTest) {
2882 SpvId test = this->writeExpression(*f.fTest, out);
2883 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2884 }
2885 this->writeLabel(body, out);
2886 this->writeStatement(*f.fStatement, out);
2887 if (fCurrentBlock) {
2888 this->writeInstruction(SpvOpBranch, next, out);
2889 }
2890 this->writeLabel(next, out);
2891 if (f.fNext) {
2892 this->writeExpression(*f.fNext, out);
2893 }
2894 this->writeInstruction(SpvOpBranch, header, out);
2895 this->writeLabel(end, out);
2896 fBreakTarget.pop();
2897 fContinueTarget.pop();
2898 }
2899
writeWhileStatement(const WhileStatement & w,OutputStream & out)2900 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2901 // We believe the while loop code below will work, but Skia doesn't actually use them and
2902 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2903 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2904 // message, simply remove the error call below to see whether our while loop support actually
2905 // works.
2906 fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2907 "see SkSLSPIRVCodeGenerator.cpp for details");
2908
2909 SpvId header = this->nextId();
2910 SpvId start = this->nextId();
2911 SpvId body = this->nextId();
2912 fContinueTarget.push(start);
2913 SpvId end = this->nextId();
2914 fBreakTarget.push(end);
2915 this->writeInstruction(SpvOpBranch, header, out);
2916 this->writeLabel(header, out);
2917 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2918 this->writeInstruction(SpvOpBranch, start, out);
2919 this->writeLabel(start, out);
2920 SpvId test = this->writeExpression(*w.fTest, out);
2921 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2922 this->writeLabel(body, out);
2923 this->writeStatement(*w.fStatement, out);
2924 if (fCurrentBlock) {
2925 this->writeInstruction(SpvOpBranch, start, out);
2926 }
2927 this->writeLabel(end, out);
2928 fBreakTarget.pop();
2929 fContinueTarget.pop();
2930 }
2931
writeDoStatement(const DoStatement & d,OutputStream & out)2932 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2933 // We believe the do loop code below will work, but Skia doesn't actually use them and
2934 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2935 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2936 // message, simply remove the error call below to see whether our do loop support actually
2937 // works.
2938 fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2939 "SkSLSPIRVCodeGenerator.cpp for details");
2940
2941 SpvId header = this->nextId();
2942 SpvId start = this->nextId();
2943 SpvId next = this->nextId();
2944 fContinueTarget.push(next);
2945 SpvId end = this->nextId();
2946 fBreakTarget.push(end);
2947 this->writeInstruction(SpvOpBranch, header, out);
2948 this->writeLabel(header, out);
2949 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2950 this->writeInstruction(SpvOpBranch, start, out);
2951 this->writeLabel(start, out);
2952 this->writeStatement(*d.fStatement, out);
2953 if (fCurrentBlock) {
2954 this->writeInstruction(SpvOpBranch, next, out);
2955 }
2956 this->writeLabel(next, out);
2957 SpvId test = this->writeExpression(*d.fTest, out);
2958 this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2959 this->writeLabel(end, out);
2960 fBreakTarget.pop();
2961 fContinueTarget.pop();
2962 }
2963
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2964 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2965 SpvId value = this->writeExpression(*s.fValue, out);
2966 std::vector<SpvId> labels;
2967 SpvId end = this->nextId();
2968 SpvId defaultLabel = end;
2969 fBreakTarget.push(end);
2970 int size = 3;
2971 for (const auto& c : s.fCases) {
2972 SpvId label = this->nextId();
2973 labels.push_back(label);
2974 if (c->fValue) {
2975 size += 2;
2976 } else {
2977 defaultLabel = label;
2978 }
2979 }
2980 labels.push_back(end);
2981 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2982 this->writeOpCode(SpvOpSwitch, size, out);
2983 this->writeWord(value, out);
2984 this->writeWord(defaultLabel, out);
2985 for (size_t i = 0; i < s.fCases.size(); ++i) {
2986 if (!s.fCases[i]->fValue) {
2987 continue;
2988 }
2989 SkASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2990 this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2991 this->writeWord(labels[i], out);
2992 }
2993 for (size_t i = 0; i < s.fCases.size(); ++i) {
2994 this->writeLabel(labels[i], out);
2995 for (const auto& stmt : s.fCases[i]->fStatements) {
2996 this->writeStatement(*stmt, out);
2997 }
2998 if (fCurrentBlock) {
2999 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3000 }
3001 }
3002 this->writeLabel(end, out);
3003 fBreakTarget.pop();
3004 }
3005
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3006 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3007 if (r.fExpression) {
3008 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3009 out);
3010 } else {
3011 this->writeInstruction(SpvOpReturn, out);
3012 }
3013 }
3014
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)3015 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
3016 SkASSERT(fProgram.fKind == Program::kGeometry_Kind);
3017 int invocations = 1;
3018 for (const auto& e : fProgram) {
3019 if (e.fKind == ProgramElement::kModifiers_Kind) {
3020 const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3021 if (m.fFlags & Modifiers::kIn_Flag) {
3022 if (m.fLayout.fInvocations != -1) {
3023 invocations = m.fLayout.fInvocations;
3024 }
3025 SpvId input;
3026 switch (m.fLayout.fPrimitive) {
3027 case Layout::kPoints_Primitive:
3028 input = SpvExecutionModeInputPoints;
3029 break;
3030 case Layout::kLines_Primitive:
3031 input = SpvExecutionModeInputLines;
3032 break;
3033 case Layout::kLinesAdjacency_Primitive:
3034 input = SpvExecutionModeInputLinesAdjacency;
3035 break;
3036 case Layout::kTriangles_Primitive:
3037 input = SpvExecutionModeTriangles;
3038 break;
3039 case Layout::kTrianglesAdjacency_Primitive:
3040 input = SpvExecutionModeInputTrianglesAdjacency;
3041 break;
3042 default:
3043 input = 0;
3044 break;
3045 }
3046 update_sk_in_count(m, &fSkInCount);
3047 if (input) {
3048 this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
3049 }
3050 } else if (m.fFlags & Modifiers::kOut_Flag) {
3051 SpvId output;
3052 switch (m.fLayout.fPrimitive) {
3053 case Layout::kPoints_Primitive:
3054 output = SpvExecutionModeOutputPoints;
3055 break;
3056 case Layout::kLineStrip_Primitive:
3057 output = SpvExecutionModeOutputLineStrip;
3058 break;
3059 case Layout::kTriangleStrip_Primitive:
3060 output = SpvExecutionModeOutputTriangleStrip;
3061 break;
3062 default:
3063 output = 0;
3064 break;
3065 }
3066 if (output) {
3067 this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
3068 }
3069 if (m.fLayout.fMaxVertices != -1) {
3070 this->writeInstruction(SpvOpExecutionMode, entryPoint,
3071 SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
3072 out);
3073 }
3074 }
3075 }
3076 }
3077 this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
3078 invocations, out);
3079 }
3080
writeInstructions(const Program & program,OutputStream & out)3081 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3082 fGLSLExtendedInstructions = this->nextId();
3083 StringStream body;
3084 std::set<SpvId> interfaceVars;
3085 // assign IDs to functions, determine sk_in size
3086 int skInSize = -1;
3087 for (const auto& e : program) {
3088 switch (e.fKind) {
3089 case ProgramElement::kFunction_Kind: {
3090 FunctionDefinition& f = (FunctionDefinition&) e;
3091 fFunctionMap[&f.fDeclaration] = this->nextId();
3092 break;
3093 }
3094 case ProgramElement::kModifiers_Kind: {
3095 Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3096 if (m.fFlags & Modifiers::kIn_Flag) {
3097 switch (m.fLayout.fPrimitive) {
3098 case Layout::kPoints_Primitive: // break
3099 case Layout::kLines_Primitive:
3100 skInSize = 1;
3101 break;
3102 case Layout::kLinesAdjacency_Primitive: // break
3103 skInSize = 2;
3104 break;
3105 case Layout::kTriangles_Primitive: // break
3106 case Layout::kTrianglesAdjacency_Primitive:
3107 skInSize = 3;
3108 break;
3109 default:
3110 break;
3111 }
3112 }
3113 break;
3114 }
3115 default:
3116 break;
3117 }
3118 }
3119 for (const auto& e : program) {
3120 if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
3121 InterfaceBlock& intf = (InterfaceBlock&) e;
3122 if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
3123 SkASSERT(skInSize != -1);
3124 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
3125 }
3126 SpvId id = this->writeInterfaceBlock(intf);
3127 if (((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3128 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) &&
3129 intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
3130 interfaceVars.insert(id);
3131 }
3132 }
3133 }
3134 for (const auto& e : program) {
3135 if (e.fKind == ProgramElement::kVar_Kind) {
3136 this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
3137 }
3138 }
3139 for (const auto& e : program) {
3140 if (e.fKind == ProgramElement::kFunction_Kind) {
3141 this->writeFunction(((FunctionDefinition&) e), body);
3142 }
3143 }
3144 const FunctionDeclaration* main = nullptr;
3145 for (auto entry : fFunctionMap) {
3146 if (entry.first->fName == "main") {
3147 main = entry.first;
3148 }
3149 }
3150 SkASSERT(main);
3151 for (auto entry : fVariableMap) {
3152 const Variable* var = entry.first;
3153 if (var->fStorage == Variable::kGlobal_Storage &&
3154 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3155 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3156 interfaceVars.insert(entry.second);
3157 }
3158 }
3159 this->writeCapabilities(out);
3160 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3161 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3162 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
3163 (int32_t) interfaceVars.size(), out);
3164 switch (program.fKind) {
3165 case Program::kVertex_Kind:
3166 this->writeWord(SpvExecutionModelVertex, out);
3167 break;
3168 case Program::kFragment_Kind:
3169 this->writeWord(SpvExecutionModelFragment, out);
3170 break;
3171 case Program::kGeometry_Kind:
3172 this->writeWord(SpvExecutionModelGeometry, out);
3173 break;
3174 default:
3175 ABORT("cannot write this kind of program to SPIR-V\n");
3176 }
3177 SpvId entryPoint = fFunctionMap[main];
3178 this->writeWord(entryPoint, out);
3179 this->writeString(main->fName.fChars, main->fName.fLength, out);
3180 for (int var : interfaceVars) {
3181 this->writeWord(var, out);
3182 }
3183 if (program.fKind == Program::kGeometry_Kind) {
3184 this->writeGeometryShaderExecutionMode(entryPoint, out);
3185 }
3186 if (program.fKind == Program::kFragment_Kind) {
3187 this->writeInstruction(SpvOpExecutionMode,
3188 fFunctionMap[main],
3189 SpvExecutionModeOriginUpperLeft,
3190 out);
3191 }
3192 for (const auto& e : program) {
3193 if (e.fKind == ProgramElement::kExtension_Kind) {
3194 this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
3195 }
3196 }
3197
3198 write_stringstream(fExtraGlobalsBuffer, out);
3199 write_stringstream(fNameBuffer, out);
3200 write_stringstream(fDecorationBuffer, out);
3201 write_stringstream(fConstantBuffer, out);
3202 write_stringstream(fExternalFunctionsBuffer, out);
3203 write_stringstream(body, out);
3204 }
3205
generateCode()3206 bool SPIRVCodeGenerator::generateCode() {
3207 SkASSERT(!fErrors.errorCount());
3208 this->writeWord(SpvMagicNumber, *fOut);
3209 this->writeWord(SpvVersion, *fOut);
3210 this->writeWord(SKSL_MAGIC, *fOut);
3211 StringStream buffer;
3212 this->writeInstructions(fProgram, buffer);
3213 this->writeWord(fIdCount, *fOut);
3214 this->writeWord(0, *fOut); // reserved, always zero
3215 write_stringstream(buffer, *fOut);
3216 return 0 == fErrors.errorCount();
3217 }
3218
3219 }
3220