1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkOpts_spi.h"
13 #include "include/private/SkSLIRNode.h"
14 #include "include/private/SkSLProgramElement.h"
15 #include "include/private/SkSLStatement.h"
16 #include "include/private/SkSLSymbol.h"
17 #include "include/private/base/SkTArray.h"
18 #include "include/sksl/DSLCore.h"
19 #include "include/sksl/DSLExpression.h"
20 #include "include/sksl/DSLType.h"
21 #include "include/sksl/DSLVar.h"
22 #include "include/sksl/SkSLErrorReporter.h"
23 #include "include/sksl/SkSLOperator.h"
24 #include "include/sksl/SkSLPosition.h"
25 #include "src/sksl/GLSL.std.450.h"
26 #include "src/sksl/SkSLAnalysis.h"
27 #include "src/sksl/SkSLBuiltinTypes.h"
28 #include "src/sksl/SkSLCompiler.h"
29 #include "src/sksl/SkSLConstantFolder.h"
30 #include "src/sksl/SkSLContext.h"
31 #include "src/sksl/SkSLIntrinsicList.h"
32 #include "src/sksl/SkSLModifiersPool.h"
33 #include "src/sksl/SkSLOutputStream.h"
34 #include "src/sksl/SkSLPool.h"
35 #include "src/sksl/SkSLProgramSettings.h"
36 #include "src/sksl/SkSLThreadContext.h"
37 #include "src/sksl/SkSLUtil.h"
38 #include "src/sksl/analysis/SkSLProgramUsage.h"
39 #include "src/sksl/ir/SkSLBinaryExpression.h"
40 #include "src/sksl/ir/SkSLBlock.h"
41 #include "src/sksl/ir/SkSLConstructor.h"
42 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
43 #include "src/sksl/ir/SkSLConstructorCompound.h"
44 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
45 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
46 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
47 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
48 #include "src/sksl/ir/SkSLConstructorSplat.h"
49 #include "src/sksl/ir/SkSLDoStatement.h"
50 #include "src/sksl/ir/SkSLExpression.h"
51 #include "src/sksl/ir/SkSLExpressionStatement.h"
52 #include "src/sksl/ir/SkSLExtension.h"
53 #include "src/sksl/ir/SkSLField.h"
54 #include "src/sksl/ir/SkSLFieldAccess.h"
55 #include "src/sksl/ir/SkSLForStatement.h"
56 #include "src/sksl/ir/SkSLFunctionCall.h"
57 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
58 #include "src/sksl/ir/SkSLFunctionDefinition.h"
59 #include "src/sksl/ir/SkSLIfStatement.h"
60 #include "src/sksl/ir/SkSLIndexExpression.h"
61 #include "src/sksl/ir/SkSLInterfaceBlock.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLPostfixExpression.h"
64 #include "src/sksl/ir/SkSLPrefixExpression.h"
65 #include "src/sksl/ir/SkSLProgram.h"
66 #include "src/sksl/ir/SkSLReturnStatement.h"
67 #include "src/sksl/ir/SkSLSetting.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLVarDeclarations.h"
73 #include "src/sksl/ir/SkSLVariableReference.h"
74 #include "src/utils/SkBitSet.h"
75
76 #include <cstring>
77 #include <set>
78 #include <string>
79 #include <utility>
80
81 #define kLast_Capability SpvCapabilityMultiViewport
82
83 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
84 constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001;
85
86 namespace SkSL {
87
88 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const89 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
90 return fOp == that.fOp &&
91 fResultKind == that.fResultKind &&
92 fWords == that.fWords;
93 }
94
95 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash96 uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
97 uint32_t hash = key.fResultKind;
98 hash = SkOpts::hash_fn(&key.fOp, sizeof(key.fOp), hash);
99 hash = SkOpts::hash_fn(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
100 return hash;
101 }
102 };
103
104 // This class is used to pass values and result placeholder slots to writeInstruction.
105 struct SPIRVCodeGenerator::Word {
106 enum Kind {
107 kNone, // intended for use as a sentinel, not part of any Instruction
108 kSpvId,
109 kNumber,
110 kDefaultPrecisionResult,
111 kRelaxedPrecisionResult,
112 kUniqueResult,
113 kKeyedResult,
114 };
115
WordSkSL::SPIRVCodeGenerator::Word116 Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word117 Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
118
NumberSkSL::SPIRVCodeGenerator::Word119 static Word Number(int32_t val) {
120 return Word{val, Kind::kNumber};
121 }
122
ResultSkSL::SPIRVCodeGenerator::Word123 static Word Result(const Type& type) {
124 return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
125 }
126
RelaxedResultSkSL::SPIRVCodeGenerator::Word127 static Word RelaxedResult() {
128 return Word{(int32_t)NA, kRelaxedPrecisionResult};
129 }
130
UniqueResultSkSL::SPIRVCodeGenerator::Word131 static Word UniqueResult() {
132 return Word{(int32_t)NA, kUniqueResult};
133 }
134
ResultSkSL::SPIRVCodeGenerator::Word135 static Word Result() {
136 return Word{(int32_t)NA, kDefaultPrecisionResult};
137 }
138
139 // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
140 // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
141 // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word142 static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
143
isResultSkSL::SPIRVCodeGenerator::Word144 bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
145
146 int32_t fValue;
147 Kind fKind;
148 };
149
150 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
151 // sksl generator if we want.
152 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
153 static const int32_t SKSL_MAGIC = 0x001F0000;
154
getIntrinsic(IntrinsicKind ik) const155 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
156
157 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
158 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
159 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
160 GLSLstd450 ## ifFloat, \
161 GLSLstd450 ## ifInt, \
162 GLSLstd450 ## ifUInt, \
163 SpvOpUndef}
164 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
165 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
166 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
167 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
168 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
169 SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
170 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
171 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
172 k ## x ## _SpecialIntrinsic}
173
174 switch (ik) {
175 case k_round_IntrinsicKind: return ALL_GLSL(Round);
176 case k_roundEven_IntrinsicKind: return ALL_GLSL(RoundEven);
177 case k_trunc_IntrinsicKind: return ALL_GLSL(Trunc);
178 case k_abs_IntrinsicKind: return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
179 case k_sign_IntrinsicKind: return BY_TYPE_GLSL(FSign, SSign, SSign);
180 case k_floor_IntrinsicKind: return ALL_GLSL(Floor);
181 case k_ceil_IntrinsicKind: return ALL_GLSL(Ceil);
182 case k_fract_IntrinsicKind: return ALL_GLSL(Fract);
183 case k_radians_IntrinsicKind: return ALL_GLSL(Radians);
184 case k_degrees_IntrinsicKind: return ALL_GLSL(Degrees);
185 case k_sin_IntrinsicKind: return ALL_GLSL(Sin);
186 case k_cos_IntrinsicKind: return ALL_GLSL(Cos);
187 case k_tan_IntrinsicKind: return ALL_GLSL(Tan);
188 case k_asin_IntrinsicKind: return ALL_GLSL(Asin);
189 case k_acos_IntrinsicKind: return ALL_GLSL(Acos);
190 case k_atan_IntrinsicKind: return SPECIAL(Atan);
191 case k_sinh_IntrinsicKind: return ALL_GLSL(Sinh);
192 case k_cosh_IntrinsicKind: return ALL_GLSL(Cosh);
193 case k_tanh_IntrinsicKind: return ALL_GLSL(Tanh);
194 case k_asinh_IntrinsicKind: return ALL_GLSL(Asinh);
195 case k_acosh_IntrinsicKind: return ALL_GLSL(Acosh);
196 case k_atanh_IntrinsicKind: return ALL_GLSL(Atanh);
197 case k_pow_IntrinsicKind: return ALL_GLSL(Pow);
198 case k_exp_IntrinsicKind: return ALL_GLSL(Exp);
199 case k_log_IntrinsicKind: return ALL_GLSL(Log);
200 case k_exp2_IntrinsicKind: return ALL_GLSL(Exp2);
201 case k_log2_IntrinsicKind: return ALL_GLSL(Log2);
202 case k_sqrt_IntrinsicKind: return ALL_GLSL(Sqrt);
203 case k_inverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
204 case k_outerProduct_IntrinsicKind: return ALL_SPIRV(OuterProduct);
205 case k_transpose_IntrinsicKind: return ALL_SPIRV(Transpose);
206 case k_isinf_IntrinsicKind: return ALL_SPIRV(IsInf);
207 case k_isnan_IntrinsicKind: return ALL_SPIRV(IsNan);
208 case k_inversesqrt_IntrinsicKind: return ALL_GLSL(InverseSqrt);
209 case k_determinant_IntrinsicKind: return ALL_GLSL(Determinant);
210 case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
211 case k_matrixInverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
212 case k_mod_IntrinsicKind: return SPECIAL(Mod);
213 case k_modf_IntrinsicKind: return ALL_GLSL(Modf);
214 case k_min_IntrinsicKind: return SPECIAL(Min);
215 case k_max_IntrinsicKind: return SPECIAL(Max);
216 case k_clamp_IntrinsicKind: return SPECIAL(Clamp);
217 case k_saturate_IntrinsicKind: return SPECIAL(Saturate);
218 case k_dot_IntrinsicKind: return FLOAT_SPIRV(Dot);
219 case k_mix_IntrinsicKind: return SPECIAL(Mix);
220 case k_step_IntrinsicKind: return SPECIAL(Step);
221 case k_smoothstep_IntrinsicKind: return SPECIAL(SmoothStep);
222 case k_fma_IntrinsicKind: return ALL_GLSL(Fma);
223 case k_frexp_IntrinsicKind: return ALL_GLSL(Frexp);
224 case k_ldexp_IntrinsicKind: return ALL_GLSL(Ldexp);
225
226 #define PACK(type) case k_pack##type##_IntrinsicKind: return ALL_GLSL(Pack##type); \
227 case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
228 PACK(Snorm4x8);
229 PACK(Unorm4x8);
230 PACK(Snorm2x16);
231 PACK(Unorm2x16);
232 PACK(Half2x16);
233 PACK(Double2x32);
234 #undef PACK
235
236 case k_length_IntrinsicKind: return ALL_GLSL(Length);
237 case k_distance_IntrinsicKind: return ALL_GLSL(Distance);
238 case k_cross_IntrinsicKind: return ALL_GLSL(Cross);
239 case k_normalize_IntrinsicKind: return ALL_GLSL(Normalize);
240 case k_faceforward_IntrinsicKind: return ALL_GLSL(FaceForward);
241 case k_reflect_IntrinsicKind: return ALL_GLSL(Reflect);
242 case k_refract_IntrinsicKind: return ALL_GLSL(Refract);
243 case k_bitCount_IntrinsicKind: return ALL_SPIRV(BitCount);
244 case k_findLSB_IntrinsicKind: return ALL_GLSL(FindILsb);
245 case k_findMSB_IntrinsicKind: return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
246 case k_dFdx_IntrinsicKind: return FLOAT_SPIRV(DPdx);
247 case k_dFdy_IntrinsicKind: return SPECIAL(DFdy);
248 case k_fwidth_IntrinsicKind: return FLOAT_SPIRV(Fwidth);
249 case k_makeSampler2D_IntrinsicKind: return SPECIAL(SampledImage);
250
251 case k_sample_IntrinsicKind: return SPECIAL(Texture);
252 case k_sampleGrad_IntrinsicKind: return SPECIAL(TextureGrad);
253 case k_sampleLod_IntrinsicKind: return SPECIAL(TextureLod);
254 case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
255
256 case k_floatBitsToInt_IntrinsicKind: return ALL_SPIRV(Bitcast);
257 case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
258 case k_intBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
259 case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
260
261 case k_any_IntrinsicKind: return BOOL_SPIRV(Any);
262 case k_all_IntrinsicKind: return BOOL_SPIRV(All);
263 case k_not_IntrinsicKind: return BOOL_SPIRV(LogicalNot);
264
265 case k_equal_IntrinsicKind:
266 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
267 SpvOpFOrdEqual,
268 SpvOpIEqual,
269 SpvOpIEqual,
270 SpvOpLogicalEqual};
271 case k_notEqual_IntrinsicKind:
272 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
273 SpvOpFUnordNotEqual,
274 SpvOpINotEqual,
275 SpvOpINotEqual,
276 SpvOpLogicalNotEqual};
277 case k_lessThan_IntrinsicKind:
278 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
279 SpvOpFOrdLessThan,
280 SpvOpSLessThan,
281 SpvOpULessThan,
282 SpvOpUndef};
283 case k_lessThanEqual_IntrinsicKind:
284 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
285 SpvOpFOrdLessThanEqual,
286 SpvOpSLessThanEqual,
287 SpvOpULessThanEqual,
288 SpvOpUndef};
289 case k_greaterThan_IntrinsicKind:
290 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
291 SpvOpFOrdGreaterThan,
292 SpvOpSGreaterThan,
293 SpvOpUGreaterThan,
294 SpvOpUndef};
295 case k_greaterThanEqual_IntrinsicKind:
296 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
297 SpvOpFOrdGreaterThanEqual,
298 SpvOpSGreaterThanEqual,
299 SpvOpUGreaterThanEqual,
300 SpvOpUndef};
301 default:
302 return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
303 }
304 }
305
writeWord(int32_t word,OutputStream & out)306 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
307 out.write((const char*) &word, sizeof(word));
308 }
309
is_float(const Type & type)310 static bool is_float(const Type& type) {
311 return (type.isScalar() || type.isVector() || type.isMatrix()) &&
312 type.componentType().isFloat();
313 }
314
is_signed(const Type & type)315 static bool is_signed(const Type& type) {
316 return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
317 }
318
is_unsigned(const Type & type)319 static bool is_unsigned(const Type& type) {
320 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
321 }
322
is_bool(const Type & type)323 static bool is_bool(const Type& type) {
324 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
325 }
326
327 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)328 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
329 if (is_float(type)) {
330 return ifFloat;
331 }
332 if (is_signed(type)) {
333 return ifInt;
334 }
335 if (is_unsigned(type)) {
336 return ifUInt;
337 }
338 if (is_bool(type)) {
339 return ifBool;
340 }
341 SkDEBUGFAIL("unrecognized type");
342 return ifFloat;
343 }
344
is_out(const Modifiers & m)345 static bool is_out(const Modifiers& m) {
346 return (m.fFlags & Modifiers::kOut_Flag) != 0;
347 }
348
is_in(const Modifiers & m)349 static bool is_in(const Modifiers& m) {
350 switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
351 case Modifiers::kOut_Flag: // out
352 return false;
353
354 case 0: // implicit in
355 case Modifiers::kIn_Flag: // explicit in
356 case Modifiers::kOut_Flag | Modifiers::kIn_Flag: // inout
357 return true;
358
359 default: SkUNREACHABLE;
360 }
361 }
362
is_control_flow_op(SpvOp_ op)363 static bool is_control_flow_op(SpvOp_ op) {
364 switch (op) {
365 case SpvOpReturn:
366 case SpvOpReturnValue:
367 case SpvOpKill:
368 case SpvOpSwitch:
369 case SpvOpBranch:
370 case SpvOpBranchConditional:
371 return true;
372 default:
373 return false;
374 }
375 }
376
is_globally_reachable_op(SpvOp_ op)377 static bool is_globally_reachable_op(SpvOp_ op) {
378 switch (op) {
379 case SpvOpConstant:
380 case SpvOpConstantTrue:
381 case SpvOpConstantFalse:
382 case SpvOpConstantComposite:
383 case SpvOpTypeVoid:
384 case SpvOpTypeInt:
385 case SpvOpTypeFloat:
386 case SpvOpTypeBool:
387 case SpvOpTypeVector:
388 case SpvOpTypeMatrix:
389 case SpvOpTypeArray:
390 case SpvOpTypePointer:
391 case SpvOpTypeFunction:
392 case SpvOpTypeRuntimeArray:
393 case SpvOpTypeStruct:
394 case SpvOpTypeImage:
395 case SpvOpTypeSampledImage:
396 case SpvOpTypeSampler:
397 case SpvOpVariable:
398 case SpvOpFunction:
399 case SpvOpFunctionParameter:
400 case SpvOpFunctionEnd:
401 case SpvOpExecutionMode:
402 case SpvOpMemoryModel:
403 case SpvOpCapability:
404 case SpvOpExtInstImport:
405 case SpvOpEntryPoint:
406 case SpvOpSource:
407 case SpvOpSourceExtension:
408 case SpvOpName:
409 case SpvOpMemberName:
410 case SpvOpDecorate:
411 case SpvOpMemberDecorate:
412 return true;
413 default:
414 return false;
415 }
416 }
417
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)418 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
419 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
420 SkASSERT(opCode != SpvOpUndef);
421 bool foundDeadCode = false;
422 if (is_control_flow_op(opCode)) {
423 // This instruction causes us to leave the current block.
424 foundDeadCode = (fCurrentBlock == 0);
425 fCurrentBlock = 0;
426 } else if (!is_globally_reachable_op(opCode)) {
427 foundDeadCode = (fCurrentBlock == 0);
428 }
429
430 if (foundDeadCode) {
431 // We just encountered dead code--an instruction that don't have an associated block.
432 // Synthesize a label if this happens; this is necessary to satisfy the validator.
433 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
434 }
435
436 this->writeWord((length << 16) | opCode, out);
437 }
438
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)439 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
440 // The straight-line label type is not important; in any case, no caches are invalidated.
441 SkASSERT(!fCurrentBlock);
442 fCurrentBlock = label;
443 this->writeInstruction(SpvOpLabel, label, out);
444 }
445
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)446 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
447 ConditionalOpCounts ops, OutputStream& out) {
448 switch (type) {
449 case kBranchIsBelow:
450 case kBranchesOnBothSides:
451 // With a backward or bidirectional branch, we haven't seen the code between the label
452 // and the branch yet, so any stored value is potentially suspect. Without scanning
453 // ahead to check, the only safe option is to ditch the store cache entirely.
454 fStoreCache.reset();
455 [[fallthrough]];
456
457 case kBranchIsAbove:
458 // With a forward branch, we can rely on stores that we had cached at the start of the
459 // statement/expression, if they haven't been touched yet. Anything newer than that is
460 // pruned.
461 this->pruneConditionalOps(ops);
462 break;
463 }
464
465 // Emit the label.
466 this->writeLabel(label, kBranchlessBlock, out);
467 }
468
writeInstruction(SpvOp_ opCode,OutputStream & out)469 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
470 this->writeOpCode(opCode, 1, out);
471 }
472
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)473 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
474 this->writeOpCode(opCode, 2, out);
475 this->writeWord(word1, out);
476 }
477
writeString(std::string_view s,OutputStream & out)478 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
479 out.write(s.data(), s.length());
480 switch (s.length() % 4) {
481 case 1:
482 out.write8(0);
483 [[fallthrough]];
484 case 2:
485 out.write8(0);
486 [[fallthrough]];
487 case 3:
488 out.write8(0);
489 break;
490 default:
491 this->writeWord(0, out);
492 break;
493 }
494 }
495
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)496 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
497 OutputStream& out) {
498 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
499 this->writeString(string, out);
500 }
501
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)502 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
503 OutputStream& out) {
504 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
505 this->writeWord(word1, out);
506 this->writeString(string, out);
507 }
508
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)509 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
510 std::string_view string, OutputStream& out) {
511 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
512 this->writeWord(word1, out);
513 this->writeWord(word2, out);
514 this->writeString(string, out);
515 }
516
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)517 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
518 OutputStream& out) {
519 this->writeOpCode(opCode, 3, out);
520 this->writeWord(word1, out);
521 this->writeWord(word2, out);
522 }
523
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)524 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
525 int32_t word3, OutputStream& out) {
526 this->writeOpCode(opCode, 4, out);
527 this->writeWord(word1, out);
528 this->writeWord(word2, out);
529 this->writeWord(word3, out);
530 }
531
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)532 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
533 int32_t word3, int32_t word4, OutputStream& out) {
534 this->writeOpCode(opCode, 5, out);
535 this->writeWord(word1, out);
536 this->writeWord(word2, out);
537 this->writeWord(word3, out);
538 this->writeWord(word4, out);
539 }
540
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)541 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
542 int32_t word3, int32_t word4, int32_t word5,
543 OutputStream& out) {
544 this->writeOpCode(opCode, 6, out);
545 this->writeWord(word1, out);
546 this->writeWord(word2, out);
547 this->writeWord(word3, out);
548 this->writeWord(word4, out);
549 this->writeWord(word5, out);
550 }
551
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)552 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
553 int32_t word3, int32_t word4, int32_t word5,
554 int32_t word6, OutputStream& out) {
555 this->writeOpCode(opCode, 7, out);
556 this->writeWord(word1, out);
557 this->writeWord(word2, out);
558 this->writeWord(word3, out);
559 this->writeWord(word4, out);
560 this->writeWord(word5, out);
561 this->writeWord(word6, out);
562 }
563
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)564 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
565 int32_t word3, int32_t word4, int32_t word5,
566 int32_t word6, int32_t word7, OutputStream& out) {
567 this->writeOpCode(opCode, 8, out);
568 this->writeWord(word1, out);
569 this->writeWord(word2, out);
570 this->writeWord(word3, out);
571 this->writeWord(word4, out);
572 this->writeWord(word5, out);
573 this->writeWord(word6, out);
574 this->writeWord(word7, out);
575 }
576
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)577 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
578 int32_t word3, int32_t word4, int32_t word5,
579 int32_t word6, int32_t word7, int32_t word8,
580 OutputStream& out) {
581 this->writeOpCode(opCode, 9, out);
582 this->writeWord(word1, out);
583 this->writeWord(word2, out);
584 this->writeWord(word3, out);
585 this->writeWord(word4, out);
586 this->writeWord(word5, out);
587 this->writeWord(word6, out);
588 this->writeWord(word7, out);
589 this->writeWord(word8, out);
590 }
591
BuildInstructionKey(SpvOp_ opCode,const SkTArray<Word> & words)592 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(
593 SpvOp_ opCode, const SkTArray<Word>& words) {
594 // Assemble a cache key for this instruction.
595 Instruction key;
596 key.fOp = opCode;
597 key.fWords.resize(words.size());
598 key.fResultKind = Word::Kind::kNone;
599
600 for (int index = 0; index < words.size(); ++index) {
601 const Word& word = words[index];
602 key.fWords[index] = word.fValue;
603 if (word.isResult()) {
604 SkASSERT(key.fResultKind == Word::Kind::kNone);
605 key.fResultKind = word.fKind;
606 }
607 }
608
609 return key;
610 }
611
writeInstruction(SpvOp_ opCode,const SkTArray<Word> & words,OutputStream & out)612 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
613 const SkTArray<Word>& words,
614 OutputStream& out) {
615 // writeOpLoad and writeOpStore have dedicated code.
616 SkASSERT(opCode != SpvOpLoad);
617 SkASSERT(opCode != SpvOpStore);
618
619 // If this instruction exists in our op cache, return the cached SpvId.
620 Instruction key = BuildInstructionKey(opCode, words);
621 if (SpvId* cachedOp = fOpCache.find(key)) {
622 return *cachedOp;
623 }
624
625 SpvId result = NA;
626 Precision precision = Precision::kDefault;
627
628 switch (key.fResultKind) {
629 case Word::Kind::kUniqueResult:
630 // The instruction returns a SpvId, but we do not want deduplication.
631 result = this->nextId(Precision::kDefault);
632 fSpvIdCache.set(result, key);
633 break;
634
635 case Word::Kind::kNone:
636 // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
637 fOpCache.set(key, result);
638 break;
639
640 case Word::Kind::kRelaxedPrecisionResult:
641 precision = Precision::kRelaxed;
642 [[fallthrough]];
643
644 case Word::Kind::kKeyedResult:
645 [[fallthrough]];
646
647 case Word::Kind::kDefaultPrecisionResult:
648 // Consume a new SpvId.
649 result = this->nextId(precision);
650 fOpCache.set(key, result);
651 fSpvIdCache.set(result, key);
652
653 // Globally-reachable ops are not subject to the whims of flow control.
654 if (!is_globally_reachable_op(opCode)) {
655 fReachableOps.push_back(result);
656 }
657 break;
658
659 default:
660 SkDEBUGFAIL("unexpected result kind");
661 break;
662 }
663
664 // Write the requested instruction.
665 this->writeOpCode(opCode, words.size() + 1, out);
666 for (const Word& word : words) {
667 if (word.isResult()) {
668 SkASSERT(result != NA);
669 this->writeWord(result, out);
670 } else {
671 this->writeWord(word.fValue, out);
672 }
673 }
674
675 // Return the result.
676 return result;
677 }
678
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)679 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
680 Precision precision,
681 SpvId pointer,
682 OutputStream& out) {
683 // Look for this pointer in our load-cache.
684 if (SpvId* cachedOp = fStoreCache.find(pointer)) {
685 return *cachedOp;
686 }
687
688 // Write the requested OpLoad instruction.
689 SpvId result = this->nextId(precision);
690 this->writeInstruction(SpvOpLoad, type, result, pointer, out);
691 return result;
692 }
693
writeOpStore(SpvStorageClass_ storageClass,SpvId pointer,SpvId value,OutputStream & out)694 void SPIRVCodeGenerator::writeOpStore(SpvStorageClass_ storageClass,
695 SpvId pointer,
696 SpvId value,
697 OutputStream& out) {
698 // Write the uncached SpvOpStore directly.
699 this->writeInstruction(SpvOpStore, pointer, value, out);
700
701 if (storageClass == SpvStorageClassFunction) {
702 // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
703 // return the cached value as-is.
704 fStoreCache.set(pointer, value);
705 fStoreOps.push_back(pointer);
706 }
707 }
708
writeOpConstantTrue(const Type & type)709 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
710 return this->writeInstruction(SpvOpConstantTrue,
711 Words{this->getType(type), Word::Result()},
712 fConstantBuffer);
713 }
714
writeOpConstantFalse(const Type & type)715 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
716 return this->writeInstruction(SpvOpConstantFalse,
717 Words{this->getType(type), Word::Result()},
718 fConstantBuffer);
719 }
720
writeOpConstant(const Type & type,int32_t valueBits)721 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
722 return this->writeInstruction(
723 SpvOpConstant,
724 Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
725 fConstantBuffer);
726 }
727
writeOpConstantComposite(const Type & type,const SkTArray<SpvId> & values)728 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
729 const SkTArray<SpvId>& values) {
730 SkASSERT(values.size() == (type.isStruct() ? (int)type.fields().size() : type.columns()));
731
732 Words words;
733 words.push_back(this->getType(type));
734 words.push_back(Word::Result());
735 for (SpvId value : values) {
736 words.push_back(value);
737 }
738 return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
739 }
740
toConstants(SpvId value,SkTArray<SpvId> * constants)741 bool SPIRVCodeGenerator::toConstants(SpvId value, SkTArray<SpvId>* constants) {
742 Instruction* instr = fSpvIdCache.find(value);
743 if (!instr) {
744 return false;
745 }
746 switch (instr->fOp) {
747 case SpvOpConstant:
748 case SpvOpConstantTrue:
749 case SpvOpConstantFalse:
750 constants->push_back(value);
751 return true;
752
753 case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
754 // Start at word 2 to skip past ResultType and ResultID.
755 for (int i = 2; i < instr->fWords.size(); ++i) {
756 if (!this->toConstants(instr->fWords[i], constants)) {
757 return false;
758 }
759 }
760 return true;
761
762 default:
763 return false;
764 }
765 }
766
toConstants(SkSpan<const SpvId> values,SkTArray<SpvId> * constants)767 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, SkTArray<SpvId>* constants) {
768 for (SpvId value : values) {
769 if (!this->toConstants(value, constants)) {
770 return false;
771 }
772 }
773 return true;
774 }
775
writeOpCompositeConstruct(const Type & type,const SkTArray<SpvId> & values,OutputStream & out)776 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
777 const SkTArray<SpvId>& values,
778 OutputStream& out) {
779 // If this is a vector composed entirely of literals, write a constant-composite instead.
780 if (type.isVector()) {
781 SkSTArray<4, SpvId> constants;
782 if (this->toConstants(SkSpan(values), &constants)) {
783 // Create a vector from literals.
784 return this->writeOpConstantComposite(type, constants);
785 }
786 }
787
788 // If this is a matrix composed entirely of literals, constant-composite them instead.
789 if (type.isMatrix()) {
790 SkSTArray<16, SpvId> constants;
791 if (this->toConstants(SkSpan(values), &constants)) {
792 // Create each matrix column.
793 SkASSERT(type.isMatrix());
794 const Type& vecType = type.componentType().toCompound(fContext,
795 /*columns=*/type.rows(),
796 /*rows=*/1);
797 SkSTArray<4, SpvId> columnIDs;
798 for (int index=0; index < type.columns(); ++index) {
799 SkSTArray<4, SpvId> columnConstants(&constants[index * type.rows()],
800 type.rows());
801 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
802 }
803 // Compose the matrix from its columns.
804 return this->writeOpConstantComposite(type, columnIDs);
805 }
806 }
807
808 Words words;
809 words.push_back(this->getType(type));
810 words.push_back(Word::Result(type));
811 for (SpvId value : values) {
812 words.push_back(value);
813 }
814
815 return this->writeInstruction(SpvOpCompositeConstruct, words, out);
816 }
817
resultTypeForInstruction(const Instruction & instr)818 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
819 const Instruction& instr) {
820 // This list should contain every op that we cache that has a result and result-type.
821 // (If one is missing, we will not find some optimization opportunities.)
822 // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
823 // universally true, so it's configurable on a per-op basis.
824 int resultTypeWord;
825 switch (instr.fOp) {
826 case SpvOpConstant:
827 case SpvOpConstantTrue:
828 case SpvOpConstantFalse:
829 case SpvOpConstantComposite:
830 case SpvOpCompositeConstruct:
831 case SpvOpCompositeExtract:
832 case SpvOpLoad:
833 resultTypeWord = 0;
834 break;
835
836 default:
837 return nullptr;
838 }
839
840 Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
841 SkASSERT(typeInstr);
842 return typeInstr;
843 }
844
numComponentsForVecInstruction(const Instruction & instr)845 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
846 // If an instruction is in the op cache, its type should be as well.
847 Instruction* typeInstr = this->resultTypeForInstruction(instr);
848 SkASSERT(typeInstr);
849 SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
850 typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
851
852 // For vectors, extract their column count. Scalars have one component by definition.
853 // SpvOpTypeVector ResultID ComponentType NumComponents
854 return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
855 : 1;
856 }
857
toComponent(SpvId id,int component)858 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
859 Instruction* instr = fSpvIdCache.find(id);
860 if (!instr) {
861 return NA;
862 }
863 if (instr->fOp == SpvOpConstantComposite) {
864 // SpvOpConstantComposite ResultType ResultID [components...]
865 // Add 2 to the component index to skip past ResultType and ResultID.
866 return instr->fWords[2 + component];
867 }
868 if (instr->fOp == SpvOpCompositeConstruct) {
869 // SpvOpCompositeConstruct ResultType ResultID [components...]
870 // Vectors have special rules; check to see if we are composing a vector.
871 Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
872 SkASSERT(composedType);
873
874 // When composing a non-vector, each instruction word maps 1:1 to the component index.
875 // We can just extract out the associated component directly.
876 if (composedType->fOp != SpvOpTypeVector) {
877 return instr->fWords[2 + component];
878 }
879
880 // When composing a vector, components can be either scalars or vectors.
881 // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
882 for (int index = 2; index < instr->fWords.size(); ++index) {
883 int32_t currentWord = instr->fWords[index];
884
885 // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
886 Instruction* subinstr = fSpvIdCache.find(currentWord);
887 if (!subinstr) {
888 return NA;
889 }
890 // If this subinstruction contains the component we're looking for...
891 int numComponents = this->numComponentsForVecInstruction(*subinstr);
892 if (component < numComponents) {
893 if (numComponents == 1) {
894 // ... it's a scalar. Return it.
895 SkASSERT(component == 0);
896 return currentWord;
897 } else {
898 // ... it's a vector. Recurse into it.
899 return this->toComponent(currentWord, component);
900 }
901 }
902 // This sub-instruction doesn't contain our component. Keep walking forward.
903 component -= numComponents;
904 }
905 SkDEBUGFAIL("component index goes past the end of this composite value");
906 return NA;
907 }
908 return NA;
909 }
910
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)911 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
912 SpvId base,
913 int component,
914 OutputStream& out) {
915 // If the base op is a composite, we can extract from it directly.
916 SpvId result = this->toComponent(base, component);
917 if (result != NA) {
918 return result;
919 }
920 return this->writeInstruction(
921 SpvOpCompositeExtract,
922 {this->getType(type), Word::Result(type), base, Word::Number(component)},
923 out);
924 }
925
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)926 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
927 SpvId base,
928 int componentA,
929 int componentB,
930 OutputStream& out) {
931 // If the base op is a composite, we can extract from it directly.
932 SpvId result = this->toComponent(base, componentA);
933 if (result != NA) {
934 return this->writeOpCompositeExtract(type, result, componentB, out);
935 }
936 return this->writeInstruction(SpvOpCompositeExtract,
937 {this->getType(type),
938 Word::Result(type),
939 base,
940 Word::Number(componentA),
941 Word::Number(componentB)},
942 out);
943 }
944
writeCapabilities(OutputStream & out)945 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
946 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
947 if (fCapabilities & bit) {
948 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
949 }
950 }
951 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
952 }
953
nextId(const Type * type)954 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
955 return this->nextId(type && type->hasPrecision() && !type->highPrecision()
956 ? Precision::kRelaxed
957 : Precision::kDefault);
958 }
959
nextId(Precision precision)960 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
961 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
962 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
963 fDecorationBuffer);
964 }
965 return fIdCount++;
966 }
967
writeStruct(const Type & type,const MemoryLayout & memoryLayout)968 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
969 // If we've already written out this struct, return its existing SpvId.
970 if (SpvId* cachedStructId = fStructMap.find(&type)) {
971 return *cachedStructId;
972 }
973
974 // Write all of the field types first, so we don't inadvertently write them while we're in the
975 // middle of writing the struct instruction.
976 Words words;
977 words.push_back(Word::UniqueResult());
978 for (const auto& f : type.fields()) {
979 words.push_back(this->getType(*f.fType, memoryLayout));
980 }
981 SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
982 this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
983 fStructMap.set(&type, resultId);
984
985 size_t offset = 0;
986 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
987 const Type::Field& field = type.fields()[i];
988 if (!memoryLayout.isSupported(*field.fType)) {
989 fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
990 "' is not permitted here");
991 return resultId;
992 }
993 size_t size = memoryLayout.size(*field.fType);
994 size_t alignment = memoryLayout.alignment(*field.fType);
995 const Layout& fieldLayout = field.fModifiers.fLayout;
996 if (fieldLayout.fOffset >= 0) {
997 if (fieldLayout.fOffset < (int) offset) {
998 fContext.fErrors->error(field.fPosition, "offset of field '" +
999 std::string(field.fName) + "' must be at least " + std::to_string(offset));
1000 }
1001 if (fieldLayout.fOffset % alignment) {
1002 fContext.fErrors->error(field.fPosition,
1003 "offset of field '" + std::string(field.fName) +
1004 "' must be a multiple of " + std::to_string(alignment));
1005 }
1006 offset = fieldLayout.fOffset;
1007 } else {
1008 size_t mod = offset % alignment;
1009 if (mod) {
1010 offset += alignment - mod;
1011 }
1012 }
1013 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1014 this->writeFieldLayout(fieldLayout, resultId, i);
1015 if (field.fModifiers.fLayout.fBuiltin < 0) {
1016 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1017 (SpvId) offset, fDecorationBuffer);
1018 }
1019 if (field.fType->isMatrix()) {
1020 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1021 fDecorationBuffer);
1022 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1023 (SpvId) memoryLayout.stride(*field.fType),
1024 fDecorationBuffer);
1025 }
1026 if (!field.fType->highPrecision()) {
1027 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1028 SpvDecorationRelaxedPrecision, fDecorationBuffer);
1029 }
1030 offset += size;
1031 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1032 offset += alignment - offset % alignment;
1033 }
1034 }
1035
1036 return resultId;
1037 }
1038
getType(const Type & type)1039 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1040 return this->getType(type, fDefaultLayout);
1041 }
1042
getType(const Type & rawType,const MemoryLayout & layout)1043 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
1044 const Type* type = &rawType;
1045
1046 switch (type->typeKind()) {
1047 case Type::TypeKind::kVoid: {
1048 return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1049 }
1050 case Type::TypeKind::kScalar:
1051 case Type::TypeKind::kLiteral: {
1052 if (type->isBoolean()) {
1053 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1054 }
1055 if (type->isSigned()) {
1056 return this->writeInstruction(
1057 SpvOpTypeInt,
1058 Words{Word::Result(), Word::Number(32), Word::Number(1)},
1059 fConstantBuffer);
1060 }
1061 if (type->isUnsigned()) {
1062 return this->writeInstruction(
1063 SpvOpTypeInt,
1064 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1065 fConstantBuffer);
1066 }
1067 if (type->isFloat()) {
1068 return this->writeInstruction(
1069 SpvOpTypeFloat,
1070 Words{Word::Result(), Word::Number(32)},
1071 fConstantBuffer);
1072 }
1073 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1074 return (SpvId)-1;
1075 }
1076 case Type::TypeKind::kVector: {
1077 SpvId scalarTypeId = this->getType(type->componentType(), layout);
1078 return this->writeInstruction(
1079 SpvOpTypeVector,
1080 Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1081 fConstantBuffer);
1082 }
1083 case Type::TypeKind::kMatrix: {
1084 SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type), layout);
1085 return this->writeInstruction(
1086 SpvOpTypeMatrix,
1087 Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1088 fConstantBuffer);
1089 }
1090 case Type::TypeKind::kArray: {
1091 if (!layout.isSupported(*type)) {
1092 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1093 "' is not permitted here");
1094 return NA;
1095 }
1096 if (type->columns() == 0) {
1097 // We do not support runtime-sized arrays.
1098 fContext.fErrors->error(type->fPosition, "runtime-sized arrays are not supported");
1099 return NA;
1100 }
1101 size_t stride = layout.stride(*type);
1102 SpvId typeId = this->getType(type->componentType(), layout);
1103 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1104 SpvId result = this->writeInstruction(SpvOpTypeArray,
1105 Words{Word::KeyedResult(stride), typeId, countId},
1106 fConstantBuffer);
1107 this->writeInstruction(SpvOpDecorate,
1108 {result, SpvDecorationArrayStride, Word::Number(stride)},
1109 fDecorationBuffer);
1110 return result;
1111 }
1112 case Type::TypeKind::kStruct: {
1113 return this->writeStruct(*type, layout);
1114 }
1115 case Type::TypeKind::kSeparateSampler: {
1116 return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1117 }
1118 case Type::TypeKind::kSampler: {
1119 // Subpass inputs should use the Texture type, not a Sampler.
1120 SkASSERT(type->dimensions() != SpvDimSubpassData);
1121 if (SpvDimBuffer == type->dimensions()) {
1122 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1123 }
1124 SpvId imageTypeId = this->getType(type->textureType(), layout);
1125 return this->writeInstruction(SpvOpTypeSampledImage,
1126 Words{Word::Result(), imageTypeId},
1127 fConstantBuffer);
1128 }
1129 case Type::TypeKind::kTexture: {
1130 SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat, layout);
1131 int sampled = (type->textureAccess() == Type::TextureAccess::kSample) ? 1 : 2;
1132 return this->writeInstruction(SpvOpTypeImage,
1133 Words{Word::Result(),
1134 floatTypeId,
1135 Word::Number(type->dimensions()),
1136 Word::Number(type->isDepth()),
1137 Word::Number(type->isArrayedTexture()),
1138 Word::Number(type->isMultisampled()),
1139 Word::Number(sampled),
1140 SpvImageFormatUnknown},
1141 fConstantBuffer);
1142 }
1143 default: {
1144 SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1145 return NA;
1146 }
1147 }
1148 }
1149
getFunctionType(const FunctionDeclaration & function)1150 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1151 Words words;
1152 words.push_back(Word::Result());
1153 words.push_back(this->getType(function.returnType()));
1154 for (const Variable* parameter : function.parameters()) {
1155 if (parameter->type().typeKind() == Type::TypeKind::kSampler &&
1156 fProgram.fConfig->fSettings.fSPIRVDawnCompatMode) {
1157 words.push_back(this->getFunctionParameterType(parameter->type().textureType()));
1158 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler));
1159 } else {
1160 words.push_back(this->getFunctionParameterType(parameter->type()));
1161 }
1162 }
1163 return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1164 }
1165
getFunctionParameterType(const Type & parameterType)1166 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType) {
1167 // glslang treats all function arguments as pointers whether they need to be or
1168 // not. I was initially puzzled by this until I ran bizarre failures with certain
1169 // patterns of function calls and control constructs, as exemplified by this minimal
1170 // failure case:
1171 //
1172 // void sphere(float x) {
1173 // }
1174 //
1175 // void map() {
1176 // sphere(1.0);
1177 // }
1178 //
1179 // void main() {
1180 // for (int i = 0; i < 1; i++) {
1181 // map();
1182 // }
1183 // }
1184 //
1185 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1186 // crashes. Making it take a float* and storing the argument in a temporary variable,
1187 // as glslang does, fixes it.
1188 //
1189 // The consensus among shader compiler authors seems to be that GPU driver generally don't
1190 // handle value-based parameters consistently. It is highly likely that they fit their
1191 // implementations to conform to glslang. We take care to do so ourselves.
1192 //
1193 // Our implementation first stores every parameter value into a function storage-class pointer
1194 // before calling a function. The exception is for opaque handle types (samplers and textures)
1195 // which must be stored in a pointer with UniformConstant storage-class. This prevents
1196 // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1197 // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1198 SpvStorageClass_ storageClass;
1199 if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1200 parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1201 parameterType.typeKind() == Type::TypeKind::kTexture) {
1202 storageClass = SpvStorageClassUniformConstant;
1203 } else {
1204 storageClass = SpvStorageClassFunction;
1205 }
1206 return this->getPointerType(parameterType, storageClass);
1207 }
1208
getPointerType(const Type & type,SpvStorageClass_ storageClass)1209 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1210 return this->getPointerType(
1211 type, this->memoryLayoutForStorageClass(storageClass), storageClass);
1212 }
1213
getPointerType(const Type & type,const MemoryLayout & layout,SpvStorageClass_ storageClass)1214 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1215 SpvStorageClass_ storageClass) {
1216 return this->writeInstruction(
1217 SpvOpTypePointer,
1218 Words{Word::Result(), Word::Number(storageClass), this->getType(type, layout)},
1219 fConstantBuffer);
1220 }
1221
writeExpression(const Expression & expr,OutputStream & out)1222 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1223 switch (expr.kind()) {
1224 case Expression::Kind::kBinary:
1225 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1226 case Expression::Kind::kConstructorArrayCast:
1227 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1228 case Expression::Kind::kConstructorArray:
1229 case Expression::Kind::kConstructorStruct:
1230 return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1231 case Expression::Kind::kConstructorDiagonalMatrix:
1232 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1233 case Expression::Kind::kConstructorMatrixResize:
1234 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1235 case Expression::Kind::kConstructorScalarCast:
1236 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1237 case Expression::Kind::kConstructorSplat:
1238 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1239 case Expression::Kind::kConstructorCompound:
1240 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1241 case Expression::Kind::kConstructorCompoundCast:
1242 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1243 case Expression::Kind::kFieldAccess:
1244 return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1245 case Expression::Kind::kFunctionCall:
1246 return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1247 case Expression::Kind::kLiteral:
1248 return this->writeLiteral(expr.as<Literal>());
1249 case Expression::Kind::kPrefix:
1250 return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1251 case Expression::Kind::kPostfix:
1252 return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1253 case Expression::Kind::kSwizzle:
1254 return this->writeSwizzle(expr.as<Swizzle>(), out);
1255 case Expression::Kind::kVariableReference:
1256 return this->writeVariableReference(expr.as<VariableReference>(), out);
1257 case Expression::Kind::kTernary:
1258 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1259 case Expression::Kind::kIndex:
1260 return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1261 case Expression::Kind::kSetting:
1262 return this->writeExpression(*expr.as<Setting>().toLiteral(fContext), out);
1263 default:
1264 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1265 break;
1266 }
1267 return NA;
1268 }
1269
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1270 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1271 const FunctionDeclaration& function = c.function();
1272 Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1273 if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1274 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1275 "'");
1276 return NA;
1277 }
1278 const ExpressionArray& arguments = c.arguments();
1279 int32_t intrinsicId = intrinsic.floatOp;
1280 if (arguments.size() > 0) {
1281 const Type& type = arguments[0]->type();
1282 if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1283 // Keep the default float op.
1284 } else {
1285 intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1286 intrinsic.unsignedOp, intrinsic.boolOp);
1287 }
1288 }
1289 switch (intrinsic.opKind) {
1290 case kGLSL_STD_450_IntrinsicOpcodeKind: {
1291 SpvId result = this->nextId(&c.type());
1292 SkTArray<SpvId> argumentIds;
1293 std::vector<TempVar> tempVars;
1294 argumentIds.reserve_back(arguments.size());
1295 for (int i = 0; i < arguments.size(); i++) {
1296 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1297 }
1298 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1299 this->writeWord(this->getType(c.type()), out);
1300 this->writeWord(result, out);
1301 this->writeWord(fGLSLExtendedInstructions, out);
1302 this->writeWord(intrinsicId, out);
1303 for (SpvId id : argumentIds) {
1304 this->writeWord(id, out);
1305 }
1306 this->copyBackTempVars(tempVars, out);
1307 return result;
1308 }
1309 case kSPIRV_IntrinsicOpcodeKind: {
1310 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
1311 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
1312 intrinsicId = SpvOpFMul;
1313 }
1314 SpvId result = this->nextId(&c.type());
1315 SkTArray<SpvId> argumentIds;
1316 std::vector<TempVar> tempVars;
1317 argumentIds.reserve_back(arguments.size());
1318 for (int i = 0; i < arguments.size(); i++) {
1319 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1320 }
1321 if (!c.type().isVoid()) {
1322 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1323 this->writeWord(this->getType(c.type()), out);
1324 this->writeWord(result, out);
1325 } else {
1326 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
1327 }
1328 for (SpvId id : argumentIds) {
1329 this->writeWord(id, out);
1330 }
1331 this->copyBackTempVars(tempVars, out);
1332 return result;
1333 }
1334 case kSpecial_IntrinsicOpcodeKind:
1335 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1336 default:
1337 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
1338 function.description() + "'");
1339 return NA;
1340 }
1341 }
1342
vectorize(const Expression & arg,int vectorSize,OutputStream & out)1343 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
1344 SkASSERT(vectorSize >= 1 && vectorSize <= 4);
1345 const Type& argType = arg.type();
1346 if (argType.isScalar() && vectorSize > 1) {
1347 ConstructorSplat splat{arg.fPosition,
1348 argType.toCompound(fContext, vectorSize, /*rows=*/1),
1349 arg.clone()};
1350 return this->writeConstructorSplat(splat, out);
1351 }
1352
1353 SkASSERT(vectorSize == argType.columns());
1354 return this->writeExpression(arg, out);
1355 }
1356
vectorize(const ExpressionArray & args,OutputStream & out)1357 SkTArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
1358 int vectorSize = 1;
1359 for (const auto& a : args) {
1360 if (a->type().isVector()) {
1361 if (vectorSize > 1) {
1362 SkASSERT(a->type().columns() == vectorSize);
1363 } else {
1364 vectorSize = a->type().columns();
1365 }
1366 }
1367 }
1368 SkTArray<SpvId> result;
1369 result.reserve_back(args.size());
1370 for (const auto& arg : args) {
1371 result.push_back(this->vectorize(*arg, vectorSize, out));
1372 }
1373 return result;
1374 }
1375
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const SkTArray<SpvId> & args,OutputStream & out)1376 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
1377 SpvId signedInst, SpvId unsignedInst,
1378 const SkTArray<SpvId>& args,
1379 OutputStream& out) {
1380 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
1381 this->writeWord(this->getType(type), out);
1382 this->writeWord(id, out);
1383 this->writeWord(fGLSLExtendedInstructions, out);
1384 this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
1385 for (SpvId a : args) {
1386 this->writeWord(a, out);
1387 }
1388 }
1389
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)1390 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1391 OutputStream& out) {
1392 const ExpressionArray& arguments = c.arguments();
1393 const Type& callType = c.type();
1394 SpvId result = this->nextId(nullptr);
1395 switch (kind) {
1396 case kAtan_SpecialIntrinsic: {
1397 SkSTArray<2, SpvId> argumentIds;
1398 for (const std::unique_ptr<Expression>& arg : arguments) {
1399 argumentIds.push_back(this->writeExpression(*arg, out));
1400 }
1401 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1402 this->writeWord(this->getType(callType), out);
1403 this->writeWord(result, out);
1404 this->writeWord(fGLSLExtendedInstructions, out);
1405 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1406 for (SpvId id : argumentIds) {
1407 this->writeWord(id, out);
1408 }
1409 break;
1410 }
1411 case kSampledImage_SpecialIntrinsic: {
1412 SkASSERT(arguments.size() == 2);
1413 SpvId img = this->writeExpression(*arguments[0], out);
1414 SpvId sampler = this->writeExpression(*arguments[1], out);
1415 this->writeInstruction(SpvOpSampledImage,
1416 this->getType(callType),
1417 result,
1418 img,
1419 sampler,
1420 out);
1421 break;
1422 }
1423 case kSubpassLoad_SpecialIntrinsic: {
1424 SpvId img = this->writeExpression(*arguments[0], out);
1425 ExpressionArray args;
1426 args.reserve_back(2);
1427 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
1428 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
1429 ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
1430 SpvId coords = this->writeExpression(ctor, out);
1431 if (arguments.size() == 1) {
1432 this->writeInstruction(SpvOpImageRead,
1433 this->getType(callType),
1434 result,
1435 img,
1436 coords,
1437 out);
1438 } else {
1439 SkASSERT(arguments.size() == 2);
1440 SpvId sample = this->writeExpression(*arguments[1], out);
1441 this->writeInstruction(SpvOpImageRead,
1442 this->getType(callType),
1443 result,
1444 img,
1445 coords,
1446 SpvImageOperandsSampleMask,
1447 sample,
1448 out);
1449 }
1450 break;
1451 }
1452 case kTexture_SpecialIntrinsic: {
1453 SpvOp_ op = SpvOpImageSampleImplicitLod;
1454 const Type& arg1Type = arguments[1]->type();
1455 switch (arguments[0]->type().dimensions()) {
1456 case SpvDim1D:
1457 if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
1458 op = SpvOpImageSampleProjImplicitLod;
1459 } else {
1460 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
1461 }
1462 break;
1463 case SpvDim2D:
1464 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
1465 op = SpvOpImageSampleProjImplicitLod;
1466 } else {
1467 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
1468 }
1469 break;
1470 case SpvDim3D:
1471 if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
1472 op = SpvOpImageSampleProjImplicitLod;
1473 } else {
1474 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
1475 }
1476 break;
1477 case SpvDimCube: // fall through
1478 case SpvDimRect: // fall through
1479 case SpvDimBuffer: // fall through
1480 case SpvDimSubpassData:
1481 break;
1482 }
1483 SpvId type = this->getType(callType);
1484 SpvId sampler = this->writeExpression(*arguments[0], out);
1485 SpvId uv = this->writeExpression(*arguments[1], out);
1486 if (arguments.size() == 3) {
1487 this->writeInstruction(op, type, result, sampler, uv,
1488 SpvImageOperandsBiasMask,
1489 this->writeExpression(*arguments[2], out),
1490 out);
1491 } else {
1492 SkASSERT(arguments.size() == 2);
1493 if (fProgram.fConfig->fSettings.fSharpenTextures) {
1494 SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
1495 *fContext.fTypes.fFloat);
1496 this->writeInstruction(op, type, result, sampler, uv,
1497 SpvImageOperandsBiasMask, lodBias, out);
1498 } else {
1499 this->writeInstruction(op, type, result, sampler, uv,
1500 out);
1501 }
1502 }
1503 break;
1504 }
1505 case kTextureGrad_SpecialIntrinsic: {
1506 SpvOp_ op = SpvOpImageSampleExplicitLod;
1507 SkASSERT(arguments.size() == 4);
1508 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
1509 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
1510 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
1511 SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
1512 SpvId type = this->getType(callType);
1513 SpvId sampler = this->writeExpression(*arguments[0], out);
1514 SpvId uv = this->writeExpression(*arguments[1], out);
1515 SpvId dPdx = this->writeExpression(*arguments[2], out);
1516 SpvId dPdy = this->writeExpression(*arguments[3], out);
1517 this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
1518 dPdx, dPdy, out);
1519 break;
1520 }
1521 case kTextureLod_SpecialIntrinsic: {
1522 SpvOp_ op = SpvOpImageSampleExplicitLod;
1523 SkASSERT(arguments.size() == 3);
1524 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
1525 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
1526 const Type& arg1Type = arguments[1]->type();
1527 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
1528 op = SpvOpImageSampleProjExplicitLod;
1529 } else {
1530 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
1531 }
1532 SpvId type = this->getType(callType);
1533 SpvId sampler = this->writeExpression(*arguments[0], out);
1534 SpvId uv = this->writeExpression(*arguments[1], out);
1535 this->writeInstruction(op, type, result, sampler, uv,
1536 SpvImageOperandsLodMask,
1537 this->writeExpression(*arguments[2], out),
1538 out);
1539 break;
1540 }
1541 case kMod_SpecialIntrinsic: {
1542 SkTArray<SpvId> args = this->vectorize(arguments, out);
1543 SkASSERT(args.size() == 2);
1544 const Type& operandType = arguments[0]->type();
1545 SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
1546 SkASSERT(op != SpvOpUndef);
1547 this->writeOpCode(op, 5, out);
1548 this->writeWord(this->getType(operandType), out);
1549 this->writeWord(result, out);
1550 this->writeWord(args[0], out);
1551 this->writeWord(args[1], out);
1552 break;
1553 }
1554 case kDFdy_SpecialIntrinsic: {
1555 SpvId fn = this->writeExpression(*arguments[0], out);
1556 this->writeOpCode(SpvOpDPdy, 4, out);
1557 this->writeWord(this->getType(callType), out);
1558 this->writeWord(result, out);
1559 this->writeWord(fn, out);
1560 if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
1561 this->addRTFlipUniform(c.fPosition);
1562 using namespace dsl;
1563 DSLExpression rtFlip(
1564 ThreadContext::Compiler().convertIdentifier(Position(), SKSL_RTFLIP_NAME));
1565 SpvId rtFlipY = this->vectorize(*rtFlip.y().release(), callType.columns(), out);
1566 SpvId flipped = this->nextId(&callType);
1567 this->writeInstruction(
1568 SpvOpFMul, this->getType(callType), flipped, result, rtFlipY, out);
1569 result = flipped;
1570 }
1571 break;
1572 }
1573 case kClamp_SpecialIntrinsic: {
1574 SkTArray<SpvId> args = this->vectorize(arguments, out);
1575 SkASSERT(args.size() == 3);
1576 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1577 GLSLstd450UClamp, args, out);
1578 break;
1579 }
1580 case kMax_SpecialIntrinsic: {
1581 SkTArray<SpvId> args = this->vectorize(arguments, out);
1582 SkASSERT(args.size() == 2);
1583 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
1584 GLSLstd450UMax, args, out);
1585 break;
1586 }
1587 case kMin_SpecialIntrinsic: {
1588 SkTArray<SpvId> args = this->vectorize(arguments, out);
1589 SkASSERT(args.size() == 2);
1590 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
1591 GLSLstd450UMin, args, out);
1592 break;
1593 }
1594 case kMix_SpecialIntrinsic: {
1595 SkTArray<SpvId> args = this->vectorize(arguments, out);
1596 SkASSERT(args.size() == 3);
1597 if (arguments[2]->type().componentType().isBoolean()) {
1598 // Use OpSelect to implement Boolean mix().
1599 SpvId falseId = this->writeExpression(*arguments[0], out);
1600 SpvId trueId = this->writeExpression(*arguments[1], out);
1601 SpvId conditionId = this->writeExpression(*arguments[2], out);
1602 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
1603 conditionId, trueId, falseId, out);
1604 } else {
1605 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
1606 SpvOpUndef, args, out);
1607 }
1608 break;
1609 }
1610 case kSaturate_SpecialIntrinsic: {
1611 SkASSERT(arguments.size() == 1);
1612 ExpressionArray finalArgs;
1613 finalArgs.reserve_back(3);
1614 finalArgs.push_back(arguments[0]->clone());
1615 finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/0));
1616 finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/1));
1617 SkTArray<SpvId> spvArgs = this->vectorize(finalArgs, out);
1618 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1619 GLSLstd450UClamp, spvArgs, out);
1620 break;
1621 }
1622 case kSmoothStep_SpecialIntrinsic: {
1623 SkTArray<SpvId> args = this->vectorize(arguments, out);
1624 SkASSERT(args.size() == 3);
1625 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
1626 SpvOpUndef, args, out);
1627 break;
1628 }
1629 case kStep_SpecialIntrinsic: {
1630 SkTArray<SpvId> args = this->vectorize(arguments, out);
1631 SkASSERT(args.size() == 2);
1632 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
1633 SpvOpUndef, args, out);
1634 break;
1635 }
1636 case kMatrixCompMult_SpecialIntrinsic: {
1637 SkASSERT(arguments.size() == 2);
1638 SpvId lhs = this->writeExpression(*arguments[0], out);
1639 SpvId rhs = this->writeExpression(*arguments[1], out);
1640 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
1641 break;
1642 }
1643 }
1644 return result;
1645 }
1646
writeFunctionCallArgument(const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,OutputStream & out,SpvId * outSynthesizedSamplerId)1647 SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const FunctionCall& call,
1648 int argIndex,
1649 std::vector<TempVar>* tempVars,
1650 OutputStream& out,
1651 SpvId* outSynthesizedSamplerId) {
1652 const FunctionDeclaration& funcDecl = call.function();
1653 const Expression& arg = *call.arguments()[argIndex];
1654 const Modifiers& paramModifiers = funcDecl.parameters()[argIndex]->modifiers();
1655
1656 // ID of temporary variable that we will use to hold this argument, or 0 if it is being
1657 // passed directly
1658 SpvId tmpVar;
1659 // if we need a temporary var to store this argument, this is the value to store in the var
1660 SpvId tmpValueId = NA;
1661
1662 if (is_out(paramModifiers)) {
1663 std::unique_ptr<LValue> lv = this->getLValue(arg, out);
1664 // We handle out params with a temp var that we copy back to the original variable at the
1665 // end of the call. GLSL guarantees that the original variable will be unchanged until the
1666 // end of the call, and also that out params are written back to their original variables in
1667 // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
1668 if (is_in(paramModifiers)) {
1669 tmpValueId = lv->load(out);
1670 }
1671 tmpVar = this->nextId(&arg.type());
1672 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
1673 } else if (funcDecl.isIntrinsic()) {
1674 // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
1675 return this->writeExpression(arg, out);
1676 } else if (arg.is<VariableReference>() &&
1677 (arg.type().typeKind() == Type::TypeKind::kSampler ||
1678 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
1679 arg.type().typeKind() == Type::TypeKind::kTexture)) {
1680 // Opaque handle (sampler/texture) arguments are always declared as pointers but never
1681 // stored in intermediates when calling user-defined functions.
1682 //
1683 // The case for intrinsics (which take opaque arguments by value) is handled above just like
1684 // regular pointers.
1685 //
1686 // See getFunctionParameterType for further explanation.
1687 const Variable* var = arg.as<VariableReference>().variable();
1688
1689 // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
1690 if (const auto* p = fSynthesizedSamplerMap.find(var)) {
1691 SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
1692 SkASSERT(arg.type().typeKind() == Type::TypeKind::kSampler);
1693 SkASSERT(outSynthesizedSamplerId);
1694
1695 SpvId* img = fVariableMap.find((*p)->fTexture.get());
1696 SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
1697 SkASSERT(img);
1698 SkASSERT(sampler);
1699
1700 *outSynthesizedSamplerId = *sampler;
1701 return *img;
1702 }
1703
1704 SpvId* entry = fVariableMap.find(var);
1705 SkASSERTF(entry, "%s", arg.description().c_str());
1706 return *entry;
1707 } else {
1708 // We always use pointer parameters when calling user functions.
1709 // See getFunctionParameterType for further explanation.
1710 tmpValueId = this->writeExpression(arg, out);
1711 tmpVar = this->nextId(nullptr);
1712 }
1713 this->writeInstruction(SpvOpVariable,
1714 this->getPointerType(arg.type(), SpvStorageClassFunction),
1715 tmpVar,
1716 SpvStorageClassFunction,
1717 fVariableBuffer);
1718 if (tmpValueId != NA) {
1719 this->writeOpStore(SpvStorageClassFunction, tmpVar, tmpValueId, out);
1720 }
1721 return tmpVar;
1722 }
1723
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)1724 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
1725 for (const TempVar& tempVar : tempVars) {
1726 SpvId load = this->nextId(tempVar.type);
1727 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
1728 tempVar.lvalue->store(load, out);
1729 }
1730 }
1731
writeFunctionCall(const FunctionCall & c,OutputStream & out)1732 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1733 const FunctionDeclaration& function = c.function();
1734 if (function.isIntrinsic() && !function.definition()) {
1735 return this->writeIntrinsicCall(c, out);
1736 }
1737 const ExpressionArray& arguments = c.arguments();
1738 SpvId* entry = fFunctionMap.find(&function);
1739 if (!entry) {
1740 fContext.fErrors->error(c.fPosition, "function '" + function.description() +
1741 "' is not defined");
1742 return NA;
1743 }
1744 // Temp variables are used to write back out-parameters after the function call is complete.
1745 std::vector<TempVar> tempVars;
1746 SkTArray<SpvId> argumentIds;
1747 argumentIds.reserve_back(arguments.size());
1748 for (int i = 0; i < arguments.size(); i++) {
1749 SpvId samplerId = NA;
1750 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out, &samplerId));
1751 if (samplerId != NA) {
1752 argumentIds.push_back(samplerId);
1753 }
1754 }
1755 SpvId result = this->nextId(nullptr);
1756 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
1757 this->writeWord(this->getType(c.type()), out);
1758 this->writeWord(result, out);
1759 this->writeWord(*entry, out);
1760 for (SpvId id : argumentIds) {
1761 this->writeWord(id, out);
1762 }
1763 // Now that the call is complete, we copy temp out-variables back to their real lvalues.
1764 this->copyBackTempVars(tempVars, out);
1765 return result;
1766 }
1767
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)1768 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
1769 const Type& inputType,
1770 const Type& outputType,
1771 OutputStream& out) {
1772 if (outputType.isFloat()) {
1773 return this->castScalarToFloat(inputExprId, inputType, outputType, out);
1774 }
1775 if (outputType.isSigned()) {
1776 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
1777 }
1778 if (outputType.isUnsigned()) {
1779 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
1780 }
1781 if (outputType.isBoolean()) {
1782 return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
1783 }
1784
1785 fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
1786 outputType.description());
1787 return inputExprId;
1788 }
1789
writeFloatConstructor(const AnyConstructor & c,OutputStream & out)1790 SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) {
1791 SkASSERT(c.argumentSpan().size() == 1);
1792 SkASSERT(c.type().isFloat());
1793 const Expression& ctorExpr = *c.argumentSpan().front();
1794 SpvId expressionId = this->writeExpression(ctorExpr, out);
1795 return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out);
1796 }
1797
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1798 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
1799 const Type& outputType, OutputStream& out) {
1800 // Casting a float to float is a no-op.
1801 if (inputType.isFloat()) {
1802 return inputId;
1803 }
1804
1805 // Given the input type, generate the appropriate instruction to cast to float.
1806 SpvId result = this->nextId(&outputType);
1807 if (inputType.isBoolean()) {
1808 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
1809 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
1810 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1811 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1812 inputId, oneID, zeroID, out);
1813 } else if (inputType.isSigned()) {
1814 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
1815 } else if (inputType.isUnsigned()) {
1816 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
1817 } else {
1818 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
1819 return NA;
1820 }
1821 return result;
1822 }
1823
writeIntConstructor(const AnyConstructor & c,OutputStream & out)1824 SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) {
1825 SkASSERT(c.argumentSpan().size() == 1);
1826 SkASSERT(c.type().isSigned());
1827 const Expression& ctorExpr = *c.argumentSpan().front();
1828 SpvId expressionId = this->writeExpression(ctorExpr, out);
1829 return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out);
1830 }
1831
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1832 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
1833 const Type& outputType, OutputStream& out) {
1834 // Casting a signed int to signed int is a no-op.
1835 if (inputType.isSigned()) {
1836 return inputId;
1837 }
1838
1839 // Given the input type, generate the appropriate instruction to cast to signed int.
1840 SpvId result = this->nextId(&outputType);
1841 if (inputType.isBoolean()) {
1842 // Use OpSelect to convert the boolean argument to a literal 1 or 0.
1843 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
1844 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1845 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1846 inputId, oneID, zeroID, out);
1847 } else if (inputType.isFloat()) {
1848 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
1849 } else if (inputType.isUnsigned()) {
1850 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1851 } else {
1852 SkDEBUGFAILF("unsupported type for signed int typecast: %s",
1853 inputType.description().c_str());
1854 return NA;
1855 }
1856 return result;
1857 }
1858
writeUIntConstructor(const AnyConstructor & c,OutputStream & out)1859 SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) {
1860 SkASSERT(c.argumentSpan().size() == 1);
1861 SkASSERT(c.type().isUnsigned());
1862 const Expression& ctorExpr = *c.argumentSpan().front();
1863 SpvId expressionId = this->writeExpression(ctorExpr, out);
1864 return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out);
1865 }
1866
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1867 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
1868 const Type& outputType, OutputStream& out) {
1869 // Casting an unsigned int to unsigned int is a no-op.
1870 if (inputType.isUnsigned()) {
1871 return inputId;
1872 }
1873
1874 // Given the input type, generate the appropriate instruction to cast to unsigned int.
1875 SpvId result = this->nextId(&outputType);
1876 if (inputType.isBoolean()) {
1877 // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
1878 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
1879 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1880 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1881 inputId, oneID, zeroID, out);
1882 } else if (inputType.isFloat()) {
1883 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
1884 } else if (inputType.isSigned()) {
1885 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1886 } else {
1887 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
1888 inputType.description().c_str());
1889 return NA;
1890 }
1891 return result;
1892 }
1893
writeBooleanConstructor(const AnyConstructor & c,OutputStream & out)1894 SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) {
1895 SkASSERT(c.argumentSpan().size() == 1);
1896 SkASSERT(c.type().isBoolean());
1897 const Expression& ctorExpr = *c.argumentSpan().front();
1898 SpvId expressionId = this->writeExpression(ctorExpr, out);
1899 return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out);
1900 }
1901
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)1902 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
1903 const Type& outputType, OutputStream& out) {
1904 // Casting a bool to bool is a no-op.
1905 if (inputType.isBoolean()) {
1906 return inputId;
1907 }
1908
1909 // Given the input type, generate the appropriate instruction to cast to bool.
1910 SpvId result = this->nextId(nullptr);
1911 if (inputType.isSigned()) {
1912 // Synthesize a boolean result by comparing the input against a signed zero literal.
1913 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1914 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1915 inputId, zeroID, out);
1916 } else if (inputType.isUnsigned()) {
1917 // Synthesize a boolean result by comparing the input against an unsigned zero literal.
1918 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1919 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1920 inputId, zeroID, out);
1921 } else if (inputType.isFloat()) {
1922 // Synthesize a boolean result by comparing the input against a floating-point zero literal.
1923 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1924 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
1925 inputId, zeroID, out);
1926 } else {
1927 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
1928 return NA;
1929 }
1930 return result;
1931 }
1932
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1933 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
1934 OutputStream& out) {
1935 SkASSERT(srcType.isMatrix());
1936 SkASSERT(dstType.isMatrix());
1937 SkASSERT(srcType.componentType().matches(dstType.componentType()));
1938 const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
1939 const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
1940 SkASSERT(dstType.componentType().isFloat());
1941 SpvId dstColumnTypeId = this->getType(dstColumnType);
1942 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
1943 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
1944
1945 SkSTArray<4, SpvId> columns;
1946 for (int i = 0; i < dstType.columns(); i++) {
1947 if (i < srcType.columns()) {
1948 // we're still inside the src matrix, copy the column
1949 SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
1950 SpvId dstColumn;
1951 if (srcType.rows() == dstType.rows()) {
1952 // columns are equal size, don't need to do anything
1953 dstColumn = srcColumn;
1954 }
1955 else if (dstType.rows() > srcType.rows()) {
1956 // dst column is bigger, need to zero-pad it
1957 SkSTArray<4, SpvId> values;
1958 values.push_back(srcColumn);
1959 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
1960 values.push_back((i == j) ? oneId : zeroId);
1961 }
1962 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
1963 }
1964 else {
1965 // dst column is smaller, need to swizzle the src column
1966 dstColumn = this->nextId(&dstType);
1967 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
1968 this->writeWord(dstColumnTypeId, out);
1969 this->writeWord(dstColumn, out);
1970 this->writeWord(srcColumn, out);
1971 this->writeWord(srcColumn, out);
1972 for (int j = 0; j < dstType.rows(); j++) {
1973 this->writeWord(j, out);
1974 }
1975 }
1976 columns.push_back(dstColumn);
1977 } else {
1978 // we're past the end of the src matrix, need to synthesize an identity-matrix column
1979 SkSTArray<4, SpvId> values;
1980 for (int j = 0; j < dstType.rows(); ++j) {
1981 values.push_back((i == j) ? oneId : zeroId);
1982 }
1983 columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
1984 }
1985 }
1986
1987 return this->writeOpCompositeConstruct(dstType, columns, out);
1988 }
1989
addColumnEntry(const Type & columnType,SkTArray<SpvId> * currentColumn,SkTArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)1990 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
1991 SkTArray<SpvId>* currentColumn,
1992 SkTArray<SpvId>* columnIds,
1993 int rows,
1994 SpvId entry,
1995 OutputStream& out) {
1996 SkASSERT(currentColumn->size() < rows);
1997 currentColumn->push_back(entry);
1998 if (currentColumn->size() == rows) {
1999 // Synthesize this column into a vector.
2000 SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2001 columnIds->push_back(columnId);
2002 currentColumn->clear();
2003 }
2004 }
2005
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2006 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2007 const Type& type = c.type();
2008 SkASSERT(type.isMatrix());
2009 SkASSERT(!c.arguments().empty());
2010 const Type& arg0Type = c.arguments()[0]->type();
2011 // go ahead and write the arguments so we don't try to write new instructions in the middle of
2012 // an instruction
2013 SkSTArray<16, SpvId> arguments;
2014 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2015 arguments.push_back(this->writeExpression(*arg, out));
2016 }
2017
2018 if (arguments.size() == 1 && arg0Type.isVector()) {
2019 // Special-case handling of float4 -> mat2x2.
2020 SkASSERT(type.rows() == 2 && type.columns() == 2);
2021 SkASSERT(arg0Type.columns() == 4);
2022 SpvId v[4];
2023 for (int i = 0; i < 4; ++i) {
2024 v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2025 }
2026 const Type& vecType = type.componentType().toCompound(fContext, /*columns=*/2, /*rows=*/1);
2027 SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2028 SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2029 return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2030 }
2031
2032 int rows = type.rows();
2033 const Type& columnType = type.componentType().toCompound(fContext,
2034 /*columns=*/rows, /*rows=*/1);
2035 // SpvIds of completed columns of the matrix.
2036 SkSTArray<4, SpvId> columnIds;
2037 // SpvIds of scalars we have written to the current column so far.
2038 SkSTArray<4, SpvId> currentColumn;
2039 for (int i = 0; i < arguments.size(); i++) {
2040 const Type& argType = c.arguments()[i]->type();
2041 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2042 // This vector is a complete matrix column by itself and can be used as-is.
2043 columnIds.push_back(arguments[i]);
2044 } else if (argType.columns() == 1) {
2045 // This argument is a lone scalar and can be added to the current column as-is.
2046 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out);
2047 } else {
2048 // This argument needs to be decomposed into its constituent scalars.
2049 for (int j = 0; j < argType.columns(); ++j) {
2050 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2051 arguments[i], j, out);
2052 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out);
2053 }
2054 }
2055 }
2056 SkASSERT(columnIds.size() == type.columns());
2057 return this->writeOpCompositeConstruct(type, columnIds, out);
2058 }
2059
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2060 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2061 OutputStream& out) {
2062 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2063 : this->writeVectorConstructor(c, out);
2064 }
2065
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2066 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2067 const Type& type = c.type();
2068 const Type& componentType = type.componentType();
2069 SkASSERT(type.isVector());
2070
2071 SkSTArray<4, SpvId> arguments;
2072 for (int i = 0; i < c.arguments().size(); i++) {
2073 const Type& argType = c.arguments()[i]->type();
2074 SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2075
2076 SpvId arg = this->writeExpression(*c.arguments()[i], out);
2077 if (argType.isMatrix()) {
2078 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2079 // each scalar separately.
2080 SkASSERT(argType.rows() == 2);
2081 SkASSERT(argType.columns() == 2);
2082 for (int j = 0; j < 4; ++j) {
2083 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2084 j / 2, j % 2, out));
2085 }
2086 } else if (argType.isVector()) {
2087 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2088 // vector arguments at all, so we always extract each vector component and pass them
2089 // into OpCompositeConstruct individually.
2090 for (int j = 0; j < argType.columns(); j++) {
2091 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2092 }
2093 } else {
2094 arguments.push_back(arg);
2095 }
2096 }
2097
2098 return this->writeOpCompositeConstruct(type, arguments, out);
2099 }
2100
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2101 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2102 // Write the splat argument.
2103 SpvId argument = this->writeExpression(*c.argument(), out);
2104
2105 // Generate a OpCompositeConstruct which repeats the argument N times.
2106 SkSTArray<4, SpvId> values;
2107 values.push_back_n(/*n=*/c.type().columns(), /*t=*/argument);
2108 return this->writeOpCompositeConstruct(c.type(), values, out);
2109 }
2110
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2111 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2112 SkASSERT(c.type().isArray() || c.type().isStruct());
2113 auto ctorArgs = c.argumentSpan();
2114
2115 SkSTArray<4, SpvId> arguments;
2116 for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2117 arguments.push_back(this->writeExpression(*arg, out));
2118 }
2119
2120 return this->writeOpCompositeConstruct(c.type(), arguments, out);
2121 }
2122
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2123 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2124 OutputStream& out) {
2125 const Type& type = c.type();
2126 if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2127 return this->writeExpression(*c.argument(), out);
2128 }
2129
2130 const Expression& ctorExpr = *c.argument();
2131 SpvId expressionId = this->writeExpression(ctorExpr, out);
2132 return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2133 }
2134
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2135 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2136 OutputStream& out) {
2137 const Type& ctorType = c.type();
2138 const Type& argType = c.argument()->type();
2139 SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2140
2141 // Write the composite that we are casting. If the actual type matches, we are done.
2142 SpvId compositeId = this->writeExpression(*c.argument(), out);
2143 if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
2144 return compositeId;
2145 }
2146
2147 // writeMatrixCopy can cast matrices to a different type.
2148 if (ctorType.isMatrix()) {
2149 return this->writeMatrixCopy(compositeId, argType, ctorType, out);
2150 }
2151
2152 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
2153 // components and convert each one manually.
2154 const Type& srcType = argType.componentType();
2155 const Type& dstType = ctorType.componentType();
2156
2157 SkSTArray<4, SpvId> arguments;
2158 for (int index = 0; index < argType.columns(); ++index) {
2159 SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
2160 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
2161 }
2162
2163 return this->writeOpCompositeConstruct(ctorType, arguments, out);
2164 }
2165
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)2166 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
2167 OutputStream& out) {
2168 const Type& type = c.type();
2169 SkASSERT(type.isMatrix());
2170 SkASSERT(c.argument()->type().isScalar());
2171
2172 // Write out the scalar argument.
2173 SpvId diagonal = this->writeExpression(*c.argument(), out);
2174
2175 // Build the diagonal matrix.
2176 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2177
2178 const Type& vecType = type.componentType().toCompound(fContext,
2179 /*columns=*/type.rows(),
2180 /*rows=*/1);
2181 SkSTArray<4, SpvId> columnIds;
2182 SkSTArray<4, SpvId> arguments;
2183 arguments.resize(type.rows());
2184 for (int column = 0; column < type.columns(); column++) {
2185 for (int row = 0; row < type.rows(); row++) {
2186 arguments[row] = (row == column) ? diagonal : zeroId;
2187 }
2188 columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
2189 }
2190 return this->writeOpCompositeConstruct(type, columnIds, out);
2191 }
2192
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)2193 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
2194 OutputStream& out) {
2195 // Write the input matrix.
2196 SpvId argument = this->writeExpression(*c.argument(), out);
2197
2198 // Use matrix-copy to resize the input matrix to its new size.
2199 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
2200 }
2201
get_storage_class_for_global_variable(const Variable & var,SpvStorageClass_ fallbackStorageClass)2202 static SpvStorageClass_ get_storage_class_for_global_variable(
2203 const Variable& var, SpvStorageClass_ fallbackStorageClass) {
2204 SkASSERT(var.storage() == Variable::Storage::kGlobal);
2205
2206 const Modifiers& modifiers = var.modifiers();
2207 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2208 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
2209 return SpvStorageClassInput;
2210 }
2211 if (modifiers.fFlags & Modifiers::kOut_Flag) {
2212 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
2213 return SpvStorageClassOutput;
2214 }
2215 if (modifiers.fFlags & Modifiers::kUniform_Flag) {
2216 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
2217 return SpvStorageClassPushConstant;
2218 }
2219 if (var.type().typeKind() == Type::TypeKind::kSampler ||
2220 var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2221 var.type().typeKind() == Type::TypeKind::kTexture) {
2222 return SpvStorageClassUniformConstant;
2223 }
2224 return SpvStorageClassUniform;
2225 }
2226 return fallbackStorageClass;
2227 }
2228
get_storage_class(const Expression & expr)2229 static SpvStorageClass_ get_storage_class(const Expression& expr) {
2230 switch (expr.kind()) {
2231 case Expression::Kind::kVariableReference: {
2232 const Variable& var = *expr.as<VariableReference>().variable();
2233 if (var.storage() != Variable::Storage::kGlobal) {
2234 return SpvStorageClassFunction;
2235 }
2236 return get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
2237 }
2238 case Expression::Kind::kFieldAccess:
2239 return get_storage_class(*expr.as<FieldAccess>().base());
2240 case Expression::Kind::kIndex:
2241 return get_storage_class(*expr.as<IndexExpression>().base());
2242 default:
2243 return SpvStorageClassFunction;
2244 }
2245 }
2246
getAccessChain(const Expression & expr,OutputStream & out)2247 SkTArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
2248 switch (expr.kind()) {
2249 case Expression::Kind::kIndex: {
2250 const IndexExpression& indexExpr = expr.as<IndexExpression>();
2251 SkTArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
2252 chain.push_back(this->writeExpression(*indexExpr.index(), out));
2253 return chain;
2254 }
2255 case Expression::Kind::kFieldAccess: {
2256 const FieldAccess& fieldExpr = expr.as<FieldAccess>();
2257 SkTArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
2258 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
2259 return chain;
2260 }
2261 default: {
2262 SpvId id = this->getLValue(expr, out)->getPointer();
2263 SkASSERT(id != NA);
2264 return SkTArray<SpvId>{id};
2265 }
2266 }
2267 SkUNREACHABLE;
2268 }
2269
2270 class PointerLValue : public SPIRVCodeGenerator::LValue {
2271 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,SpvStorageClass_ storageClass)2272 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
2273 SPIRVCodeGenerator::Precision precision, SpvStorageClass_ storageClass)
2274 : fGen(gen)
2275 , fPointer(pointer)
2276 , fIsMemoryObject(isMemoryObject)
2277 , fType(type)
2278 , fPrecision(precision)
2279 , fStorageClass(storageClass) {}
2280
getPointer()2281 SpvId getPointer() override {
2282 return fPointer;
2283 }
2284
isMemoryObjectPointer() const2285 bool isMemoryObjectPointer() const override {
2286 return fIsMemoryObject;
2287 }
2288
load(OutputStream & out)2289 SpvId load(OutputStream& out) override {
2290 return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
2291 }
2292
store(SpvId value,OutputStream & out)2293 void store(SpvId value, OutputStream& out) override {
2294 if (!fIsMemoryObject) {
2295 // We are going to write into an access chain; this could represent one component of a
2296 // vector, or one element of an array. This has the potential to invalidate other,
2297 // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
2298 // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
2299 // relying on stale data, reset the store cache entirely when this happens.
2300 fGen.fStoreCache.reset();
2301 }
2302
2303 fGen.writeOpStore(fStorageClass, fPointer, value, out);
2304 }
2305
2306 private:
2307 SPIRVCodeGenerator& fGen;
2308 const SpvId fPointer;
2309 const bool fIsMemoryObject;
2310 const SpvId fType;
2311 const SPIRVCodeGenerator::Precision fPrecision;
2312 const SpvStorageClass_ fStorageClass;
2313 };
2314
2315 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
2316 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,SpvStorageClass_ storageClass)2317 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
2318 const Type& baseType, const Type& swizzleType, SpvStorageClass_ storageClass)
2319 : fGen(gen)
2320 , fVecPointer(vecPointer)
2321 , fComponents(components)
2322 , fBaseType(&baseType)
2323 , fSwizzleType(&swizzleType)
2324 , fStorageClass(storageClass) {}
2325
applySwizzle(const ComponentArray & components,const Type & newType)2326 bool applySwizzle(const ComponentArray& components, const Type& newType) override {
2327 ComponentArray updatedSwizzle;
2328 for (int8_t component : components) {
2329 if (component < 0 || component >= fComponents.size()) {
2330 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
2331 return false;
2332 }
2333 updatedSwizzle.push_back(fComponents[component]);
2334 }
2335 fComponents = updatedSwizzle;
2336 fSwizzleType = &newType;
2337 return true;
2338 }
2339
load(OutputStream & out)2340 SpvId load(OutputStream& out) override {
2341 SpvId base = fGen.nextId(fBaseType);
2342 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
2343 SpvId result = fGen.nextId(fBaseType);
2344 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
2345 fGen.writeWord(fGen.getType(*fSwizzleType), out);
2346 fGen.writeWord(result, out);
2347 fGen.writeWord(base, out);
2348 fGen.writeWord(base, out);
2349 for (int component : fComponents) {
2350 fGen.writeWord(component, out);
2351 }
2352 return result;
2353 }
2354
store(SpvId value,OutputStream & out)2355 void store(SpvId value, OutputStream& out) override {
2356 // use OpVectorShuffle to mix and match the vector components. We effectively create
2357 // a virtual vector out of the concatenation of the left and right vectors, and then
2358 // select components from this virtual vector to make the result vector. For
2359 // instance, given:
2360 // float3L = ...;
2361 // float3R = ...;
2362 // L.xz = R.xy;
2363 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
2364 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
2365 // (3, 1, 4).
2366 SpvId base = fGen.nextId(fBaseType);
2367 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
2368 SpvId shuffle = fGen.nextId(fBaseType);
2369 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
2370 fGen.writeWord(fGen.getType(*fBaseType), out);
2371 fGen.writeWord(shuffle, out);
2372 fGen.writeWord(base, out);
2373 fGen.writeWord(value, out);
2374 for (int i = 0; i < fBaseType->columns(); i++) {
2375 // current offset into the virtual vector, defaults to pulling the unmodified
2376 // value from the left side
2377 int offset = i;
2378 // check to see if we are writing this component
2379 for (int j = 0; j < fComponents.size(); j++) {
2380 if (fComponents[j] == i) {
2381 // we're writing to this component, so adjust the offset to pull from
2382 // the correct component of the right side instead of preserving the
2383 // value from the left
2384 offset = (int) (j + fBaseType->columns());
2385 break;
2386 }
2387 }
2388 fGen.writeWord(offset, out);
2389 }
2390 fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
2391 }
2392
2393 private:
2394 SPIRVCodeGenerator& fGen;
2395 const SpvId fVecPointer;
2396 ComponentArray fComponents;
2397 const Type* fBaseType;
2398 const Type* fSwizzleType;
2399 const SpvStorageClass_ fStorageClass;
2400 };
2401
findUniformFieldIndex(const Variable & var) const2402 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
2403 int* fieldIndex = fTopLevelUniformMap.find(&var);
2404 return fieldIndex ? *fieldIndex : -1;
2405 }
2406
getLValue(const Expression & expr,OutputStream & out)2407 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
2408 OutputStream& out) {
2409 const Type& type = expr.type();
2410 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
2411 switch (expr.kind()) {
2412 case Expression::Kind::kVariableReference: {
2413 const Variable& var = *expr.as<VariableReference>().variable();
2414 int uniformIdx = this->findUniformFieldIndex(var);
2415 if (uniformIdx >= 0) {
2416 SpvId memberId = this->nextId(nullptr);
2417 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
2418 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
2419 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
2420 uniformIdxId, out);
2421 return std::make_unique<PointerLValue>(
2422 *this,
2423 memberId,
2424 /*isMemoryObjectPointer=*/true,
2425 this->getType(type, this->memoryLayoutForVariable(var)),
2426 precision,
2427 SpvStorageClassUniform);
2428 }
2429 SpvId typeId = this->getType(type, this->memoryLayoutForVariable(var));
2430 SpvId* entry = fVariableMap.find(&var);
2431 SkASSERTF(entry, "%s", expr.description().c_str());
2432 return std::make_unique<PointerLValue>(*this, *entry,
2433 /*isMemoryObjectPointer=*/true,
2434 typeId, precision, get_storage_class(expr));
2435 }
2436 case Expression::Kind::kIndex: // fall through
2437 case Expression::Kind::kFieldAccess: {
2438 SkTArray<SpvId> chain = this->getAccessChain(expr, out);
2439 SpvId member = this->nextId(nullptr);
2440 SpvStorageClass_ storageClass = get_storage_class(expr);
2441 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
2442 this->writeWord(this->getPointerType(type, storageClass), out);
2443 this->writeWord(member, out);
2444 for (SpvId idx : chain) {
2445 this->writeWord(idx, out);
2446 }
2447 return std::make_unique<PointerLValue>(
2448 *this,
2449 member,
2450 /*isMemoryObjectPointer=*/false,
2451 this->getType(type, this->memoryLayoutForStorageClass(storageClass)),
2452 precision,
2453 storageClass);
2454 }
2455 case Expression::Kind::kSwizzle: {
2456 const Swizzle& swizzle = expr.as<Swizzle>();
2457 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
2458 if (lvalue->applySwizzle(swizzle.components(), type)) {
2459 return lvalue;
2460 }
2461 SpvId base = lvalue->getPointer();
2462 if (base == NA) {
2463 fContext.fErrors->error(swizzle.fPosition,
2464 "unable to retrieve lvalue from swizzle");
2465 }
2466 SpvStorageClass_ storageClass = get_storage_class(*swizzle.base());
2467 if (swizzle.components().size() == 1) {
2468 SpvId member = this->nextId(nullptr);
2469 SpvId typeId = this->getPointerType(type, storageClass);
2470 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
2471 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
2472 return std::make_unique<PointerLValue>(*this, member,
2473 /*isMemoryObjectPointer=*/false,
2474 this->getType(type),
2475 precision, storageClass);
2476 } else {
2477 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
2478 swizzle.base()->type(), type, storageClass);
2479 }
2480 }
2481 default: {
2482 // expr isn't actually an lvalue, create a placeholder variable for it. This case
2483 // happens due to the need to store values in temporary variables during function
2484 // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
2485 // lvalues should have been caught before code generation.
2486 //
2487 // This is with the exception of opaque handle types (textures/samplers) which are
2488 // always defined as UniformConstant pointers and don't need to be explicitly stored
2489 // into a temporary (which is handled explicitly in writeFunctionCallArgument).
2490 SpvId result = this->nextId(nullptr);
2491 SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
2492 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
2493 fVariableBuffer);
2494 this->writeOpStore(SpvStorageClassFunction, result, this->writeExpression(expr, out),
2495 out);
2496 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
2497 this->getType(type), precision,
2498 SpvStorageClassFunction);
2499 }
2500 }
2501 }
2502
writeVariableReference(const VariableReference & ref,OutputStream & out)2503 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
2504 const Variable* variable = ref.variable();
2505 switch (variable->modifiers().fLayout.fBuiltin) {
2506 case DEVICE_FRAGCOORDS_BUILTIN: {
2507 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
2508 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
2509 // access the fragcoord; do so now.
2510 dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2511 return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2512 }
2513 case DEVICE_CLOCKWISE_BUILTIN: {
2514 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
2515 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
2516 // access front facing; do so now.
2517 dsl::DSLGlobalVar clockwise("sk_Clockwise");
2518 return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2519 }
2520 case SK_SECONDARYFRAGCOLOR_BUILTIN: {
2521 // sk_SecondaryFragColor corresponds to gl_SecondaryFragColorEXT, which isn't supposed
2522 // to appear in a SPIR-V program (it's only valid in ES2). Report an error.
2523 fContext.fErrors->error(ref.fPosition,
2524 "sk_SecondaryFragColor is not allowed in SPIR-V");
2525 return NA;
2526 }
2527 case SK_FRAGCOORD_BUILTIN: {
2528 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
2529 dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2530 return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2531 }
2532
2533 // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
2534 this->addRTFlipUniform(ref.fPosition);
2535 // Use sk_RTAdjust to compute the flipped coordinate
2536 using namespace dsl;
2537 const char* DEVICE_COORDS_NAME = "$device_FragCoords";
2538 SymbolTable& symbols = *ThreadContext::SymbolTable();
2539 // Use a uniform to flip the Y coordinate. The new expression will be written in
2540 // terms of $device_FragCoords, which is a fake variable that means "access the
2541 // underlying fragcoords directly without flipping it".
2542 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(Position(),
2543 SKSL_RTFLIP_NAME));
2544 if (!symbols.find(DEVICE_COORDS_NAME)) {
2545 AutoAttachPoolToThread attach(fProgram.fPool.get());
2546 Modifiers modifiers;
2547 modifiers.fLayout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
2548 auto coordsVar = std::make_unique<Variable>(/*pos=*/Position(),
2549 /*modifiersPosition=*/Position(),
2550 fContext.fModifiersPool->add(modifiers),
2551 DEVICE_COORDS_NAME,
2552 fContext.fTypes.fFloat4.get(),
2553 /*builtin=*/true,
2554 Variable::Storage::kGlobal);
2555 fSPIRVBonusVariables.add(coordsVar.get());
2556 symbols.add(std::move(coordsVar));
2557 }
2558 DSLGlobalVar deviceCoord(DEVICE_COORDS_NAME);
2559 std::unique_ptr<Expression> rtFlipSkSLExpr = rtFlip.release();
2560 DSLExpression x = DSLExpression(rtFlipSkSLExpr->clone()).x();
2561 DSLExpression y = DSLExpression(std::move(rtFlipSkSLExpr)).y();
2562 return this->writeExpression(*dsl::Float4(deviceCoord.x(),
2563 std::move(x) + std::move(y) * deviceCoord.y(),
2564 deviceCoord.z(),
2565 deviceCoord.w()).release(),
2566 out);
2567 }
2568 case SK_CLOCKWISE_BUILTIN: {
2569 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
2570 dsl::DSLGlobalVar clockwise("sk_Clockwise");
2571 return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2572 }
2573
2574 // Handle flipping sk_Clockwise.
2575 this->addRTFlipUniform(ref.fPosition);
2576 using namespace dsl;
2577 const char* DEVICE_CLOCKWISE_NAME = "$device_Clockwise";
2578 SymbolTable& symbols = *ThreadContext::SymbolTable();
2579 // Use a uniform to flip the Y coordinate. The new expression will be written in
2580 // terms of $device_Clockwise, which is a fake variable that means "access the
2581 // underlying FrontFacing directly".
2582 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(Position(),
2583 SKSL_RTFLIP_NAME));
2584 if (!symbols.find(DEVICE_CLOCKWISE_NAME)) {
2585 AutoAttachPoolToThread attach(fProgram.fPool.get());
2586 Modifiers modifiers;
2587 modifiers.fLayout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
2588 auto clockwiseVar = std::make_unique<Variable>(/*pos=*/Position(),
2589 /*modifiersPosition=*/Position(),
2590 fContext.fModifiersPool->add(modifiers),
2591 DEVICE_CLOCKWISE_NAME,
2592 fContext.fTypes.fBool.get(),
2593 /*builtin=*/true,
2594 Variable::Storage::kGlobal);
2595 fSPIRVBonusVariables.add(clockwiseVar.get());
2596 symbols.add(std::move(clockwiseVar));
2597 }
2598 DSLGlobalVar deviceClockwise(DEVICE_CLOCKWISE_NAME);
2599 // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia,
2600 // we use the default convention of "counter-clockwise face is front".
2601 return this->writeExpression(*dsl::Bool(Select(rtFlip.y() > 0,
2602 !deviceClockwise,
2603 deviceClockwise)).release(),
2604 out);
2605 }
2606 default: {
2607 // Constant-propagate variables that have a known compile-time value.
2608 if (const Expression* expr = ConstantFolder::GetConstantValueOrNullForVariable(ref)) {
2609 return this->writeExpression(*expr, out);
2610 }
2611
2612 // A reference to a sampler variable at global scope with synthesized texture/sampler
2613 // backing should construct a function-scope combined image-sampler from the synthesized
2614 // constituents. This is the case in which a sample intrinsic was invoked.
2615 //
2616 // Variable references to opaque handles (texture/sampler) that appear as the argument
2617 // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
2618 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
2619 SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
2620
2621 SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
2622 SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
2623 SkASSERT(imgPtr);
2624 SkASSERT(samplerPtr);
2625
2626 SpvId img = this->writeOpLoad(
2627 this->getType((*p)->fTexture->type()), Precision::kDefault, *imgPtr, out);
2628 SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
2629 Precision::kDefault,
2630 *samplerPtr,
2631 out);
2632
2633 SpvId result = this->nextId(nullptr);
2634 this->writeInstruction(SpvOpSampledImage,
2635 this->getType(variable->type()),
2636 result,
2637 img,
2638 sampler,
2639 out);
2640
2641 return result;
2642 }
2643
2644 return this->getLValue(ref, out)->load(out);
2645 }
2646 }
2647 }
2648
writeIndexExpression(const IndexExpression & expr,OutputStream & out)2649 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
2650 if (expr.base()->type().isVector()) {
2651 SpvId base = this->writeExpression(*expr.base(), out);
2652 SpvId index = this->writeExpression(*expr.index(), out);
2653 SpvId result = this->nextId(nullptr);
2654 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
2655 index, out);
2656 return result;
2657 }
2658 return getLValue(expr, out)->load(out);
2659 }
2660
writeFieldAccess(const FieldAccess & f,OutputStream & out)2661 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
2662 return getLValue(f, out)->load(out);
2663 }
2664
writeSwizzle(const Swizzle & swizzle,OutputStream & out)2665 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
2666 SpvId base = this->writeExpression(*swizzle.base(), out);
2667 size_t count = swizzle.components().size();
2668 if (count == 1) {
2669 return this->writeOpCompositeExtract(swizzle.type(), base, swizzle.components()[0], out);
2670 }
2671
2672 SpvId result = this->nextId(&swizzle.type());
2673 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
2674 this->writeWord(this->getType(swizzle.type()), out);
2675 this->writeWord(result, out);
2676 this->writeWord(base, out);
2677 this->writeWord(base, out);
2678 for (int component : swizzle.components()) {
2679 this->writeWord(component, out);
2680 }
2681 return result;
2682 }
2683
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)2684 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2685 const Type& operandType, SpvId lhs,
2686 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2687 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2688 SpvId result = this->nextId(&resultType);
2689 SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
2690 if (op == SpvOpUndef) {
2691 fContext.fErrors->error(operandType.fPosition,
2692 "unsupported operand for binary expression: " + operandType.description());
2693 return NA;
2694 }
2695 this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
2696 return result;
2697 }
2698
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)2699 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
2700 OutputStream& out) {
2701 if (operandType.isVector()) {
2702 SpvId result = this->nextId(nullptr);
2703 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
2704 return result;
2705 }
2706 return id;
2707 }
2708
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)2709 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2710 SpvOp_ floatOperator, SpvOp_ intOperator,
2711 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
2712 OutputStream& out) {
2713 SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
2714 SkASSERT(operandType.isMatrix());
2715 const Type& columnType = operandType.componentType().toCompound(fContext,
2716 operandType.rows(),
2717 1);
2718 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
2719 operandType.rows(),
2720 1));
2721 SpvId boolType = this->getType(*fContext.fTypes.fBool);
2722 SpvId result = 0;
2723 for (int i = 0; i < operandType.columns(); i++) {
2724 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
2725 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
2726 SpvId compare = this->nextId(&operandType);
2727 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2728 SpvId merge = this->nextId(nullptr);
2729 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2730 if (result != 0) {
2731 SpvId next = this->nextId(nullptr);
2732 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2733 result = next;
2734 } else {
2735 result = merge;
2736 }
2737 }
2738 return result;
2739 }
2740
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)2741 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
2742 SpvId operand,
2743 SpvOp_ op,
2744 OutputStream& out) {
2745 SkASSERT(operandType.isMatrix());
2746 const Type& columnType = operandType.componentType().toCompound(fContext,
2747 /*columns=*/operandType.rows(),
2748 /*rows=*/1);
2749 SpvId columnTypeId = this->getType(columnType);
2750
2751 SkSTArray<4, SpvId> columns;
2752 for (int i = 0; i < operandType.columns(); i++) {
2753 SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
2754 SpvId dstColumn = this->nextId(&operandType);
2755 this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
2756 columns.push_back(dstColumn);
2757 }
2758
2759 return this->writeOpCompositeConstruct(operandType, columns, out);
2760 }
2761
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)2762 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2763 SpvId rhs, SpvOp_ op, OutputStream& out) {
2764 SkASSERT(operandType.isMatrix());
2765 const Type& columnType = operandType.componentType().toCompound(fContext,
2766 /*columns=*/operandType.rows(),
2767 /*rows=*/1);
2768 SpvId columnTypeId = this->getType(columnType);
2769
2770 SkSTArray<4, SpvId> columns;
2771 for (int i = 0; i < operandType.columns(); i++) {
2772 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
2773 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
2774 columns.push_back(this->nextId(&operandType));
2775 this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
2776 }
2777 return this->writeOpCompositeConstruct(operandType, columns, out);
2778 }
2779
writeReciprocal(const Type & type,SpvId value,OutputStream & out)2780 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
2781 SkASSERT(type.isFloat());
2782 SpvId one = this->writeLiteral(1.0, type);
2783 SpvId reciprocal = this->nextId(&type);
2784 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
2785 return reciprocal;
2786 }
2787
writeScalarToMatrixSplat(const Type & matrixType,SpvId scalarId,OutputStream & out)2788 SpvId SPIRVCodeGenerator::writeScalarToMatrixSplat(const Type& matrixType,
2789 SpvId scalarId,
2790 OutputStream& out) {
2791 // Splat the scalar into a vector.
2792 const Type& vectorType = matrixType.componentType().toCompound(fContext,
2793 /*columns=*/matrixType.rows(),
2794 /*rows=*/1);
2795 SkSTArray<4, SpvId> vecArguments;
2796 vecArguments.push_back_n(/*n=*/matrixType.rows(), /*t=*/scalarId);
2797 SpvId vectorId = this->writeOpCompositeConstruct(vectorType, vecArguments, out);
2798
2799 // Splat the vector into a matrix.
2800 SkSTArray<4, SpvId> matArguments;
2801 matArguments.push_back_n(/*n=*/matrixType.columns(), /*t=*/vectorId);
2802 return this->writeOpCompositeConstruct(matrixType, matArguments, out);
2803 }
2804
types_match(const Type & a,const Type & b)2805 static bool types_match(const Type& a, const Type& b) {
2806 if (a.matches(b)) {
2807 return true;
2808 }
2809 return (a.typeKind() == b.typeKind()) &&
2810 (a.isScalar() || a.isVector() || a.isMatrix()) &&
2811 (a.columns() == b.columns() && a.rows() == b.rows()) &&
2812 a.componentType().numberKind() == b.componentType().numberKind();
2813 }
2814
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2815 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
2816 const Type& rightType, SpvId rhs,
2817 const Type& resultType, OutputStream& out) {
2818 // The comma operator ignores the type of the left-hand side entirely.
2819 if (op.kind() == Operator::Kind::COMMA) {
2820 return rhs;
2821 }
2822 // overall type we are operating on: float2, int, uint4...
2823 const Type* operandType;
2824 if (types_match(leftType, rightType)) {
2825 operandType = &leftType;
2826 } else {
2827 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2828 // handling in SPIR-V
2829 if (leftType.isVector() && rightType.isNumber()) {
2830 if (resultType.componentType().isFloat()) {
2831 switch (op.kind()) {
2832 case Operator::Kind::SLASH: {
2833 rhs = this->writeReciprocal(rightType, rhs, out);
2834 [[fallthrough]];
2835 }
2836 case Operator::Kind::STAR: {
2837 SpvId result = this->nextId(&resultType);
2838 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2839 result, lhs, rhs, out);
2840 return result;
2841 }
2842 default:
2843 break;
2844 }
2845 }
2846 // Vectorize the right-hand side.
2847 SkSTArray<4, SpvId> arguments;
2848 arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
2849 rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
2850 operandType = &leftType;
2851 } else if (rightType.isVector() && leftType.isNumber()) {
2852 if (resultType.componentType().isFloat()) {
2853 if (op.kind() == Operator::Kind::STAR) {
2854 SpvId result = this->nextId(&resultType);
2855 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2856 result, rhs, lhs, out);
2857 return result;
2858 }
2859 }
2860 // Vectorize the left-hand side.
2861 SkSTArray<4, SpvId> arguments;
2862 arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
2863 lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
2864 operandType = &rightType;
2865 } else if (leftType.isMatrix()) {
2866 if (op.kind() == Operator::Kind::STAR) {
2867 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2868 SpvOp_ spvop;
2869 if (rightType.isMatrix()) {
2870 spvop = SpvOpMatrixTimesMatrix;
2871 } else if (rightType.isVector()) {
2872 spvop = SpvOpMatrixTimesVector;
2873 } else {
2874 SkASSERT(rightType.isScalar());
2875 spvop = SpvOpMatrixTimesScalar;
2876 }
2877 SpvId result = this->nextId(&resultType);
2878 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2879 return result;
2880 } else {
2881 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
2882 // expect to have a scalar here.
2883 SkASSERT(rightType.isScalar());
2884
2885 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
2886 SpvId rhsMatrix = this->writeScalarToMatrixSplat(leftType, rhs, out);
2887
2888 // Perform this operation as matrix-op-matrix.
2889 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
2890 resultType, out);
2891 }
2892 } else if (rightType.isMatrix()) {
2893 if (op.kind() == Operator::Kind::STAR) {
2894 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2895 SpvId result = this->nextId(&resultType);
2896 if (leftType.isVector()) {
2897 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
2898 result, lhs, rhs, out);
2899 } else {
2900 SkASSERT(leftType.isScalar());
2901 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
2902 result, rhs, lhs, out);
2903 }
2904 return result;
2905 } else {
2906 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
2907 // expect to have a scalar here.
2908 SkASSERT(leftType.isScalar());
2909
2910 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
2911 SpvId lhsMatrix = this->writeScalarToMatrixSplat(rightType, lhs, out);
2912
2913 // Perform this operation as matrix-op-matrix.
2914 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
2915 resultType, out);
2916 }
2917 } else {
2918 fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
2919 return NA;
2920 }
2921 }
2922
2923 switch (op.kind()) {
2924 case Operator::Kind::EQEQ: {
2925 if (operandType->isMatrix()) {
2926 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2927 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2928 }
2929 if (operandType->isStruct()) {
2930 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2931 }
2932 if (operandType->isArray()) {
2933 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2934 }
2935 SkASSERT(resultType.isBoolean());
2936 const Type* tmpType;
2937 if (operandType->isVector()) {
2938 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2939 operandType->columns(),
2940 operandType->rows());
2941 } else {
2942 tmpType = &resultType;
2943 }
2944 if (lhs == rhs) {
2945 // This ignores the effects of NaN.
2946 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
2947 }
2948 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2949 SpvOpFOrdEqual, SpvOpIEqual,
2950 SpvOpIEqual, SpvOpLogicalEqual, out),
2951 *operandType, SpvOpAll, out);
2952 }
2953 case Operator::Kind::NEQ:
2954 if (operandType->isMatrix()) {
2955 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
2956 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2957 }
2958 if (operandType->isStruct()) {
2959 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2960 }
2961 if (operandType->isArray()) {
2962 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2963 }
2964 [[fallthrough]];
2965 case Operator::Kind::LOGICALXOR:
2966 SkASSERT(resultType.isBoolean());
2967 const Type* tmpType;
2968 if (operandType->isVector()) {
2969 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2970 operandType->columns(),
2971 operandType->rows());
2972 } else {
2973 tmpType = &resultType;
2974 }
2975 if (lhs == rhs) {
2976 // This ignores the effects of NaN.
2977 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
2978 }
2979 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2980 SpvOpFUnordNotEqual, SpvOpINotEqual,
2981 SpvOpINotEqual, SpvOpLogicalNotEqual,
2982 out),
2983 *operandType, SpvOpAny, out);
2984 case Operator::Kind::GT:
2985 SkASSERT(resultType.isBoolean());
2986 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2987 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2988 SpvOpUGreaterThan, SpvOpUndef, out);
2989 case Operator::Kind::LT:
2990 SkASSERT(resultType.isBoolean());
2991 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2992 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2993 case Operator::Kind::GTEQ:
2994 SkASSERT(resultType.isBoolean());
2995 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2996 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2997 SpvOpUGreaterThanEqual, SpvOpUndef, out);
2998 case Operator::Kind::LTEQ:
2999 SkASSERT(resultType.isBoolean());
3000 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3001 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
3002 SpvOpULessThanEqual, SpvOpUndef, out);
3003 case Operator::Kind::PLUS:
3004 if (leftType.isMatrix() && rightType.isMatrix()) {
3005 SkASSERT(leftType.matches(rightType));
3006 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFAdd, out);
3007 }
3008 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
3009 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
3010 case Operator::Kind::MINUS:
3011 if (leftType.isMatrix() && rightType.isMatrix()) {
3012 SkASSERT(leftType.matches(rightType));
3013 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFSub, out);
3014 }
3015 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
3016 SpvOpISub, SpvOpISub, SpvOpUndef, out);
3017 case Operator::Kind::STAR:
3018 if (leftType.isMatrix() && rightType.isMatrix()) {
3019 // matrix multiply
3020 SpvId result = this->nextId(&resultType);
3021 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
3022 lhs, rhs, out);
3023 return result;
3024 }
3025 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
3026 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
3027 case Operator::Kind::SLASH:
3028 if (leftType.isMatrix() && rightType.isMatrix()) {
3029 SkASSERT(leftType.matches(rightType));
3030 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFDiv, out);
3031 }
3032 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
3033 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
3034 case Operator::Kind::PERCENT:
3035 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
3036 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
3037 case Operator::Kind::SHL:
3038 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3039 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
3040 SpvOpUndef, out);
3041 case Operator::Kind::SHR:
3042 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3043 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
3044 SpvOpUndef, out);
3045 case Operator::Kind::BITWISEAND:
3046 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3047 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
3048 case Operator::Kind::BITWISEOR:
3049 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3050 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
3051 case Operator::Kind::BITWISEXOR:
3052 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3053 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
3054 default:
3055 fContext.fErrors->error(Position(), "unsupported token");
3056 return NA;
3057 }
3058 }
3059
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3060 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
3061 SpvId rhs, OutputStream& out) {
3062 // The inputs must be arrays, and the op must be == or !=.
3063 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3064 SkASSERT(arrayType.isArray());
3065 const Type& componentType = arrayType.componentType();
3066 const int arraySize = arrayType.columns();
3067 SkASSERT(arraySize > 0);
3068
3069 // Synthesize equality checks for each item in the array.
3070 const Type& boolType = *fContext.fTypes.fBool;
3071 SpvId allComparisons = NA;
3072 for (int index = 0; index < arraySize; ++index) {
3073 // Get the left and right item in the array.
3074 SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
3075 SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
3076 // Use `writeBinaryExpression` with the requested == or != operator on these items.
3077 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
3078 componentType, itemR, boolType, out);
3079 // Merge this comparison result with all the other comparisons we've done.
3080 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3081 }
3082 return allComparisons;
3083 }
3084
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3085 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
3086 SpvId rhs, OutputStream& out) {
3087 // The inputs must be structs containing fields, and the op must be == or !=.
3088 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3089 SkASSERT(structType.isStruct());
3090 const std::vector<Type::Field>& fields = structType.fields();
3091 SkASSERT(!fields.empty());
3092
3093 // Synthesize equality checks for each field in the struct.
3094 const Type& boolType = *fContext.fTypes.fBool;
3095 SpvId allComparisons = NA;
3096 for (int index = 0; index < (int)fields.size(); ++index) {
3097 // Get the left and right versions of this field.
3098 const Type& fieldType = *fields[index].fType;
3099
3100 SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
3101 SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
3102 // Use `writeBinaryExpression` with the requested == or != operator on these fields.
3103 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
3104 boolType, out);
3105 // Merge this comparison result with all the other comparisons we've done.
3106 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3107 }
3108 return allComparisons;
3109 }
3110
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)3111 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
3112 OutputStream& out) {
3113 // If this is the first entry, we don't need to merge comparison results with anything.
3114 if (allComparisons == NA) {
3115 return comparison;
3116 }
3117 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
3118 const Type& boolType = *fContext.fTypes.fBool;
3119 SpvId boolTypeId = this->getType(boolType);
3120 SpvId logicalOp = this->nextId(&boolType);
3121 switch (op.kind()) {
3122 case Operator::Kind::EQEQ:
3123 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
3124 comparison, allComparisons, out);
3125 break;
3126 case Operator::Kind::NEQ:
3127 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
3128 comparison, allComparisons, out);
3129 break;
3130 default:
3131 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
3132 return NA;
3133 }
3134 return logicalOp;
3135 }
3136
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)3137 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
3138 const Expression* left = b.left().get();
3139 const Expression* right = b.right().get();
3140 Operator op = b.getOperator();
3141
3142 switch (op.kind()) {
3143 case Operator::Kind::EQ: {
3144 // Handles assignment.
3145 SpvId rhs = this->writeExpression(*right, out);
3146 this->getLValue(*left, out)->store(rhs, out);
3147 return rhs;
3148 }
3149 case Operator::Kind::LOGICALAND:
3150 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
3151 return this->writeLogicalAnd(*b.left(), *b.right(), out);
3152
3153 case Operator::Kind::LOGICALOR:
3154 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
3155 return this->writeLogicalOr(*b.left(), *b.right(), out);
3156
3157 default:
3158 break;
3159 }
3160
3161 std::unique_ptr<LValue> lvalue;
3162 SpvId lhs;
3163 if (op.isAssignment()) {
3164 lvalue = this->getLValue(*left, out);
3165 lhs = lvalue->load(out);
3166 } else {
3167 lvalue = nullptr;
3168 lhs = this->writeExpression(*left, out);
3169 }
3170
3171 SpvId rhs = this->writeExpression(*right, out);
3172 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
3173 right->type(), rhs, b.type(), out);
3174 if (lvalue) {
3175 lvalue->store(result, out);
3176 }
3177 return result;
3178 }
3179
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)3180 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
3181 OutputStream& out) {
3182 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
3183 SpvId lhs = this->writeExpression(left, out);
3184
3185 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3186
3187 SpvId rhsLabel = this->nextId(nullptr);
3188 SpvId end = this->nextId(nullptr);
3189 SpvId lhsBlock = fCurrentBlock;
3190 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3191 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
3192 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
3193 SpvId rhs = this->writeExpression(right, out);
3194 SpvId rhsBlock = fCurrentBlock;
3195 this->writeInstruction(SpvOpBranch, end, out);
3196 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3197 SpvId result = this->nextId(nullptr);
3198 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
3199 lhsBlock, rhs, rhsBlock, out);
3200
3201 return result;
3202 }
3203
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)3204 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
3205 OutputStream& out) {
3206 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
3207 SpvId lhs = this->writeExpression(left, out);
3208
3209 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3210
3211 SpvId rhsLabel = this->nextId(nullptr);
3212 SpvId end = this->nextId(nullptr);
3213 SpvId lhsBlock = fCurrentBlock;
3214 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3215 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
3216 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
3217 SpvId rhs = this->writeExpression(right, out);
3218 SpvId rhsBlock = fCurrentBlock;
3219 this->writeInstruction(SpvOpBranch, end, out);
3220 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3221 SpvId result = this->nextId(nullptr);
3222 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
3223 lhsBlock, rhs, rhsBlock, out);
3224
3225 return result;
3226 }
3227
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)3228 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
3229 const Type& type = t.type();
3230 SpvId test = this->writeExpression(*t.test(), out);
3231 if (t.ifTrue()->type().columns() == 1 &&
3232 Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
3233 Analysis::IsCompileTimeConstant(*t.ifFalse())) {
3234 // both true and false are constants, can just use OpSelect
3235 SpvId result = this->nextId(nullptr);
3236 SpvId trueId = this->writeExpression(*t.ifTrue(), out);
3237 SpvId falseId = this->writeExpression(*t.ifFalse(), out);
3238 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
3239 out);
3240 return result;
3241 }
3242
3243 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3244
3245 // was originally using OpPhi to choose the result, but for some reason that is crashing on
3246 // Adreno. Switched to storing the result in a temp variable as glslang does.
3247 SpvId var = this->nextId(nullptr);
3248 this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
3249 var, SpvStorageClassFunction, fVariableBuffer);
3250 SpvId trueLabel = this->nextId(nullptr);
3251 SpvId falseLabel = this->nextId(nullptr);
3252 SpvId end = this->nextId(nullptr);
3253 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3254 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
3255 this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
3256 this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifTrue(), out), out);
3257 this->writeInstruction(SpvOpBranch, end, out);
3258 this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
3259 this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifFalse(), out), out);
3260 this->writeInstruction(SpvOpBranch, end, out);
3261 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3262 SpvId result = this->nextId(&type);
3263 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
3264
3265 return result;
3266 }
3267
writePrefixExpression(const PrefixExpression & p,OutputStream & out)3268 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
3269 const Type& type = p.type();
3270 if (p.getOperator().kind() == Operator::Kind::MINUS) {
3271 SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
3272 SkASSERT(negateOp != SpvOpUndef);
3273 SpvId expr = this->writeExpression(*p.operand(), out);
3274 if (type.isMatrix()) {
3275 return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
3276 }
3277 SpvId result = this->nextId(&type);
3278 SpvId typeId = this->getType(type);
3279 this->writeInstruction(negateOp, typeId, result, expr, out);
3280 return result;
3281 }
3282 switch (p.getOperator().kind()) {
3283 case Operator::Kind::PLUS:
3284 return this->writeExpression(*p.operand(), out);
3285 case Operator::Kind::PLUSPLUS: {
3286 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3287 SpvId one = this->writeLiteral(1.0, type);
3288 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
3289 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
3290 out);
3291 lv->store(result, out);
3292 return result;
3293 }
3294 case Operator::Kind::MINUSMINUS: {
3295 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3296 SpvId one = this->writeLiteral(1.0, type);
3297 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
3298 SpvOpISub, SpvOpISub, SpvOpUndef, out);
3299 lv->store(result, out);
3300 return result;
3301 }
3302 case Operator::Kind::LOGICALNOT: {
3303 SkASSERT(p.operand()->type().isBoolean());
3304 SpvId result = this->nextId(nullptr);
3305 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
3306 this->writeExpression(*p.operand(), out), out);
3307 return result;
3308 }
3309 case Operator::Kind::BITWISENOT: {
3310 SpvId result = this->nextId(nullptr);
3311 this->writeInstruction(SpvOpNot, this->getType(type), result,
3312 this->writeExpression(*p.operand(), out), out);
3313 return result;
3314 }
3315 default:
3316 SkDEBUGFAILF("unsupported prefix expression: %s",
3317 p.description(OperatorPrecedence::kTopLevel).c_str());
3318 return NA;
3319 }
3320 }
3321
writePostfixExpression(const PostfixExpression & p,OutputStream & out)3322 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
3323 const Type& type = p.type();
3324 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
3325 SpvId result = lv->load(out);
3326 SpvId one = this->writeLiteral(1.0, type);
3327 switch (p.getOperator().kind()) {
3328 case Operator::Kind::PLUSPLUS: {
3329 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,
3330 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
3331 lv->store(temp, out);
3332 return result;
3333 }
3334 case Operator::Kind::MINUSMINUS: {
3335 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub,
3336 SpvOpISub, SpvOpISub, SpvOpUndef, out);
3337 lv->store(temp, out);
3338 return result;
3339 }
3340 default:
3341 SkDEBUGFAILF("unsupported postfix expression %s",
3342 p.description(OperatorPrecedence::kTopLevel).c_str());
3343 return NA;
3344 }
3345 }
3346
writeLiteral(const Literal & l)3347 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
3348 return this->writeLiteral(l.value(), l.type());
3349 }
3350
writeLiteral(double value,const Type & type)3351 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
3352 switch (type.numberKind()) {
3353 case Type::NumberKind::kFloat: {
3354 float floatVal = value;
3355 int32_t valueBits;
3356 memcpy(&valueBits, &floatVal, sizeof(valueBits));
3357 return this->writeOpConstant(type, valueBits);
3358 }
3359 case Type::NumberKind::kBoolean: {
3360 return value ? this->writeOpConstantTrue(type)
3361 : this->writeOpConstantFalse(type);
3362 }
3363 default: {
3364 return this->writeOpConstant(type, (SKSL_INT)value);
3365 }
3366 }
3367 }
3368
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)3369 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
3370 SpvId result = fFunctionMap[&f];
3371 SpvId returnTypeId = this->getType(f.returnType());
3372 SpvId functionTypeId = this->getFunctionType(f);
3373 this->writeInstruction(SpvOpFunction, returnTypeId, result,
3374 SpvFunctionControlMaskNone, functionTypeId, out);
3375 std::string mangledName = f.mangledName();
3376 this->writeInstruction(SpvOpName,
3377 result,
3378 std::string_view(mangledName.c_str(), mangledName.size()),
3379 fNameBuffer);
3380 for (const Variable* parameter : f.parameters()) {
3381 if (parameter->type().typeKind() == Type::TypeKind::kSampler &&
3382 fProgram.fConfig->fSettings.fSPIRVDawnCompatMode) {
3383 auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
3384
3385 SpvId textureId = this->nextId(nullptr);
3386 SpvId samplerId = this->nextId(nullptr);
3387 fVariableMap.set(texture, textureId);
3388 fVariableMap.set(sampler, samplerId);
3389
3390 SpvId textureType = this->getFunctionParameterType(texture->type());
3391 SpvId samplerType = this->getFunctionParameterType(sampler->type());
3392
3393 this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
3394 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
3395 } else {
3396 SpvId id = this->nextId(nullptr);
3397 fVariableMap.set(parameter, id);
3398
3399 SpvId type = this->getFunctionParameterType(parameter->type());
3400 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
3401 }
3402 }
3403 return result;
3404 }
3405
writeFunction(const FunctionDefinition & f,OutputStream & out)3406 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
3407 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3408
3409 fVariableBuffer.reset();
3410 SpvId result = this->writeFunctionStart(f.declaration(), out);
3411 fCurrentBlock = 0;
3412 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
3413 StringStream bodyBuffer;
3414 this->writeBlock(f.body()->as<Block>(), bodyBuffer);
3415 write_stringstream(fVariableBuffer, out);
3416 if (f.declaration().isMain()) {
3417 write_stringstream(fGlobalInitializersBuffer, out);
3418 }
3419 write_stringstream(bodyBuffer, out);
3420 if (fCurrentBlock) {
3421 if (f.declaration().returnType().isVoid()) {
3422 this->writeInstruction(SpvOpReturn, out);
3423 } else {
3424 this->writeInstruction(SpvOpUnreachable, out);
3425 }
3426 }
3427 this->writeInstruction(SpvOpFunctionEnd, out);
3428 this->pruneConditionalOps(conditionalOps);
3429 return result;
3430 }
3431
writeLayout(const Layout & layout,SpvId target,Position pos)3432 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
3433 bool isPushConstant = (layout.fFlags & Layout::kPushConstant_Flag);
3434 if (layout.fLocation >= 0) {
3435 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
3436 fDecorationBuffer);
3437 }
3438 if (layout.fBinding >= 0) {
3439 if (isPushConstant) {
3440 fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
3441 } else {
3442 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
3443 fDecorationBuffer);
3444 }
3445 }
3446 if (layout.fIndex >= 0) {
3447 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
3448 fDecorationBuffer);
3449 }
3450 if (layout.fSet >= 0) {
3451 if (isPushConstant) {
3452 fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
3453 } else {
3454 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
3455 fDecorationBuffer);
3456 }
3457 }
3458 if (layout.fInputAttachmentIndex >= 0) {
3459 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
3460 layout.fInputAttachmentIndex, fDecorationBuffer);
3461 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
3462 }
3463 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
3464 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
3465 fDecorationBuffer);
3466 }
3467 }
3468
writeFieldLayout(const Layout & layout,SpvId target,int member)3469 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
3470 // 'binding' and 'set' can not be applied to struct members
3471 SkASSERT(layout.fBinding == -1);
3472 SkASSERT(layout.fSet == -1);
3473 if (layout.fLocation >= 0) {
3474 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
3475 layout.fLocation, fDecorationBuffer);
3476 }
3477 if (layout.fIndex >= 0) {
3478 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
3479 layout.fIndex, fDecorationBuffer);
3480 }
3481 if (layout.fInputAttachmentIndex >= 0) {
3482 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
3483 layout.fInputAttachmentIndex, fDecorationBuffer);
3484 }
3485 if (layout.fBuiltin >= 0) {
3486 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
3487 layout.fBuiltin, fDecorationBuffer);
3488 }
3489 }
3490
memoryLayoutForStorageClass(SpvStorageClass_ storageClass)3491 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(SpvStorageClass_ storageClass) {
3492 return storageClass == SpvStorageClassPushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
3493 : fDefaultLayout;
3494 }
3495
memoryLayoutForVariable(const Variable & v) const3496 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
3497 bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
3498 return pushConstant ? MemoryLayout(MemoryLayout::Standard::k430) : fDefaultLayout;
3499 }
3500
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)3501 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
3502 MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
3503 SpvId result = this->nextId(nullptr);
3504 const Variable& intfVar = *intf.var();
3505 const Type& type = intfVar.type();
3506 if (!memoryLayout.isSupported(type)) {
3507 fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
3508 "' is not permitted here");
3509 return this->nextId(nullptr);
3510 }
3511 SpvStorageClass_ storageClass =
3512 get_storage_class_for_global_variable(intfVar, SpvStorageClassFunction);
3513 if (fProgram.fInputs.fUseFlipRTUniform && appendRTFlip && type.isStruct()) {
3514 // We can only have one interface block (because we use push_constant and that is limited
3515 // to one per program), so we need to append rtflip to this one rather than synthesize an
3516 // entirely new block when the variable is referenced. And we can't modify the existing
3517 // block, so we instead create a modified copy of it and write that.
3518 std::vector<Type::Field> fields = type.fields();
3519 fields.emplace_back(Position(),
3520 Modifiers(Layout(/*flags=*/0,
3521 /*location=*/-1,
3522 fProgram.fConfig->fSettings.fRTFlipOffset,
3523 /*binding=*/-1,
3524 /*index=*/-1,
3525 /*set=*/-1,
3526 /*builtin=*/-1,
3527 /*inputAttachmentIndex=*/-1),
3528 /*flags=*/0),
3529 SKSL_RTFLIP_NAME,
3530 fContext.fTypes.fFloat2.get());
3531 {
3532 AutoAttachPoolToThread attach(fProgram.fPool.get());
3533 const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
3534 Type::MakeStructType(fContext,
3535 type.fPosition,
3536 type.name(),
3537 std::move(fields),
3538 /*interfaceBlock=*/true));
3539 InterfaceBlockVariable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
3540 std::make_unique<InterfaceBlockVariable>(intfVar.fPosition,
3541 intfVar.modifiersPosition(),
3542 &intfVar.modifiers(),
3543 intfVar.name(),
3544 rtFlipStructType,
3545 intfVar.isBuiltin(),
3546 intfVar.storage()));
3547 fSPIRVBonusVariables.add(modifiedVar);
3548 InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar, intf.typeOwner());
3549 result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
3550 fProgram.fSymbols->add(std::make_unique<Field>(
3551 Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
3552 }
3553 fVariableMap.set(&intfVar, result);
3554 fWroteRTFlip = true;
3555 return result;
3556 }
3557 const Modifiers& intfModifiers = intfVar.modifiers();
3558 SpvId typeId = this->getType(type, memoryLayout);
3559 if (intfModifiers.fLayout.fBuiltin == -1) {
3560 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
3561 }
3562 SpvId ptrType = this->nextId(nullptr);
3563 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
3564 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
3565 Layout layout = intfModifiers.fLayout;
3566 if (storageClass == SpvStorageClassUniform && layout.fSet < 0) {
3567 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3568 }
3569 this->writeLayout(layout, result, intfVar.fPosition);
3570 fVariableMap.set(&intfVar, result);
3571 return result;
3572 }
3573
isDead(const Variable & var) const3574 bool SPIRVCodeGenerator::isDead(const Variable& var) const {
3575 // During SPIR-V code generation, we synthesize some extra bonus variables that don't actually
3576 // exist in the Program at all and aren't tracked by the ProgramUsage. They aren't dead, though.
3577 if (fSPIRVBonusVariables.contains(&var)) {
3578 return false;
3579 }
3580 ProgramUsage::VariableCounts counts = fProgram.usage()->get(var);
3581 if (counts.fRead || counts.fWrite) {
3582 return false;
3583 }
3584 // It's not entirely clear what the rules are for eliding interface variables. Generally, it
3585 // causes problems to elide them, even when they're dead.
3586 return !(var.modifiers().fFlags &
3587 (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag));
3588 }
3589
3590 // This function determines whether to skip an OpVariable (of pointer type) declaration for
3591 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
3592 // always reference by value.
3593 //
3594 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
3595 // requires a base operand of pointer type. However, a vector can always be accessed by value using
3596 // OpVectorExtractDynamic (see writeIndexExpression).
3597 //
3598 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
3599 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)3600 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
3601 return varDecl.var()->modifiers().fFlags & Modifiers::kConst_Flag &&
3602 (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
3603 (ConstantFolder::GetConstantValueOrNullForVariable(*varDecl.value()) ||
3604 Analysis::IsCompileTimeConstant(*varDecl.value()));
3605 }
3606
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)3607 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
3608 const VarDeclaration& varDecl) {
3609 const Variable* var = varDecl.var();
3610 const bool inDawnMode = fProgram.fConfig->fSettings.fSPIRVDawnCompatMode;
3611 const int backendFlags = var->modifiers().fLayout.fFlags & Layout::kAllBackendFlagsMask;
3612 const int permittedBackendFlags = Layout::kSPIRV_Flag | (inDawnMode ? Layout::kWGSL_Flag : 0);
3613 if (backendFlags & ~permittedBackendFlags) {
3614 fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
3615 return false;
3616 }
3617
3618 // If this global variable is a compile-time constant then we'll emit OpConstant or
3619 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
3620 if (is_vardecl_compile_time_constant(varDecl)) {
3621 return true;
3622 }
3623
3624 SpvStorageClass_ storageClass =
3625 get_storage_class_for_global_variable(*var, SpvStorageClassPrivate);
3626 if (storageClass == SpvStorageClassUniform) {
3627 // Top-level uniforms are emitted in writeUniformBuffer.
3628 fTopLevelUniforms.push_back(&varDecl);
3629 return true;
3630 }
3631
3632 if (this->isDead(*var)) {
3633 return true;
3634 }
3635
3636 if (var->type().typeKind() == Type::TypeKind::kSampler && inDawnMode) {
3637 if (var->modifiers().fLayout.fTexture == -1 || var->modifiers().fLayout.fSampler == -1 ||
3638 !(var->modifiers().fLayout.fFlags & Layout::kWGSL_Flag)) {
3639 fContext.fErrors->error(var->fPosition,
3640 "SPIR-V dawn compatibility mode requires an explicit texture "
3641 "and sampler index");
3642 return false;
3643 }
3644 SkASSERT(storageClass == SpvStorageClassUniformConstant);
3645
3646 auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
3647 this->writeGlobalVar(kind, storageClass, *texture);
3648 this->writeGlobalVar(kind, storageClass, *sampler);
3649
3650 return true;
3651 }
3652
3653 SpvId id = this->writeGlobalVar(kind, storageClass, *var);
3654 if (id != NA && varDecl.value()) {
3655 SkASSERT(!fCurrentBlock);
3656 fCurrentBlock = NA;
3657 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
3658 this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
3659 fCurrentBlock = 0;
3660 }
3661 return true;
3662 }
3663
writeGlobalVar(ProgramKind kind,SpvStorageClass_ storageClass,const Variable & var)3664 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
3665 SpvStorageClass_ storageClass,
3666 const Variable& var) {
3667 if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
3668 !ProgramConfig::IsFragment(kind)) {
3669 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
3670 return NA;
3671 }
3672
3673 // Add this global to the variable map.
3674 const Type& type = var.type();
3675 SpvId id = this->nextId(&type);
3676 fVariableMap.set(&var, id);
3677
3678 Layout layout = var.modifiers().fLayout;
3679 if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
3680 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3681 }
3682
3683 SpvId typeId = this->getPointerType(type, storageClass);
3684 this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
3685 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3686 this->writeLayout(layout, id, var.fPosition);
3687 if (var.modifiers().fFlags & Modifiers::kFlat_Flag) {
3688 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
3689 }
3690 if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) {
3691 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
3692 fDecorationBuffer);
3693 }
3694
3695 return id;
3696 }
3697
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)3698 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
3699 // If this variable is a compile-time constant then we'll emit OpConstant or
3700 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
3701 if (is_vardecl_compile_time_constant(varDecl)) {
3702 return;
3703 }
3704
3705 const Variable* var = varDecl.var();
3706 SpvId id = this->nextId(&var->type());
3707 fVariableMap.set(var, id);
3708 SpvId type = this->getPointerType(var->type(), SpvStorageClassFunction);
3709 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
3710 this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
3711 if (varDecl.value()) {
3712 SpvId value = this->writeExpression(*varDecl.value(), out);
3713 this->writeOpStore(SpvStorageClassFunction, id, value, out);
3714 }
3715 }
3716
writeStatement(const Statement & s,OutputStream & out)3717 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
3718 switch (s.kind()) {
3719 case Statement::Kind::kNop:
3720 break;
3721 case Statement::Kind::kBlock:
3722 this->writeBlock(s.as<Block>(), out);
3723 break;
3724 case Statement::Kind::kExpression:
3725 this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
3726 break;
3727 case Statement::Kind::kReturn:
3728 this->writeReturnStatement(s.as<ReturnStatement>(), out);
3729 break;
3730 case Statement::Kind::kVarDeclaration:
3731 this->writeVarDeclaration(s.as<VarDeclaration>(), out);
3732 break;
3733 case Statement::Kind::kIf:
3734 this->writeIfStatement(s.as<IfStatement>(), out);
3735 break;
3736 case Statement::Kind::kFor:
3737 this->writeForStatement(s.as<ForStatement>(), out);
3738 break;
3739 case Statement::Kind::kDo:
3740 this->writeDoStatement(s.as<DoStatement>(), out);
3741 break;
3742 case Statement::Kind::kSwitch:
3743 this->writeSwitchStatement(s.as<SwitchStatement>(), out);
3744 break;
3745 case Statement::Kind::kBreak:
3746 this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
3747 break;
3748 case Statement::Kind::kContinue:
3749 this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
3750 break;
3751 case Statement::Kind::kDiscard:
3752 this->writeInstruction(SpvOpKill, out);
3753 break;
3754 default:
3755 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
3756 break;
3757 }
3758 }
3759
writeBlock(const Block & b,OutputStream & out)3760 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
3761 for (const std::unique_ptr<Statement>& stmt : b.children()) {
3762 this->writeStatement(*stmt, out);
3763 }
3764 }
3765
getConditionalOpCounts()3766 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
3767 return {fReachableOps.size(), fStoreOps.size()};
3768 }
3769
pruneConditionalOps(ConditionalOpCounts ops)3770 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
3771 // Remove ops which are no longer reachable.
3772 while (fReachableOps.size() > ops.numReachableOps) {
3773 SpvId prunableSpvId = fReachableOps.back();
3774 const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
3775
3776 if (prunableOp) {
3777 fOpCache.remove(*prunableOp);
3778 fSpvIdCache.remove(prunableSpvId);
3779 } else {
3780 SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
3781 }
3782
3783 fReachableOps.pop_back();
3784 }
3785
3786 // Remove any cached stores that occurred during the conditional block.
3787 while (fStoreOps.size() > ops.numStoreOps) {
3788 if (fStoreCache.find(fStoreOps.back())) {
3789 fStoreCache.remove(fStoreOps.back());
3790 }
3791 fStoreOps.pop_back();
3792 }
3793 }
3794
writeIfStatement(const IfStatement & stmt,OutputStream & out)3795 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
3796 SpvId test = this->writeExpression(*stmt.test(), out);
3797 SpvId ifTrue = this->nextId(nullptr);
3798 SpvId ifFalse = this->nextId(nullptr);
3799
3800 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3801
3802 if (stmt.ifFalse()) {
3803 SpvId end = this->nextId(nullptr);
3804 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3805 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3806 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
3807 this->writeStatement(*stmt.ifTrue(), out);
3808 if (fCurrentBlock) {
3809 this->writeInstruction(SpvOpBranch, end, out);
3810 }
3811 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
3812 this->writeStatement(*stmt.ifFalse(), out);
3813 if (fCurrentBlock) {
3814 this->writeInstruction(SpvOpBranch, end, out);
3815 }
3816 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3817 } else {
3818 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
3819 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3820 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
3821 this->writeStatement(*stmt.ifTrue(), out);
3822 if (fCurrentBlock) {
3823 this->writeInstruction(SpvOpBranch, ifFalse, out);
3824 }
3825 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
3826 }
3827 }
3828
writeForStatement(const ForStatement & f,OutputStream & out)3829 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
3830 if (f.initializer()) {
3831 this->writeStatement(*f.initializer(), out);
3832 }
3833
3834 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3835
3836 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3837 // in the context of linear straight-line execution. If we wanted to be more clever, we could
3838 // only invalidate store cache entries for variables affected by the loop body, but for now we
3839 // simply clear the entire cache whenever branching occurs.
3840 SpvId header = this->nextId(nullptr);
3841 SpvId start = this->nextId(nullptr);
3842 SpvId body = this->nextId(nullptr);
3843 SpvId next = this->nextId(nullptr);
3844 fContinueTarget.push_back(next);
3845 SpvId end = this->nextId(nullptr);
3846 fBreakTarget.push_back(end);
3847 this->writeInstruction(SpvOpBranch, header, out);
3848 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
3849 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
3850 this->writeInstruction(SpvOpBranch, start, out);
3851 this->writeLabel(start, kBranchIsOnPreviousLine, out);
3852 if (f.test()) {
3853 SpvId test = this->writeExpression(*f.test(), out);
3854 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
3855 } else {
3856 this->writeInstruction(SpvOpBranch, body, out);
3857 }
3858 this->writeLabel(body, kBranchIsOnPreviousLine, out);
3859 this->writeStatement(*f.statement(), out);
3860 if (fCurrentBlock) {
3861 this->writeInstruction(SpvOpBranch, next, out);
3862 }
3863 this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
3864 if (f.next()) {
3865 this->writeExpression(*f.next(), out);
3866 }
3867 this->writeInstruction(SpvOpBranch, header, out);
3868 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3869 fBreakTarget.pop_back();
3870 fContinueTarget.pop_back();
3871 }
3872
writeDoStatement(const DoStatement & d,OutputStream & out)3873 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
3874 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3875
3876 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3877 // in the context of linear straight-line execution. If we wanted to be more clever, we could
3878 // only invalidate store cache entries for variables affected by the loop body, but for now we
3879 // simply clear the entire cache whenever branching occurs.
3880 SpvId header = this->nextId(nullptr);
3881 SpvId start = this->nextId(nullptr);
3882 SpvId next = this->nextId(nullptr);
3883 SpvId continueTarget = this->nextId(nullptr);
3884 fContinueTarget.push_back(continueTarget);
3885 SpvId end = this->nextId(nullptr);
3886 fBreakTarget.push_back(end);
3887 this->writeInstruction(SpvOpBranch, header, out);
3888 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
3889 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
3890 this->writeInstruction(SpvOpBranch, start, out);
3891 this->writeLabel(start, kBranchIsOnPreviousLine, out);
3892 this->writeStatement(*d.statement(), out);
3893 if (fCurrentBlock) {
3894 this->writeInstruction(SpvOpBranch, next, out);
3895 this->writeLabel(next, kBranchIsOnPreviousLine, out);
3896 this->writeInstruction(SpvOpBranch, continueTarget, out);
3897 }
3898 this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
3899 SpvId test = this->writeExpression(*d.test(), out);
3900 this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
3901 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3902 fBreakTarget.pop_back();
3903 fContinueTarget.pop_back();
3904 }
3905
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)3906 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3907 SpvId value = this->writeExpression(*s.value(), out);
3908
3909 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
3910
3911 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
3912 // in the context of linear straight-line execution. If we wanted to be more clever, we could
3913 // only invalidate store cache entries for variables affected by the switch body, but for now we
3914 // simply clear the entire cache whenever branching occurs.
3915 SkTArray<SpvId> labels;
3916 SpvId end = this->nextId(nullptr);
3917 SpvId defaultLabel = end;
3918 fBreakTarget.push_back(end);
3919 int size = 3;
3920 const StatementArray& cases = s.cases();
3921 for (const std::unique_ptr<Statement>& stmt : cases) {
3922 const SwitchCase& c = stmt->as<SwitchCase>();
3923 SpvId label = this->nextId(nullptr);
3924 labels.push_back(label);
3925 if (!c.isDefault()) {
3926 size += 2;
3927 } else {
3928 defaultLabel = label;
3929 }
3930 }
3931
3932 // We should have exactly one label for each case.
3933 SkASSERT(labels.size() == cases.size());
3934
3935 // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
3936 // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
3937 // does support multiple switch-cases branching to the same label.
3938 SkBitSet caseIsCollapsed(cases.size());
3939 for (int index = cases.size() - 2; index >= 0; index--) {
3940 if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
3941 caseIsCollapsed.set(index);
3942 labels[index] = labels[index + 1];
3943 }
3944 }
3945
3946 labels.push_back(end);
3947
3948 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3949 this->writeOpCode(SpvOpSwitch, size, out);
3950 this->writeWord(value, out);
3951 this->writeWord(defaultLabel, out);
3952 for (int i = 0; i < cases.size(); ++i) {
3953 const SwitchCase& c = cases[i]->as<SwitchCase>();
3954 if (c.isDefault()) {
3955 continue;
3956 }
3957 this->writeWord(c.value(), out);
3958 this->writeWord(labels[i], out);
3959 }
3960 for (int i = 0; i < cases.size(); ++i) {
3961 if (caseIsCollapsed.test(i)) {
3962 continue;
3963 }
3964 const SwitchCase& c = cases[i]->as<SwitchCase>();
3965 if (i == 0) {
3966 this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
3967 } else {
3968 this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
3969 }
3970 this->writeStatement(*c.statement(), out);
3971 if (fCurrentBlock) {
3972 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3973 }
3974 }
3975 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
3976 fBreakTarget.pop_back();
3977 }
3978
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3979 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3980 if (r.expression()) {
3981 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
3982 out);
3983 } else {
3984 this->writeInstruction(SpvOpReturn, out);
3985 }
3986 }
3987
3988 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)3989 static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
3990 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
3991 }
3992
writeEntrypointAdapter(const FunctionDeclaration & main)3993 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
3994 const FunctionDeclaration& main) {
3995 // Our goal is to synthesize a tiny helper function which looks like this:
3996 // void _entrypoint() { sk_FragColor = main(); }
3997
3998 // Fish a symbol table out of main().
3999 std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main);
4000
4001 // Get `sk_FragColor` as a writable reference.
4002 const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
4003 SkASSERT(skFragColorSymbol);
4004 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
4005 auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
4006 VariableReference::RefKind::kWrite);
4007 // Synthesize a call to the `main()` function.
4008 if (!main.returnType().matches(skFragColorRef->type())) {
4009 fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
4010 main.returnType().description() + "' from main()");
4011 return {};
4012 }
4013 ExpressionArray args;
4014 if (main.parameters().size() == 1) {
4015 if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
4016 fContext.fErrors->error(main.fPosition,
4017 "SPIR-V does not support parameter of type '" +
4018 main.parameters()[0]->type().description() + "' to main()");
4019 return {};
4020 }
4021 args.push_back(dsl::Float2(0).release());
4022 }
4023 auto callMainFn = std::make_unique<FunctionCall>(Position(), &main.returnType(), &main,
4024 std::move(args));
4025
4026 // Synthesize `skFragColor = main()` as a BinaryExpression.
4027 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
4028 Position(),
4029 std::move(skFragColorRef),
4030 Operator::Kind::EQ,
4031 std::move(callMainFn),
4032 &main.returnType()));
4033
4034 // Function bodies are always wrapped in a Block.
4035 StatementArray entrypointStmts;
4036 entrypointStmts.push_back(std::move(assignmentStmt));
4037 auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
4038 Block::Kind::kBracedScope, symbolTable);
4039 // Declare an entrypoint function.
4040 EntrypointAdapter adapter;
4041 adapter.fLayout = {};
4042 adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kNo_Flag};
4043 adapter.entrypointDecl =
4044 std::make_unique<FunctionDeclaration>(Position(),
4045 &adapter.fModifiers,
4046 "_entrypoint",
4047 /*parameters=*/std::vector<Variable*>{},
4048 /*returnType=*/fContext.fTypes.fVoid.get(),
4049 /*builtin=*/false);
4050 // Define it.
4051 adapter.entrypointDef = FunctionDefinition::Convert(fContext,
4052 Position(),
4053 *adapter.entrypointDecl,
4054 std::move(entrypointBlock),
4055 /*builtin=*/false);
4056
4057 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
4058 return adapter;
4059 }
4060
writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable)4061 void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) {
4062 SkASSERT(!fTopLevelUniforms.empty());
4063 static constexpr char kUniformBufferName[] = "_UniformBuffer";
4064
4065 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
4066 // a lookup table of variables to UniformBuffer field indices.
4067 std::vector<Type::Field> fields;
4068 fields.reserve(fTopLevelUniforms.size());
4069 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
4070 const Variable* var = topLevelUniform->var();
4071 fTopLevelUniformMap.set(var, (int)fields.size());
4072 Modifiers modifiers = var->modifiers();
4073 modifiers.fFlags &= ~Modifiers::kUniform_Flag;
4074 fields.emplace_back(var->fPosition, modifiers, var->name(), &var->type());
4075 }
4076 fUniformBuffer.fStruct = Type::MakeStructType(fContext,
4077 Position(),
4078 kUniformBufferName,
4079 std::move(fields),
4080 /*interfaceBlock=*/true);
4081
4082 // Create a global variable to contain this struct.
4083 Layout layout;
4084 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4085 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4086 Modifiers modifiers{layout, Modifiers::kUniform_Flag};
4087
4088 fUniformBuffer.fInnerVariable = std::make_unique<InterfaceBlockVariable>(
4089 /*pos=*/Position(), /*modifiersPosition=*/Position(),
4090 fContext.fModifiersPool->add(modifiers), kUniformBufferName,
4091 fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal);
4092
4093 // Create an interface block object for this global variable.
4094 fUniformBuffer.fInterfaceBlock =
4095 std::make_unique<InterfaceBlock>(Position(),
4096 fUniformBuffer.fInnerVariable.get(),
4097 topLevelSymbolTable);
4098
4099 // Generate an interface block and hold onto its ID.
4100 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
4101 }
4102
addRTFlipUniform(Position pos)4103 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
4104 SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
4105
4106 if (fWroteRTFlip) {
4107 return;
4108 }
4109 // Flip variable hasn't been written yet. This means we don't have an existing
4110 // interface block, so we're free to just synthesize one.
4111 fWroteRTFlip = true;
4112 std::vector<Type::Field> fields;
4113 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
4114 fContext.fErrors->error(pos, "RTFlipOffset is negative");
4115 }
4116 fields.emplace_back(pos,
4117 Modifiers(Layout(/*flags=*/0,
4118 /*location=*/-1,
4119 fProgram.fConfig->fSettings.fRTFlipOffset,
4120 /*binding=*/-1,
4121 /*index=*/-1,
4122 /*set=*/-1,
4123 /*builtin=*/-1,
4124 /*inputAttachmentIndex=*/-1),
4125 /*flags=*/0),
4126 SKSL_RTFLIP_NAME,
4127 fContext.fTypes.fFloat2.get());
4128 std::string_view name = "sksl_synthetic_uniforms";
4129 const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(
4130 Type::MakeStructType(fContext, Position(), name, fields, /*interfaceBlock=*/true));
4131 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
4132 int binding = -1, set = -1;
4133 if (!usePushConstants) {
4134 binding = fProgram.fConfig->fSettings.fRTFlipBinding;
4135 if (binding == -1) {
4136 fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
4137 }
4138 set = fProgram.fConfig->fSettings.fRTFlipSet;
4139 if (set == -1) {
4140 fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
4141 }
4142 }
4143 int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0;
4144 const Modifiers* modsPtr;
4145 {
4146 AutoAttachPoolToThread attach(fProgram.fPool.get());
4147 Modifiers modifiers(Layout(flags,
4148 /*location=*/-1,
4149 /*offset=*/-1,
4150 binding,
4151 /*index=*/-1,
4152 set,
4153 /*builtin=*/-1,
4154 /*inputAttachmentIndex=*/-1),
4155 Modifiers::kUniform_Flag);
4156 modsPtr = fContext.fModifiersPool->add(modifiers);
4157 }
4158 InterfaceBlockVariable* intfVar = fSynthetics.takeOwnershipOfSymbol(
4159 std::make_unique<InterfaceBlockVariable>(/*pos=*/Position(),
4160 /*modifiersPosition=*/Position(),
4161 modsPtr,
4162 name,
4163 intfStruct,
4164 /*builtin=*/false,
4165 Variable::Storage::kGlobal));
4166 fSPIRVBonusVariables.add(intfVar);
4167 {
4168 AutoAttachPoolToThread attach(fProgram.fPool.get());
4169 fProgram.fSymbols->add(std::make_unique<Field>(Position(), intfVar, /*field=*/0));
4170 }
4171 InterfaceBlock intf(Position(), intfVar, std::make_shared<SymbolTable>(/*builtin=*/false));
4172 this->writeInterfaceBlock(intf, false);
4173 }
4174
synthesizeTextureAndSampler(const Variable & combinedSampler)4175 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
4176 const Variable& combinedSampler) {
4177 SkASSERT(fProgram.fConfig->fSettings.fSPIRVDawnCompatMode);
4178 SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
4179
4180 const Modifiers& modifiers = combinedSampler.modifiers();
4181
4182 auto data = std::make_unique<SynthesizedTextureSamplerPair>();
4183
4184 Modifiers texModifiers = modifiers;
4185 texModifiers.fLayout.fBinding = modifiers.fLayout.fTexture;
4186 data->fTextureName = std::string(combinedSampler.name()) + "_texture";
4187 auto texture = std::make_unique<Variable>(/*pos=*/Position(),
4188 /*modifierPosition=*/Position(),
4189 fContext.fModifiersPool->add(texModifiers),
4190 data->fTextureName,
4191 &combinedSampler.type().textureType(),
4192 /*builtin=*/false,
4193 Variable::Storage::kGlobal);
4194
4195 Modifiers samplerModifiers = modifiers;
4196 samplerModifiers.fLayout.fBinding = modifiers.fLayout.fSampler;
4197 data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
4198 auto sampler = std::make_unique<Variable>(/*pos=*/Position(),
4199 /*modifierPosition=*/Position(),
4200 fContext.fModifiersPool->add(samplerModifiers),
4201 data->fSamplerName,
4202 fContext.fTypes.fSampler.get(),
4203 /*builtin=*/false,
4204 Variable::Storage::kGlobal);
4205
4206 const Variable* t = texture.get();
4207 const Variable* s = sampler.get();
4208 data->fTexture = std::move(texture);
4209 data->fSampler = std::move(sampler);
4210 fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
4211
4212 return {t, s};
4213 }
4214
writeInstructions(const Program & program,OutputStream & out)4215 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
4216 fGLSLExtendedInstructions = this->nextId(nullptr);
4217 StringStream body;
4218 // Assign SpvIds to functions.
4219 const FunctionDeclaration* main = nullptr;
4220 for (const ProgramElement* e : program.elements()) {
4221 if (e->is<FunctionDefinition>()) {
4222 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
4223 const FunctionDeclaration& funcDecl = funcDef.declaration();
4224 fFunctionMap.set(&funcDecl, this->nextId(nullptr));
4225 if (funcDecl.isMain()) {
4226 main = &funcDecl;
4227 }
4228 }
4229 }
4230 // Make sure we have a main() function.
4231 if (!main) {
4232 fContext.fErrors->error(Position(), "program does not contain a main() function");
4233 return;
4234 }
4235 // Emit interface blocks.
4236 std::set<SpvId> interfaceVars;
4237 for (const ProgramElement* e : program.elements()) {
4238 if (e->is<InterfaceBlock>()) {
4239 const InterfaceBlock& intf = e->as<InterfaceBlock>();
4240 SpvId id = this->writeInterfaceBlock(intf);
4241
4242 const Modifiers& modifiers = intf.var()->modifiers();
4243 if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
4244 modifiers.fLayout.fBuiltin == -1 && !this->isDead(*intf.var())) {
4245 interfaceVars.insert(id);
4246 }
4247 }
4248 }
4249 // Emit global variable declarations.
4250 for (const ProgramElement* e : program.elements()) {
4251 if (e->is<GlobalVarDeclaration>()) {
4252 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind,
4253 e->as<GlobalVarDeclaration>().varDeclaration())) {
4254 return;
4255 }
4256 }
4257 }
4258 // Emit top-level uniforms into a dedicated uniform buffer.
4259 if (!fTopLevelUniforms.empty()) {
4260 this->writeUniformBuffer(get_top_level_symbol_table(*main));
4261 }
4262 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
4263 // main() and stores the result into sk_FragColor.
4264 EntrypointAdapter adapter;
4265 if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
4266 adapter = this->writeEntrypointAdapter(*main);
4267 if (adapter.entrypointDecl) {
4268 fFunctionMap.set(adapter.entrypointDecl.get(), this->nextId(nullptr));
4269 this->writeFunction(*adapter.entrypointDef, body);
4270 main = adapter.entrypointDecl.get();
4271 }
4272 }
4273 // Emit all the functions.
4274 for (const ProgramElement* e : program.elements()) {
4275 if (e->is<FunctionDefinition>()) {
4276 this->writeFunction(e->as<FunctionDefinition>(), body);
4277 }
4278 }
4279 // Add global in/out variables to the list of interface variables.
4280 for (const auto& [var, spvId] : fVariableMap) {
4281 if (var->storage() == Variable::Storage::kGlobal &&
4282 (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
4283 !this->isDead(*var)) {
4284 interfaceVars.insert(spvId);
4285 }
4286 }
4287 this->writeCapabilities(out);
4288 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
4289 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
4290 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().length() + 4) / 4) +
4291 (int32_t) interfaceVars.size(), out);
4292 if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
4293 this->writeWord(SpvExecutionModelVertex, out);
4294 } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
4295 this->writeWord(SpvExecutionModelFragment, out);
4296 } else {
4297 SK_ABORT("cannot write this kind of program to SPIR-V\n");
4298 }
4299 SpvId entryPoint = fFunctionMap[main];
4300 this->writeWord(entryPoint, out);
4301 this->writeString(main->name(), out);
4302 for (int var : interfaceVars) {
4303 this->writeWord(var, out);
4304 }
4305 if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
4306 this->writeInstruction(SpvOpExecutionMode,
4307 fFunctionMap[main],
4308 SpvExecutionModeOriginUpperLeft,
4309 out);
4310 }
4311 for (const ProgramElement* e : program.elements()) {
4312 if (e->is<Extension>()) {
4313 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
4314 }
4315 }
4316
4317 write_stringstream(fNameBuffer, out);
4318 write_stringstream(fDecorationBuffer, out);
4319 write_stringstream(fConstantBuffer, out);
4320 write_stringstream(body, out);
4321 }
4322
generateCode()4323 bool SPIRVCodeGenerator::generateCode() {
4324 SkASSERT(!fContext.fErrors->errorCount());
4325 this->writeWord(SpvMagicNumber, *fOut);
4326 this->writeWord(SpvVersion, *fOut);
4327 this->writeWord(SKSL_MAGIC, *fOut);
4328 StringStream buffer;
4329 this->writeInstructions(fProgram, buffer);
4330 this->writeWord(fIdCount, *fOut);
4331 this->writeWord(0, *fOut); // reserved, always zero
4332 write_stringstream(buffer, *fOut);
4333 return fContext.fErrors->errorCount() == 0;
4334 }
4335
4336 } // namespace SkSL
4337