1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/codegen/SkSLCodeGenerator.h"
36 #include "src/sksl/ir/SkSLBinaryExpression.h"
37 #include "src/sksl/ir/SkSLBlock.h"
38 #include "src/sksl/ir/SkSLConstructor.h"
39 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
40 #include "src/sksl/ir/SkSLConstructorCompound.h"
41 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
42 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
43 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
44 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
45 #include "src/sksl/ir/SkSLConstructorSplat.h"
46 #include "src/sksl/ir/SkSLDoStatement.h"
47 #include "src/sksl/ir/SkSLExpression.h"
48 #include "src/sksl/ir/SkSLExpressionStatement.h"
49 #include "src/sksl/ir/SkSLExtension.h"
50 #include "src/sksl/ir/SkSLFieldAccess.h"
51 #include "src/sksl/ir/SkSLFieldSymbol.h"
52 #include "src/sksl/ir/SkSLForStatement.h"
53 #include "src/sksl/ir/SkSLFunctionCall.h"
54 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
55 #include "src/sksl/ir/SkSLFunctionDefinition.h"
56 #include "src/sksl/ir/SkSLIRNode.h"
57 #include "src/sksl/ir/SkSLIfStatement.h"
58 #include "src/sksl/ir/SkSLIndexExpression.h"
59 #include "src/sksl/ir/SkSLInterfaceBlock.h"
60 #include "src/sksl/ir/SkSLLayout.h"
61 #include "src/sksl/ir/SkSLLiteral.h"
62 #include "src/sksl/ir/SkSLModifierFlags.h"
63 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
64 #include "src/sksl/ir/SkSLPoison.h"
65 #include "src/sksl/ir/SkSLPostfixExpression.h"
66 #include "src/sksl/ir/SkSLPrefixExpression.h"
67 #include "src/sksl/ir/SkSLProgram.h"
68 #include "src/sksl/ir/SkSLProgramElement.h"
69 #include "src/sksl/ir/SkSLReturnStatement.h"
70 #include "src/sksl/ir/SkSLSetting.h"
71 #include "src/sksl/ir/SkSLStatement.h"
72 #include "src/sksl/ir/SkSLSwitchCase.h"
73 #include "src/sksl/ir/SkSLSwitchStatement.h"
74 #include "src/sksl/ir/SkSLSwizzle.h"
75 #include "src/sksl/ir/SkSLSymbol.h"
76 #include "src/sksl/ir/SkSLSymbolTable.h"
77 #include "src/sksl/ir/SkSLTernaryExpression.h"
78 #include "src/sksl/ir/SkSLType.h"
79 #include "src/sksl/ir/SkSLVarDeclarations.h"
80 #include "src/sksl/ir/SkSLVariable.h"
81 #include "src/sksl/ir/SkSLVariableReference.h"
82 #include "src/sksl/spirv.h"
83 #include "src/sksl/transform/SkSLTransform.h"
84 #include "src/utils/SkBitSet.h"
85
86 #include <cstdint>
87 #include <cstring>
88 #include <memory>
89 #include <set>
90 #include <string>
91 #include <string_view>
92 #include <tuple>
93 #include <utility>
94 #include <vector>
95
96 #ifdef SK_ENABLE_SPIRV_VALIDATION
97 #include "spirv-tools/libspirv.hpp"
98 #endif
99
100 using namespace skia_private;
101
102 #define kLast_Capability SpvCapabilityMultiViewport
103
104 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
105 constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001;
106 static constexpr SkSL::Layout kDefaultTypeLayout;
107
108 namespace SkSL {
109
110 enum class ProgramKind : int8_t;
111
112 class SPIRVCodeGenerator : public CodeGenerator {
113 public:
114 // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
115 static constexpr SpvId NA = (SpvId)-1;
116
117 class LValue {
118 public:
~LValue()119 virtual ~LValue() {}
120
121 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
122 // by a pointer (e.g. vector swizzles), returns NA.
getPointer()123 virtual SpvId getPointer() { return NA; }
124
125 // Returns true if a valid pointer returned by getPointer represents a memory object
126 // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
127 // getPointer() returns NA.
isMemoryObjectPointer() const128 virtual bool isMemoryObjectPointer() const { return true; }
129
130 // Applies a swizzle to the components of the LValue, if possible. This is used to create
131 // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)132 virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
133 return false;
134 }
135
136 // Returns the storage class of the lvalue.
137 virtual SpvStorageClass storageClass() const = 0;
138
139 virtual SpvId load(OutputStream& out) = 0;
140
141 virtual void store(SpvId value, OutputStream& out) = 0;
142 };
143
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)144 SPIRVCodeGenerator(const Context* context,
145 const ShaderCaps* caps,
146 const Program* program,
147 OutputStream* out)
148 : INHERITED(context, caps, program, out) {}
149
150 bool generateCode() override;
151
152 private:
153 enum IntrinsicOpcodeKind {
154 kGLSL_STD_450_IntrinsicOpcodeKind,
155 kSPIRV_IntrinsicOpcodeKind,
156 kSpecial_IntrinsicOpcodeKind,
157 kInvalid_IntrinsicOpcodeKind,
158 };
159
160 enum SpecialIntrinsic {
161 kAtan_SpecialIntrinsic,
162 kClamp_SpecialIntrinsic,
163 kMatrixCompMult_SpecialIntrinsic,
164 kMax_SpecialIntrinsic,
165 kMin_SpecialIntrinsic,
166 kMix_SpecialIntrinsic,
167 kMod_SpecialIntrinsic,
168 kDFdy_SpecialIntrinsic,
169 kSaturate_SpecialIntrinsic,
170 kSampledImage_SpecialIntrinsic,
171 kSmoothStep_SpecialIntrinsic,
172 kStep_SpecialIntrinsic,
173 kSubpassLoad_SpecialIntrinsic,
174 kTexture_SpecialIntrinsic,
175 kTextureGrad_SpecialIntrinsic,
176 kTextureLod_SpecialIntrinsic,
177 kTextureRead_SpecialIntrinsic,
178 kTextureWrite_SpecialIntrinsic,
179 kTextureWidth_SpecialIntrinsic,
180 kTextureHeight_SpecialIntrinsic,
181 kAtomicAdd_SpecialIntrinsic,
182 kAtomicLoad_SpecialIntrinsic,
183 kAtomicStore_SpecialIntrinsic,
184 kStorageBarrier_SpecialIntrinsic,
185 kWorkgroupBarrier_SpecialIntrinsic,
186 };
187
188 enum class Precision {
189 kDefault,
190 kRelaxed,
191 };
192
193 struct TempVar {
194 SpvId spvId;
195 const Type* type;
196 std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
197 };
198
199 /**
200 * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
201 * appropriate, or null to never add one.
202 */
203 SpvId nextId(const Type* type);
204
205 SpvId nextId(Precision precision);
206
207 SpvId getType(const Type& type);
208
209 SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
210
211 SpvId getFunctionType(const FunctionDeclaration& function);
212
213 SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
214
215 SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
216
217 SpvId getPointerType(const Type& type,
218 const Layout& typeLayout,
219 const MemoryLayout& memoryLayout,
220 SpvStorageClass_ storageClass);
221
222 skia_private::TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
223
224 void writeLayout(const Layout& layout, SpvId target, Position pos);
225
226 void writeFieldLayout(const Layout& layout, SpvId target, int member);
227
228 SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
229
230 void writeProgramElement(const ProgramElement& pe, OutputStream& out);
231
232 SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
233
234 SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
235
236 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
237
238 SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
239
240 bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
241
242 SpvId writeGlobalVar(ProgramKind kind, SpvStorageClass_, const Variable& v);
243
244 void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
245
246 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
247
248 int findUniformFieldIndex(const Variable& var) const;
249
250 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
251
252 SpvId writeExpression(const Expression& expr, OutputStream& out);
253
254 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
255
256 SpvId writeFunctionCallArgument(const FunctionCall& call,
257 int argIndex,
258 std::vector<TempVar>* tempVars,
259 OutputStream& out,
260 SpvId* outSynthesizedSamplerId = nullptr);
261
262 void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
263
264 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
265
266
267 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
268 SpvId signedInst, SpvId unsignedInst,
269 const skia_private::TArray<SpvId>& args, OutputStream& out);
270
271 /**
272 * Promotes an expression to a vector. If the expression is already a vector with vectorSize
273 * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
274 * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
275 * expression is already a vector and it does not have vectorSize columns.
276 */
277 SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
278
279 /**
280 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
281 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
282 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
283 * vec2, vec3).
284 */
285 skia_private::TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
286
287 /**
288 * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
289 * returns the SpvId of the new value.
290 */
291 SpvId splat(const Type& type, SpvId id, OutputStream& out);
292
293 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
294 SpvId writeAtomicIntrinsic(const FunctionCall& c,
295 SpecialIntrinsic kind,
296 SpvId resultId,
297 OutputStream& out);
298
299 SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
300 OutputStream& out);
301
302 SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
303 OutputStream& out);
304
305 SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
306 OutputStream& out);
307
308 SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
309 OutputStream& out);
310
311 SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
312 OutputStream& out);
313
314 /**
315 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
316 * source matrix are filled with zero; entries which do not exist in the destination matrix are
317 * ignored.
318 */
319 SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
320
321 void addColumnEntry(const Type& columnType,
322 skia_private::TArray<SpvId>* currentColumn,
323 skia_private::TArray<SpvId>* columnIds,
324 int rows, SpvId entry, OutputStream& out);
325
326 SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
327
328 SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
329
330 SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
331
332 SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
333
334 SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
335
336 SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
337
338 SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
339
340 SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
341
342 SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
343
344 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
345
346 SpvId writeSwizzle(const Expression& baseExpr,
347 const ComponentArray& components,
348 OutputStream& out);
349
350 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
351
352 /**
353 * Folds the potentially-vector result of a logical operation down to a single bool. If
354 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
355 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
356 * returns the original id value.
357 */
358 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
359
360 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
361 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
362 SpvOp_ mergeOperator, OutputStream& out);
363
364 SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
365 OutputStream& out);
366
367 SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
368 OutputStream& out);
369
370 // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
371 // comparisons into an overall comparison result.
372 // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
373 // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
374 SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
375
376 // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
377 // multiplication into a sum of vector-scalar products.
378 SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
379 SpvId lhs,
380 const Type& rightType,
381 SpvId rhs,
382 const Type& resultType,
383 OutputStream& out);
384
385 SpvId writeComponentwiseMatrixUnary(const Type& operandType,
386 SpvId operand,
387 SpvOp_ op,
388 OutputStream& out);
389
390 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
391 SpvOp_ op, OutputStream& out);
392
393 SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
394 SpvId lhs, SpvId rhs,
395 SpvOp_ ifFloat, SpvOp_ ifInt,
396 SpvOp_ ifUInt, SpvOp_ ifBool,
397 OutputStream& out);
398
399 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
400 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
401 SpvOp_ ifBool, OutputStream& out);
402
403 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
404 SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
405 SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
406
407 SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
408
409 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
410 const Type& rightType, SpvId rhs, const Type& resultType,
411 OutputStream& out);
412
413 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
414
415 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
416
417 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
418
419 SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
420
421 SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
422
423 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
424
425 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
426
427 SpvId writeLiteral(const Literal& f);
428
429 SpvId writeLiteral(double value, const Type& type);
430
431 void writeStatement(const Statement& s, OutputStream& out);
432
433 void writeBlock(const Block& b, OutputStream& out);
434
435 void writeIfStatement(const IfStatement& stmt, OutputStream& out);
436
437 void writeForStatement(const ForStatement& f, OutputStream& out);
438
439 void writeDoStatement(const DoStatement& d, OutputStream& out);
440
441 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
442
443 void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
444
445 void writeCapabilities(OutputStream& out);
446
447 void writeInstructions(const Program& program, OutputStream& out);
448
449 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
450
451 void writeWord(int32_t word, OutputStream& out);
452
453 void writeString(std::string_view s, OutputStream& out);
454
455 void writeInstruction(SpvOp_ opCode, OutputStream& out);
456
457 void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
458
459 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
460
461 void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
462 OutputStream& out);
463
464 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
465 OutputStream& out);
466
467 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
468
469 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
470 OutputStream& out);
471
472 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
473 OutputStream& out);
474
475 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
476 int32_t word5, OutputStream& out);
477
478 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
479 int32_t word5, int32_t word6, OutputStream& out);
480
481 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
482 int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
483
484 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
485 int32_t word5, int32_t word6, int32_t word7, int32_t word8,
486 OutputStream& out);
487
488 // This form of writeInstruction can deduplicate redundant ops.
489 struct Word;
490 // 8 Words is enough for nearly all instructions (except variable-length instructions like
491 // OpAccessChain or OpConstantComposite).
492 using Words = skia_private::STArray<8, Word, true>;
493 SpvId writeInstruction(
494 SpvOp_ opCode, const skia_private::TArray<Word, true>& words, OutputStream& out);
495
496 struct Instruction {
497 SpvId fOp;
498 int32_t fResultKind;
499 skia_private::STArray<8, int32_t> fWords;
500
501 bool operator==(const Instruction& that) const;
502 struct Hash;
503 };
504
505 static Instruction BuildInstructionKey(SpvOp_ opCode,
506 const skia_private::TArray<Word, true>& words);
507
508 // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
509 SpvId writeOpConstantTrue(const Type& type);
510 SpvId writeOpConstantFalse(const Type& type);
511 SpvId writeOpConstant(const Type& type, int32_t valueBits);
512 SpvId writeOpConstantComposite(const Type& type, const skia_private::TArray<SpvId>& values);
513 SpvId writeOpCompositeConstruct(const Type& type, const skia_private::TArray<SpvId>&,
514 OutputStream& out);
515 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
516 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
517 OutputStream& out);
518 SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
519 void writeOpStore(SpvStorageClass_ storageClass, SpvId pointer, SpvId value, OutputStream& out);
520
521 // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
522 bool toConstants(SpvId value, skia_private::TArray<SpvId>* constants);
523 bool toConstants(SkSpan<const SpvId> values, skia_private::TArray<SpvId>* constants);
524
525 // Extracts the requested component SpvId from a composite instruction, if it can be done.
526 Instruction* resultTypeForInstruction(const Instruction& instr);
527 int numComponentsForVecInstruction(const Instruction& instr);
528 SpvId toComponent(SpvId id, int component);
529
530 struct ConditionalOpCounts {
531 int numReachableOps;
532 int numStoreOps;
533 };
534 ConditionalOpCounts getConditionalOpCounts();
535 void pruneConditionalOps(ConditionalOpCounts ops);
536
537 enum StraightLineLabelType {
538 // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
539 // happens at the start of a function, or when we find unreachable code.
540 kBranchlessBlock,
541
542 // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
543 // associated branch. Example usage:
544 // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
545 // use an OpBranch to explicitly jump to the next block, even when they are adjacent in
546 // the code.
547 // - The block immediately following an OpBranchConditional or OpSwitch.
548 kBranchIsOnPreviousLine,
549 };
550
551 enum BranchingLabelType {
552 // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
553 // ops that are above the label in the code--i.e., the branch skips forward in the code.
554 kBranchIsAbove,
555
556 // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
557 // ops below the label in the code--i.e., the branch jumps backward in the code.
558 kBranchIsBelow,
559
560 // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
561 kBranchesOnBothSides,
562 };
563 void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
564 void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
565 OutputStream& out);
566
567 MemoryLayout memoryLayoutForStorageClass(SpvStorageClass_ storageClass);
568 MemoryLayout memoryLayoutForVariable(const Variable&) const;
569
570 struct EntrypointAdapter {
571 std::unique_ptr<FunctionDefinition> entrypointDef;
572 std::unique_ptr<FunctionDeclaration> entrypointDecl;
573 };
574
575 EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
576
577 struct UniformBuffer {
578 std::unique_ptr<InterfaceBlock> fInterfaceBlock;
579 std::unique_ptr<Variable> fInnerVariable;
580 std::unique_ptr<Type> fStruct;
581 };
582
583 void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
584
585 void addRTFlipUniform(Position pos);
586
587 std::unique_ptr<Expression> identifier(std::string_view name);
588
589 std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
590 const Variable& combinedSampler);
591
592 const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
593
594 uint64_t fCapabilities = 0;
595 SpvId fIdCount = 1;
596 SpvId fGLSLExtendedInstructions;
597 struct Intrinsic {
598 IntrinsicOpcodeKind opKind;
599 int32_t floatOp;
600 int32_t signedOp;
601 int32_t unsignedOp;
602 int32_t boolOp;
603 };
604 Intrinsic getIntrinsic(IntrinsicKind) const;
605 skia_private::THashMap<const FunctionDeclaration*, SpvId> fFunctionMap;
606 skia_private::THashMap<const Variable*, SpvId> fVariableMap;
607 skia_private::THashMap<const Type*, SpvId> fStructMap;
608 StringStream fGlobalInitializersBuffer;
609 StringStream fConstantBuffer;
610 StringStream fVariableBuffer;
611 StringStream fNameBuffer;
612 StringStream fDecorationBuffer;
613
614 // Mapping from combined sampler declarations to synthesized texture/sampler variables.
615 // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
616 bool fUseTextureSamplerPairs = false;
617 struct SynthesizedTextureSamplerPair {
618 // The names of the synthesized variables. The Variable objects themselves store string
619 // views referencing these strings. It is important for the std::string instances to have a
620 // fixed memory location after the string views get created, which is why
621 // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
622 std::string fTextureName;
623 std::string fSamplerName;
624 std::unique_ptr<Variable> fTexture;
625 std::unique_ptr<Variable> fSampler;
626 };
627 skia_private::THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
628 fSynthesizedSamplerMap;
629
630 // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
631 // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
632 // and simplify code we've already emitted (by taking a SpvId from an Instruction and following
633 // it back to its source).
634
635 // A map of instruction -> SpvId:
636 skia_private::THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
637 // A map of SpvId -> instruction:
638 skia_private::THashMap<SpvId, Instruction> fSpvIdCache;
639 // A map of SpvId -> value SpvId:
640 skia_private::THashMap<SpvId, SpvId> fStoreCache;
641
642 // "Reachable" ops are instructions which can safely be accessed from the current block.
643 // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
644 // reuse that computation on following lines. However, if that Add operation occurred inside an
645 // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
646 // depending on the if condition, we may or may not have actually done that computation). The
647 // same logic applies to other control-flow blocks as well. Once an instruction becomes
648 // unreachable, we remove it from both op-caches.
649 skia_private::TArray<SpvId> fReachableOps;
650
651 // The "store-ops" list contains a running list of all the pointers in the store cache. If a
652 // store occurs inside of a conditional block, once that block exits, we no longer know what is
653 // stored in that particular SpvId. At that point, we must remove any associated entry from the
654 // store cache.
655 skia_private::TArray<SpvId> fStoreOps;
656
657 // label of the current block, or 0 if we are not in a block
658 SpvId fCurrentBlock = 0;
659 skia_private::TArray<SpvId> fBreakTarget;
660 skia_private::TArray<SpvId> fContinueTarget;
661 bool fWroteRTFlip = false;
662 // holds variables synthesized during output, for lifetime purposes
663 SymbolTable fSynthetics{/*builtin=*/true};
664 // Holds a list of uniforms that were declared as globals at the top-level instead of in an
665 // interface block.
666 UniformBuffer fUniformBuffer;
667 std::vector<const VarDeclaration*> fTopLevelUniforms;
668 skia_private::THashMap<const Variable*, int>
669 fTopLevelUniformMap; // <var, UniformBuffer field index>
670 SpvId fUniformBufferId = NA;
671
672 friend class PointerLValue;
673 friend class SwizzleLValue;
674
675 using INHERITED = CodeGenerator;
676 };
677
678 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const679 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
680 return fOp == that.fOp &&
681 fResultKind == that.fResultKind &&
682 fWords == that.fWords;
683 }
684
685 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash686 uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
687 uint32_t hash = key.fResultKind;
688 hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
689 hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
690 return hash;
691 }
692 };
693
694 // This class is used to pass values and result placeholder slots to writeInstruction.
695 struct SPIRVCodeGenerator::Word {
696 enum Kind {
697 kNone, // intended for use as a sentinel, not part of any Instruction
698 kSpvId,
699 kNumber,
700 kDefaultPrecisionResult,
701 kRelaxedPrecisionResult,
702 kUniqueResult,
703 kKeyedResult,
704 };
705
WordSkSL::SPIRVCodeGenerator::Word706 Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word707 Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
708
NumberSkSL::SPIRVCodeGenerator::Word709 static Word Number(int32_t val) {
710 return Word{val, Kind::kNumber};
711 }
712
ResultSkSL::SPIRVCodeGenerator::Word713 static Word Result(const Type& type) {
714 return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
715 }
716
RelaxedResultSkSL::SPIRVCodeGenerator::Word717 static Word RelaxedResult() {
718 return Word{(int32_t)NA, kRelaxedPrecisionResult};
719 }
720
UniqueResultSkSL::SPIRVCodeGenerator::Word721 static Word UniqueResult() {
722 return Word{(int32_t)NA, kUniqueResult};
723 }
724
ResultSkSL::SPIRVCodeGenerator::Word725 static Word Result() {
726 return Word{(int32_t)NA, kDefaultPrecisionResult};
727 }
728
729 // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
730 // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
731 // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word732 static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
733
isResultSkSL::SPIRVCodeGenerator::Word734 bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
735
736 int32_t fValue;
737 Kind fKind;
738 };
739
740 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
741 // sksl generator if we want.
742 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
743 static const int32_t SKSL_MAGIC = 0x001F0000;
744
getIntrinsic(IntrinsicKind ik) const745 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
746
747 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
748 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
749 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
750 GLSLstd450 ## ifFloat, \
751 GLSLstd450 ## ifInt, \
752 GLSLstd450 ## ifUInt, \
753 SpvOpUndef}
754 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
755 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
756 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
757 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
758 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
759 SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
760 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
761 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
762 k ## x ## _SpecialIntrinsic}
763
764 switch (ik) {
765 case k_round_IntrinsicKind: return ALL_GLSL(Round);
766 case k_roundEven_IntrinsicKind: return ALL_GLSL(RoundEven);
767 case k_trunc_IntrinsicKind: return ALL_GLSL(Trunc);
768 case k_abs_IntrinsicKind: return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
769 case k_sign_IntrinsicKind: return BY_TYPE_GLSL(FSign, SSign, SSign);
770 case k_floor_IntrinsicKind: return ALL_GLSL(Floor);
771 case k_ceil_IntrinsicKind: return ALL_GLSL(Ceil);
772 case k_fract_IntrinsicKind: return ALL_GLSL(Fract);
773 case k_radians_IntrinsicKind: return ALL_GLSL(Radians);
774 case k_degrees_IntrinsicKind: return ALL_GLSL(Degrees);
775 case k_sin_IntrinsicKind: return ALL_GLSL(Sin);
776 case k_cos_IntrinsicKind: return ALL_GLSL(Cos);
777 case k_tan_IntrinsicKind: return ALL_GLSL(Tan);
778 case k_asin_IntrinsicKind: return ALL_GLSL(Asin);
779 case k_acos_IntrinsicKind: return ALL_GLSL(Acos);
780 case k_atan_IntrinsicKind: return SPECIAL(Atan);
781 case k_sinh_IntrinsicKind: return ALL_GLSL(Sinh);
782 case k_cosh_IntrinsicKind: return ALL_GLSL(Cosh);
783 case k_tanh_IntrinsicKind: return ALL_GLSL(Tanh);
784 case k_asinh_IntrinsicKind: return ALL_GLSL(Asinh);
785 case k_acosh_IntrinsicKind: return ALL_GLSL(Acosh);
786 case k_atanh_IntrinsicKind: return ALL_GLSL(Atanh);
787 case k_pow_IntrinsicKind: return ALL_GLSL(Pow);
788 case k_exp_IntrinsicKind: return ALL_GLSL(Exp);
789 case k_log_IntrinsicKind: return ALL_GLSL(Log);
790 case k_exp2_IntrinsicKind: return ALL_GLSL(Exp2);
791 case k_log2_IntrinsicKind: return ALL_GLSL(Log2);
792 case k_sqrt_IntrinsicKind: return ALL_GLSL(Sqrt);
793 case k_inverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
794 case k_outerProduct_IntrinsicKind: return ALL_SPIRV(OuterProduct);
795 case k_transpose_IntrinsicKind: return ALL_SPIRV(Transpose);
796 case k_isinf_IntrinsicKind: return ALL_SPIRV(IsInf);
797 case k_isnan_IntrinsicKind: return ALL_SPIRV(IsNan);
798 case k_inversesqrt_IntrinsicKind: return ALL_GLSL(InverseSqrt);
799 case k_determinant_IntrinsicKind: return ALL_GLSL(Determinant);
800 case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
801 case k_matrixInverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
802 case k_mod_IntrinsicKind: return SPECIAL(Mod);
803 case k_modf_IntrinsicKind: return ALL_GLSL(Modf);
804 case k_min_IntrinsicKind: return SPECIAL(Min);
805 case k_max_IntrinsicKind: return SPECIAL(Max);
806 case k_clamp_IntrinsicKind: return SPECIAL(Clamp);
807 case k_saturate_IntrinsicKind: return SPECIAL(Saturate);
808 case k_dot_IntrinsicKind: return FLOAT_SPIRV(Dot);
809 case k_mix_IntrinsicKind: return SPECIAL(Mix);
810 case k_step_IntrinsicKind: return SPECIAL(Step);
811 case k_smoothstep_IntrinsicKind: return SPECIAL(SmoothStep);
812 case k_fma_IntrinsicKind: return ALL_GLSL(Fma);
813 case k_frexp_IntrinsicKind: return ALL_GLSL(Frexp);
814 case k_ldexp_IntrinsicKind: return ALL_GLSL(Ldexp);
815
816 #define PACK(type) case k_pack##type##_IntrinsicKind: return ALL_GLSL(Pack##type); \
817 case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
818 PACK(Snorm4x8);
819 PACK(Unorm4x8);
820 PACK(Snorm2x16);
821 PACK(Unorm2x16);
822 PACK(Half2x16);
823 #undef PACK
824
825 case k_length_IntrinsicKind: return ALL_GLSL(Length);
826 case k_distance_IntrinsicKind: return ALL_GLSL(Distance);
827 case k_cross_IntrinsicKind: return ALL_GLSL(Cross);
828 case k_normalize_IntrinsicKind: return ALL_GLSL(Normalize);
829 case k_faceforward_IntrinsicKind: return ALL_GLSL(FaceForward);
830 case k_reflect_IntrinsicKind: return ALL_GLSL(Reflect);
831 case k_refract_IntrinsicKind: return ALL_GLSL(Refract);
832 case k_bitCount_IntrinsicKind: return ALL_SPIRV(BitCount);
833 case k_findLSB_IntrinsicKind: return ALL_GLSL(FindILsb);
834 case k_findMSB_IntrinsicKind: return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
835 case k_dFdx_IntrinsicKind: return FLOAT_SPIRV(DPdx);
836 case k_dFdy_IntrinsicKind: return SPECIAL(DFdy);
837 case k_fwidth_IntrinsicKind: return FLOAT_SPIRV(Fwidth);
838
839 case k_sample_IntrinsicKind: return SPECIAL(Texture);
840 case k_sampleGrad_IntrinsicKind: return SPECIAL(TextureGrad);
841 case k_sampleLod_IntrinsicKind: return SPECIAL(TextureLod);
842 case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
843
844 case k_textureRead_IntrinsicKind: return SPECIAL(TextureRead);
845 case k_textureWrite_IntrinsicKind: return SPECIAL(TextureWrite);
846 case k_textureWidth_IntrinsicKind: return SPECIAL(TextureWidth);
847 case k_textureHeight_IntrinsicKind: return SPECIAL(TextureHeight);
848
849 case k_floatBitsToInt_IntrinsicKind: return ALL_SPIRV(Bitcast);
850 case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
851 case k_intBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
852 case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
853
854 case k_any_IntrinsicKind: return BOOL_SPIRV(Any);
855 case k_all_IntrinsicKind: return BOOL_SPIRV(All);
856 case k_not_IntrinsicKind: return BOOL_SPIRV(LogicalNot);
857
858 case k_equal_IntrinsicKind:
859 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
860 SpvOpFOrdEqual,
861 SpvOpIEqual,
862 SpvOpIEqual,
863 SpvOpLogicalEqual};
864 case k_notEqual_IntrinsicKind:
865 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
866 SpvOpFUnordNotEqual,
867 SpvOpINotEqual,
868 SpvOpINotEqual,
869 SpvOpLogicalNotEqual};
870 case k_lessThan_IntrinsicKind:
871 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
872 SpvOpFOrdLessThan,
873 SpvOpSLessThan,
874 SpvOpULessThan,
875 SpvOpUndef};
876 case k_lessThanEqual_IntrinsicKind:
877 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
878 SpvOpFOrdLessThanEqual,
879 SpvOpSLessThanEqual,
880 SpvOpULessThanEqual,
881 SpvOpUndef};
882 case k_greaterThan_IntrinsicKind:
883 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
884 SpvOpFOrdGreaterThan,
885 SpvOpSGreaterThan,
886 SpvOpUGreaterThan,
887 SpvOpUndef};
888 case k_greaterThanEqual_IntrinsicKind:
889 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
890 SpvOpFOrdGreaterThanEqual,
891 SpvOpSGreaterThanEqual,
892 SpvOpUGreaterThanEqual,
893 SpvOpUndef};
894
895 case k_atomicAdd_IntrinsicKind: return SPECIAL(AtomicAdd);
896 case k_atomicLoad_IntrinsicKind: return SPECIAL(AtomicLoad);
897 case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
898
899 case k_storageBarrier_IntrinsicKind: return SPECIAL(StorageBarrier);
900 case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
901 default:
902 return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
903 }
904 }
905
writeWord(int32_t word,OutputStream & out)906 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
907 out.write((const char*) &word, sizeof(word));
908 }
909
is_float(const Type & type)910 static bool is_float(const Type& type) {
911 return (type.isScalar() || type.isVector() || type.isMatrix()) &&
912 type.componentType().isFloat();
913 }
914
is_signed(const Type & type)915 static bool is_signed(const Type& type) {
916 return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
917 }
918
is_unsigned(const Type & type)919 static bool is_unsigned(const Type& type) {
920 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
921 }
922
is_bool(const Type & type)923 static bool is_bool(const Type& type) {
924 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
925 }
926
927 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)928 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
929 if (is_float(type)) {
930 return ifFloat;
931 }
932 if (is_signed(type)) {
933 return ifInt;
934 }
935 if (is_unsigned(type)) {
936 return ifUInt;
937 }
938 if (is_bool(type)) {
939 return ifBool;
940 }
941 SkDEBUGFAIL("unrecognized type");
942 return ifFloat;
943 }
944
is_out(ModifierFlags f)945 static bool is_out(ModifierFlags f) {
946 return SkToBool(f & ModifierFlag::kOut);
947 }
948
is_in(ModifierFlags f)949 static bool is_in(ModifierFlags f) {
950 if (f & ModifierFlag::kIn) {
951 return true; // `in` and `inout` both count
952 }
953 // If neither in/out flag is set, the type is implicitly `in`.
954 return !SkToBool(f & ModifierFlag::kOut);
955 }
956
is_control_flow_op(SpvOp_ op)957 static bool is_control_flow_op(SpvOp_ op) {
958 switch (op) {
959 case SpvOpReturn:
960 case SpvOpReturnValue:
961 case SpvOpKill:
962 case SpvOpSwitch:
963 case SpvOpBranch:
964 case SpvOpBranchConditional:
965 return true;
966 default:
967 return false;
968 }
969 }
970
is_globally_reachable_op(SpvOp_ op)971 static bool is_globally_reachable_op(SpvOp_ op) {
972 switch (op) {
973 case SpvOpConstant:
974 case SpvOpConstantTrue:
975 case SpvOpConstantFalse:
976 case SpvOpConstantComposite:
977 case SpvOpTypeVoid:
978 case SpvOpTypeInt:
979 case SpvOpTypeFloat:
980 case SpvOpTypeBool:
981 case SpvOpTypeVector:
982 case SpvOpTypeMatrix:
983 case SpvOpTypeArray:
984 case SpvOpTypePointer:
985 case SpvOpTypeFunction:
986 case SpvOpTypeRuntimeArray:
987 case SpvOpTypeStruct:
988 case SpvOpTypeImage:
989 case SpvOpTypeSampledImage:
990 case SpvOpTypeSampler:
991 case SpvOpVariable:
992 case SpvOpFunction:
993 case SpvOpFunctionParameter:
994 case SpvOpFunctionEnd:
995 case SpvOpExecutionMode:
996 case SpvOpMemoryModel:
997 case SpvOpCapability:
998 case SpvOpExtInstImport:
999 case SpvOpEntryPoint:
1000 case SpvOpSource:
1001 case SpvOpSourceExtension:
1002 case SpvOpName:
1003 case SpvOpMemberName:
1004 case SpvOpDecorate:
1005 case SpvOpMemberDecorate:
1006 return true;
1007 default:
1008 return false;
1009 }
1010 }
1011
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1012 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1013 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1014 SkASSERT(opCode != SpvOpUndef);
1015 bool foundDeadCode = false;
1016 if (is_control_flow_op(opCode)) {
1017 // This instruction causes us to leave the current block.
1018 foundDeadCode = (fCurrentBlock == 0);
1019 fCurrentBlock = 0;
1020 } else if (!is_globally_reachable_op(opCode)) {
1021 foundDeadCode = (fCurrentBlock == 0);
1022 }
1023
1024 if (foundDeadCode) {
1025 // We just encountered dead code--an instruction that don't have an associated block.
1026 // Synthesize a label if this happens; this is necessary to satisfy the validator.
1027 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1028 }
1029
1030 this->writeWord((length << 16) | opCode, out);
1031 }
1032
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1033 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1034 // The straight-line label type is not important; in any case, no caches are invalidated.
1035 SkASSERT(!fCurrentBlock);
1036 fCurrentBlock = label;
1037 this->writeInstruction(SpvOpLabel, label, out);
1038 }
1039
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1040 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1041 ConditionalOpCounts ops, OutputStream& out) {
1042 switch (type) {
1043 case kBranchIsBelow:
1044 case kBranchesOnBothSides:
1045 // With a backward or bidirectional branch, we haven't seen the code between the label
1046 // and the branch yet, so any stored value is potentially suspect. Without scanning
1047 // ahead to check, the only safe option is to ditch the store cache entirely.
1048 fStoreCache.reset();
1049 [[fallthrough]];
1050
1051 case kBranchIsAbove:
1052 // With a forward branch, we can rely on stores that we had cached at the start of the
1053 // statement/expression, if they haven't been touched yet. Anything newer than that is
1054 // pruned.
1055 this->pruneConditionalOps(ops);
1056 break;
1057 }
1058
1059 // Emit the label.
1060 this->writeLabel(label, kBranchlessBlock, out);
1061 }
1062
writeInstruction(SpvOp_ opCode,OutputStream & out)1063 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1064 this->writeOpCode(opCode, 1, out);
1065 }
1066
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1067 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1068 this->writeOpCode(opCode, 2, out);
1069 this->writeWord(word1, out);
1070 }
1071
writeString(std::string_view s,OutputStream & out)1072 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1073 out.write(s.data(), s.length());
1074 switch (s.length() % 4) {
1075 case 1:
1076 out.write8(0);
1077 [[fallthrough]];
1078 case 2:
1079 out.write8(0);
1080 [[fallthrough]];
1081 case 3:
1082 out.write8(0);
1083 break;
1084 default:
1085 this->writeWord(0, out);
1086 break;
1087 }
1088 }
1089
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1090 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1091 OutputStream& out) {
1092 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1093 this->writeString(string, out);
1094 }
1095
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1096 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1097 OutputStream& out) {
1098 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1099 this->writeWord(word1, out);
1100 this->writeString(string, out);
1101 }
1102
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1103 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1104 std::string_view string, OutputStream& out) {
1105 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1106 this->writeWord(word1, out);
1107 this->writeWord(word2, out);
1108 this->writeString(string, out);
1109 }
1110
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1111 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1112 OutputStream& out) {
1113 this->writeOpCode(opCode, 3, out);
1114 this->writeWord(word1, out);
1115 this->writeWord(word2, out);
1116 }
1117
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1118 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1119 int32_t word3, OutputStream& out) {
1120 this->writeOpCode(opCode, 4, out);
1121 this->writeWord(word1, out);
1122 this->writeWord(word2, out);
1123 this->writeWord(word3, out);
1124 }
1125
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1126 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1127 int32_t word3, int32_t word4, OutputStream& out) {
1128 this->writeOpCode(opCode, 5, out);
1129 this->writeWord(word1, out);
1130 this->writeWord(word2, out);
1131 this->writeWord(word3, out);
1132 this->writeWord(word4, out);
1133 }
1134
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1135 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1136 int32_t word3, int32_t word4, int32_t word5,
1137 OutputStream& out) {
1138 this->writeOpCode(opCode, 6, out);
1139 this->writeWord(word1, out);
1140 this->writeWord(word2, out);
1141 this->writeWord(word3, out);
1142 this->writeWord(word4, out);
1143 this->writeWord(word5, out);
1144 }
1145
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1146 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1147 int32_t word3, int32_t word4, int32_t word5,
1148 int32_t word6, OutputStream& out) {
1149 this->writeOpCode(opCode, 7, out);
1150 this->writeWord(word1, out);
1151 this->writeWord(word2, out);
1152 this->writeWord(word3, out);
1153 this->writeWord(word4, out);
1154 this->writeWord(word5, out);
1155 this->writeWord(word6, out);
1156 }
1157
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1158 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1159 int32_t word3, int32_t word4, int32_t word5,
1160 int32_t word6, int32_t word7, OutputStream& out) {
1161 this->writeOpCode(opCode, 8, out);
1162 this->writeWord(word1, out);
1163 this->writeWord(word2, out);
1164 this->writeWord(word3, out);
1165 this->writeWord(word4, out);
1166 this->writeWord(word5, out);
1167 this->writeWord(word6, out);
1168 this->writeWord(word7, out);
1169 }
1170
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1171 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1172 int32_t word3, int32_t word4, int32_t word5,
1173 int32_t word6, int32_t word7, int32_t word8,
1174 OutputStream& out) {
1175 this->writeOpCode(opCode, 9, out);
1176 this->writeWord(word1, out);
1177 this->writeWord(word2, out);
1178 this->writeWord(word3, out);
1179 this->writeWord(word4, out);
1180 this->writeWord(word5, out);
1181 this->writeWord(word6, out);
1182 this->writeWord(word7, out);
1183 this->writeWord(word8, out);
1184 }
1185
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1186 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1187 const TArray<Word>& words) {
1188 // Assemble a cache key for this instruction.
1189 Instruction key;
1190 key.fOp = opCode;
1191 key.fWords.resize(words.size());
1192 key.fResultKind = Word::Kind::kNone;
1193
1194 for (int index = 0; index < words.size(); ++index) {
1195 const Word& word = words[index];
1196 key.fWords[index] = word.fValue;
1197 if (word.isResult()) {
1198 SkASSERT(key.fResultKind == Word::Kind::kNone);
1199 key.fResultKind = word.fKind;
1200 }
1201 }
1202
1203 return key;
1204 }
1205
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1206 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1207 const TArray<Word>& words,
1208 OutputStream& out) {
1209 // writeOpLoad and writeOpStore have dedicated code.
1210 SkASSERT(opCode != SpvOpLoad);
1211 SkASSERT(opCode != SpvOpStore);
1212
1213 // If this instruction exists in our op cache, return the cached SpvId.
1214 Instruction key = BuildInstructionKey(opCode, words);
1215 if (SpvId* cachedOp = fOpCache.find(key)) {
1216 return *cachedOp;
1217 }
1218
1219 SpvId result = NA;
1220 Precision precision = Precision::kDefault;
1221
1222 switch (key.fResultKind) {
1223 case Word::Kind::kUniqueResult:
1224 // The instruction returns a SpvId, but we do not want deduplication.
1225 result = this->nextId(Precision::kDefault);
1226 fSpvIdCache.set(result, key);
1227 break;
1228
1229 case Word::Kind::kNone:
1230 // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1231 fOpCache.set(key, result);
1232 break;
1233
1234 case Word::Kind::kRelaxedPrecisionResult:
1235 precision = Precision::kRelaxed;
1236 [[fallthrough]];
1237
1238 case Word::Kind::kKeyedResult:
1239 [[fallthrough]];
1240
1241 case Word::Kind::kDefaultPrecisionResult:
1242 // Consume a new SpvId.
1243 result = this->nextId(precision);
1244 fOpCache.set(key, result);
1245 fSpvIdCache.set(result, key);
1246
1247 // Globally-reachable ops are not subject to the whims of flow control.
1248 if (!is_globally_reachable_op(opCode)) {
1249 fReachableOps.push_back(result);
1250 }
1251 break;
1252
1253 default:
1254 SkDEBUGFAIL("unexpected result kind");
1255 break;
1256 }
1257
1258 // Write the requested instruction.
1259 this->writeOpCode(opCode, words.size() + 1, out);
1260 for (const Word& word : words) {
1261 if (word.isResult()) {
1262 SkASSERT(result != NA);
1263 this->writeWord(result, out);
1264 } else {
1265 this->writeWord(word.fValue, out);
1266 }
1267 }
1268
1269 // Return the result.
1270 return result;
1271 }
1272
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1273 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1274 Precision precision,
1275 SpvId pointer,
1276 OutputStream& out) {
1277 // Look for this pointer in our load-cache.
1278 if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1279 return *cachedOp;
1280 }
1281
1282 // Write the requested OpLoad instruction.
1283 SpvId result = this->nextId(precision);
1284 this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1285 return result;
1286 }
1287
writeOpStore(SpvStorageClass_ storageClass,SpvId pointer,SpvId value,OutputStream & out)1288 void SPIRVCodeGenerator::writeOpStore(SpvStorageClass_ storageClass,
1289 SpvId pointer,
1290 SpvId value,
1291 OutputStream& out) {
1292 // Write the uncached SpvOpStore directly.
1293 this->writeInstruction(SpvOpStore, pointer, value, out);
1294
1295 if (storageClass == SpvStorageClassFunction) {
1296 // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1297 // return the cached value as-is.
1298 fStoreCache.set(pointer, value);
1299 fStoreOps.push_back(pointer);
1300 }
1301 }
1302
writeOpConstantTrue(const Type & type)1303 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1304 return this->writeInstruction(SpvOpConstantTrue,
1305 Words{this->getType(type), Word::Result()},
1306 fConstantBuffer);
1307 }
1308
writeOpConstantFalse(const Type & type)1309 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1310 return this->writeInstruction(SpvOpConstantFalse,
1311 Words{this->getType(type), Word::Result()},
1312 fConstantBuffer);
1313 }
1314
writeOpConstant(const Type & type,int32_t valueBits)1315 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1316 return this->writeInstruction(
1317 SpvOpConstant,
1318 Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1319 fConstantBuffer);
1320 }
1321
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1322 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1323 const TArray<SpvId>& values) {
1324 SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1325
1326 Words words;
1327 words.push_back(this->getType(type));
1328 words.push_back(Word::Result());
1329 for (SpvId value : values) {
1330 words.push_back(value);
1331 }
1332 return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1333 }
1334
toConstants(SpvId value,TArray<SpvId> * constants)1335 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1336 Instruction* instr = fSpvIdCache.find(value);
1337 if (!instr) {
1338 return false;
1339 }
1340 switch (instr->fOp) {
1341 case SpvOpConstant:
1342 case SpvOpConstantTrue:
1343 case SpvOpConstantFalse:
1344 constants->push_back(value);
1345 return true;
1346
1347 case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1348 // Start at word 2 to skip past ResultType and ResultID.
1349 for (int i = 2; i < instr->fWords.size(); ++i) {
1350 if (!this->toConstants(instr->fWords[i], constants)) {
1351 return false;
1352 }
1353 }
1354 return true;
1355
1356 default:
1357 return false;
1358 }
1359 }
1360
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1361 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1362 for (SpvId value : values) {
1363 if (!this->toConstants(value, constants)) {
1364 return false;
1365 }
1366 }
1367 return true;
1368 }
1369
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1370 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1371 const TArray<SpvId>& values,
1372 OutputStream& out) {
1373 // If this is a vector composed entirely of literals, write a constant-composite instead.
1374 if (type.isVector()) {
1375 STArray<4, SpvId> constants;
1376 if (this->toConstants(SkSpan(values), &constants)) {
1377 // Create a vector from literals.
1378 return this->writeOpConstantComposite(type, constants);
1379 }
1380 }
1381
1382 // If this is a matrix composed entirely of literals, constant-composite them instead.
1383 if (type.isMatrix()) {
1384 STArray<16, SpvId> constants;
1385 if (this->toConstants(SkSpan(values), &constants)) {
1386 // Create each matrix column.
1387 SkASSERT(type.isMatrix());
1388 const Type& vecType = type.columnType(fContext);
1389 STArray<4, SpvId> columnIDs;
1390 for (int index=0; index < type.columns(); ++index) {
1391 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1392 type.rows());
1393 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1394 }
1395 // Compose the matrix from its columns.
1396 return this->writeOpConstantComposite(type, columnIDs);
1397 }
1398 }
1399
1400 Words words;
1401 words.push_back(this->getType(type));
1402 words.push_back(Word::Result(type));
1403 for (SpvId value : values) {
1404 words.push_back(value);
1405 }
1406
1407 return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1408 }
1409
resultTypeForInstruction(const Instruction & instr)1410 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1411 const Instruction& instr) {
1412 // This list should contain every op that we cache that has a result and result-type.
1413 // (If one is missing, we will not find some optimization opportunities.)
1414 // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1415 // universally true, so it's configurable on a per-op basis.
1416 int resultTypeWord;
1417 switch (instr.fOp) {
1418 case SpvOpConstant:
1419 case SpvOpConstantTrue:
1420 case SpvOpConstantFalse:
1421 case SpvOpConstantComposite:
1422 case SpvOpCompositeConstruct:
1423 case SpvOpCompositeExtract:
1424 case SpvOpLoad:
1425 resultTypeWord = 0;
1426 break;
1427
1428 default:
1429 return nullptr;
1430 }
1431
1432 Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1433 SkASSERT(typeInstr);
1434 return typeInstr;
1435 }
1436
numComponentsForVecInstruction(const Instruction & instr)1437 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1438 // If an instruction is in the op cache, its type should be as well.
1439 Instruction* typeInstr = this->resultTypeForInstruction(instr);
1440 SkASSERT(typeInstr);
1441 SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1442 typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1443
1444 // For vectors, extract their column count. Scalars have one component by definition.
1445 // SpvOpTypeVector ResultID ComponentType NumComponents
1446 return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1447 : 1;
1448 }
1449
toComponent(SpvId id,int component)1450 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1451 Instruction* instr = fSpvIdCache.find(id);
1452 if (!instr) {
1453 return NA;
1454 }
1455 if (instr->fOp == SpvOpConstantComposite) {
1456 // SpvOpConstantComposite ResultType ResultID [components...]
1457 // Add 2 to the component index to skip past ResultType and ResultID.
1458 return instr->fWords[2 + component];
1459 }
1460 if (instr->fOp == SpvOpCompositeConstruct) {
1461 // SpvOpCompositeConstruct ResultType ResultID [components...]
1462 // Vectors have special rules; check to see if we are composing a vector.
1463 Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1464 SkASSERT(composedType);
1465
1466 // When composing a non-vector, each instruction word maps 1:1 to the component index.
1467 // We can just extract out the associated component directly.
1468 if (composedType->fOp != SpvOpTypeVector) {
1469 return instr->fWords[2 + component];
1470 }
1471
1472 // When composing a vector, components can be either scalars or vectors.
1473 // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1474 for (int index = 2; index < instr->fWords.size(); ++index) {
1475 int32_t currentWord = instr->fWords[index];
1476
1477 // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1478 Instruction* subinstr = fSpvIdCache.find(currentWord);
1479 if (!subinstr) {
1480 return NA;
1481 }
1482 // If this subinstruction contains the component we're looking for...
1483 int numComponents = this->numComponentsForVecInstruction(*subinstr);
1484 if (component < numComponents) {
1485 if (numComponents == 1) {
1486 // ... it's a scalar. Return it.
1487 SkASSERT(component == 0);
1488 return currentWord;
1489 } else {
1490 // ... it's a vector. Recurse into it.
1491 return this->toComponent(currentWord, component);
1492 }
1493 }
1494 // This sub-instruction doesn't contain our component. Keep walking forward.
1495 component -= numComponents;
1496 }
1497 SkDEBUGFAIL("component index goes past the end of this composite value");
1498 return NA;
1499 }
1500 return NA;
1501 }
1502
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1503 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1504 SpvId base,
1505 int component,
1506 OutputStream& out) {
1507 // If the base op is a composite, we can extract from it directly.
1508 SpvId result = this->toComponent(base, component);
1509 if (result != NA) {
1510 return result;
1511 }
1512 return this->writeInstruction(
1513 SpvOpCompositeExtract,
1514 {this->getType(type), Word::Result(type), base, Word::Number(component)},
1515 out);
1516 }
1517
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1518 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1519 SpvId base,
1520 int componentA,
1521 int componentB,
1522 OutputStream& out) {
1523 // If the base op is a composite, we can extract from it directly.
1524 SpvId result = this->toComponent(base, componentA);
1525 if (result != NA) {
1526 return this->writeOpCompositeExtract(type, result, componentB, out);
1527 }
1528 return this->writeInstruction(SpvOpCompositeExtract,
1529 {this->getType(type),
1530 Word::Result(type),
1531 base,
1532 Word::Number(componentA),
1533 Word::Number(componentB)},
1534 out);
1535 }
1536
writeCapabilities(OutputStream & out)1537 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1538 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1539 if (fCapabilities & bit) {
1540 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1541 }
1542 }
1543 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1544 }
1545
nextId(const Type * type)1546 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1547 return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1548 ? Precision::kRelaxed
1549 : Precision::kDefault);
1550 }
1551
nextId(Precision precision)1552 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1553 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1554 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1555 fDecorationBuffer);
1556 }
1557 return fIdCount++;
1558 }
1559
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1560 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1561 // If we've already written out this struct, return its existing SpvId.
1562 if (SpvId* cachedStructId = fStructMap.find(&type)) {
1563 return *cachedStructId;
1564 }
1565
1566 // Write all of the field types first, so we don't inadvertently write them while we're in the
1567 // middle of writing the struct instruction.
1568 Words words;
1569 words.push_back(Word::UniqueResult());
1570 for (const auto& f : type.fields()) {
1571 words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1572 }
1573 SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1574 this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1575 fStructMap.set(&type, resultId);
1576
1577 size_t offset = 0;
1578 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1579 const Field& field = type.fields()[i];
1580 if (!memoryLayout.isSupported(*field.fType)) {
1581 fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1582 "' is not permitted here");
1583 return resultId;
1584 }
1585 size_t size = memoryLayout.size(*field.fType);
1586 size_t alignment = memoryLayout.alignment(*field.fType);
1587 const Layout& fieldLayout = field.fLayout;
1588 if (fieldLayout.fOffset >= 0) {
1589 if (fieldLayout.fOffset < (int) offset) {
1590 fContext.fErrors->error(field.fPosition, "offset of field '" +
1591 std::string(field.fName) + "' must be at least " + std::to_string(offset));
1592 }
1593 if (fieldLayout.fOffset % alignment) {
1594 fContext.fErrors->error(field.fPosition,
1595 "offset of field '" + std::string(field.fName) +
1596 "' must be a multiple of " + std::to_string(alignment));
1597 }
1598 offset = fieldLayout.fOffset;
1599 } else {
1600 size_t mod = offset % alignment;
1601 if (mod) {
1602 offset += alignment - mod;
1603 }
1604 }
1605 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1606 this->writeFieldLayout(fieldLayout, resultId, i);
1607 if (field.fLayout.fBuiltin < 0) {
1608 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1609 (SpvId) offset, fDecorationBuffer);
1610 }
1611 if (field.fType->isMatrix()) {
1612 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1613 fDecorationBuffer);
1614 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1615 (SpvId) memoryLayout.stride(*field.fType),
1616 fDecorationBuffer);
1617 }
1618 if (!field.fType->highPrecision()) {
1619 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1620 SpvDecorationRelaxedPrecision, fDecorationBuffer);
1621 }
1622 offset += size;
1623 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1624 offset += alignment - offset % alignment;
1625 }
1626 }
1627
1628 return resultId;
1629 }
1630
getType(const Type & type)1631 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1632 return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1633 }
1634
layout_flags_to_image_format(LayoutFlags flags)1635 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1636 flags &= LayoutFlag::kAllPixelFormats;
1637 switch (flags.value()) {
1638 case (int)LayoutFlag::kRGBA8:
1639 return SpvImageFormatRgba8;
1640
1641 case (int)LayoutFlag::kRGBA32F:
1642 return SpvImageFormatRgba32f;
1643
1644 case (int)LayoutFlag::kR32F:
1645 return SpvImageFormatR32f;
1646
1647 default:
1648 return SpvImageFormatUnknown;
1649 }
1650
1651 SkUNREACHABLE;
1652 }
1653
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1654 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1655 const Layout& typeLayout,
1656 const MemoryLayout& memoryLayout) {
1657 const Type* type = &rawType;
1658
1659 switch (type->typeKind()) {
1660 case Type::TypeKind::kVoid: {
1661 return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1662 }
1663 case Type::TypeKind::kScalar:
1664 case Type::TypeKind::kLiteral: {
1665 if (type->isBoolean()) {
1666 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1667 }
1668 if (type->isSigned()) {
1669 return this->writeInstruction(
1670 SpvOpTypeInt,
1671 Words{Word::Result(), Word::Number(32), Word::Number(1)},
1672 fConstantBuffer);
1673 }
1674 if (type->isUnsigned()) {
1675 return this->writeInstruction(
1676 SpvOpTypeInt,
1677 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1678 fConstantBuffer);
1679 }
1680 if (type->isFloat()) {
1681 return this->writeInstruction(
1682 SpvOpTypeFloat,
1683 Words{Word::Result(), Word::Number(32)},
1684 fConstantBuffer);
1685 }
1686 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1687 return NA;
1688 }
1689 case Type::TypeKind::kVector: {
1690 SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1691 return this->writeInstruction(
1692 SpvOpTypeVector,
1693 Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1694 fConstantBuffer);
1695 }
1696 case Type::TypeKind::kMatrix: {
1697 SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1698 typeLayout,
1699 memoryLayout);
1700 return this->writeInstruction(
1701 SpvOpTypeMatrix,
1702 Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1703 fConstantBuffer);
1704 }
1705 case Type::TypeKind::kArray: {
1706 if (!memoryLayout.isSupported(*type)) {
1707 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1708 "' is not permitted here");
1709 return NA;
1710 }
1711 size_t stride = memoryLayout.stride(*type);
1712 SpvId typeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1713 SpvId result = NA;
1714 if (type->isUnsizedArray()) {
1715 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1716 Words{Word::KeyedResult(stride), typeId},
1717 fConstantBuffer);
1718 } else {
1719 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1720 result = this->writeInstruction(SpvOpTypeArray,
1721 Words{Word::KeyedResult(stride), typeId, countId},
1722 fConstantBuffer);
1723 }
1724 this->writeInstruction(SpvOpDecorate,
1725 {result, SpvDecorationArrayStride, Word::Number(stride)},
1726 fDecorationBuffer);
1727 return result;
1728 }
1729 case Type::TypeKind::kStruct: {
1730 return this->writeStruct(*type, memoryLayout);
1731 }
1732 case Type::TypeKind::kSeparateSampler: {
1733 return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1734 }
1735 case Type::TypeKind::kSampler: {
1736 if (SpvDimBuffer == type->dimensions()) {
1737 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1738 }
1739 SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1740 return this->writeInstruction(SpvOpTypeSampledImage,
1741 Words{Word::Result(), imageTypeId},
1742 fConstantBuffer);
1743 }
1744 case Type::TypeKind::kTexture: {
1745 SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1746 kDefaultTypeLayout,
1747 memoryLayout);
1748
1749 bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1750 SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1751 ? layout_flags_to_image_format(typeLayout.fFlags)
1752 : SpvImageFormatUnknown;
1753
1754 return this->writeInstruction(SpvOpTypeImage,
1755 Words{Word::Result(),
1756 floatTypeId,
1757 Word::Number(type->dimensions()),
1758 Word::Number(type->isDepth()),
1759 Word::Number(type->isArrayedTexture()),
1760 Word::Number(type->isMultisampled()),
1761 Word::Number(sampled ? 1 : 2),
1762 format},
1763 fConstantBuffer);
1764 }
1765 case Type::TypeKind::kAtomic: {
1766 // SkSL currently only supports the atomicUint type.
1767 SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1768 // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1769 // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1770 return this->writeInstruction(SpvOpTypeInt,
1771 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1772 fConstantBuffer);
1773 }
1774 default: {
1775 SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1776 return NA;
1777 }
1778 }
1779 }
1780
getFunctionType(const FunctionDeclaration & function)1781 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1782 Words words;
1783 words.push_back(Word::Result());
1784 words.push_back(this->getType(function.returnType()));
1785 for (const Variable* parameter : function.parameters()) {
1786 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1787 words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1788 parameter->layout()));
1789 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1790 kDefaultTypeLayout));
1791 } else {
1792 words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1793 }
1794 }
1795 return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1796 }
1797
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1798 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1799 const Layout& parameterLayout) {
1800 // glslang treats all function arguments as pointers whether they need to be or
1801 // not. I was initially puzzled by this until I ran bizarre failures with certain
1802 // patterns of function calls and control constructs, as exemplified by this minimal
1803 // failure case:
1804 //
1805 // void sphere(float x) {
1806 // }
1807 //
1808 // void map() {
1809 // sphere(1.0);
1810 // }
1811 //
1812 // void main() {
1813 // for (int i = 0; i < 1; i++) {
1814 // map();
1815 // }
1816 // }
1817 //
1818 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1819 // crashes. Making it take a float* and storing the argument in a temporary variable,
1820 // as glslang does, fixes it.
1821 //
1822 // The consensus among shader compiler authors seems to be that GPU driver generally don't
1823 // handle value-based parameters consistently. It is highly likely that they fit their
1824 // implementations to conform to glslang. We take care to do so ourselves.
1825 //
1826 // Our implementation first stores every parameter value into a function storage-class pointer
1827 // before calling a function. The exception is for opaque handle types (samplers and textures)
1828 // which must be stored in a pointer with UniformConstant storage-class. This prevents
1829 // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1830 // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1831 SpvStorageClass_ storageClass;
1832 if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1833 parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1834 parameterType.typeKind() == Type::TypeKind::kTexture) {
1835 storageClass = SpvStorageClassUniformConstant;
1836 } else {
1837 storageClass = SpvStorageClassFunction;
1838 }
1839 return this->getPointerType(parameterType,
1840 parameterLayout,
1841 this->memoryLayoutForStorageClass(storageClass),
1842 storageClass);
1843 }
1844
getPointerType(const Type & type,SpvStorageClass_ storageClass)1845 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1846 return this->getPointerType(type,
1847 kDefaultTypeLayout,
1848 this->memoryLayoutForStorageClass(storageClass),
1849 storageClass);
1850 }
1851
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,SpvStorageClass_ storageClass)1852 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1853 const Layout& typeLayout,
1854 const MemoryLayout& memoryLayout,
1855 SpvStorageClass_ storageClass) {
1856 return this->writeInstruction(SpvOpTypePointer,
1857 Words{Word::Result(),
1858 Word::Number(storageClass),
1859 this->getType(type, typeLayout, memoryLayout)},
1860 fConstantBuffer);
1861 }
1862
writeExpression(const Expression & expr,OutputStream & out)1863 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1864 switch (expr.kind()) {
1865 case Expression::Kind::kBinary:
1866 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1867 case Expression::Kind::kConstructorArrayCast:
1868 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1869 case Expression::Kind::kConstructorArray:
1870 case Expression::Kind::kConstructorStruct:
1871 return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1872 case Expression::Kind::kConstructorDiagonalMatrix:
1873 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1874 case Expression::Kind::kConstructorMatrixResize:
1875 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1876 case Expression::Kind::kConstructorScalarCast:
1877 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1878 case Expression::Kind::kConstructorSplat:
1879 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1880 case Expression::Kind::kConstructorCompound:
1881 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1882 case Expression::Kind::kConstructorCompoundCast:
1883 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1884 case Expression::Kind::kEmpty:
1885 return NA;
1886 case Expression::Kind::kFieldAccess:
1887 return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1888 case Expression::Kind::kFunctionCall:
1889 return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1890 case Expression::Kind::kLiteral:
1891 return this->writeLiteral(expr.as<Literal>());
1892 case Expression::Kind::kPrefix:
1893 return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1894 case Expression::Kind::kPostfix:
1895 return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1896 case Expression::Kind::kSwizzle:
1897 return this->writeSwizzle(expr.as<Swizzle>(), out);
1898 case Expression::Kind::kVariableReference:
1899 return this->writeVariableReference(expr.as<VariableReference>(), out);
1900 case Expression::Kind::kTernary:
1901 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1902 case Expression::Kind::kIndex:
1903 return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1904 case Expression::Kind::kSetting:
1905 return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
1906 default:
1907 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1908 break;
1909 }
1910 return NA;
1911 }
1912
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1913 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1914 const FunctionDeclaration& function = c.function();
1915 Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1916 if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1917 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1918 "'");
1919 return NA;
1920 }
1921 const ExpressionArray& arguments = c.arguments();
1922 int32_t intrinsicId = intrinsic.floatOp;
1923 if (!arguments.empty()) {
1924 const Type& type = arguments[0]->type();
1925 if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1926 // Keep the default float op.
1927 } else {
1928 intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1929 intrinsic.unsignedOp, intrinsic.boolOp);
1930 }
1931 }
1932 switch (intrinsic.opKind) {
1933 case kGLSL_STD_450_IntrinsicOpcodeKind: {
1934 SpvId result = this->nextId(&c.type());
1935 TArray<SpvId> argumentIds;
1936 argumentIds.reserve_exact(arguments.size());
1937 std::vector<TempVar> tempVars;
1938 for (int i = 0; i < arguments.size(); i++) {
1939 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1940 }
1941 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1942 this->writeWord(this->getType(c.type()), out);
1943 this->writeWord(result, out);
1944 this->writeWord(fGLSLExtendedInstructions, out);
1945 this->writeWord(intrinsicId, out);
1946 for (SpvId id : argumentIds) {
1947 this->writeWord(id, out);
1948 }
1949 this->copyBackTempVars(tempVars, out);
1950 return result;
1951 }
1952 case kSPIRV_IntrinsicOpcodeKind: {
1953 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
1954 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
1955 intrinsicId = SpvOpFMul;
1956 }
1957 SpvId result = this->nextId(&c.type());
1958 TArray<SpvId> argumentIds;
1959 argumentIds.reserve_exact(arguments.size());
1960 std::vector<TempVar> tempVars;
1961 for (int i = 0; i < arguments.size(); i++) {
1962 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1963 }
1964 if (!c.type().isVoid()) {
1965 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1966 this->writeWord(this->getType(c.type()), out);
1967 this->writeWord(result, out);
1968 } else {
1969 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
1970 }
1971 for (SpvId id : argumentIds) {
1972 this->writeWord(id, out);
1973 }
1974 this->copyBackTempVars(tempVars, out);
1975 return result;
1976 }
1977 case kSpecial_IntrinsicOpcodeKind:
1978 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1979 default:
1980 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
1981 function.description() + "'");
1982 return NA;
1983 }
1984 }
1985
vectorize(const Expression & arg,int vectorSize,OutputStream & out)1986 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
1987 SkASSERT(vectorSize >= 1 && vectorSize <= 4);
1988 const Type& argType = arg.type();
1989 if (argType.isScalar() && vectorSize > 1) {
1990 ConstructorSplat splat{arg.fPosition,
1991 argType.toCompound(fContext, vectorSize, /*rows=*/1),
1992 arg.clone()};
1993 return this->writeConstructorSplat(splat, out);
1994 }
1995
1996 SkASSERT(vectorSize == argType.columns());
1997 return this->writeExpression(arg, out);
1998 }
1999
vectorize(const ExpressionArray & args,OutputStream & out)2000 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2001 int vectorSize = 1;
2002 for (const auto& a : args) {
2003 if (a->type().isVector()) {
2004 if (vectorSize > 1) {
2005 SkASSERT(a->type().columns() == vectorSize);
2006 } else {
2007 vectorSize = a->type().columns();
2008 }
2009 }
2010 }
2011 TArray<SpvId> result;
2012 result.reserve_exact(args.size());
2013 for (const auto& arg : args) {
2014 result.push_back(this->vectorize(*arg, vectorSize, out));
2015 }
2016 return result;
2017 }
2018
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2019 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2020 SpvId signedInst, SpvId unsignedInst,
2021 const TArray<SpvId>& args,
2022 OutputStream& out) {
2023 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2024 this->writeWord(this->getType(type), out);
2025 this->writeWord(id, out);
2026 this->writeWord(fGLSLExtendedInstructions, out);
2027 this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2028 for (SpvId a : args) {
2029 this->writeWord(a, out);
2030 }
2031 }
2032
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2033 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2034 OutputStream& out) {
2035 const ExpressionArray& arguments = c.arguments();
2036 const Type& callType = c.type();
2037 SpvId result = this->nextId(nullptr);
2038 switch (kind) {
2039 case kAtan_SpecialIntrinsic: {
2040 STArray<2, SpvId> argumentIds;
2041 for (const std::unique_ptr<Expression>& arg : arguments) {
2042 argumentIds.push_back(this->writeExpression(*arg, out));
2043 }
2044 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2045 this->writeWord(this->getType(callType), out);
2046 this->writeWord(result, out);
2047 this->writeWord(fGLSLExtendedInstructions, out);
2048 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2049 for (SpvId id : argumentIds) {
2050 this->writeWord(id, out);
2051 }
2052 break;
2053 }
2054 case kSampledImage_SpecialIntrinsic: {
2055 SkASSERT(arguments.size() == 2);
2056 SpvId img = this->writeExpression(*arguments[0], out);
2057 SpvId sampler = this->writeExpression(*arguments[1], out);
2058 this->writeInstruction(SpvOpSampledImage,
2059 this->getType(callType),
2060 result,
2061 img,
2062 sampler,
2063 out);
2064 break;
2065 }
2066 case kSubpassLoad_SpecialIntrinsic: {
2067 SpvId img = this->writeExpression(*arguments[0], out);
2068 ExpressionArray args;
2069 args.reserve_exact(2);
2070 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2071 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2072 ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2073 SpvId coords = this->writeExpression(ctor, out);
2074 if (arguments.size() == 1) {
2075 this->writeInstruction(SpvOpImageRead,
2076 this->getType(callType),
2077 result,
2078 img,
2079 coords,
2080 out);
2081 } else {
2082 SkASSERT(arguments.size() == 2);
2083 SpvId sample = this->writeExpression(*arguments[1], out);
2084 this->writeInstruction(SpvOpImageRead,
2085 this->getType(callType),
2086 result,
2087 img,
2088 coords,
2089 SpvImageOperandsSampleMask,
2090 sample,
2091 out);
2092 }
2093 break;
2094 }
2095 case kTexture_SpecialIntrinsic: {
2096 SpvOp_ op = SpvOpImageSampleImplicitLod;
2097 const Type& arg1Type = arguments[1]->type();
2098 switch (arguments[0]->type().dimensions()) {
2099 case SpvDim1D:
2100 if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2101 op = SpvOpImageSampleProjImplicitLod;
2102 } else {
2103 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2104 }
2105 break;
2106 case SpvDim2D:
2107 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2108 op = SpvOpImageSampleProjImplicitLod;
2109 } else {
2110 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2111 }
2112 break;
2113 case SpvDim3D:
2114 if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2115 op = SpvOpImageSampleProjImplicitLod;
2116 } else {
2117 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2118 }
2119 break;
2120 case SpvDimCube: // fall through
2121 case SpvDimRect: // fall through
2122 case SpvDimBuffer: // fall through
2123 case SpvDimSubpassData:
2124 break;
2125 }
2126 SpvId type = this->getType(callType);
2127 SpvId sampler = this->writeExpression(*arguments[0], out);
2128 SpvId uv = this->writeExpression(*arguments[1], out);
2129 if (arguments.size() == 3) {
2130 this->writeInstruction(op, type, result, sampler, uv,
2131 SpvImageOperandsBiasMask,
2132 this->writeExpression(*arguments[2], out),
2133 out);
2134 } else {
2135 SkASSERT(arguments.size() == 2);
2136 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2137 SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2138 *fContext.fTypes.fFloat);
2139 this->writeInstruction(op, type, result, sampler, uv,
2140 SpvImageOperandsBiasMask, lodBias, out);
2141 } else {
2142 this->writeInstruction(op, type, result, sampler, uv,
2143 out);
2144 }
2145 }
2146 break;
2147 }
2148 case kTextureGrad_SpecialIntrinsic: {
2149 SpvOp_ op = SpvOpImageSampleExplicitLod;
2150 SkASSERT(arguments.size() == 4);
2151 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2152 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2153 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2154 SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2155 SpvId type = this->getType(callType);
2156 SpvId sampler = this->writeExpression(*arguments[0], out);
2157 SpvId uv = this->writeExpression(*arguments[1], out);
2158 SpvId dPdx = this->writeExpression(*arguments[2], out);
2159 SpvId dPdy = this->writeExpression(*arguments[3], out);
2160 this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2161 dPdx, dPdy, out);
2162 break;
2163 }
2164 case kTextureLod_SpecialIntrinsic: {
2165 SpvOp_ op = SpvOpImageSampleExplicitLod;
2166 SkASSERT(arguments.size() == 3);
2167 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2168 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2169 const Type& arg1Type = arguments[1]->type();
2170 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2171 op = SpvOpImageSampleProjExplicitLod;
2172 } else {
2173 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2174 }
2175 SpvId type = this->getType(callType);
2176 SpvId sampler = this->writeExpression(*arguments[0], out);
2177 SpvId uv = this->writeExpression(*arguments[1], out);
2178 this->writeInstruction(op, type, result, sampler, uv,
2179 SpvImageOperandsLodMask,
2180 this->writeExpression(*arguments[2], out),
2181 out);
2182 break;
2183 }
2184 case kTextureRead_SpecialIntrinsic: {
2185 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2186 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2187
2188 SpvId type = this->getType(callType);
2189 SpvId image = this->writeExpression(*arguments[0], out);
2190 SpvId coord = this->writeExpression(*arguments[1], out);
2191
2192 const Type& arg0Type = arguments[0]->type();
2193 SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2194
2195 switch (arg0Type.textureAccess()) {
2196 case Type::TextureAccess::kSample:
2197 this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2198 SpvImageOperandsLodMask,
2199 this->writeOpConstant(*fContext.fTypes.fInt, 0),
2200 out);
2201 break;
2202 case Type::TextureAccess::kRead:
2203 case Type::TextureAccess::kReadWrite:
2204 this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2205 break;
2206 case Type::TextureAccess::kWrite:
2207 default:
2208 SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2209 break;
2210 }
2211
2212 break;
2213 }
2214 case kTextureWrite_SpecialIntrinsic: {
2215 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2216 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2217 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2218
2219 SpvId image = this->writeExpression(*arguments[0], out);
2220 SpvId coord = this->writeExpression(*arguments[1], out);
2221 SpvId texel = this->writeExpression(*arguments[2], out);
2222
2223 this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2224 break;
2225 }
2226 case kTextureWidth_SpecialIntrinsic:
2227 case kTextureHeight_SpecialIntrinsic: {
2228 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2229 fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2230
2231 SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2232 SpvId dims = this->nextId(nullptr);
2233 SpvId image = this->writeExpression(*arguments[0], out);
2234 this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2235
2236 SpvId type = this->getType(callType);
2237 int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2238 this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2239 break;
2240 }
2241 case kMod_SpecialIntrinsic: {
2242 TArray<SpvId> args = this->vectorize(arguments, out);
2243 SkASSERT(args.size() == 2);
2244 const Type& operandType = arguments[0]->type();
2245 SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2246 SkASSERT(op != SpvOpUndef);
2247 this->writeOpCode(op, 5, out);
2248 this->writeWord(this->getType(operandType), out);
2249 this->writeWord(result, out);
2250 this->writeWord(args[0], out);
2251 this->writeWord(args[1], out);
2252 break;
2253 }
2254 case kDFdy_SpecialIntrinsic: {
2255 SpvId fn = this->writeExpression(*arguments[0], out);
2256 this->writeOpCode(SpvOpDPdy, 4, out);
2257 this->writeWord(this->getType(callType), out);
2258 this->writeWord(result, out);
2259 this->writeWord(fn, out);
2260 if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2261 this->addRTFlipUniform(c.fPosition);
2262 ComponentArray componentArray;
2263 for (int index = 0; index < callType.columns(); ++index) {
2264 componentArray.push_back(SwizzleComponent::Y);
2265 }
2266 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2267 componentArray, out);
2268 SpvId flipped = this->nextId(&callType);
2269 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2270 rtFlipY, out);
2271 result = flipped;
2272 }
2273 break;
2274 }
2275 case kClamp_SpecialIntrinsic: {
2276 TArray<SpvId> args = this->vectorize(arguments, out);
2277 SkASSERT(args.size() == 3);
2278 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2279 GLSLstd450UClamp, args, out);
2280 break;
2281 }
2282 case kMax_SpecialIntrinsic: {
2283 TArray<SpvId> args = this->vectorize(arguments, out);
2284 SkASSERT(args.size() == 2);
2285 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2286 GLSLstd450UMax, args, out);
2287 break;
2288 }
2289 case kMin_SpecialIntrinsic: {
2290 TArray<SpvId> args = this->vectorize(arguments, out);
2291 SkASSERT(args.size() == 2);
2292 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2293 GLSLstd450UMin, args, out);
2294 break;
2295 }
2296 case kMix_SpecialIntrinsic: {
2297 TArray<SpvId> args = this->vectorize(arguments, out);
2298 SkASSERT(args.size() == 3);
2299 if (arguments[2]->type().componentType().isBoolean()) {
2300 // Use OpSelect to implement Boolean mix().
2301 SpvId falseId = this->writeExpression(*arguments[0], out);
2302 SpvId trueId = this->writeExpression(*arguments[1], out);
2303 SpvId conditionId = this->writeExpression(*arguments[2], out);
2304 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2305 conditionId, trueId, falseId, out);
2306 } else {
2307 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2308 SpvOpUndef, args, out);
2309 }
2310 break;
2311 }
2312 case kSaturate_SpecialIntrinsic: {
2313 SkASSERT(arguments.size() == 1);
2314 ExpressionArray finalArgs;
2315 finalArgs.reserve_exact(3);
2316 finalArgs.push_back(arguments[0]->clone());
2317 finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/0));
2318 finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/1));
2319 TArray<SpvId> spvArgs = this->vectorize(finalArgs, out);
2320 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2321 GLSLstd450UClamp, spvArgs, out);
2322 break;
2323 }
2324 case kSmoothStep_SpecialIntrinsic: {
2325 TArray<SpvId> args = this->vectorize(arguments, out);
2326 SkASSERT(args.size() == 3);
2327 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2328 SpvOpUndef, args, out);
2329 break;
2330 }
2331 case kStep_SpecialIntrinsic: {
2332 TArray<SpvId> args = this->vectorize(arguments, out);
2333 SkASSERT(args.size() == 2);
2334 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2335 SpvOpUndef, args, out);
2336 break;
2337 }
2338 case kMatrixCompMult_SpecialIntrinsic: {
2339 SkASSERT(arguments.size() == 2);
2340 SpvId lhs = this->writeExpression(*arguments[0], out);
2341 SpvId rhs = this->writeExpression(*arguments[1], out);
2342 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2343 break;
2344 }
2345 case kAtomicAdd_SpecialIntrinsic:
2346 case kAtomicLoad_SpecialIntrinsic:
2347 case kAtomicStore_SpecialIntrinsic:
2348 result = this->writeAtomicIntrinsic(c, kind, result, out);
2349 break;
2350 case kStorageBarrier_SpecialIntrinsic:
2351 case kWorkgroupBarrier_SpecialIntrinsic: {
2352 // Both barrier types operate in the workgroup execution and memory scope and differ
2353 // only in memory semantics. storageBarrier() is not a device-scope barrier.
2354 SpvId scopeId =
2355 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2356 int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2357 ? SpvMemorySemanticsAcquireReleaseMask |
2358 SpvMemorySemanticsUniformMemoryMask
2359 : SpvMemorySemanticsAcquireReleaseMask |
2360 SpvMemorySemanticsWorkgroupMemoryMask;
2361 SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2362 this->writeInstruction(SpvOpControlBarrier,
2363 scopeId, // execution scope
2364 scopeId, // memory scope
2365 memorySemanticsId,
2366 out);
2367 break;
2368 }
2369 }
2370 return result;
2371 }
2372
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2373 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2374 SpecialIntrinsic kind,
2375 SpvId resultId,
2376 OutputStream& out) {
2377 const ExpressionArray& arguments = c.arguments();
2378 SkASSERT(!arguments.empty());
2379
2380 std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2381 SpvId atomicPtrId = atomicPtr->getPointer();
2382 if (atomicPtrId == NA) {
2383 SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2384 arguments[0]->description().c_str());
2385 return NA;
2386 }
2387
2388 SpvId memoryScopeId = NA;
2389 {
2390 // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2391 // member. The two memory scopes that these map to are "workgroup" and "device",
2392 // respectively.
2393 SpvScope memoryScope;
2394 switch (atomicPtr->storageClass()) {
2395 case SpvStorageClassUniform:
2396 // We encode storage buffers in the uniform address space (with the BufferBlock
2397 // decorator).
2398 memoryScope = SpvScopeDevice;
2399 break;
2400 case SpvStorageClassWorkgroup:
2401 memoryScope = SpvScopeWorkgroup;
2402 break;
2403 default:
2404 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2405 atomicPtr->storageClass());
2406 return NA;
2407 }
2408 memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2409 }
2410
2411 SpvId relaxedMemoryOrderId =
2412 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2413
2414 switch (kind) {
2415 case kAtomicAdd_SpecialIntrinsic:
2416 SkASSERT(arguments.size() == 2);
2417 this->writeInstruction(SpvOpAtomicIAdd,
2418 this->getType(c.type()),
2419 resultId,
2420 atomicPtrId,
2421 memoryScopeId,
2422 relaxedMemoryOrderId,
2423 this->writeExpression(*arguments[1], out),
2424 out);
2425 break;
2426 case kAtomicLoad_SpecialIntrinsic:
2427 SkASSERT(arguments.size() == 1);
2428 this->writeInstruction(SpvOpAtomicLoad,
2429 this->getType(c.type()),
2430 resultId,
2431 atomicPtrId,
2432 memoryScopeId,
2433 relaxedMemoryOrderId,
2434 out);
2435 break;
2436 case kAtomicStore_SpecialIntrinsic:
2437 SkASSERT(arguments.size() == 2);
2438 this->writeInstruction(SpvOpAtomicStore,
2439 atomicPtrId,
2440 memoryScopeId,
2441 relaxedMemoryOrderId,
2442 this->writeExpression(*arguments[1], out),
2443 out);
2444 break;
2445 default:
2446 SkUNREACHABLE;
2447 }
2448
2449 return resultId;
2450 }
2451
writeFunctionCallArgument(const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,OutputStream & out,SpvId * outSynthesizedSamplerId)2452 SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const FunctionCall& call,
2453 int argIndex,
2454 std::vector<TempVar>* tempVars,
2455 OutputStream& out,
2456 SpvId* outSynthesizedSamplerId) {
2457 const FunctionDeclaration& funcDecl = call.function();
2458 const Expression& arg = *call.arguments()[argIndex];
2459 ModifierFlags paramFlags = funcDecl.parameters()[argIndex]->modifierFlags();
2460
2461 // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2462 // passed directly
2463 SpvId tmpVar;
2464 // if we need a temporary var to store this argument, this is the value to store in the var
2465 SpvId tmpValueId = NA;
2466
2467 if (is_out(paramFlags)) {
2468 std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2469 // We handle out params with a temp var that we copy back to the original variable at the
2470 // end of the call. GLSL guarantees that the original variable will be unchanged until the
2471 // end of the call, and also that out params are written back to their original variables in
2472 // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2473 if (is_in(paramFlags)) {
2474 tmpValueId = lv->load(out);
2475 }
2476 tmpVar = this->nextId(&arg.type());
2477 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2478 } else if (funcDecl.isIntrinsic()) {
2479 // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2480 return this->writeExpression(arg, out);
2481 } else if (arg.is<VariableReference>() &&
2482 (arg.type().typeKind() == Type::TypeKind::kSampler ||
2483 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2484 arg.type().typeKind() == Type::TypeKind::kTexture)) {
2485 // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2486 // stored in intermediates when calling user-defined functions.
2487 //
2488 // The case for intrinsics (which take opaque arguments by value) is handled above just like
2489 // regular pointers.
2490 //
2491 // See getFunctionParameterType for further explanation.
2492 const Variable* var = arg.as<VariableReference>().variable();
2493
2494 // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2495 if (fUseTextureSamplerPairs && var->type().isSampler()) {
2496 if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2497 SkASSERT(outSynthesizedSamplerId);
2498
2499 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2500 SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2501 SkASSERT(img);
2502 SkASSERT(sampler);
2503
2504 *outSynthesizedSamplerId = *sampler;
2505 return *img;
2506 }
2507 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2508 }
2509
2510 SpvId* entry = fVariableMap.find(var);
2511 SkASSERTF(entry, "%s", arg.description().c_str());
2512 return *entry;
2513 } else {
2514 // We always use pointer parameters when calling user functions.
2515 // See getFunctionParameterType for further explanation.
2516 tmpValueId = this->writeExpression(arg, out);
2517 tmpVar = this->nextId(nullptr);
2518 }
2519 this->writeInstruction(SpvOpVariable,
2520 this->getPointerType(arg.type(), SpvStorageClassFunction),
2521 tmpVar,
2522 SpvStorageClassFunction,
2523 fVariableBuffer);
2524 if (tmpValueId != NA) {
2525 this->writeOpStore(SpvStorageClassFunction, tmpVar, tmpValueId, out);
2526 }
2527 return tmpVar;
2528 }
2529
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2530 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2531 for (const TempVar& tempVar : tempVars) {
2532 SpvId load = this->nextId(tempVar.type);
2533 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2534 tempVar.lvalue->store(load, out);
2535 }
2536 }
2537
writeFunctionCall(const FunctionCall & c,OutputStream & out)2538 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2539 const FunctionDeclaration& function = c.function();
2540 if (function.isIntrinsic() && !function.definition()) {
2541 return this->writeIntrinsicCall(c, out);
2542 }
2543 const ExpressionArray& arguments = c.arguments();
2544 SpvId* entry = fFunctionMap.find(&function);
2545 if (!entry) {
2546 fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2547 "' is not defined");
2548 return NA;
2549 }
2550 // Temp variables are used to write back out-parameters after the function call is complete.
2551 std::vector<TempVar> tempVars;
2552 TArray<SpvId> argumentIds;
2553 argumentIds.reserve_exact(arguments.size());
2554 for (int i = 0; i < arguments.size(); i++) {
2555 SpvId samplerId = NA;
2556 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out, &samplerId));
2557 if (samplerId != NA) {
2558 argumentIds.push_back(samplerId);
2559 }
2560 }
2561 SpvId result = this->nextId(nullptr);
2562 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2563 this->writeWord(this->getType(c.type()), out);
2564 this->writeWord(result, out);
2565 this->writeWord(*entry, out);
2566 for (SpvId id : argumentIds) {
2567 this->writeWord(id, out);
2568 }
2569 // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2570 this->copyBackTempVars(tempVars, out);
2571 return result;
2572 }
2573
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2574 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2575 const Type& inputType,
2576 const Type& outputType,
2577 OutputStream& out) {
2578 if (outputType.isFloat()) {
2579 return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2580 }
2581 if (outputType.isSigned()) {
2582 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2583 }
2584 if (outputType.isUnsigned()) {
2585 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2586 }
2587 if (outputType.isBoolean()) {
2588 return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2589 }
2590
2591 fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2592 outputType.description());
2593 return inputExprId;
2594 }
2595
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2596 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2597 const Type& outputType, OutputStream& out) {
2598 // Casting a float to float is a no-op.
2599 if (inputType.isFloat()) {
2600 return inputId;
2601 }
2602
2603 // Given the input type, generate the appropriate instruction to cast to float.
2604 SpvId result = this->nextId(&outputType);
2605 if (inputType.isBoolean()) {
2606 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2607 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2608 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2609 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2610 inputId, oneID, zeroID, out);
2611 } else if (inputType.isSigned()) {
2612 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2613 } else if (inputType.isUnsigned()) {
2614 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2615 } else {
2616 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2617 return NA;
2618 }
2619 return result;
2620 }
2621
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2622 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2623 const Type& outputType, OutputStream& out) {
2624 // Casting a signed int to signed int is a no-op.
2625 if (inputType.isSigned()) {
2626 return inputId;
2627 }
2628
2629 // Given the input type, generate the appropriate instruction to cast to signed int.
2630 SpvId result = this->nextId(&outputType);
2631 if (inputType.isBoolean()) {
2632 // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2633 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2634 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2635 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2636 inputId, oneID, zeroID, out);
2637 } else if (inputType.isFloat()) {
2638 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2639 } else if (inputType.isUnsigned()) {
2640 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2641 } else {
2642 SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2643 inputType.description().c_str());
2644 return NA;
2645 }
2646 return result;
2647 }
2648
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2649 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2650 const Type& outputType, OutputStream& out) {
2651 // Casting an unsigned int to unsigned int is a no-op.
2652 if (inputType.isUnsigned()) {
2653 return inputId;
2654 }
2655
2656 // Given the input type, generate the appropriate instruction to cast to unsigned int.
2657 SpvId result = this->nextId(&outputType);
2658 if (inputType.isBoolean()) {
2659 // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2660 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2661 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2662 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2663 inputId, oneID, zeroID, out);
2664 } else if (inputType.isFloat()) {
2665 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2666 } else if (inputType.isSigned()) {
2667 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2668 } else {
2669 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2670 inputType.description().c_str());
2671 return NA;
2672 }
2673 return result;
2674 }
2675
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2676 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2677 const Type& outputType, OutputStream& out) {
2678 // Casting a bool to bool is a no-op.
2679 if (inputType.isBoolean()) {
2680 return inputId;
2681 }
2682
2683 // Given the input type, generate the appropriate instruction to cast to bool.
2684 SpvId result = this->nextId(nullptr);
2685 if (inputType.isSigned()) {
2686 // Synthesize a boolean result by comparing the input against a signed zero literal.
2687 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2688 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2689 inputId, zeroID, out);
2690 } else if (inputType.isUnsigned()) {
2691 // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2692 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2693 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2694 inputId, zeroID, out);
2695 } else if (inputType.isFloat()) {
2696 // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2697 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2698 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2699 inputId, zeroID, out);
2700 } else {
2701 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2702 return NA;
2703 }
2704 return result;
2705 }
2706
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2707 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2708 OutputStream& out) {
2709 SkASSERT(srcType.isMatrix());
2710 SkASSERT(dstType.isMatrix());
2711 SkASSERT(srcType.componentType().matches(dstType.componentType()));
2712 const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2713 const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2714 SkASSERT(dstType.componentType().isFloat());
2715 SpvId dstColumnTypeId = this->getType(dstColumnType);
2716 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2717 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2718
2719 STArray<4, SpvId> columns;
2720 for (int i = 0; i < dstType.columns(); i++) {
2721 if (i < srcType.columns()) {
2722 // we're still inside the src matrix, copy the column
2723 SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2724 SpvId dstColumn;
2725 if (srcType.rows() == dstType.rows()) {
2726 // columns are equal size, don't need to do anything
2727 dstColumn = srcColumn;
2728 }
2729 else if (dstType.rows() > srcType.rows()) {
2730 // dst column is bigger, need to zero-pad it
2731 STArray<4, SpvId> values;
2732 values.push_back(srcColumn);
2733 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2734 values.push_back((i == j) ? oneId : zeroId);
2735 }
2736 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2737 }
2738 else {
2739 // dst column is smaller, need to swizzle the src column
2740 dstColumn = this->nextId(&dstType);
2741 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2742 this->writeWord(dstColumnTypeId, out);
2743 this->writeWord(dstColumn, out);
2744 this->writeWord(srcColumn, out);
2745 this->writeWord(srcColumn, out);
2746 for (int j = 0; j < dstType.rows(); j++) {
2747 this->writeWord(j, out);
2748 }
2749 }
2750 columns.push_back(dstColumn);
2751 } else {
2752 // we're past the end of the src matrix, need to synthesize an identity-matrix column
2753 STArray<4, SpvId> values;
2754 for (int j = 0; j < dstType.rows(); ++j) {
2755 values.push_back((i == j) ? oneId : zeroId);
2756 }
2757 columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2758 }
2759 }
2760
2761 return this->writeOpCompositeConstruct(dstType, columns, out);
2762 }
2763
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2764 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2765 TArray<SpvId>* currentColumn,
2766 TArray<SpvId>* columnIds,
2767 int rows,
2768 SpvId entry,
2769 OutputStream& out) {
2770 SkASSERT(currentColumn->size() < rows);
2771 currentColumn->push_back(entry);
2772 if (currentColumn->size() == rows) {
2773 // Synthesize this column into a vector.
2774 SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2775 columnIds->push_back(columnId);
2776 currentColumn->clear();
2777 }
2778 }
2779
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2780 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2781 const Type& type = c.type();
2782 SkASSERT(type.isMatrix());
2783 SkASSERT(!c.arguments().empty());
2784 const Type& arg0Type = c.arguments()[0]->type();
2785 // go ahead and write the arguments so we don't try to write new instructions in the middle of
2786 // an instruction
2787 STArray<16, SpvId> arguments;
2788 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2789 arguments.push_back(this->writeExpression(*arg, out));
2790 }
2791
2792 if (arguments.size() == 1 && arg0Type.isVector()) {
2793 // Special-case handling of float4 -> mat2x2.
2794 SkASSERT(type.rows() == 2 && type.columns() == 2);
2795 SkASSERT(arg0Type.columns() == 4);
2796 SpvId v[4];
2797 for (int i = 0; i < 4; ++i) {
2798 v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2799 }
2800 const Type& vecType = type.columnType(fContext);
2801 SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2802 SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2803 return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2804 }
2805
2806 int rows = type.rows();
2807 const Type& columnType = type.columnType(fContext);
2808 // SpvIds of completed columns of the matrix.
2809 STArray<4, SpvId> columnIds;
2810 // SpvIds of scalars we have written to the current column so far.
2811 STArray<4, SpvId> currentColumn;
2812 for (int i = 0; i < arguments.size(); i++) {
2813 const Type& argType = c.arguments()[i]->type();
2814 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2815 // This vector is a complete matrix column by itself and can be used as-is.
2816 columnIds.push_back(arguments[i]);
2817 } else if (argType.columns() == 1) {
2818 // This argument is a lone scalar and can be added to the current column as-is.
2819 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out);
2820 } else {
2821 // This argument needs to be decomposed into its constituent scalars.
2822 for (int j = 0; j < argType.columns(); ++j) {
2823 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2824 arguments[i], j, out);
2825 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out);
2826 }
2827 }
2828 }
2829 SkASSERT(columnIds.size() == type.columns());
2830 return this->writeOpCompositeConstruct(type, columnIds, out);
2831 }
2832
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2833 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2834 OutputStream& out) {
2835 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2836 : this->writeVectorConstructor(c, out);
2837 }
2838
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2839 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2840 const Type& type = c.type();
2841 const Type& componentType = type.componentType();
2842 SkASSERT(type.isVector());
2843
2844 STArray<4, SpvId> arguments;
2845 for (int i = 0; i < c.arguments().size(); i++) {
2846 const Type& argType = c.arguments()[i]->type();
2847 SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2848
2849 SpvId arg = this->writeExpression(*c.arguments()[i], out);
2850 if (argType.isMatrix()) {
2851 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2852 // each scalar separately.
2853 SkASSERT(argType.rows() == 2);
2854 SkASSERT(argType.columns() == 2);
2855 for (int j = 0; j < 4; ++j) {
2856 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2857 j / 2, j % 2, out));
2858 }
2859 } else if (argType.isVector()) {
2860 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2861 // vector arguments at all, so we always extract each vector component and pass them
2862 // into OpCompositeConstruct individually.
2863 for (int j = 0; j < argType.columns(); j++) {
2864 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2865 }
2866 } else {
2867 arguments.push_back(arg);
2868 }
2869 }
2870
2871 return this->writeOpCompositeConstruct(type, arguments, out);
2872 }
2873
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2874 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2875 // Write the splat argument as a scalar, then splat it.
2876 SpvId argument = this->writeExpression(*c.argument(), out);
2877 return this->splat(c.type(), argument, out);
2878 }
2879
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2880 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2881 SkASSERT(c.type().isArray() || c.type().isStruct());
2882 auto ctorArgs = c.argumentSpan();
2883
2884 STArray<4, SpvId> arguments;
2885 for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2886 arguments.push_back(this->writeExpression(*arg, out));
2887 }
2888
2889 return this->writeOpCompositeConstruct(c.type(), arguments, out);
2890 }
2891
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2892 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2893 OutputStream& out) {
2894 const Type& type = c.type();
2895 if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2896 return this->writeExpression(*c.argument(), out);
2897 }
2898
2899 const Expression& ctorExpr = *c.argument();
2900 SpvId expressionId = this->writeExpression(ctorExpr, out);
2901 return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2902 }
2903
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2904 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2905 OutputStream& out) {
2906 const Type& ctorType = c.type();
2907 const Type& argType = c.argument()->type();
2908 SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2909
2910 // Write the composite that we are casting. If the actual type matches, we are done.
2911 SpvId compositeId = this->writeExpression(*c.argument(), out);
2912 if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
2913 return compositeId;
2914 }
2915
2916 // writeMatrixCopy can cast matrices to a different type.
2917 if (ctorType.isMatrix()) {
2918 return this->writeMatrixCopy(compositeId, argType, ctorType, out);
2919 }
2920
2921 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
2922 // components and convert each one manually.
2923 const Type& srcType = argType.componentType();
2924 const Type& dstType = ctorType.componentType();
2925
2926 STArray<4, SpvId> arguments;
2927 for (int index = 0; index < argType.columns(); ++index) {
2928 SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
2929 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
2930 }
2931
2932 return this->writeOpCompositeConstruct(ctorType, arguments, out);
2933 }
2934
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)2935 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
2936 OutputStream& out) {
2937 const Type& type = c.type();
2938 SkASSERT(type.isMatrix());
2939 SkASSERT(c.argument()->type().isScalar());
2940
2941 // Write out the scalar argument.
2942 SpvId diagonal = this->writeExpression(*c.argument(), out);
2943
2944 // Build the diagonal matrix.
2945 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2946
2947 const Type& vecType = type.columnType(fContext);
2948 STArray<4, SpvId> columnIds;
2949 STArray<4, SpvId> arguments;
2950 arguments.resize(type.rows());
2951 for (int column = 0; column < type.columns(); column++) {
2952 for (int row = 0; row < type.rows(); row++) {
2953 arguments[row] = (row == column) ? diagonal : zeroId;
2954 }
2955 columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
2956 }
2957 return this->writeOpCompositeConstruct(type, columnIds, out);
2958 }
2959
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)2960 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
2961 OutputStream& out) {
2962 // Write the input matrix.
2963 SpvId argument = this->writeExpression(*c.argument(), out);
2964
2965 // Use matrix-copy to resize the input matrix to its new size.
2966 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
2967 }
2968
get_storage_class_for_global_variable(const Variable & var,SpvStorageClass_ fallbackStorageClass)2969 static SpvStorageClass_ get_storage_class_for_global_variable(
2970 const Variable& var, SpvStorageClass_ fallbackStorageClass) {
2971 SkASSERT(var.storage() == Variable::Storage::kGlobal);
2972
2973 if (var.type().typeKind() == Type::TypeKind::kSampler ||
2974 var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2975 var.type().typeKind() == Type::TypeKind::kTexture) {
2976 return SpvStorageClassUniformConstant;
2977 }
2978
2979 const Layout& layout = var.layout();
2980 ModifierFlags flags = var.modifierFlags();
2981 if (flags & ModifierFlag::kIn) {
2982 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
2983 return SpvStorageClassInput;
2984 }
2985 if (flags & ModifierFlag::kOut) {
2986 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
2987 return SpvStorageClassOutput;
2988 }
2989 if (flags.isUniform()) {
2990 if (layout.fFlags & LayoutFlag::kPushConstant) {
2991 return SpvStorageClassPushConstant;
2992 }
2993 return SpvStorageClassUniform;
2994 }
2995 if (flags.isBuffer()) {
2996 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
2997 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
2998 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
2999 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
3000 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
3001 // would benefit from SPV_KHR_variable_pointers capabilities).
3002 return SpvStorageClassUniform;
3003 }
3004 if (flags.isWorkgroup()) {
3005 return SpvStorageClassWorkgroup;
3006 }
3007 return fallbackStorageClass;
3008 }
3009
get_storage_class(const Expression & expr)3010 static SpvStorageClass_ get_storage_class(const Expression& expr) {
3011 switch (expr.kind()) {
3012 case Expression::Kind::kVariableReference: {
3013 const Variable& var = *expr.as<VariableReference>().variable();
3014 if (var.storage() != Variable::Storage::kGlobal) {
3015 return SpvStorageClassFunction;
3016 }
3017 return get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
3018 }
3019 case Expression::Kind::kFieldAccess:
3020 return get_storage_class(*expr.as<FieldAccess>().base());
3021 case Expression::Kind::kIndex:
3022 return get_storage_class(*expr.as<IndexExpression>().base());
3023 default:
3024 return SpvStorageClassFunction;
3025 }
3026 }
3027
getAccessChain(const Expression & expr,OutputStream & out)3028 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3029 switch (expr.kind()) {
3030 case Expression::Kind::kIndex: {
3031 const IndexExpression& indexExpr = expr.as<IndexExpression>();
3032 if (indexExpr.base()->is<Swizzle>()) {
3033 // Access chains don't directly support dynamically indexing into a swizzle, but we
3034 // can rewrite them into a supported form.
3035 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3036 out);
3037 }
3038 // All other index-expressions can be represented as typical access chains.
3039 TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3040 chain.push_back(this->writeExpression(*indexExpr.index(), out));
3041 return chain;
3042 }
3043 case Expression::Kind::kFieldAccess: {
3044 const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3045 TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3046 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3047 return chain;
3048 }
3049 default: {
3050 SpvId id = this->getLValue(expr, out)->getPointer();
3051 SkASSERT(id != NA);
3052 return TArray<SpvId>{id};
3053 }
3054 }
3055 SkUNREACHABLE;
3056 }
3057
3058 class PointerLValue : public SPIRVCodeGenerator::LValue {
3059 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,SpvStorageClass_ storageClass)3060 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3061 SPIRVCodeGenerator::Precision precision, SpvStorageClass_ storageClass)
3062 : fGen(gen)
3063 , fPointer(pointer)
3064 , fIsMemoryObject(isMemoryObject)
3065 , fType(type)
3066 , fPrecision(precision)
3067 , fStorageClass(storageClass) {}
3068
getPointer()3069 SpvId getPointer() override {
3070 return fPointer;
3071 }
3072
isMemoryObjectPointer() const3073 bool isMemoryObjectPointer() const override {
3074 return fIsMemoryObject;
3075 }
3076
storageClass() const3077 SpvStorageClass storageClass() const override {
3078 return fStorageClass;
3079 }
3080
load(OutputStream & out)3081 SpvId load(OutputStream& out) override {
3082 return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3083 }
3084
store(SpvId value,OutputStream & out)3085 void store(SpvId value, OutputStream& out) override {
3086 if (!fIsMemoryObject) {
3087 // We are going to write into an access chain; this could represent one component of a
3088 // vector, or one element of an array. This has the potential to invalidate other,
3089 // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3090 // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3091 // relying on stale data, reset the store cache entirely when this happens.
3092 fGen.fStoreCache.reset();
3093 }
3094
3095 fGen.writeOpStore(fStorageClass, fPointer, value, out);
3096 }
3097
3098 private:
3099 SPIRVCodeGenerator& fGen;
3100 const SpvId fPointer;
3101 const bool fIsMemoryObject;
3102 const SpvId fType;
3103 const SPIRVCodeGenerator::Precision fPrecision;
3104 const SpvStorageClass_ fStorageClass;
3105 };
3106
3107 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3108 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,SpvStorageClass_ storageClass)3109 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3110 const Type& baseType, const Type& swizzleType, SpvStorageClass_ storageClass)
3111 : fGen(gen)
3112 , fVecPointer(vecPointer)
3113 , fComponents(components)
3114 , fBaseType(&baseType)
3115 , fSwizzleType(&swizzleType)
3116 , fStorageClass(storageClass) {}
3117
applySwizzle(const ComponentArray & components,const Type & newType)3118 bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3119 ComponentArray updatedSwizzle;
3120 for (int8_t component : components) {
3121 if (component < 0 || component >= fComponents.size()) {
3122 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3123 return false;
3124 }
3125 updatedSwizzle.push_back(fComponents[component]);
3126 }
3127 fComponents = updatedSwizzle;
3128 fSwizzleType = &newType;
3129 return true;
3130 }
3131
storageClass() const3132 SpvStorageClass storageClass() const override {
3133 return fStorageClass;
3134 }
3135
load(OutputStream & out)3136 SpvId load(OutputStream& out) override {
3137 SpvId base = fGen.nextId(fBaseType);
3138 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3139 SpvId result = fGen.nextId(fBaseType);
3140 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3141 fGen.writeWord(fGen.getType(*fSwizzleType), out);
3142 fGen.writeWord(result, out);
3143 fGen.writeWord(base, out);
3144 fGen.writeWord(base, out);
3145 for (int component : fComponents) {
3146 fGen.writeWord(component, out);
3147 }
3148 return result;
3149 }
3150
store(SpvId value,OutputStream & out)3151 void store(SpvId value, OutputStream& out) override {
3152 // use OpVectorShuffle to mix and match the vector components. We effectively create
3153 // a virtual vector out of the concatenation of the left and right vectors, and then
3154 // select components from this virtual vector to make the result vector. For
3155 // instance, given:
3156 // float3L = ...;
3157 // float3R = ...;
3158 // L.xz = R.xy;
3159 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3160 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3161 // (3, 1, 4).
3162 SpvId base = fGen.nextId(fBaseType);
3163 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3164 SpvId shuffle = fGen.nextId(fBaseType);
3165 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3166 fGen.writeWord(fGen.getType(*fBaseType), out);
3167 fGen.writeWord(shuffle, out);
3168 fGen.writeWord(base, out);
3169 fGen.writeWord(value, out);
3170 for (int i = 0; i < fBaseType->columns(); i++) {
3171 // current offset into the virtual vector, defaults to pulling the unmodified
3172 // value from the left side
3173 int offset = i;
3174 // check to see if we are writing this component
3175 for (int j = 0; j < fComponents.size(); j++) {
3176 if (fComponents[j] == i) {
3177 // we're writing to this component, so adjust the offset to pull from
3178 // the correct component of the right side instead of preserving the
3179 // value from the left
3180 offset = (int) (j + fBaseType->columns());
3181 break;
3182 }
3183 }
3184 fGen.writeWord(offset, out);
3185 }
3186 fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3187 }
3188
3189 private:
3190 SPIRVCodeGenerator& fGen;
3191 const SpvId fVecPointer;
3192 ComponentArray fComponents;
3193 const Type* fBaseType;
3194 const Type* fSwizzleType;
3195 const SpvStorageClass_ fStorageClass;
3196 };
3197
findUniformFieldIndex(const Variable & var) const3198 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3199 int* fieldIndex = fTopLevelUniformMap.find(&var);
3200 return fieldIndex ? *fieldIndex : -1;
3201 }
3202
getLValue(const Expression & expr,OutputStream & out)3203 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3204 OutputStream& out) {
3205 const Type& type = expr.type();
3206 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3207 switch (expr.kind()) {
3208 case Expression::Kind::kVariableReference: {
3209 const Variable& var = *expr.as<VariableReference>().variable();
3210 int uniformIdx = this->findUniformFieldIndex(var);
3211 if (uniformIdx >= 0) {
3212 // Access uniforms via an AccessChain into the uniform-buffer struct.
3213 SpvId memberId = this->nextId(nullptr);
3214 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
3215 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3216 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3217 uniformIdxId, out);
3218 return std::make_unique<PointerLValue>(
3219 *this,
3220 memberId,
3221 /*isMemoryObjectPointer=*/true,
3222 this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3223 precision,
3224 SpvStorageClassUniform);
3225 }
3226
3227 SpvId* entry = fVariableMap.find(&var);
3228 SkASSERTF(entry, "%s", expr.description().c_str());
3229
3230 if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3231 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3232 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3233 // represents sample masks as an array of uints.
3234 SpvStorageClass_ storageClass =
3235 get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
3236 SkASSERT(storageClass != SpvStorageClassPrivate);
3237 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3238
3239 SpvId accessId = this->nextId(nullptr);
3240 SpvId typeId = this->getPointerType(type, storageClass);
3241 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3242 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3243 return std::make_unique<PointerLValue>(*this,
3244 accessId,
3245 /*isMemoryObjectPointer=*/true,
3246 this->getType(type),
3247 precision,
3248 storageClass);
3249 }
3250
3251 SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3252 return std::make_unique<PointerLValue>(*this, *entry,
3253 /*isMemoryObjectPointer=*/true,
3254 typeId, precision, get_storage_class(expr));
3255 }
3256 case Expression::Kind::kIndex: // fall through
3257 case Expression::Kind::kFieldAccess: {
3258 TArray<SpvId> chain = this->getAccessChain(expr, out);
3259 SpvId member = this->nextId(nullptr);
3260 SpvStorageClass_ storageClass = get_storage_class(expr);
3261 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3262 this->writeWord(this->getPointerType(type, storageClass), out);
3263 this->writeWord(member, out);
3264 for (SpvId idx : chain) {
3265 this->writeWord(idx, out);
3266 }
3267 return std::make_unique<PointerLValue>(
3268 *this,
3269 member,
3270 /*isMemoryObjectPointer=*/false,
3271 this->getType(type,
3272 kDefaultTypeLayout,
3273 this->memoryLayoutForStorageClass(storageClass)),
3274 precision,
3275 storageClass);
3276 }
3277 case Expression::Kind::kSwizzle: {
3278 const Swizzle& swizzle = expr.as<Swizzle>();
3279 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3280 if (lvalue->applySwizzle(swizzle.components(), type)) {
3281 return lvalue;
3282 }
3283 SpvId base = lvalue->getPointer();
3284 if (base == NA) {
3285 fContext.fErrors->error(swizzle.fPosition,
3286 "unable to retrieve lvalue from swizzle");
3287 }
3288 SpvStorageClass_ storageClass = get_storage_class(*swizzle.base());
3289 if (swizzle.components().size() == 1) {
3290 SpvId member = this->nextId(nullptr);
3291 SpvId typeId = this->getPointerType(type, storageClass);
3292 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3293 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3294 return std::make_unique<PointerLValue>(*this, member,
3295 /*isMemoryObjectPointer=*/false,
3296 this->getType(type),
3297 precision, storageClass);
3298 } else {
3299 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3300 swizzle.base()->type(), type, storageClass);
3301 }
3302 }
3303 default: {
3304 // expr isn't actually an lvalue, create a placeholder variable for it. This case
3305 // happens due to the need to store values in temporary variables during function
3306 // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3307 // lvalues should have been caught before code generation.
3308 //
3309 // This is with the exception of opaque handle types (textures/samplers) which are
3310 // always defined as UniformConstant pointers and don't need to be explicitly stored
3311 // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3312 SpvId result = this->nextId(nullptr);
3313 SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
3314 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3315 fVariableBuffer);
3316 this->writeOpStore(SpvStorageClassFunction, result, this->writeExpression(expr, out),
3317 out);
3318 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3319 this->getType(type), precision,
3320 SpvStorageClassFunction);
3321 }
3322 }
3323 }
3324
identifier(std::string_view name)3325 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3326 std::unique_ptr<Expression> expr =
3327 fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3328 return expr ? std::move(expr)
3329 : Poison::Make(Position(), fContext);
3330 }
3331
writeVariableReference(const VariableReference & ref,OutputStream & out)3332 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3333 const Variable* variable = ref.variable();
3334 switch (variable->layout().fBuiltin) {
3335 case DEVICE_FRAGCOORDS_BUILTIN: {
3336 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3337 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3338 // access the fragcoord; do so now.
3339 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3340 }
3341 case DEVICE_CLOCKWISE_BUILTIN: {
3342 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3343 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3344 // access front facing; do so now.
3345 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3346 }
3347 case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3348 // sk_SecondaryFragColor corresponds to gl_SecondaryFragColorEXT, which isn't supposed
3349 // to appear in a SPIR-V program (it's only valid in ES2). Report an error.
3350 fContext.fErrors->error(ref.fPosition,
3351 "sk_SecondaryFragColor is not allowed in SPIR-V");
3352 return NA;
3353 }
3354 case SK_FRAGCOORD_BUILTIN: {
3355 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3356 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3357 }
3358
3359 // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3360 this->addRTFlipUniform(ref.fPosition);
3361 // Use sk_RTAdjust to compute the flipped coordinate
3362 // Use a uniform to flip the Y coordinate. The new expression will be written in
3363 // terms of $device_FragCoords, which is a fake variable that means "access the
3364 // underlying fragcoords directly without flipping it".
3365 static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3366 if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3367 AutoAttachPoolToThread attach(fProgram.fPool.get());
3368 Layout layout;
3369 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3370 auto coordsVar = Variable::Make(/*pos=*/Position(),
3371 /*modifiersPosition=*/Position(),
3372 layout,
3373 ModifierFlag::kNone,
3374 fContext.fTypes.fFloat4.get(),
3375 DEVICE_COORDS_NAME,
3376 /*mangledName=*/"",
3377 /*builtin=*/true,
3378 Variable::Storage::kGlobal);
3379 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3380 }
3381 std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3382 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3383 SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3384 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3385 SpvId deviceCoordX = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3386 SpvId deviceCoordY = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3387 SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3388 SwizzleComponent::W}, out);
3389 // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3390 SpvId flippedY = this->writeBinaryExpression(
3391 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3392 *fContext.fTypes.fFloat, deviceCoordY,
3393 *fContext.fTypes.fFloat, out);
3394
3395 // Compute `flippedY = u_RTFlip.x + flippedY`.
3396 flippedY = this->writeBinaryExpression(
3397 *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3398 *fContext.fTypes.fFloat, flippedY,
3399 *fContext.fTypes.fFloat, out);
3400
3401 // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3402 return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3403 {deviceCoordX, flippedY, deviceCoordZW},
3404 out);
3405 }
3406 case SK_CLOCKWISE_BUILTIN: {
3407 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3408 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3409 }
3410
3411 // Apply RTFlip to sk_Clockwise.
3412 this->addRTFlipUniform(ref.fPosition);
3413 // Use a uniform to flip the Y coordinate. The new expression will be written in
3414 // terms of $device_Clockwise, which is a fake variable that means "access the
3415 // underlying FrontFacing directly".
3416 static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3417 if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3418 AutoAttachPoolToThread attach(fProgram.fPool.get());
3419 Layout layout;
3420 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3421 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3422 /*modifiersPosition=*/Position(),
3423 layout,
3424 ModifierFlag::kNone,
3425 fContext.fTypes.fBool.get(),
3426 DEVICE_CLOCKWISE_NAME,
3427 /*mangledName=*/"",
3428 /*builtin=*/true,
3429 Variable::Storage::kGlobal);
3430 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3431 }
3432 // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3433 // we use the default convention of "counter-clockwise face is front".
3434
3435 // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3436 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3437 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3438 SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3439 SpvId positiveRTFlip = this->writeBinaryExpression(
3440 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3441 *fContext.fTypes.fFloat, zero,
3442 *fContext.fTypes.fBool, out);
3443
3444 // Compute `positiveRTFlip ^^ $device_Clockwise`.
3445 std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3446 SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3447 return this->writeBinaryExpression(
3448 *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3449 *fContext.fTypes.fBool, deviceClockwiseID,
3450 *fContext.fTypes.fBool, out);
3451 }
3452 default: {
3453 // Constant-propagate variables that have a known compile-time value.
3454 if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3455 return this->writeExpression(*expr, out);
3456 }
3457
3458 // A reference to a sampler variable at global scope with synthesized texture/sampler
3459 // backing should construct a function-scope combined image-sampler from the synthesized
3460 // constituents. This is the case in which a sample intrinsic was invoked.
3461 //
3462 // Variable references to opaque handles (texture/sampler) that appear as the argument
3463 // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3464 if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3465 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3466 SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3467 SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3468 SkASSERT(imgPtr);
3469 SkASSERT(samplerPtr);
3470
3471 SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3472 Precision::kDefault, *imgPtr, out);
3473 SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3474 Precision::kDefault,
3475 *samplerPtr,
3476 out);
3477 SpvId result = this->nextId(nullptr);
3478 this->writeInstruction(SpvOpSampledImage,
3479 this->getType(variable->type()),
3480 result,
3481 img,
3482 sampler,
3483 out);
3484 return result;
3485 }
3486 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3487 }
3488 return this->getLValue(ref, out)->load(out);
3489 }
3490 }
3491 }
3492
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3493 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3494 if (expr.base()->type().isVector()) {
3495 SpvId base = this->writeExpression(*expr.base(), out);
3496 SpvId index = this->writeExpression(*expr.index(), out);
3497 SpvId result = this->nextId(nullptr);
3498 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3499 index, out);
3500 return result;
3501 }
3502 return getLValue(expr, out)->load(out);
3503 }
3504
writeFieldAccess(const FieldAccess & f,OutputStream & out)3505 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3506 return getLValue(f, out)->load(out);
3507 }
3508
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3509 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3510 const ComponentArray& components,
3511 OutputStream& out) {
3512 size_t count = components.size();
3513 const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3514 SpvId base = this->writeExpression(baseExpr, out);
3515 if (count == 1) {
3516 return this->writeOpCompositeExtract(type, base, components[0], out);
3517 }
3518
3519 SpvId result = this->nextId(&type);
3520 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3521 this->writeWord(this->getType(type), out);
3522 this->writeWord(result, out);
3523 this->writeWord(base, out);
3524 this->writeWord(base, out);
3525 for (int component : components) {
3526 this->writeWord(component, out);
3527 }
3528 return result;
3529 }
3530
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3531 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3532 return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3533 }
3534
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3535 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3536 SpvId lhs, SpvId rhs,
3537 bool writeComponentwiseIfMatrix,
3538 SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3539 SpvOp_ ifBool, OutputStream& out) {
3540 SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3541 if (op == SpvOpUndef) {
3542 fContext.fErrors->error(operandType.fPosition,
3543 "unsupported operand for binary expression: " + operandType.description());
3544 return NA;
3545 }
3546 if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3547 return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3548 }
3549 SpvId result = this->nextId(&resultType);
3550 this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3551 return result;
3552 }
3553
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3554 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3555 const Type& operandType,
3556 SpvId lhs, SpvId rhs,
3557 SpvOp_ ifFloat, SpvOp_ ifInt,
3558 SpvOp_ ifUInt, SpvOp_ ifBool,
3559 OutputStream& out) {
3560 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3561 /*writeComponentwiseIfMatrix=*/true,
3562 ifFloat, ifInt, ifUInt, ifBool, out);
3563 }
3564
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3565 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3566 SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3567 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3568 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3569 /*writeComponentwiseIfMatrix=*/false,
3570 ifFloat, ifInt, ifUInt, ifBool, out);
3571 }
3572
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3573 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3574 OutputStream& out) {
3575 if (operandType.isVector()) {
3576 SpvId result = this->nextId(nullptr);
3577 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3578 return result;
3579 }
3580 return id;
3581 }
3582
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3583 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3584 SpvOp_ floatOperator, SpvOp_ intOperator,
3585 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3586 OutputStream& out) {
3587 SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3588 SkASSERT(operandType.isMatrix());
3589 const Type& columnType = operandType.componentType().toCompound(fContext,
3590 operandType.rows(),
3591 1);
3592 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3593 operandType.rows(),
3594 1));
3595 SpvId boolType = this->getType(*fContext.fTypes.fBool);
3596 SpvId result = 0;
3597 for (int i = 0; i < operandType.columns(); i++) {
3598 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3599 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3600 SpvId compare = this->nextId(&operandType);
3601 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3602 SpvId merge = this->nextId(nullptr);
3603 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3604 if (result != 0) {
3605 SpvId next = this->nextId(nullptr);
3606 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3607 result = next;
3608 } else {
3609 result = merge;
3610 }
3611 }
3612 return result;
3613 }
3614
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3615 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3616 SpvId operand,
3617 SpvOp_ op,
3618 OutputStream& out) {
3619 SkASSERT(operandType.isMatrix());
3620 const Type& columnType = operandType.columnType(fContext);
3621 SpvId columnTypeId = this->getType(columnType);
3622
3623 STArray<4, SpvId> columns;
3624 for (int i = 0; i < operandType.columns(); i++) {
3625 SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3626 SpvId dstColumn = this->nextId(&operandType);
3627 this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3628 columns.push_back(dstColumn);
3629 }
3630
3631 return this->writeOpCompositeConstruct(operandType, columns, out);
3632 }
3633
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3634 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3635 SpvId rhs, SpvOp_ op, OutputStream& out) {
3636 SkASSERT(operandType.isMatrix());
3637 const Type& columnType = operandType.columnType(fContext);
3638 SpvId columnTypeId = this->getType(columnType);
3639
3640 STArray<4, SpvId> columns;
3641 for (int i = 0; i < operandType.columns(); i++) {
3642 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3643 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3644 columns.push_back(this->nextId(&operandType));
3645 this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3646 }
3647 return this->writeOpCompositeConstruct(operandType, columns, out);
3648 }
3649
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3650 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3651 SkASSERT(type.isFloat());
3652 SpvId one = this->writeLiteral(1.0, type);
3653 SpvId reciprocal = this->nextId(&type);
3654 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3655 return reciprocal;
3656 }
3657
splat(const Type & type,SpvId id,OutputStream & out)3658 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3659 if (type.isScalar()) {
3660 // Scalars require no additional work; we can return the passed-in ID as is.
3661 } else {
3662 SkASSERT(type.isVector() || type.isMatrix());
3663 bool isMatrix = type.isMatrix();
3664
3665 // Splat the input scalar across a vector.
3666 int vectorSize = (isMatrix ? type.rows() : type.columns());
3667 const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3668
3669 STArray<4, SpvId> values;
3670 values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3671 id = this->writeOpCompositeConstruct(vectorType, values, out);
3672
3673 if (isMatrix) {
3674 // Splat the newly-synthesized vector into a matrix.
3675 STArray<4, SpvId> matArguments;
3676 matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3677 id = this->writeOpCompositeConstruct(type, matArguments, out);
3678 }
3679 }
3680
3681 return id;
3682 }
3683
types_match(const Type & a,const Type & b)3684 static bool types_match(const Type& a, const Type& b) {
3685 if (a.matches(b)) {
3686 return true;
3687 }
3688 return (a.typeKind() == b.typeKind()) &&
3689 (a.isScalar() || a.isVector() || a.isMatrix()) &&
3690 (a.columns() == b.columns() && a.rows() == b.rows()) &&
3691 a.componentType().numberKind() == b.componentType().numberKind();
3692 }
3693
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3694 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3695 SpvId lhs,
3696 const Type& rightType,
3697 SpvId rhs,
3698 const Type& resultType,
3699 OutputStream& out) {
3700 SpvId sum = NA;
3701 const Type& columnType = leftType.columnType(fContext);
3702 const Type& scalarType = rightType.componentType();
3703
3704 for (int n = 0; n < leftType.rows(); ++n) {
3705 // Extract mat[N] from the matrix.
3706 SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3707
3708 // Extract vec[N] from the vector.
3709 SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3710
3711 // Multiply them together.
3712 SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3713 scalarType, vecN,
3714 columnType, out);
3715
3716 // Sum all the components together.
3717 if (sum == NA) {
3718 sum = product;
3719 } else {
3720 sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3721 columnType, product,
3722 columnType, out);
3723 }
3724 }
3725
3726 return sum;
3727 }
3728
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3729 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
3730 const Type& rightType, SpvId rhs,
3731 const Type& resultType, OutputStream& out) {
3732 // The comma operator ignores the type of the left-hand side entirely.
3733 if (op.kind() == Operator::Kind::COMMA) {
3734 return rhs;
3735 }
3736 // overall type we are operating on: float2, int, uint4...
3737 const Type* operandType;
3738 if (types_match(leftType, rightType)) {
3739 operandType = &leftType;
3740 } else {
3741 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
3742 // handling in SPIR-V
3743 if (leftType.isVector() && rightType.isNumber()) {
3744 if (resultType.componentType().isFloat()) {
3745 switch (op.kind()) {
3746 case Operator::Kind::SLASH: {
3747 rhs = this->writeReciprocal(rightType, rhs, out);
3748 [[fallthrough]];
3749 }
3750 case Operator::Kind::STAR: {
3751 SpvId result = this->nextId(&resultType);
3752 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3753 result, lhs, rhs, out);
3754 return result;
3755 }
3756 default:
3757 break;
3758 }
3759 }
3760 // Vectorize the right-hand side.
3761 STArray<4, SpvId> arguments;
3762 arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
3763 rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
3764 operandType = &leftType;
3765 } else if (rightType.isVector() && leftType.isNumber()) {
3766 if (resultType.componentType().isFloat()) {
3767 if (op.kind() == Operator::Kind::STAR) {
3768 SpvId result = this->nextId(&resultType);
3769 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3770 result, rhs, lhs, out);
3771 return result;
3772 }
3773 }
3774 // Vectorize the left-hand side.
3775 STArray<4, SpvId> arguments;
3776 arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
3777 lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
3778 operandType = &rightType;
3779 } else if (leftType.isMatrix()) {
3780 if (op.kind() == Operator::Kind::STAR) {
3781 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
3782 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
3783 if (fCaps.fRewriteMatrixVectorMultiply &&
3784 rightType.isVector() &&
3785 !resultType.highPrecision()) {
3786 return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
3787 resultType, out);
3788 }
3789
3790 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3791 SpvOp_ spvop;
3792 if (rightType.isMatrix()) {
3793 spvop = SpvOpMatrixTimesMatrix;
3794 } else if (rightType.isVector()) {
3795 spvop = SpvOpMatrixTimesVector;
3796 } else {
3797 SkASSERT(rightType.isScalar());
3798 spvop = SpvOpMatrixTimesScalar;
3799 }
3800 SpvId result = this->nextId(&resultType);
3801 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
3802 return result;
3803 } else {
3804 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
3805 // expect to have a scalar here.
3806 SkASSERT(rightType.isScalar());
3807
3808 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
3809 SpvId rhsMatrix = this->splat(leftType, rhs, out);
3810
3811 // Perform this operation as matrix-op-matrix.
3812 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
3813 resultType, out);
3814 }
3815 } else if (rightType.isMatrix()) {
3816 if (op.kind() == Operator::Kind::STAR) {
3817 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3818 SpvId result = this->nextId(&resultType);
3819 if (leftType.isVector()) {
3820 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
3821 result, lhs, rhs, out);
3822 } else {
3823 SkASSERT(leftType.isScalar());
3824 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
3825 result, rhs, lhs, out);
3826 }
3827 return result;
3828 } else {
3829 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
3830 // expect to have a scalar here.
3831 SkASSERT(leftType.isScalar());
3832
3833 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
3834 SpvId lhsMatrix = this->splat(rightType, lhs, out);
3835
3836 // Perform this operation as matrix-op-matrix.
3837 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
3838 resultType, out);
3839 }
3840 } else {
3841 fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
3842 return NA;
3843 }
3844 }
3845
3846 switch (op.kind()) {
3847 case Operator::Kind::EQEQ: {
3848 if (operandType->isMatrix()) {
3849 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
3850 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
3851 }
3852 if (operandType->isStruct()) {
3853 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3854 }
3855 if (operandType->isArray()) {
3856 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3857 }
3858 SkASSERT(resultType.isBoolean());
3859 const Type* tmpType;
3860 if (operandType->isVector()) {
3861 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3862 operandType->columns(),
3863 operandType->rows());
3864 } else {
3865 tmpType = &resultType;
3866 }
3867 if (lhs == rhs) {
3868 // This ignores the effects of NaN.
3869 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
3870 }
3871 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3872 SpvOpFOrdEqual, SpvOpIEqual,
3873 SpvOpIEqual, SpvOpLogicalEqual, out),
3874 *operandType, SpvOpAll, out);
3875 }
3876 case Operator::Kind::NEQ:
3877 if (operandType->isMatrix()) {
3878 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
3879 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
3880 }
3881 if (operandType->isStruct()) {
3882 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3883 }
3884 if (operandType->isArray()) {
3885 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3886 }
3887 [[fallthrough]];
3888 case Operator::Kind::LOGICALXOR:
3889 SkASSERT(resultType.isBoolean());
3890 const Type* tmpType;
3891 if (operandType->isVector()) {
3892 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3893 operandType->columns(),
3894 operandType->rows());
3895 } else {
3896 tmpType = &resultType;
3897 }
3898 if (lhs == rhs) {
3899 // This ignores the effects of NaN.
3900 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
3901 }
3902 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3903 SpvOpFUnordNotEqual, SpvOpINotEqual,
3904 SpvOpINotEqual, SpvOpLogicalNotEqual,
3905 out),
3906 *operandType, SpvOpAny, out);
3907 case Operator::Kind::GT:
3908 SkASSERT(resultType.isBoolean());
3909 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3910 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
3911 SpvOpUGreaterThan, SpvOpUndef, out);
3912 case Operator::Kind::LT:
3913 SkASSERT(resultType.isBoolean());
3914 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
3915 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
3916 case Operator::Kind::GTEQ:
3917 SkASSERT(resultType.isBoolean());
3918 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3919 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
3920 SpvOpUGreaterThanEqual, SpvOpUndef, out);
3921 case Operator::Kind::LTEQ:
3922 SkASSERT(resultType.isBoolean());
3923 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3924 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
3925 SpvOpULessThanEqual, SpvOpUndef, out);
3926 case Operator::Kind::PLUS:
3927 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3928 lhs, rhs, SpvOpFAdd, SpvOpIAdd,
3929 SpvOpIAdd, SpvOpUndef, out);
3930 case Operator::Kind::MINUS:
3931 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3932 lhs, rhs, SpvOpFSub, SpvOpISub,
3933 SpvOpISub, SpvOpUndef, out);
3934 case Operator::Kind::STAR:
3935 if (leftType.isMatrix() && rightType.isMatrix()) {
3936 // matrix multiply
3937 SpvId result = this->nextId(&resultType);
3938 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
3939 lhs, rhs, out);
3940 return result;
3941 }
3942 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
3943 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
3944 case Operator::Kind::SLASH:
3945 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3946 lhs, rhs, SpvOpFDiv, SpvOpSDiv,
3947 SpvOpUDiv, SpvOpUndef, out);
3948 case Operator::Kind::PERCENT:
3949 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
3950 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
3951 case Operator::Kind::SHL:
3952 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3953 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
3954 SpvOpUndef, out);
3955 case Operator::Kind::SHR:
3956 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3957 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
3958 SpvOpUndef, out);
3959 case Operator::Kind::BITWISEAND:
3960 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3961 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
3962 case Operator::Kind::BITWISEOR:
3963 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3964 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
3965 case Operator::Kind::BITWISEXOR:
3966 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3967 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
3968 default:
3969 fContext.fErrors->error(Position(), "unsupported token");
3970 return NA;
3971 }
3972 }
3973
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3974 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
3975 SpvId rhs, OutputStream& out) {
3976 // The inputs must be arrays, and the op must be == or !=.
3977 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3978 SkASSERT(arrayType.isArray());
3979 const Type& componentType = arrayType.componentType();
3980 const int arraySize = arrayType.columns();
3981 SkASSERT(arraySize > 0);
3982
3983 // Synthesize equality checks for each item in the array.
3984 const Type& boolType = *fContext.fTypes.fBool;
3985 SpvId allComparisons = NA;
3986 for (int index = 0; index < arraySize; ++index) {
3987 // Get the left and right item in the array.
3988 SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
3989 SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
3990 // Use `writeBinaryExpression` with the requested == or != operator on these items.
3991 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
3992 componentType, itemR, boolType, out);
3993 // Merge this comparison result with all the other comparisons we've done.
3994 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3995 }
3996 return allComparisons;
3997 }
3998
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3999 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4000 SpvId rhs, OutputStream& out) {
4001 // The inputs must be structs containing fields, and the op must be == or !=.
4002 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4003 SkASSERT(structType.isStruct());
4004 SkSpan<const Field> fields = structType.fields();
4005 SkASSERT(!fields.empty());
4006
4007 // Synthesize equality checks for each field in the struct.
4008 const Type& boolType = *fContext.fTypes.fBool;
4009 SpvId allComparisons = NA;
4010 for (int index = 0; index < (int)fields.size(); ++index) {
4011 // Get the left and right versions of this field.
4012 const Type& fieldType = *fields[index].fType;
4013
4014 SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4015 SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4016 // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4017 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4018 boolType, out);
4019 // Merge this comparison result with all the other comparisons we've done.
4020 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4021 }
4022 return allComparisons;
4023 }
4024
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4025 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4026 OutputStream& out) {
4027 // If this is the first entry, we don't need to merge comparison results with anything.
4028 if (allComparisons == NA) {
4029 return comparison;
4030 }
4031 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4032 const Type& boolType = *fContext.fTypes.fBool;
4033 SpvId boolTypeId = this->getType(boolType);
4034 SpvId logicalOp = this->nextId(&boolType);
4035 switch (op.kind()) {
4036 case Operator::Kind::EQEQ:
4037 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4038 comparison, allComparisons, out);
4039 break;
4040 case Operator::Kind::NEQ:
4041 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4042 comparison, allComparisons, out);
4043 break;
4044 default:
4045 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4046 return NA;
4047 }
4048 return logicalOp;
4049 }
4050
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4051 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4052 const Expression* left = b.left().get();
4053 const Expression* right = b.right().get();
4054 Operator op = b.getOperator();
4055
4056 switch (op.kind()) {
4057 case Operator::Kind::EQ: {
4058 // Handles assignment.
4059 SpvId rhs = this->writeExpression(*right, out);
4060 this->getLValue(*left, out)->store(rhs, out);
4061 return rhs;
4062 }
4063 case Operator::Kind::LOGICALAND:
4064 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4065 return this->writeLogicalAnd(*b.left(), *b.right(), out);
4066
4067 case Operator::Kind::LOGICALOR:
4068 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4069 return this->writeLogicalOr(*b.left(), *b.right(), out);
4070
4071 default:
4072 break;
4073 }
4074
4075 std::unique_ptr<LValue> lvalue;
4076 SpvId lhs;
4077 if (op.isAssignment()) {
4078 lvalue = this->getLValue(*left, out);
4079 lhs = lvalue->load(out);
4080 } else {
4081 lvalue = nullptr;
4082 lhs = this->writeExpression(*left, out);
4083 }
4084
4085 SpvId rhs = this->writeExpression(*right, out);
4086 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4087 right->type(), rhs, b.type(), out);
4088 if (lvalue) {
4089 lvalue->store(result, out);
4090 }
4091 return result;
4092 }
4093
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4094 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4095 OutputStream& out) {
4096 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4097 SpvId lhs = this->writeExpression(left, out);
4098
4099 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4100
4101 SpvId rhsLabel = this->nextId(nullptr);
4102 SpvId end = this->nextId(nullptr);
4103 SpvId lhsBlock = fCurrentBlock;
4104 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4105 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4106 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4107 SpvId rhs = this->writeExpression(right, out);
4108 SpvId rhsBlock = fCurrentBlock;
4109 this->writeInstruction(SpvOpBranch, end, out);
4110 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4111 SpvId result = this->nextId(nullptr);
4112 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4113 lhsBlock, rhs, rhsBlock, out);
4114
4115 return result;
4116 }
4117
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4118 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4119 OutputStream& out) {
4120 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4121 SpvId lhs = this->writeExpression(left, out);
4122
4123 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4124
4125 SpvId rhsLabel = this->nextId(nullptr);
4126 SpvId end = this->nextId(nullptr);
4127 SpvId lhsBlock = fCurrentBlock;
4128 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4129 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4130 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4131 SpvId rhs = this->writeExpression(right, out);
4132 SpvId rhsBlock = fCurrentBlock;
4133 this->writeInstruction(SpvOpBranch, end, out);
4134 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4135 SpvId result = this->nextId(nullptr);
4136 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4137 lhsBlock, rhs, rhsBlock, out);
4138
4139 return result;
4140 }
4141
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4142 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4143 const Type& type = t.type();
4144 SpvId test = this->writeExpression(*t.test(), out);
4145 if (t.ifTrue()->type().columns() == 1 &&
4146 Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4147 Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4148 // both true and false are constants, can just use OpSelect
4149 SpvId result = this->nextId(nullptr);
4150 SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4151 SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4152 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4153 out);
4154 return result;
4155 }
4156
4157 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4158
4159 // was originally using OpPhi to choose the result, but for some reason that is crashing on
4160 // Adreno. Switched to storing the result in a temp variable as glslang does.
4161 SpvId var = this->nextId(nullptr);
4162 this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
4163 var, SpvStorageClassFunction, fVariableBuffer);
4164 SpvId trueLabel = this->nextId(nullptr);
4165 SpvId falseLabel = this->nextId(nullptr);
4166 SpvId end = this->nextId(nullptr);
4167 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4168 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4169 this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4170 this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4171 this->writeInstruction(SpvOpBranch, end, out);
4172 this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4173 this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4174 this->writeInstruction(SpvOpBranch, end, out);
4175 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4176 SpvId result = this->nextId(&type);
4177 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4178
4179 return result;
4180 }
4181
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4182 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4183 const Type& type = p.type();
4184 if (p.getOperator().kind() == Operator::Kind::MINUS) {
4185 SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4186 SkASSERT(negateOp != SpvOpUndef);
4187 SpvId expr = this->writeExpression(*p.operand(), out);
4188 if (type.isMatrix()) {
4189 return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4190 }
4191 SpvId result = this->nextId(&type);
4192 SpvId typeId = this->getType(type);
4193 this->writeInstruction(negateOp, typeId, result, expr, out);
4194 return result;
4195 }
4196 switch (p.getOperator().kind()) {
4197 case Operator::Kind::PLUS:
4198 return this->writeExpression(*p.operand(), out);
4199
4200 case Operator::Kind::PLUSPLUS: {
4201 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4202 SpvId one = this->writeLiteral(1.0, type.componentType());
4203 one = this->splat(type, one, out);
4204 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4205 lv->load(out), one,
4206 SpvOpFAdd, SpvOpIAdd,
4207 SpvOpIAdd, SpvOpUndef,
4208 out);
4209 lv->store(result, out);
4210 return result;
4211 }
4212 case Operator::Kind::MINUSMINUS: {
4213 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4214 SpvId one = this->writeLiteral(1.0, type.componentType());
4215 one = this->splat(type, one, out);
4216 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4217 lv->load(out), one,
4218 SpvOpFSub, SpvOpISub,
4219 SpvOpISub, SpvOpUndef,
4220 out);
4221 lv->store(result, out);
4222 return result;
4223 }
4224 case Operator::Kind::LOGICALNOT: {
4225 SkASSERT(p.operand()->type().isBoolean());
4226 SpvId result = this->nextId(nullptr);
4227 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4228 this->writeExpression(*p.operand(), out), out);
4229 return result;
4230 }
4231 case Operator::Kind::BITWISENOT: {
4232 SpvId result = this->nextId(nullptr);
4233 this->writeInstruction(SpvOpNot, this->getType(type), result,
4234 this->writeExpression(*p.operand(), out), out);
4235 return result;
4236 }
4237 default:
4238 SkDEBUGFAILF("unsupported prefix expression: %s",
4239 p.description(OperatorPrecedence::kExpression).c_str());
4240 return NA;
4241 }
4242 }
4243
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4244 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4245 const Type& type = p.type();
4246 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4247 SpvId result = lv->load(out);
4248 SpvId one = this->writeLiteral(1.0, type.componentType());
4249 one = this->splat(type, one, out);
4250 switch (p.getOperator().kind()) {
4251 case Operator::Kind::PLUSPLUS: {
4252 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4253 SpvOpFAdd, SpvOpIAdd,
4254 SpvOpIAdd, SpvOpUndef,
4255 out);
4256 lv->store(temp, out);
4257 return result;
4258 }
4259 case Operator::Kind::MINUSMINUS: {
4260 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4261 SpvOpFSub, SpvOpISub,
4262 SpvOpISub, SpvOpUndef,
4263 out);
4264 lv->store(temp, out);
4265 return result;
4266 }
4267 default:
4268 SkDEBUGFAILF("unsupported postfix expression %s",
4269 p.description(OperatorPrecedence::kExpression).c_str());
4270 return NA;
4271 }
4272 }
4273
writeLiteral(const Literal & l)4274 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4275 return this->writeLiteral(l.value(), l.type());
4276 }
4277
writeLiteral(double value,const Type & type)4278 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4279 switch (type.numberKind()) {
4280 case Type::NumberKind::kFloat: {
4281 float floatVal = value;
4282 int32_t valueBits;
4283 memcpy(&valueBits, &floatVal, sizeof(valueBits));
4284 return this->writeOpConstant(type, valueBits);
4285 }
4286 case Type::NumberKind::kBoolean: {
4287 return value ? this->writeOpConstantTrue(type)
4288 : this->writeOpConstantFalse(type);
4289 }
4290 default: {
4291 return this->writeOpConstant(type, (SKSL_INT)value);
4292 }
4293 }
4294 }
4295
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4296 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4297 SpvId result = fFunctionMap[&f];
4298 SpvId returnTypeId = this->getType(f.returnType());
4299 SpvId functionTypeId = this->getFunctionType(f);
4300 this->writeInstruction(SpvOpFunction, returnTypeId, result,
4301 SpvFunctionControlMaskNone, functionTypeId, out);
4302 std::string mangledName = f.mangledName();
4303 this->writeInstruction(SpvOpName,
4304 result,
4305 std::string_view(mangledName.c_str(), mangledName.size()),
4306 fNameBuffer);
4307 for (const Variable* parameter : f.parameters()) {
4308 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4309 auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4310
4311 SpvId textureId = this->nextId(nullptr);
4312 SpvId samplerId = this->nextId(nullptr);
4313 fVariableMap.set(texture, textureId);
4314 fVariableMap.set(sampler, samplerId);
4315
4316 SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4317 SpvId samplerType = this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4318
4319 this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4320 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4321 } else {
4322 SpvId id = this->nextId(nullptr);
4323 fVariableMap.set(parameter, id);
4324
4325 SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4326 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4327 }
4328 }
4329 return result;
4330 }
4331
writeFunction(const FunctionDefinition & f,OutputStream & out)4332 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4333 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4334
4335 fVariableBuffer.reset();
4336 SpvId result = this->writeFunctionStart(f.declaration(), out);
4337 fCurrentBlock = 0;
4338 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4339 StringStream bodyBuffer;
4340 this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4341 write_stringstream(fVariableBuffer, out);
4342 if (f.declaration().isMain()) {
4343 write_stringstream(fGlobalInitializersBuffer, out);
4344 }
4345 write_stringstream(bodyBuffer, out);
4346 if (fCurrentBlock) {
4347 if (f.declaration().returnType().isVoid()) {
4348 this->writeInstruction(SpvOpReturn, out);
4349 } else {
4350 this->writeInstruction(SpvOpUnreachable, out);
4351 }
4352 }
4353 this->writeInstruction(SpvOpFunctionEnd, out);
4354 this->pruneConditionalOps(conditionalOps);
4355 return result;
4356 }
4357
writeLayout(const Layout & layout,SpvId target,Position pos)4358 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4359 bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4360 if (layout.fLocation >= 0) {
4361 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4362 fDecorationBuffer);
4363 }
4364 if (layout.fBinding >= 0) {
4365 if (isPushConstant) {
4366 fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4367 } else {
4368 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4369 fDecorationBuffer);
4370 }
4371 }
4372 if (layout.fIndex >= 0) {
4373 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4374 fDecorationBuffer);
4375 }
4376 if (layout.fSet >= 0) {
4377 if (isPushConstant) {
4378 fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4379 } else {
4380 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4381 fDecorationBuffer);
4382 }
4383 }
4384 if (layout.fInputAttachmentIndex >= 0) {
4385 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4386 layout.fInputAttachmentIndex, fDecorationBuffer);
4387 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4388 }
4389 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
4390 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4391 fDecorationBuffer);
4392 }
4393 }
4394
writeFieldLayout(const Layout & layout,SpvId target,int member)4395 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4396 // 'binding' and 'set' can not be applied to struct members
4397 SkASSERT(layout.fBinding == -1);
4398 SkASSERT(layout.fSet == -1);
4399 if (layout.fLocation >= 0) {
4400 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4401 layout.fLocation, fDecorationBuffer);
4402 }
4403 if (layout.fIndex >= 0) {
4404 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4405 layout.fIndex, fDecorationBuffer);
4406 }
4407 if (layout.fInputAttachmentIndex >= 0) {
4408 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4409 layout.fInputAttachmentIndex, fDecorationBuffer);
4410 }
4411 if (layout.fBuiltin >= 0) {
4412 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4413 layout.fBuiltin, fDecorationBuffer);
4414 }
4415 }
4416
memoryLayoutForStorageClass(SpvStorageClass_ storageClass)4417 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(SpvStorageClass_ storageClass) {
4418 return storageClass == SpvStorageClassPushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
4419 : fDefaultMemoryLayout;
4420 }
4421
memoryLayoutForVariable(const Variable & v) const4422 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4423 bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4424 return pushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
4425 : fDefaultMemoryLayout;
4426 }
4427
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4428 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4429 MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4430 SpvId result = this->nextId(nullptr);
4431 const Variable& intfVar = *intf.var();
4432 const Type& type = intfVar.type();
4433 if (!memoryLayout.isSupported(type)) {
4434 fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4435 "' is not permitted here");
4436 return this->nextId(nullptr);
4437 }
4438 SpvStorageClass_ storageClass =
4439 get_storage_class_for_global_variable(intfVar, SpvStorageClassFunction);
4440 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4441 !fWroteRTFlip && type.isStruct()) {
4442 // We can only have one interface block (because we use push_constant and that is limited
4443 // to one per program), so we need to append rtflip to this one rather than synthesize an
4444 // entirely new block when the variable is referenced. And we can't modify the existing
4445 // block, so we instead create a modified copy of it and write that.
4446 SkSpan<const Field> fieldSpan = type.fields();
4447 TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4448 fields.emplace_back(Position(),
4449 Layout(LayoutFlag::kNone,
4450 /*location=*/-1,
4451 fProgram.fConfig->fSettings.fRTFlipOffset,
4452 /*binding=*/-1,
4453 /*index=*/-1,
4454 /*set=*/-1,
4455 /*builtin=*/-1,
4456 /*inputAttachmentIndex=*/-1),
4457 ModifierFlag::kNone,
4458 SKSL_RTFLIP_NAME,
4459 fContext.fTypes.fFloat2.get());
4460 {
4461 AutoAttachPoolToThread attach(fProgram.fPool.get());
4462 const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4463 Type::MakeStructType(fContext,
4464 type.fPosition,
4465 type.name(),
4466 std::move(fields),
4467 /*interfaceBlock=*/true));
4468 Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4469 Variable::Make(intfVar.fPosition,
4470 intfVar.modifiersPosition(),
4471 intfVar.layout(),
4472 intfVar.modifierFlags(),
4473 rtFlipStructType,
4474 intfVar.name(),
4475 /*mangledName=*/"",
4476 intfVar.isBuiltin(),
4477 intfVar.storage()));
4478 InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4479 result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4480 fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4481 Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4482 }
4483 fVariableMap.set(&intfVar, result);
4484 fWroteRTFlip = true;
4485 return result;
4486 }
4487 SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4488 if (intfVar.layout().fBuiltin == -1) {
4489 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4490 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4491 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4492 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4493 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4494 // would benefit from SPV_KHR_variable_pointers capabilities).
4495 bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4496 this->writeInstruction(SpvOpDecorate,
4497 typeId,
4498 isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4499 fDecorationBuffer);
4500 }
4501 SpvId ptrType = this->nextId(nullptr);
4502 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
4503 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
4504 Layout layout = intfVar.layout();
4505 if (storageClass == SpvStorageClassUniform && layout.fSet < 0) {
4506 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4507 }
4508 this->writeLayout(layout, result, intfVar.fPosition);
4509 fVariableMap.set(&intfVar, result);
4510 return result;
4511 }
4512
4513 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4514 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4515 // always reference by value.
4516 //
4517 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4518 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4519 // OpVectorExtractDynamic (see writeIndexExpression).
4520 //
4521 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4522 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4523 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4524 return varDecl.var()->modifierFlags().isConst() &&
4525 (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4526 (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4527 Analysis::IsCompileTimeConstant(*varDecl.value()));
4528 }
4529
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4530 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4531 const VarDeclaration& varDecl) {
4532 const Variable* var = varDecl.var();
4533 const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4534 const LayoutFlags kPermittedBackendFlags =
4535 LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4536 if (backendFlags & ~kPermittedBackendFlags) {
4537 fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4538 return false;
4539 }
4540
4541 // If this global variable is a compile-time constant then we'll emit OpConstant or
4542 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4543 if (is_vardecl_compile_time_constant(varDecl)) {
4544 return true;
4545 }
4546
4547 SpvStorageClass_ storageClass =
4548 get_storage_class_for_global_variable(*var, SpvStorageClassPrivate);
4549 if (storageClass == SpvStorageClassUniform) {
4550 // Top-level uniforms are emitted in writeUniformBuffer.
4551 fTopLevelUniforms.push_back(&varDecl);
4552 return true;
4553 }
4554
4555 if (fUseTextureSamplerPairs && var->type().isSampler()) {
4556 if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4557 fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4558 "and sampler indices");
4559 return false;
4560 }
4561 SkASSERT(storageClass == SpvStorageClassUniformConstant);
4562
4563 auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4564 this->writeGlobalVar(kind, storageClass, *texture);
4565 this->writeGlobalVar(kind, storageClass, *sampler);
4566
4567 return true;
4568 }
4569
4570 SpvId id = this->writeGlobalVar(kind, storageClass, *var);
4571 if (id != NA && varDecl.value()) {
4572 SkASSERT(!fCurrentBlock);
4573 fCurrentBlock = NA;
4574 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
4575 this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
4576 fCurrentBlock = 0;
4577 }
4578 return true;
4579 }
4580
writeGlobalVar(ProgramKind kind,SpvStorageClass_ storageClass,const Variable & var)4581 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
4582 SpvStorageClass_ storageClass,
4583 const Variable& var) {
4584 Layout layout = var.layout();
4585 const ModifierFlags flags = var.modifierFlags();
4586 const Type* type = &var.type();
4587 switch (layout.fBuiltin) {
4588 case SK_FRAGCOLOR_BUILTIN:
4589 if (!ProgramConfig::IsFragment(kind)) {
4590 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
4591 return NA;
4592 }
4593 break;
4594
4595 case SK_SAMPLEMASKIN_BUILTIN:
4596 case SK_SAMPLEMASK_BUILTIN:
4597 // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
4598 // decorated with SpvBuiltinSampleMask.
4599 type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
4600 layout.fBuiltin = SpvBuiltInSampleMask;
4601 break;
4602 }
4603
4604 // Add this global to the variable map.
4605 SpvId id = this->nextId(type);
4606 fVariableMap.set(&var, id);
4607
4608 if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
4609 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4610 }
4611
4612 SpvId typeId = this->getPointerType(*type,
4613 layout,
4614 this->memoryLayoutForStorageClass(storageClass),
4615 storageClass);
4616 this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
4617 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
4618 this->writeLayout(layout, id, var.fPosition);
4619 if (flags & ModifierFlag::kFlat) {
4620 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
4621 }
4622 if (flags & ModifierFlag::kNoPerspective) {
4623 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
4624 fDecorationBuffer);
4625 }
4626 if (flags.isWriteOnly()) {
4627 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
4628 } else if (flags.isReadOnly()) {
4629 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
4630 }
4631
4632 return id;
4633 }
4634
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)4635 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
4636 // If this variable is a compile-time constant then we'll emit OpConstant or
4637 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4638 if (is_vardecl_compile_time_constant(varDecl)) {
4639 return;
4640 }
4641
4642 const Variable* var = varDecl.var();
4643 SpvId id = this->nextId(&var->type());
4644 fVariableMap.set(var, id);
4645 SpvId type = this->getPointerType(var->type(), SpvStorageClassFunction);
4646 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
4647 this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4648 if (varDecl.value()) {
4649 SpvId value = this->writeExpression(*varDecl.value(), out);
4650 this->writeOpStore(SpvStorageClassFunction, id, value, out);
4651 }
4652 }
4653
writeStatement(const Statement & s,OutputStream & out)4654 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
4655 switch (s.kind()) {
4656 case Statement::Kind::kNop:
4657 break;
4658 case Statement::Kind::kBlock:
4659 this->writeBlock(s.as<Block>(), out);
4660 break;
4661 case Statement::Kind::kExpression:
4662 this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
4663 break;
4664 case Statement::Kind::kReturn:
4665 this->writeReturnStatement(s.as<ReturnStatement>(), out);
4666 break;
4667 case Statement::Kind::kVarDeclaration:
4668 this->writeVarDeclaration(s.as<VarDeclaration>(), out);
4669 break;
4670 case Statement::Kind::kIf:
4671 this->writeIfStatement(s.as<IfStatement>(), out);
4672 break;
4673 case Statement::Kind::kFor:
4674 this->writeForStatement(s.as<ForStatement>(), out);
4675 break;
4676 case Statement::Kind::kDo:
4677 this->writeDoStatement(s.as<DoStatement>(), out);
4678 break;
4679 case Statement::Kind::kSwitch:
4680 this->writeSwitchStatement(s.as<SwitchStatement>(), out);
4681 break;
4682 case Statement::Kind::kBreak:
4683 this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
4684 break;
4685 case Statement::Kind::kContinue:
4686 this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
4687 break;
4688 case Statement::Kind::kDiscard:
4689 this->writeInstruction(SpvOpKill, out);
4690 break;
4691 default:
4692 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
4693 break;
4694 }
4695 }
4696
writeBlock(const Block & b,OutputStream & out)4697 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
4698 for (const std::unique_ptr<Statement>& stmt : b.children()) {
4699 this->writeStatement(*stmt, out);
4700 }
4701 }
4702
getConditionalOpCounts()4703 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
4704 return {fReachableOps.size(), fStoreOps.size()};
4705 }
4706
pruneConditionalOps(ConditionalOpCounts ops)4707 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
4708 // Remove ops which are no longer reachable.
4709 while (fReachableOps.size() > ops.numReachableOps) {
4710 SpvId prunableSpvId = fReachableOps.back();
4711 const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
4712
4713 if (prunableOp) {
4714 fOpCache.remove(*prunableOp);
4715 fSpvIdCache.remove(prunableSpvId);
4716 } else {
4717 SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
4718 }
4719
4720 fReachableOps.pop_back();
4721 }
4722
4723 // Remove any cached stores that occurred during the conditional block.
4724 while (fStoreOps.size() > ops.numStoreOps) {
4725 if (fStoreCache.find(fStoreOps.back())) {
4726 fStoreCache.remove(fStoreOps.back());
4727 }
4728 fStoreOps.pop_back();
4729 }
4730 }
4731
writeIfStatement(const IfStatement & stmt,OutputStream & out)4732 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
4733 SpvId test = this->writeExpression(*stmt.test(), out);
4734 SpvId ifTrue = this->nextId(nullptr);
4735 SpvId ifFalse = this->nextId(nullptr);
4736
4737 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4738
4739 if (stmt.ifFalse()) {
4740 SpvId end = this->nextId(nullptr);
4741 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4742 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4743 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4744 this->writeStatement(*stmt.ifTrue(), out);
4745 if (fCurrentBlock) {
4746 this->writeInstruction(SpvOpBranch, end, out);
4747 }
4748 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4749 this->writeStatement(*stmt.ifFalse(), out);
4750 if (fCurrentBlock) {
4751 this->writeInstruction(SpvOpBranch, end, out);
4752 }
4753 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4754 } else {
4755 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
4756 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4757 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4758 this->writeStatement(*stmt.ifTrue(), out);
4759 if (fCurrentBlock) {
4760 this->writeInstruction(SpvOpBranch, ifFalse, out);
4761 }
4762 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4763 }
4764 }
4765
writeForStatement(const ForStatement & f,OutputStream & out)4766 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
4767 if (f.initializer()) {
4768 this->writeStatement(*f.initializer(), out);
4769 }
4770
4771 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4772
4773 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4774 // in the context of linear straight-line execution. If we wanted to be more clever, we could
4775 // only invalidate store cache entries for variables affected by the loop body, but for now we
4776 // simply clear the entire cache whenever branching occurs.
4777 SpvId header = this->nextId(nullptr);
4778 SpvId start = this->nextId(nullptr);
4779 SpvId body = this->nextId(nullptr);
4780 SpvId next = this->nextId(nullptr);
4781 fContinueTarget.push_back(next);
4782 SpvId end = this->nextId(nullptr);
4783 fBreakTarget.push_back(end);
4784 this->writeInstruction(SpvOpBranch, header, out);
4785 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4786 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
4787 this->writeInstruction(SpvOpBranch, start, out);
4788 this->writeLabel(start, kBranchIsOnPreviousLine, out);
4789 if (f.test()) {
4790 SpvId test = this->writeExpression(*f.test(), out);
4791 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
4792 } else {
4793 this->writeInstruction(SpvOpBranch, body, out);
4794 }
4795 this->writeLabel(body, kBranchIsOnPreviousLine, out);
4796 this->writeStatement(*f.statement(), out);
4797 if (fCurrentBlock) {
4798 this->writeInstruction(SpvOpBranch, next, out);
4799 }
4800 this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
4801 if (f.next()) {
4802 this->writeExpression(*f.next(), out);
4803 }
4804 this->writeInstruction(SpvOpBranch, header, out);
4805 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4806 fBreakTarget.pop_back();
4807 fContinueTarget.pop_back();
4808 }
4809
writeDoStatement(const DoStatement & d,OutputStream & out)4810 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
4811 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4812
4813 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4814 // in the context of linear straight-line execution. If we wanted to be more clever, we could
4815 // only invalidate store cache entries for variables affected by the loop body, but for now we
4816 // simply clear the entire cache whenever branching occurs.
4817 SpvId header = this->nextId(nullptr);
4818 SpvId start = this->nextId(nullptr);
4819 SpvId next = this->nextId(nullptr);
4820 SpvId continueTarget = this->nextId(nullptr);
4821 fContinueTarget.push_back(continueTarget);
4822 SpvId end = this->nextId(nullptr);
4823 fBreakTarget.push_back(end);
4824 this->writeInstruction(SpvOpBranch, header, out);
4825 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4826 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
4827 this->writeInstruction(SpvOpBranch, start, out);
4828 this->writeLabel(start, kBranchIsOnPreviousLine, out);
4829 this->writeStatement(*d.statement(), out);
4830 if (fCurrentBlock) {
4831 this->writeInstruction(SpvOpBranch, next, out);
4832 this->writeLabel(next, kBranchIsOnPreviousLine, out);
4833 this->writeInstruction(SpvOpBranch, continueTarget, out);
4834 }
4835 this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
4836 SpvId test = this->writeExpression(*d.test(), out);
4837 this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
4838 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4839 fBreakTarget.pop_back();
4840 fContinueTarget.pop_back();
4841 }
4842
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)4843 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
4844 SpvId value = this->writeExpression(*s.value(), out);
4845
4846 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4847
4848 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4849 // in the context of linear straight-line execution. If we wanted to be more clever, we could
4850 // only invalidate store cache entries for variables affected by the switch body, but for now we
4851 // simply clear the entire cache whenever branching occurs.
4852 TArray<SpvId> labels;
4853 SpvId end = this->nextId(nullptr);
4854 SpvId defaultLabel = end;
4855 fBreakTarget.push_back(end);
4856 int size = 3;
4857 const StatementArray& cases = s.cases();
4858 for (const std::unique_ptr<Statement>& stmt : cases) {
4859 const SwitchCase& c = stmt->as<SwitchCase>();
4860 SpvId label = this->nextId(nullptr);
4861 labels.push_back(label);
4862 if (!c.isDefault()) {
4863 size += 2;
4864 } else {
4865 defaultLabel = label;
4866 }
4867 }
4868
4869 // We should have exactly one label for each case.
4870 SkASSERT(labels.size() == cases.size());
4871
4872 // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
4873 // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
4874 // does support multiple switch-cases branching to the same label.
4875 SkBitSet caseIsCollapsed(cases.size());
4876 for (int index = cases.size() - 2; index >= 0; index--) {
4877 if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
4878 caseIsCollapsed.set(index);
4879 labels[index] = labels[index + 1];
4880 }
4881 }
4882
4883 labels.push_back(end);
4884
4885 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4886 this->writeOpCode(SpvOpSwitch, size, out);
4887 this->writeWord(value, out);
4888 this->writeWord(defaultLabel, out);
4889 for (int i = 0; i < cases.size(); ++i) {
4890 const SwitchCase& c = cases[i]->as<SwitchCase>();
4891 if (c.isDefault()) {
4892 continue;
4893 }
4894 this->writeWord(c.value(), out);
4895 this->writeWord(labels[i], out);
4896 }
4897 for (int i = 0; i < cases.size(); ++i) {
4898 if (caseIsCollapsed.test(i)) {
4899 continue;
4900 }
4901 const SwitchCase& c = cases[i]->as<SwitchCase>();
4902 if (i == 0) {
4903 this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
4904 } else {
4905 this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
4906 }
4907 this->writeStatement(*c.statement(), out);
4908 if (fCurrentBlock) {
4909 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
4910 }
4911 }
4912 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4913 fBreakTarget.pop_back();
4914 }
4915
writeReturnStatement(const ReturnStatement & r,OutputStream & out)4916 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
4917 if (r.expression()) {
4918 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
4919 out);
4920 } else {
4921 this->writeInstruction(SpvOpReturn, out);
4922 }
4923 }
4924
4925 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)4926 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
4927 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
4928 }
4929
writeEntrypointAdapter(const FunctionDeclaration & main)4930 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
4931 const FunctionDeclaration& main) {
4932 // Our goal is to synthesize a tiny helper function which looks like this:
4933 // void _entrypoint() { sk_FragColor = main(); }
4934
4935 // Fish a symbol table out of main().
4936 SymbolTable* symbolTable = get_top_level_symbol_table(main);
4937
4938 // Get `sk_FragColor` as a writable reference.
4939 const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
4940 SkASSERT(skFragColorSymbol);
4941 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
4942 auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
4943 VariableReference::RefKind::kWrite);
4944 // Synthesize a call to the `main()` function.
4945 if (!main.returnType().matches(skFragColorRef->type())) {
4946 fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
4947 main.returnType().description() + "' from main()");
4948 return {};
4949 }
4950 ExpressionArray args;
4951 if (main.parameters().size() == 1) {
4952 if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
4953 fContext.fErrors->error(main.fPosition,
4954 "SPIR-V does not support parameter of type '" +
4955 main.parameters()[0]->type().description() + "' to main()");
4956 return {};
4957 }
4958 double kZero[2] = {0.0, 0.0};
4959 args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
4960 *fContext.fTypes.fFloat2, kZero));
4961 }
4962 auto callMainFn = std::make_unique<FunctionCall>(Position(), &main.returnType(), &main,
4963 std::move(args));
4964
4965 // Synthesize `skFragColor = main()` as a BinaryExpression.
4966 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
4967 Position(),
4968 std::move(skFragColorRef),
4969 Operator::Kind::EQ,
4970 std::move(callMainFn),
4971 &main.returnType()));
4972
4973 // Function bodies are always wrapped in a Block.
4974 StatementArray entrypointStmts;
4975 entrypointStmts.push_back(std::move(assignmentStmt));
4976 auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
4977 Block::Kind::kBracedScope, /*symbols=*/nullptr);
4978 // Declare an entrypoint function.
4979 EntrypointAdapter adapter;
4980 adapter.entrypointDecl =
4981 std::make_unique<FunctionDeclaration>(fContext,
4982 Position(),
4983 ModifierFlag::kNone,
4984 "_entrypoint",
4985 /*parameters=*/TArray<Variable*>{},
4986 /*returnType=*/fContext.fTypes.fVoid.get(),
4987 kNotIntrinsic);
4988 // Define it.
4989 adapter.entrypointDef = FunctionDefinition::Convert(fContext,
4990 Position(),
4991 *adapter.entrypointDecl,
4992 std::move(entrypointBlock),
4993 /*builtin=*/false);
4994
4995 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
4996 return adapter;
4997 }
4998
writeUniformBuffer(SymbolTable * topLevelSymbolTable)4999 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5000 SkASSERT(!fTopLevelUniforms.empty());
5001 static constexpr char kUniformBufferName[] = "_UniformBuffer";
5002
5003 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5004 // a lookup table of variables to UniformBuffer field indices.
5005 TArray<Field> fields;
5006 fields.reserve_exact(fTopLevelUniforms.size());
5007 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5008 const Variable* var = topLevelUniform->var();
5009 fTopLevelUniformMap.set(var, (int)fields.size());
5010 ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5011 fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5012 }
5013 fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5014 Position(),
5015 kUniformBufferName,
5016 std::move(fields),
5017 /*interfaceBlock=*/true);
5018
5019 // Create a global variable to contain this struct.
5020 Layout layout;
5021 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5022 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
5023
5024 fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5025 /*modifiersPosition=*/Position(),
5026 layout,
5027 ModifierFlag::kUniform,
5028 fUniformBuffer.fStruct.get(),
5029 kUniformBufferName,
5030 /*mangledName=*/"",
5031 /*builtin=*/false,
5032 Variable::Storage::kGlobal);
5033
5034 // Create an interface block object for this global variable.
5035 fUniformBuffer.fInterfaceBlock =
5036 std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5037
5038 // Generate an interface block and hold onto its ID.
5039 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5040 }
5041
addRTFlipUniform(Position pos)5042 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5043 SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5044
5045 if (fWroteRTFlip) {
5046 return;
5047 }
5048 // Flip variable hasn't been written yet. This means we don't have an existing
5049 // interface block, so we're free to just synthesize one.
5050 fWroteRTFlip = true;
5051 TArray<Field> fields;
5052 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5053 fContext.fErrors->error(pos, "RTFlipOffset is negative");
5054 }
5055 fields.emplace_back(pos,
5056 Layout(LayoutFlag::kNone,
5057 /*location=*/-1,
5058 fProgram.fConfig->fSettings.fRTFlipOffset,
5059 /*binding=*/-1,
5060 /*index=*/-1,
5061 /*set=*/-1,
5062 /*builtin=*/-1,
5063 /*inputAttachmentIndex=*/-1),
5064 ModifierFlag::kNone,
5065 SKSL_RTFLIP_NAME,
5066 fContext.fTypes.fFloat2.get());
5067 std::string_view name = "sksl_synthetic_uniforms";
5068 const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5069 fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5070 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
5071 int binding = -1, set = -1;
5072 if (!usePushConstants) {
5073 binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5074 if (binding == -1) {
5075 fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5076 }
5077 set = fProgram.fConfig->fSettings.fRTFlipSet;
5078 if (set == -1) {
5079 fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5080 }
5081 }
5082 Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5083 /*location=*/-1,
5084 /*offset=*/-1,
5085 binding,
5086 /*index=*/-1,
5087 set,
5088 /*builtin=*/-1,
5089 /*inputAttachmentIndex=*/-1);
5090 Variable* intfVar =
5091 fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5092 /*modifiersPosition=*/Position(),
5093 layout,
5094 ModifierFlag::kUniform,
5095 intfStruct,
5096 name,
5097 /*mangledName=*/"",
5098 /*builtin=*/false,
5099 Variable::Storage::kGlobal));
5100 {
5101 AutoAttachPoolToThread attach(fProgram.fPool.get());
5102 fProgram.fSymbols->add(fContext,
5103 std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5104 }
5105 InterfaceBlock intf(Position(), intfVar);
5106 this->writeInterfaceBlock(intf, false);
5107 }
5108
synthesizeTextureAndSampler(const Variable & combinedSampler)5109 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5110 const Variable& combinedSampler) {
5111 SkASSERT(fUseTextureSamplerPairs);
5112 SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5113
5114 auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5115
5116 Layout texLayout = combinedSampler.layout();
5117 texLayout.fBinding = texLayout.fTexture;
5118 data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5119
5120 auto texture = Variable::Make(/*pos=*/Position(),
5121 /*modifiersPosition=*/Position(),
5122 texLayout,
5123 combinedSampler.modifierFlags(),
5124 &combinedSampler.type().textureType(),
5125 data->fTextureName,
5126 /*mangledName=*/"",
5127 /*builtin=*/false,
5128 Variable::Storage::kGlobal);
5129
5130 Layout samplerLayout = combinedSampler.layout();
5131 samplerLayout.fBinding = samplerLayout.fSampler;
5132 samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5133 data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5134
5135 auto sampler = Variable::Make(/*pos=*/Position(),
5136 /*modifiersPosition=*/Position(),
5137 samplerLayout,
5138 combinedSampler.modifierFlags(),
5139 fContext.fTypes.fSampler.get(),
5140 data->fSamplerName,
5141 /*mangledName=*/"",
5142 /*builtin=*/false,
5143 Variable::Storage::kGlobal);
5144
5145 const Variable* t = texture.get();
5146 const Variable* s = sampler.get();
5147 data->fTexture = std::move(texture);
5148 data->fSampler = std::move(sampler);
5149 fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5150
5151 return {t, s};
5152 }
5153
writeInstructions(const Program & program,OutputStream & out)5154 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5155 fGLSLExtendedInstructions = this->nextId(nullptr);
5156 StringStream body;
5157
5158 // Do an initial pass over the program elements to establish some baseline info.
5159 const FunctionDeclaration* main = nullptr;
5160 int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5161 Position combinedSamplerPos;
5162 Position separateSamplerPos;
5163 for (const ProgramElement* e : program.elements()) {
5164 switch (e->kind()) {
5165 case ProgramElement::Kind::kFunction: {
5166 // Assign SpvIds to functions.
5167 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5168 const FunctionDeclaration& funcDecl = funcDef.declaration();
5169 fFunctionMap.set(&funcDecl, this->nextId(nullptr));
5170 if (funcDecl.isMain()) {
5171 main = &funcDecl;
5172 }
5173 break;
5174 }
5175 case ProgramElement::Kind::kGlobalVar: {
5176 // Look for sampler variables and determine whether or not this program uses
5177 // combined samplers or separate samplers. The layout backend will be marked as
5178 // WebGPU for separate samplers, or Vulkan for combined samplers.
5179 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5180 const Variable& var = *decl.varDeclaration().var();
5181 if (var.type().isSampler()) {
5182 if (var.layout().fFlags & LayoutFlag::kVulkan) {
5183 combinedSamplerPos = decl.position();
5184 }
5185 if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5186 separateSamplerPos = decl.position();
5187 }
5188 }
5189 break;
5190 }
5191 case ProgramElement::Kind::kModifiers: {
5192 // If this is a compute program, collect the local-size values. Dimensions that are
5193 // not present will be assigned a value of 1.
5194 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5195 const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5196 if (modifiers.layout().fLocalSizeX >= 0) {
5197 localSizeX = modifiers.layout().fLocalSizeX;
5198 }
5199 if (modifiers.layout().fLocalSizeY >= 0) {
5200 localSizeY = modifiers.layout().fLocalSizeY;
5201 }
5202 if (modifiers.layout().fLocalSizeZ >= 0) {
5203 localSizeZ = modifiers.layout().fLocalSizeZ;
5204 }
5205 }
5206 break;
5207 }
5208 default:
5209 break;
5210 }
5211 }
5212
5213 // Make sure we have a main() function.
5214 if (!main) {
5215 fContext.fErrors->error(Position(), "program does not contain a main() function");
5216 return;
5217 }
5218 // Make sure our program's sampler usage is consistent.
5219 if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5220 fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5221 fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5222 fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5223 return;
5224 }
5225 fUseTextureSamplerPairs = separateSamplerPos.valid();
5226
5227 // Emit interface blocks.
5228 std::set<SpvId> interfaceVars;
5229 for (const ProgramElement* e : program.elements()) {
5230 if (e->is<InterfaceBlock>()) {
5231 const InterfaceBlock& intf = e->as<InterfaceBlock>();
5232 SpvId id = this->writeInterfaceBlock(intf);
5233
5234 if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5235 intf.var()->layout().fBuiltin == -1) {
5236 interfaceVars.insert(id);
5237 }
5238 }
5239 }
5240 // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5241 // explicitly declared in the output, whether or not the program explicitly references it.
5242 // However, if the program naturally declares it, we don't want to include it a second time;
5243 // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5244 const VarDeclaration* missingClockwiseDecl = nullptr;
5245 if (fCaps.fMustDeclareFragmentFrontFacing) {
5246 if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5247 missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5248 }
5249 }
5250 // Emit global variable declarations.
5251 for (const ProgramElement* e : program.elements()) {
5252 if (e->is<GlobalVarDeclaration>()) {
5253 const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5254 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5255 return;
5256 }
5257 if (missingClockwiseDecl == &decl) {
5258 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5259 missingClockwiseDecl = nullptr;
5260 }
5261 }
5262 }
5263 // All the global variables have been declared. If sk_Clockwise was not naturally included in
5264 // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5265 if (missingClockwiseDecl) {
5266 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5267 return;
5268 }
5269 missingClockwiseDecl = nullptr;
5270 }
5271 // Emit top-level uniforms into a dedicated uniform buffer.
5272 if (!fTopLevelUniforms.empty()) {
5273 this->writeUniformBuffer(get_top_level_symbol_table(*main));
5274 }
5275 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5276 // main() and stores the result into sk_FragColor.
5277 EntrypointAdapter adapter;
5278 if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5279 adapter = this->writeEntrypointAdapter(*main);
5280 if (adapter.entrypointDecl) {
5281 fFunctionMap.set(adapter.entrypointDecl.get(), this->nextId(nullptr));
5282 this->writeFunction(*adapter.entrypointDef, body);
5283 main = adapter.entrypointDecl.get();
5284 }
5285 }
5286 // Emit all the functions.
5287 for (const ProgramElement* e : program.elements()) {
5288 if (e->is<FunctionDefinition>()) {
5289 this->writeFunction(e->as<FunctionDefinition>(), body);
5290 }
5291 }
5292 // Add global in/out variables to the list of interface variables.
5293 for (const auto& [var, spvId] : fVariableMap) {
5294 if (var->storage() == Variable::Storage::kGlobal &&
5295 (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5296 interfaceVars.insert(spvId);
5297 }
5298 }
5299 this->writeCapabilities(out);
5300 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5301 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5302 this->writeOpCode(SpvOpEntryPoint,
5303 (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5304 out);
5305 if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5306 this->writeWord(SpvExecutionModelVertex, out);
5307 } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5308 this->writeWord(SpvExecutionModelFragment, out);
5309 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5310 this->writeWord(SpvExecutionModelGLCompute, out);
5311 } else {
5312 SK_ABORT("cannot write this kind of program to SPIR-V\n");
5313 }
5314 SpvId entryPoint = fFunctionMap[main];
5315 this->writeWord(entryPoint, out);
5316 this->writeString(main->name(), out);
5317 for (int var : interfaceVars) {
5318 this->writeWord(var, out);
5319 }
5320 if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5321 this->writeInstruction(SpvOpExecutionMode,
5322 fFunctionMap[main],
5323 SpvExecutionModeOriginUpperLeft,
5324 out);
5325 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5326 this->writeInstruction(SpvOpExecutionMode,
5327 fFunctionMap[main],
5328 SpvExecutionModeLocalSize,
5329 localSizeX, localSizeY, localSizeZ,
5330 out);
5331 }
5332 for (const ProgramElement* e : program.elements()) {
5333 if (e->is<Extension>()) {
5334 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5335 }
5336 }
5337
5338 write_stringstream(fNameBuffer, out);
5339 write_stringstream(fDecorationBuffer, out);
5340 write_stringstream(fConstantBuffer, out);
5341 write_stringstream(body, out);
5342 }
5343
generateCode()5344 bool SPIRVCodeGenerator::generateCode() {
5345 SkASSERT(!fContext.fErrors->errorCount());
5346 this->writeWord(SpvMagicNumber, *fOut);
5347 this->writeWord(SpvVersion, *fOut);
5348 this->writeWord(SKSL_MAGIC, *fOut);
5349 StringStream buffer;
5350 this->writeInstructions(fProgram, buffer);
5351 this->writeWord(fIdCount, *fOut);
5352 this->writeWord(0, *fOut); // reserved, always zero
5353 write_stringstream(buffer, *fOut);
5354 return fContext.fErrors->errorCount() == 0;
5355 }
5356
5357 #if defined(SK_ENABLE_SPIRV_VALIDATION)
validate_spirv(ErrorReporter & reporter,std::string_view program)5358 static bool validate_spirv(ErrorReporter& reporter, std::string_view program) {
5359 SkASSERT(0 == program.size() % 4);
5360 const uint32_t* programData = reinterpret_cast<const uint32_t*>(program.data());
5361 size_t programSize = program.size() / 4;
5362
5363 spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
5364 std::string errors;
5365 auto msgFn = [&errors](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
5366 errors += "SPIR-V validation error: ";
5367 errors += m;
5368 errors += '\n';
5369 };
5370 tools.SetMessageConsumer(msgFn);
5371
5372 // Verify that the SPIR-V we produced is valid. At runtime, we will abort() with a message
5373 // explaining the error. In standalone mode (skslc), we will send the message, plus the
5374 // entire disassembled SPIR-V (for easier context & debugging) as *our* error message.
5375 bool result = tools.Validate(programData, programSize);
5376 if (!result) {
5377 #if defined(SKSL_STANDALONE)
5378 // Convert the string-stream to a SPIR-V disassembly.
5379 std::string disassembly;
5380 uint32_t options = spvtools::SpirvTools::kDefaultDisassembleOption;
5381 options |= SPV_BINARY_TO_TEXT_OPTION_INDENT;
5382 if (tools.Disassemble(programData, programSize, &disassembly, options)) {
5383 errors.append(disassembly);
5384 }
5385 reporter.error(Position(), errors);
5386 #else
5387 SkDEBUGFAILF("%s", errors.c_str());
5388 #endif
5389 }
5390 return result;
5391 }
5392 #endif
5393
ToSPIRV(Program & program,const ShaderCaps * caps,OutputStream & out)5394 bool ToSPIRV(Program& program, const ShaderCaps* caps, OutputStream& out) {
5395 TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5396 SkASSERT(caps != nullptr);
5397
5398 program.fContext->fErrors->setSource(*program.fSource);
5399 #ifdef SK_ENABLE_SPIRV_VALIDATION
5400 StringStream buffer;
5401 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5402 bool result = cg.generateCode();
5403
5404 if (result && program.fConfig->fSettings.fValidateSPIRV) {
5405 std::string_view binary = buffer.str();
5406 result = validate_spirv(*program.fContext->fErrors, binary);
5407 out.write(binary.data(), binary.size());
5408 }
5409 #else
5410 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5411 bool result = cg.generateCode();
5412 #endif
5413 program.fContext->fErrors->setSource(std::string_view());
5414
5415 return result;
5416 }
5417
ToSPIRV(Program & program,const ShaderCaps * caps,std::string * out)5418 bool ToSPIRV(Program& program, const ShaderCaps* caps, std::string* out) {
5419 StringStream buffer;
5420 if (!ToSPIRV(program, caps, buffer)) {
5421 return false;
5422 }
5423 *out = buffer.str();
5424 return true;
5425 }
5426
5427 } // namespace SkSL
5428