1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLSPIRVCodeGenerator.h"
9
10 #include "GLSL.std.450.h"
11
12 #include "ir/SkSLExpressionStatement.h"
13 #include "ir/SkSLExtension.h"
14 #include "ir/SkSLIndexExpression.h"
15 #include "ir/SkSLVariableReference.h"
16 #include "SkSLCompiler.h"
17
18 namespace SkSL {
19
20 static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number
21
setupIntrinsics()22 void SPIRVCodeGenerator::setupIntrinsics() {
23 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24 GLSLstd450 ## x, GLSLstd450 ## x)
25 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26 GLSLstd450 ## ifFloat, \
27 GLSLstd450 ## ifInt, \
28 GLSLstd450 ## ifUInt, \
29 SpvOpUndef)
30 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31 SpvOp ## x)
32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34 k ## x ## _SpecialIntrinsic)
35 fIntrinsicMap[String("round")] = ALL_GLSL(Round);
36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven);
37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc);
38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign);
40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor);
41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil);
42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract);
43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians);
44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees);
45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin);
46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos);
47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan);
48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin);
49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos);
50 fIntrinsicMap[String("atan")] = SPECIAL(Atan);
51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh);
52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh);
53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh);
54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh);
55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh);
56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh);
57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow);
58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp);
59 fIntrinsicMap[String("log")] = ALL_GLSL(Log);
60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2);
61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2);
62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt);
63 fIntrinsicMap[String("inverse")] = ALL_GLSL(MatrixInverse);
64 fIntrinsicMap[String("transpose")] = ALL_SPIRV(Transpose);
65 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt);
66 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant);
67 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68 fIntrinsicMap[String("mod")] = SPECIAL(Mod);
69 fIntrinsicMap[String("min")] = SPECIAL(Min);
70 fIntrinsicMap[String("max")] = SPECIAL(Max);
71 fIntrinsicMap[String("clamp")] = SPECIAL(Clamp);
72 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
73 SpvOpUndef, SpvOpUndef, SpvOpUndef);
74 fIntrinsicMap[String("mix")] = SPECIAL(Mix);
75 fIntrinsicMap[String("step")] = ALL_GLSL(Step);
76 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep);
77 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma);
78 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp);
79 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp);
80
81 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
82 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
83 PACK(Snorm4x8);
84 PACK(Unorm4x8);
85 PACK(Snorm2x16);
86 PACK(Unorm2x16);
87 PACK(Half2x16);
88 PACK(Double2x32);
89 fIntrinsicMap[String("length")] = ALL_GLSL(Length);
90 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance);
91 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross);
92 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize);
93 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
94 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect);
95 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract);
96 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb);
97 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
98 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
99 SpvOpUndef, SpvOpUndef, SpvOpUndef);
100 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
101 SpvOpUndef, SpvOpUndef, SpvOpUndef);
102 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
103 SpvOpUndef, SpvOpUndef, SpvOpUndef);
104 fIntrinsicMap[String("texture")] = SPECIAL(Texture);
105 fIntrinsicMap[String("texelFetch")] = SPECIAL(TexelFetch);
106 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107
108 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109 SpvOpUndef, SpvOpUndef, SpvOpAny);
110 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111 SpvOpUndef, SpvOpUndef, SpvOpAll);
112 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind,
113 SpvOpFOrdEqual, SpvOpIEqual,
114 SpvOpIEqual, SpvOpLogicalEqual);
115 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
116 SpvOpFOrdNotEqual, SpvOpINotEqual,
117 SpvOpINotEqual,
118 SpvOpLogicalNotEqual);
119 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
120 SpvOpFOrdLessThan, SpvOpSLessThan,
121 SpvOpULessThan, SpvOpUndef);
122 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
123 SpvOpFOrdLessThanEqual,
124 SpvOpSLessThanEqual,
125 SpvOpULessThanEqual,
126 SpvOpUndef);
127 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
128 SpvOpFOrdGreaterThan,
129 SpvOpSGreaterThan,
130 SpvOpUGreaterThan,
131 SpvOpUndef);
132 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133 SpvOpFOrdGreaterThanEqual,
134 SpvOpSGreaterThanEqual,
135 SpvOpUGreaterThanEqual,
136 SpvOpUndef);
137 fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex);
138 fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive);
139 // interpolateAt* not yet supported...
140 }
141
writeWord(int32_t word,OutputStream & out)142 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143 out.write((const char*) &word, sizeof(word));
144 }
145
is_float(const Context & context,const Type & type)146 static bool is_float(const Context& context, const Type& type) {
147 if (type.kind() == Type::kVector_Kind) {
148 return is_float(context, type.componentType());
149 }
150 return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151 type == *context.fDouble_Type;
152 }
153
is_signed(const Context & context,const Type & type)154 static bool is_signed(const Context& context, const Type& type) {
155 if (type.kind() == Type::kVector_Kind) {
156 return is_signed(context, type.componentType());
157 }
158 return type == *context.fInt_Type || type == *context.fShort_Type;
159 }
160
is_unsigned(const Context & context,const Type & type)161 static bool is_unsigned(const Context& context, const Type& type) {
162 if (type.kind() == Type::kVector_Kind) {
163 return is_unsigned(context, type.componentType());
164 }
165 return type == *context.fUInt_Type || type == *context.fUShort_Type;
166 }
167
is_bool(const Context & context,const Type & type)168 static bool is_bool(const Context& context, const Type& type) {
169 if (type.kind() == Type::kVector_Kind) {
170 return is_bool(context, type.componentType());
171 }
172 return type == *context.fBool_Type;
173 }
174
is_out(const Variable & var)175 static bool is_out(const Variable& var) {
176 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177 }
178
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)179 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
180 ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
181 ASSERT(opCode != SpvOpUndef);
182 switch (opCode) {
183 case SpvOpReturn: // fall through
184 case SpvOpReturnValue: // fall through
185 case SpvOpKill: // fall through
186 case SpvOpBranch: // fall through
187 case SpvOpBranchConditional:
188 ASSERT(fCurrentBlock);
189 fCurrentBlock = 0;
190 break;
191 case SpvOpConstant: // fall through
192 case SpvOpConstantTrue: // fall through
193 case SpvOpConstantFalse: // fall through
194 case SpvOpConstantComposite: // fall through
195 case SpvOpTypeVoid: // fall through
196 case SpvOpTypeInt: // fall through
197 case SpvOpTypeFloat: // fall through
198 case SpvOpTypeBool: // fall through
199 case SpvOpTypeVector: // fall through
200 case SpvOpTypeMatrix: // fall through
201 case SpvOpTypeArray: // fall through
202 case SpvOpTypePointer: // fall through
203 case SpvOpTypeFunction: // fall through
204 case SpvOpTypeRuntimeArray: // fall through
205 case SpvOpTypeStruct: // fall through
206 case SpvOpTypeImage: // fall through
207 case SpvOpTypeSampledImage: // fall through
208 case SpvOpVariable: // fall through
209 case SpvOpFunction: // fall through
210 case SpvOpFunctionParameter: // fall through
211 case SpvOpFunctionEnd: // fall through
212 case SpvOpExecutionMode: // fall through
213 case SpvOpMemoryModel: // fall through
214 case SpvOpCapability: // fall through
215 case SpvOpExtInstImport: // fall through
216 case SpvOpEntryPoint: // fall through
217 case SpvOpSource: // fall through
218 case SpvOpSourceExtension: // fall through
219 case SpvOpName: // fall through
220 case SpvOpMemberName: // fall through
221 case SpvOpDecorate: // fall through
222 case SpvOpMemberDecorate:
223 break;
224 default:
225 ASSERT(fCurrentBlock);
226 }
227 this->writeWord((length << 16) | opCode, out);
228 }
229
writeLabel(SpvId label,OutputStream & out)230 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
231 fCurrentBlock = label;
232 this->writeInstruction(SpvOpLabel, label, out);
233 }
234
writeInstruction(SpvOp_ opCode,OutputStream & out)235 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
236 this->writeOpCode(opCode, 1, out);
237 }
238
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)239 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
240 this->writeOpCode(opCode, 2, out);
241 this->writeWord(word1, out);
242 }
243
writeString(const char * string,size_t length,OutputStream & out)244 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
245 out.write(string, length);
246 switch (length % 4) {
247 case 1:
248 out.write8(0);
249 // fall through
250 case 2:
251 out.write8(0);
252 // fall through
253 case 3:
254 out.write8(0);
255 break;
256 default:
257 this->writeWord(0, out);
258 }
259 }
260
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)261 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
262 this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
263 this->writeString(string.fChars, string.fLength, out);
264 }
265
266
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)267 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
268 OutputStream& out) {
269 this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
270 this->writeWord(word1, out);
271 this->writeString(string.fChars, string.fLength, out);
272 }
273
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)274 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
275 StringFragment string, OutputStream& out) {
276 this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
277 this->writeWord(word1, out);
278 this->writeWord(word2, out);
279 this->writeString(string.fChars, string.fLength, out);
280 }
281
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)282 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283 OutputStream& out) {
284 this->writeOpCode(opCode, 3, out);
285 this->writeWord(word1, out);
286 this->writeWord(word2, out);
287 }
288
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)289 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
290 int32_t word3, OutputStream& out) {
291 this->writeOpCode(opCode, 4, out);
292 this->writeWord(word1, out);
293 this->writeWord(word2, out);
294 this->writeWord(word3, out);
295 }
296
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)297 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298 int32_t word3, int32_t word4, OutputStream& out) {
299 this->writeOpCode(opCode, 5, out);
300 this->writeWord(word1, out);
301 this->writeWord(word2, out);
302 this->writeWord(word3, out);
303 this->writeWord(word4, out);
304 }
305
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)306 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
307 int32_t word3, int32_t word4, int32_t word5,
308 OutputStream& out) {
309 this->writeOpCode(opCode, 6, out);
310 this->writeWord(word1, out);
311 this->writeWord(word2, out);
312 this->writeWord(word3, out);
313 this->writeWord(word4, out);
314 this->writeWord(word5, out);
315 }
316
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)317 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
318 int32_t word3, int32_t word4, int32_t word5,
319 int32_t word6, OutputStream& out) {
320 this->writeOpCode(opCode, 7, out);
321 this->writeWord(word1, out);
322 this->writeWord(word2, out);
323 this->writeWord(word3, out);
324 this->writeWord(word4, out);
325 this->writeWord(word5, out);
326 this->writeWord(word6, out);
327 }
328
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)329 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
330 int32_t word3, int32_t word4, int32_t word5,
331 int32_t word6, int32_t word7, OutputStream& out) {
332 this->writeOpCode(opCode, 8, out);
333 this->writeWord(word1, out);
334 this->writeWord(word2, out);
335 this->writeWord(word3, out);
336 this->writeWord(word4, out);
337 this->writeWord(word5, out);
338 this->writeWord(word6, out);
339 this->writeWord(word7, out);
340 }
341
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)342 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
343 int32_t word3, int32_t word4, int32_t word5,
344 int32_t word6, int32_t word7, int32_t word8,
345 OutputStream& out) {
346 this->writeOpCode(opCode, 9, out);
347 this->writeWord(word1, out);
348 this->writeWord(word2, out);
349 this->writeWord(word3, out);
350 this->writeWord(word4, out);
351 this->writeWord(word5, out);
352 this->writeWord(word6, out);
353 this->writeWord(word7, out);
354 this->writeWord(word8, out);
355 }
356
writeCapabilities(OutputStream & out)357 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
358 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
359 if (fCapabilities & bit) {
360 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
361 }
362 }
363 if (fProgram.fKind == Program::kGeometry_Kind) {
364 this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
365 }
366 }
367
nextId()368 SpvId SPIRVCodeGenerator::nextId() {
369 return fIdCount++;
370 }
371
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)372 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
373 SpvId resultId) {
374 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
375 // go ahead and write all of the field types, so we don't inadvertently write them while we're
376 // in the middle of writing the struct instruction
377 std::vector<SpvId> types;
378 for (const auto& f : type.fields()) {
379 types.push_back(this->getType(*f.fType, memoryLayout));
380 }
381 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
382 this->writeWord(resultId, fConstantBuffer);
383 for (SpvId id : types) {
384 this->writeWord(id, fConstantBuffer);
385 }
386 size_t offset = 0;
387 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
388 size_t size = memoryLayout.size(*type.fields()[i].fType);
389 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
390 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
391 if (fieldLayout.fOffset >= 0) {
392 if (fieldLayout.fOffset < (int) offset) {
393 fErrors.error(type.fOffset,
394 "offset of field '" + type.fields()[i].fName + "' must be at "
395 "least " + to_string((int) offset));
396 }
397 if (fieldLayout.fOffset % alignment) {
398 fErrors.error(type.fOffset,
399 "offset of field '" + type.fields()[i].fName + "' must be a multiple"
400 " of " + to_string((int) alignment));
401 }
402 offset = fieldLayout.fOffset;
403 } else {
404 size_t mod = offset % alignment;
405 if (mod) {
406 offset += alignment - mod;
407 }
408 }
409 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
410 this->writeLayout(fieldLayout, resultId, i);
411 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
412 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
413 (SpvId) offset, fDecorationBuffer);
414 }
415 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
416 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
417 fDecorationBuffer);
418 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
419 (SpvId) memoryLayout.stride(*type.fields()[i].fType),
420 fDecorationBuffer);
421 }
422 offset += size;
423 Type::Kind kind = type.fields()[i].fType->kind();
424 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
425 offset += alignment - offset % alignment;
426 }
427 }
428 }
429
getActualType(const Type & type)430 Type SPIRVCodeGenerator::getActualType(const Type& type) {
431 if (type == *fContext.fHalf_Type) {
432 return *fContext.fFloat_Type;
433 }
434 if (type == *fContext.fShort_Type) {
435 return *fContext.fInt_Type;
436 }
437 if (type == *fContext.fUShort_Type) {
438 return *fContext.fUInt_Type;
439 }
440 if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
441 if (type.componentType() == *fContext.fHalf_Type) {
442 return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
443 }
444 if (type.componentType() == *fContext.fShort_Type) {
445 return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
446 }
447 if (type.componentType() == *fContext.fUShort_Type) {
448 return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
449 }
450 }
451 return type;
452 }
453
getType(const Type & type)454 SpvId SPIRVCodeGenerator::getType(const Type& type) {
455 return this->getType(type, fDefaultLayout);
456 }
457
getType(const Type & rawType,const MemoryLayout & layout)458 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
459 Type type = this->getActualType(rawType);
460 String key = type.name() + to_string((int) layout.fStd);
461 auto entry = fTypeMap.find(key);
462 if (entry == fTypeMap.end()) {
463 SpvId result = this->nextId();
464 switch (type.kind()) {
465 case Type::kScalar_Kind:
466 if (type == *fContext.fBool_Type) {
467 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
468 } else if (type == *fContext.fInt_Type) {
469 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
470 } else if (type == *fContext.fUInt_Type) {
471 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
472 } else if (type == *fContext.fFloat_Type) {
473 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
474 } else if (type == *fContext.fDouble_Type) {
475 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
476 } else {
477 ASSERT(false);
478 }
479 break;
480 case Type::kVector_Kind:
481 this->writeInstruction(SpvOpTypeVector, result,
482 this->getType(type.componentType(), layout),
483 type.columns(), fConstantBuffer);
484 break;
485 case Type::kMatrix_Kind:
486 this->writeInstruction(SpvOpTypeMatrix, result,
487 this->getType(index_type(fContext, type), layout),
488 type.columns(), fConstantBuffer);
489 break;
490 case Type::kStruct_Kind:
491 this->writeStruct(type, layout, result);
492 break;
493 case Type::kArray_Kind: {
494 if (type.columns() > 0) {
495 IntLiteral count(fContext, -1, type.columns());
496 this->writeInstruction(SpvOpTypeArray, result,
497 this->getType(type.componentType(), layout),
498 this->writeIntLiteral(count), fConstantBuffer);
499 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
500 (int32_t) layout.stride(type),
501 fDecorationBuffer);
502 } else {
503 this->writeInstruction(SpvOpTypeRuntimeArray, result,
504 this->getType(type.componentType(), layout),
505 fConstantBuffer);
506 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507 (int32_t) layout.stride(type),
508 fDecorationBuffer);
509 }
510 break;
511 }
512 case Type::kSampler_Kind: {
513 SpvId image = result;
514 if (SpvDimSubpassData != type.dimensions()) {
515 image = this->nextId();
516 }
517 if (SpvDimBuffer == type.dimensions()) {
518 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
519 }
520 this->writeInstruction(SpvOpTypeImage, image,
521 this->getType(*fContext.fFloat_Type, layout),
522 type.dimensions(), type.isDepth(), type.isArrayed(),
523 type.isMultisampled(), type.isSampled() ? 1 : 2,
524 SpvImageFormatUnknown, fConstantBuffer);
525 fImageTypeMap[key] = image;
526 if (SpvDimSubpassData != type.dimensions()) {
527 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
528 }
529 break;
530 }
531 default:
532 if (type == *fContext.fVoid_Type) {
533 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
534 } else {
535 ABORT("invalid type: %s", type.description().c_str());
536 }
537 }
538 fTypeMap[key] = result;
539 return result;
540 }
541 return entry->second;
542 }
543
getImageType(const Type & type)544 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
545 ASSERT(type.kind() == Type::kSampler_Kind);
546 this->getType(type);
547 String key = type.name() + to_string((int) fDefaultLayout.fStd);
548 ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
549 return fImageTypeMap[key];
550 }
551
getFunctionType(const FunctionDeclaration & function)552 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
553 String key = function.fReturnType.description() + "(";
554 String separator;
555 for (size_t i = 0; i < function.fParameters.size(); i++) {
556 key += separator;
557 separator = ", ";
558 key += function.fParameters[i]->fType.description();
559 }
560 key += ")";
561 auto entry = fTypeMap.find(key);
562 if (entry == fTypeMap.end()) {
563 SpvId result = this->nextId();
564 int32_t length = 3 + (int32_t) function.fParameters.size();
565 SpvId returnType = this->getType(function.fReturnType);
566 std::vector<SpvId> parameterTypes;
567 for (size_t i = 0; i < function.fParameters.size(); i++) {
568 // glslang seems to treat all function arguments as pointers whether they need to be or
569 // not. I was initially puzzled by this until I ran bizarre failures with certain
570 // patterns of function calls and control constructs, as exemplified by this minimal
571 // failure case:
572 //
573 // void sphere(float x) {
574 // }
575 //
576 // void map() {
577 // sphere(1.0);
578 // }
579 //
580 // void main() {
581 // for (int i = 0; i < 1; i++) {
582 // map();
583 // }
584 // }
585 //
586 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
587 // crashes. Making it take a float* and storing the argument in a temporary variable,
588 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
589 // the spec makes this make sense.
590 // if (is_out(function->fParameters[i])) {
591 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
592 SpvStorageClassFunction));
593 // } else {
594 // parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
595 // }
596 }
597 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
598 this->writeWord(result, fConstantBuffer);
599 this->writeWord(returnType, fConstantBuffer);
600 for (SpvId id : parameterTypes) {
601 this->writeWord(id, fConstantBuffer);
602 }
603 fTypeMap[key] = result;
604 return result;
605 }
606 return entry->second;
607 }
608
getPointerType(const Type & type,SpvStorageClass_ storageClass)609 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
610 return this->getPointerType(type, fDefaultLayout, storageClass);
611 }
612
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)613 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
614 SpvStorageClass_ storageClass) {
615 Type type = this->getActualType(rawType);
616 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
617 auto entry = fTypeMap.find(key);
618 if (entry == fTypeMap.end()) {
619 SpvId result = this->nextId();
620 this->writeInstruction(SpvOpTypePointer, result, storageClass,
621 this->getType(type), fConstantBuffer);
622 fTypeMap[key] = result;
623 return result;
624 }
625 return entry->second;
626 }
627
writeExpression(const Expression & expr,OutputStream & out)628 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
629 switch (expr.fKind) {
630 case Expression::kBinary_Kind:
631 return this->writeBinaryExpression((BinaryExpression&) expr, out);
632 case Expression::kBoolLiteral_Kind:
633 return this->writeBoolLiteral((BoolLiteral&) expr);
634 case Expression::kConstructor_Kind:
635 return this->writeConstructor((Constructor&) expr, out);
636 case Expression::kIntLiteral_Kind:
637 return this->writeIntLiteral((IntLiteral&) expr);
638 case Expression::kFieldAccess_Kind:
639 return this->writeFieldAccess(((FieldAccess&) expr), out);
640 case Expression::kFloatLiteral_Kind:
641 return this->writeFloatLiteral(((FloatLiteral&) expr));
642 case Expression::kFunctionCall_Kind:
643 return this->writeFunctionCall((FunctionCall&) expr, out);
644 case Expression::kPrefix_Kind:
645 return this->writePrefixExpression((PrefixExpression&) expr, out);
646 case Expression::kPostfix_Kind:
647 return this->writePostfixExpression((PostfixExpression&) expr, out);
648 case Expression::kSwizzle_Kind:
649 return this->writeSwizzle((Swizzle&) expr, out);
650 case Expression::kVariableReference_Kind:
651 return this->writeVariableReference((VariableReference&) expr, out);
652 case Expression::kTernary_Kind:
653 return this->writeTernaryExpression((TernaryExpression&) expr, out);
654 case Expression::kIndex_Kind:
655 return this->writeIndexExpression((IndexExpression&) expr, out);
656 default:
657 ABORT("unsupported expression: %s", expr.description().c_str());
658 }
659 return -1;
660 }
661
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)662 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
663 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
664 ASSERT(intrinsic != fIntrinsicMap.end());
665 int32_t intrinsicId;
666 if (c.fArguments.size() > 0) {
667 const Type& type = c.fArguments[0]->fType;
668 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
669 intrinsicId = std::get<1>(intrinsic->second);
670 } else if (is_signed(fContext, type)) {
671 intrinsicId = std::get<2>(intrinsic->second);
672 } else if (is_unsigned(fContext, type)) {
673 intrinsicId = std::get<3>(intrinsic->second);
674 } else if (is_bool(fContext, type)) {
675 intrinsicId = std::get<4>(intrinsic->second);
676 } else {
677 intrinsicId = std::get<1>(intrinsic->second);
678 }
679 } else {
680 intrinsicId = std::get<1>(intrinsic->second);
681 }
682 switch (std::get<0>(intrinsic->second)) {
683 case kGLSL_STD_450_IntrinsicKind: {
684 SpvId result = this->nextId();
685 std::vector<SpvId> arguments;
686 for (size_t i = 0; i < c.fArguments.size(); i++) {
687 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
688 }
689 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
690 this->writeWord(this->getType(c.fType), out);
691 this->writeWord(result, out);
692 this->writeWord(fGLSLExtendedInstructions, out);
693 this->writeWord(intrinsicId, out);
694 for (SpvId id : arguments) {
695 this->writeWord(id, out);
696 }
697 return result;
698 }
699 case kSPIRV_IntrinsicKind: {
700 SpvId result = this->nextId();
701 std::vector<SpvId> arguments;
702 for (size_t i = 0; i < c.fArguments.size(); i++) {
703 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
704 }
705 if (c.fType != *fContext.fVoid_Type) {
706 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
707 this->writeWord(this->getType(c.fType), out);
708 this->writeWord(result, out);
709 } else {
710 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
711 }
712 for (SpvId id : arguments) {
713 this->writeWord(id, out);
714 }
715 return result;
716 }
717 case kSpecial_IntrinsicKind:
718 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
719 default:
720 ABORT("unsupported intrinsic kind");
721 }
722 }
723
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)724 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
725 const std::vector<std::unique_ptr<Expression>>& args,
726 OutputStream& out) {
727 int vectorSize = 0;
728 for (const auto& a : args) {
729 if (a->fType.kind() == Type::kVector_Kind) {
730 if (vectorSize) {
731 ASSERT(a->fType.columns() == vectorSize);
732 }
733 else {
734 vectorSize = a->fType.columns();
735 }
736 }
737 }
738 std::vector<SpvId> result;
739 for (const auto& a : args) {
740 SpvId raw = this->writeExpression(*a, out);
741 if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
742 SpvId vector = this->nextId();
743 this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
744 this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
745 this->writeWord(vector, out);
746 for (int i = 0; i < vectorSize; i++) {
747 this->writeWord(raw, out);
748 }
749 result.push_back(vector);
750 } else {
751 result.push_back(raw);
752 }
753 }
754 return result;
755 }
756
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)757 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
758 SpvId signedInst, SpvId unsignedInst,
759 const std::vector<SpvId>& args,
760 OutputStream& out) {
761 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
762 this->writeWord(this->getType(type), out);
763 this->writeWord(id, out);
764 this->writeWord(fGLSLExtendedInstructions, out);
765
766 if (is_float(fContext, type)) {
767 this->writeWord(floatInst, out);
768 } else if (is_signed(fContext, type)) {
769 this->writeWord(signedInst, out);
770 } else if (is_unsigned(fContext, type)) {
771 this->writeWord(unsignedInst, out);
772 } else {
773 ASSERT(false);
774 }
775 for (SpvId a : args) {
776 this->writeWord(a, out);
777 }
778 }
779
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)780 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
781 OutputStream& out) {
782 SpvId result = this->nextId();
783 switch (kind) {
784 case kAtan_SpecialIntrinsic: {
785 std::vector<SpvId> arguments;
786 for (size_t i = 0; i < c.fArguments.size(); i++) {
787 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
788 }
789 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
790 this->writeWord(this->getType(c.fType), out);
791 this->writeWord(result, out);
792 this->writeWord(fGLSLExtendedInstructions, out);
793 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
794 for (SpvId id : arguments) {
795 this->writeWord(id, out);
796 }
797 break;
798 }
799 case kSubpassLoad_SpecialIntrinsic: {
800 SpvId img = this->writeExpression(*c.fArguments[0], out);
801 std::vector<std::unique_ptr<Expression>> args;
802 args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
803 args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
804 Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
805 SpvId coords = this->writeConstantVector(ctor);
806 if (1 == c.fArguments.size()) {
807 this->writeInstruction(SpvOpImageRead,
808 this->getType(c.fType),
809 result,
810 img,
811 coords,
812 out);
813 } else {
814 ASSERT(2 == c.fArguments.size());
815 SpvId sample = this->writeExpression(*c.fArguments[1], out);
816 this->writeInstruction(SpvOpImageRead,
817 this->getType(c.fType),
818 result,
819 img,
820 coords,
821 SpvImageOperandsSampleMask,
822 sample,
823 out);
824 }
825 break;
826 }
827 case kTexelFetch_SpecialIntrinsic: {
828 ASSERT(c.fArguments.size() == 2);
829 SpvId image = this->nextId();
830 this->writeInstruction(SpvOpImage,
831 this->getImageType(c.fArguments[0]->fType),
832 image,
833 this->writeExpression(*c.fArguments[0], out),
834 out);
835 this->writeInstruction(SpvOpImageFetch,
836 this->getType(c.fType),
837 result,
838 image,
839 this->writeExpression(*c.fArguments[1], out),
840 out);
841 break;
842 }
843 case kTexture_SpecialIntrinsic: {
844 SpvOp_ op = SpvOpImageSampleImplicitLod;
845 switch (c.fArguments[0]->fType.dimensions()) {
846 case SpvDim1D:
847 if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
848 op = SpvOpImageSampleProjImplicitLod;
849 } else {
850 ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
851 }
852 break;
853 case SpvDim2D:
854 if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
855 op = SpvOpImageSampleProjImplicitLod;
856 } else {
857 ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
858 }
859 break;
860 case SpvDim3D:
861 if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
862 op = SpvOpImageSampleProjImplicitLod;
863 } else {
864 ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
865 }
866 break;
867 case SpvDimCube: // fall through
868 case SpvDimRect: // fall through
869 case SpvDimBuffer: // fall through
870 case SpvDimSubpassData:
871 break;
872 }
873 SpvId type = this->getType(c.fType);
874 SpvId sampler = this->writeExpression(*c.fArguments[0], out);
875 SpvId uv = this->writeExpression(*c.fArguments[1], out);
876 if (c.fArguments.size() == 3) {
877 this->writeInstruction(op, type, result, sampler, uv,
878 SpvImageOperandsBiasMask,
879 this->writeExpression(*c.fArguments[2], out),
880 out);
881 } else {
882 ASSERT(c.fArguments.size() == 2);
883 this->writeInstruction(op, type, result, sampler, uv,
884 out);
885 }
886 break;
887 }
888 case kMod_SpecialIntrinsic: {
889 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
890 ASSERT(args.size() == 2);
891 const Type& operandType = c.fArguments[0]->fType;
892 SpvOp_ op;
893 if (is_float(fContext, operandType)) {
894 op = SpvOpFMod;
895 } else if (is_signed(fContext, operandType)) {
896 op = SpvOpSMod;
897 } else if (is_unsigned(fContext, operandType)) {
898 op = SpvOpUMod;
899 } else {
900 ASSERT(false);
901 return 0;
902 }
903 this->writeOpCode(op, 5, out);
904 this->writeWord(this->getType(operandType), out);
905 this->writeWord(result, out);
906 this->writeWord(args[0], out);
907 this->writeWord(args[1], out);
908 break;
909 }
910 case kClamp_SpecialIntrinsic: {
911 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
912 ASSERT(args.size() == 3);
913 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
914 GLSLstd450UClamp, args, out);
915 break;
916 }
917 case kMax_SpecialIntrinsic: {
918 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
919 ASSERT(args.size() == 2);
920 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
921 GLSLstd450UMax, args, out);
922 break;
923 }
924 case kMin_SpecialIntrinsic: {
925 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
926 ASSERT(args.size() == 2);
927 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
928 GLSLstd450UMin, args, out);
929 break;
930 }
931 case kMix_SpecialIntrinsic: {
932 std::vector<SpvId> args = this->vectorize(c.fArguments, out);
933 ASSERT(args.size() == 3);
934 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
935 SpvOpUndef, args, out);
936 break;
937 }
938 }
939 return result;
940 }
941
writeFunctionCall(const FunctionCall & c,OutputStream & out)942 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
943 const auto& entry = fFunctionMap.find(&c.fFunction);
944 if (entry == fFunctionMap.end()) {
945 return this->writeIntrinsicCall(c, out);
946 }
947 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
948 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
949 std::vector<SpvId> arguments;
950 for (size_t i = 0; i < c.fArguments.size(); i++) {
951 // id of temporary variable that we will use to hold this argument, or 0 if it is being
952 // passed directly
953 SpvId tmpVar;
954 // if we need a temporary var to store this argument, this is the value to store in the var
955 SpvId tmpValueId;
956 if (is_out(*c.fFunction.fParameters[i])) {
957 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
958 SpvId ptr = lv->getPointer();
959 if (ptr) {
960 arguments.push_back(ptr);
961 continue;
962 } else {
963 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
964 // copy it into a temp, call the function, read the value out of the temp, and then
965 // update the lvalue.
966 tmpValueId = lv->load(out);
967 tmpVar = this->nextId();
968 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
969 std::move(lv)));
970 }
971 } else {
972 // see getFunctionType for an explanation of why we're always using pointer parameters
973 tmpValueId = this->writeExpression(*c.fArguments[i], out);
974 tmpVar = this->nextId();
975 }
976 this->writeInstruction(SpvOpVariable,
977 this->getPointerType(c.fArguments[i]->fType,
978 SpvStorageClassFunction),
979 tmpVar,
980 SpvStorageClassFunction,
981 fVariableBuffer);
982 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
983 arguments.push_back(tmpVar);
984 }
985 SpvId result = this->nextId();
986 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
987 this->writeWord(this->getType(c.fType), out);
988 this->writeWord(result, out);
989 this->writeWord(entry->second, out);
990 for (SpvId id : arguments) {
991 this->writeWord(id, out);
992 }
993 // now that the call is complete, we may need to update some lvalues with the new values of out
994 // arguments
995 for (const auto& tuple : lvalues) {
996 SpvId load = this->nextId();
997 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
998 std::get<2>(tuple)->store(load, out);
999 }
1000 return result;
1001 }
1002
writeConstantVector(const Constructor & c)1003 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1004 ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1005 SpvId result = this->nextId();
1006 std::vector<SpvId> arguments;
1007 for (size_t i = 0; i < c.fArguments.size(); i++) {
1008 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1009 }
1010 SpvId type = this->getType(c.fType);
1011 if (c.fArguments.size() == 1) {
1012 // with a single argument, a vector will have all of its entries equal to the argument
1013 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1014 this->writeWord(type, fConstantBuffer);
1015 this->writeWord(result, fConstantBuffer);
1016 for (int i = 0; i < c.fType.columns(); i++) {
1017 this->writeWord(arguments[0], fConstantBuffer);
1018 }
1019 } else {
1020 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1021 fConstantBuffer);
1022 this->writeWord(type, fConstantBuffer);
1023 this->writeWord(result, fConstantBuffer);
1024 for (SpvId id : arguments) {
1025 this->writeWord(id, fConstantBuffer);
1026 }
1027 }
1028 return result;
1029 }
1030
writeFloatConstructor(const Constructor & c,OutputStream & out)1031 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1032 ASSERT(c.fType.isFloat());
1033 ASSERT(c.fArguments.size() == 1);
1034 ASSERT(c.fArguments[0]->fType.isNumber());
1035 SpvId result = this->nextId();
1036 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1037 if (c.fArguments[0]->fType.isSigned()) {
1038 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1039 out);
1040 } else {
1041 ASSERT(c.fArguments[0]->fType.isUnsigned());
1042 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1043 out);
1044 }
1045 return result;
1046 }
1047
writeIntConstructor(const Constructor & c,OutputStream & out)1048 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1049 ASSERT(c.fType.isSigned());
1050 ASSERT(c.fArguments.size() == 1);
1051 ASSERT(c.fArguments[0]->fType.isNumber());
1052 SpvId result = this->nextId();
1053 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1054 if (c.fArguments[0]->fType.isFloat()) {
1055 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1056 out);
1057 }
1058 else {
1059 ASSERT(c.fArguments[0]->fType.isUnsigned());
1060 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1061 out);
1062 }
1063 return result;
1064 }
1065
writeUIntConstructor(const Constructor & c,OutputStream & out)1066 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1067 ASSERT(c.fType.isUnsigned());
1068 ASSERT(c.fArguments.size() == 1);
1069 ASSERT(c.fArguments[0]->fType.isNumber());
1070 SpvId result = this->nextId();
1071 SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1072 if (c.fArguments[0]->fType.isFloat()) {
1073 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1074 out);
1075 } else {
1076 ASSERT(c.fArguments[0]->fType.isSigned());
1077 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1078 out);
1079 }
1080 return result;
1081 }
1082
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1083 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1084 OutputStream& out) {
1085 FloatLiteral zero(fContext, -1, 0);
1086 SpvId zeroId = this->writeFloatLiteral(zero);
1087 std::vector<SpvId> columnIds;
1088 for (int column = 0; column < type.columns(); column++) {
1089 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1090 out);
1091 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1092 out);
1093 SpvId columnId = this->nextId();
1094 this->writeWord(columnId, out);
1095 columnIds.push_back(columnId);
1096 for (int row = 0; row < type.columns(); row++) {
1097 this->writeWord(row == column ? diagonal : zeroId, out);
1098 }
1099 }
1100 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1101 out);
1102 this->writeWord(this->getType(type), out);
1103 this->writeWord(id, out);
1104 for (SpvId id : columnIds) {
1105 this->writeWord(id, out);
1106 }
1107 }
1108
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1109 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1110 const Type& dstType, OutputStream& out) {
1111 ASSERT(srcType.kind() == Type::kMatrix_Kind);
1112 ASSERT(dstType.kind() == Type::kMatrix_Kind);
1113 ASSERT(srcType.componentType() == dstType.componentType());
1114 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1115 srcType.rows(),
1116 1));
1117 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1118 dstType.rows(),
1119 1));
1120 SpvId zeroId;
1121 if (dstType.componentType() == *fContext.fFloat_Type) {
1122 FloatLiteral zero(fContext, -1, 0.0);
1123 zeroId = this->writeFloatLiteral(zero);
1124 } else if (dstType.componentType() == *fContext.fInt_Type) {
1125 IntLiteral zero(fContext, -1, 0);
1126 zeroId = this->writeIntLiteral(zero);
1127 } else {
1128 ABORT("unsupported matrix component type");
1129 }
1130 SpvId zeroColumn = 0;
1131 SpvId columns[4];
1132 for (int i = 0; i < dstType.columns(); i++) {
1133 if (i < srcType.columns()) {
1134 // we're still inside the src matrix, copy the column
1135 SpvId srcColumn = this->nextId();
1136 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1137 SpvId dstColumn;
1138 if (srcType.rows() == dstType.rows()) {
1139 // columns are equal size, don't need to do anything
1140 dstColumn = srcColumn;
1141 }
1142 else if (dstType.rows() > srcType.rows()) {
1143 // dst column is bigger, need to zero-pad it
1144 dstColumn = this->nextId();
1145 int delta = dstType.rows() - srcType.rows();
1146 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1147 this->writeWord(dstColumnType, out);
1148 this->writeWord(dstColumn, out);
1149 this->writeWord(srcColumn, out);
1150 for (int i = 0; i < delta; ++i) {
1151 this->writeWord(zeroId, out);
1152 }
1153 }
1154 else {
1155 // dst column is smaller, need to swizzle the src column
1156 dstColumn = this->nextId();
1157 int count = dstType.rows();
1158 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1159 this->writeWord(dstColumnType, out);
1160 this->writeWord(dstColumn, out);
1161 this->writeWord(srcColumn, out);
1162 this->writeWord(srcColumn, out);
1163 for (int i = 0; i < count; i++) {
1164 this->writeWord(i, out);
1165 }
1166 }
1167 columns[i] = dstColumn;
1168 } else {
1169 // we're past the end of the src matrix, need a vector of zeroes
1170 if (!zeroColumn) {
1171 zeroColumn = this->nextId();
1172 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1173 this->writeWord(dstColumnType, out);
1174 this->writeWord(zeroColumn, out);
1175 for (int i = 0; i < dstType.rows(); ++i) {
1176 this->writeWord(zeroId, out);
1177 }
1178 }
1179 columns[i] = zeroColumn;
1180 }
1181 }
1182 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1183 this->writeWord(this->getType(dstType), out);
1184 this->writeWord(id, out);
1185 for (int i = 0; i < dstType.columns(); i++) {
1186 this->writeWord(columns[i], out);
1187 }
1188 }
1189
writeMatrixConstructor(const Constructor & c,OutputStream & out)1190 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1191 ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1192 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1193 // an instruction
1194 std::vector<SpvId> arguments;
1195 for (size_t i = 0; i < c.fArguments.size(); i++) {
1196 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1197 }
1198 SpvId result = this->nextId();
1199 int rows = c.fType.rows();
1200 int columns = c.fType.columns();
1201 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1202 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1203 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1204 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1205 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1206 ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1207 ASSERT(c.fArguments[0]->fType.columns() == 4);
1208 SpvId componentType = this->getType(c.fType.componentType());
1209 SpvId v[4];
1210 for (int i = 0; i < 4; ++i) {
1211 v[i] = this->nextId();
1212 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1213 }
1214 SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1215 SpvId column1 = this->nextId();
1216 this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1217 SpvId column2 = this->nextId();
1218 this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1219 this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1220 column2, out);
1221 } else {
1222 std::vector<SpvId> columnIds;
1223 // ids of vectors and scalars we have written to the current column so far
1224 std::vector<SpvId> currentColumn;
1225 // the total number of scalars represented by currentColumn's entries
1226 int currentCount = 0;
1227 for (size_t i = 0; i < arguments.size(); i++) {
1228 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1229 c.fArguments[i]->fType.columns() == c.fType.rows()) {
1230 // this is a complete column by itself
1231 ASSERT(currentCount == 0);
1232 columnIds.push_back(arguments[i]);
1233 } else {
1234 currentColumn.push_back(arguments[i]);
1235 currentCount += c.fArguments[i]->fType.columns();
1236 if (currentCount == rows) {
1237 currentCount = 0;
1238 this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1239 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1240 1)),
1241 out);
1242 SpvId columnId = this->nextId();
1243 this->writeWord(columnId, out);
1244 columnIds.push_back(columnId);
1245 for (SpvId id : currentColumn) {
1246 this->writeWord(id, out);
1247 }
1248 currentColumn.clear();
1249 }
1250 ASSERT(currentCount < rows);
1251 }
1252 }
1253 ASSERT(columnIds.size() == (size_t) columns);
1254 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1255 this->writeWord(this->getType(c.fType), out);
1256 this->writeWord(result, out);
1257 for (SpvId id : columnIds) {
1258 this->writeWord(id, out);
1259 }
1260 }
1261 return result;
1262 }
1263
writeVectorConstructor(const Constructor & c,OutputStream & out)1264 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1265 ASSERT(c.fType.kind() == Type::kVector_Kind);
1266 if (c.isConstant()) {
1267 return this->writeConstantVector(c);
1268 }
1269 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1270 // an instruction
1271 std::vector<SpvId> arguments;
1272 for (size_t i = 0; i < c.fArguments.size(); i++) {
1273 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1274 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1275 // extract the components and convert them in that case manually. On top of that,
1276 // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1277 // doesn't handle vector arguments at all, so we always extract vector components and
1278 // pass them into OpCreateComposite individually.
1279 SpvId vec = this->writeExpression(*c.fArguments[i], out);
1280 SpvOp_ op = SpvOpUndef;
1281 const Type& src = c.fArguments[i]->fType.componentType();
1282 const Type& dst = c.fType.componentType();
1283 if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1284 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1285 if (c.fArguments.size() == 1) {
1286 return vec;
1287 }
1288 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1289 op = SpvOpConvertSToF;
1290 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1291 op = SpvOpConvertUToF;
1292 } else {
1293 ASSERT(false);
1294 }
1295 } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) {
1296 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1297 op = SpvOpConvertFToS;
1298 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1299 if (c.fArguments.size() == 1) {
1300 return vec;
1301 }
1302 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1303 op = SpvOpBitcast;
1304 } else {
1305 ASSERT(false);
1306 }
1307 } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) {
1308 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1309 op = SpvOpConvertFToS;
1310 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1311 op = SpvOpBitcast;
1312 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1313 if (c.fArguments.size() == 1) {
1314 return vec;
1315 }
1316 } else {
1317 ASSERT(false);
1318 }
1319 }
1320 for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1321 SpvId swizzle = this->nextId();
1322 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1323 out);
1324 if (op != SpvOpUndef) {
1325 SpvId cast = this->nextId();
1326 this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1327 arguments.push_back(cast);
1328 } else {
1329 arguments.push_back(swizzle);
1330 }
1331 }
1332 } else {
1333 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1334 }
1335 }
1336 SpvId result = this->nextId();
1337 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1338 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1339 this->writeWord(this->getType(c.fType), out);
1340 this->writeWord(result, out);
1341 for (int i = 0; i < c.fType.columns(); i++) {
1342 this->writeWord(arguments[0], out);
1343 }
1344 } else {
1345 ASSERT(arguments.size() > 1);
1346 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1347 this->writeWord(this->getType(c.fType), out);
1348 this->writeWord(result, out);
1349 for (SpvId id : arguments) {
1350 this->writeWord(id, out);
1351 }
1352 }
1353 return result;
1354 }
1355
writeArrayConstructor(const Constructor & c,OutputStream & out)1356 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1357 ASSERT(c.fType.kind() == Type::kArray_Kind);
1358 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1359 // an instruction
1360 std::vector<SpvId> arguments;
1361 for (size_t i = 0; i < c.fArguments.size(); i++) {
1362 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1363 }
1364 SpvId result = this->nextId();
1365 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1366 this->writeWord(this->getType(c.fType), out);
1367 this->writeWord(result, out);
1368 for (SpvId id : arguments) {
1369 this->writeWord(id, out);
1370 }
1371 return result;
1372 }
1373
writeConstructor(const Constructor & c,OutputStream & out)1374 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1375 if (c.fArguments.size() == 1 &&
1376 this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1377 return this->writeExpression(*c.fArguments[0], out);
1378 }
1379 if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1380 return this->writeFloatConstructor(c, out);
1381 } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) {
1382 return this->writeIntConstructor(c, out);
1383 } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) {
1384 return this->writeUIntConstructor(c, out);
1385 }
1386 switch (c.fType.kind()) {
1387 case Type::kVector_Kind:
1388 return this->writeVectorConstructor(c, out);
1389 case Type::kMatrix_Kind:
1390 return this->writeMatrixConstructor(c, out);
1391 case Type::kArray_Kind:
1392 return this->writeArrayConstructor(c, out);
1393 default:
1394 ABORT("unsupported constructor: %s", c.description().c_str());
1395 }
1396 }
1397
get_storage_class(const Modifiers & modifiers)1398 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1399 if (modifiers.fFlags & Modifiers::kIn_Flag) {
1400 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1401 return SpvStorageClassInput;
1402 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1403 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1404 return SpvStorageClassOutput;
1405 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1406 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1407 return SpvStorageClassPushConstant;
1408 }
1409 return SpvStorageClassUniform;
1410 } else {
1411 return SpvStorageClassFunction;
1412 }
1413 }
1414
get_storage_class(const Expression & expr)1415 SpvStorageClass_ get_storage_class(const Expression& expr) {
1416 switch (expr.fKind) {
1417 case Expression::kVariableReference_Kind: {
1418 const Variable& var = ((VariableReference&) expr).fVariable;
1419 if (var.fStorage != Variable::kGlobal_Storage) {
1420 return SpvStorageClassFunction;
1421 }
1422 SpvStorageClass_ result = get_storage_class(var.fModifiers);
1423 if (result == SpvStorageClassFunction) {
1424 result = SpvStorageClassPrivate;
1425 }
1426 return result;
1427 }
1428 case Expression::kFieldAccess_Kind:
1429 return get_storage_class(*((FieldAccess&) expr).fBase);
1430 case Expression::kIndex_Kind:
1431 return get_storage_class(*((IndexExpression&) expr).fBase);
1432 default:
1433 return SpvStorageClassFunction;
1434 }
1435 }
1436
getAccessChain(const Expression & expr,OutputStream & out)1437 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1438 std::vector<SpvId> chain;
1439 switch (expr.fKind) {
1440 case Expression::kIndex_Kind: {
1441 IndexExpression& indexExpr = (IndexExpression&) expr;
1442 chain = this->getAccessChain(*indexExpr.fBase, out);
1443 chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1444 break;
1445 }
1446 case Expression::kFieldAccess_Kind: {
1447 FieldAccess& fieldExpr = (FieldAccess&) expr;
1448 chain = this->getAccessChain(*fieldExpr.fBase, out);
1449 IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1450 chain.push_back(this->writeIntLiteral(index));
1451 break;
1452 }
1453 default:
1454 chain.push_back(this->getLValue(expr, out)->getPointer());
1455 }
1456 return chain;
1457 }
1458
1459 class PointerLValue : public SPIRVCodeGenerator::LValue {
1460 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type)1461 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1462 : fGen(gen)
1463 , fPointer(pointer)
1464 , fType(type) {}
1465
getPointer()1466 virtual SpvId getPointer() override {
1467 return fPointer;
1468 }
1469
load(OutputStream & out)1470 virtual SpvId load(OutputStream& out) override {
1471 SpvId result = fGen.nextId();
1472 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1473 return result;
1474 }
1475
store(SpvId value,OutputStream & out)1476 virtual void store(SpvId value, OutputStream& out) override {
1477 fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1478 }
1479
1480 private:
1481 SPIRVCodeGenerator& fGen;
1482 const SpvId fPointer;
1483 const SpvId fType;
1484 };
1485
1486 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1487 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType)1488 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1489 const Type& baseType, const Type& swizzleType)
1490 : fGen(gen)
1491 , fVecPointer(vecPointer)
1492 , fComponents(components)
1493 , fBaseType(baseType)
1494 , fSwizzleType(swizzleType) {}
1495
getPointer()1496 virtual SpvId getPointer() override {
1497 return 0;
1498 }
1499
load(OutputStream & out)1500 virtual SpvId load(OutputStream& out) override {
1501 SpvId base = fGen.nextId();
1502 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1503 SpvId result = fGen.nextId();
1504 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1505 fGen.writeWord(fGen.getType(fSwizzleType), out);
1506 fGen.writeWord(result, out);
1507 fGen.writeWord(base, out);
1508 fGen.writeWord(base, out);
1509 for (int component : fComponents) {
1510 fGen.writeWord(component, out);
1511 }
1512 return result;
1513 }
1514
store(SpvId value,OutputStream & out)1515 virtual void store(SpvId value, OutputStream& out) override {
1516 // use OpVectorShuffle to mix and match the vector components. We effectively create
1517 // a virtual vector out of the concatenation of the left and right vectors, and then
1518 // select components from this virtual vector to make the result vector. For
1519 // instance, given:
1520 // float3L = ...;
1521 // float3R = ...;
1522 // L.xz = R.xy;
1523 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1524 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1525 // (3, 1, 4).
1526 SpvId base = fGen.nextId();
1527 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1528 SpvId shuffle = fGen.nextId();
1529 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1530 fGen.writeWord(fGen.getType(fBaseType), out);
1531 fGen.writeWord(shuffle, out);
1532 fGen.writeWord(base, out);
1533 fGen.writeWord(value, out);
1534 for (int i = 0; i < fBaseType.columns(); i++) {
1535 // current offset into the virtual vector, defaults to pulling the unmodified
1536 // value from the left side
1537 int offset = i;
1538 // check to see if we are writing this component
1539 for (size_t j = 0; j < fComponents.size(); j++) {
1540 if (fComponents[j] == i) {
1541 // we're writing to this component, so adjust the offset to pull from
1542 // the correct component of the right side instead of preserving the
1543 // value from the left
1544 offset = (int) (j + fBaseType.columns());
1545 break;
1546 }
1547 }
1548 fGen.writeWord(offset, out);
1549 }
1550 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1551 }
1552
1553 private:
1554 SPIRVCodeGenerator& fGen;
1555 const SpvId fVecPointer;
1556 const std::vector<int>& fComponents;
1557 const Type& fBaseType;
1558 const Type& fSwizzleType;
1559 };
1560
getLValue(const Expression & expr,OutputStream & out)1561 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1562 OutputStream& out) {
1563 switch (expr.fKind) {
1564 case Expression::kVariableReference_Kind: {
1565 const Variable& var = ((VariableReference&) expr).fVariable;
1566 auto entry = fVariableMap.find(&var);
1567 ASSERT(entry != fVariableMap.end());
1568 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1569 *this,
1570 entry->second,
1571 this->getType(expr.fType)));
1572 }
1573 case Expression::kIndex_Kind: // fall through
1574 case Expression::kFieldAccess_Kind: {
1575 std::vector<SpvId> chain = this->getAccessChain(expr, out);
1576 SpvId member = this->nextId();
1577 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1578 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1579 this->writeWord(member, out);
1580 for (SpvId idx : chain) {
1581 this->writeWord(idx, out);
1582 }
1583 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1584 *this,
1585 member,
1586 this->getType(expr.fType)));
1587 }
1588 case Expression::kSwizzle_Kind: {
1589 Swizzle& swizzle = (Swizzle&) expr;
1590 size_t count = swizzle.fComponents.size();
1591 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1592 ASSERT(base);
1593 if (count == 1) {
1594 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1595 SpvId member = this->nextId();
1596 this->writeInstruction(SpvOpAccessChain,
1597 this->getPointerType(swizzle.fType,
1598 get_storage_class(*swizzle.fBase)),
1599 member,
1600 base,
1601 this->writeIntLiteral(index),
1602 out);
1603 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1604 *this,
1605 member,
1606 this->getType(expr.fType)));
1607 } else {
1608 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1609 *this,
1610 base,
1611 swizzle.fComponents,
1612 swizzle.fBase->fType,
1613 expr.fType));
1614 }
1615 }
1616 case Expression::kTernary_Kind: {
1617 TernaryExpression& t = (TernaryExpression&) expr;
1618 SpvId test = this->writeExpression(*t.fTest, out);
1619 SpvId end = this->nextId();
1620 SpvId ifTrueLabel = this->nextId();
1621 SpvId ifFalseLabel = this->nextId();
1622 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1623 this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1624 this->writeLabel(ifTrueLabel, out);
1625 SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1626 ASSERT(ifTrue);
1627 this->writeInstruction(SpvOpBranch, end, out);
1628 ifTrueLabel = fCurrentBlock;
1629 SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1630 ASSERT(ifFalse);
1631 ifFalseLabel = fCurrentBlock;
1632 this->writeInstruction(SpvOpBranch, end, out);
1633 SpvId result = this->nextId();
1634 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1635 ifTrueLabel, ifFalse, ifFalseLabel, out);
1636 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1637 *this,
1638 result,
1639 this->getType(expr.fType)));
1640 }
1641 default:
1642 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1643 // to the need to store values in temporary variables during function calls (see
1644 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1645 // caught by IRGenerator
1646 SpvId result = this->nextId();
1647 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1648 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1649 fVariableBuffer);
1650 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1651 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1652 *this,
1653 result,
1654 this->getType(expr.fType)));
1655 }
1656 }
1657
writeVariableReference(const VariableReference & ref,OutputStream & out)1658 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1659 SpvId result = this->nextId();
1660 auto entry = fVariableMap.find(&ref.fVariable);
1661 ASSERT(entry != fVariableMap.end());
1662 SpvId var = entry->second;
1663 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1664 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1665 fProgram.fSettings.fFlipY) {
1666 // need to remap to a top-left coordinate system
1667 if (fRTHeightStructId == (SpvId) -1) {
1668 // height variable hasn't been written yet
1669 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1670 ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1671 std::vector<Type::Field> fields;
1672 fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1673 StringFragment name("sksl_synthetic_uniforms");
1674 Type intfStruct(-1, name, fields);
1675 Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1676 Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1677 StringFragment());
1678 Variable* intfVar = new Variable(-1,
1679 Modifiers(layout, Modifiers::kUniform_Flag),
1680 name,
1681 intfStruct,
1682 Variable::kGlobal_Storage);
1683 fSynthetics.takeOwnership(intfVar);
1684 InterfaceBlock intf(-1, intfVar, name, String(""),
1685 std::vector<std::unique_ptr<Expression>>(), st);
1686 fRTHeightStructId = this->writeInterfaceBlock(intf);
1687 fRTHeightFieldIndex = 0;
1688 }
1689 ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1690 // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1691 SpvId xId = this->nextId();
1692 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1693 result, 0, out);
1694 IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1695 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1696 SpvId heightPtr = this->nextId();
1697 this->writeOpCode(SpvOpAccessChain, 5, out);
1698 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1699 this->writeWord(heightPtr, out);
1700 this->writeWord(fRTHeightStructId, out);
1701 this->writeWord(fieldIndexId, out);
1702 SpvId heightRead = this->nextId();
1703 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1704 heightPtr, out);
1705 SpvId rawYId = this->nextId();
1706 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1707 result, 1, out);
1708 SpvId flippedYId = this->nextId();
1709 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1710 heightRead, rawYId, out);
1711 FloatLiteral zero(fContext, -1, 0.0);
1712 SpvId zeroId = writeFloatLiteral(zero);
1713 FloatLiteral one(fContext, -1, 1.0);
1714 SpvId oneId = writeFloatLiteral(one);
1715 SpvId flipped = this->nextId();
1716 this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1717 this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1718 this->writeWord(flipped, out);
1719 this->writeWord(xId, out);
1720 this->writeWord(flippedYId, out);
1721 this->writeWord(zeroId, out);
1722 this->writeWord(oneId, out);
1723 return flipped;
1724 }
1725 return result;
1726 }
1727
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1728 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1729 return getLValue(expr, out)->load(out);
1730 }
1731
writeFieldAccess(const FieldAccess & f,OutputStream & out)1732 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1733 return getLValue(f, out)->load(out);
1734 }
1735
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1736 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1737 SpvId base = this->writeExpression(*swizzle.fBase, out);
1738 SpvId result = this->nextId();
1739 size_t count = swizzle.fComponents.size();
1740 if (count == 1) {
1741 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1742 swizzle.fComponents[0], out);
1743 } else {
1744 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1745 this->writeWord(this->getType(swizzle.fType), out);
1746 this->writeWord(result, out);
1747 this->writeWord(base, out);
1748 this->writeWord(base, out);
1749 for (int component : swizzle.fComponents) {
1750 this->writeWord(component, out);
1751 }
1752 }
1753 return result;
1754 }
1755
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1756 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1757 const Type& operandType, SpvId lhs,
1758 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1759 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1760 SpvId result = this->nextId();
1761 if (is_float(fContext, operandType)) {
1762 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1763 } else if (is_signed(fContext, operandType)) {
1764 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1765 } else if (is_unsigned(fContext, operandType)) {
1766 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1767 } else if (operandType == *fContext.fBool_Type) {
1768 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1769 } else {
1770 ABORT("invalid operandType: %s", operandType.description().c_str());
1771 }
1772 return result;
1773 }
1774
is_assignment(Token::Kind op)1775 bool is_assignment(Token::Kind op) {
1776 switch (op) {
1777 case Token::EQ: // fall through
1778 case Token::PLUSEQ: // fall through
1779 case Token::MINUSEQ: // fall through
1780 case Token::STAREQ: // fall through
1781 case Token::SLASHEQ: // fall through
1782 case Token::PERCENTEQ: // fall through
1783 case Token::SHLEQ: // fall through
1784 case Token::SHREQ: // fall through
1785 case Token::BITWISEOREQ: // fall through
1786 case Token::BITWISEXOREQ: // fall through
1787 case Token::BITWISEANDEQ: // fall through
1788 case Token::LOGICALOREQ: // fall through
1789 case Token::LOGICALXOREQ: // fall through
1790 case Token::LOGICALANDEQ:
1791 return true;
1792 default:
1793 return false;
1794 }
1795 }
1796
foldToBool(SpvId id,const Type & operandType,OutputStream & out)1797 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
1798 if (operandType.kind() == Type::kVector_Kind) {
1799 SpvId result = this->nextId();
1800 this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
1801 return result;
1802 }
1803 return id;
1804 }
1805
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)1806 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1807 SpvOp_ floatOperator, SpvOp_ intOperator,
1808 OutputStream& out) {
1809 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1810 ASSERT(operandType.kind() == Type::kMatrix_Kind);
1811 SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
1812 operandType.columns(),
1813 1));
1814 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1815 operandType.columns(),
1816 1));
1817 SpvId boolType = this->getType(*fContext.fBool_Type);
1818 SpvId result = 0;
1819 for (int i = 0; i < operandType.rows(); i++) {
1820 SpvId rowL = this->nextId();
1821 this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
1822 SpvId rowR = this->nextId();
1823 this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
1824 SpvId compare = this->nextId();
1825 this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
1826 SpvId all = this->nextId();
1827 this->writeInstruction(SpvOpAll, boolType, all, compare, out);
1828 if (result != 0) {
1829 SpvId next = this->nextId();
1830 this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
1831 result = next;
1832 }
1833 else {
1834 result = all;
1835 }
1836 }
1837 return result;
1838 }
1839
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)1840 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1841 // handle cases where we don't necessarily evaluate both LHS and RHS
1842 switch (b.fOperator) {
1843 case Token::EQ: {
1844 SpvId rhs = this->writeExpression(*b.fRight, out);
1845 this->getLValue(*b.fLeft, out)->store(rhs, out);
1846 return rhs;
1847 }
1848 case Token::LOGICALAND:
1849 return this->writeLogicalAnd(b, out);
1850 case Token::LOGICALOR:
1851 return this->writeLogicalOr(b, out);
1852 default:
1853 break;
1854 }
1855
1856 // "normal" operators
1857 const Type& resultType = b.fType;
1858 std::unique_ptr<LValue> lvalue;
1859 SpvId lhs;
1860 if (is_assignment(b.fOperator)) {
1861 lvalue = this->getLValue(*b.fLeft, out);
1862 lhs = lvalue->load(out);
1863 } else {
1864 lvalue = nullptr;
1865 lhs = this->writeExpression(*b.fLeft, out);
1866 }
1867 SpvId rhs = this->writeExpression(*b.fRight, out);
1868 if (b.fOperator == Token::COMMA) {
1869 return rhs;
1870 }
1871 Type tmp("<invalid>");
1872 // component type we are operating on: float, int, uint
1873 const Type* operandType;
1874 // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling
1875 // in SPIR-V
1876 if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1877 if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1878 b.fRight->fType.isNumber()) {
1879 // promote number to vector
1880 SpvId vec = this->nextId();
1881 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1882 this->writeWord(this->getType(resultType), out);
1883 this->writeWord(vec, out);
1884 for (int i = 0; i < resultType.columns(); i++) {
1885 this->writeWord(rhs, out);
1886 }
1887 rhs = vec;
1888 operandType = &b.fRight->fType;
1889 } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1890 b.fLeft->fType.isNumber()) {
1891 // promote number to vector
1892 SpvId vec = this->nextId();
1893 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1894 this->writeWord(this->getType(resultType), out);
1895 this->writeWord(vec, out);
1896 for (int i = 0; i < resultType.columns(); i++) {
1897 this->writeWord(lhs, out);
1898 }
1899 lhs = vec;
1900 ASSERT(!lvalue);
1901 operandType = &b.fLeft->fType;
1902 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
1903 SpvOp_ op;
1904 if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1905 op = SpvOpMatrixTimesMatrix;
1906 } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
1907 op = SpvOpMatrixTimesVector;
1908 } else {
1909 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
1910 op = SpvOpMatrixTimesScalar;
1911 }
1912 SpvId result = this->nextId();
1913 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
1914 if (b.fOperator == Token::STAREQ) {
1915 lvalue->store(result, out);
1916 } else {
1917 ASSERT(b.fOperator == Token::STAR);
1918 }
1919 return result;
1920 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1921 SpvId result = this->nextId();
1922 if (b.fLeft->fType.kind() == Type::kVector_Kind) {
1923 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
1924 lhs, rhs, out);
1925 } else {
1926 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
1927 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
1928 lhs, out);
1929 }
1930 if (b.fOperator == Token::STAREQ) {
1931 lvalue->store(result, out);
1932 } else {
1933 ASSERT(b.fOperator == Token::STAR);
1934 }
1935 return result;
1936 } else {
1937 ABORT("unsupported binary expression: %s", b.description().c_str());
1938 }
1939 } else {
1940 tmp = this->getActualType(b.fLeft->fType);
1941 operandType = &tmp;
1942 ASSERT(*operandType == this->getActualType(b.fRight->fType));
1943 }
1944 switch (b.fOperator) {
1945 case Token::EQEQ: {
1946 if (operandType->kind() == Type::kMatrix_Kind) {
1947 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
1948 SpvOpIEqual, out);
1949 }
1950 ASSERT(resultType == *fContext.fBool_Type);
1951 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1952 SpvOpFOrdEqual, SpvOpIEqual,
1953 SpvOpIEqual, SpvOpLogicalEqual, out),
1954 *operandType, out);
1955 }
1956 case Token::NEQ:
1957 if (operandType->kind() == Type::kMatrix_Kind) {
1958 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
1959 SpvOpINotEqual, out);
1960 }
1961 ASSERT(resultType == *fContext.fBool_Type);
1962 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1963 SpvOpFOrdNotEqual, SpvOpINotEqual,
1964 SpvOpINotEqual, SpvOpLogicalNotEqual,
1965 out),
1966 *operandType, out);
1967 case Token::GT:
1968 ASSERT(resultType == *fContext.fBool_Type);
1969 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1970 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
1971 SpvOpUGreaterThan, SpvOpUndef, out);
1972 case Token::LT:
1973 ASSERT(resultType == *fContext.fBool_Type);
1974 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
1975 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
1976 case Token::GTEQ:
1977 ASSERT(resultType == *fContext.fBool_Type);
1978 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1979 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
1980 SpvOpUGreaterThanEqual, SpvOpUndef, out);
1981 case Token::LTEQ:
1982 ASSERT(resultType == *fContext.fBool_Type);
1983 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1984 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
1985 SpvOpULessThanEqual, SpvOpUndef, out);
1986 case Token::PLUS:
1987 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1988 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1989 case Token::MINUS:
1990 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1991 SpvOpISub, SpvOpISub, SpvOpUndef, out);
1992 case Token::STAR:
1993 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
1994 b.fRight->fType.kind() == Type::kMatrix_Kind) {
1995 // matrix multiply
1996 SpvId result = this->nextId();
1997 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
1998 lhs, rhs, out);
1999 return result;
2000 }
2001 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2002 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2003 case Token::SLASH:
2004 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2005 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2006 case Token::PERCENT:
2007 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2008 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2009 case Token::SHL:
2010 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2011 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2012 SpvOpUndef, out);
2013 case Token::SHR:
2014 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2015 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2016 SpvOpUndef, out);
2017 case Token::BITWISEAND:
2018 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2019 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2020 case Token::BITWISEOR:
2021 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2022 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2023 case Token::BITWISEXOR:
2024 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2025 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2026 case Token::PLUSEQ: {
2027 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2028 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2029 ASSERT(lvalue);
2030 lvalue->store(result, out);
2031 return result;
2032 }
2033 case Token::MINUSEQ: {
2034 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2035 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2036 ASSERT(lvalue);
2037 lvalue->store(result, out);
2038 return result;
2039 }
2040 case Token::STAREQ: {
2041 if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2042 b.fRight->fType.kind() == Type::kMatrix_Kind) {
2043 // matrix multiply
2044 SpvId result = this->nextId();
2045 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2046 lhs, rhs, out);
2047 ASSERT(lvalue);
2048 lvalue->store(result, out);
2049 return result;
2050 }
2051 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2052 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2053 ASSERT(lvalue);
2054 lvalue->store(result, out);
2055 return result;
2056 }
2057 case Token::SLASHEQ: {
2058 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2059 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2060 ASSERT(lvalue);
2061 lvalue->store(result, out);
2062 return result;
2063 }
2064 case Token::PERCENTEQ: {
2065 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2066 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2067 ASSERT(lvalue);
2068 lvalue->store(result, out);
2069 return result;
2070 }
2071 case Token::SHLEQ: {
2072 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2073 SpvOpUndef, SpvOpShiftLeftLogical,
2074 SpvOpShiftLeftLogical, SpvOpUndef, out);
2075 ASSERT(lvalue);
2076 lvalue->store(result, out);
2077 return result;
2078 }
2079 case Token::SHREQ: {
2080 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2081 SpvOpUndef, SpvOpShiftRightArithmetic,
2082 SpvOpShiftRightLogical, SpvOpUndef, out);
2083 ASSERT(lvalue);
2084 lvalue->store(result, out);
2085 return result;
2086 }
2087 case Token::BITWISEANDEQ: {
2088 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2089 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2090 SpvOpUndef, out);
2091 ASSERT(lvalue);
2092 lvalue->store(result, out);
2093 return result;
2094 }
2095 case Token::BITWISEOREQ: {
2096 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2097 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2098 SpvOpUndef, out);
2099 ASSERT(lvalue);
2100 lvalue->store(result, out);
2101 return result;
2102 }
2103 case Token::BITWISEXOREQ: {
2104 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2105 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2106 SpvOpUndef, out);
2107 ASSERT(lvalue);
2108 lvalue->store(result, out);
2109 return result;
2110 }
2111 default:
2112 ABORT("unsupported binary expression: %s", b.description().c_str());
2113 }
2114 }
2115
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2116 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2117 ASSERT(a.fOperator == Token::LOGICALAND);
2118 BoolLiteral falseLiteral(fContext, -1, false);
2119 SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2120 SpvId lhs = this->writeExpression(*a.fLeft, out);
2121 SpvId rhsLabel = this->nextId();
2122 SpvId end = this->nextId();
2123 SpvId lhsBlock = fCurrentBlock;
2124 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2125 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2126 this->writeLabel(rhsLabel, out);
2127 SpvId rhs = this->writeExpression(*a.fRight, out);
2128 SpvId rhsBlock = fCurrentBlock;
2129 this->writeInstruction(SpvOpBranch, end, out);
2130 this->writeLabel(end, out);
2131 SpvId result = this->nextId();
2132 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2133 lhsBlock, rhs, rhsBlock, out);
2134 return result;
2135 }
2136
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2137 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2138 ASSERT(o.fOperator == Token::LOGICALOR);
2139 BoolLiteral trueLiteral(fContext, -1, true);
2140 SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2141 SpvId lhs = this->writeExpression(*o.fLeft, out);
2142 SpvId rhsLabel = this->nextId();
2143 SpvId end = this->nextId();
2144 SpvId lhsBlock = fCurrentBlock;
2145 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2146 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2147 this->writeLabel(rhsLabel, out);
2148 SpvId rhs = this->writeExpression(*o.fRight, out);
2149 SpvId rhsBlock = fCurrentBlock;
2150 this->writeInstruction(SpvOpBranch, end, out);
2151 this->writeLabel(end, out);
2152 SpvId result = this->nextId();
2153 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2154 lhsBlock, rhs, rhsBlock, out);
2155 return result;
2156 }
2157
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2158 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2159 SpvId test = this->writeExpression(*t.fTest, out);
2160 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2161 // both true and false are constants, can just use OpSelect
2162 SpvId result = this->nextId();
2163 SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2164 SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2165 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2166 out);
2167 return result;
2168 }
2169 // was originally using OpPhi to choose the result, but for some reason that is crashing on
2170 // Adreno. Switched to storing the result in a temp variable as glslang does.
2171 SpvId var = this->nextId();
2172 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2173 var, SpvStorageClassFunction, fVariableBuffer);
2174 SpvId trueLabel = this->nextId();
2175 SpvId falseLabel = this->nextId();
2176 SpvId end = this->nextId();
2177 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2178 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2179 this->writeLabel(trueLabel, out);
2180 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2181 this->writeInstruction(SpvOpBranch, end, out);
2182 this->writeLabel(falseLabel, out);
2183 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2184 this->writeInstruction(SpvOpBranch, end, out);
2185 this->writeLabel(end, out);
2186 SpvId result = this->nextId();
2187 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2188 return result;
2189 }
2190
create_literal_1(const Context & context,const Type & type)2191 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2192 if (type.isInteger()) {
2193 return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type));
2194 }
2195 else if (type.isFloat()) {
2196 return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type));
2197 } else {
2198 ABORT("math is unsupported on type '%s'", type.name().c_str());
2199 }
2200 }
2201
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2202 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2203 if (p.fOperator == Token::MINUS) {
2204 SpvId result = this->nextId();
2205 SpvId typeId = this->getType(p.fType);
2206 SpvId expr = this->writeExpression(*p.fOperand, out);
2207 if (is_float(fContext, p.fType)) {
2208 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2209 } else if (is_signed(fContext, p.fType)) {
2210 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2211 } else {
2212 ABORT("unsupported prefix expression %s", p.description().c_str());
2213 };
2214 return result;
2215 }
2216 switch (p.fOperator) {
2217 case Token::PLUS:
2218 return this->writeExpression(*p.fOperand, out);
2219 case Token::PLUSPLUS: {
2220 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2221 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2222 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2223 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2224 out);
2225 lv->store(result, out);
2226 return result;
2227 }
2228 case Token::MINUSMINUS: {
2229 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2230 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2231 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2232 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2233 out);
2234 lv->store(result, out);
2235 return result;
2236 }
2237 case Token::LOGICALNOT: {
2238 ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2239 SpvId result = this->nextId();
2240 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2241 this->writeExpression(*p.fOperand, out), out);
2242 return result;
2243 }
2244 case Token::BITWISENOT: {
2245 SpvId result = this->nextId();
2246 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2247 this->writeExpression(*p.fOperand, out), out);
2248 return result;
2249 }
2250 default:
2251 ABORT("unsupported prefix expression: %s", p.description().c_str());
2252 }
2253 }
2254
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2255 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2256 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2257 SpvId result = lv->load(out);
2258 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2259 switch (p.fOperator) {
2260 case Token::PLUSPLUS: {
2261 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2262 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2263 lv->store(temp, out);
2264 return result;
2265 }
2266 case Token::MINUSMINUS: {
2267 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2268 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2269 lv->store(temp, out);
2270 return result;
2271 }
2272 default:
2273 ABORT("unsupported postfix expression %s", p.description().c_str());
2274 }
2275 }
2276
writeBoolLiteral(const BoolLiteral & b)2277 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2278 if (b.fValue) {
2279 if (fBoolTrue == 0) {
2280 fBoolTrue = this->nextId();
2281 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2282 fConstantBuffer);
2283 }
2284 return fBoolTrue;
2285 } else {
2286 if (fBoolFalse == 0) {
2287 fBoolFalse = this->nextId();
2288 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2289 fConstantBuffer);
2290 }
2291 return fBoolFalse;
2292 }
2293 }
2294
writeIntLiteral(const IntLiteral & i)2295 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2296 if (i.fType == *fContext.fInt_Type) {
2297 auto entry = fIntConstants.find(i.fValue);
2298 if (entry == fIntConstants.end()) {
2299 SpvId result = this->nextId();
2300 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2301 fConstantBuffer);
2302 fIntConstants[i.fValue] = result;
2303 return result;
2304 }
2305 return entry->second;
2306 } else {
2307 ASSERT(i.fType == *fContext.fUInt_Type);
2308 auto entry = fUIntConstants.find(i.fValue);
2309 if (entry == fUIntConstants.end()) {
2310 SpvId result = this->nextId();
2311 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2312 fConstantBuffer);
2313 fUIntConstants[i.fValue] = result;
2314 return result;
2315 }
2316 return entry->second;
2317 }
2318 }
2319
writeFloatLiteral(const FloatLiteral & f)2320 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2321 if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2322 float value = (float) f.fValue;
2323 auto entry = fFloatConstants.find(value);
2324 if (entry == fFloatConstants.end()) {
2325 SpvId result = this->nextId();
2326 uint32_t bits;
2327 ASSERT(sizeof(bits) == sizeof(value));
2328 memcpy(&bits, &value, sizeof(bits));
2329 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2330 fConstantBuffer);
2331 fFloatConstants[value] = result;
2332 return result;
2333 }
2334 return entry->second;
2335 } else {
2336 ASSERT(f.fType == *fContext.fDouble_Type);
2337 auto entry = fDoubleConstants.find(f.fValue);
2338 if (entry == fDoubleConstants.end()) {
2339 SpvId result = this->nextId();
2340 uint64_t bits;
2341 ASSERT(sizeof(bits) == sizeof(f.fValue));
2342 memcpy(&bits, &f.fValue, sizeof(bits));
2343 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2344 bits & 0xffffffff, bits >> 32, fConstantBuffer);
2345 fDoubleConstants[f.fValue] = result;
2346 return result;
2347 }
2348 return entry->second;
2349 }
2350 }
2351
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2352 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2353 SpvId result = fFunctionMap[&f];
2354 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2355 SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2356 this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2357 for (size_t i = 0; i < f.fParameters.size(); i++) {
2358 SpvId id = this->nextId();
2359 fVariableMap[f.fParameters[i]] = id;
2360 SpvId type;
2361 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2362 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2363 }
2364 return result;
2365 }
2366
writeFunction(const FunctionDefinition & f,OutputStream & out)2367 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2368 fVariableBuffer.reset();
2369 SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2370 this->writeLabel(this->nextId(), out);
2371 if (f.fDeclaration.fName == "main") {
2372 write_stringstream(fGlobalInitializersBuffer, out);
2373 }
2374 StringStream bodyBuffer;
2375 this->writeBlock((Block&) *f.fBody, bodyBuffer);
2376 write_stringstream(fVariableBuffer, out);
2377 write_stringstream(bodyBuffer, out);
2378 if (fCurrentBlock) {
2379 if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2380 this->writeInstruction(SpvOpReturn, out);
2381 } else {
2382 this->writeInstruction(SpvOpUnreachable, out);
2383 }
2384 }
2385 this->writeInstruction(SpvOpFunctionEnd, out);
2386 return result;
2387 }
2388
writeLayout(const Layout & layout,SpvId target)2389 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2390 if (layout.fLocation >= 0) {
2391 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2392 fDecorationBuffer);
2393 }
2394 if (layout.fBinding >= 0) {
2395 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2396 fDecorationBuffer);
2397 }
2398 if (layout.fIndex >= 0) {
2399 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2400 fDecorationBuffer);
2401 }
2402 if (layout.fSet >= 0) {
2403 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2404 fDecorationBuffer);
2405 }
2406 if (layout.fInputAttachmentIndex >= 0) {
2407 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2408 layout.fInputAttachmentIndex, fDecorationBuffer);
2409 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2410 }
2411 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2412 layout.fBuiltin != SK_IN_BUILTIN) {
2413 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2414 fDecorationBuffer);
2415 }
2416 }
2417
writeLayout(const Layout & layout,SpvId target,int member)2418 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2419 if (layout.fLocation >= 0) {
2420 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2421 layout.fLocation, fDecorationBuffer);
2422 }
2423 if (layout.fBinding >= 0) {
2424 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2425 layout.fBinding, fDecorationBuffer);
2426 }
2427 if (layout.fIndex >= 0) {
2428 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2429 layout.fIndex, fDecorationBuffer);
2430 }
2431 if (layout.fSet >= 0) {
2432 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2433 layout.fSet, fDecorationBuffer);
2434 }
2435 if (layout.fInputAttachmentIndex >= 0) {
2436 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2437 layout.fInputAttachmentIndex, fDecorationBuffer);
2438 }
2439 if (layout.fBuiltin >= 0) {
2440 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2441 layout.fBuiltin, fDecorationBuffer);
2442 }
2443 }
2444
writeInterfaceBlock(const InterfaceBlock & intf)2445 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2446 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2447 bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2448 Layout::kPushConstant_Flag));
2449 MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2450 MemoryLayout(MemoryLayout::k430_Standard) :
2451 fDefaultLayout;
2452 SpvId result = this->nextId();
2453 const Type* type = &intf.fVariable.fType;
2454 if (fProgram.fInputs.fRTHeight) {
2455 ASSERT(fRTHeightStructId == (SpvId) -1);
2456 ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2457 std::vector<Type::Field> fields = type->fields();
2458 fRTHeightStructId = result;
2459 fRTHeightFieldIndex = fields.size();
2460 fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2461 type = new Type(type->fOffset, type->name(), fields);
2462 }
2463 SpvId typeId = this->getType(*type, memoryLayout);
2464 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2465 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2466 } else {
2467 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2468 }
2469 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2470 SpvId ptrType = this->nextId();
2471 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2472 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2473 Layout layout = intf.fVariable.fModifiers.fLayout;
2474 if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2475 layout.fSet = 0;
2476 }
2477 this->writeLayout(layout, result);
2478 fVariableMap[&intf.fVariable] = result;
2479 if (fProgram.fInputs.fRTHeight) {
2480 delete type;
2481 }
2482 return result;
2483 }
2484
writePrecisionModifier(const Modifiers & modifiers,SpvId id)2485 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2486 if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2487 (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2488 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2489 }
2490 }
2491
2492 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2493 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2494 OutputStream& out) {
2495 for (size_t i = 0; i < decl.fVars.size(); i++) {
2496 if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2497 continue;
2498 }
2499 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2500 const Variable* var = varDecl.fVar;
2501 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2502 // in the OpenGL backend.
2503 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2504 Modifiers::kWriteOnly_Flag |
2505 Modifiers::kCoherent_Flag |
2506 Modifiers::kVolatile_Flag |
2507 Modifiers::kRestrict_Flag)));
2508 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2509 continue;
2510 }
2511 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2512 kind != Program::kFragment_Kind) {
2513 ASSERT(!fProgram.fSettings.fFragColorIsInOut);
2514 continue;
2515 }
2516 if (!var->fReadCount && !var->fWriteCount &&
2517 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2518 Modifiers::kOut_Flag |
2519 Modifiers::kUniform_Flag |
2520 Modifiers::kBuffer_Flag))) {
2521 // variable is dead and not an input / output var (the Vulkan debug layers complain if
2522 // we elide an interface var, even if it's dead)
2523 continue;
2524 }
2525 SpvStorageClass_ storageClass;
2526 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2527 storageClass = SpvStorageClassInput;
2528 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2529 storageClass = SpvStorageClassOutput;
2530 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2531 if (var->fType.kind() == Type::kSampler_Kind) {
2532 storageClass = SpvStorageClassUniformConstant;
2533 } else {
2534 storageClass = SpvStorageClassUniform;
2535 }
2536 } else {
2537 storageClass = SpvStorageClassPrivate;
2538 }
2539 SpvId id = this->nextId();
2540 fVariableMap[var] = id;
2541 SpvId type = this->getPointerType(var->fType, storageClass);
2542 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2543 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2544 this->writePrecisionModifier(var->fModifiers, id);
2545 if (varDecl.fValue) {
2546 ASSERT(!fCurrentBlock);
2547 fCurrentBlock = -1;
2548 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2549 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2550 fCurrentBlock = 0;
2551 }
2552 this->writeLayout(var->fModifiers.fLayout, id);
2553 if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2554 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2555 }
2556 if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2557 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2558 fDecorationBuffer);
2559 }
2560 }
2561 }
2562
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2563 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2564 for (const auto& stmt : decl.fVars) {
2565 ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2566 VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2567 const Variable* var = varDecl.fVar;
2568 // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2569 // in the OpenGL backend.
2570 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2571 Modifiers::kWriteOnly_Flag |
2572 Modifiers::kCoherent_Flag |
2573 Modifiers::kVolatile_Flag |
2574 Modifiers::kRestrict_Flag)));
2575 SpvId id = this->nextId();
2576 fVariableMap[var] = id;
2577 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2578 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2579 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2580 if (varDecl.fValue) {
2581 SpvId value = this->writeExpression(*varDecl.fValue, out);
2582 this->writeInstruction(SpvOpStore, id, value, out);
2583 }
2584 }
2585 }
2586
writeStatement(const Statement & s,OutputStream & out)2587 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2588 switch (s.fKind) {
2589 case Statement::kNop_Kind:
2590 break;
2591 case Statement::kBlock_Kind:
2592 this->writeBlock((Block&) s, out);
2593 break;
2594 case Statement::kExpression_Kind:
2595 this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2596 break;
2597 case Statement::kReturn_Kind:
2598 this->writeReturnStatement((ReturnStatement&) s, out);
2599 break;
2600 case Statement::kVarDeclarations_Kind:
2601 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2602 break;
2603 case Statement::kIf_Kind:
2604 this->writeIfStatement((IfStatement&) s, out);
2605 break;
2606 case Statement::kFor_Kind:
2607 this->writeForStatement((ForStatement&) s, out);
2608 break;
2609 case Statement::kWhile_Kind:
2610 this->writeWhileStatement((WhileStatement&) s, out);
2611 break;
2612 case Statement::kDo_Kind:
2613 this->writeDoStatement((DoStatement&) s, out);
2614 break;
2615 case Statement::kSwitch_Kind:
2616 this->writeSwitchStatement((SwitchStatement&) s, out);
2617 break;
2618 case Statement::kBreak_Kind:
2619 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2620 break;
2621 case Statement::kContinue_Kind:
2622 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2623 break;
2624 case Statement::kDiscard_Kind:
2625 this->writeInstruction(SpvOpKill, out);
2626 break;
2627 default:
2628 ABORT("unsupported statement: %s", s.description().c_str());
2629 }
2630 }
2631
writeBlock(const Block & b,OutputStream & out)2632 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2633 for (size_t i = 0; i < b.fStatements.size(); i++) {
2634 this->writeStatement(*b.fStatements[i], out);
2635 }
2636 }
2637
writeIfStatement(const IfStatement & stmt,OutputStream & out)2638 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2639 SpvId test = this->writeExpression(*stmt.fTest, out);
2640 SpvId ifTrue = this->nextId();
2641 SpvId ifFalse = this->nextId();
2642 if (stmt.fIfFalse) {
2643 SpvId end = this->nextId();
2644 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2645 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2646 this->writeLabel(ifTrue, out);
2647 this->writeStatement(*stmt.fIfTrue, out);
2648 if (fCurrentBlock) {
2649 this->writeInstruction(SpvOpBranch, end, out);
2650 }
2651 this->writeLabel(ifFalse, out);
2652 this->writeStatement(*stmt.fIfFalse, out);
2653 if (fCurrentBlock) {
2654 this->writeInstruction(SpvOpBranch, end, out);
2655 }
2656 this->writeLabel(end, out);
2657 } else {
2658 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2659 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2660 this->writeLabel(ifTrue, out);
2661 this->writeStatement(*stmt.fIfTrue, out);
2662 if (fCurrentBlock) {
2663 this->writeInstruction(SpvOpBranch, ifFalse, out);
2664 }
2665 this->writeLabel(ifFalse, out);
2666 }
2667 }
2668
writeForStatement(const ForStatement & f,OutputStream & out)2669 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2670 if (f.fInitializer) {
2671 this->writeStatement(*f.fInitializer, out);
2672 }
2673 SpvId header = this->nextId();
2674 SpvId start = this->nextId();
2675 SpvId body = this->nextId();
2676 SpvId next = this->nextId();
2677 fContinueTarget.push(next);
2678 SpvId end = this->nextId();
2679 fBreakTarget.push(end);
2680 this->writeInstruction(SpvOpBranch, header, out);
2681 this->writeLabel(header, out);
2682 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2683 this->writeInstruction(SpvOpBranch, start, out);
2684 this->writeLabel(start, out);
2685 if (f.fTest) {
2686 SpvId test = this->writeExpression(*f.fTest, out);
2687 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2688 }
2689 this->writeLabel(body, out);
2690 this->writeStatement(*f.fStatement, out);
2691 if (fCurrentBlock) {
2692 this->writeInstruction(SpvOpBranch, next, out);
2693 }
2694 this->writeLabel(next, out);
2695 if (f.fNext) {
2696 this->writeExpression(*f.fNext, out);
2697 }
2698 this->writeInstruction(SpvOpBranch, header, out);
2699 this->writeLabel(end, out);
2700 fBreakTarget.pop();
2701 fContinueTarget.pop();
2702 }
2703
writeWhileStatement(const WhileStatement & w,OutputStream & out)2704 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2705 // We believe the while loop code below will work, but Skia doesn't actually use them and
2706 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2707 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2708 // message, simply remove the error call below to see whether our while loop support actually
2709 // works.
2710 fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2711 "see SkSLSPIRVCodeGenerator.cpp for details");
2712
2713 SpvId header = this->nextId();
2714 SpvId start = this->nextId();
2715 SpvId body = this->nextId();
2716 fContinueTarget.push(start);
2717 SpvId end = this->nextId();
2718 fBreakTarget.push(end);
2719 this->writeInstruction(SpvOpBranch, header, out);
2720 this->writeLabel(header, out);
2721 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2722 this->writeInstruction(SpvOpBranch, start, out);
2723 this->writeLabel(start, out);
2724 SpvId test = this->writeExpression(*w.fTest, out);
2725 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2726 this->writeLabel(body, out);
2727 this->writeStatement(*w.fStatement, out);
2728 if (fCurrentBlock) {
2729 this->writeInstruction(SpvOpBranch, start, out);
2730 }
2731 this->writeLabel(end, out);
2732 fBreakTarget.pop();
2733 fContinueTarget.pop();
2734 }
2735
writeDoStatement(const DoStatement & d,OutputStream & out)2736 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2737 // We believe the do loop code below will work, but Skia doesn't actually use them and
2738 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2739 // the time being, we just fail with an error due to the lack of testing. If you encounter this
2740 // message, simply remove the error call below to see whether our do loop support actually
2741 // works.
2742 fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2743 "SkSLSPIRVCodeGenerator.cpp for details");
2744
2745 SpvId header = this->nextId();
2746 SpvId start = this->nextId();
2747 SpvId next = this->nextId();
2748 fContinueTarget.push(next);
2749 SpvId end = this->nextId();
2750 fBreakTarget.push(end);
2751 this->writeInstruction(SpvOpBranch, header, out);
2752 this->writeLabel(header, out);
2753 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2754 this->writeInstruction(SpvOpBranch, start, out);
2755 this->writeLabel(start, out);
2756 this->writeStatement(*d.fStatement, out);
2757 if (fCurrentBlock) {
2758 this->writeInstruction(SpvOpBranch, next, out);
2759 }
2760 this->writeLabel(next, out);
2761 SpvId test = this->writeExpression(*d.fTest, out);
2762 this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2763 this->writeLabel(end, out);
2764 fBreakTarget.pop();
2765 fContinueTarget.pop();
2766 }
2767
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2768 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2769 SpvId value = this->writeExpression(*s.fValue, out);
2770 std::vector<SpvId> labels;
2771 SpvId end = this->nextId();
2772 SpvId defaultLabel = end;
2773 fBreakTarget.push(end);
2774 int size = 3;
2775 for (const auto& c : s.fCases) {
2776 SpvId label = this->nextId();
2777 labels.push_back(label);
2778 if (c->fValue) {
2779 size += 2;
2780 } else {
2781 defaultLabel = label;
2782 }
2783 }
2784 labels.push_back(end);
2785 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2786 this->writeOpCode(SpvOpSwitch, size, out);
2787 this->writeWord(value, out);
2788 this->writeWord(defaultLabel, out);
2789 for (size_t i = 0; i < s.fCases.size(); ++i) {
2790 if (!s.fCases[i]->fValue) {
2791 continue;
2792 }
2793 ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2794 this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2795 this->writeWord(labels[i], out);
2796 }
2797 for (size_t i = 0; i < s.fCases.size(); ++i) {
2798 this->writeLabel(labels[i], out);
2799 for (const auto& stmt : s.fCases[i]->fStatements) {
2800 this->writeStatement(*stmt, out);
2801 }
2802 if (fCurrentBlock) {
2803 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
2804 }
2805 }
2806 this->writeLabel(end, out);
2807 fBreakTarget.pop();
2808 }
2809
writeReturnStatement(const ReturnStatement & r,OutputStream & out)2810 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
2811 if (r.fExpression) {
2812 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2813 out);
2814 } else {
2815 this->writeInstruction(SpvOpReturn, out);
2816 }
2817 }
2818
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)2819 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
2820 ASSERT(fProgram.fKind == Program::kGeometry_Kind);
2821 int invocations = 1;
2822 for (size_t i = 0; i < fProgram.fElements.size(); i++) {
2823 if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) {
2824 const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers;
2825 if (m.fFlags & Modifiers::kIn_Flag) {
2826 if (m.fLayout.fInvocations != -1) {
2827 invocations = m.fLayout.fInvocations;
2828 }
2829 SpvId input;
2830 switch (m.fLayout.fPrimitive) {
2831 case Layout::kPoints_Primitive:
2832 input = SpvExecutionModeInputPoints;
2833 break;
2834 case Layout::kLines_Primitive:
2835 input = SpvExecutionModeInputLines;
2836 break;
2837 case Layout::kLinesAdjacency_Primitive:
2838 input = SpvExecutionModeInputLinesAdjacency;
2839 break;
2840 case Layout::kTriangles_Primitive:
2841 input = SpvExecutionModeTriangles;
2842 break;
2843 case Layout::kTrianglesAdjacency_Primitive:
2844 input = SpvExecutionModeInputTrianglesAdjacency;
2845 break;
2846 default:
2847 input = 0;
2848 break;
2849 }
2850 if (input) {
2851 this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
2852 }
2853 } else if (m.fFlags & Modifiers::kOut_Flag) {
2854 SpvId output;
2855 switch (m.fLayout.fPrimitive) {
2856 case Layout::kPoints_Primitive:
2857 output = SpvExecutionModeOutputPoints;
2858 break;
2859 case Layout::kLineStrip_Primitive:
2860 output = SpvExecutionModeOutputLineStrip;
2861 break;
2862 case Layout::kTriangleStrip_Primitive:
2863 output = SpvExecutionModeOutputTriangleStrip;
2864 break;
2865 default:
2866 output = 0;
2867 break;
2868 }
2869 if (output) {
2870 this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
2871 }
2872 if (m.fLayout.fMaxVertices != -1) {
2873 this->writeInstruction(SpvOpExecutionMode, entryPoint,
2874 SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
2875 out);
2876 }
2877 }
2878 }
2879 }
2880 this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
2881 invocations, out);
2882 }
2883
writeInstructions(const Program & program,OutputStream & out)2884 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
2885 fGLSLExtendedInstructions = this->nextId();
2886 StringStream body;
2887 std::set<SpvId> interfaceVars;
2888 // assign IDs to functions, determine sk_in size
2889 int skInSize = -1;
2890 for (size_t i = 0; i < program.fElements.size(); i++) {
2891 switch (program.fElements[i]->fKind) {
2892 case ProgramElement::kFunction_Kind: {
2893 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2894 fFunctionMap[&f.fDeclaration] = this->nextId();
2895 break;
2896 }
2897 case ProgramElement::kModifiers_Kind: {
2898 Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers;
2899 if (m.fFlags & Modifiers::kIn_Flag) {
2900 switch (m.fLayout.fPrimitive) {
2901 case Layout::kPoints_Primitive: // break
2902 case Layout::kLines_Primitive:
2903 skInSize = 1;
2904 break;
2905 case Layout::kLinesAdjacency_Primitive: // break
2906 skInSize = 2;
2907 break;
2908 case Layout::kTriangles_Primitive: // break
2909 case Layout::kTrianglesAdjacency_Primitive:
2910 skInSize = 3;
2911 break;
2912 default:
2913 break;
2914 }
2915 }
2916 break;
2917 }
2918 default:
2919 break;
2920 }
2921 }
2922 for (size_t i = 0; i < program.fElements.size(); i++) {
2923 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2924 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2925 if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
2926 ASSERT(skInSize != -1);
2927 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
2928 }
2929 SpvId id = this->writeInterfaceBlock(intf);
2930 if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2931 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2932 interfaceVars.insert(id);
2933 }
2934 }
2935 }
2936 for (size_t i = 0; i < program.fElements.size(); i++) {
2937 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2938 this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2939 body);
2940 }
2941 }
2942 for (size_t i = 0; i < program.fElements.size(); i++) {
2943 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2944 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2945 }
2946 }
2947 const FunctionDeclaration* main = nullptr;
2948 for (auto entry : fFunctionMap) {
2949 if (entry.first->fName == "main") {
2950 main = entry.first;
2951 }
2952 }
2953 ASSERT(main);
2954 for (auto entry : fVariableMap) {
2955 const Variable* var = entry.first;
2956 if (var->fStorage == Variable::kGlobal_Storage &&
2957 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2958 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2959 interfaceVars.insert(entry.second);
2960 }
2961 }
2962 this->writeCapabilities(out);
2963 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2964 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2965 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
2966 (int32_t) interfaceVars.size(), out);
2967 switch (program.fKind) {
2968 case Program::kVertex_Kind:
2969 this->writeWord(SpvExecutionModelVertex, out);
2970 break;
2971 case Program::kFragment_Kind:
2972 this->writeWord(SpvExecutionModelFragment, out);
2973 break;
2974 case Program::kGeometry_Kind:
2975 this->writeWord(SpvExecutionModelGeometry, out);
2976 break;
2977 default:
2978 ABORT("cannot write this kind of program to SPIR-V\n");
2979 }
2980 SpvId entryPoint = fFunctionMap[main];
2981 this->writeWord(entryPoint, out);
2982 this->writeString(main->fName.fChars, main->fName.fLength, out);
2983 for (int var : interfaceVars) {
2984 this->writeWord(var, out);
2985 }
2986 if (program.fKind == Program::kGeometry_Kind) {
2987 this->writeGeometryShaderExecutionMode(entryPoint, out);
2988 }
2989 if (program.fKind == Program::kFragment_Kind) {
2990 this->writeInstruction(SpvOpExecutionMode,
2991 fFunctionMap[main],
2992 SpvExecutionModeOriginUpperLeft,
2993 out);
2994 }
2995 for (size_t i = 0; i < program.fElements.size(); i++) {
2996 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2997 this->writeInstruction(SpvOpSourceExtension,
2998 ((Extension&) *program.fElements[i]).fName.c_str(),
2999 out);
3000 }
3001 }
3002
3003 write_stringstream(fExtraGlobalsBuffer, out);
3004 write_stringstream(fNameBuffer, out);
3005 write_stringstream(fDecorationBuffer, out);
3006 write_stringstream(fConstantBuffer, out);
3007 write_stringstream(fExternalFunctionsBuffer, out);
3008 write_stringstream(body, out);
3009 }
3010
generateCode()3011 bool SPIRVCodeGenerator::generateCode() {
3012 ASSERT(!fErrors.errorCount());
3013 this->writeWord(SpvMagicNumber, *fOut);
3014 this->writeWord(SpvVersion, *fOut);
3015 this->writeWord(SKSL_MAGIC, *fOut);
3016 StringStream buffer;
3017 this->writeInstructions(fProgram, buffer);
3018 this->writeWord(fIdCount, *fOut);
3019 this->writeWord(0, *fOut); // reserved, always zero
3020 write_stringstream(buffer, *fOut);
3021 return 0 == fErrors.errorCount();
3022 }
3023
3024 }
3025