1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10 #include "src/sksl/GLSL.std.450.h"
11
12 #include "include/sksl/DSLCore.h"
13 #include "src/sksl/SkSLCompiler.h"
14 #include "src/sksl/SkSLOperators.h"
15 #include "src/sksl/SkSLThreadContext.h"
16 #include "src/sksl/ir/SkSLBinaryExpression.h"
17 #include "src/sksl/ir/SkSLBlock.h"
18 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
19 #include "src/sksl/ir/SkSLConstructorCompound.h"
20 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
21 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
22 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
23 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
24 #include "src/sksl/ir/SkSLConstructorSplat.h"
25 #include "src/sksl/ir/SkSLDoStatement.h"
26 #include "src/sksl/ir/SkSLExpressionStatement.h"
27 #include "src/sksl/ir/SkSLExtension.h"
28 #include "src/sksl/ir/SkSLField.h"
29 #include "src/sksl/ir/SkSLFieldAccess.h"
30 #include "src/sksl/ir/SkSLForStatement.h"
31 #include "src/sksl/ir/SkSLFunctionCall.h"
32 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
33 #include "src/sksl/ir/SkSLFunctionDefinition.h"
34 #include "src/sksl/ir/SkSLIfStatement.h"
35 #include "src/sksl/ir/SkSLIndexExpression.h"
36 #include "src/sksl/ir/SkSLInterfaceBlock.h"
37 #include "src/sksl/ir/SkSLPostfixExpression.h"
38 #include "src/sksl/ir/SkSLPrefixExpression.h"
39 #include "src/sksl/ir/SkSLReturnStatement.h"
40 #include "src/sksl/ir/SkSLSwitchStatement.h"
41 #include "src/sksl/ir/SkSLSwizzle.h"
42 #include "src/sksl/ir/SkSLTernaryExpression.h"
43 #include "src/sksl/ir/SkSLVarDeclarations.h"
44 #include "src/sksl/ir/SkSLVariableReference.h"
45
46 #ifdef SK_VULKAN
47 #include "src/gpu/vk/GrVkCaps.h"
48 #endif
49
50 #define kLast_Capability SpvCapabilityMultiViewport
51
52 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
53 constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001;
54
55 namespace SkSL {
56
57 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
58 // sksl generator if we want.
59 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
60 static const int32_t SKSL_MAGIC = 0x001F0000;
61
setupIntrinsics()62 void SPIRVCodeGenerator::setupIntrinsics() {
63 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
64 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
65 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
66 GLSLstd450 ## ifFloat, \
67 GLSLstd450 ## ifInt, \
68 GLSLstd450 ## ifUInt, \
69 SpvOpUndef}
70 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
71 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
72 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
73 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
74 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
75 SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
76 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
77 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
78 k ## x ## _SpecialIntrinsic}
79 fIntrinsicMap[k_round_IntrinsicKind] = ALL_GLSL(Round);
80 fIntrinsicMap[k_roundEven_IntrinsicKind] = ALL_GLSL(RoundEven);
81 fIntrinsicMap[k_trunc_IntrinsicKind] = ALL_GLSL(Trunc);
82 fIntrinsicMap[k_abs_IntrinsicKind] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
83 fIntrinsicMap[k_sign_IntrinsicKind] = BY_TYPE_GLSL(FSign, SSign, SSign);
84 fIntrinsicMap[k_floor_IntrinsicKind] = ALL_GLSL(Floor);
85 fIntrinsicMap[k_ceil_IntrinsicKind] = ALL_GLSL(Ceil);
86 fIntrinsicMap[k_fract_IntrinsicKind] = ALL_GLSL(Fract);
87 fIntrinsicMap[k_radians_IntrinsicKind] = ALL_GLSL(Radians);
88 fIntrinsicMap[k_degrees_IntrinsicKind] = ALL_GLSL(Degrees);
89 fIntrinsicMap[k_sin_IntrinsicKind] = ALL_GLSL(Sin);
90 fIntrinsicMap[k_cos_IntrinsicKind] = ALL_GLSL(Cos);
91 fIntrinsicMap[k_tan_IntrinsicKind] = ALL_GLSL(Tan);
92 fIntrinsicMap[k_asin_IntrinsicKind] = ALL_GLSL(Asin);
93 fIntrinsicMap[k_acos_IntrinsicKind] = ALL_GLSL(Acos);
94 fIntrinsicMap[k_atan_IntrinsicKind] = SPECIAL(Atan);
95 fIntrinsicMap[k_sinh_IntrinsicKind] = ALL_GLSL(Sinh);
96 fIntrinsicMap[k_cosh_IntrinsicKind] = ALL_GLSL(Cosh);
97 fIntrinsicMap[k_tanh_IntrinsicKind] = ALL_GLSL(Tanh);
98 fIntrinsicMap[k_asinh_IntrinsicKind] = ALL_GLSL(Asinh);
99 fIntrinsicMap[k_acosh_IntrinsicKind] = ALL_GLSL(Acosh);
100 fIntrinsicMap[k_atanh_IntrinsicKind] = ALL_GLSL(Atanh);
101 fIntrinsicMap[k_pow_IntrinsicKind] = ALL_GLSL(Pow);
102 fIntrinsicMap[k_exp_IntrinsicKind] = ALL_GLSL(Exp);
103 fIntrinsicMap[k_log_IntrinsicKind] = ALL_GLSL(Log);
104 fIntrinsicMap[k_exp2_IntrinsicKind] = ALL_GLSL(Exp2);
105 fIntrinsicMap[k_log2_IntrinsicKind] = ALL_GLSL(Log2);
106 fIntrinsicMap[k_sqrt_IntrinsicKind] = ALL_GLSL(Sqrt);
107 fIntrinsicMap[k_inverse_IntrinsicKind] = ALL_GLSL(MatrixInverse);
108 fIntrinsicMap[k_outerProduct_IntrinsicKind] = ALL_SPIRV(OuterProduct);
109 fIntrinsicMap[k_transpose_IntrinsicKind] = ALL_SPIRV(Transpose);
110 fIntrinsicMap[k_isinf_IntrinsicKind] = ALL_SPIRV(IsInf);
111 fIntrinsicMap[k_isnan_IntrinsicKind] = ALL_SPIRV(IsNan);
112 fIntrinsicMap[k_inversesqrt_IntrinsicKind] = ALL_GLSL(InverseSqrt);
113 fIntrinsicMap[k_determinant_IntrinsicKind] = ALL_GLSL(Determinant);
114 fIntrinsicMap[k_matrixCompMult_IntrinsicKind] = SPECIAL(MatrixCompMult);
115 fIntrinsicMap[k_matrixInverse_IntrinsicKind] = ALL_GLSL(MatrixInverse);
116 fIntrinsicMap[k_mod_IntrinsicKind] = SPECIAL(Mod);
117 fIntrinsicMap[k_modf_IntrinsicKind] = ALL_GLSL(Modf);
118 fIntrinsicMap[k_min_IntrinsicKind] = SPECIAL(Min);
119 fIntrinsicMap[k_max_IntrinsicKind] = SPECIAL(Max);
120 fIntrinsicMap[k_clamp_IntrinsicKind] = SPECIAL(Clamp);
121 fIntrinsicMap[k_saturate_IntrinsicKind] = SPECIAL(Saturate);
122 fIntrinsicMap[k_dot_IntrinsicKind] = FLOAT_SPIRV(Dot);
123 fIntrinsicMap[k_mix_IntrinsicKind] = SPECIAL(Mix);
124 fIntrinsicMap[k_step_IntrinsicKind] = SPECIAL(Step);
125 fIntrinsicMap[k_smoothstep_IntrinsicKind] = SPECIAL(SmoothStep);
126 fIntrinsicMap[k_fma_IntrinsicKind] = ALL_GLSL(Fma);
127 fIntrinsicMap[k_frexp_IntrinsicKind] = ALL_GLSL(Frexp);
128 fIntrinsicMap[k_ldexp_IntrinsicKind] = ALL_GLSL(Ldexp);
129
130 #define PACK(type) fIntrinsicMap[k_pack##type##_IntrinsicKind] = ALL_GLSL(Pack##type); \
131 fIntrinsicMap[k_unpack##type##_IntrinsicKind] = ALL_GLSL(Unpack##type)
132 PACK(Snorm4x8);
133 PACK(Unorm4x8);
134 PACK(Snorm2x16);
135 PACK(Unorm2x16);
136 PACK(Half2x16);
137 PACK(Double2x32);
138 #undef PACK
139 fIntrinsicMap[k_length_IntrinsicKind] = ALL_GLSL(Length);
140 fIntrinsicMap[k_distance_IntrinsicKind] = ALL_GLSL(Distance);
141 fIntrinsicMap[k_cross_IntrinsicKind] = ALL_GLSL(Cross);
142 fIntrinsicMap[k_normalize_IntrinsicKind] = ALL_GLSL(Normalize);
143 fIntrinsicMap[k_faceforward_IntrinsicKind] = ALL_GLSL(FaceForward);
144 fIntrinsicMap[k_reflect_IntrinsicKind] = ALL_GLSL(Reflect);
145 fIntrinsicMap[k_refract_IntrinsicKind] = ALL_GLSL(Refract);
146 fIntrinsicMap[k_bitCount_IntrinsicKind] = ALL_SPIRV(BitCount);
147 fIntrinsicMap[k_findLSB_IntrinsicKind] = ALL_GLSL(FindILsb);
148 fIntrinsicMap[k_findMSB_IntrinsicKind] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
149 fIntrinsicMap[k_dFdx_IntrinsicKind] = FLOAT_SPIRV(DPdx);
150 fIntrinsicMap[k_dFdy_IntrinsicKind] = SPECIAL(DFdy);
151 fIntrinsicMap[k_fwidth_IntrinsicKind] = FLOAT_SPIRV(Fwidth);
152 fIntrinsicMap[k_makeSampler2D_IntrinsicKind] = SPECIAL(SampledImage);
153
154 fIntrinsicMap[k_sample_IntrinsicKind] = SPECIAL(Texture);
155 fIntrinsicMap[k_subpassLoad_IntrinsicKind] = SPECIAL(SubpassLoad);
156
157 fIntrinsicMap[k_floatBitsToInt_IntrinsicKind] = ALL_SPIRV(Bitcast);
158 fIntrinsicMap[k_floatBitsToUint_IntrinsicKind] = ALL_SPIRV(Bitcast);
159 fIntrinsicMap[k_intBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast);
160 fIntrinsicMap[k_uintBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast);
161
162 fIntrinsicMap[k_any_IntrinsicKind] = BOOL_SPIRV(Any);
163 fIntrinsicMap[k_all_IntrinsicKind] = BOOL_SPIRV(All);
164 fIntrinsicMap[k_not_IntrinsicKind] = BOOL_SPIRV(LogicalNot);
165 fIntrinsicMap[k_equal_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
166 SpvOpFOrdEqual, SpvOpIEqual,
167 SpvOpIEqual, SpvOpLogicalEqual};
168 fIntrinsicMap[k_notEqual_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
169 SpvOpFOrdNotEqual, SpvOpINotEqual,
170 SpvOpINotEqual,
171 SpvOpLogicalNotEqual};
172 fIntrinsicMap[k_lessThan_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
173 SpvOpFOrdLessThan,
174 SpvOpSLessThan,
175 SpvOpULessThan,
176 SpvOpUndef};
177 fIntrinsicMap[k_lessThanEqual_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
178 SpvOpFOrdLessThanEqual,
179 SpvOpSLessThanEqual,
180 SpvOpULessThanEqual,
181 SpvOpUndef};
182 fIntrinsicMap[k_greaterThan_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
183 SpvOpFOrdGreaterThan,
184 SpvOpSGreaterThan,
185 SpvOpUGreaterThan,
186 SpvOpUndef};
187 fIntrinsicMap[k_greaterThanEqual_IntrinsicKind] = Intrinsic{kSPIRV_IntrinsicOpcodeKind,
188 SpvOpFOrdGreaterThanEqual,
189 SpvOpSGreaterThanEqual,
190 SpvOpUGreaterThanEqual,
191 SpvOpUndef};
192 // interpolateAt* not yet supported...
193 }
194
writeWord(int32_t word,OutputStream & out)195 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
196 out.write((const char*) &word, sizeof(word));
197 }
198
is_float(const Context & context,const Type & type)199 static bool is_float(const Context& context, const Type& type) {
200 return (type.isScalar() || type.isVector() || type.isMatrix()) &&
201 type.componentType().isFloat();
202 }
203
is_signed(const Context & context,const Type & type)204 static bool is_signed(const Context& context, const Type& type) {
205 return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
206 }
207
is_unsigned(const Context & context,const Type & type)208 static bool is_unsigned(const Context& context, const Type& type) {
209 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
210 }
211
is_bool(const Context & context,const Type & type)212 static bool is_bool(const Context& context, const Type& type) {
213 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
214 }
215
is_out(const Modifiers & m)216 static bool is_out(const Modifiers& m) {
217 return (m.fFlags & Modifiers::kOut_Flag) != 0;
218 }
219
is_in(const Modifiers & m)220 static bool is_in(const Modifiers& m) {
221 switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
222 case Modifiers::kOut_Flag: // out
223 return false;
224
225 case 0: // implicit in
226 case Modifiers::kIn_Flag: // explicit in
227 case Modifiers::kOut_Flag | Modifiers::kIn_Flag: // inout
228 return true;
229
230 default: SkUNREACHABLE;
231 }
232 }
233
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)234 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
235 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
236 SkASSERT(opCode != SpvOpUndef);
237 switch (opCode) {
238 case SpvOpReturn: // fall through
239 case SpvOpReturnValue: // fall through
240 case SpvOpKill: // fall through
241 case SpvOpSwitch: // fall through
242 case SpvOpBranch: // fall through
243 case SpvOpBranchConditional:
244 if (fCurrentBlock == 0) {
245 // We just encountered dead code--instructions that don't have an associated block.
246 // Synthesize a label if this happens; this is necessary to satisfy the validator.
247 this->writeLabel(this->nextId(nullptr), out);
248 }
249 fCurrentBlock = 0;
250 break;
251 case SpvOpConstant: // fall through
252 case SpvOpConstantTrue: // fall through
253 case SpvOpConstantFalse: // fall through
254 case SpvOpConstantComposite: // fall through
255 case SpvOpTypeVoid: // fall through
256 case SpvOpTypeInt: // fall through
257 case SpvOpTypeFloat: // fall through
258 case SpvOpTypeBool: // fall through
259 case SpvOpTypeVector: // fall through
260 case SpvOpTypeMatrix: // fall through
261 case SpvOpTypeArray: // fall through
262 case SpvOpTypePointer: // fall through
263 case SpvOpTypeFunction: // fall through
264 case SpvOpTypeRuntimeArray: // fall through
265 case SpvOpTypeStruct: // fall through
266 case SpvOpTypeImage: // fall through
267 case SpvOpTypeSampledImage: // fall through
268 case SpvOpTypeSampler: // fall through
269 case SpvOpVariable: // fall through
270 case SpvOpFunction: // fall through
271 case SpvOpFunctionParameter: // fall through
272 case SpvOpFunctionEnd: // fall through
273 case SpvOpExecutionMode: // fall through
274 case SpvOpMemoryModel: // fall through
275 case SpvOpCapability: // fall through
276 case SpvOpExtInstImport: // fall through
277 case SpvOpEntryPoint: // fall through
278 case SpvOpSource: // fall through
279 case SpvOpSourceExtension: // fall through
280 case SpvOpName: // fall through
281 case SpvOpMemberName: // fall through
282 case SpvOpDecorate: // fall through
283 case SpvOpMemberDecorate:
284 break;
285 default:
286 // We may find ourselves with dead code--instructions that don't have an associated
287 // block. This should be a rare event, but if it happens, synthesize a label; this is
288 // necessary to satisfy the validator.
289 if (fCurrentBlock == 0) {
290 this->writeLabel(this->nextId(nullptr), out);
291 }
292 break;
293 }
294 this->writeWord((length << 16) | opCode, out);
295 }
296
writeLabel(SpvId label,OutputStream & out)297 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
298 SkASSERT(!fCurrentBlock);
299 fCurrentBlock = label;
300 this->writeInstruction(SpvOpLabel, label, out);
301 }
302
writeInstruction(SpvOp_ opCode,OutputStream & out)303 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
304 this->writeOpCode(opCode, 1, out);
305 }
306
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)307 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
308 this->writeOpCode(opCode, 2, out);
309 this->writeWord(word1, out);
310 }
311
writeString(std::string_view s,OutputStream & out)312 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
313 out.write(s.data(), s.length());
314 switch (s.length() % 4) {
315 case 1:
316 out.write8(0);
317 [[fallthrough]];
318 case 2:
319 out.write8(0);
320 [[fallthrough]];
321 case 3:
322 out.write8(0);
323 break;
324 default:
325 this->writeWord(0, out);
326 break;
327 }
328 }
329
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)330 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
331 OutputStream& out) {
332 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
333 this->writeString(string, out);
334 }
335
336
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)337 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
338 OutputStream& out) {
339 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
340 this->writeWord(word1, out);
341 this->writeString(string, out);
342 }
343
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)344 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
345 std::string_view string, OutputStream& out) {
346 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
347 this->writeWord(word1, out);
348 this->writeWord(word2, out);
349 this->writeString(string, out);
350 }
351
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)352 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
353 OutputStream& out) {
354 this->writeOpCode(opCode, 3, out);
355 this->writeWord(word1, out);
356 this->writeWord(word2, out);
357 }
358
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)359 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
360 int32_t word3, OutputStream& out) {
361 this->writeOpCode(opCode, 4, out);
362 this->writeWord(word1, out);
363 this->writeWord(word2, out);
364 this->writeWord(word3, out);
365 }
366
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)367 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
368 int32_t word3, int32_t word4, OutputStream& out) {
369 this->writeOpCode(opCode, 5, out);
370 this->writeWord(word1, out);
371 this->writeWord(word2, out);
372 this->writeWord(word3, out);
373 this->writeWord(word4, out);
374 }
375
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)376 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
377 int32_t word3, int32_t word4, int32_t word5,
378 OutputStream& out) {
379 this->writeOpCode(opCode, 6, out);
380 this->writeWord(word1, out);
381 this->writeWord(word2, out);
382 this->writeWord(word3, out);
383 this->writeWord(word4, out);
384 this->writeWord(word5, out);
385 }
386
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)387 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
388 int32_t word3, int32_t word4, int32_t word5,
389 int32_t word6, OutputStream& out) {
390 this->writeOpCode(opCode, 7, out);
391 this->writeWord(word1, out);
392 this->writeWord(word2, out);
393 this->writeWord(word3, out);
394 this->writeWord(word4, out);
395 this->writeWord(word5, out);
396 this->writeWord(word6, out);
397 }
398
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)399 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
400 int32_t word3, int32_t word4, int32_t word5,
401 int32_t word6, int32_t word7, OutputStream& out) {
402 this->writeOpCode(opCode, 8, out);
403 this->writeWord(word1, out);
404 this->writeWord(word2, out);
405 this->writeWord(word3, out);
406 this->writeWord(word4, out);
407 this->writeWord(word5, out);
408 this->writeWord(word6, out);
409 this->writeWord(word7, out);
410 }
411
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)412 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
413 int32_t word3, int32_t word4, int32_t word5,
414 int32_t word6, int32_t word7, int32_t word8,
415 OutputStream& out) {
416 this->writeOpCode(opCode, 9, out);
417 this->writeWord(word1, out);
418 this->writeWord(word2, out);
419 this->writeWord(word3, out);
420 this->writeWord(word4, out);
421 this->writeWord(word5, out);
422 this->writeWord(word6, out);
423 this->writeWord(word7, out);
424 this->writeWord(word8, out);
425 }
426
writeCapabilities(OutputStream & out)427 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
428 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
429 if (fCapabilities & bit) {
430 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
431 }
432 }
433 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
434 }
435
nextId(const Type * type)436 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
437 return this->nextId(type && type->hasPrecision() && !type->highPrecision()
438 ? Precision::kRelaxed
439 : Precision::kDefault);
440 }
441
nextId(Precision precision)442 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
443 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
444 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
445 fDecorationBuffer);
446 }
447 return fIdCount++;
448 }
449
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)450 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
451 SpvId resultId) {
452 this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
453 // go ahead and write all of the field types, so we don't inadvertently write them while we're
454 // in the middle of writing the struct instruction
455 std::vector<SpvId> types;
456 for (const auto& f : type.fields()) {
457 types.push_back(this->getType(*f.fType, memoryLayout));
458 }
459 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
460 this->writeWord(resultId, fConstantBuffer);
461 for (SpvId id : types) {
462 this->writeWord(id, fConstantBuffer);
463 }
464 size_t offset = 0;
465 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
466 const Type::Field& field = type.fields()[i];
467 if (!MemoryLayout::LayoutIsSupported(*field.fType)) {
468 fContext.fErrors->error(type.fLine, "type '" + field.fType->displayName() +
469 "' is not permitted here");
470 return;
471 }
472 size_t size = memoryLayout.size(*field.fType);
473 size_t alignment = memoryLayout.alignment(*field.fType);
474 const Layout& fieldLayout = field.fModifiers.fLayout;
475 if (fieldLayout.fOffset >= 0) {
476 if (fieldLayout.fOffset < (int) offset) {
477 fContext.fErrors->error(type.fLine, "offset of field '" + std::string(field.fName) +
478 "' must be at least " + std::to_string(offset));
479 }
480 if (fieldLayout.fOffset % alignment) {
481 fContext.fErrors->error(type.fLine,
482 "offset of field '" + std::string(field.fName) +
483 "' must be a multiple of " + std::to_string(alignment));
484 }
485 offset = fieldLayout.fOffset;
486 } else {
487 size_t mod = offset % alignment;
488 if (mod) {
489 offset += alignment - mod;
490 }
491 }
492 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
493 this->writeFieldLayout(fieldLayout, resultId, i);
494 if (field.fModifiers.fLayout.fBuiltin < 0) {
495 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
496 (SpvId) offset, fDecorationBuffer);
497 }
498 if (field.fType->isMatrix()) {
499 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
500 fDecorationBuffer);
501 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
502 (SpvId) memoryLayout.stride(*field.fType),
503 fDecorationBuffer);
504 }
505 if (!field.fType->highPrecision()) {
506 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
507 SpvDecorationRelaxedPrecision, fDecorationBuffer);
508 }
509 offset += size;
510 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
511 offset += alignment - offset % alignment;
512 }
513 }
514 }
515
getActualType(const Type & type)516 const Type& SPIRVCodeGenerator::getActualType(const Type& type) {
517 if (type.isFloat()) {
518 return *fContext.fTypes.fFloat;
519 }
520 if (type.isSigned()) {
521 return *fContext.fTypes.fInt;
522 }
523 if (type.isUnsigned()) {
524 return *fContext.fTypes.fUInt;
525 }
526 if (type.isMatrix() || type.isVector()) {
527 if (type.componentType().matches(*fContext.fTypes.fHalf)) {
528 return fContext.fTypes.fFloat->toCompound(fContext, type.columns(), type.rows());
529 }
530 if (type.componentType().matches(*fContext.fTypes.fShort)) {
531 return fContext.fTypes.fInt->toCompound(fContext, type.columns(), type.rows());
532 }
533 if (type.componentType().matches(*fContext.fTypes.fUShort)) {
534 return fContext.fTypes.fUInt->toCompound(fContext, type.columns(), type.rows());
535 }
536 }
537 return type;
538 }
539
getType(const Type & type)540 SpvId SPIRVCodeGenerator::getType(const Type& type) {
541 return this->getType(type, fDefaultLayout);
542 }
543
getType(const Type & rawType,const MemoryLayout & layout)544 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
545 const Type* type;
546 std::unique_ptr<Type> arrayType;
547 std::string arrayName;
548
549 if (rawType.isArray()) {
550 // For arrays, we need to synthesize a temporary Array type using the "actual" component
551 // type. That is, if `short[10]` is passed in, we need to synthesize a `int[10]` Type.
552 // Otherwise, we can end up with two different SpvIds for the same array type.
553 const Type& component = this->getActualType(rawType.componentType());
554 arrayName = component.getArrayName(rawType.columns());
555 arrayType = Type::MakeArrayType(arrayName, component, rawType.columns());
556 type = arrayType.get();
557 } else {
558 // For non-array types, we can simply look up the "actual" type and use it.
559 type = &this->getActualType(rawType);
560 }
561
562 std::string key(type->name());
563 if (type->isStruct() || type->isArray()) {
564 key += std::to_string(layout.fStd);
565 #ifdef SK_DEBUG
566 SkASSERT(layout.fStd == MemoryLayout::Standard::k140_Standard ||
567 layout.fStd == MemoryLayout::Standard::k430_Standard);
568 MemoryLayout::Standard otherStd = layout.fStd == MemoryLayout::Standard::k140_Standard
569 ? MemoryLayout::Standard::k430_Standard
570 : MemoryLayout::Standard::k140_Standard;
571 std::string otherKey = type->displayName() + std::to_string(otherStd);
572 SkASSERT(fTypeMap.find(otherKey) == fTypeMap.end());
573 #endif
574 }
575 auto entry = fTypeMap.find(key);
576 if (entry == fTypeMap.end()) {
577 SpvId result = this->nextId(nullptr);
578 switch (type->typeKind()) {
579 case Type::TypeKind::kScalar:
580 if (type->isBoolean()) {
581 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
582 } else if (type->isSigned()) {
583 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
584 } else if (type->isUnsigned()) {
585 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
586 } else if (type->isFloat()) {
587 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
588 } else {
589 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
590 }
591 break;
592 case Type::TypeKind::kVector:
593 this->writeInstruction(SpvOpTypeVector, result,
594 this->getType(type->componentType(), layout),
595 type->columns(), fConstantBuffer);
596 break;
597 case Type::TypeKind::kMatrix:
598 this->writeInstruction(
599 SpvOpTypeMatrix,
600 result,
601 this->getType(IndexExpression::IndexType(fContext, *type), layout),
602 type->columns(),
603 fConstantBuffer);
604 break;
605 case Type::TypeKind::kStruct:
606 this->writeStruct(*type, layout, result);
607 break;
608 case Type::TypeKind::kArray: {
609 if (!MemoryLayout::LayoutIsSupported(*type)) {
610 fContext.fErrors->error(type->fLine, "type '" + type->displayName() +
611 "' is not permitted here");
612 return this->nextId(nullptr);
613 }
614 if (type->columns() > 0) {
615 SpvId typeId = this->getType(type->componentType(), layout);
616 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
617 this->writeInstruction(SpvOpTypeArray, result, typeId, countId,
618 fConstantBuffer);
619 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
620 (int32_t) layout.stride(*type),
621 fDecorationBuffer);
622 } else {
623 // We shouldn't have any runtime-sized arrays right now
624 fContext.fErrors->error(type->fLine,
625 "runtime-sized arrays are not supported in SPIR-V");
626 this->writeInstruction(SpvOpTypeRuntimeArray, result,
627 this->getType(type->componentType(), layout),
628 fConstantBuffer);
629 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
630 (int32_t) layout.stride(*type),
631 fDecorationBuffer);
632 }
633 break;
634 }
635 case Type::TypeKind::kSampler: {
636 SpvId image = result;
637 if (SpvDimSubpassData != type->dimensions()) {
638 image = this->getType(type->textureType(), layout);
639 }
640 if (SpvDimBuffer == type->dimensions()) {
641 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
642 }
643 if (SpvDimSubpassData != type->dimensions()) {
644 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
645 }
646 break;
647 }
648 case Type::TypeKind::kSeparateSampler: {
649 this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer);
650 break;
651 }
652 case Type::TypeKind::kTexture: {
653 this->writeInstruction(SpvOpTypeImage, result,
654 this->getType(*fContext.fTypes.fFloat, layout),
655 type->dimensions(), type->isDepth(),
656 type->isArrayedTexture(), type->isMultisampled(),
657 type->isSampled() ? 1 : 2, SpvImageFormatUnknown,
658 fConstantBuffer);
659 fImageTypeMap[key] = result;
660 break;
661 }
662 default:
663 if (type->isVoid()) {
664 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
665 } else {
666 SkDEBUGFAILF("invalid type: %s", type->description().c_str());
667 }
668 break;
669 }
670 fTypeMap[key] = result;
671 return result;
672 }
673 return entry->second;
674 }
675
getImageType(const Type & type)676 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
677 SkASSERT(type.typeKind() == Type::TypeKind::kSampler);
678 this->getType(type);
679 std::string key = type.displayName() + std::to_string(fDefaultLayout.fStd);
680 SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
681 return fImageTypeMap[key];
682 }
683
getFunctionType(const FunctionDeclaration & function)684 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
685 std::string key = std::to_string(this->getType(function.returnType())) + "(";
686 std::string separator;
687 const std::vector<const Variable*>& parameters = function.parameters();
688 for (size_t i = 0; i < parameters.size(); i++) {
689 key += separator;
690 separator = ", ";
691 key += std::to_string(this->getType(parameters[i]->type()));
692 }
693 key += ")";
694 auto entry = fTypeMap.find(key);
695 if (entry == fTypeMap.end()) {
696 SpvId result = this->nextId(nullptr);
697 int32_t length = 3 + (int32_t) parameters.size();
698 SpvId returnType = this->getType(function.returnType());
699 std::vector<SpvId> parameterTypes;
700 for (size_t i = 0; i < parameters.size(); i++) {
701 // glslang seems to treat all function arguments as pointers whether they need to be or
702 // not. I was initially puzzled by this until I ran bizarre failures with certain
703 // patterns of function calls and control constructs, as exemplified by this minimal
704 // failure case:
705 //
706 // void sphere(float x) {
707 // }
708 //
709 // void map() {
710 // sphere(1.0);
711 // }
712 //
713 // void main() {
714 // for (int i = 0; i < 1; i++) {
715 // map();
716 // }
717 // }
718 //
719 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
720 // crashes. Making it take a float* and storing the argument in a temporary variable,
721 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
722 // the spec makes this make sense.
723 parameterTypes.push_back(this->getPointerType(parameters[i]->type(),
724 SpvStorageClassFunction));
725 }
726 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
727 this->writeWord(result, fConstantBuffer);
728 this->writeWord(returnType, fConstantBuffer);
729 for (SpvId id : parameterTypes) {
730 this->writeWord(id, fConstantBuffer);
731 }
732 fTypeMap[key] = result;
733 return result;
734 }
735 return entry->second;
736 }
737
getPointerType(const Type & type,SpvStorageClass_ storageClass)738 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
739 return this->getPointerType(type, fDefaultLayout, storageClass);
740 }
741
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)742 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
743 SpvStorageClass_ storageClass) {
744 const Type& type = this->getActualType(rawType);
745 std::string key = type.displayName() + "*" + std::to_string(layout.fStd) +
746 std::to_string(storageClass);
747 auto entry = fTypeMap.find(key);
748 if (entry == fTypeMap.end()) {
749 SpvId result = this->nextId(nullptr);
750 this->writeInstruction(SpvOpTypePointer, result, storageClass,
751 this->getType(type), fConstantBuffer);
752 fTypeMap[key] = result;
753 return result;
754 }
755 return entry->second;
756 }
757
writeExpression(const Expression & expr,OutputStream & out)758 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
759 switch (expr.kind()) {
760 case Expression::Kind::kBinary:
761 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
762 case Expression::Kind::kConstructorArrayCast:
763 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
764 case Expression::Kind::kConstructorArray:
765 case Expression::Kind::kConstructorStruct:
766 return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
767 case Expression::Kind::kConstructorDiagonalMatrix:
768 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
769 case Expression::Kind::kConstructorMatrixResize:
770 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
771 case Expression::Kind::kConstructorScalarCast:
772 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
773 case Expression::Kind::kConstructorSplat:
774 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
775 case Expression::Kind::kConstructorCompound:
776 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
777 case Expression::Kind::kConstructorCompoundCast:
778 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
779 case Expression::Kind::kFieldAccess:
780 return this->writeFieldAccess(expr.as<FieldAccess>(), out);
781 case Expression::Kind::kFunctionCall:
782 return this->writeFunctionCall(expr.as<FunctionCall>(), out);
783 case Expression::Kind::kLiteral:
784 return this->writeLiteral(expr.as<Literal>());
785 case Expression::Kind::kPrefix:
786 return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
787 case Expression::Kind::kPostfix:
788 return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
789 case Expression::Kind::kSwizzle:
790 return this->writeSwizzle(expr.as<Swizzle>(), out);
791 case Expression::Kind::kVariableReference:
792 return this->writeVariableReference(expr.as<VariableReference>(), out);
793 case Expression::Kind::kTernary:
794 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
795 case Expression::Kind::kIndex:
796 return this->writeIndexExpression(expr.as<IndexExpression>(), out);
797 default:
798 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
799 break;
800 }
801 return -1;
802 }
803
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)804 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
805 const FunctionDeclaration& function = c.function();
806 auto iter = fIntrinsicMap.find(function.intrinsicKind());
807 if (iter == fIntrinsicMap.end()) {
808 fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() + "'");
809 return -1;
810 }
811 const ExpressionArray& arguments = c.arguments();
812 const Intrinsic& intrinsic = iter->second;
813 int32_t intrinsicId = intrinsic.floatOp;
814 if (arguments.size() > 0) {
815 const Type& type = arguments[0]->type();
816 if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind || is_float(fContext, type)) {
817 // Keep the default float op.
818 } else if (is_signed(fContext, type)) {
819 intrinsicId = intrinsic.signedOp;
820 } else if (is_unsigned(fContext, type)) {
821 intrinsicId = intrinsic.unsignedOp;
822 } else if (is_bool(fContext, type)) {
823 intrinsicId = intrinsic.boolOp;
824 }
825 }
826 switch (intrinsic.opKind) {
827 case kGLSL_STD_450_IntrinsicOpcodeKind: {
828 SpvId result = this->nextId(&c.type());
829 std::vector<SpvId> argumentIds;
830 std::vector<TempVar> tempVars;
831 argumentIds.reserve(arguments.size());
832 for (size_t i = 0; i < arguments.size(); i++) {
833 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
834 }
835 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
836 this->writeWord(this->getType(c.type()), out);
837 this->writeWord(result, out);
838 this->writeWord(fGLSLExtendedInstructions, out);
839 this->writeWord(intrinsicId, out);
840 for (SpvId id : argumentIds) {
841 this->writeWord(id, out);
842 }
843 this->copyBackTempVars(tempVars, out);
844 return result;
845 }
846 case kSPIRV_IntrinsicOpcodeKind: {
847 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
848 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
849 intrinsicId = SpvOpFMul;
850 }
851 SpvId result = this->nextId(&c.type());
852 std::vector<SpvId> argumentIds;
853 std::vector<TempVar> tempVars;
854 argumentIds.reserve(arguments.size());
855 for (size_t i = 0; i < arguments.size(); i++) {
856 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
857 }
858 if (!c.type().isVoid()) {
859 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
860 this->writeWord(this->getType(c.type()), out);
861 this->writeWord(result, out);
862 } else {
863 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
864 }
865 for (SpvId id : argumentIds) {
866 this->writeWord(id, out);
867 }
868 this->copyBackTempVars(tempVars, out);
869 return result;
870 }
871 case kSpecial_IntrinsicOpcodeKind:
872 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
873 default:
874 fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() +
875 "'");
876 return -1;
877 }
878 }
879
vectorize(const Expression & arg,int vectorSize,OutputStream & out)880 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
881 SkASSERT(vectorSize >= 1 && vectorSize <= 4);
882 const Type& argType = arg.type();
883 SpvId raw = this->writeExpression(arg, out);
884 if (argType.isScalar()) {
885 if (vectorSize == 1) {
886 return raw;
887 }
888 SpvId vector = this->nextId(&argType);
889 this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
890 this->writeWord(this->getType(argType.toCompound(fContext, vectorSize, 1)), out);
891 this->writeWord(vector, out);
892 for (int i = 0; i < vectorSize; i++) {
893 this->writeWord(raw, out);
894 }
895 return vector;
896 } else {
897 SkASSERT(vectorSize == argType.columns());
898 return raw;
899 }
900 }
901
vectorize(const ExpressionArray & args,OutputStream & out)902 std::vector<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
903 int vectorSize = 1;
904 for (const auto& a : args) {
905 if (a->type().isVector()) {
906 if (vectorSize > 1) {
907 SkASSERT(a->type().columns() == vectorSize);
908 } else {
909 vectorSize = a->type().columns();
910 }
911 }
912 }
913 std::vector<SpvId> result;
914 result.reserve(args.size());
915 for (const auto& arg : args) {
916 result.push_back(this->vectorize(*arg, vectorSize, out));
917 }
918 return result;
919 }
920
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)921 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
922 SpvId signedInst, SpvId unsignedInst,
923 const std::vector<SpvId>& args,
924 OutputStream& out) {
925 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
926 this->writeWord(this->getType(type), out);
927 this->writeWord(id, out);
928 this->writeWord(fGLSLExtendedInstructions, out);
929
930 if (is_float(fContext, type)) {
931 this->writeWord(floatInst, out);
932 } else if (is_signed(fContext, type)) {
933 this->writeWord(signedInst, out);
934 } else if (is_unsigned(fContext, type)) {
935 this->writeWord(unsignedInst, out);
936 } else {
937 SkASSERT(false);
938 }
939 for (SpvId a : args) {
940 this->writeWord(a, out);
941 }
942 }
943
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)944 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
945 OutputStream& out) {
946 const ExpressionArray& arguments = c.arguments();
947 const Type& callType = c.type();
948 SpvId result = this->nextId(nullptr);
949 switch (kind) {
950 case kAtan_SpecialIntrinsic: {
951 std::vector<SpvId> argumentIds;
952 argumentIds.reserve(arguments.size());
953 for (const std::unique_ptr<Expression>& arg : arguments) {
954 argumentIds.push_back(this->writeExpression(*arg, out));
955 }
956 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
957 this->writeWord(this->getType(callType), out);
958 this->writeWord(result, out);
959 this->writeWord(fGLSLExtendedInstructions, out);
960 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
961 for (SpvId id : argumentIds) {
962 this->writeWord(id, out);
963 }
964 break;
965 }
966 case kSampledImage_SpecialIntrinsic: {
967 SkASSERT(arguments.size() == 2);
968 SpvId img = this->writeExpression(*arguments[0], out);
969 SpvId sampler = this->writeExpression(*arguments[1], out);
970 this->writeInstruction(SpvOpSampledImage,
971 this->getType(callType),
972 result,
973 img,
974 sampler,
975 out);
976 break;
977 }
978 case kSubpassLoad_SpecialIntrinsic: {
979 SpvId img = this->writeExpression(*arguments[0], out);
980 ExpressionArray args;
981 args.reserve_back(2);
982 args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0));
983 args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0));
984 ConstructorCompound ctor(/*line=*/-1, *fContext.fTypes.fInt2, std::move(args));
985 SpvId coords = this->writeConstantVector(ctor);
986 if (arguments.size() == 1) {
987 this->writeInstruction(SpvOpImageRead,
988 this->getType(callType),
989 result,
990 img,
991 coords,
992 out);
993 } else {
994 SkASSERT(arguments.size() == 2);
995 SpvId sample = this->writeExpression(*arguments[1], out);
996 this->writeInstruction(SpvOpImageRead,
997 this->getType(callType),
998 result,
999 img,
1000 coords,
1001 SpvImageOperandsSampleMask,
1002 sample,
1003 out);
1004 }
1005 break;
1006 }
1007 case kTexture_SpecialIntrinsic: {
1008 SpvOp_ op = SpvOpImageSampleImplicitLod;
1009 const Type& arg1Type = arguments[1]->type();
1010 switch (arguments[0]->type().dimensions()) {
1011 case SpvDim1D:
1012 if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
1013 op = SpvOpImageSampleProjImplicitLod;
1014 } else {
1015 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
1016 }
1017 break;
1018 case SpvDim2D:
1019 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
1020 op = SpvOpImageSampleProjImplicitLod;
1021 } else {
1022 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
1023 }
1024 break;
1025 case SpvDim3D:
1026 if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
1027 op = SpvOpImageSampleProjImplicitLod;
1028 } else {
1029 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
1030 }
1031 break;
1032 case SpvDimCube: // fall through
1033 case SpvDimRect: // fall through
1034 case SpvDimBuffer: // fall through
1035 case SpvDimSubpassData:
1036 break;
1037 }
1038 SpvId type = this->getType(callType);
1039 SpvId sampler = this->writeExpression(*arguments[0], out);
1040 SpvId uv = this->writeExpression(*arguments[1], out);
1041 if (arguments.size() == 3) {
1042 this->writeInstruction(op, type, result, sampler, uv,
1043 SpvImageOperandsBiasMask,
1044 this->writeExpression(*arguments[2], out),
1045 out);
1046 } else {
1047 SkASSERT(arguments.size() == 2);
1048 if (fProgram.fConfig->fSettings.fSharpenTextures) {
1049 SpvId lodBias = this->writeLiteral(-0.5, *fContext.fTypes.fFloat);
1050 this->writeInstruction(op, type, result, sampler, uv,
1051 SpvImageOperandsBiasMask, lodBias, out);
1052 } else {
1053 this->writeInstruction(op, type, result, sampler, uv,
1054 out);
1055 }
1056 }
1057 break;
1058 }
1059 case kMod_SpecialIntrinsic: {
1060 std::vector<SpvId> args = this->vectorize(arguments, out);
1061 SkASSERT(args.size() == 2);
1062 const Type& operandType = arguments[0]->type();
1063 SpvOp_ op;
1064 if (is_float(fContext, operandType)) {
1065 op = SpvOpFMod;
1066 } else if (is_signed(fContext, operandType)) {
1067 op = SpvOpSMod;
1068 } else if (is_unsigned(fContext, operandType)) {
1069 op = SpvOpUMod;
1070 } else {
1071 SkASSERT(false);
1072 return 0;
1073 }
1074 this->writeOpCode(op, 5, out);
1075 this->writeWord(this->getType(operandType), out);
1076 this->writeWord(result, out);
1077 this->writeWord(args[0], out);
1078 this->writeWord(args[1], out);
1079 break;
1080 }
1081 case kDFdy_SpecialIntrinsic: {
1082 SpvId fn = this->writeExpression(*arguments[0], out);
1083 this->writeOpCode(SpvOpDPdy, 4, out);
1084 this->writeWord(this->getType(callType), out);
1085 this->writeWord(result, out);
1086 this->writeWord(fn, out);
1087 this->addRTFlipUniform(c.fLine);
1088 using namespace dsl;
1089 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
1090 SKSL_RTFLIP_NAME));
1091 SpvId rtFlipY = this->vectorize(*rtFlip.y().release(), callType.columns(), out);
1092 SpvId flipped = this->nextId(&callType);
1093 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result, rtFlipY,
1094 out);
1095 result = flipped;
1096 break;
1097 }
1098 case kClamp_SpecialIntrinsic: {
1099 std::vector<SpvId> args = this->vectorize(arguments, out);
1100 SkASSERT(args.size() == 3);
1101 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1102 GLSLstd450UClamp, args, out);
1103 break;
1104 }
1105 case kMax_SpecialIntrinsic: {
1106 std::vector<SpvId> args = this->vectorize(arguments, out);
1107 SkASSERT(args.size() == 2);
1108 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
1109 GLSLstd450UMax, args, out);
1110 break;
1111 }
1112 case kMin_SpecialIntrinsic: {
1113 std::vector<SpvId> args = this->vectorize(arguments, out);
1114 SkASSERT(args.size() == 2);
1115 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
1116 GLSLstd450UMin, args, out);
1117 break;
1118 }
1119 case kMix_SpecialIntrinsic: {
1120 std::vector<SpvId> args = this->vectorize(arguments, out);
1121 SkASSERT(args.size() == 3);
1122 if (arguments[2]->type().componentType().isBoolean()) {
1123 // Use OpSelect to implement Boolean mix().
1124 SpvId falseId = this->writeExpression(*arguments[0], out);
1125 SpvId trueId = this->writeExpression(*arguments[1], out);
1126 SpvId conditionId = this->writeExpression(*arguments[2], out);
1127 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
1128 conditionId, trueId, falseId, out);
1129 } else {
1130 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
1131 SpvOpUndef, args, out);
1132 }
1133 break;
1134 }
1135 case kSaturate_SpecialIntrinsic: {
1136 SkASSERT(arguments.size() == 1);
1137 ExpressionArray finalArgs;
1138 finalArgs.reserve_back(3);
1139 finalArgs.push_back(arguments[0]->clone());
1140 finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/0));
1141 finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/1));
1142 std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
1143 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1144 GLSLstd450UClamp, spvArgs, out);
1145 break;
1146 }
1147 case kSmoothStep_SpecialIntrinsic: {
1148 std::vector<SpvId> args = this->vectorize(arguments, out);
1149 SkASSERT(args.size() == 3);
1150 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
1151 SpvOpUndef, args, out);
1152 break;
1153 }
1154 case kStep_SpecialIntrinsic: {
1155 std::vector<SpvId> args = this->vectorize(arguments, out);
1156 SkASSERT(args.size() == 2);
1157 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
1158 SpvOpUndef, args, out);
1159 break;
1160 }
1161 case kMatrixCompMult_SpecialIntrinsic: {
1162 SkASSERT(arguments.size() == 2);
1163 SpvId lhs = this->writeExpression(*arguments[0], out);
1164 SpvId rhs = this->writeExpression(*arguments[1], out);
1165 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
1166 break;
1167 }
1168 }
1169 return result;
1170 }
1171
writeFunctionCallArgument(const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,OutputStream & out)1172 SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const FunctionCall& call,
1173 int argIndex,
1174 std::vector<TempVar>* tempVars,
1175 OutputStream& out) {
1176 const FunctionDeclaration& funcDecl = call.function();
1177 const Expression& arg = *call.arguments()[argIndex];
1178 const Modifiers& paramModifiers = funcDecl.parameters()[argIndex]->modifiers();
1179
1180 // ID of temporary variable that we will use to hold this argument, or 0 if it is being
1181 // passed directly
1182 SpvId tmpVar;
1183 // if we need a temporary var to store this argument, this is the value to store in the var
1184 SpvId tmpValueId = -1;
1185
1186 if (is_out(paramModifiers)) {
1187 std::unique_ptr<LValue> lv = this->getLValue(arg, out);
1188 // We handle out params with a temp var that we copy back to the original variable at the
1189 // end of the call. GLSL guarantees that the original variable will be unchanged until the
1190 // end of the call, and also that out params are written back to their original variables in
1191 // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
1192 if (is_in(paramModifiers)) {
1193 tmpValueId = lv->load(out);
1194 }
1195 tmpVar = this->nextId(&arg.type());
1196 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
1197 } else if (funcDecl.isIntrinsic()) {
1198 // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
1199 return this->writeExpression(arg, out);
1200 } else {
1201 // We always use pointer parameters when calling user functions.
1202 // See getFunctionType for further explanation.
1203 tmpValueId = this->writeExpression(arg, out);
1204 tmpVar = this->nextId(nullptr);
1205 }
1206 this->writeInstruction(SpvOpVariable,
1207 this->getPointerType(arg.type(), SpvStorageClassFunction),
1208 tmpVar,
1209 SpvStorageClassFunction,
1210 fVariableBuffer);
1211 if (tmpValueId != (SpvId)-1) {
1212 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1213 }
1214 return tmpVar;
1215 }
1216
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)1217 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
1218 for (const TempVar& tempVar : tempVars) {
1219 SpvId load = this->nextId(tempVar.type);
1220 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
1221 tempVar.lvalue->store(load, out);
1222 }
1223 }
1224
writeFunctionCall(const FunctionCall & c,OutputStream & out)1225 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1226 const FunctionDeclaration& function = c.function();
1227 if (function.isIntrinsic() && !function.definition()) {
1228 return this->writeIntrinsicCall(c, out);
1229 }
1230 const ExpressionArray& arguments = c.arguments();
1231 const auto& entry = fFunctionMap.find(&function);
1232 if (entry == fFunctionMap.end()) {
1233 fContext.fErrors->error(c.fLine, "function '" + function.description() +
1234 "' is not defined");
1235 return -1;
1236 }
1237 // Temp variables are used to write back out-parameters after the function call is complete.
1238 std::vector<TempVar> tempVars;
1239 std::vector<SpvId> argumentIds;
1240 argumentIds.reserve(arguments.size());
1241 for (size_t i = 0; i < arguments.size(); i++) {
1242 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1243 }
1244 SpvId result = this->nextId(nullptr);
1245 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) arguments.size(), out);
1246 this->writeWord(this->getType(c.type()), out);
1247 this->writeWord(result, out);
1248 this->writeWord(entry->second, out);
1249 for (SpvId id : argumentIds) {
1250 this->writeWord(id, out);
1251 }
1252 // Now that the call is complete, we copy temp out-variables back to their real lvalues.
1253 this->copyBackTempVars(tempVars, out);
1254 return result;
1255 }
1256
writeConstantVector(const AnyConstructor & c)1257 SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) {
1258 const Type& type = c.type();
1259 SkASSERT(type.isVector() && c.isCompileTimeConstant());
1260
1261 // Get each of the constructor components as SPIR-V constants.
1262 SPIRVVectorConstant key{this->getType(type),
1263 /*fValueId=*/{SpvId(-1), SpvId(-1), SpvId(-1), SpvId(-1)}};
1264
1265 const Type& scalarType = type.componentType();
1266 for (int n = 0; n < type.columns(); n++) {
1267 std::optional<double> slotVal = c.getConstantValue(n);
1268 if (!slotVal.has_value()) {
1269 SkDEBUGFAILF("writeConstantVector: %s not actually constant", c.description().c_str());
1270 return (SpvId)-1;
1271 }
1272 key.fValueId[n] = this->writeLiteral(*slotVal, scalarType);
1273 }
1274
1275 // Check to see if we've already synthesized this vector constant.
1276 auto [iter, newlyCreated] = fVectorConstants.insert({key, (SpvId)-1});
1277 if (newlyCreated) {
1278 // Emit an OpConstantComposite instruction for this constant.
1279 SpvId result = this->nextId(&type);
1280 this->writeOpCode(SpvOpConstantComposite, 3 + type.columns(), fConstantBuffer);
1281 this->writeWord(key.fTypeId, fConstantBuffer);
1282 this->writeWord(result, fConstantBuffer);
1283 for (int i = 0; i < type.columns(); i++) {
1284 this->writeWord(key.fValueId[i], fConstantBuffer);
1285 }
1286 iter->second = result;
1287 }
1288 return iter->second;
1289 }
1290
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)1291 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
1292 const Type& inputType,
1293 const Type& outputType,
1294 OutputStream& out) {
1295 if (outputType.isFloat()) {
1296 return this->castScalarToFloat(inputExprId, inputType, outputType, out);
1297 }
1298 if (outputType.isSigned()) {
1299 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
1300 }
1301 if (outputType.isUnsigned()) {
1302 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
1303 }
1304 if (outputType.isBoolean()) {
1305 return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
1306 }
1307
1308 fContext.fErrors->error(-1, "unsupported cast: " + inputType.description() +
1309 " to " + outputType.description());
1310 return inputExprId;
1311 }
1312
writeFloatConstructor(const AnyConstructor & c,OutputStream & out)1313 SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) {
1314 SkASSERT(c.argumentSpan().size() == 1);
1315 SkASSERT(c.type().isFloat());
1316 const Expression& ctorExpr = *c.argumentSpan().front();
1317 SpvId expressionId = this->writeExpression(ctorExpr, out);
1318 return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out);
1319 }
1320
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1321 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
1322 const Type& outputType, OutputStream& out) {
1323 // Casting a float to float is a no-op.
1324 if (inputType.isFloat()) {
1325 return inputId;
1326 }
1327
1328 // Given the input type, generate the appropriate instruction to cast to float.
1329 SpvId result = this->nextId(&outputType);
1330 if (inputType.isBoolean()) {
1331 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
1332 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
1333 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1334 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1335 inputId, oneID, zeroID, out);
1336 } else if (inputType.isSigned()) {
1337 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
1338 } else if (inputType.isUnsigned()) {
1339 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
1340 } else {
1341 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
1342 return (SpvId)-1;
1343 }
1344 return result;
1345 }
1346
writeIntConstructor(const AnyConstructor & c,OutputStream & out)1347 SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) {
1348 SkASSERT(c.argumentSpan().size() == 1);
1349 SkASSERT(c.type().isSigned());
1350 const Expression& ctorExpr = *c.argumentSpan().front();
1351 SpvId expressionId = this->writeExpression(ctorExpr, out);
1352 return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out);
1353 }
1354
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1355 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
1356 const Type& outputType, OutputStream& out) {
1357 // Casting a signed int to signed int is a no-op.
1358 if (inputType.isSigned()) {
1359 return inputId;
1360 }
1361
1362 // Given the input type, generate the appropriate instruction to cast to signed int.
1363 SpvId result = this->nextId(&outputType);
1364 if (inputType.isBoolean()) {
1365 // Use OpSelect to convert the boolean argument to a literal 1 or 0.
1366 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
1367 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1368 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1369 inputId, oneID, zeroID, out);
1370 } else if (inputType.isFloat()) {
1371 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
1372 } else if (inputType.isUnsigned()) {
1373 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1374 } else {
1375 SkDEBUGFAILF("unsupported type for signed int typecast: %s",
1376 inputType.description().c_str());
1377 return (SpvId)-1;
1378 }
1379 return result;
1380 }
1381
writeUIntConstructor(const AnyConstructor & c,OutputStream & out)1382 SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) {
1383 SkASSERT(c.argumentSpan().size() == 1);
1384 SkASSERT(c.type().isUnsigned());
1385 const Expression& ctorExpr = *c.argumentSpan().front();
1386 SpvId expressionId = this->writeExpression(ctorExpr, out);
1387 return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out);
1388 }
1389
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1390 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
1391 const Type& outputType, OutputStream& out) {
1392 // Casting an unsigned int to unsigned int is a no-op.
1393 if (inputType.isUnsigned()) {
1394 return inputId;
1395 }
1396
1397 // Given the input type, generate the appropriate instruction to cast to unsigned int.
1398 SpvId result = this->nextId(&outputType);
1399 if (inputType.isBoolean()) {
1400 // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
1401 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
1402 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1403 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1404 inputId, oneID, zeroID, out);
1405 } else if (inputType.isFloat()) {
1406 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
1407 } else if (inputType.isSigned()) {
1408 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1409 } else {
1410 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
1411 inputType.description().c_str());
1412 return (SpvId)-1;
1413 }
1414 return result;
1415 }
1416
writeBooleanConstructor(const AnyConstructor & c,OutputStream & out)1417 SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) {
1418 SkASSERT(c.argumentSpan().size() == 1);
1419 SkASSERT(c.type().isBoolean());
1420 const Expression& ctorExpr = *c.argumentSpan().front();
1421 SpvId expressionId = this->writeExpression(ctorExpr, out);
1422 return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out);
1423 }
1424
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1425 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
1426 const Type& outputType, OutputStream& out) {
1427 // Casting a bool to bool is a no-op.
1428 if (inputType.isBoolean()) {
1429 return inputId;
1430 }
1431
1432 // Given the input type, generate the appropriate instruction to cast to bool.
1433 SpvId result = this->nextId(nullptr);
1434 if (inputType.isSigned()) {
1435 // Synthesize a boolean result by comparing the input against a signed zero literal.
1436 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1437 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1438 inputId, zeroID, out);
1439 } else if (inputType.isUnsigned()) {
1440 // Synthesize a boolean result by comparing the input against an unsigned zero literal.
1441 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1442 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1443 inputId, zeroID, out);
1444 } else if (inputType.isFloat()) {
1445 // Synthesize a boolean result by comparing the input against a floating-point zero literal.
1446 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1447 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
1448 inputId, zeroID, out);
1449 } else {
1450 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
1451 return (SpvId)-1;
1452 }
1453 return result;
1454 }
1455
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1456 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1457 OutputStream& out) {
1458 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1459 std::vector<SpvId> columnIds;
1460 columnIds.reserve(type.columns());
1461 for (int column = 0; column < type.columns(); column++) {
1462 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1463 out);
1464 this->writeWord(this->getType(type.componentType().toCompound(
1465 fContext, /*columns=*/type.rows(), /*rows=*/1)),
1466 out);
1467 SpvId columnId = this->nextId(&type);
1468 this->writeWord(columnId, out);
1469 columnIds.push_back(columnId);
1470 for (int row = 0; row < type.rows(); row++) {
1471 this->writeWord(row == column ? diagonal : zeroId, out);
1472 }
1473 }
1474 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1475 out);
1476 this->writeWord(this->getType(type), out);
1477 this->writeWord(id, out);
1478 for (SpvId columnId : columnIds) {
1479 this->writeWord(columnId, out);
1480 }
1481 }
1482
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1483 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
1484 OutputStream& out) {
1485 SkASSERT(srcType.isMatrix());
1486 SkASSERT(dstType.isMatrix());
1487 SkASSERT(srcType.componentType().matches(dstType.componentType()));
1488 SpvId id = this->nextId(&dstType);
1489 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1490 srcType.rows(),
1491 1));
1492 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1493 dstType.rows(),
1494 1));
1495 SkASSERT(dstType.componentType().isFloat());
1496 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
1497 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
1498
1499 SpvId columns[4];
1500 for (int i = 0; i < dstType.columns(); i++) {
1501 if (i < srcType.columns()) {
1502 // we're still inside the src matrix, copy the column
1503 SpvId srcColumn = this->nextId(&dstType);
1504 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1505 SpvId dstColumn;
1506 if (srcType.rows() == dstType.rows()) {
1507 // columns are equal size, don't need to do anything
1508 dstColumn = srcColumn;
1509 }
1510 else if (dstType.rows() > srcType.rows()) {
1511 // dst column is bigger, need to zero-pad it
1512 dstColumn = this->nextId(&dstType);
1513 int delta = dstType.rows() - srcType.rows();
1514 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1515 this->writeWord(dstColumnType, out);
1516 this->writeWord(dstColumn, out);
1517 this->writeWord(srcColumn, out);
1518 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
1519 this->writeWord((i == j) ? oneId : zeroId, out);
1520 }
1521 }
1522 else {
1523 // dst column is smaller, need to swizzle the src column
1524 dstColumn = this->nextId(&dstType);
1525 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
1526 this->writeWord(dstColumnType, out);
1527 this->writeWord(dstColumn, out);
1528 this->writeWord(srcColumn, out);
1529 this->writeWord(srcColumn, out);
1530 for (int j = 0; j < dstType.rows(); j++) {
1531 this->writeWord(j, out);
1532 }
1533 }
1534 columns[i] = dstColumn;
1535 } else {
1536 // we're past the end of the src matrix, need to synthesize an identity-matrix column
1537 SpvId identityColumn = this->nextId(&dstType);
1538 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1539 this->writeWord(dstColumnType, out);
1540 this->writeWord(identityColumn, out);
1541 for (int j = 0; j < dstType.rows(); ++j) {
1542 this->writeWord((i == j) ? oneId : zeroId, out);
1543 }
1544 columns[i] = identityColumn;
1545 }
1546 }
1547 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1548 this->writeWord(this->getType(dstType), out);
1549 this->writeWord(id, out);
1550 for (int i = 0; i < dstType.columns(); i++) {
1551 this->writeWord(columns[i], out);
1552 }
1553 return id;
1554 }
1555
addColumnEntry(const Type & columnType,std::vector<SpvId> * currentColumn,std::vector<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)1556 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
1557 std::vector<SpvId>* currentColumn,
1558 std::vector<SpvId>* columnIds,
1559 int rows,
1560 SpvId entry,
1561 OutputStream& out) {
1562 SkASSERT((int)currentColumn->size() < rows);
1563 currentColumn->push_back(entry);
1564 if ((int)currentColumn->size() == rows) {
1565 // Synthesize this column into a vector.
1566 SpvId columnId = this->writeComposite(*currentColumn, columnType, out);
1567 columnIds->push_back(columnId);
1568 currentColumn->clear();
1569 }
1570 }
1571
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)1572 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
1573 const Type& type = c.type();
1574 SkASSERT(type.isMatrix());
1575 SkASSERT(!c.arguments().empty());
1576 const Type& arg0Type = c.arguments()[0]->type();
1577 // go ahead and write the arguments so we don't try to write new instructions in the middle of
1578 // an instruction
1579 std::vector<SpvId> arguments;
1580 arguments.reserve(c.arguments().size());
1581 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1582 arguments.push_back(this->writeExpression(*arg, out));
1583 }
1584
1585 if (arguments.size() == 1 && arg0Type.isVector()) {
1586 // Special-case handling of float4 -> mat2x2.
1587 SkASSERT(type.rows() == 2 && type.columns() == 2);
1588 SkASSERT(arg0Type.columns() == 4);
1589 SpvId componentType = this->getType(type.componentType());
1590 SpvId v[4];
1591 for (int i = 0; i < 4; ++i) {
1592 v[i] = this->nextId(&type);
1593 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i,
1594 out);
1595 }
1596 const Type& vecType = type.componentType().toCompound(fContext, /*columns=*/2, /*rows=*/1);
1597 SpvId v0v1 = this->writeComposite({v[0], v[1]}, vecType, out);
1598 SpvId v2v3 = this->writeComposite({v[2], v[3]}, vecType, out);
1599 return this->writeComposite({v0v1, v2v3}, type, out);
1600 }
1601
1602 int rows = type.rows();
1603 const Type& columnType = type.componentType().toCompound(fContext,
1604 /*columns=*/rows, /*rows=*/1);
1605 // SpvIds of completed columns of the matrix.
1606 std::vector<SpvId> columnIds;
1607 // SpvIds of scalars we have written to the current column so far.
1608 std::vector<SpvId> currentColumn;
1609 for (size_t i = 0; i < arguments.size(); i++) {
1610 const Type& argType = c.arguments()[i]->type();
1611 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
1612 // This vector is a complete matrix column by itself and can be used as-is.
1613 columnIds.push_back(arguments[i]);
1614 } else if (argType.columns() == 1) {
1615 // This argument is a lone scalar and can be added to the current column as-is.
1616 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out);
1617 } else {
1618 // This argument needs to be decomposed into its constituent scalars.
1619 SpvId componentType = this->getType(argType.componentType());
1620 for (int j = 0; j < argType.columns(); ++j) {
1621 SpvId swizzle = this->nextId(&argType);
1622 this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1623 arguments[i], j, out);
1624 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out);
1625 }
1626 }
1627 }
1628 SkASSERT(columnIds.size() == (size_t) type.columns());
1629 return this->writeComposite(columnIds, type, out);
1630 }
1631
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)1632 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1633 OutputStream& out) {
1634 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
1635 : this->writeVectorConstructor(c, out);
1636 }
1637
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)1638 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
1639 const Type& type = c.type();
1640 const Type& componentType = type.componentType();
1641 SkASSERT(type.isVector());
1642
1643 if (c.isCompileTimeConstant()) {
1644 return this->writeConstantVector(c);
1645 }
1646
1647 std::vector<SpvId> arguments;
1648 arguments.reserve(c.arguments().size());
1649 for (size_t i = 0; i < c.arguments().size(); i++) {
1650 const Type& argType = c.arguments()[i]->type();
1651 SkASSERT(componentType.matches(argType.componentType()));
1652
1653 SpvId arg = this->writeExpression(*c.arguments()[i], out);
1654 if (argType.isMatrix()) {
1655 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
1656 // each scalar separately.
1657 SkASSERT(argType.rows() == 2);
1658 SkASSERT(argType.columns() == 2);
1659 for (int j = 0; j < 4; ++j) {
1660 SpvId componentId = this->nextId(&componentType);
1661 this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType),
1662 componentId, arg, j / 2, j % 2, out);
1663 arguments.push_back(componentId);
1664 }
1665 } else if (argType.isVector()) {
1666 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
1667 // vector arguments at all, so we always extract each vector component and pass them
1668 // into OpCompositeConstruct individually.
1669 for (int j = 0; j < argType.columns(); j++) {
1670 SpvId componentId = this->nextId(&componentType);
1671 this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType),
1672 componentId, arg, j, out);
1673 arguments.push_back(componentId);
1674 }
1675 } else {
1676 arguments.push_back(arg);
1677 }
1678 }
1679
1680 return this->writeComposite(arguments, type, out);
1681 }
1682
writeComposite(const std::vector<SpvId> & arguments,const Type & type,OutputStream & out)1683 SpvId SPIRVCodeGenerator::writeComposite(const std::vector<SpvId>& arguments,
1684 const Type& type,
1685 OutputStream& out) {
1686 SkASSERT(arguments.size() == (type.isStruct() ? type.fields().size() : (size_t)type.columns()));
1687
1688 SpvId result = this->nextId(&type);
1689 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1690 this->writeWord(this->getType(type), out);
1691 this->writeWord(result, out);
1692 for (SpvId id : arguments) {
1693 this->writeWord(id, out);
1694 }
1695 return result;
1696 }
1697
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)1698 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
1699 // Use writeConstantVector to deduplicate constant splats.
1700 if (c.isCompileTimeConstant()) {
1701 return this->writeConstantVector(c);
1702 }
1703
1704 // Write the splat argument.
1705 SpvId argument = this->writeExpression(*c.argument(), out);
1706
1707 // Generate a OpCompositeConstruct which repeats the argument N times.
1708 std::vector<SpvId> arguments(/*count*/ c.type().columns(), /*value*/ argument);
1709 return this->writeComposite(arguments, c.type(), out);
1710 }
1711
1712
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)1713 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
1714 SkASSERT(c.type().isArray() || c.type().isStruct());
1715 auto ctorArgs = c.argumentSpan();
1716
1717 std::vector<SpvId> arguments;
1718 arguments.reserve(ctorArgs.size());
1719 for (const std::unique_ptr<Expression>& arg : ctorArgs) {
1720 arguments.push_back(this->writeExpression(*arg, out));
1721 }
1722
1723 return this->writeComposite(arguments, c.type(), out);
1724 }
1725
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)1726 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
1727 OutputStream& out) {
1728 const Type& type = c.type();
1729 if (this->getActualType(type).matches(this->getActualType(c.argument()->type()))) {
1730 return this->writeExpression(*c.argument(), out);
1731 }
1732
1733 const Expression& ctorExpr = *c.argument();
1734 SpvId expressionId = this->writeExpression(ctorExpr, out);
1735 return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
1736 }
1737
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)1738 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
1739 OutputStream& out) {
1740 const Type& ctorType = c.type();
1741 const Type& argType = c.argument()->type();
1742 SkASSERT(ctorType.isVector() || ctorType.isMatrix());
1743
1744 // Write the composite that we are casting. If the actual type matches, we are done.
1745 SpvId compositeId = this->writeExpression(*c.argument(), out);
1746 if (this->getActualType(ctorType).matches(this->getActualType(argType))) {
1747 return compositeId;
1748 }
1749
1750 // writeMatrixCopy can cast matrices to a different type.
1751 if (ctorType.isMatrix()) {
1752 return this->writeMatrixCopy(compositeId, argType, ctorType, out);
1753 }
1754
1755 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
1756 // components and convert each one manually.
1757 const Type& srcType = argType.componentType();
1758 const Type& dstType = ctorType.componentType();
1759
1760 std::vector<SpvId> arguments;
1761 arguments.reserve(argType.columns());
1762 for (int index = 0; index < argType.columns(); ++index) {
1763 SpvId componentId = this->nextId(&srcType);
1764 this->writeInstruction(SpvOpCompositeExtract, this->getType(srcType), componentId,
1765 compositeId, index, out);
1766 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
1767 }
1768
1769 return this->writeComposite(arguments, ctorType, out);
1770 }
1771
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)1772 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
1773 OutputStream& out) {
1774 const Type& type = c.type();
1775 SkASSERT(type.isMatrix());
1776 SkASSERT(c.argument()->type().isScalar());
1777
1778 // Write out the scalar argument.
1779 SpvId argument = this->writeExpression(*c.argument(), out);
1780
1781 // Build the diagonal matrix.
1782 SpvId result = this->nextId(&type);
1783 this->writeUniformScaleMatrix(result, argument, type, out);
1784 return result;
1785 }
1786
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)1787 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1788 OutputStream& out) {
1789 // Write the input matrix.
1790 SpvId argument = this->writeExpression(*c.argument(), out);
1791
1792 // Use matrix-copy to resize the input matrix to its new size.
1793 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
1794 }
1795
get_storage_class(const Variable & var,SpvStorageClass_ fallbackStorageClass)1796 static SpvStorageClass_ get_storage_class(const Variable& var,
1797 SpvStorageClass_ fallbackStorageClass) {
1798 const Modifiers& modifiers = var.modifiers();
1799 if (modifiers.fFlags & Modifiers::kIn_Flag) {
1800 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1801 return SpvStorageClassInput;
1802 }
1803 if (modifiers.fFlags & Modifiers::kOut_Flag) {
1804 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1805 return SpvStorageClassOutput;
1806 }
1807 if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1808 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1809 return SpvStorageClassPushConstant;
1810 }
1811 if (var.type().typeKind() == Type::TypeKind::kSampler ||
1812 var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
1813 var.type().typeKind() == Type::TypeKind::kTexture) {
1814 return SpvStorageClassUniformConstant;
1815 }
1816 return SpvStorageClassUniform;
1817 }
1818 return fallbackStorageClass;
1819 }
1820
get_storage_class(const Expression & expr)1821 static SpvStorageClass_ get_storage_class(const Expression& expr) {
1822 switch (expr.kind()) {
1823 case Expression::Kind::kVariableReference: {
1824 const Variable& var = *expr.as<VariableReference>().variable();
1825 if (var.storage() != Variable::Storage::kGlobal) {
1826 return SpvStorageClassFunction;
1827 }
1828 return get_storage_class(var, SpvStorageClassPrivate);
1829 }
1830 case Expression::Kind::kFieldAccess:
1831 return get_storage_class(*expr.as<FieldAccess>().base());
1832 case Expression::Kind::kIndex:
1833 return get_storage_class(*expr.as<IndexExpression>().base());
1834 default:
1835 return SpvStorageClassFunction;
1836 }
1837 }
1838
getAccessChain(const Expression & expr,OutputStream & out)1839 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1840 std::vector<SpvId> chain;
1841 switch (expr.kind()) {
1842 case Expression::Kind::kIndex: {
1843 const IndexExpression& indexExpr = expr.as<IndexExpression>();
1844 chain = this->getAccessChain(*indexExpr.base(), out);
1845 chain.push_back(this->writeExpression(*indexExpr.index(), out));
1846 break;
1847 }
1848 case Expression::Kind::kFieldAccess: {
1849 const FieldAccess& fieldExpr = expr.as<FieldAccess>();
1850 chain = this->getAccessChain(*fieldExpr.base(), out);
1851 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
1852 break;
1853 }
1854 default: {
1855 SpvId id = this->getLValue(expr, out)->getPointer();
1856 SkASSERT(id != (SpvId) -1);
1857 chain.push_back(id);
1858 break;
1859 }
1860 }
1861 return chain;
1862 }
1863
1864 class PointerLValue : public SPIRVCodeGenerator::LValue {
1865 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision)1866 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
1867 SPIRVCodeGenerator::Precision precision)
1868 : fGen(gen)
1869 , fPointer(pointer)
1870 , fIsMemoryObject(isMemoryObject)
1871 , fType(type)
1872 , fPrecision(precision) {}
1873
getPointer()1874 SpvId getPointer() override {
1875 return fPointer;
1876 }
1877
isMemoryObjectPointer() const1878 bool isMemoryObjectPointer() const override {
1879 return fIsMemoryObject;
1880 }
1881
load(OutputStream & out)1882 SpvId load(OutputStream& out) override {
1883 SpvId result = fGen.nextId(fPrecision);
1884 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1885 return result;
1886 }
1887
store(SpvId value,OutputStream & out)1888 void store(SpvId value, OutputStream& out) override {
1889 fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1890 }
1891
1892 private:
1893 SPIRVCodeGenerator& fGen;
1894 const SpvId fPointer;
1895 const bool fIsMemoryObject;
1896 const SpvId fType;
1897 const SPIRVCodeGenerator::Precision fPrecision;
1898 };
1899
1900 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1901 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType)1902 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
1903 const Type& baseType, const Type& swizzleType)
1904 : fGen(gen)
1905 , fVecPointer(vecPointer)
1906 , fComponents(components)
1907 , fBaseType(&baseType)
1908 , fSwizzleType(&swizzleType) {}
1909
applySwizzle(const ComponentArray & components,const Type & newType)1910 bool applySwizzle(const ComponentArray& components, const Type& newType) override {
1911 ComponentArray updatedSwizzle;
1912 for (int8_t component : components) {
1913 if (component < 0 || component >= fComponents.count()) {
1914 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
1915 return false;
1916 }
1917 updatedSwizzle.push_back(fComponents[component]);
1918 }
1919 fComponents = updatedSwizzle;
1920 fSwizzleType = &newType;
1921 return true;
1922 }
1923
load(OutputStream & out)1924 SpvId load(OutputStream& out) override {
1925 SpvId base = fGen.nextId(fBaseType);
1926 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1927 SpvId result = fGen.nextId(fBaseType);
1928 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1929 fGen.writeWord(fGen.getType(*fSwizzleType), out);
1930 fGen.writeWord(result, out);
1931 fGen.writeWord(base, out);
1932 fGen.writeWord(base, out);
1933 for (int component : fComponents) {
1934 fGen.writeWord(component, out);
1935 }
1936 return result;
1937 }
1938
store(SpvId value,OutputStream & out)1939 void store(SpvId value, OutputStream& out) override {
1940 // use OpVectorShuffle to mix and match the vector components. We effectively create
1941 // a virtual vector out of the concatenation of the left and right vectors, and then
1942 // select components from this virtual vector to make the result vector. For
1943 // instance, given:
1944 // float3L = ...;
1945 // float3R = ...;
1946 // L.xz = R.xy;
1947 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1948 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1949 // (3, 1, 4).
1950 SpvId base = fGen.nextId(fBaseType);
1951 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1952 SpvId shuffle = fGen.nextId(fBaseType);
1953 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
1954 fGen.writeWord(fGen.getType(*fBaseType), out);
1955 fGen.writeWord(shuffle, out);
1956 fGen.writeWord(base, out);
1957 fGen.writeWord(value, out);
1958 for (int i = 0; i < fBaseType->columns(); i++) {
1959 // current offset into the virtual vector, defaults to pulling the unmodified
1960 // value from the left side
1961 int offset = i;
1962 // check to see if we are writing this component
1963 for (size_t j = 0; j < fComponents.size(); j++) {
1964 if (fComponents[j] == i) {
1965 // we're writing to this component, so adjust the offset to pull from
1966 // the correct component of the right side instead of preserving the
1967 // value from the left
1968 offset = (int) (j + fBaseType->columns());
1969 break;
1970 }
1971 }
1972 fGen.writeWord(offset, out);
1973 }
1974 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1975 }
1976
1977 private:
1978 SPIRVCodeGenerator& fGen;
1979 const SpvId fVecPointer;
1980 ComponentArray fComponents;
1981 const Type* fBaseType;
1982 const Type* fSwizzleType;
1983 };
1984
findUniformFieldIndex(const Variable & var) const1985 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
1986 auto iter = fTopLevelUniformMap.find(&var);
1987 return (iter != fTopLevelUniformMap.end()) ? iter->second : -1;
1988 }
1989
getLValue(const Expression & expr,OutputStream & out)1990 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1991 OutputStream& out) {
1992 const Type& type = expr.type();
1993 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
1994 switch (expr.kind()) {
1995 case Expression::Kind::kVariableReference: {
1996 const Variable& var = *expr.as<VariableReference>().variable();
1997 int uniformIdx = this->findUniformFieldIndex(var);
1998 if (uniformIdx >= 0) {
1999 SpvId memberId = this->nextId(nullptr);
2000 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
2001 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
2002 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
2003 uniformIdxId, out);
2004 return std::make_unique<PointerLValue>(*this, memberId,
2005 /*isMemoryObjectPointer=*/true,
2006 this->getType(type), precision);
2007 }
2008 SpvId typeId = this->getType(type, this->memoryLayoutForVariable(var));
2009 auto entry = fVariableMap.find(&var);
2010 SkASSERTF(entry != fVariableMap.end(), "%s", expr.description().c_str());
2011 return std::make_unique<PointerLValue>(*this, entry->second,
2012 /*isMemoryObjectPointer=*/true,
2013 typeId, precision);
2014 }
2015 case Expression::Kind::kIndex: // fall through
2016 case Expression::Kind::kFieldAccess: {
2017 std::vector<SpvId> chain = this->getAccessChain(expr, out);
2018 SpvId member = this->nextId(nullptr);
2019 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
2020 this->writeWord(this->getPointerType(type, get_storage_class(expr)), out);
2021 this->writeWord(member, out);
2022 for (SpvId idx : chain) {
2023 this->writeWord(idx, out);
2024 }
2025 return std::make_unique<PointerLValue>(*this, member, /*isMemoryObjectPointer=*/false,
2026 this->getType(type), precision);
2027 }
2028 case Expression::Kind::kSwizzle: {
2029 const Swizzle& swizzle = expr.as<Swizzle>();
2030 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
2031 if (lvalue->applySwizzle(swizzle.components(), type)) {
2032 return lvalue;
2033 }
2034 SpvId base = lvalue->getPointer();
2035 if (base == (SpvId) -1) {
2036 fContext.fErrors->error(swizzle.fLine, "unable to retrieve lvalue from swizzle");
2037 }
2038 if (swizzle.components().size() == 1) {
2039 SpvId member = this->nextId(nullptr);
2040 SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base()));
2041 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
2042 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
2043 return std::make_unique<PointerLValue>(*this,
2044 member,
2045 /*isMemoryObjectPointer=*/false,
2046 this->getType(type),
2047 precision);
2048 } else {
2049 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
2050 swizzle.base()->type(), type);
2051 }
2052 }
2053 default: {
2054 // expr isn't actually an lvalue, create a placeholder variable for it. This case
2055 // happens due to the need to store values in temporary variables during function
2056 // calls (see comments in getFunctionType); erroneous uses of rvalues as lvalues
2057 // should have been caught before code generation
2058 SpvId result = this->nextId(nullptr);
2059 SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
2060 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
2061 fVariableBuffer);
2062 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
2063 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
2064 this->getType(type), precision);
2065 }
2066 }
2067 }
2068
writeVariableReference(const VariableReference & ref,OutputStream & out)2069 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
2070 const Variable* variable = ref.variable();
2071 if (variable->modifiers().fLayout.fBuiltin == DEVICE_FRAGCOORDS_BUILTIN) {
2072 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
2073 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly access
2074 // the fragcoord; do so now.
2075 dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2076 return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2077 }
2078 if (variable->modifiers().fLayout.fBuiltin == DEVICE_CLOCKWISE_BUILTIN) {
2079 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
2080 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
2081 // access front facing; do so now.
2082 dsl::DSLGlobalVar clockwise("sk_Clockwise");
2083 return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2084 }
2085
2086 // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
2087 if (variable->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2088 this->addRTFlipUniform(ref.fLine);
2089 // Use sk_RTAdjust to compute the flipped coordinate
2090 using namespace dsl;
2091 const char* DEVICE_COORDS_NAME = "$device_FragCoords";
2092 SymbolTable& symbols = *ThreadContext::SymbolTable();
2093 // Use a uniform to flip the Y coordinate. The new expression will be written in
2094 // terms of $device_FragCoords, which is a fake variable that means "access the
2095 // underlying fragcoords directly without flipping it".
2096 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
2097 SKSL_RTFLIP_NAME));
2098 if (!symbols[DEVICE_COORDS_NAME]) {
2099 AutoAttachPoolToThread attach(fProgram.fPool.get());
2100 Modifiers modifiers;
2101 modifiers.fLayout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
2102 auto coordsVar = std::make_unique<Variable>(/*line=*/-1,
2103 fContext.fModifiersPool->add(modifiers),
2104 DEVICE_COORDS_NAME,
2105 fContext.fTypes.fFloat4.get(),
2106 /*builtin=*/true,
2107 Variable::Storage::kGlobal);
2108 fSPIRVBonusVariables.insert(coordsVar.get());
2109 symbols.add(std::move(coordsVar));
2110 }
2111 DSLGlobalVar deviceCoord(DEVICE_COORDS_NAME);
2112 std::unique_ptr<Expression> rtFlipSkSLExpr = rtFlip.release();
2113 DSLExpression x = DSLExpression(rtFlipSkSLExpr->clone()).x();
2114 DSLExpression y = DSLExpression(std::move(rtFlipSkSLExpr)).y();
2115 return this->writeExpression(*dsl::Float4(deviceCoord.x(),
2116 std::move(x) + std::move(y) * deviceCoord.y(),
2117 deviceCoord.z(),
2118 deviceCoord.w()).release(),
2119 out);
2120 }
2121
2122 // Handle flipping sk_Clockwise.
2123 if (variable->modifiers().fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN) {
2124 this->addRTFlipUniform(ref.fLine);
2125 using namespace dsl;
2126 const char* DEVICE_CLOCKWISE_NAME = "$device_Clockwise";
2127 SymbolTable& symbols = *ThreadContext::SymbolTable();
2128 // Use a uniform to flip the Y coordinate. The new expression will be written in
2129 // terms of $device_Clockwise, which is a fake variable that means "access the
2130 // underlying FrontFacing directly".
2131 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
2132 SKSL_RTFLIP_NAME));
2133 if (!symbols[DEVICE_CLOCKWISE_NAME]) {
2134 AutoAttachPoolToThread attach(fProgram.fPool.get());
2135 Modifiers modifiers;
2136 modifiers.fLayout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
2137 auto clockwiseVar = std::make_unique<Variable>(/*line=*/-1,
2138 fContext.fModifiersPool->add(modifiers),
2139 DEVICE_CLOCKWISE_NAME,
2140 fContext.fTypes.fBool.get(),
2141 /*builtin=*/true,
2142 Variable::Storage::kGlobal);
2143 fSPIRVBonusVariables.insert(clockwiseVar.get());
2144 symbols.add(std::move(clockwiseVar));
2145 }
2146 DSLGlobalVar deviceClockwise(DEVICE_CLOCKWISE_NAME);
2147 // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia,
2148 // we use the default convention of "counter-clockwise face is front".
2149 return this->writeExpression(*dsl::Bool(Select(rtFlip.y() > 0,
2150 !deviceClockwise,
2151 deviceClockwise)).release(),
2152 out);
2153 }
2154
2155 return this->getLValue(ref, out)->load(out);
2156 }
2157
writeIndexExpression(const IndexExpression & expr,OutputStream & out)2158 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
2159 if (expr.base()->type().isVector()) {
2160 SpvId base = this->writeExpression(*expr.base(), out);
2161 SpvId index = this->writeExpression(*expr.index(), out);
2162 SpvId result = this->nextId(nullptr);
2163 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
2164 index, out);
2165 return result;
2166 }
2167 return getLValue(expr, out)->load(out);
2168 }
2169
writeFieldAccess(const FieldAccess & f,OutputStream & out)2170 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
2171 return getLValue(f, out)->load(out);
2172 }
2173
writeSwizzle(const Swizzle & swizzle,OutputStream & out)2174 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
2175 SpvId base = this->writeExpression(*swizzle.base(), out);
2176 SpvId result = this->nextId(&swizzle.type());
2177 size_t count = swizzle.components().size();
2178 if (count == 1) {
2179 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.type()), result, base,
2180 swizzle.components()[0], out);
2181 } else {
2182 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
2183 this->writeWord(this->getType(swizzle.type()), out);
2184 this->writeWord(result, out);
2185 this->writeWord(base, out);
2186 this->writeWord(base, out);
2187 for (int component : swizzle.components()) {
2188 this->writeWord(component, out);
2189 }
2190 }
2191 return result;
2192 }
2193
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2194 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2195 const Type& operandType, SpvId lhs,
2196 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2197 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2198 SpvId result = this->nextId(&resultType);
2199 if (is_float(fContext, operandType)) {
2200 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
2201 } else if (is_signed(fContext, operandType)) {
2202 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
2203 } else if (is_unsigned(fContext, operandType)) {
2204 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
2205 } else if (is_bool(fContext, operandType)) {
2206 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
2207 } else {
2208 fContext.fErrors->error(operandType.fLine,
2209 "unsupported operand for binary expression: " + operandType.description());
2210 }
2211 return result;
2212 }
2213
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)2214 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
2215 OutputStream& out) {
2216 if (operandType.isVector()) {
2217 SpvId result = this->nextId(nullptr);
2218 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
2219 return result;
2220 }
2221 return id;
2222 }
2223
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)2224 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2225 SpvOp_ floatOperator, SpvOp_ intOperator,
2226 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
2227 OutputStream& out) {
2228 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
2229 SkASSERT(operandType.isMatrix());
2230 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2231 operandType.rows(),
2232 1));
2233 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
2234 operandType.rows(),
2235 1));
2236 SpvId boolType = this->getType(*fContext.fTypes.fBool);
2237 SpvId result = 0;
2238 for (int i = 0; i < operandType.columns(); i++) {
2239 SpvId columnL = this->nextId(&operandType);
2240 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2241 SpvId columnR = this->nextId(&operandType);
2242 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2243 SpvId compare = this->nextId(&operandType);
2244 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2245 SpvId merge = this->nextId(nullptr);
2246 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2247 if (result != 0) {
2248 SpvId next = this->nextId(nullptr);
2249 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2250 result = next;
2251 }
2252 else {
2253 result = merge;
2254 }
2255 }
2256 return result;
2257 }
2258
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)2259 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
2260 SpvId operand,
2261 SpvOp_ op,
2262 OutputStream& out) {
2263 SkASSERT(operandType.isMatrix());
2264 SpvId columnType = this->getType(operandType.componentType().toCompound(
2265 fContext, /*columns=*/operandType.rows(), /*rows=*/1));
2266
2267 std::vector<SpvId> columns;
2268 columns.reserve(operandType.columns());
2269 for (int i = 0; i < operandType.columns(); i++) {
2270 SpvId srcColumn = this->nextId(&operandType);
2271 this->writeInstruction(SpvOpCompositeExtract, columnType, srcColumn, operand, i, out);
2272
2273 SpvId dstColumn = this->nextId(&operandType);
2274 this->writeInstruction(op, columnType, dstColumn, srcColumn, out);
2275 columns.push_back(dstColumn);
2276 }
2277
2278 return this->writeComposite(columns, operandType, out);
2279 }
2280
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)2281 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2282 SpvId rhs, SpvOp_ op, OutputStream& out) {
2283 SkASSERT(operandType.isMatrix());
2284 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2285 operandType.rows(),
2286 1));
2287 std::vector<SpvId> columns;
2288 columns.reserve(operandType.columns());
2289 for (int i = 0; i < operandType.columns(); i++) {
2290 SpvId columnL = this->nextId(&operandType);
2291 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2292 SpvId columnR = this->nextId(&operandType);
2293 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2294 columns.push_back(this->nextId(&operandType));
2295 this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
2296 }
2297 return this->writeComposite(columns, operandType, out);
2298 }
2299
writeReciprocal(const Type & type,SpvId value,OutputStream & out)2300 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
2301 SkASSERT(type.isFloat());
2302 SpvId one = this->writeLiteral(1.0, type);
2303 SpvId reciprocal = this->nextId(&type);
2304 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
2305 return reciprocal;
2306 }
2307
writeScalarToMatrixSplat(const Type & matrixType,SpvId scalarId,OutputStream & out)2308 SpvId SPIRVCodeGenerator::writeScalarToMatrixSplat(const Type& matrixType,
2309 SpvId scalarId,
2310 OutputStream& out) {
2311 // Splat the scalar into a vector.
2312 const Type& vectorType = matrixType.componentType().toCompound(fContext,
2313 /*columns=*/matrixType.rows(),
2314 /*rows=*/1);
2315 std::vector<SpvId> vecArguments(/*count*/ matrixType.rows(), /*value*/ scalarId);
2316 SpvId vectorId = this->writeComposite(vecArguments, vectorType, out);
2317
2318 // Splat the vector into a matrix.
2319 std::vector<SpvId> matArguments(/*count*/ matrixType.columns(), /*value*/ vectorId);
2320 return this->writeComposite(matArguments, matrixType, out);
2321 }
2322
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2323 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
2324 const Type& rightType, SpvId rhs,
2325 const Type& resultType, OutputStream& out) {
2326 // The comma operator ignores the type of the left-hand side entirely.
2327 if (op.kind() == Token::Kind::TK_COMMA) {
2328 return rhs;
2329 }
2330 // overall type we are operating on: float2, int, uint4...
2331 const Type* operandType;
2332 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2333 // handling in SPIR-V
2334 if (!this->getActualType(leftType).matches(this->getActualType(rightType))) {
2335 if (leftType.isVector() && rightType.isNumber()) {
2336 if (resultType.componentType().isFloat()) {
2337 switch (op.kind()) {
2338 case Token::Kind::TK_SLASH: {
2339 rhs = this->writeReciprocal(rightType, rhs, out);
2340 [[fallthrough]];
2341 }
2342 case Token::Kind::TK_STAR: {
2343 SpvId result = this->nextId(&resultType);
2344 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2345 result, lhs, rhs, out);
2346 return result;
2347 }
2348 default:
2349 break;
2350 }
2351 }
2352 // promote number to vector
2353 const Type& vecType = leftType;
2354 SpvId vec = this->nextId(&vecType);
2355 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2356 this->writeWord(this->getType(vecType), out);
2357 this->writeWord(vec, out);
2358 for (int i = 0; i < vecType.columns(); i++) {
2359 this->writeWord(rhs, out);
2360 }
2361 rhs = vec;
2362 operandType = &leftType;
2363 } else if (rightType.isVector() && leftType.isNumber()) {
2364 if (resultType.componentType().isFloat()) {
2365 if (op.kind() == Token::Kind::TK_STAR) {
2366 SpvId result = this->nextId(&resultType);
2367 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2368 result, rhs, lhs, out);
2369 return result;
2370 }
2371 }
2372 // promote number to vector
2373 const Type& vecType = rightType;
2374 SpvId vec = this->nextId(&vecType);
2375 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2376 this->writeWord(this->getType(vecType), out);
2377 this->writeWord(vec, out);
2378 for (int i = 0; i < vecType.columns(); i++) {
2379 this->writeWord(lhs, out);
2380 }
2381 lhs = vec;
2382 operandType = &rightType;
2383 } else if (leftType.isMatrix()) {
2384 if (op.kind() == Token::Kind::TK_STAR) {
2385 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2386 SpvOp_ spvop;
2387 if (rightType.isMatrix()) {
2388 spvop = SpvOpMatrixTimesMatrix;
2389 } else if (rightType.isVector()) {
2390 spvop = SpvOpMatrixTimesVector;
2391 } else {
2392 SkASSERT(rightType.isScalar());
2393 spvop = SpvOpMatrixTimesScalar;
2394 }
2395 SpvId result = this->nextId(&resultType);
2396 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2397 return result;
2398 } else {
2399 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
2400 // expect to have a scalar here.
2401 SkASSERT(rightType.isScalar());
2402
2403 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
2404 SpvId rhsMatrix = this->writeScalarToMatrixSplat(leftType, rhs, out);
2405
2406 // Perform this operation as matrix-op-matrix.
2407 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
2408 resultType, out);
2409 }
2410 } else if (rightType.isMatrix()) {
2411 if (op.kind() == Token::Kind::TK_STAR) {
2412 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2413 SpvId result = this->nextId(&resultType);
2414 if (leftType.isVector()) {
2415 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
2416 result, lhs, rhs, out);
2417 } else {
2418 SkASSERT(leftType.isScalar());
2419 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
2420 result, rhs, lhs, out);
2421 }
2422 return result;
2423 } else {
2424 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
2425 // expect to have a scalar here.
2426 SkASSERT(leftType.isScalar());
2427
2428 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
2429 SpvId lhsMatrix = this->writeScalarToMatrixSplat(rightType, lhs, out);
2430
2431 // Perform this operation as matrix-op-matrix.
2432 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
2433 resultType, out);
2434 }
2435 } else {
2436 fContext.fErrors->error(leftType.fLine, "unsupported mixed-type expression");
2437 return -1;
2438 }
2439 } else {
2440 operandType = &this->getActualType(leftType);
2441 SkASSERT(operandType->matches(this->getActualType(rightType)));
2442 }
2443 switch (op.kind()) {
2444 case Token::Kind::TK_EQEQ: {
2445 if (operandType->isMatrix()) {
2446 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2447 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2448 }
2449 if (operandType->isStruct()) {
2450 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2451 }
2452 if (operandType->isArray()) {
2453 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2454 }
2455 SkASSERT(resultType.isBoolean());
2456 const Type* tmpType;
2457 if (operandType->isVector()) {
2458 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2459 operandType->columns(),
2460 operandType->rows());
2461 } else {
2462 tmpType = &resultType;
2463 }
2464 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2465 SpvOpFOrdEqual, SpvOpIEqual,
2466 SpvOpIEqual, SpvOpLogicalEqual, out),
2467 *operandType, SpvOpAll, out);
2468 }
2469 case Token::Kind::TK_NEQ:
2470 if (operandType->isMatrix()) {
2471 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2472 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2473 }
2474 if (operandType->isStruct()) {
2475 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2476 }
2477 if (operandType->isArray()) {
2478 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2479 }
2480 [[fallthrough]];
2481 case Token::Kind::TK_LOGICALXOR:
2482 SkASSERT(resultType.isBoolean());
2483 const Type* tmpType;
2484 if (operandType->isVector()) {
2485 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2486 operandType->columns(),
2487 operandType->rows());
2488 } else {
2489 tmpType = &resultType;
2490 }
2491 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2492 SpvOpFOrdNotEqual, SpvOpINotEqual,
2493 SpvOpINotEqual, SpvOpLogicalNotEqual,
2494 out),
2495 *operandType, SpvOpAny, out);
2496 case Token::Kind::TK_GT:
2497 SkASSERT(resultType.isBoolean());
2498 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2499 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2500 SpvOpUGreaterThan, SpvOpUndef, out);
2501 case Token::Kind::TK_LT:
2502 SkASSERT(resultType.isBoolean());
2503 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2504 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2505 case Token::Kind::TK_GTEQ:
2506 SkASSERT(resultType.isBoolean());
2507 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2508 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2509 SpvOpUGreaterThanEqual, SpvOpUndef, out);
2510 case Token::Kind::TK_LTEQ:
2511 SkASSERT(resultType.isBoolean());
2512 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2513 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2514 SpvOpULessThanEqual, SpvOpUndef, out);
2515 case Token::Kind::TK_PLUS:
2516 if (leftType.isMatrix() && rightType.isMatrix()) {
2517 SkASSERT(leftType.matches(rightType));
2518 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFAdd, out);
2519 }
2520 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2521 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2522 case Token::Kind::TK_MINUS:
2523 if (leftType.isMatrix() && rightType.isMatrix()) {
2524 SkASSERT(leftType.matches(rightType));
2525 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFSub, out);
2526 }
2527 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2528 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2529 case Token::Kind::TK_STAR:
2530 if (leftType.isMatrix() && rightType.isMatrix()) {
2531 // matrix multiply
2532 SpvId result = this->nextId(&resultType);
2533 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2534 lhs, rhs, out);
2535 return result;
2536 }
2537 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2538 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2539 case Token::Kind::TK_SLASH:
2540 if (leftType.isMatrix() && rightType.isMatrix()) {
2541 SkASSERT(leftType.matches(rightType));
2542 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFDiv, out);
2543 }
2544 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2545 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2546 case Token::Kind::TK_PERCENT:
2547 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2548 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2549 case Token::Kind::TK_SHL:
2550 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2551 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2552 SpvOpUndef, out);
2553 case Token::Kind::TK_SHR:
2554 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2555 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2556 SpvOpUndef, out);
2557 case Token::Kind::TK_BITWISEAND:
2558 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2559 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2560 case Token::Kind::TK_BITWISEOR:
2561 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2562 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2563 case Token::Kind::TK_BITWISEXOR:
2564 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2565 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2566 default:
2567 fContext.fErrors->error(0, "unsupported token");
2568 return -1;
2569 }
2570 }
2571
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)2572 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
2573 SpvId rhs, OutputStream& out) {
2574 // The inputs must be arrays, and the op must be == or !=.
2575 SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2576 SkASSERT(arrayType.isArray());
2577 const Type& componentType = arrayType.componentType();
2578 const SpvId componentTypeId = this->getType(componentType);
2579 const int arraySize = arrayType.columns();
2580 SkASSERT(arraySize > 0);
2581
2582 // Synthesize equality checks for each item in the array.
2583 const Type& boolType = *fContext.fTypes.fBool;
2584 SpvId allComparisons = (SpvId)-1;
2585 for (int index = 0; index < arraySize; ++index) {
2586 // Get the left and right item in the array.
2587 SpvId itemL = this->nextId(&componentType);
2588 this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemL, lhs, index, out);
2589 SpvId itemR = this->nextId(&componentType);
2590 this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemR, rhs, index, out);
2591 // Use `writeBinaryExpression` with the requested == or != operator on these items.
2592 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
2593 componentType, itemR, boolType, out);
2594 // Merge this comparison result with all the other comparisons we've done.
2595 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2596 }
2597 return allComparisons;
2598 }
2599
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)2600 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
2601 SpvId rhs, OutputStream& out) {
2602 // The inputs must be structs containing fields, and the op must be == or !=.
2603 SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2604 SkASSERT(structType.isStruct());
2605 const std::vector<Type::Field>& fields = structType.fields();
2606 SkASSERT(!fields.empty());
2607
2608 // Synthesize equality checks for each field in the struct.
2609 const Type& boolType = *fContext.fTypes.fBool;
2610 SpvId allComparisons = (SpvId)-1;
2611 for (int index = 0; index < (int)fields.size(); ++index) {
2612 // Get the left and right versions of this field.
2613 const Type& fieldType = *fields[index].fType;
2614 const SpvId fieldTypeId = this->getType(fieldType);
2615
2616 SpvId fieldL = this->nextId(&fieldType);
2617 this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldL, lhs, index, out);
2618 SpvId fieldR = this->nextId(&fieldType);
2619 this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldR, rhs, index, out);
2620 // Use `writeBinaryExpression` with the requested == or != operator on these fields.
2621 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
2622 boolType, out);
2623 // Merge this comparison result with all the other comparisons we've done.
2624 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2625 }
2626 return allComparisons;
2627 }
2628
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)2629 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
2630 OutputStream& out) {
2631 // If this is the first entry, we don't need to merge comparison results with anything.
2632 if (allComparisons == (SpvId)-1) {
2633 return comparison;
2634 }
2635 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
2636 const Type& boolType = *fContext.fTypes.fBool;
2637 SpvId boolTypeId = this->getType(boolType);
2638 SpvId logicalOp = this->nextId(&boolType);
2639 switch (op.kind()) {
2640 case Token::Kind::TK_EQEQ:
2641 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
2642 comparison, allComparisons, out);
2643 break;
2644 case Token::Kind::TK_NEQ:
2645 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
2646 comparison, allComparisons, out);
2647 break;
2648 default:
2649 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
2650 return (SpvId)-1;
2651 }
2652 return logicalOp;
2653 }
2654
division_by_literal_value(Operator op,const Expression & right)2655 static float division_by_literal_value(Operator op, const Expression& right) {
2656 // If this is a division by a literal value, returns that literal value. Otherwise, returns 0.
2657 if (op.kind() == Token::Kind::TK_SLASH && right.isFloatLiteral()) {
2658 float rhsValue = right.as<Literal>().floatValue();
2659 if (std::isfinite(rhsValue)) {
2660 return rhsValue;
2661 }
2662 }
2663 return 0.0f;
2664 }
2665
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2666 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2667 const Expression* left = b.left().get();
2668 const Expression* right = b.right().get();
2669 Operator op = b.getOperator();
2670
2671 switch (op.kind()) {
2672 case Token::Kind::TK_EQ: {
2673 // Handles assignment.
2674 SpvId rhs = this->writeExpression(*right, out);
2675 this->getLValue(*left, out)->store(rhs, out);
2676 return rhs;
2677 }
2678 case Token::Kind::TK_LOGICALAND:
2679 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2680 return this->writeLogicalAnd(*b.left(), *b.right(), out);
2681
2682 case Token::Kind::TK_LOGICALOR:
2683 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2684 return this->writeLogicalOr(*b.left(), *b.right(), out);
2685
2686 default:
2687 break;
2688 }
2689
2690 std::unique_ptr<LValue> lvalue;
2691 SpvId lhs;
2692 if (op.isAssignment()) {
2693 lvalue = this->getLValue(*left, out);
2694 lhs = lvalue->load(out);
2695 } else {
2696 lvalue = nullptr;
2697 lhs = this->writeExpression(*left, out);
2698 }
2699
2700 SpvId rhs;
2701 float rhsValue = division_by_literal_value(op, *right);
2702 if (rhsValue != 0.0f) {
2703 // Rewrite floating-point division by a literal into multiplication by the reciprocal.
2704 // This converts `expr / 2` into `expr * 0.5`
2705 // This improves codegen, especially for certain types of divides (e.g. vector/scalar).
2706 op = Operator(Token::Kind::TK_STAR);
2707 rhs = this->writeLiteral(1.0 / rhsValue, right->type());
2708 } else {
2709 // Write the right-hand side expression normally.
2710 rhs = this->writeExpression(*right, out);
2711 }
2712
2713 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
2714 right->type(), rhs, b.type(), out);
2715 if (lvalue) {
2716 lvalue->store(result, out);
2717 }
2718 return result;
2719 }
2720
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)2721 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
2722 OutputStream& out) {
2723 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
2724 SpvId lhs = this->writeExpression(left, out);
2725 SpvId rhsLabel = this->nextId(nullptr);
2726 SpvId end = this->nextId(nullptr);
2727 SpvId lhsBlock = fCurrentBlock;
2728 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2729 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2730 this->writeLabel(rhsLabel, out);
2731 SpvId rhs = this->writeExpression(right, out);
2732 SpvId rhsBlock = fCurrentBlock;
2733 this->writeInstruction(SpvOpBranch, end, out);
2734 this->writeLabel(end, out);
2735 SpvId result = this->nextId(nullptr);
2736 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
2737 lhsBlock, rhs, rhsBlock, out);
2738 return result;
2739 }
2740
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)2741 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
2742 OutputStream& out) {
2743 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
2744 SpvId lhs = this->writeExpression(left, out);
2745 SpvId rhsLabel = this->nextId(nullptr);
2746 SpvId end = this->nextId(nullptr);
2747 SpvId lhsBlock = fCurrentBlock;
2748 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2749 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2750 this->writeLabel(rhsLabel, out);
2751 SpvId rhs = this->writeExpression(right, out);
2752 SpvId rhsBlock = fCurrentBlock;
2753 this->writeInstruction(SpvOpBranch, end, out);
2754 this->writeLabel(end, out);
2755 SpvId result = this->nextId(nullptr);
2756 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
2757 lhsBlock, rhs, rhsBlock, out);
2758 return result;
2759 }
2760
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2761 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2762 const Type& type = t.type();
2763 SpvId test = this->writeExpression(*t.test(), out);
2764 if (t.ifTrue()->type().columns() == 1 &&
2765 t.ifTrue()->isCompileTimeConstant() &&
2766 t.ifFalse()->isCompileTimeConstant()) {
2767 // both true and false are constants, can just use OpSelect
2768 SpvId result = this->nextId(nullptr);
2769 SpvId trueId = this->writeExpression(*t.ifTrue(), out);
2770 SpvId falseId = this->writeExpression(*t.ifFalse(), out);
2771 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
2772 out);
2773 return result;
2774 }
2775 // was originally using OpPhi to choose the result, but for some reason that is crashing on
2776 // Adreno. Switched to storing the result in a temp variable as glslang does.
2777 SpvId var = this->nextId(nullptr);
2778 this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
2779 var, SpvStorageClassFunction, fVariableBuffer);
2780 SpvId trueLabel = this->nextId(nullptr);
2781 SpvId falseLabel = this->nextId(nullptr);
2782 SpvId end = this->nextId(nullptr);
2783 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2784 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2785 this->writeLabel(trueLabel, out);
2786 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out);
2787 this->writeInstruction(SpvOpBranch, end, out);
2788 this->writeLabel(falseLabel, out);
2789 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out);
2790 this->writeInstruction(SpvOpBranch, end, out);
2791 this->writeLabel(end, out);
2792 SpvId result = this->nextId(&type);
2793 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
2794 return result;
2795 }
2796
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2797 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2798 const Type& type = p.type();
2799 if (p.getOperator().kind() == Token::Kind::TK_MINUS) {
2800 SpvOp_ negateOp = SpvOpFNegate;
2801 if (is_float(fContext, type)) {
2802 negateOp = SpvOpFNegate;
2803 } else if (is_signed(fContext, type) || is_unsigned(fContext, type)) {
2804 negateOp = SpvOpSNegate;
2805 } else {
2806 SkDEBUGFAILF("unsupported prefix expression %s", p.description().c_str());
2807 }
2808
2809 SpvId expr = this->writeExpression(*p.operand(), out);
2810 if (type.isMatrix()) {
2811 return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
2812 }
2813 SpvId result = this->nextId(&type);
2814 SpvId typeId = this->getType(type);
2815 this->writeInstruction(negateOp, typeId, result, expr, out);
2816 return result;
2817 }
2818 switch (p.getOperator().kind()) {
2819 case Token::Kind::TK_PLUS:
2820 return this->writeExpression(*p.operand(), out);
2821 case Token::Kind::TK_PLUSPLUS: {
2822 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2823 SpvId one = this->writeLiteral(1.0, type);
2824 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
2825 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2826 out);
2827 lv->store(result, out);
2828 return result;
2829 }
2830 case Token::Kind::TK_MINUSMINUS: {
2831 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2832 SpvId one = this->writeLiteral(1.0, type);
2833 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
2834 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2835 lv->store(result, out);
2836 return result;
2837 }
2838 case Token::Kind::TK_LOGICALNOT: {
2839 SkASSERT(p.operand()->type().isBoolean());
2840 SpvId result = this->nextId(nullptr);
2841 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
2842 this->writeExpression(*p.operand(), out), out);
2843 return result;
2844 }
2845 case Token::Kind::TK_BITWISENOT: {
2846 SpvId result = this->nextId(nullptr);
2847 this->writeInstruction(SpvOpNot, this->getType(type), result,
2848 this->writeExpression(*p.operand(), out), out);
2849 return result;
2850 }
2851 default:
2852 SkDEBUGFAILF("unsupported prefix expression: %s", p.description().c_str());
2853 return -1;
2854 }
2855 }
2856
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2857 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2858 const Type& type = p.type();
2859 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2860 SpvId result = lv->load(out);
2861 SpvId one = this->writeLiteral(1.0, type);
2862 switch (p.getOperator().kind()) {
2863 case Token::Kind::TK_PLUSPLUS: {
2864 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,
2865 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2866 lv->store(temp, out);
2867 return result;
2868 }
2869 case Token::Kind::TK_MINUSMINUS: {
2870 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub,
2871 SpvOpISub, SpvOpISub, SpvOpUndef, out);
2872 lv->store(temp, out);
2873 return result;
2874 }
2875 default:
2876 SkDEBUGFAILF("unsupported postfix expression %s", p.description().c_str());
2877 return -1;
2878 }
2879 }
2880
writeLiteral(const Literal & l)2881 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
2882 return this->writeLiteral(l.value(), l.type());
2883 }
2884
writeLiteral(double value,const Type & type)2885 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
2886 int32_t valueBits;
2887 if (type.isFloat()) {
2888 float fValue = value;
2889 memcpy(&valueBits, &fValue, sizeof(valueBits));
2890 } else {
2891 SKSL_INT iValue = value;
2892 valueBits = iValue;
2893 }
2894
2895 SPIRVNumberConstant key{valueBits, type.numberKind()};
2896 auto [iter, newlyCreated] = fNumberConstants.insert({key, (SpvId)-1});
2897 if (newlyCreated) {
2898 SpvId result = this->nextId(nullptr);
2899 iter->second = result;
2900
2901 if (type.isBoolean()) {
2902 this->writeInstruction(valueBits ? SpvOpConstantTrue : SpvOpConstantFalse,
2903 this->getType(type), result, fConstantBuffer);
2904 } else {
2905 this->writeInstruction(SpvOpConstant, this->getType(type), result,
2906 (SpvId)valueBits, fConstantBuffer);
2907 }
2908 }
2909
2910 return iter->second;
2911 }
2912
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2913 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2914 SpvId result = fFunctionMap[&f];
2915 SpvId returnTypeId = this->getType(f.returnType());
2916 SpvId functionTypeId = this->getFunctionType(f);
2917 this->writeInstruction(SpvOpFunction, returnTypeId, result,
2918 SpvFunctionControlMaskNone, functionTypeId, out);
2919 std::string mangledName = f.mangledName();
2920 this->writeInstruction(SpvOpName,
2921 result,
2922 std::string_view(mangledName.c_str(), mangledName.size()),
2923 fNameBuffer);
2924 for (const Variable* parameter : f.parameters()) {
2925 SpvId id = this->nextId(nullptr);
2926 fVariableMap[parameter] = id;
2927 SpvId type = this->getPointerType(parameter->type(), SpvStorageClassFunction);
2928 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2929 }
2930 return result;
2931 }
2932
writeFunction(const FunctionDefinition & f,OutputStream & out)2933 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2934 fVariableBuffer.reset();
2935 SpvId result = this->writeFunctionStart(f.declaration(), out);
2936 fCurrentBlock = 0;
2937 this->writeLabel(this->nextId(nullptr), out);
2938 StringStream bodyBuffer;
2939 this->writeBlock(f.body()->as<Block>(), bodyBuffer);
2940 write_stringstream(fVariableBuffer, out);
2941 if (f.declaration().isMain()) {
2942 write_stringstream(fGlobalInitializersBuffer, out);
2943 }
2944 write_stringstream(bodyBuffer, out);
2945 if (fCurrentBlock) {
2946 if (f.declaration().returnType().isVoid()) {
2947 this->writeInstruction(SpvOpReturn, out);
2948 } else {
2949 this->writeInstruction(SpvOpUnreachable, out);
2950 }
2951 }
2952 this->writeInstruction(SpvOpFunctionEnd, out);
2953 return result;
2954 }
2955
writeLayout(const Layout & layout,SpvId target,int line)2956 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int line) {
2957 bool isPushConstant = (layout.fFlags & Layout::kPushConstant_Flag);
2958 if (layout.fLocation >= 0) {
2959 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2960 fDecorationBuffer);
2961 }
2962 if (layout.fBinding >= 0) {
2963 if (isPushConstant) {
2964 fContext.fErrors->error(line, "Can't apply 'binding' to push constants");
2965 } else {
2966 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2967 fDecorationBuffer);
2968 }
2969 }
2970 if (layout.fIndex >= 0) {
2971 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2972 fDecorationBuffer);
2973 }
2974 if (layout.fSet >= 0) {
2975 if (isPushConstant) {
2976 fContext.fErrors->error(line, "Can't apply 'set' to push constants");
2977 } else {
2978 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2979 fDecorationBuffer);
2980 }
2981 }
2982 if (layout.fInputAttachmentIndex >= 0) {
2983 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2984 layout.fInputAttachmentIndex, fDecorationBuffer);
2985 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2986 }
2987 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2988 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2989 fDecorationBuffer);
2990 }
2991 }
2992
writeFieldLayout(const Layout & layout,SpvId target,int member)2993 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
2994 // 'binding' and 'set' can not be applied to struct members
2995 SkASSERT(layout.fBinding == -1);
2996 SkASSERT(layout.fSet == -1);
2997 if (layout.fLocation >= 0) {
2998 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2999 layout.fLocation, fDecorationBuffer);
3000 }
3001 if (layout.fIndex >= 0) {
3002 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
3003 layout.fIndex, fDecorationBuffer);
3004 }
3005 if (layout.fInputAttachmentIndex >= 0) {
3006 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
3007 layout.fInputAttachmentIndex, fDecorationBuffer);
3008 }
3009 if (layout.fBuiltin >= 0) {
3010 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
3011 layout.fBuiltin, fDecorationBuffer);
3012 }
3013 }
3014
memoryLayoutForVariable(const Variable & v) const3015 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
3016 bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
3017 return pushConstant ? MemoryLayout(MemoryLayout::k430_Standard) : fDefaultLayout;
3018 }
3019
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)3020 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
3021 MemoryLayout memoryLayout = this->memoryLayoutForVariable(intf.variable());
3022 SpvId result = this->nextId(nullptr);
3023 const Variable& intfVar = intf.variable();
3024 const Type& type = intfVar.type();
3025 if (!MemoryLayout::LayoutIsSupported(type)) {
3026 fContext.fErrors->error(type.fLine, "type '" + type.displayName() +
3027 "' is not permitted here");
3028 return this->nextId(nullptr);
3029 }
3030 SpvStorageClass_ storageClass = get_storage_class(intf.variable(), SpvStorageClassFunction);
3031 if (fProgram.fInputs.fUseFlipRTUniform && appendRTFlip && type.isStruct()) {
3032 // We can only have one interface block (because we use push_constant and that is limited
3033 // to one per program), so we need to append rtflip to this one rather than synthesize an
3034 // entirely new block when the variable is referenced. And we can't modify the existing
3035 // block, so we instead create a modified copy of it and write that.
3036 std::vector<Type::Field> fields = type.fields();
3037 fields.emplace_back(Modifiers(Layout(/*flags=*/0,
3038 /*location=*/-1,
3039 fProgram.fConfig->fSettings.fRTFlipOffset,
3040 /*binding=*/-1,
3041 /*index=*/-1,
3042 /*set=*/-1,
3043 /*builtin=*/-1,
3044 /*inputAttachmentIndex=*/-1),
3045 /*flags=*/0),
3046 SKSL_RTFLIP_NAME,
3047 fContext.fTypes.fFloat2.get());
3048 {
3049 AutoAttachPoolToThread attach(fProgram.fPool.get());
3050 const Type* rtFlipStructType =
3051 fProgram.fSymbols->takeOwnershipOfSymbol(Type::MakeStructType(
3052 type.fLine, type.name(), std::move(fields), /*interfaceBlock=*/true));
3053 const Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
3054 std::make_unique<Variable>(intfVar.fLine,
3055 &intfVar.modifiers(),
3056 intfVar.name(),
3057 rtFlipStructType,
3058 intfVar.isBuiltin(),
3059 intfVar.storage()));
3060 fSPIRVBonusVariables.insert(modifiedVar);
3061 InterfaceBlock modifiedCopy(intf.fLine,
3062 *modifiedVar,
3063 intf.typeName(),
3064 intf.instanceName(),
3065 intf.arraySize(),
3066 intf.typeOwner());
3067 result = this->writeInterfaceBlock(modifiedCopy, false);
3068 fProgram.fSymbols->add(std::make_unique<Field>(
3069 /*line=*/-1, modifiedVar, rtFlipStructType->fields().size() - 1));
3070 }
3071 fVariableMap[&intfVar] = result;
3072 fWroteRTFlip = true;
3073 return result;
3074 }
3075 const Modifiers& intfModifiers = intfVar.modifiers();
3076 SpvId typeId = this->getType(type, memoryLayout);
3077 if (intfModifiers.fLayout.fBuiltin == -1) {
3078 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
3079 }
3080 SpvId ptrType = this->nextId(nullptr);
3081 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
3082 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
3083 Layout layout = intfModifiers.fLayout;
3084 if (storageClass == SpvStorageClassUniform && layout.fSet < 0) {
3085 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3086 }
3087 this->writeLayout(layout, result, intfVar.fLine);
3088 fVariableMap[&intfVar] = result;
3089 return result;
3090 }
3091
isDead(const Variable & var) const3092 bool SPIRVCodeGenerator::isDead(const Variable& var) const {
3093 // During SPIR-V code generation, we synthesize some extra bonus variables that don't actually
3094 // exist in the Program at all and aren't tracked by the ProgramUsage. They aren't dead, though.
3095 if (fSPIRVBonusVariables.count(&var)) {
3096 return false;
3097 }
3098 ProgramUsage::VariableCounts counts = fProgram.usage()->get(var);
3099 if (counts.fRead || counts.fWrite) {
3100 return false;
3101 }
3102 // It's not entirely clear what the rules are for eliding interface variables. Generally, it
3103 // causes problems to elide them, even when they're dead.
3104 return !(var.modifiers().fFlags &
3105 (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag));
3106 }
3107
writeGlobalVar(ProgramKind kind,const VarDeclaration & varDecl)3108 void SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind, const VarDeclaration& varDecl) {
3109 const Variable& var = varDecl.var();
3110 if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
3111 kind != ProgramKind::kFragment) {
3112 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
3113 return;
3114 }
3115 if (this->isDead(var)) {
3116 return;
3117 }
3118 SpvStorageClass_ storageClass = get_storage_class(var, SpvStorageClassPrivate);
3119 if (storageClass == SpvStorageClassUniform) {
3120 // Top-level uniforms are emitted in writeUniformBuffer.
3121 fTopLevelUniforms.push_back(&varDecl);
3122 return;
3123 }
3124 // Add this global to the variable map.
3125 const Type& type = var.type();
3126 SpvId id = this->nextId(&type);
3127 fVariableMap[&var] = id;
3128 if (var.modifiers().fLayout.fBuiltin == SK_SECONDARYFRAGCOLOR_BUILTIN) {
3129 // sk_SecondaryFragColor corresponds to gl_SecondaryFragColorEXT, which isn't supposed to
3130 // appear in a SPIR-V program (it's only valid in ES2). Report an error.
3131 fContext.fErrors->error(varDecl.fLine, "sk_SecondaryFragColor is not allowed in SPIR-V");
3132 return;
3133 }
3134 Layout layout = var.modifiers().fLayout;
3135 if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
3136 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3137 }
3138 SpvId typeId = this->getPointerType(type, storageClass);
3139 this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
3140 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3141 if (varDecl.value()) {
3142 SkASSERT(!fCurrentBlock);
3143 fCurrentBlock = -1;
3144 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
3145 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
3146 fCurrentBlock = 0;
3147 }
3148 this->writeLayout(layout, id, var.fLine);
3149 if (var.modifiers().fFlags & Modifiers::kFlat_Flag) {
3150 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
3151 }
3152 if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) {
3153 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
3154 fDecorationBuffer);
3155 }
3156 }
3157
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)3158 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
3159 const Variable& var = varDecl.var();
3160 SpvId id = this->nextId(&var.type());
3161 fVariableMap[&var] = id;
3162 SpvId type = this->getPointerType(var.type(), SpvStorageClassFunction);
3163 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
3164 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3165 if (varDecl.value()) {
3166 SpvId value = this->writeExpression(*varDecl.value(), out);
3167 this->writeInstruction(SpvOpStore, id, value, out);
3168 }
3169 }
3170
writeStatement(const Statement & s,OutputStream & out)3171 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
3172 switch (s.kind()) {
3173 case Statement::Kind::kInlineMarker:
3174 case Statement::Kind::kNop:
3175 break;
3176 case Statement::Kind::kBlock:
3177 this->writeBlock(s.as<Block>(), out);
3178 break;
3179 case Statement::Kind::kExpression:
3180 this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
3181 break;
3182 case Statement::Kind::kReturn:
3183 this->writeReturnStatement(s.as<ReturnStatement>(), out);
3184 break;
3185 case Statement::Kind::kVarDeclaration:
3186 this->writeVarDeclaration(s.as<VarDeclaration>(), out);
3187 break;
3188 case Statement::Kind::kIf:
3189 this->writeIfStatement(s.as<IfStatement>(), out);
3190 break;
3191 case Statement::Kind::kFor:
3192 this->writeForStatement(s.as<ForStatement>(), out);
3193 break;
3194 case Statement::Kind::kDo:
3195 this->writeDoStatement(s.as<DoStatement>(), out);
3196 break;
3197 case Statement::Kind::kSwitch:
3198 this->writeSwitchStatement(s.as<SwitchStatement>(), out);
3199 break;
3200 case Statement::Kind::kBreak:
3201 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
3202 break;
3203 case Statement::Kind::kContinue:
3204 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
3205 break;
3206 case Statement::Kind::kDiscard:
3207 this->writeInstruction(SpvOpKill, out);
3208 break;
3209 default:
3210 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
3211 break;
3212 }
3213 }
3214
writeBlock(const Block & b,OutputStream & out)3215 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
3216 for (const std::unique_ptr<Statement>& stmt : b.children()) {
3217 this->writeStatement(*stmt, out);
3218 }
3219 }
3220
writeIfStatement(const IfStatement & stmt,OutputStream & out)3221 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
3222 SpvId test = this->writeExpression(*stmt.test(), out);
3223 SpvId ifTrue = this->nextId(nullptr);
3224 SpvId ifFalse = this->nextId(nullptr);
3225 if (stmt.ifFalse()) {
3226 SpvId end = this->nextId(nullptr);
3227 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3228 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3229 this->writeLabel(ifTrue, out);
3230 this->writeStatement(*stmt.ifTrue(), out);
3231 if (fCurrentBlock) {
3232 this->writeInstruction(SpvOpBranch, end, out);
3233 }
3234 this->writeLabel(ifFalse, out);
3235 this->writeStatement(*stmt.ifFalse(), out);
3236 if (fCurrentBlock) {
3237 this->writeInstruction(SpvOpBranch, end, out);
3238 }
3239 this->writeLabel(end, out);
3240 } else {
3241 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
3242 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3243 this->writeLabel(ifTrue, out);
3244 this->writeStatement(*stmt.ifTrue(), out);
3245 if (fCurrentBlock) {
3246 this->writeInstruction(SpvOpBranch, ifFalse, out);
3247 }
3248 this->writeLabel(ifFalse, out);
3249 }
3250 }
3251
writeForStatement(const ForStatement & f,OutputStream & out)3252 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
3253 if (f.initializer()) {
3254 this->writeStatement(*f.initializer(), out);
3255 }
3256 SpvId header = this->nextId(nullptr);
3257 SpvId start = this->nextId(nullptr);
3258 SpvId body = this->nextId(nullptr);
3259 SpvId next = this->nextId(nullptr);
3260 fContinueTarget.push(next);
3261 SpvId end = this->nextId(nullptr);
3262 fBreakTarget.push(end);
3263 this->writeInstruction(SpvOpBranch, header, out);
3264 this->writeLabel(header, out);
3265 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
3266 this->writeInstruction(SpvOpBranch, start, out);
3267 this->writeLabel(start, out);
3268 if (f.test()) {
3269 SpvId test = this->writeExpression(*f.test(), out);
3270 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
3271 } else {
3272 this->writeInstruction(SpvOpBranch, body, out);
3273 }
3274 this->writeLabel(body, out);
3275 this->writeStatement(*f.statement(), out);
3276 if (fCurrentBlock) {
3277 this->writeInstruction(SpvOpBranch, next, out);
3278 }
3279 this->writeLabel(next, out);
3280 if (f.next()) {
3281 this->writeExpression(*f.next(), out);
3282 }
3283 this->writeInstruction(SpvOpBranch, header, out);
3284 this->writeLabel(end, out);
3285 fBreakTarget.pop();
3286 fContinueTarget.pop();
3287 }
3288
writeDoStatement(const DoStatement & d,OutputStream & out)3289 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
3290 SpvId header = this->nextId(nullptr);
3291 SpvId start = this->nextId(nullptr);
3292 SpvId next = this->nextId(nullptr);
3293 SpvId continueTarget = this->nextId(nullptr);
3294 fContinueTarget.push(continueTarget);
3295 SpvId end = this->nextId(nullptr);
3296 fBreakTarget.push(end);
3297 this->writeInstruction(SpvOpBranch, header, out);
3298 this->writeLabel(header, out);
3299 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
3300 this->writeInstruction(SpvOpBranch, start, out);
3301 this->writeLabel(start, out);
3302 this->writeStatement(*d.statement(), out);
3303 if (fCurrentBlock) {
3304 this->writeInstruction(SpvOpBranch, next, out);
3305 }
3306 this->writeLabel(next, out);
3307 this->writeInstruction(SpvOpBranch, continueTarget, out);
3308 this->writeLabel(continueTarget, out);
3309 SpvId test = this->writeExpression(*d.test(), out);
3310 this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
3311 this->writeLabel(end, out);
3312 fBreakTarget.pop();
3313 fContinueTarget.pop();
3314 }
3315
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)3316 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3317 SpvId value = this->writeExpression(*s.value(), out);
3318 std::vector<SpvId> labels;
3319 SpvId end = this->nextId(nullptr);
3320 SpvId defaultLabel = end;
3321 fBreakTarget.push(end);
3322 int size = 3;
3323 auto& cases = s.cases();
3324 for (const std::unique_ptr<Statement>& stmt : cases) {
3325 const SwitchCase& c = stmt->as<SwitchCase>();
3326 SpvId label = this->nextId(nullptr);
3327 labels.push_back(label);
3328 if (!c.isDefault()) {
3329 size += 2;
3330 } else {
3331 defaultLabel = label;
3332 }
3333 }
3334 labels.push_back(end);
3335 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3336 this->writeOpCode(SpvOpSwitch, size, out);
3337 this->writeWord(value, out);
3338 this->writeWord(defaultLabel, out);
3339 for (size_t i = 0; i < cases.size(); ++i) {
3340 const SwitchCase& c = cases[i]->as<SwitchCase>();
3341 if (c.isDefault()) {
3342 continue;
3343 }
3344 this->writeWord(c.value(), out);
3345 this->writeWord(labels[i], out);
3346 }
3347 for (size_t i = 0; i < cases.size(); ++i) {
3348 const SwitchCase& c = cases[i]->as<SwitchCase>();
3349 this->writeLabel(labels[i], out);
3350 this->writeStatement(*c.statement(), out);
3351 if (fCurrentBlock) {
3352 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3353 }
3354 }
3355 this->writeLabel(end, out);
3356 fBreakTarget.pop();
3357 }
3358
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3359 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3360 if (r.expression()) {
3361 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
3362 out);
3363 } else {
3364 this->writeInstruction(SpvOpReturn, out);
3365 }
3366 }
3367
3368 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)3369 static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
3370 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
3371 }
3372
writeEntrypointAdapter(const FunctionDeclaration & main)3373 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
3374 const FunctionDeclaration& main) {
3375 // Our goal is to synthesize a tiny helper function which looks like this:
3376 // void _entrypoint() { sk_FragColor = main(); }
3377
3378 // Fish a symbol table out of main().
3379 std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main);
3380
3381 // Get `sk_FragColor` as a writable reference.
3382 const Symbol* skFragColorSymbol = (*symbolTable)["sk_FragColor"];
3383 SkASSERT(skFragColorSymbol);
3384 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
3385 auto skFragColorRef = std::make_unique<VariableReference>(/*line=*/-1, &skFragColorVar,
3386 VariableReference::RefKind::kWrite);
3387 // Synthesize a call to the `main()` function.
3388 if (!main.returnType().matches(skFragColorRef->type())) {
3389 fContext.fErrors->error(main.fLine, "SPIR-V does not support returning '" +
3390 main.returnType().description() + "' from main()");
3391 return {};
3392 }
3393 ExpressionArray args;
3394 if (main.parameters().size() == 1) {
3395 if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
3396 fContext.fErrors->error(main.fLine,
3397 "SPIR-V does not support parameter of type '" +
3398 main.parameters()[0]->type().description() + "' to main()");
3399 return {};
3400 }
3401 args.push_back(dsl::Float2(0).release());
3402 }
3403 auto callMainFn = std::make_unique<FunctionCall>(/*line=*/-1, &main.returnType(), &main,
3404 std::move(args));
3405
3406 // Synthesize `skFragColor = main()` as a BinaryExpression.
3407 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
3408 /*line=*/-1,
3409 std::move(skFragColorRef),
3410 Token::Kind::TK_EQ,
3411 std::move(callMainFn),
3412 &main.returnType()));
3413
3414 // Function bodies are always wrapped in a Block.
3415 StatementArray entrypointStmts;
3416 entrypointStmts.push_back(std::move(assignmentStmt));
3417 auto entrypointBlock = Block::Make(/*line=*/-1, std::move(entrypointStmts),
3418 symbolTable, /*isScope=*/true);
3419 // Declare an entrypoint function.
3420 EntrypointAdapter adapter;
3421 adapter.fLayout = {};
3422 adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kHasSideEffects_Flag};
3423 adapter.entrypointDecl =
3424 std::make_unique<FunctionDeclaration>(/*line=*/-1,
3425 &adapter.fModifiers,
3426 "_entrypoint",
3427 /*parameters=*/std::vector<const Variable*>{},
3428 /*returnType=*/fContext.fTypes.fVoid.get(),
3429 /*builtin=*/false);
3430 // Define it.
3431 adapter.entrypointDef = FunctionDefinition::Convert(fContext,
3432 /*line=*/-1,
3433 *adapter.entrypointDecl,
3434 std::move(entrypointBlock),
3435 /*builtin=*/false);
3436
3437 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
3438 return adapter;
3439 }
3440
writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable)3441 void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) {
3442 SkASSERT(!fTopLevelUniforms.empty());
3443 static constexpr char kUniformBufferName[] = "_UniformBuffer";
3444
3445 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
3446 // a lookup table of variables to UniformBuffer field indices.
3447 std::vector<Type::Field> fields;
3448 fields.reserve(fTopLevelUniforms.size());
3449 fTopLevelUniformMap.reserve(fTopLevelUniforms.size());
3450 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
3451 const Variable* var = &topLevelUniform->var();
3452 fTopLevelUniformMap[var] = (int)fields.size();
3453 fields.emplace_back(var->modifiers(), var->name(), &var->type());
3454 }
3455 fUniformBuffer.fStruct = Type::MakeStructType(
3456 /*line=*/-1, kUniformBufferName, std::move(fields), /*interfaceBlock=*/true);
3457
3458 // Create a global variable to contain this struct.
3459 Layout layout;
3460 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
3461 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3462 Modifiers modifiers{layout, Modifiers::kUniform_Flag};
3463
3464 fUniformBuffer.fInnerVariable = std::make_unique<Variable>(
3465 /*line=*/-1, fProgram.fModifiers->add(modifiers), kUniformBufferName,
3466 fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal);
3467
3468 // Create an interface block object for this global variable.
3469 fUniformBuffer.fInterfaceBlock = std::make_unique<InterfaceBlock>(
3470 /*offset=*/-1, *fUniformBuffer.fInnerVariable, kUniformBufferName,
3471 kUniformBufferName, /*arraySize=*/0, topLevelSymbolTable);
3472
3473 // Generate an interface block and hold onto its ID.
3474 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
3475 }
3476
addRTFlipUniform(int line)3477 void SPIRVCodeGenerator::addRTFlipUniform(int line) {
3478 if (fWroteRTFlip) {
3479 return;
3480 }
3481 // Flip variable hasn't been written yet. This means we don't have an existing
3482 // interface block, so we're free to just synthesize one.
3483 fWroteRTFlip = true;
3484 std::vector<Type::Field> fields;
3485 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
3486 fContext.fErrors->error(line, "RTFlipOffset is negative");
3487 }
3488 fields.emplace_back(Modifiers(Layout(/*flags=*/0,
3489 /*location=*/-1,
3490 fProgram.fConfig->fSettings.fRTFlipOffset,
3491 /*binding=*/-1,
3492 /*index=*/-1,
3493 /*set=*/-1,
3494 /*builtin=*/-1,
3495 /*inputAttachmentIndex=*/-1),
3496 /*flags=*/0),
3497 SKSL_RTFLIP_NAME,
3498 fContext.fTypes.fFloat2.get());
3499 std::string_view name = "sksl_synthetic_uniforms";
3500 const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(
3501 Type::MakeStructType(/*line=*/-1, name, fields, /*interfaceBlock=*/true));
3502 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
3503 int binding = -1, set = -1;
3504 if (!usePushConstants) {
3505 binding = fProgram.fConfig->fSettings.fRTFlipBinding;
3506 if (binding == -1) {
3507 fContext.fErrors->error(line, "layout(binding=...) is required in SPIR-V");
3508 }
3509 set = fProgram.fConfig->fSettings.fRTFlipSet;
3510 if (set == -1) {
3511 fContext.fErrors->error(line, "layout(set=...) is required in SPIR-V");
3512 }
3513 }
3514 int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0;
3515 const Modifiers* modsPtr;
3516 {
3517 AutoAttachPoolToThread attach(fProgram.fPool.get());
3518 Modifiers modifiers(Layout(flags,
3519 /*location=*/-1,
3520 /*offset=*/-1,
3521 binding,
3522 /*index=*/-1,
3523 set,
3524 /*builtin=*/-1,
3525 /*inputAttachmentIndex=*/-1),
3526 Modifiers::kUniform_Flag);
3527 modsPtr = fProgram.fModifiers->add(modifiers);
3528 }
3529 const Variable* intfVar = fSynthetics.takeOwnershipOfSymbol(
3530 std::make_unique<Variable>(/*line=*/-1,
3531 modsPtr,
3532 name,
3533 intfStruct,
3534 /*builtin=*/false,
3535 Variable::Storage::kGlobal));
3536 fSPIRVBonusVariables.insert(intfVar);
3537 {
3538 AutoAttachPoolToThread attach(fProgram.fPool.get());
3539 fProgram.fSymbols->add(std::make_unique<Field>(/*line=*/-1, intfVar, /*field=*/0));
3540 }
3541 InterfaceBlock intf(/*line=*/-1,
3542 *intfVar,
3543 name,
3544 /*instanceName=*/"",
3545 /*arraySize=*/0,
3546 std::make_shared<SymbolTable>(fContext, /*builtin=*/false));
3547
3548 this->writeInterfaceBlock(intf, false);
3549 }
3550
writeInstructions(const Program & program,OutputStream & out)3551 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3552 fGLSLExtendedInstructions = this->nextId(nullptr);
3553 StringStream body;
3554 // Assign SpvIds to functions.
3555 const FunctionDeclaration* main = nullptr;
3556 for (const ProgramElement* e : program.elements()) {
3557 if (e->is<FunctionDefinition>()) {
3558 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
3559 const FunctionDeclaration& funcDecl = funcDef.declaration();
3560 fFunctionMap[&funcDecl] = this->nextId(nullptr);
3561 if (funcDecl.isMain()) {
3562 main = &funcDecl;
3563 }
3564 }
3565 }
3566 // Make sure we have a main() function.
3567 if (!main) {
3568 fContext.fErrors->error(/*line=*/-1, "program does not contain a main() function");
3569 return;
3570 }
3571 // Emit interface blocks.
3572 std::set<SpvId> interfaceVars;
3573 for (const ProgramElement* e : program.elements()) {
3574 if (e->is<InterfaceBlock>()) {
3575 const InterfaceBlock& intf = e->as<InterfaceBlock>();
3576 SpvId id = this->writeInterfaceBlock(intf);
3577
3578 const Modifiers& modifiers = intf.variable().modifiers();
3579 if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3580 modifiers.fLayout.fBuiltin == -1 && !this->isDead(intf.variable())) {
3581 interfaceVars.insert(id);
3582 }
3583 }
3584 }
3585 // Emit global variable declarations.
3586 for (const ProgramElement* e : program.elements()) {
3587 if (e->is<GlobalVarDeclaration>()) {
3588 this->writeGlobalVar(program.fConfig->fKind,
3589 e->as<GlobalVarDeclaration>().declaration()->as<VarDeclaration>());
3590 }
3591 }
3592 // Emit top-level uniforms into a dedicated uniform buffer.
3593 if (!fTopLevelUniforms.empty()) {
3594 this->writeUniformBuffer(get_top_level_symbol_table(*main));
3595 }
3596 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
3597 // main() and stores the result into sk_FragColor.
3598 EntrypointAdapter adapter;
3599 if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
3600 adapter = this->writeEntrypointAdapter(*main);
3601 if (adapter.entrypointDecl) {
3602 fFunctionMap[adapter.entrypointDecl.get()] = this->nextId(nullptr);
3603 this->writeFunction(*adapter.entrypointDef, body);
3604 main = adapter.entrypointDecl.get();
3605 }
3606 }
3607 // Emit all the functions.
3608 for (const ProgramElement* e : program.elements()) {
3609 if (e->is<FunctionDefinition>()) {
3610 this->writeFunction(e->as<FunctionDefinition>(), body);
3611 }
3612 }
3613 // Add global in/out variables to the list of interface variables.
3614 for (auto entry : fVariableMap) {
3615 const Variable* var = entry.first;
3616 if (var->storage() == Variable::Storage::kGlobal &&
3617 (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3618 !this->isDead(*var)) {
3619 interfaceVars.insert(entry.second);
3620 }
3621 }
3622 this->writeCapabilities(out);
3623 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3624 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3625 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().length() + 4) / 4) +
3626 (int32_t) interfaceVars.size(), out);
3627 switch (program.fConfig->fKind) {
3628 case ProgramKind::kVertex:
3629 this->writeWord(SpvExecutionModelVertex, out);
3630 break;
3631 case ProgramKind::kFragment:
3632 this->writeWord(SpvExecutionModelFragment, out);
3633 break;
3634 default:
3635 SK_ABORT("cannot write this kind of program to SPIR-V\n");
3636 }
3637 SpvId entryPoint = fFunctionMap[main];
3638 this->writeWord(entryPoint, out);
3639 this->writeString(main->name(), out);
3640 for (int var : interfaceVars) {
3641 this->writeWord(var, out);
3642 }
3643 if (program.fConfig->fKind == ProgramKind::kFragment) {
3644 this->writeInstruction(SpvOpExecutionMode,
3645 fFunctionMap[main],
3646 SpvExecutionModeOriginUpperLeft,
3647 out);
3648 }
3649 for (const ProgramElement* e : program.elements()) {
3650 if (e->is<Extension>()) {
3651 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
3652 }
3653 }
3654
3655 write_stringstream(fExtraGlobalsBuffer, out);
3656 write_stringstream(fNameBuffer, out);
3657 write_stringstream(fDecorationBuffer, out);
3658 write_stringstream(fConstantBuffer, out);
3659 write_stringstream(body, out);
3660 }
3661
generateCode()3662 bool SPIRVCodeGenerator::generateCode() {
3663 SkASSERT(!fContext.fErrors->errorCount());
3664 this->writeWord(SpvMagicNumber, *fOut);
3665 this->writeWord(SpvVersion, *fOut);
3666 this->writeWord(SKSL_MAGIC, *fOut);
3667 StringStream buffer;
3668 this->writeInstructions(fProgram, buffer);
3669 this->writeWord(fIdCount, *fOut);
3670 this->writeWord(0, *fOut); // reserved, always zero
3671 write_stringstream(buffer, *fOut);
3672 fContext.fErrors->reportPendingErrors(PositionInfo());
3673 return fContext.fErrors->errorCount() == 0;
3674 }
3675
3676 } // namespace SkSL
3677