1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/analysis/SkSLSpecialization.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
46 #include "src/sksl/ir/SkSLConstructorSplat.h"
47 #include "src/sksl/ir/SkSLDoStatement.h"
48 #include "src/sksl/ir/SkSLExpression.h"
49 #include "src/sksl/ir/SkSLExpressionStatement.h"
50 #include "src/sksl/ir/SkSLExtension.h"
51 #include "src/sksl/ir/SkSLFieldAccess.h"
52 #include "src/sksl/ir/SkSLFieldSymbol.h"
53 #include "src/sksl/ir/SkSLForStatement.h"
54 #include "src/sksl/ir/SkSLFunctionCall.h"
55 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
56 #include "src/sksl/ir/SkSLFunctionDefinition.h"
57 #include "src/sksl/ir/SkSLIRNode.h"
58 #include "src/sksl/ir/SkSLIfStatement.h"
59 #include "src/sksl/ir/SkSLIndexExpression.h"
60 #include "src/sksl/ir/SkSLInterfaceBlock.h"
61 #include "src/sksl/ir/SkSLLayout.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLModifierFlags.h"
64 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
65 #include "src/sksl/ir/SkSLPoison.h"
66 #include "src/sksl/ir/SkSLPostfixExpression.h"
67 #include "src/sksl/ir/SkSLPrefixExpression.h"
68 #include "src/sksl/ir/SkSLProgram.h"
69 #include "src/sksl/ir/SkSLProgramElement.h"
70 #include "src/sksl/ir/SkSLReturnStatement.h"
71 #include "src/sksl/ir/SkSLSetting.h"
72 #include "src/sksl/ir/SkSLStatement.h"
73 #include "src/sksl/ir/SkSLSwitchCase.h"
74 #include "src/sksl/ir/SkSLSwitchStatement.h"
75 #include "src/sksl/ir/SkSLSwizzle.h"
76 #include "src/sksl/ir/SkSLSymbol.h"
77 #include "src/sksl/ir/SkSLSymbolTable.h"
78 #include "src/sksl/ir/SkSLTernaryExpression.h"
79 #include "src/sksl/ir/SkSLType.h"
80 #include "src/sksl/ir/SkSLVarDeclarations.h"
81 #include "src/sksl/ir/SkSLVariable.h"
82 #include "src/sksl/ir/SkSLVariableReference.h"
83 #include "src/sksl/spirv.h"
84 #include "src/sksl/transform/SkSLTransform.h"
85 #include "src/utils/SkBitSet.h"
86
87 #include <algorithm>
88 #include <cstdint>
89 #include <cstring>
90 #include <ctype.h>
91 #include <functional>
92 #include <memory>
93 #include <set>
94 #include <string>
95 #include <string_view>
96 #include <tuple>
97 #include <utility>
98 #include <vector>
99 #include <unordered_map>
100 #include <unordered_set>
101
102 using namespace skia_private;
103
104 #define kLast_Capability SpvCapabilityMultiViewport
105
106 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
107 constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001;
108 static constexpr SkSL::Layout kDefaultTypeLayout;
109
110 namespace SkSL {
111
112 enum class ProgramKind : int8_t;
113
114 enum class StorageClass {
115 kUniformConstant,
116 kInput,
117 kUniform,
118 kStorageBuffer,
119 kOutput,
120 kWorkgroup,
121 kCrossWorkgroup,
122 kPrivate,
123 kFunction,
124 kGeneric,
125 kPushConstant,
126 kAtomicCounter,
127 kImage,
128 };
129
get_storage_class_spv_id(StorageClass storageClass)130 static SpvStorageClass get_storage_class_spv_id(StorageClass storageClass) {
131 switch (storageClass) {
132 case StorageClass::kUniformConstant: return SpvStorageClassUniformConstant;
133 case StorageClass::kInput: return SpvStorageClassInput;
134 case StorageClass::kUniform: return SpvStorageClassUniform;
135 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
136 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
137 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
138 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
139 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
140 // would benefit from SPV_KHR_variable_pointers capabilities).
141 #ifdef SKSL_EXT
142 case StorageClass::kStorageBuffer: return SpvStorageClassStorageBuffer;
143 #else
144 case StorageClass::kStorageBuffer: return SpvStorageClassUniform;
145 #endif
146 case StorageClass::kOutput: return SpvStorageClassOutput;
147 case StorageClass::kWorkgroup: return SpvStorageClassWorkgroup;
148 case StorageClass::kCrossWorkgroup: return SpvStorageClassCrossWorkgroup;
149 case StorageClass::kPrivate: return SpvStorageClassPrivate;
150 case StorageClass::kFunction: return SpvStorageClassFunction;
151 case StorageClass::kGeneric: return SpvStorageClassGeneric;
152 case StorageClass::kPushConstant: return SpvStorageClassPushConstant;
153 case StorageClass::kAtomicCounter: return SpvStorageClassAtomicCounter;
154 case StorageClass::kImage: return SpvStorageClassImage;
155 }
156
157 SkUNREACHABLE;
158 }
159
160 class SPIRVCodeGenerator : public CodeGenerator {
161 public:
162 // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
163 static constexpr SpvId NA = (SpvId)-1;
164
165 class LValue {
166 public:
~LValue()167 virtual ~LValue() {}
168
169 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
170 // by a pointer (e.g. vector swizzles), returns NA.
getPointer()171 virtual SpvId getPointer() { return NA; }
172
173 // Returns true if a valid pointer returned by getPointer represents a memory object
174 // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
175 // getPointer() returns NA.
isMemoryObjectPointer() const176 virtual bool isMemoryObjectPointer() const { return true; }
177
178 // Applies a swizzle to the components of the LValue, if possible. This is used to create
179 // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)180 virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
181 return false;
182 }
183
184 // Returns the storage class of the lvalue.
185 virtual StorageClass storageClass() const = 0;
186
187 virtual SpvId load(OutputStream& out) = 0;
188
189 virtual void store(SpvId value, OutputStream& out) = 0;
190 };
191
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)192 SPIRVCodeGenerator(const Context* context,
193 const ShaderCaps* caps,
194 const Program* program,
195 OutputStream* out)
196 : CodeGenerator(context, caps, program, out) {}
197
198 bool generateCode() override;
199
200 private:
201 enum IntrinsicOpcodeKind {
202 kGLSL_STD_450_IntrinsicOpcodeKind,
203 kSPIRV_IntrinsicOpcodeKind,
204 kSpecial_IntrinsicOpcodeKind,
205 kInvalid_IntrinsicOpcodeKind,
206 };
207
208 enum SpecialIntrinsic {
209 kAtan_SpecialIntrinsic,
210 kClamp_SpecialIntrinsic,
211 kMatrixCompMult_SpecialIntrinsic,
212 kMax_SpecialIntrinsic,
213 kMin_SpecialIntrinsic,
214 kMix_SpecialIntrinsic,
215 kMod_SpecialIntrinsic,
216 kDFdy_SpecialIntrinsic,
217 kSaturate_SpecialIntrinsic,
218 kSampledImage_SpecialIntrinsic,
219 kSmoothStep_SpecialIntrinsic,
220 kStep_SpecialIntrinsic,
221 kSubpassLoad_SpecialIntrinsic,
222 kTexture_SpecialIntrinsic,
223 kTextureGrad_SpecialIntrinsic,
224 kTextureLod_SpecialIntrinsic,
225 kTextureRead_SpecialIntrinsic,
226 kTextureWrite_SpecialIntrinsic,
227 kTextureWidth_SpecialIntrinsic,
228 kTextureHeight_SpecialIntrinsic,
229 kAtomicAdd_SpecialIntrinsic,
230 kAtomicLoad_SpecialIntrinsic,
231 kAtomicStore_SpecialIntrinsic,
232 kStorageBarrier_SpecialIntrinsic,
233 kWorkgroupBarrier_SpecialIntrinsic,
234 #ifdef SKSL_EXT
235 kTextureSize_SpecialIntrinsic,
236 kSampleGather_SpecialIntrinsic,
237 kNonuniformEXT_SpecialIntrinsic,
238 #endif
239 };
240
241 enum class Precision {
242 kDefault,
243 kRelaxed,
244 };
245
246 struct TempVar {
247 SpvId spvId;
248 const Type* type;
249 std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
250 };
251
252 /**
253 * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
254 * appropriate, or null to never add one.
255 */
256 SpvId nextId(const Type* type);
257
258 SpvId nextId(Precision precision);
259
260 SpvId getType(const Type& type);
261
262 SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
263
264 SpvId getFunctionType(const FunctionDeclaration& function);
265
266 SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
267
268 SpvId getPointerType(const Type& type, StorageClass storageClass);
269
270 SpvId getPointerType(const Type& type,
271 const Layout& typeLayout,
272 const MemoryLayout& memoryLayout,
273 StorageClass storageClass);
274
275 StorageClass getStorageClass(const Expression& expr);
276
277 TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
278
279 void writeLayout(const Layout& layout, SpvId target, Position pos);
280
281 void writeFieldLayout(const Layout& layout, SpvId target, int member);
282
283 SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
284
285 void writeProgramElement(const ProgramElement& pe, OutputStream& out);
286
287 SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
288
289 void writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
290
291 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
292
293 void writeFunction(const FunctionDefinition& f, OutputStream& out);
294
295 // Writes the function with the defined specializationIndex, if the index is -1, then it is
296 // assumed that the function has no specializations.
297 void writeFunctionInstantiation(const FunctionDefinition& f,
298 Analysis::SpecializationIndex specializationIndex,
299 const Analysis::SpecializedParameters* specializedParams,
300 OutputStream& out);
301
302 bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
303
304 SpvId writeGlobalVar(ProgramKind kind, StorageClass, const Variable& v);
305
306 void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
307
308 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
309
310 int findUniformFieldIndex(const Variable& var) const;
311
312 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
313
314 SpvId writeExpression(const Expression& expr, OutputStream& out);
315
316 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
317
318 void writeFunctionCallArgument(TArray<SpvId>& argumentList,
319 const FunctionCall& call,
320 int argIndex,
321 std::vector<TempVar>* tempVars,
322 const SkBitSet* specializedParams,
323 OutputStream& out);
324
325 void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
326
327 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
328
329
330 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
331 SpvId signedInst, SpvId unsignedInst,
332 const TArray<SpvId>& args, OutputStream& out);
333
334 /**
335 * Promotes an expression to a vector. If the expression is already a vector with vectorSize
336 * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
337 * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
338 * expression is already a vector and it does not have vectorSize columns.
339 */
340 SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
341
342 /**
343 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
344 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
345 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
346 * vec2, vec3).
347 */
348 TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
349
350 /**
351 * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
352 * returns the SpvId of the new value.
353 */
354 SpvId splat(const Type& type, SpvId id, OutputStream& out);
355
356 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
357 SpvId writeAtomicIntrinsic(const FunctionCall& c,
358 SpecialIntrinsic kind,
359 SpvId resultId,
360 OutputStream& out);
361
362 SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
363 OutputStream& out);
364
365 SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
366 OutputStream& out);
367
368 SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
369 OutputStream& out);
370
371 SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
372 OutputStream& out);
373
374 SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
375 OutputStream& out);
376
377 /**
378 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
379 * source matrix are filled with zero; entries which do not exist in the destination matrix are
380 * ignored.
381 */
382 SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
383
384 void addColumnEntry(const Type& columnType,
385 TArray<SpvId>* currentColumn,
386 TArray<SpvId>* columnIds,
387 int rows,
388 SpvId entry,
389 OutputStream& out);
390
391 SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
392
393 SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
394
395 SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
396
397 SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
398
399 SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
400
401 SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
402
403 SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
404
405 SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
406
407 SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
408
409 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
410
411 SpvId writeSwizzle(const Expression& baseExpr,
412 const ComponentArray& components,
413 OutputStream& out);
414
415 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
416
417 /**
418 * Folds the potentially-vector result of a logical operation down to a single bool. If
419 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
420 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
421 * returns the original id value.
422 */
423 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
424
425 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
426 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
427 SpvOp_ mergeOperator, OutputStream& out);
428
429 SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
430 OutputStream& out);
431
432 SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
433 OutputStream& out);
434
435 // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
436 // comparisons into an overall comparison result.
437 // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
438 // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
439 SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
440
441 // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
442 // multiplication into a sum of vector-scalar products.
443 SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
444 SpvId lhs,
445 const Type& rightType,
446 SpvId rhs,
447 const Type& resultType,
448 OutputStream& out);
449
450 SpvId writeComponentwiseMatrixUnary(const Type& operandType,
451 SpvId operand,
452 SpvOp_ op,
453 OutputStream& out);
454
455 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
456 SpvOp_ op, OutputStream& out);
457
458 SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
459 SpvId lhs, SpvId rhs,
460 SpvOp_ ifFloat, SpvOp_ ifInt,
461 SpvOp_ ifUInt, SpvOp_ ifBool,
462 OutputStream& out);
463
464 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
465 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
466 SpvOp_ ifBool, OutputStream& out);
467
468 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
469 SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
470 SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
471
472 SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
473
474 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
475 const Type& rightType, SpvId rhs, const Type& resultType,
476 OutputStream& out);
477
478 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
479
480 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
481
482 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
483
484 SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
485
486 SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
487
488 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
489
490 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
491
492 SpvId writeLiteral(const Literal& f);
493
494 SpvId writeLiteral(double value, const Type& type);
495
496 void writeStatement(const Statement& s, OutputStream& out);
497
498 void writeBlock(const Block& b, OutputStream& out);
499
500 void writeIfStatement(const IfStatement& stmt, OutputStream& out);
501
502 void writeForStatement(const ForStatement& f, OutputStream& out);
503
504 void writeDoStatement(const DoStatement& d, OutputStream& out);
505
506 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
507
508 void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
509
510 void writeCapabilities(OutputStream& out);
511
512 void writeInstructions(const Program& program, OutputStream& out);
513
514 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
515
516 void writeWord(int32_t word, OutputStream& out);
517
518 void writeString(std::string_view s, OutputStream& out);
519
520 void writeInstruction(SpvOp_ opCode, OutputStream& out);
521
522 void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
523
524 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
525
526 void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
527 OutputStream& out);
528
529 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
530 OutputStream& out);
531
532 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
533
534 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
535 OutputStream& out);
536
537 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
538 OutputStream& out);
539
540 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
541 int32_t word5, OutputStream& out);
542
543 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
544 int32_t word5, int32_t word6, OutputStream& out);
545
546 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
547 int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
548
549 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
550 int32_t word5, int32_t word6, int32_t word7, int32_t word8,
551 OutputStream& out);
552
553 // This form of writeInstruction can deduplicate redundant ops.
554 struct Word;
555 // 8 Words is enough for nearly all instructions (except variable-length instructions like
556 // OpAccessChain or OpConstantComposite).
557 using Words = STArray<8, Word, true>;
558 SpvId writeInstruction(SpvOp_ opCode, const TArray<Word, true>& words, OutputStream& out);
559
560 struct Instruction {
561 SpvId fOp;
562 int32_t fResultKind;
563 STArray<8, int32_t> fWords;
564
565 bool operator==(const Instruction& that) const;
566 struct Hash;
567 };
568
569 static Instruction BuildInstructionKey(SpvOp_ opCode, const TArray<Word, true>& words);
570
571 // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
572 SpvId writeOpConstantTrue(const Type& type);
573 SpvId writeOpConstantFalse(const Type& type);
574 SpvId writeOpConstant(const Type& type, int32_t valueBits);
575 SpvId writeOpConstantComposite(const Type& type, const TArray<SpvId>& values);
576 SpvId writeOpCompositeConstruct(const Type& type, const TArray<SpvId>&, OutputStream& out);
577 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
578 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
579 OutputStream& out);
580 SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
581 void writeOpStore(StorageClass storageClass, SpvId pointer, SpvId value, OutputStream& out);
582
583 // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
584 bool toConstants(SpvId value, TArray<SpvId>* constants);
585 bool toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants);
586
587 // Extracts the requested component SpvId from a composite instruction, if it can be done.
588 Instruction* resultTypeForInstruction(const Instruction& instr);
589 int numComponentsForVecInstruction(const Instruction& instr);
590 SpvId toComponent(SpvId id, int component);
591
592 struct ConditionalOpCounts {
593 int numReachableOps;
594 int numStoreOps;
595 };
596 ConditionalOpCounts getConditionalOpCounts();
597 void pruneConditionalOps(ConditionalOpCounts ops);
598
599 enum StraightLineLabelType {
600 // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
601 // happens at the start of a function, or when we find unreachable code.
602 kBranchlessBlock,
603
604 // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
605 // associated branch. Example usage:
606 // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
607 // use an OpBranch to explicitly jump to the next block, even when they are adjacent in
608 // the code.
609 // - The block immediately following an OpBranchConditional or OpSwitch.
610 kBranchIsOnPreviousLine,
611 };
612
613 enum BranchingLabelType {
614 // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
615 // ops that are above the label in the code--i.e., the branch skips forward in the code.
616 kBranchIsAbove,
617
618 // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
619 // ops below the label in the code--i.e., the branch jumps backward in the code.
620 kBranchIsBelow,
621
622 // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
623 kBranchesOnBothSides,
624 };
625 void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
626 void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
627 OutputStream& out);
628
629 MemoryLayout memoryLayoutForStorageClass(StorageClass storageClass);
630 MemoryLayout memoryLayoutForVariable(const Variable&) const;
631
632 struct EntrypointAdapter {
633 std::unique_ptr<FunctionDefinition> entrypointDef;
634 std::unique_ptr<FunctionDeclaration> entrypointDecl;
635 };
636
637 EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
638
639 struct UniformBuffer {
640 std::unique_ptr<InterfaceBlock> fInterfaceBlock;
641 std::unique_ptr<Variable> fInnerVariable;
642 std::unique_ptr<Type> fStruct;
643 };
644
645 void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
646
647 void addRTFlipUniform(Position pos);
648
649 #ifdef SKSL_EXT
650 SpvId writeSpecConstBinaryExpression(const BinaryExpression& b, const Operator& op,
651 SpvId lhs, SpvId rhs);
652 void writeExtensions(OutputStream& out);
653
654 std::unordered_set<std::string> fExtensions;
655 std::unordered_set<uint32_t> fCapabilitiesExt;
656 std::unordered_set<SpvId> fNonUniformSpvId;
657 std::unordered_map<const Variable*, SpvId> fGlobalConstVariableValueMap;
658 bool fEmittingGlobalConstConstructor = false;
659 #endif
660
661 std::unique_ptr<Expression> identifier(std::string_view name);
662
663 std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
664 const Variable& combinedSampler);
665
666 #ifdef SKSL_EXT
667 const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k430};
668 #else
669 const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
670 #endif
671 uint64_t fCapabilities = 0;
672 SpvId fIdCount = 1;
673 SpvId fGLSLExtendedInstructions;
674 struct Intrinsic {
675 IntrinsicOpcodeKind opKind;
676 int32_t floatOp;
677 int32_t signedOp;
678 int32_t unsignedOp;
679 int32_t boolOp;
680 };
681 Intrinsic getIntrinsic(IntrinsicKind) const;
682
683 THashMap<Analysis::SpecializedFunctionKey, SpvId, Analysis::SpecializedFunctionKey::Hash>
684 fFunctionMap;
685
686 Analysis::SpecializationInfo fSpecializationInfo;
687 Analysis::SpecializationIndex fActiveSpecializationIndex = Analysis::kUnspecialized;
688 const Analysis::SpecializedParameters* fActiveSpecialization = nullptr;
689
690 THashMap<const Variable*, SpvId> fVariableMap;
691 THashMap<const Type*, SpvId> fStructMap;
692 StringStream fGlobalInitializersBuffer;
693 StringStream fConstantBuffer;
694 StringStream fVariableBuffer;
695 StringStream fNameBuffer;
696 StringStream fDecorationBuffer;
697
698 // Mapping from combined sampler declarations to synthesized texture/sampler variables.
699 // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
700 bool fUseTextureSamplerPairs = false;
701 struct SynthesizedTextureSamplerPair {
702 // The names of the synthesized variables. The Variable objects themselves store string
703 // views referencing these strings. It is important for the std::string instances to have a
704 // fixed memory location after the string views get created, which is why
705 // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
706 std::string fTextureName;
707 std::string fSamplerName;
708 std::unique_ptr<Variable> fTexture;
709 std::unique_ptr<Variable> fSampler;
710 };
711 THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
712 fSynthesizedSamplerMap;
713
714 // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
715 // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
716 // and simplify code we've already emitted (by taking a SpvId from an Instruction and following
717 // it back to its source).
718
719 // A map of instruction -> SpvId:
720 THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
721 // A map of SpvId -> instruction:
722 THashMap<SpvId, Instruction> fSpvIdCache;
723 // A map of SpvId -> value SpvId:
724 THashMap<SpvId, SpvId> fStoreCache;
725
726 // "Reachable" ops are instructions which can safely be accessed from the current block.
727 // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
728 // reuse that computation on following lines. However, if that Add operation occurred inside an
729 // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
730 // depending on the if condition, we may or may not have actually done that computation). The
731 // same logic applies to other control-flow blocks as well. Once an instruction becomes
732 // unreachable, we remove it from both op-caches.
733 TArray<SpvId> fReachableOps;
734
735 // The "store-ops" list contains a running list of all the pointers in the store cache. If a
736 // store occurs inside of a conditional block, once that block exits, we no longer know what is
737 // stored in that particular SpvId. At that point, we must remove any associated entry from the
738 // store cache.
739 TArray<SpvId> fStoreOps;
740
741 // label of the current block, or 0 if we are not in a block
742 SpvId fCurrentBlock = 0;
743 TArray<SpvId> fBreakTarget;
744 TArray<SpvId> fContinueTarget;
745 bool fWroteRTFlip = false;
746 // holds variables synthesized during output, for lifetime purposes
747 SymbolTable fSynthetics{/*builtin=*/true};
748 // Holds a list of uniforms that were declared as globals at the top-level instead of in an
749 // interface block.
750 UniformBuffer fUniformBuffer;
751 std::vector<const VarDeclaration*> fTopLevelUniforms;
752 THashMap<const Variable*, int> fTopLevelUniformMap; // <var, UniformBuffer field index>
753 SpvId fUniformBufferId = NA;
754
755 friend class PointerLValue;
756 friend class SwizzleLValue;
757 };
758
759 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const760 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
761 return fOp == that.fOp &&
762 fResultKind == that.fResultKind &&
763 fWords == that.fWords;
764 }
765
766 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash767 uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
768 uint32_t hash = key.fResultKind;
769 hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
770 hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
771 return hash;
772 }
773 };
774
775 // This class is used to pass values and result placeholder slots to writeInstruction.
776 struct SPIRVCodeGenerator::Word {
777 enum Kind {
778 kNone, // intended for use as a sentinel, not part of any Instruction
779 kSpvId,
780 kNumber,
781 kDefaultPrecisionResult,
782 kRelaxedPrecisionResult,
783 kUniqueResult,
784 kKeyedResult,
785 };
786
WordSkSL::SPIRVCodeGenerator::Word787 Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word788 Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
789
NumberSkSL::SPIRVCodeGenerator::Word790 static Word Number(int32_t val) {
791 return Word{val, Kind::kNumber};
792 }
793
ResultSkSL::SPIRVCodeGenerator::Word794 static Word Result(const Type& type) {
795 return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
796 }
797
RelaxedResultSkSL::SPIRVCodeGenerator::Word798 static Word RelaxedResult() {
799 return Word{(int32_t)NA, kRelaxedPrecisionResult};
800 }
801
UniqueResultSkSL::SPIRVCodeGenerator::Word802 static Word UniqueResult() {
803 return Word{(int32_t)NA, kUniqueResult};
804 }
805
ResultSkSL::SPIRVCodeGenerator::Word806 static Word Result() {
807 return Word{(int32_t)NA, kDefaultPrecisionResult};
808 }
809
810 // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
811 // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
812 // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word813 static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
814
isResultSkSL::SPIRVCodeGenerator::Word815 bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
816
817 int32_t fValue;
818 Kind fKind;
819 };
820
821 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
822 // sksl generator if we want.
823 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
824 static const int32_t SKSL_MAGIC = 0x001F0000;
825
getIntrinsic(IntrinsicKind ik) const826 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
827
828 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
829 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
830 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
831 GLSLstd450 ## ifFloat, \
832 GLSLstd450 ## ifInt, \
833 GLSLstd450 ## ifUInt, \
834 SpvOpUndef}
835 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
836 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
837 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
838 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
839 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
840 SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
841 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
842 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
843 k ## x ## _SpecialIntrinsic}
844
845 switch (ik) {
846 case k_round_IntrinsicKind: return ALL_GLSL(Round);
847 case k_roundEven_IntrinsicKind: return ALL_GLSL(RoundEven);
848 case k_trunc_IntrinsicKind: return ALL_GLSL(Trunc);
849 case k_abs_IntrinsicKind: return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
850 case k_sign_IntrinsicKind: return BY_TYPE_GLSL(FSign, SSign, SSign);
851 case k_floor_IntrinsicKind: return ALL_GLSL(Floor);
852 case k_ceil_IntrinsicKind: return ALL_GLSL(Ceil);
853 case k_fract_IntrinsicKind: return ALL_GLSL(Fract);
854 case k_radians_IntrinsicKind: return ALL_GLSL(Radians);
855 case k_degrees_IntrinsicKind: return ALL_GLSL(Degrees);
856 case k_sin_IntrinsicKind: return ALL_GLSL(Sin);
857 case k_cos_IntrinsicKind: return ALL_GLSL(Cos);
858 case k_tan_IntrinsicKind: return ALL_GLSL(Tan);
859 case k_asin_IntrinsicKind: return ALL_GLSL(Asin);
860 case k_acos_IntrinsicKind: return ALL_GLSL(Acos);
861 case k_atan_IntrinsicKind: return SPECIAL(Atan);
862 case k_sinh_IntrinsicKind: return ALL_GLSL(Sinh);
863 case k_cosh_IntrinsicKind: return ALL_GLSL(Cosh);
864 case k_tanh_IntrinsicKind: return ALL_GLSL(Tanh);
865 case k_asinh_IntrinsicKind: return ALL_GLSL(Asinh);
866 case k_acosh_IntrinsicKind: return ALL_GLSL(Acosh);
867 case k_atanh_IntrinsicKind: return ALL_GLSL(Atanh);
868 case k_pow_IntrinsicKind: return ALL_GLSL(Pow);
869 case k_exp_IntrinsicKind: return ALL_GLSL(Exp);
870 case k_log_IntrinsicKind: return ALL_GLSL(Log);
871 case k_exp2_IntrinsicKind: return ALL_GLSL(Exp2);
872 case k_log2_IntrinsicKind: return ALL_GLSL(Log2);
873 case k_sqrt_IntrinsicKind: return ALL_GLSL(Sqrt);
874 case k_inverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
875 case k_outerProduct_IntrinsicKind: return ALL_SPIRV(OuterProduct);
876 case k_transpose_IntrinsicKind: return ALL_SPIRV(Transpose);
877 case k_isinf_IntrinsicKind: return ALL_SPIRV(IsInf);
878 case k_isnan_IntrinsicKind: return ALL_SPIRV(IsNan);
879 case k_inversesqrt_IntrinsicKind: return ALL_GLSL(InverseSqrt);
880 case k_determinant_IntrinsicKind: return ALL_GLSL(Determinant);
881 case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
882 case k_matrixInverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
883 case k_mod_IntrinsicKind: return SPECIAL(Mod);
884 case k_modf_IntrinsicKind: return ALL_GLSL(Modf);
885 case k_min_IntrinsicKind: return SPECIAL(Min);
886 case k_max_IntrinsicKind: return SPECIAL(Max);
887 case k_clamp_IntrinsicKind: return SPECIAL(Clamp);
888 case k_saturate_IntrinsicKind: return SPECIAL(Saturate);
889 case k_dot_IntrinsicKind: return FLOAT_SPIRV(Dot);
890 case k_mix_IntrinsicKind: return SPECIAL(Mix);
891 case k_step_IntrinsicKind: return SPECIAL(Step);
892 case k_smoothstep_IntrinsicKind: return SPECIAL(SmoothStep);
893 case k_fma_IntrinsicKind: return ALL_GLSL(Fma);
894 case k_frexp_IntrinsicKind: return ALL_GLSL(Frexp);
895 case k_ldexp_IntrinsicKind: return ALL_GLSL(Ldexp);
896
897 #define PACK(type) case k_pack##type##_IntrinsicKind: return ALL_GLSL(Pack##type); \
898 case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
899 PACK(Snorm4x8);
900 PACK(Unorm4x8);
901 PACK(Snorm2x16);
902 PACK(Unorm2x16);
903 PACK(Half2x16);
904 #undef PACK
905
906 case k_length_IntrinsicKind: return ALL_GLSL(Length);
907 case k_distance_IntrinsicKind: return ALL_GLSL(Distance);
908 case k_cross_IntrinsicKind: return ALL_GLSL(Cross);
909 case k_normalize_IntrinsicKind: return ALL_GLSL(Normalize);
910 case k_faceforward_IntrinsicKind: return ALL_GLSL(FaceForward);
911 case k_reflect_IntrinsicKind: return ALL_GLSL(Reflect);
912 case k_refract_IntrinsicKind: return ALL_GLSL(Refract);
913 case k_bitCount_IntrinsicKind: return ALL_SPIRV(BitCount);
914 case k_findLSB_IntrinsicKind: return ALL_GLSL(FindILsb);
915 case k_findMSB_IntrinsicKind: return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
916 case k_dFdx_IntrinsicKind: return FLOAT_SPIRV(DPdx);
917 case k_dFdy_IntrinsicKind: return SPECIAL(DFdy);
918 case k_fwidth_IntrinsicKind: return FLOAT_SPIRV(Fwidth);
919
920 case k_sample_IntrinsicKind: return SPECIAL(Texture);
921 case k_sampleGrad_IntrinsicKind: return SPECIAL(TextureGrad);
922 case k_sampleLod_IntrinsicKind: return SPECIAL(TextureLod);
923 case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
924
925 case k_textureRead_IntrinsicKind: return SPECIAL(TextureRead);
926 case k_textureWrite_IntrinsicKind: return SPECIAL(TextureWrite);
927 case k_textureWidth_IntrinsicKind: return SPECIAL(TextureWidth);
928 case k_textureHeight_IntrinsicKind: return SPECIAL(TextureHeight);
929
930 case k_floatBitsToInt_IntrinsicKind: return ALL_SPIRV(Bitcast);
931 case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
932 case k_intBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
933 case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
934
935 case k_any_IntrinsicKind: return BOOL_SPIRV(Any);
936 case k_all_IntrinsicKind: return BOOL_SPIRV(All);
937 case k_not_IntrinsicKind: return BOOL_SPIRV(LogicalNot);
938
939 case k_equal_IntrinsicKind:
940 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
941 SpvOpFOrdEqual,
942 SpvOpIEqual,
943 SpvOpIEqual,
944 SpvOpLogicalEqual};
945 case k_notEqual_IntrinsicKind:
946 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
947 SpvOpFUnordNotEqual,
948 SpvOpINotEqual,
949 SpvOpINotEqual,
950 SpvOpLogicalNotEqual};
951 case k_lessThan_IntrinsicKind:
952 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
953 SpvOpFOrdLessThan,
954 SpvOpSLessThan,
955 SpvOpULessThan,
956 SpvOpUndef};
957 case k_lessThanEqual_IntrinsicKind:
958 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
959 SpvOpFOrdLessThanEqual,
960 SpvOpSLessThanEqual,
961 SpvOpULessThanEqual,
962 SpvOpUndef};
963 case k_greaterThan_IntrinsicKind:
964 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
965 SpvOpFOrdGreaterThan,
966 SpvOpSGreaterThan,
967 SpvOpUGreaterThan,
968 SpvOpUndef};
969 case k_greaterThanEqual_IntrinsicKind:
970 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
971 SpvOpFOrdGreaterThanEqual,
972 SpvOpSGreaterThanEqual,
973 SpvOpUGreaterThanEqual,
974 SpvOpUndef};
975
976 case k_atomicAdd_IntrinsicKind: return SPECIAL(AtomicAdd);
977 case k_atomicLoad_IntrinsicKind: return SPECIAL(AtomicLoad);
978 case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
979
980 case k_storageBarrier_IntrinsicKind: return SPECIAL(StorageBarrier);
981 case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
982 #ifdef SKSL_EXT
983 case k_textureSize_IntrinsicKind: return SPECIAL(TextureSize);
984 case k_sampleGather_IntrinsicKind: return SPECIAL(SampleGather);
985 case k_nonuniformEXT_IntrinsicKind: return SPECIAL(NonuniformEXT);
986 #endif
987 default:
988 return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
989 }
990 }
991
writeWord(int32_t word,OutputStream & out)992 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
993 out.write((const char*) &word, sizeof(word));
994 }
995
is_float(const Type & type)996 static bool is_float(const Type& type) {
997 return (type.isScalar() || type.isVector() || type.isMatrix()) &&
998 type.componentType().isFloat();
999 }
1000
is_signed(const Type & type)1001 static bool is_signed(const Type& type) {
1002 return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
1003 }
1004
is_unsigned(const Type & type)1005 static bool is_unsigned(const Type& type) {
1006 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
1007 }
1008
is_bool(const Type & type)1009 static bool is_bool(const Type& type) {
1010 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
1011 }
1012
1013 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)1014 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
1015 if (is_float(type)) {
1016 return ifFloat;
1017 }
1018 if (is_signed(type)) {
1019 return ifInt;
1020 }
1021 if (is_unsigned(type)) {
1022 return ifUInt;
1023 }
1024 if (is_bool(type)) {
1025 return ifBool;
1026 }
1027 SkDEBUGFAIL("unrecognized type");
1028 return ifFloat;
1029 }
1030
is_out(ModifierFlags f)1031 static bool is_out(ModifierFlags f) {
1032 return SkToBool(f & ModifierFlag::kOut);
1033 }
1034
is_in(ModifierFlags f)1035 static bool is_in(ModifierFlags f) {
1036 if (f & ModifierFlag::kIn) {
1037 return true; // `in` and `inout` both count
1038 }
1039 // If neither in/out flag is set, the type is implicitly `in`.
1040 return !SkToBool(f & ModifierFlag::kOut);
1041 }
1042
is_control_flow_op(SpvOp_ op)1043 static bool is_control_flow_op(SpvOp_ op) {
1044 switch (op) {
1045 case SpvOpReturn:
1046 case SpvOpReturnValue:
1047 case SpvOpKill:
1048 case SpvOpSwitch:
1049 case SpvOpBranch:
1050 case SpvOpBranchConditional:
1051 return true;
1052 default:
1053 return false;
1054 }
1055 }
1056
is_globally_reachable_op(SpvOp_ op)1057 static bool is_globally_reachable_op(SpvOp_ op) {
1058 switch (op) {
1059 case SpvOpConstant:
1060 case SpvOpConstantTrue:
1061 case SpvOpConstantFalse:
1062 case SpvOpConstantComposite:
1063 case SpvOpTypeVoid:
1064 case SpvOpTypeInt:
1065 case SpvOpTypeFloat:
1066 case SpvOpTypeBool:
1067 case SpvOpTypeVector:
1068 case SpvOpTypeMatrix:
1069 case SpvOpTypeArray:
1070 case SpvOpTypePointer:
1071 case SpvOpTypeFunction:
1072 case SpvOpTypeRuntimeArray:
1073 case SpvOpTypeStruct:
1074 case SpvOpTypeImage:
1075 case SpvOpTypeSampledImage:
1076 case SpvOpTypeSampler:
1077 case SpvOpVariable:
1078 case SpvOpFunction:
1079 case SpvOpFunctionParameter:
1080 case SpvOpFunctionEnd:
1081 case SpvOpExecutionMode:
1082 case SpvOpMemoryModel:
1083 case SpvOpCapability:
1084 case SpvOpExtInstImport:
1085 case SpvOpEntryPoint:
1086 case SpvOpSource:
1087 case SpvOpSourceExtension:
1088 case SpvOpName:
1089 case SpvOpMemberName:
1090 case SpvOpDecorate:
1091 case SpvOpMemberDecorate:
1092 #ifdef SKSL_EXT
1093 case SpvOpExtension:
1094 case SpvOpSpecConstant:
1095 case SpvOpSpecConstantOp:
1096 #endif
1097 return true;
1098 default:
1099 return false;
1100 }
1101 }
1102
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1103 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1104 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1105 SkASSERT(opCode != SpvOpUndef);
1106 bool foundDeadCode = false;
1107 if (is_control_flow_op(opCode)) {
1108 // This instruction causes us to leave the current block.
1109 foundDeadCode = (fCurrentBlock == 0);
1110 fCurrentBlock = 0;
1111 } else if (!is_globally_reachable_op(opCode)) {
1112 foundDeadCode = (fCurrentBlock == 0);
1113 }
1114
1115 if (foundDeadCode) {
1116 // We just encountered dead code--an instruction that don't have an associated block.
1117 // Synthesize a label if this happens; this is necessary to satisfy the validator.
1118 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1119 }
1120
1121 this->writeWord((length << 16) | opCode, out);
1122 }
1123
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1124 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1125 // The straight-line label type is not important; in any case, no caches are invalidated.
1126 SkASSERT(!fCurrentBlock);
1127 fCurrentBlock = label;
1128 this->writeInstruction(SpvOpLabel, label, out);
1129 }
1130
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1131 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1132 ConditionalOpCounts ops, OutputStream& out) {
1133 switch (type) {
1134 case kBranchIsBelow:
1135 case kBranchesOnBothSides:
1136 // With a backward or bidirectional branch, we haven't seen the code between the label
1137 // and the branch yet, so any stored value is potentially suspect. Without scanning
1138 // ahead to check, the only safe option is to ditch the store cache entirely.
1139 fStoreCache.reset();
1140 [[fallthrough]];
1141
1142 case kBranchIsAbove:
1143 // With a forward branch, we can rely on stores that we had cached at the start of the
1144 // statement/expression, if they haven't been touched yet. Anything newer than that is
1145 // pruned.
1146 this->pruneConditionalOps(ops);
1147 break;
1148 }
1149
1150 // Emit the label.
1151 this->writeLabel(label, kBranchlessBlock, out);
1152 }
1153
writeInstruction(SpvOp_ opCode,OutputStream & out)1154 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1155 this->writeOpCode(opCode, 1, out);
1156 }
1157
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1158 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1159 this->writeOpCode(opCode, 2, out);
1160 this->writeWord(word1, out);
1161 }
1162
writeString(std::string_view s,OutputStream & out)1163 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1164 out.write(s.data(), s.length());
1165 switch (s.length() % 4) {
1166 case 1:
1167 out.write8(0);
1168 [[fallthrough]];
1169 case 2:
1170 out.write8(0);
1171 [[fallthrough]];
1172 case 3:
1173 out.write8(0);
1174 break;
1175 default:
1176 this->writeWord(0, out);
1177 break;
1178 }
1179 }
1180
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1181 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1182 OutputStream& out) {
1183 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1184 this->writeString(string, out);
1185 }
1186
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1187 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1188 OutputStream& out) {
1189 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1190 this->writeWord(word1, out);
1191 this->writeString(string, out);
1192 }
1193
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1194 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1195 std::string_view string, OutputStream& out) {
1196 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1197 this->writeWord(word1, out);
1198 this->writeWord(word2, out);
1199 this->writeString(string, out);
1200 }
1201
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1202 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1203 OutputStream& out) {
1204 this->writeOpCode(opCode, 3, out);
1205 this->writeWord(word1, out);
1206 this->writeWord(word2, out);
1207 }
1208
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1209 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1210 int32_t word3, OutputStream& out) {
1211 this->writeOpCode(opCode, 4, out);
1212 this->writeWord(word1, out);
1213 this->writeWord(word2, out);
1214 this->writeWord(word3, out);
1215 }
1216
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1217 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1218 int32_t word3, int32_t word4, OutputStream& out) {
1219 this->writeOpCode(opCode, 5, out);
1220 this->writeWord(word1, out);
1221 this->writeWord(word2, out);
1222 this->writeWord(word3, out);
1223 this->writeWord(word4, out);
1224 }
1225
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1226 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1227 int32_t word3, int32_t word4, int32_t word5,
1228 OutputStream& out) {
1229 this->writeOpCode(opCode, 6, out);
1230 this->writeWord(word1, out);
1231 this->writeWord(word2, out);
1232 this->writeWord(word3, out);
1233 this->writeWord(word4, out);
1234 this->writeWord(word5, out);
1235 }
1236
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1237 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1238 int32_t word3, int32_t word4, int32_t word5,
1239 int32_t word6, OutputStream& out) {
1240 this->writeOpCode(opCode, 7, out);
1241 this->writeWord(word1, out);
1242 this->writeWord(word2, out);
1243 this->writeWord(word3, out);
1244 this->writeWord(word4, out);
1245 this->writeWord(word5, out);
1246 this->writeWord(word6, out);
1247 }
1248
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1249 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1250 int32_t word3, int32_t word4, int32_t word5,
1251 int32_t word6, int32_t word7, OutputStream& out) {
1252 this->writeOpCode(opCode, 8, out);
1253 this->writeWord(word1, out);
1254 this->writeWord(word2, out);
1255 this->writeWord(word3, out);
1256 this->writeWord(word4, out);
1257 this->writeWord(word5, out);
1258 this->writeWord(word6, out);
1259 this->writeWord(word7, out);
1260 }
1261
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1262 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1263 int32_t word3, int32_t word4, int32_t word5,
1264 int32_t word6, int32_t word7, int32_t word8,
1265 OutputStream& out) {
1266 this->writeOpCode(opCode, 9, out);
1267 this->writeWord(word1, out);
1268 this->writeWord(word2, out);
1269 this->writeWord(word3, out);
1270 this->writeWord(word4, out);
1271 this->writeWord(word5, out);
1272 this->writeWord(word6, out);
1273 this->writeWord(word7, out);
1274 this->writeWord(word8, out);
1275 }
1276
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1277 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1278 const TArray<Word>& words) {
1279 // Assemble a cache key for this instruction.
1280 Instruction key;
1281 key.fOp = opCode;
1282 key.fWords.resize(words.size());
1283 key.fResultKind = Word::Kind::kNone;
1284
1285 for (int index = 0; index < words.size(); ++index) {
1286 const Word& word = words[index];
1287 key.fWords[index] = word.fValue;
1288 if (word.isResult()) {
1289 SkASSERT(key.fResultKind == Word::Kind::kNone);
1290 key.fResultKind = word.fKind;
1291 }
1292 }
1293
1294 return key;
1295 }
1296
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1297 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1298 const TArray<Word>& words,
1299 OutputStream& out) {
1300 // writeOpLoad and writeOpStore have dedicated code.
1301 SkASSERT(opCode != SpvOpLoad);
1302 SkASSERT(opCode != SpvOpStore);
1303
1304 // If this instruction exists in our op cache, return the cached SpvId.
1305 Instruction key = BuildInstructionKey(opCode, words);
1306 if (SpvId* cachedOp = fOpCache.find(key)) {
1307 return *cachedOp;
1308 }
1309
1310 SpvId result = NA;
1311 Precision precision = Precision::kDefault;
1312
1313 switch (key.fResultKind) {
1314 case Word::Kind::kUniqueResult:
1315 // The instruction returns a SpvId, but we do not want deduplication.
1316 result = this->nextId(Precision::kDefault);
1317 fSpvIdCache.set(result, key);
1318 break;
1319
1320 case Word::Kind::kNone:
1321 // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1322 fOpCache.set(key, result);
1323 break;
1324
1325 case Word::Kind::kRelaxedPrecisionResult:
1326 precision = Precision::kRelaxed;
1327 [[fallthrough]];
1328
1329 case Word::Kind::kKeyedResult:
1330 [[fallthrough]];
1331
1332 case Word::Kind::kDefaultPrecisionResult:
1333 // Consume a new SpvId.
1334 result = this->nextId(precision);
1335 fOpCache.set(key, result);
1336 fSpvIdCache.set(result, key);
1337
1338 // Globally-reachable ops are not subject to the whims of flow control.
1339 if (!is_globally_reachable_op(opCode)) {
1340 fReachableOps.push_back(result);
1341 }
1342 break;
1343
1344 default:
1345 SkDEBUGFAIL("unexpected result kind");
1346 break;
1347 }
1348
1349 // Write the requested instruction.
1350 this->writeOpCode(opCode, words.size() + 1, out);
1351 for (const Word& word : words) {
1352 if (word.isResult()) {
1353 SkASSERT(result != NA);
1354 this->writeWord(result, out);
1355 } else {
1356 this->writeWord(word.fValue, out);
1357 }
1358 }
1359
1360 // Return the result.
1361 return result;
1362 }
1363
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1364 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1365 Precision precision,
1366 SpvId pointer,
1367 OutputStream& out) {
1368 // Look for this pointer in our load-cache.
1369 if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1370 return *cachedOp;
1371 }
1372
1373 // Write the requested OpLoad instruction.
1374 #ifdef SKSL_EXT
1375 SpvId result = -1;
1376 if (fNonUniformSpvId.find(pointer) != fNonUniformSpvId.end()) {
1377 result = this->nextId(nullptr);
1378 this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
1379 } else {
1380 result = this->nextId(precision);
1381 }
1382 #else
1383 SpvId result = this->nextId(precision);
1384 #endif
1385 this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1386 return result;
1387 }
1388
1389 #ifdef SKSL_EXT
writeExtensions(OutputStream & out)1390 void SPIRVCodeGenerator::writeExtensions(OutputStream& out) {
1391 for (const auto& ext : fExtensions) {
1392 this->writeInstruction(SpvOpExtension, ext, out);
1393 }
1394 }
1395 #endif
1396
writeOpStore(StorageClass storageClass,SpvId pointer,SpvId value,OutputStream & out)1397 void SPIRVCodeGenerator::writeOpStore(StorageClass storageClass,
1398 SpvId pointer,
1399 SpvId value,
1400 OutputStream& out) {
1401 // Write the uncached SpvOpStore directly.
1402 this->writeInstruction(SpvOpStore, pointer, value, out);
1403
1404 if (storageClass == StorageClass::kFunction) {
1405 // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1406 // return the cached value as-is.
1407 fStoreCache.set(pointer, value);
1408 fStoreOps.push_back(pointer);
1409 }
1410 }
1411
writeOpConstantTrue(const Type & type)1412 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1413 return this->writeInstruction(SpvOpConstantTrue,
1414 Words{this->getType(type), Word::Result()},
1415 fConstantBuffer);
1416 }
1417
writeOpConstantFalse(const Type & type)1418 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1419 return this->writeInstruction(SpvOpConstantFalse,
1420 Words{this->getType(type), Word::Result()},
1421 fConstantBuffer);
1422 }
1423
writeOpConstant(const Type & type,int32_t valueBits)1424 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1425 return this->writeInstruction(
1426 SpvOpConstant,
1427 Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1428 fConstantBuffer);
1429 }
1430
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1431 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1432 const TArray<SpvId>& values) {
1433 SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1434
1435 Words words;
1436 words.push_back(this->getType(type));
1437 words.push_back(Word::Result());
1438 for (SpvId value : values) {
1439 words.push_back(value);
1440 }
1441 return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1442 }
1443
toConstants(SpvId value,TArray<SpvId> * constants)1444 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1445 Instruction* instr = fSpvIdCache.find(value);
1446 if (!instr) {
1447 return false;
1448 }
1449 switch (instr->fOp) {
1450 case SpvOpConstant:
1451 case SpvOpConstantTrue:
1452 case SpvOpConstantFalse:
1453 constants->push_back(value);
1454 return true;
1455
1456 case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1457 // Start at word 2 to skip past ResultType and ResultID.
1458 for (int i = 2; i < instr->fWords.size(); ++i) {
1459 if (!this->toConstants(instr->fWords[i], constants)) {
1460 return false;
1461 }
1462 }
1463 return true;
1464
1465 default:
1466 return false;
1467 }
1468 }
1469
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1470 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1471 for (SpvId value : values) {
1472 if (!this->toConstants(value, constants)) {
1473 return false;
1474 }
1475 }
1476 return true;
1477 }
1478
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1479 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1480 const TArray<SpvId>& values,
1481 OutputStream& out) {
1482 // If this is a vector composed entirely of literals, write a constant-composite instead.
1483 if (type.isVector()) {
1484 STArray<4, SpvId> constants;
1485 if (this->toConstants(SkSpan(values), &constants)) {
1486 // Create a vector from literals.
1487 return this->writeOpConstantComposite(type, constants);
1488 }
1489 }
1490
1491 // If this is a matrix composed entirely of literals, constant-composite them instead.
1492 if (type.isMatrix()) {
1493 STArray<16, SpvId> constants;
1494 if (this->toConstants(SkSpan(values), &constants)) {
1495 // Create each matrix column.
1496 SkASSERT(type.isMatrix());
1497 const Type& vecType = type.columnType(fContext);
1498 STArray<4, SpvId> columnIDs;
1499 for (int index=0; index < type.columns(); ++index) {
1500 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1501 type.rows());
1502 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1503 }
1504 // Compose the matrix from its columns.
1505 return this->writeOpConstantComposite(type, columnIDs);
1506 }
1507 }
1508
1509 Words words;
1510 words.push_back(this->getType(type));
1511 words.push_back(Word::Result(type));
1512 for (SpvId value : values) {
1513 words.push_back(value);
1514 }
1515 #ifdef SKSL_EXT
1516 return this->writeInstruction(fEmittingGlobalConstConstructor ?
1517 SpvOpConstantComposite : SpvOpCompositeConstruct, words, out);
1518 #else
1519 return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1520 #endif
1521 }
1522
resultTypeForInstruction(const Instruction & instr)1523 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1524 const Instruction& instr) {
1525 // This list should contain every op that we cache that has a result and result-type.
1526 // (If one is missing, we will not find some optimization opportunities.)
1527 // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1528 // universally true, so it's configurable on a per-op basis.
1529 int resultTypeWord;
1530 switch (instr.fOp) {
1531 case SpvOpConstant:
1532 case SpvOpConstantTrue:
1533 case SpvOpConstantFalse:
1534 case SpvOpConstantComposite:
1535 case SpvOpCompositeConstruct:
1536 case SpvOpCompositeExtract:
1537 case SpvOpLoad:
1538 resultTypeWord = 0;
1539 break;
1540
1541 default:
1542 return nullptr;
1543 }
1544
1545 Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1546 SkASSERT(typeInstr);
1547 return typeInstr;
1548 }
1549
numComponentsForVecInstruction(const Instruction & instr)1550 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1551 // If an instruction is in the op cache, its type should be as well.
1552 Instruction* typeInstr = this->resultTypeForInstruction(instr);
1553 SkASSERT(typeInstr);
1554 SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1555 typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1556
1557 // For vectors, extract their column count. Scalars have one component by definition.
1558 // SpvOpTypeVector ResultID ComponentType NumComponents
1559 return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1560 : 1;
1561 }
1562
toComponent(SpvId id,int component)1563 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1564 Instruction* instr = fSpvIdCache.find(id);
1565 if (!instr) {
1566 return NA;
1567 }
1568 if (instr->fOp == SpvOpConstantComposite) {
1569 // SpvOpConstantComposite ResultType ResultID [components...]
1570 // Add 2 to the component index to skip past ResultType and ResultID.
1571 return instr->fWords[2 + component];
1572 }
1573 if (instr->fOp == SpvOpCompositeConstruct) {
1574 // SpvOpCompositeConstruct ResultType ResultID [components...]
1575 // Vectors have special rules; check to see if we are composing a vector.
1576 Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1577 SkASSERT(composedType);
1578
1579 // When composing a non-vector, each instruction word maps 1:1 to the component index.
1580 // We can just extract out the associated component directly.
1581 if (composedType->fOp != SpvOpTypeVector) {
1582 return instr->fWords[2 + component];
1583 }
1584
1585 // When composing a vector, components can be either scalars or vectors.
1586 // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1587 for (int index = 2; index < instr->fWords.size(); ++index) {
1588 int32_t currentWord = instr->fWords[index];
1589
1590 // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1591 Instruction* subinstr = fSpvIdCache.find(currentWord);
1592 if (!subinstr) {
1593 return NA;
1594 }
1595 // If this subinstruction contains the component we're looking for...
1596 int numComponents = this->numComponentsForVecInstruction(*subinstr);
1597 if (component < numComponents) {
1598 if (numComponents == 1) {
1599 // ... it's a scalar. Return it.
1600 SkASSERT(component == 0);
1601 return currentWord;
1602 } else {
1603 // ... it's a vector. Recurse into it.
1604 return this->toComponent(currentWord, component);
1605 }
1606 }
1607 // This sub-instruction doesn't contain our component. Keep walking forward.
1608 component -= numComponents;
1609 }
1610 SkDEBUGFAIL("component index goes past the end of this composite value");
1611 return NA;
1612 }
1613 return NA;
1614 }
1615
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1616 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1617 SpvId base,
1618 int component,
1619 OutputStream& out) {
1620 // If the base op is a composite, we can extract from it directly.
1621 SpvId result = this->toComponent(base, component);
1622 if (result != NA) {
1623 return result;
1624 }
1625 return this->writeInstruction(
1626 SpvOpCompositeExtract,
1627 {this->getType(type), Word::Result(type), base, Word::Number(component)},
1628 out);
1629 }
1630
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1631 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1632 SpvId base,
1633 int componentA,
1634 int componentB,
1635 OutputStream& out) {
1636 // If the base op is a composite, we can extract from it directly.
1637 SpvId result = this->toComponent(base, componentA);
1638 if (result != NA) {
1639 return this->writeOpCompositeExtract(type, result, componentB, out);
1640 }
1641 return this->writeInstruction(SpvOpCompositeExtract,
1642 {this->getType(type),
1643 Word::Result(type),
1644 base,
1645 Word::Number(componentA),
1646 Word::Number(componentB)},
1647 out);
1648 }
1649
writeCapabilities(OutputStream & out)1650 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1651 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1652 if (fCapabilities & bit) {
1653 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1654 }
1655 }
1656 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1657 #ifdef SKSL_EXT
1658 for (auto i : fCapabilitiesExt) {
1659 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1660 }
1661 #endif
1662 }
1663
nextId(const Type * type)1664 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1665 return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1666 ? Precision::kRelaxed
1667 : Precision::kDefault);
1668 }
1669
nextId(Precision precision)1670 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1671 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1672 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1673 fDecorationBuffer);
1674 }
1675 return fIdCount++;
1676 }
1677
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1678 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1679 // If we've already written out this struct, return its existing SpvId.
1680 if (SpvId* cachedStructId = fStructMap.find(&type)) {
1681 return *cachedStructId;
1682 }
1683
1684 // Write all of the field types first, so we don't inadvertently write them while we're in the
1685 // middle of writing the struct instruction.
1686 Words words;
1687 words.push_back(Word::UniqueResult());
1688 for (const auto& f : type.fields()) {
1689 words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1690 }
1691 SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1692 this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1693 fStructMap.set(&type, resultId);
1694
1695 size_t offset = 0;
1696 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1697 const Field& field = type.fields()[i];
1698 if (!memoryLayout.isSupported(*field.fType)) {
1699 fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1700 "' is not permitted here");
1701 return resultId;
1702 }
1703 size_t size = memoryLayout.size(*field.fType);
1704 size_t alignment = memoryLayout.alignment(*field.fType);
1705 const Layout& fieldLayout = field.fLayout;
1706 if (fieldLayout.fOffset >= 0) {
1707 if (fieldLayout.fOffset < (int) offset) {
1708 fContext.fErrors->error(field.fPosition, "offset of field '" +
1709 std::string(field.fName) + "' must be at least " + std::to_string(offset));
1710 }
1711 if (fieldLayout.fOffset % alignment) {
1712 fContext.fErrors->error(field.fPosition,
1713 "offset of field '" + std::string(field.fName) +
1714 "' must be a multiple of " + std::to_string(alignment));
1715 }
1716 offset = fieldLayout.fOffset;
1717 } else {
1718 size_t mod = offset % alignment;
1719 if (mod) {
1720 offset += alignment - mod;
1721 }
1722 }
1723 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1724 this->writeFieldLayout(fieldLayout, resultId, i);
1725 if (field.fLayout.fBuiltin < 0) {
1726 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1727 (SpvId) offset, fDecorationBuffer);
1728 }
1729 if (field.fType->isMatrix()) {
1730 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1731 fDecorationBuffer);
1732 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1733 (SpvId) memoryLayout.stride(*field.fType),
1734 fDecorationBuffer);
1735 }
1736 #ifdef SKSL_EXT
1737 if (field.fType->isArray()) {
1738 if (field.fType->componentType().isMatrix()) {
1739 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1740 fDecorationBuffer);
1741 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1742 (SpvId) memoryLayout.stride(field.fType->componentType()),
1743 fDecorationBuffer);
1744 }
1745 }
1746 #endif
1747 if (!field.fType->highPrecision()) {
1748 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1749 SpvDecorationRelaxedPrecision, fDecorationBuffer);
1750 }
1751 offset += size;
1752 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1753 offset += alignment - offset % alignment;
1754 }
1755 }
1756
1757 return resultId;
1758 }
1759
getType(const Type & type)1760 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1761 return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1762 }
1763
layout_flags_to_image_format(LayoutFlags flags)1764 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1765 flags &= LayoutFlag::kAllPixelFormats;
1766 switch (flags.value()) {
1767 case (int)LayoutFlag::kRGBA8:
1768 return SpvImageFormatRgba8;
1769
1770 case (int)LayoutFlag::kRGBA32F:
1771 return SpvImageFormatRgba32f;
1772
1773 case (int)LayoutFlag::kR32F:
1774 return SpvImageFormatR32f;
1775
1776 default:
1777 return SpvImageFormatUnknown;
1778 }
1779
1780 SkUNREACHABLE;
1781 }
1782
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1783 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1784 const Layout& typeLayout,
1785 const MemoryLayout& memoryLayout) {
1786 const Type* type = &rawType;
1787
1788 switch (type->typeKind()) {
1789 case Type::TypeKind::kVoid: {
1790 return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1791 }
1792 case Type::TypeKind::kScalar:
1793 case Type::TypeKind::kLiteral: {
1794 if (type->isBoolean()) {
1795 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1796 }
1797 if (type->isSigned()) {
1798 return this->writeInstruction(
1799 SpvOpTypeInt,
1800 Words{Word::Result(), Word::Number(32), Word::Number(1)},
1801 fConstantBuffer);
1802 }
1803 if (type->isUnsigned()) {
1804 return this->writeInstruction(
1805 SpvOpTypeInt,
1806 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1807 fConstantBuffer);
1808 }
1809 if (type->isFloat()) {
1810 return this->writeInstruction(
1811 SpvOpTypeFloat,
1812 Words{Word::Result(), Word::Number(32)},
1813 fConstantBuffer);
1814 }
1815 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1816 return NA;
1817 }
1818 case Type::TypeKind::kVector: {
1819 SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1820 return this->writeInstruction(
1821 SpvOpTypeVector,
1822 Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1823 fConstantBuffer);
1824 }
1825 case Type::TypeKind::kMatrix: {
1826 SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1827 typeLayout,
1828 memoryLayout);
1829 return this->writeInstruction(
1830 SpvOpTypeMatrix,
1831 Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1832 fConstantBuffer);
1833 }
1834 case Type::TypeKind::kArray: {
1835 const MemoryLayout arrayMemoryLayout =
1836 fCaps.fForceStd430ArrayLayout
1837 ? MemoryLayout(MemoryLayout::Standard::k430)
1838 : memoryLayout;
1839
1840 if (!arrayMemoryLayout.isSupported(*type)) {
1841 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1842 "' is not permitted here");
1843 return NA;
1844 }
1845 size_t stride = arrayMemoryLayout.stride(*type);
1846 SpvId typeId = this->getType(type->componentType(), typeLayout, arrayMemoryLayout);
1847 SpvId result = NA;
1848 if (type->isUnsizedArray()) {
1849 #ifdef SKSL_EXT
1850 if (type->componentType().isSampler()) {
1851 fCapabilitiesExt.insert(SpvCapabilityRuntimeDescriptorArray);
1852 fExtensions.insert("SPV_EXT_descriptor_indexing");
1853 }
1854 #endif
1855 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1856 Words{Word::KeyedResult(stride), typeId},
1857 fConstantBuffer);
1858 } else {
1859 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1860 result = this->writeInstruction(SpvOpTypeArray,
1861 Words{Word::KeyedResult(stride), typeId, countId},
1862 fConstantBuffer);
1863 }
1864 this->writeInstruction(SpvOpDecorate,
1865 {result, SpvDecorationArrayStride, Word::Number(stride)},
1866 fDecorationBuffer);
1867 return result;
1868 }
1869 case Type::TypeKind::kStruct: {
1870 return this->writeStruct(*type, memoryLayout);
1871 }
1872 case Type::TypeKind::kSeparateSampler: {
1873 return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1874 }
1875 case Type::TypeKind::kSampler: {
1876 if (SpvDimBuffer == type->dimensions()) {
1877 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1878 }
1879 SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1880 return this->writeInstruction(SpvOpTypeSampledImage,
1881 Words{Word::Result(), imageTypeId},
1882 fConstantBuffer);
1883 }
1884 case Type::TypeKind::kTexture: {
1885 SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1886 kDefaultTypeLayout,
1887 memoryLayout);
1888
1889 bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1890 SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1891 ? layout_flags_to_image_format(typeLayout.fFlags)
1892 : SpvImageFormatUnknown;
1893
1894 return this->writeInstruction(SpvOpTypeImage,
1895 Words{Word::Result(),
1896 floatTypeId,
1897 Word::Number(type->dimensions()),
1898 Word::Number(type->isDepth()),
1899 Word::Number(type->isArrayedTexture()),
1900 Word::Number(type->isMultisampled()),
1901 Word::Number(sampled ? 1 : 2),
1902 format},
1903 fConstantBuffer);
1904 }
1905 case Type::TypeKind::kAtomic: {
1906 // SkSL currently only supports the atomicUint type.
1907 SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1908 // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1909 // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1910 return this->writeInstruction(SpvOpTypeInt,
1911 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1912 fConstantBuffer);
1913 }
1914 default: {
1915 SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1916 return NA;
1917 }
1918 }
1919 }
1920
getFunctionType(const FunctionDeclaration & function)1921 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1922 Words words;
1923 words.push_back(Word::Result());
1924 words.push_back(this->getType(function.returnType()));
1925 for (const Variable* parameter : function.parameters()) {
1926 bool paramIsSpecialized = fActiveSpecialization && fActiveSpecialization->find(parameter);
1927 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1928 words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1929 parameter->layout()));
1930 if (!paramIsSpecialized) {
1931 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1932 kDefaultTypeLayout));
1933 }
1934 } else if (!paramIsSpecialized) {
1935 words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1936 }
1937 }
1938 return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1939 }
1940
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1941 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1942 const Layout& parameterLayout) {
1943 // glslang treats all function arguments as pointers whether they need to be or
1944 // not. I was initially puzzled by this until I ran bizarre failures with certain
1945 // patterns of function calls and control constructs, as exemplified by this minimal
1946 // failure case:
1947 //
1948 // void sphere(float x) {
1949 // }
1950 //
1951 // void map() {
1952 // sphere(1.0);
1953 // }
1954 //
1955 // void main() {
1956 // for (int i = 0; i < 1; i++) {
1957 // map();
1958 // }
1959 // }
1960 //
1961 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1962 // crashes. Making it take a float* and storing the argument in a temporary variable,
1963 // as glslang does, fixes it.
1964 //
1965 // The consensus among shader compiler authors seems to be that GPU driver generally don't
1966 // handle value-based parameters consistently. It is highly likely that they fit their
1967 // implementations to conform to glslang. We take care to do so ourselves.
1968 //
1969 // Our implementation first stores every parameter value into a function storage-class pointer
1970 // before calling a function. The exception is for opaque handle types (samplers and textures)
1971 // which must be stored in a pointer with UniformConstant storage-class. This prevents
1972 // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1973 // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1974 StorageClass storageClass;
1975 if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1976 parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1977 parameterType.typeKind() == Type::TypeKind::kTexture) {
1978 storageClass = StorageClass::kUniformConstant;
1979 } else {
1980 storageClass = StorageClass::kFunction;
1981 }
1982 return this->getPointerType(parameterType,
1983 parameterLayout,
1984 this->memoryLayoutForStorageClass(storageClass),
1985 storageClass);
1986 }
1987
getPointerType(const Type & type,StorageClass storageClass)1988 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, StorageClass storageClass) {
1989 return this->getPointerType(type,
1990 kDefaultTypeLayout,
1991 this->memoryLayoutForStorageClass(storageClass),
1992 storageClass);
1993 }
1994
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,StorageClass storageClass)1995 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1996 const Layout& typeLayout,
1997 const MemoryLayout& memoryLayout,
1998 StorageClass storageClass) {
1999 return this->writeInstruction(SpvOpTypePointer,
2000 Words{Word::Result(),
2001 Word::Number(get_storage_class_spv_id(storageClass)),
2002 this->getType(type, typeLayout, memoryLayout)},
2003 fConstantBuffer);
2004 }
2005
writeExpression(const Expression & expr,OutputStream & out)2006 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
2007 switch (expr.kind()) {
2008 case Expression::Kind::kBinary:
2009 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
2010 case Expression::Kind::kConstructorArrayCast:
2011 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
2012 case Expression::Kind::kConstructorArray:
2013 case Expression::Kind::kConstructorStruct:
2014 return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
2015 case Expression::Kind::kConstructorDiagonalMatrix:
2016 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
2017 case Expression::Kind::kConstructorMatrixResize:
2018 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
2019 case Expression::Kind::kConstructorScalarCast:
2020 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
2021 case Expression::Kind::kConstructorSplat:
2022 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
2023 case Expression::Kind::kConstructorCompound:
2024 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
2025 case Expression::Kind::kConstructorCompoundCast:
2026 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
2027 case Expression::Kind::kEmpty:
2028 return NA;
2029 case Expression::Kind::kFieldAccess:
2030 return this->writeFieldAccess(expr.as<FieldAccess>(), out);
2031 case Expression::Kind::kFunctionCall:
2032 return this->writeFunctionCall(expr.as<FunctionCall>(), out);
2033 case Expression::Kind::kLiteral:
2034 return this->writeLiteral(expr.as<Literal>());
2035 case Expression::Kind::kPrefix:
2036 return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
2037 case Expression::Kind::kPostfix:
2038 return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
2039 case Expression::Kind::kSwizzle:
2040 return this->writeSwizzle(expr.as<Swizzle>(), out);
2041 case Expression::Kind::kVariableReference:
2042 return this->writeVariableReference(expr.as<VariableReference>(), out);
2043 case Expression::Kind::kTernary:
2044 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
2045 case Expression::Kind::kIndex:
2046 return this->writeIndexExpression(expr.as<IndexExpression>(), out);
2047 case Expression::Kind::kSetting:
2048 return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
2049 default:
2050 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
2051 break;
2052 }
2053 return NA;
2054 }
2055
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)2056 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
2057 const FunctionDeclaration& function = c.function();
2058 Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
2059 if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
2060 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
2061 "'");
2062 return NA;
2063 }
2064 const ExpressionArray& arguments = c.arguments();
2065 int32_t intrinsicId = intrinsic.floatOp;
2066 if (!arguments.empty()) {
2067 const Type& type = arguments[0]->type();
2068 if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
2069 // Keep the default float op.
2070 } else {
2071 intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
2072 intrinsic.unsignedOp, intrinsic.boolOp);
2073 }
2074 }
2075 switch (intrinsic.opKind) {
2076 case kGLSL_STD_450_IntrinsicOpcodeKind: {
2077 SpvId result = this->nextId(&c.type());
2078 TArray<SpvId> argumentIds;
2079 argumentIds.reserve_exact(arguments.size());
2080 std::vector<TempVar> tempVars;
2081 for (int i = 0; i < arguments.size(); i++) {
2082 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2083 /*specializedParams=*/nullptr, out);
2084 }
2085 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2086 this->writeWord(this->getType(c.type()), out);
2087 this->writeWord(result, out);
2088 this->writeWord(fGLSLExtendedInstructions, out);
2089 this->writeWord(intrinsicId, out);
2090 for (SpvId id : argumentIds) {
2091 this->writeWord(id, out);
2092 }
2093 this->copyBackTempVars(tempVars, out);
2094 return result;
2095 }
2096 case kSPIRV_IntrinsicOpcodeKind: {
2097 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
2098 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
2099 intrinsicId = SpvOpFMul;
2100 }
2101 SpvId result = this->nextId(&c.type());
2102 TArray<SpvId> argumentIds;
2103 argumentIds.reserve_exact(arguments.size());
2104 std::vector<TempVar> tempVars;
2105 for (int i = 0; i < arguments.size(); i++) {
2106 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2107 /*specializedParams=*/nullptr, out);
2108 }
2109 if (!c.type().isVoid()) {
2110 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
2111 this->writeWord(this->getType(c.type()), out);
2112 this->writeWord(result, out);
2113 } else {
2114 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
2115 }
2116 for (SpvId id : argumentIds) {
2117 this->writeWord(id, out);
2118 }
2119 this->copyBackTempVars(tempVars, out);
2120 return result;
2121 }
2122 case kSpecial_IntrinsicOpcodeKind:
2123 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
2124 default:
2125 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
2126 function.description() + "'");
2127 return NA;
2128 }
2129 }
2130
vectorize(const Expression & arg,int vectorSize,OutputStream & out)2131 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
2132 SkASSERT(vectorSize >= 1 && vectorSize <= 4);
2133 const Type& argType = arg.type();
2134 if (argType.isScalar() && vectorSize > 1) {
2135 SpvId argID = this->writeExpression(arg, out);
2136 return this->splat(argType.toCompound(fContext, vectorSize, /*rows=*/1), argID, out);
2137 }
2138
2139 SkASSERT(vectorSize == argType.columns());
2140 return this->writeExpression(arg, out);
2141 }
2142
vectorize(const ExpressionArray & args,OutputStream & out)2143 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2144 int vectorSize = 1;
2145 for (const auto& a : args) {
2146 if (a->type().isVector()) {
2147 if (vectorSize > 1) {
2148 SkASSERT(a->type().columns() == vectorSize);
2149 } else {
2150 vectorSize = a->type().columns();
2151 }
2152 }
2153 }
2154 TArray<SpvId> result;
2155 result.reserve_exact(args.size());
2156 for (const auto& arg : args) {
2157 result.push_back(this->vectorize(*arg, vectorSize, out));
2158 }
2159 return result;
2160 }
2161
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2162 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2163 SpvId signedInst, SpvId unsignedInst,
2164 const TArray<SpvId>& args,
2165 OutputStream& out) {
2166 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2167 this->writeWord(this->getType(type), out);
2168 this->writeWord(id, out);
2169 this->writeWord(fGLSLExtendedInstructions, out);
2170 this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2171 for (SpvId a : args) {
2172 this->writeWord(a, out);
2173 }
2174 }
2175
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2176 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2177 OutputStream& out) {
2178 const ExpressionArray& arguments = c.arguments();
2179 const Type& callType = c.type();
2180 #ifdef SKSL_EXT
2181 SpvId result = this->nextId(kind == kMix_SpecialIntrinsic ? nullptr : &callType);
2182 #else
2183 SpvId result = this->nextId(nullptr);
2184 #endif
2185 switch (kind) {
2186 case kAtan_SpecialIntrinsic: {
2187 STArray<2, SpvId> argumentIds;
2188 for (const std::unique_ptr<Expression>& arg : arguments) {
2189 argumentIds.push_back(this->writeExpression(*arg, out));
2190 }
2191 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2192 this->writeWord(this->getType(callType), out);
2193 this->writeWord(result, out);
2194 this->writeWord(fGLSLExtendedInstructions, out);
2195 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2196 for (SpvId id : argumentIds) {
2197 this->writeWord(id, out);
2198 }
2199 break;
2200 }
2201 case kSampledImage_SpecialIntrinsic: {
2202 SkASSERT(arguments.size() == 2);
2203 SpvId img = this->writeExpression(*arguments[0], out);
2204 SpvId sampler = this->writeExpression(*arguments[1], out);
2205 this->writeInstruction(SpvOpSampledImage,
2206 this->getType(callType),
2207 result,
2208 img,
2209 sampler,
2210 out);
2211 break;
2212 }
2213 case kSubpassLoad_SpecialIntrinsic: {
2214 SpvId img = this->writeExpression(*arguments[0], out);
2215 ExpressionArray args;
2216 args.reserve_exact(2);
2217 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2218 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2219 ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2220 SpvId coords = this->writeExpression(ctor, out);
2221 if (arguments.size() == 1) {
2222 this->writeInstruction(SpvOpImageRead,
2223 this->getType(callType),
2224 result,
2225 img,
2226 coords,
2227 out);
2228 } else {
2229 SkASSERT(arguments.size() == 2);
2230 SpvId sample = this->writeExpression(*arguments[1], out);
2231 this->writeInstruction(SpvOpImageRead,
2232 this->getType(callType),
2233 result,
2234 img,
2235 coords,
2236 SpvImageOperandsSampleMask,
2237 sample,
2238 out);
2239 }
2240 break;
2241 }
2242 case kTexture_SpecialIntrinsic: {
2243 SpvOp_ op = SpvOpImageSampleImplicitLod;
2244 const Type& arg1Type = arguments[1]->type();
2245 switch (arguments[0]->type().dimensions()) {
2246 case SpvDim1D:
2247 if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2248 op = SpvOpImageSampleProjImplicitLod;
2249 } else {
2250 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2251 }
2252 break;
2253 case SpvDim2D:
2254 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2255 op = SpvOpImageSampleProjImplicitLod;
2256 } else {
2257 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2258 }
2259 break;
2260 case SpvDim3D:
2261 if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2262 op = SpvOpImageSampleProjImplicitLod;
2263 } else {
2264 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2265 }
2266 break;
2267 case SpvDimCube: // fall through
2268 case SpvDimRect: // fall through
2269 case SpvDimBuffer: // fall through
2270 case SpvDimSubpassData:
2271 break;
2272 }
2273 SpvId type = this->getType(callType);
2274 SpvId sampler = this->writeExpression(*arguments[0], out);
2275 SpvId uv = this->writeExpression(*arguments[1], out);
2276 if (arguments.size() == 3) {
2277 this->writeInstruction(op, type, result, sampler, uv,
2278 SpvImageOperandsBiasMask,
2279 this->writeExpression(*arguments[2], out),
2280 out);
2281 } else {
2282 SkASSERT(arguments.size() == 2);
2283 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2284 SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2285 *fContext.fTypes.fFloat);
2286 this->writeInstruction(op, type, result, sampler, uv,
2287 SpvImageOperandsBiasMask, lodBias, out);
2288 } else {
2289 this->writeInstruction(op, type, result, sampler, uv,
2290 out);
2291 }
2292 }
2293 break;
2294 }
2295 case kTextureGrad_SpecialIntrinsic: {
2296 SpvOp_ op = SpvOpImageSampleExplicitLod;
2297 SkASSERT(arguments.size() == 4);
2298 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2299 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2300 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2301 SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2302 SpvId type = this->getType(callType);
2303 SpvId sampler = this->writeExpression(*arguments[0], out);
2304 SpvId uv = this->writeExpression(*arguments[1], out);
2305 SpvId dPdx = this->writeExpression(*arguments[2], out);
2306 SpvId dPdy = this->writeExpression(*arguments[3], out);
2307 this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2308 dPdx, dPdy, out);
2309 break;
2310 }
2311 case kTextureLod_SpecialIntrinsic: {
2312 SpvOp_ op = SpvOpImageSampleExplicitLod;
2313 SkASSERT(arguments.size() == 3);
2314 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2315 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2316 const Type& arg1Type = arguments[1]->type();
2317 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2318 op = SpvOpImageSampleProjExplicitLod;
2319 } else {
2320 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2321 }
2322 SpvId type = this->getType(callType);
2323 SpvId sampler = this->writeExpression(*arguments[0], out);
2324 SpvId uv = this->writeExpression(*arguments[1], out);
2325 this->writeInstruction(op, type, result, sampler, uv,
2326 SpvImageOperandsLodMask,
2327 this->writeExpression(*arguments[2], out),
2328 out);
2329 break;
2330 }
2331 case kTextureRead_SpecialIntrinsic: {
2332 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2333 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2334
2335 SpvId type = this->getType(callType);
2336 SpvId image = this->writeExpression(*arguments[0], out);
2337 SpvId coord = this->writeExpression(*arguments[1], out);
2338
2339 const Type& arg0Type = arguments[0]->type();
2340 SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2341
2342 switch (arg0Type.textureAccess()) {
2343 case Type::TextureAccess::kSample:
2344 this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2345 SpvImageOperandsLodMask,
2346 this->writeOpConstant(*fContext.fTypes.fInt, 0),
2347 out);
2348 break;
2349 case Type::TextureAccess::kRead:
2350 case Type::TextureAccess::kReadWrite:
2351 this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2352 break;
2353 case Type::TextureAccess::kWrite:
2354 default:
2355 SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2356 break;
2357 }
2358
2359 break;
2360 }
2361 case kTextureWrite_SpecialIntrinsic: {
2362 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2363 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2364 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2365
2366 SpvId image = this->writeExpression(*arguments[0], out);
2367 SpvId coord = this->writeExpression(*arguments[1], out);
2368 SpvId texel = this->writeExpression(*arguments[2], out);
2369
2370 this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2371 break;
2372 }
2373 case kTextureWidth_SpecialIntrinsic:
2374 case kTextureHeight_SpecialIntrinsic: {
2375 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2376 fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2377
2378 SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2379 SpvId dims = this->nextId(nullptr);
2380 SpvId image = this->writeExpression(*arguments[0], out);
2381 this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2382
2383 SpvId type = this->getType(callType);
2384 int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2385 this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2386 break;
2387 }
2388 case kMod_SpecialIntrinsic: {
2389 TArray<SpvId> args = this->vectorize(arguments, out);
2390 SkASSERT(args.size() == 2);
2391 const Type& operandType = arguments[0]->type();
2392 SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2393 SkASSERT(op != SpvOpUndef);
2394 this->writeOpCode(op, 5, out);
2395 this->writeWord(this->getType(operandType), out);
2396 this->writeWord(result, out);
2397 this->writeWord(args[0], out);
2398 this->writeWord(args[1], out);
2399 break;
2400 }
2401 case kDFdy_SpecialIntrinsic: {
2402 SpvId fn = this->writeExpression(*arguments[0], out);
2403 this->writeOpCode(SpvOpDPdy, 4, out);
2404 this->writeWord(this->getType(callType), out);
2405 this->writeWord(result, out);
2406 this->writeWord(fn, out);
2407 if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2408 this->addRTFlipUniform(c.fPosition);
2409 ComponentArray componentArray;
2410 for (int index = 0; index < callType.columns(); ++index) {
2411 componentArray.push_back(SwizzleComponent::Y);
2412 }
2413 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2414 componentArray, out);
2415 SpvId flipped = this->nextId(&callType);
2416 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2417 rtFlipY, out);
2418 result = flipped;
2419 }
2420 break;
2421 }
2422 case kClamp_SpecialIntrinsic: {
2423 TArray<SpvId> args = this->vectorize(arguments, out);
2424 SkASSERT(args.size() == 3);
2425 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2426 GLSLstd450UClamp, args, out);
2427 break;
2428 }
2429 case kMax_SpecialIntrinsic: {
2430 TArray<SpvId> args = this->vectorize(arguments, out);
2431 SkASSERT(args.size() == 2);
2432 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2433 GLSLstd450UMax, args, out);
2434 break;
2435 }
2436 case kMin_SpecialIntrinsic: {
2437 TArray<SpvId> args = this->vectorize(arguments, out);
2438 SkASSERT(args.size() == 2);
2439 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2440 GLSLstd450UMin, args, out);
2441 break;
2442 }
2443 case kMix_SpecialIntrinsic: {
2444 TArray<SpvId> args = this->vectorize(arguments, out);
2445 SkASSERT(args.size() == 3);
2446 if (arguments[2]->type().componentType().isBoolean()) {
2447 // Use OpSelect to implement Boolean mix().
2448 SpvId falseId = this->writeExpression(*arguments[0], out);
2449 SpvId trueId = this->writeExpression(*arguments[1], out);
2450 SpvId conditionId = this->writeExpression(*arguments[2], out);
2451 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2452 conditionId, trueId, falseId, out);
2453 } else {
2454 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2455 SpvOpUndef, args, out);
2456 }
2457 break;
2458 }
2459 case kSaturate_SpecialIntrinsic: {
2460 SkASSERT(arguments.size() == 1);
2461 int width = arguments[0]->type().columns();
2462 STArray<3, SpvId> spvArgs{
2463 this->vectorize(*arguments[0], width, out),
2464 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/0), width, out),
2465 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/1), width, out),
2466 };
2467 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2468 GLSLstd450UClamp, spvArgs, out);
2469 break;
2470 }
2471 case kSmoothStep_SpecialIntrinsic: {
2472 TArray<SpvId> args = this->vectorize(arguments, out);
2473 SkASSERT(args.size() == 3);
2474 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2475 SpvOpUndef, args, out);
2476 break;
2477 }
2478 case kStep_SpecialIntrinsic: {
2479 TArray<SpvId> args = this->vectorize(arguments, out);
2480 SkASSERT(args.size() == 2);
2481 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2482 SpvOpUndef, args, out);
2483 break;
2484 }
2485 case kMatrixCompMult_SpecialIntrinsic: {
2486 SkASSERT(arguments.size() == 2);
2487 SpvId lhs = this->writeExpression(*arguments[0], out);
2488 SpvId rhs = this->writeExpression(*arguments[1], out);
2489 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2490 break;
2491 }
2492 #ifdef SKSL_EXT
2493 case kTextureSize_SpecialIntrinsic: {
2494 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2495
2496 fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2497
2498 SpvId dimsType = this->getType(*fContext.fTypes.fInt2);
2499 SpvId sampledImage = this->writeExpression(*arguments[0], out);
2500 SpvId image = this->nextId(nullptr);
2501 SpvId imageType = this->getType(arguments[0]->type().textureType());
2502 SpvId lod = this->writeExpression(*arguments[1], out);
2503 this->writeInstruction(SpvOpImage, imageType, image, sampledImage, out);
2504 this->writeInstruction(SpvOpImageQuerySizeLod, dimsType, result, image, lod, out);
2505 break;
2506 }
2507 case kSampleGather_SpecialIntrinsic: {
2508 SpvOp_ op = SpvOpImageGather;
2509 SpvId type = this->getType(callType);
2510 SpvId sampler = this->writeExpression(*arguments[0], out);
2511 SpvId uv = this->writeExpression(*arguments[1], out);
2512 SpvId comp;
2513 // 3 is the number of input parameters for this function.
2514 if (arguments.size() == 3) {
2515 // 2 is the position of the comp parameter, representing the sequence number of the color channel.
2516 comp = this->writeExpression(*arguments[2], out);
2517 } else {
2518 // 2 is the number of input parameter for function.
2519 SkASSERT(arguments.size() == 2);
2520 comp = this->writeLiteral(0, *fContext.fTypes.fInt);
2521 }
2522 this->writeInstruction(op, type, result, sampler, uv, comp, out);
2523 break;
2524 }
2525 case kNonuniformEXT_SpecialIntrinsic: {
2526 fCapabilitiesExt.insert(SpvCapabilityShaderNonUniform);
2527 SpvId dimsType = this->getType(*fContext.fTypes.fUInt);
2528 this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2529 SpvId lod = this->writeExpression(*arguments[0], out);
2530 fNonUniformSpvId.insert(result);
2531 this->writeInstruction(SpvOpCopyObject, dimsType, result, lod, out);
2532 break;
2533 }
2534 #endif
2535 case kAtomicAdd_SpecialIntrinsic:
2536 case kAtomicLoad_SpecialIntrinsic:
2537 case kAtomicStore_SpecialIntrinsic:
2538 result = this->writeAtomicIntrinsic(c, kind, result, out);
2539 break;
2540 case kStorageBarrier_SpecialIntrinsic:
2541 case kWorkgroupBarrier_SpecialIntrinsic: {
2542 // Both barrier types operate in the workgroup execution and memory scope and differ
2543 // only in memory semantics. storageBarrier() is not a device-scope barrier.
2544 SpvId scopeId =
2545 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2546 int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2547 ? SpvMemorySemanticsAcquireReleaseMask |
2548 SpvMemorySemanticsUniformMemoryMask
2549 : SpvMemorySemanticsAcquireReleaseMask |
2550 SpvMemorySemanticsWorkgroupMemoryMask;
2551 SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2552 this->writeInstruction(SpvOpControlBarrier,
2553 scopeId, // execution scope
2554 scopeId, // memory scope
2555 memorySemanticsId,
2556 out);
2557 break;
2558 }
2559 }
2560 return result;
2561 }
2562
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2563 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2564 SpecialIntrinsic kind,
2565 SpvId resultId,
2566 OutputStream& out) {
2567 const ExpressionArray& arguments = c.arguments();
2568 SkASSERT(!arguments.empty());
2569
2570 std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2571 SpvId atomicPtrId = atomicPtr->getPointer();
2572 if (atomicPtrId == NA) {
2573 SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2574 arguments[0]->description().c_str());
2575 return NA;
2576 }
2577
2578 SpvId memoryScopeId = NA;
2579 {
2580 // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2581 // member. The two memory scopes that these map to are "workgroup" and "device",
2582 // respectively.
2583 SpvScope memoryScope;
2584 switch (atomicPtr->storageClass()) {
2585 case StorageClass::kUniform:
2586 case StorageClass::kStorageBuffer:
2587 // We encode storage buffers in the uniform address space (with the BufferBlock
2588 // decorator).
2589 memoryScope = SpvScopeDevice;
2590 break;
2591 case StorageClass::kWorkgroup:
2592 memoryScope = SpvScopeWorkgroup;
2593 break;
2594 default:
2595 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2596 get_storage_class_spv_id(atomicPtr->storageClass()));
2597 return NA;
2598 }
2599 memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2600 }
2601
2602 SpvId relaxedMemoryOrderId =
2603 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2604
2605 switch (kind) {
2606 case kAtomicAdd_SpecialIntrinsic:
2607 SkASSERT(arguments.size() == 2);
2608 this->writeInstruction(SpvOpAtomicIAdd,
2609 this->getType(c.type()),
2610 resultId,
2611 atomicPtrId,
2612 memoryScopeId,
2613 relaxedMemoryOrderId,
2614 this->writeExpression(*arguments[1], out),
2615 out);
2616 break;
2617 case kAtomicLoad_SpecialIntrinsic:
2618 SkASSERT(arguments.size() == 1);
2619 this->writeInstruction(SpvOpAtomicLoad,
2620 this->getType(c.type()),
2621 resultId,
2622 atomicPtrId,
2623 memoryScopeId,
2624 relaxedMemoryOrderId,
2625 out);
2626 break;
2627 case kAtomicStore_SpecialIntrinsic:
2628 SkASSERT(arguments.size() == 2);
2629 this->writeInstruction(SpvOpAtomicStore,
2630 atomicPtrId,
2631 memoryScopeId,
2632 relaxedMemoryOrderId,
2633 this->writeExpression(*arguments[1], out),
2634 out);
2635 break;
2636 default:
2637 SkUNREACHABLE;
2638 }
2639
2640 return resultId;
2641 }
2642
writeFunctionCallArgument(TArray<SpvId> & argumentList,const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,const SkBitSet * specializedParams,OutputStream & out)2643 void SPIRVCodeGenerator::writeFunctionCallArgument(TArray<SpvId>& argumentList,
2644 const FunctionCall& call,
2645 int argIndex,
2646 std::vector<TempVar>* tempVars,
2647 const SkBitSet* specializedParams,
2648 OutputStream& out) {
2649 const FunctionDeclaration& funcDecl = call.function();
2650 const Expression& arg = *call.arguments()[argIndex];
2651 const Variable* param = funcDecl.parameters()[argIndex];
2652 bool paramIsSpecialized = specializedParams && specializedParams->test(argIndex);
2653 ModifierFlags paramFlags = param->modifierFlags();
2654
2655 // Ignore the argument since it is specialized, if fUseTextureSamplerPairs is true and this
2656 // argument is a sampler, handle ignoring the sampler below when generating the texture and
2657 // sampler pair arguments.
2658 if (paramIsSpecialized && !(param->type().isSampler() && fUseTextureSamplerPairs)) {
2659 return;
2660 }
2661
2662 if (arg.is<VariableReference>() && (arg.type().typeKind() == Type::TypeKind::kSampler ||
2663 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2664 arg.type().typeKind() == Type::TypeKind::kTexture)) {
2665 // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2666 // stored in intermediates when calling user-defined functions.
2667 //
2668 // The case for intrinsics (which take opaque arguments by value) is handled above just like
2669 // regular pointers.
2670 //
2671 // See getFunctionParameterType for further explanation.
2672 const Variable* var = arg.as<VariableReference>().variable();
2673
2674 // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2675 if (fUseTextureSamplerPairs && var->type().isSampler()) {
2676 if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2677 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2678 SkASSERT(img);
2679
2680 argumentList.push_back(*img);
2681
2682 if (!paramIsSpecialized) {
2683 SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2684 SkASSERT(sampler);
2685 argumentList.push_back(*sampler);
2686 }
2687 return;
2688 }
2689 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2690 }
2691
2692 SpvId* entry = fVariableMap.find(var);
2693 SkASSERTF(entry, "%s", arg.description().c_str());
2694 argumentList.push_back(*entry);
2695 return;
2696 }
2697 SkASSERT(!paramIsSpecialized);
2698
2699 // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2700 // passed directly
2701 SpvId tmpVar = NA;
2702 // if we need a temporary var to store this argument, this is the value to store in the var
2703 SpvId tmpValueId = NA;
2704
2705 if (is_out(paramFlags)) {
2706 std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2707 // We handle out params with a temp var that we copy back to the original variable at the
2708 // end of the call. GLSL guarantees that the original variable will be unchanged until the
2709 // end of the call, and also that out params are written back to their original variables in
2710 // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2711 if (is_in(paramFlags)) {
2712 tmpValueId = lv->load(out);
2713 }
2714 tmpVar = this->nextId(&arg.type());
2715 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2716 } else if (funcDecl.isIntrinsic()) {
2717 // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2718 argumentList.push_back(this->writeExpression(arg, out));
2719 return;
2720 } else {
2721 // We always use pointer parameters when calling user functions.
2722 // See getFunctionParameterType for further explanation.
2723 tmpValueId = this->writeExpression(arg, out);
2724 tmpVar = this->nextId(nullptr);
2725 }
2726 this->writeInstruction(SpvOpVariable,
2727 this->getPointerType(arg.type(), StorageClass::kFunction),
2728 tmpVar,
2729 SpvStorageClassFunction,
2730 fVariableBuffer);
2731 if (tmpValueId != NA) {
2732 this->writeOpStore(StorageClass::kFunction, tmpVar, tmpValueId, out);
2733 }
2734 argumentList.push_back(tmpVar);
2735 }
2736
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2737 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2738 for (const TempVar& tempVar : tempVars) {
2739 SpvId load = this->nextId(tempVar.type);
2740 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2741 tempVar.lvalue->store(load, out);
2742 }
2743 }
2744
writeFunctionCall(const FunctionCall & c,OutputStream & out)2745 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2746 // Handle intrinsics.
2747 const FunctionDeclaration& function = c.function();
2748 if (function.isIntrinsic() && !function.definition()) {
2749 return this->writeIntrinsicCall(c, out);
2750 }
2751
2752 // Look up this function (or its specialization, if any) in our map of function SpvIds.
2753 Analysis::SpecializationIndex specializationIndex = Analysis::FindSpecializationIndexForCall(
2754 c, fSpecializationInfo, fActiveSpecializationIndex);
2755 SpvId* entry = fFunctionMap.find({&function, specializationIndex});
2756 if (!entry) {
2757 fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2758 "' is not defined");
2759 return NA;
2760 }
2761
2762 // If we are calling a specialized function, we need to gather the specialized parameters
2763 // so we can remove them from the argument list.
2764 SkBitSet specializedParams =
2765 Analysis::FindSpecializedParametersForFunction(c.function(), fSpecializationInfo);
2766
2767 // Temp variables are used to write back out-parameters after the function call is complete.
2768 const ExpressionArray& arguments = c.arguments();
2769 std::vector<TempVar> tempVars;
2770 TArray<SpvId> argumentIds;
2771 argumentIds.reserve_exact(arguments.size());
2772 for (int i = 0; i < arguments.size(); i++) {
2773 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars, &specializedParams, out);
2774 }
2775 SpvId result = this->nextId(nullptr);
2776 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2777 this->writeWord(this->getType(c.type()), out);
2778 this->writeWord(result, out);
2779 this->writeWord(*entry, out);
2780 for (SpvId id : argumentIds) {
2781 this->writeWord(id, out);
2782 }
2783 // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2784 this->copyBackTempVars(tempVars, out);
2785 return result;
2786 }
2787
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2788 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2789 const Type& inputType,
2790 const Type& outputType,
2791 OutputStream& out) {
2792 if (outputType.isFloat()) {
2793 return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2794 }
2795 if (outputType.isSigned()) {
2796 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2797 }
2798 if (outputType.isUnsigned()) {
2799 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2800 }
2801 if (outputType.isBoolean()) {
2802 return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2803 }
2804
2805 fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2806 outputType.description());
2807 return inputExprId;
2808 }
2809
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2810 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2811 const Type& outputType, OutputStream& out) {
2812 // Casting a float to float is a no-op.
2813 if (inputType.isFloat()) {
2814 return inputId;
2815 }
2816
2817 // Given the input type, generate the appropriate instruction to cast to float.
2818 SpvId result = this->nextId(&outputType);
2819 if (inputType.isBoolean()) {
2820 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2821 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2822 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2823 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2824 inputId, oneID, zeroID, out);
2825 } else if (inputType.isSigned()) {
2826 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2827 } else if (inputType.isUnsigned()) {
2828 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2829 } else {
2830 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2831 return NA;
2832 }
2833 return result;
2834 }
2835
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2836 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2837 const Type& outputType, OutputStream& out) {
2838 // Casting a signed int to signed int is a no-op.
2839 if (inputType.isSigned()) {
2840 return inputId;
2841 }
2842
2843 // Given the input type, generate the appropriate instruction to cast to signed int.
2844 SpvId result = this->nextId(&outputType);
2845 if (inputType.isBoolean()) {
2846 // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2847 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2848 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2849 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2850 inputId, oneID, zeroID, out);
2851 } else if (inputType.isFloat()) {
2852 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2853 } else if (inputType.isUnsigned()) {
2854 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2855 } else {
2856 SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2857 inputType.description().c_str());
2858 return NA;
2859 }
2860 #ifdef SKSL_EXT
2861 if (fNonUniformSpvId.find(inputId) != fNonUniformSpvId.end()) {
2862 fNonUniformSpvId.insert(result);
2863 this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2864 }
2865 #endif
2866 return result;
2867 }
2868
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2869 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2870 const Type& outputType, OutputStream& out) {
2871 // Casting an unsigned int to unsigned int is a no-op.
2872 if (inputType.isUnsigned()) {
2873 return inputId;
2874 }
2875
2876 // Given the input type, generate the appropriate instruction to cast to unsigned int.
2877 SpvId result = this->nextId(&outputType);
2878 if (inputType.isBoolean()) {
2879 // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2880 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2881 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2882 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2883 inputId, oneID, zeroID, out);
2884 } else if (inputType.isFloat()) {
2885 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2886 } else if (inputType.isSigned()) {
2887 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2888 } else {
2889 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2890 inputType.description().c_str());
2891 return NA;
2892 }
2893 #ifdef SKSL_EXT
2894 if (fNonUniformSpvId.find(inputId) != fNonUniformSpvId.end()) {
2895 fNonUniformSpvId.insert(result);
2896 this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2897 }
2898 #endif
2899 return result;
2900 }
2901
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2902 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2903 const Type& outputType, OutputStream& out) {
2904 // Casting a bool to bool is a no-op.
2905 if (inputType.isBoolean()) {
2906 return inputId;
2907 }
2908
2909 // Given the input type, generate the appropriate instruction to cast to bool.
2910 SpvId result = this->nextId(nullptr);
2911 if (inputType.isSigned()) {
2912 // Synthesize a boolean result by comparing the input against a signed zero literal.
2913 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2914 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2915 inputId, zeroID, out);
2916 } else if (inputType.isUnsigned()) {
2917 // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2918 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2919 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2920 inputId, zeroID, out);
2921 } else if (inputType.isFloat()) {
2922 // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2923 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2924 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2925 inputId, zeroID, out);
2926 } else {
2927 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2928 return NA;
2929 }
2930 return result;
2931 }
2932
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2933 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2934 OutputStream& out) {
2935 SkASSERT(srcType.isMatrix());
2936 SkASSERT(dstType.isMatrix());
2937 SkASSERT(srcType.componentType().matches(dstType.componentType()));
2938 const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2939 const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2940 SkASSERT(dstType.componentType().isFloat());
2941 SpvId dstColumnTypeId = this->getType(dstColumnType);
2942 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2943 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2944
2945 STArray<4, SpvId> columns;
2946 for (int i = 0; i < dstType.columns(); i++) {
2947 if (i < srcType.columns()) {
2948 // we're still inside the src matrix, copy the column
2949 SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2950 SpvId dstColumn;
2951 if (srcType.rows() == dstType.rows()) {
2952 // columns are equal size, don't need to do anything
2953 dstColumn = srcColumn;
2954 }
2955 else if (dstType.rows() > srcType.rows()) {
2956 // dst column is bigger, need to zero-pad it
2957 STArray<4, SpvId> values;
2958 values.push_back(srcColumn);
2959 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2960 values.push_back((i == j) ? oneId : zeroId);
2961 }
2962 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2963 }
2964 else {
2965 // dst column is smaller, need to swizzle the src column
2966 dstColumn = this->nextId(&dstType);
2967 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2968 this->writeWord(dstColumnTypeId, out);
2969 this->writeWord(dstColumn, out);
2970 this->writeWord(srcColumn, out);
2971 this->writeWord(srcColumn, out);
2972 for (int j = 0; j < dstType.rows(); j++) {
2973 this->writeWord(j, out);
2974 }
2975 }
2976 columns.push_back(dstColumn);
2977 } else {
2978 // we're past the end of the src matrix, need to synthesize an identity-matrix column
2979 STArray<4, SpvId> values;
2980 for (int j = 0; j < dstType.rows(); ++j) {
2981 values.push_back((i == j) ? oneId : zeroId);
2982 }
2983 columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2984 }
2985 }
2986
2987 return this->writeOpCompositeConstruct(dstType, columns, out);
2988 }
2989
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2990 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2991 TArray<SpvId>* currentColumn,
2992 TArray<SpvId>* columnIds,
2993 int rows,
2994 SpvId entry,
2995 OutputStream& out) {
2996 SkASSERT(currentColumn->size() < rows);
2997 currentColumn->push_back(entry);
2998 if (currentColumn->size() == rows) {
2999 // Synthesize this column into a vector.
3000 SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
3001 columnIds->push_back(columnId);
3002 currentColumn->clear();
3003 }
3004 }
3005
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)3006 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
3007 const Type& type = c.type();
3008 SkASSERT(type.isMatrix());
3009 SkASSERT(!c.arguments().empty());
3010 const Type& arg0Type = c.arguments()[0]->type();
3011 // go ahead and write the arguments so we don't try to write new instructions in the middle of
3012 // an instruction
3013 STArray<16, SpvId> arguments;
3014 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
3015 arguments.push_back(this->writeExpression(*arg, out));
3016 }
3017
3018 if (arguments.size() == 1 && arg0Type.isVector()) {
3019 // Special-case handling of float4 -> mat2x2.
3020 SkASSERT(type.rows() == 2 && type.columns() == 2);
3021 SkASSERT(arg0Type.columns() == 4);
3022 SpvId v[4];
3023 for (int i = 0; i < 4; ++i) {
3024 v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
3025 }
3026 const Type& vecType = type.columnType(fContext);
3027 SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
3028 SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
3029 return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
3030 }
3031
3032 int rows = type.rows();
3033 const Type& columnType = type.columnType(fContext);
3034 // SpvIds of completed columns of the matrix.
3035 STArray<4, SpvId> columnIds;
3036 // SpvIds of scalars we have written to the current column so far.
3037 STArray<4, SpvId> currentColumn;
3038 for (int i = 0; i < arguments.size(); i++) {
3039 const Type& argType = c.arguments()[i]->type();
3040 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
3041 // This vector is a complete matrix column by itself and can be used as-is.
3042 columnIds.push_back(arguments[i]);
3043 } else if (argType.columns() == 1) {
3044 // This argument is a lone scalar and can be added to the current column as-is.
3045 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out);
3046 } else {
3047 // This argument needs to be decomposed into its constituent scalars.
3048 for (int j = 0; j < argType.columns(); ++j) {
3049 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
3050 arguments[i], j, out);
3051 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out);
3052 }
3053 }
3054 }
3055 SkASSERT(columnIds.size() == type.columns());
3056 return this->writeOpCompositeConstruct(type, columnIds, out);
3057 }
3058
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)3059 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
3060 OutputStream& out) {
3061 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
3062 : this->writeVectorConstructor(c, out);
3063 }
3064
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)3065 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
3066 const Type& type = c.type();
3067 const Type& componentType = type.componentType();
3068 SkASSERT(type.isVector());
3069
3070 STArray<4, SpvId> arguments;
3071 for (int i = 0; i < c.arguments().size(); i++) {
3072 const Type& argType = c.arguments()[i]->type();
3073 SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
3074
3075 SpvId arg = this->writeExpression(*c.arguments()[i], out);
3076 if (argType.isMatrix()) {
3077 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
3078 // each scalar separately.
3079 SkASSERT(argType.rows() == 2);
3080 SkASSERT(argType.columns() == 2);
3081 for (int j = 0; j < 4; ++j) {
3082 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
3083 j / 2, j % 2, out));
3084 }
3085 } else if (argType.isVector()) {
3086 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
3087 // vector arguments at all, so we always extract each vector component and pass them
3088 // into OpCompositeConstruct individually.
3089 for (int j = 0; j < argType.columns(); j++) {
3090 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
3091 }
3092 } else {
3093 arguments.push_back(arg);
3094 }
3095 }
3096
3097 return this->writeOpCompositeConstruct(type, arguments, out);
3098 }
3099
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)3100 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
3101 // Write the splat argument as a scalar, then splat it.
3102 SpvId argument = this->writeExpression(*c.argument(), out);
3103 return this->splat(c.type(), argument, out);
3104 }
3105
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)3106 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
3107 SkASSERT(c.type().isArray() || c.type().isStruct());
3108 auto ctorArgs = c.argumentSpan();
3109
3110 STArray<4, SpvId> arguments;
3111 for (const std::unique_ptr<Expression>& arg : ctorArgs) {
3112 arguments.push_back(this->writeExpression(*arg, out));
3113 }
3114
3115 return this->writeOpCompositeConstruct(c.type(), arguments, out);
3116 }
3117
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)3118 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
3119 OutputStream& out) {
3120 const Type& type = c.type();
3121 if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
3122 return this->writeExpression(*c.argument(), out);
3123 }
3124
3125 const Expression& ctorExpr = *c.argument();
3126 SpvId expressionId = this->writeExpression(ctorExpr, out);
3127 return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
3128 }
3129
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)3130 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
3131 OutputStream& out) {
3132 const Type& ctorType = c.type();
3133 const Type& argType = c.argument()->type();
3134 SkASSERT(ctorType.isVector() || ctorType.isMatrix());
3135
3136 // Write the composite that we are casting. If the actual type matches, we are done.
3137 SpvId compositeId = this->writeExpression(*c.argument(), out);
3138 if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
3139 return compositeId;
3140 }
3141
3142 // writeMatrixCopy can cast matrices to a different type.
3143 if (ctorType.isMatrix()) {
3144 return this->writeMatrixCopy(compositeId, argType, ctorType, out);
3145 }
3146
3147 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
3148 // components and convert each one manually.
3149 const Type& srcType = argType.componentType();
3150 const Type& dstType = ctorType.componentType();
3151
3152 STArray<4, SpvId> arguments;
3153 for (int index = 0; index < argType.columns(); ++index) {
3154 SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
3155 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
3156 }
3157
3158 return this->writeOpCompositeConstruct(ctorType, arguments, out);
3159 }
3160
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)3161 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
3162 OutputStream& out) {
3163 const Type& type = c.type();
3164 SkASSERT(type.isMatrix());
3165 SkASSERT(c.argument()->type().isScalar());
3166
3167 // Write out the scalar argument.
3168 SpvId diagonal = this->writeExpression(*c.argument(), out);
3169
3170 // Build the diagonal matrix.
3171 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3172
3173 const Type& vecType = type.columnType(fContext);
3174 STArray<4, SpvId> columnIds;
3175 STArray<4, SpvId> arguments;
3176 arguments.resize(type.rows());
3177 for (int column = 0; column < type.columns(); column++) {
3178 for (int row = 0; row < type.rows(); row++) {
3179 arguments[row] = (row == column) ? diagonal : zeroId;
3180 }
3181 columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
3182 }
3183 return this->writeOpCompositeConstruct(type, columnIds, out);
3184 }
3185
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)3186 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
3187 OutputStream& out) {
3188 // Write the input matrix.
3189 SpvId argument = this->writeExpression(*c.argument(), out);
3190
3191 // Use matrix-copy to resize the input matrix to its new size.
3192 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
3193 }
3194
get_storage_class_for_global_variable(const Variable & var,StorageClass fallbackStorageClass)3195 static StorageClass get_storage_class_for_global_variable(
3196 const Variable& var, StorageClass fallbackStorageClass) {
3197 SkASSERT(var.storage() == Variable::Storage::kGlobal);
3198
3199 if (var.type().typeKind() == Type::TypeKind::kSampler ||
3200 #ifdef SKSL_EXT
3201 (var.type().typeKind() == Type::TypeKind::kArray &&
3202 var.type().componentType().typeKind() == Type::TypeKind::kSampler) ||
3203 #endif
3204 var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
3205 var.type().typeKind() == Type::TypeKind::kTexture) {
3206 return StorageClass::kUniformConstant;
3207 }
3208
3209 const Layout& layout = var.layout();
3210 ModifierFlags flags = var.modifierFlags();
3211 #ifdef SKSL_EXT
3212 if (!(flags & ModifierFlag::kUniform) &&
3213 var.storage() == Variable::Storage::kGlobal &&
3214 (flags & ModifierFlag::kConst)) {
3215 return StorageClass::kFunction;
3216 }
3217 #endif
3218 if (flags & ModifierFlag::kIn) {
3219 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3220 return StorageClass::kInput;
3221 }
3222 if (flags & ModifierFlag::kOut) {
3223 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3224 return StorageClass::kOutput;
3225 }
3226 if (flags.isUniform()) {
3227 if (layout.fFlags & LayoutFlag::kPushConstant) {
3228 return StorageClass::kPushConstant;
3229 }
3230 return StorageClass::kUniform;
3231 }
3232 if (flags.isBuffer()) {
3233 return StorageClass::kStorageBuffer;
3234 }
3235 if (flags.isWorkgroup()) {
3236 return StorageClass::kWorkgroup;
3237 }
3238 return fallbackStorageClass;
3239 }
3240
getStorageClass(const Expression & expr)3241 StorageClass SPIRVCodeGenerator::getStorageClass(const Expression& expr) {
3242 switch (expr.kind()) {
3243 case Expression::Kind::kVariableReference: {
3244 const Variable& var = *expr.as<VariableReference>().variable();
3245 if (fActiveSpecialization) {
3246 const Expression** specializedExpr = fActiveSpecialization->find(&var);
3247 if (specializedExpr && (*specializedExpr)->is<FieldAccess>()) {
3248 return this->getStorageClass(**specializedExpr);
3249 }
3250 }
3251 if (var.storage() != Variable::Storage::kGlobal) {
3252 return StorageClass::kFunction;
3253 }
3254 return get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3255 }
3256 case Expression::Kind::kFieldAccess:
3257 return this->getStorageClass(*expr.as<FieldAccess>().base());
3258 case Expression::Kind::kIndex:
3259 return this->getStorageClass(*expr.as<IndexExpression>().base());
3260 default:
3261 return StorageClass::kFunction;
3262 }
3263 }
3264
getAccessChain(const Expression & expr,OutputStream & out)3265 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3266 switch (expr.kind()) {
3267 case Expression::Kind::kIndex: {
3268 const IndexExpression& indexExpr = expr.as<IndexExpression>();
3269 if (indexExpr.base()->is<Swizzle>()) {
3270 // Access chains don't directly support dynamically indexing into a swizzle, but we
3271 // can rewrite them into a supported form.
3272 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3273 out);
3274 }
3275 // All other index-expressions can be represented as typical access chains.
3276 TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3277 chain.push_back(this->writeExpression(*indexExpr.index(), out));
3278 return chain;
3279 }
3280 case Expression::Kind::kFieldAccess: {
3281 const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3282 TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3283 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3284 return chain;
3285 }
3286 case Expression::Kind::kVariableReference: {
3287 if (fActiveSpecialization) {
3288 const Expression** specializedFieldIndex =
3289 fActiveSpecialization->find(expr.as<VariableReference>().variable());
3290 if (specializedFieldIndex && (*specializedFieldIndex)->is<FieldAccess>()) {
3291 return this->getAccessChain(**specializedFieldIndex, out);
3292 }
3293 }
3294 [[fallthrough]];
3295 }
3296 default: {
3297 SpvId id = this->getLValue(expr, out)->getPointer();
3298 SkASSERT(id != NA);
3299 return TArray<SpvId>{id};
3300 }
3301 }
3302 SkUNREACHABLE;
3303 }
3304
3305 class PointerLValue : public SPIRVCodeGenerator::LValue {
3306 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,StorageClass storageClass)3307 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3308 SPIRVCodeGenerator::Precision precision, StorageClass storageClass)
3309 : fGen(gen)
3310 , fPointer(pointer)
3311 , fIsMemoryObject(isMemoryObject)
3312 , fType(type)
3313 , fPrecision(precision)
3314 , fStorageClass(storageClass) {}
3315
getPointer()3316 SpvId getPointer() override {
3317 return fPointer;
3318 }
3319
isMemoryObjectPointer() const3320 bool isMemoryObjectPointer() const override {
3321 return fIsMemoryObject;
3322 }
3323
storageClass() const3324 StorageClass storageClass() const override {
3325 return fStorageClass;
3326 }
3327
load(OutputStream & out)3328 SpvId load(OutputStream& out) override {
3329 return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3330 }
3331
store(SpvId value,OutputStream & out)3332 void store(SpvId value, OutputStream& out) override {
3333 if (!fIsMemoryObject) {
3334 // We are going to write into an access chain; this could represent one component of a
3335 // vector, or one element of an array. This has the potential to invalidate other,
3336 // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3337 // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3338 // relying on stale data, reset the store cache entirely when this happens.
3339 fGen.fStoreCache.reset();
3340 }
3341
3342 fGen.writeOpStore(fStorageClass, fPointer, value, out);
3343 }
3344
3345 private:
3346 SPIRVCodeGenerator& fGen;
3347 const SpvId fPointer;
3348 const bool fIsMemoryObject;
3349 const SpvId fType;
3350 const SPIRVCodeGenerator::Precision fPrecision;
3351 const StorageClass fStorageClass;
3352 };
3353
3354 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3355 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,StorageClass storageClass)3356 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3357 const Type& baseType, const Type& swizzleType, StorageClass storageClass)
3358 : fGen(gen)
3359 , fVecPointer(vecPointer)
3360 , fComponents(components)
3361 , fBaseType(&baseType)
3362 , fSwizzleType(&swizzleType)
3363 , fStorageClass(storageClass) {}
3364
applySwizzle(const ComponentArray & components,const Type & newType)3365 bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3366 ComponentArray updatedSwizzle;
3367 for (int8_t component : components) {
3368 if (component < 0 || component >= fComponents.size()) {
3369 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3370 return false;
3371 }
3372 updatedSwizzle.push_back(fComponents[component]);
3373 }
3374 fComponents = updatedSwizzle;
3375 fSwizzleType = &newType;
3376 return true;
3377 }
3378
storageClass() const3379 StorageClass storageClass() const override {
3380 return fStorageClass;
3381 }
3382
load(OutputStream & out)3383 SpvId load(OutputStream& out) override {
3384 SpvId base = fGen.nextId(fBaseType);
3385 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3386 SpvId result = fGen.nextId(fBaseType);
3387 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3388 fGen.writeWord(fGen.getType(*fSwizzleType), out);
3389 fGen.writeWord(result, out);
3390 fGen.writeWord(base, out);
3391 fGen.writeWord(base, out);
3392 for (int component : fComponents) {
3393 fGen.writeWord(component, out);
3394 }
3395 return result;
3396 }
3397
store(SpvId value,OutputStream & out)3398 void store(SpvId value, OutputStream& out) override {
3399 // use OpVectorShuffle to mix and match the vector components. We effectively create
3400 // a virtual vector out of the concatenation of the left and right vectors, and then
3401 // select components from this virtual vector to make the result vector. For
3402 // instance, given:
3403 // float3L = ...;
3404 // float3R = ...;
3405 // L.xz = R.xy;
3406 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3407 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3408 // (3, 1, 4).
3409 SpvId base = fGen.nextId(fBaseType);
3410 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3411 SpvId shuffle = fGen.nextId(fBaseType);
3412 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3413 fGen.writeWord(fGen.getType(*fBaseType), out);
3414 fGen.writeWord(shuffle, out);
3415 fGen.writeWord(base, out);
3416 fGen.writeWord(value, out);
3417 for (int i = 0; i < fBaseType->columns(); i++) {
3418 // current offset into the virtual vector, defaults to pulling the unmodified
3419 // value from the left side
3420 int offset = i;
3421 // check to see if we are writing this component
3422 for (int j = 0; j < fComponents.size(); j++) {
3423 if (fComponents[j] == i) {
3424 // we're writing to this component, so adjust the offset to pull from
3425 // the correct component of the right side instead of preserving the
3426 // value from the left
3427 offset = (int) (j + fBaseType->columns());
3428 break;
3429 }
3430 }
3431 fGen.writeWord(offset, out);
3432 }
3433 fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3434 }
3435
3436 private:
3437 SPIRVCodeGenerator& fGen;
3438 const SpvId fVecPointer;
3439 ComponentArray fComponents;
3440 const Type* fBaseType;
3441 const Type* fSwizzleType;
3442 const StorageClass fStorageClass;
3443 };
3444
findUniformFieldIndex(const Variable & var) const3445 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3446 int* fieldIndex = fTopLevelUniformMap.find(&var);
3447 return fieldIndex ? *fieldIndex : -1;
3448 }
3449
getLValue(const Expression & expr,OutputStream & out)3450 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3451 OutputStream& out) {
3452 const Type& type = expr.type();
3453 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3454 switch (expr.kind()) {
3455 case Expression::Kind::kVariableReference: {
3456 const Variable& var = *expr.as<VariableReference>().variable();
3457 int uniformIdx = this->findUniformFieldIndex(var);
3458 if (uniformIdx >= 0) {
3459 // Access uniforms via an AccessChain into the uniform-buffer struct.
3460 SpvId memberId = this->nextId(nullptr);
3461 SpvId typeId = this->getPointerType(type, StorageClass::kUniform);
3462 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3463 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3464 uniformIdxId, out);
3465 return std::make_unique<PointerLValue>(
3466 *this,
3467 memberId,
3468 /*isMemoryObjectPointer=*/true,
3469 this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3470 precision,
3471 StorageClass::kUniform);
3472 }
3473 #ifdef SKSL_EXT
3474 if (fGlobalConstVariableValueMap.find(&var) != fGlobalConstVariableValueMap.end()) {
3475 SpvId id = this->nextId(&type);
3476 fVariableMap[&var] = id;
3477 SpvId typeId = this->getPointerType(type, StorageClass::kFunction);
3478 this->writeInstruction(SpvOpVariable, typeId, id, SpvStorageClassFunction, fVariableBuffer);
3479 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3480 this->writeInstruction(SpvOpStore, id, fGlobalConstVariableValueMap[&var], out);
3481 }
3482 #endif
3483 SpvId* entry = fVariableMap.find(&var);
3484 SkASSERTF(entry, "%s", expr.description().c_str());
3485
3486 if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3487 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3488 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3489 // represents sample masks as an array of uints.
3490 StorageClass storageClass =
3491 get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3492 SkASSERT(storageClass != StorageClass::kPrivate);
3493 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3494
3495 SpvId accessId = this->nextId(nullptr);
3496 SpvId typeId = this->getPointerType(type, storageClass);
3497 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3498 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3499 return std::make_unique<PointerLValue>(*this,
3500 accessId,
3501 /*isMemoryObjectPointer=*/true,
3502 this->getType(type),
3503 precision,
3504 storageClass);
3505 }
3506 SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3507 return std::make_unique<PointerLValue>(*this, *entry,
3508 /*isMemoryObjectPointer=*/true,
3509 typeId, precision, this->getStorageClass(expr));
3510 }
3511 case Expression::Kind::kIndex: // fall through
3512 case Expression::Kind::kFieldAccess: {
3513 TArray<SpvId> chain = this->getAccessChain(expr, out);
3514 SpvId member = this->nextId(nullptr);
3515 StorageClass storageClass = this->getStorageClass(expr);
3516 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3517 this->writeWord(this->getPointerType(type, storageClass), out);
3518 this->writeWord(member, out);
3519 #ifdef SKSL_EXT
3520 bool needDecorate = false;
3521 for (SpvId idx : chain) {
3522 this->writeWord(idx, out);
3523 needDecorate |= fNonUniformSpvId.find(idx) != fNonUniformSpvId.end();
3524 }
3525 if (needDecorate) {
3526 fNonUniformSpvId.insert(member);
3527 this->writeInstruction(SpvOpDecorate, member, SpvDecorationNonUniform, fDecorationBuffer);
3528 }
3529 #else
3530 for (SpvId idx : chain) {
3531 this->writeWord(idx, out);
3532 }
3533 #endif
3534 return std::make_unique<PointerLValue>(
3535 *this,
3536 member,
3537 /*isMemoryObjectPointer=*/false,
3538 this->getType(type,
3539 kDefaultTypeLayout,
3540 this->memoryLayoutForStorageClass(storageClass)),
3541 precision,
3542 storageClass);
3543 }
3544 case Expression::Kind::kSwizzle: {
3545 const Swizzle& swizzle = expr.as<Swizzle>();
3546 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3547 if (lvalue->applySwizzle(swizzle.components(), type)) {
3548 return lvalue;
3549 }
3550 SpvId base = lvalue->getPointer();
3551 if (base == NA) {
3552 fContext.fErrors->error(swizzle.fPosition,
3553 "unable to retrieve lvalue from swizzle");
3554 }
3555 StorageClass storageClass = this->getStorageClass(*swizzle.base());
3556 if (swizzle.components().size() == 1) {
3557 SpvId member = this->nextId(nullptr);
3558 SpvId typeId = this->getPointerType(type, storageClass);
3559 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3560 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3561 return std::make_unique<PointerLValue>(*this, member,
3562 /*isMemoryObjectPointer=*/false,
3563 this->getType(type),
3564 precision, storageClass);
3565 } else {
3566 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3567 swizzle.base()->type(), type, storageClass);
3568 }
3569 }
3570 default: {
3571 // expr isn't actually an lvalue, create a placeholder variable for it. This case
3572 // happens due to the need to store values in temporary variables during function
3573 // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3574 // lvalues should have been caught before code generation.
3575 //
3576 // This is with the exception of opaque handle types (textures/samplers) which are
3577 // always defined as UniformConstant pointers and don't need to be explicitly stored
3578 // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3579 SpvId result = this->nextId(nullptr);
3580 SpvId pointerType = this->getPointerType(type, StorageClass::kFunction);
3581 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3582 fVariableBuffer);
3583 this->writeOpStore(StorageClass::kFunction, result, this->writeExpression(expr, out),
3584 out);
3585 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3586 this->getType(type), precision,
3587 StorageClass::kFunction);
3588 }
3589 }
3590 }
3591
identifier(std::string_view name)3592 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3593 std::unique_ptr<Expression> expr =
3594 fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3595 return expr ? std::move(expr)
3596 : Poison::Make(Position(), fContext);
3597 }
3598
writeVariableReference(const VariableReference & ref,OutputStream & out)3599 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3600 const Variable* variable = ref.variable();
3601 switch (variable->layout().fBuiltin) {
3602 case DEVICE_FRAGCOORDS_BUILTIN: {
3603 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3604 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3605 // access the fragcoord; do so now.
3606 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3607 }
3608 case DEVICE_CLOCKWISE_BUILTIN: {
3609 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3610 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3611 // access front facing; do so now.
3612 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3613 }
3614 case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3615 if (fCaps.fDualSourceBlendingSupport) {
3616 return this->getLValue(*this->identifier("sk_SecondaryFragColor"), out)->load(out);
3617 } else {
3618 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
3619 return NA;
3620 }
3621 }
3622 case SK_FRAGCOORD_BUILTIN: {
3623 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3624 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3625 }
3626
3627 // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3628 this->addRTFlipUniform(ref.fPosition);
3629 // Use sk_RTAdjust to compute the flipped coordinate
3630 // Use a uniform to flip the Y coordinate. The new expression will be written in
3631 // terms of $device_FragCoords, which is a fake variable that means "access the
3632 // underlying fragcoords directly without flipping it".
3633 static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3634 if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3635 AutoAttachPoolToThread attach(fProgram.fPool.get());
3636 Layout layout;
3637 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3638 auto coordsVar = Variable::Make(/*pos=*/Position(),
3639 /*modifiersPosition=*/Position(),
3640 layout,
3641 ModifierFlag::kNone,
3642 fContext.fTypes.fFloat4.get(),
3643 DEVICE_COORDS_NAME,
3644 /*mangledName=*/"",
3645 /*builtin=*/true,
3646 Variable::Storage::kGlobal);
3647 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3648 }
3649 std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3650 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3651 SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3652 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3653 SpvId deviceCoordX = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3654 SpvId deviceCoordY = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3655 SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3656 SwizzleComponent::W}, out);
3657 // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3658 SpvId flippedY = this->writeBinaryExpression(
3659 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3660 *fContext.fTypes.fFloat, deviceCoordY,
3661 *fContext.fTypes.fFloat, out);
3662
3663 // Compute `flippedY = u_RTFlip.x + flippedY`.
3664 flippedY = this->writeBinaryExpression(
3665 *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3666 *fContext.fTypes.fFloat, flippedY,
3667 *fContext.fTypes.fFloat, out);
3668
3669 // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3670 return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3671 {deviceCoordX, flippedY, deviceCoordZW},
3672 out);
3673 }
3674 case SK_CLOCKWISE_BUILTIN: {
3675 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3676 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3677 }
3678
3679 // Apply RTFlip to sk_Clockwise.
3680 this->addRTFlipUniform(ref.fPosition);
3681 // Use a uniform to flip the Y coordinate. The new expression will be written in
3682 // terms of $device_Clockwise, which is a fake variable that means "access the
3683 // underlying FrontFacing directly".
3684 static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3685 if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3686 AutoAttachPoolToThread attach(fProgram.fPool.get());
3687 Layout layout;
3688 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3689 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3690 /*modifiersPosition=*/Position(),
3691 layout,
3692 ModifierFlag::kNone,
3693 fContext.fTypes.fBool.get(),
3694 DEVICE_CLOCKWISE_NAME,
3695 /*mangledName=*/"",
3696 /*builtin=*/true,
3697 Variable::Storage::kGlobal);
3698 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3699 }
3700 // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3701 // we use the default convention of "counter-clockwise face is front".
3702
3703 // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3704 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3705 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3706 SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3707 SpvId positiveRTFlip = this->writeBinaryExpression(
3708 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3709 *fContext.fTypes.fFloat, zero,
3710 *fContext.fTypes.fBool, out);
3711
3712 // Compute `positiveRTFlip ^^ $device_Clockwise`.
3713 std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3714 SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3715 return this->writeBinaryExpression(
3716 *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3717 *fContext.fTypes.fBool, deviceClockwiseID,
3718 *fContext.fTypes.fBool, out);
3719 }
3720 default: {
3721 // Constant-propagate variables that have a known compile-time value.
3722 if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3723 return this->writeExpression(*expr, out);
3724 }
3725
3726 // A reference to a sampler variable at global scope with synthesized texture/sampler
3727 // backing should construct a function-scope combined image-sampler from the synthesized
3728 // constituents. This is the case in which a sample intrinsic was invoked.
3729 //
3730 // Variable references to opaque handles (texture/sampler) that appear as the argument
3731 // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3732 if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3733 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3734 SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3735 SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3736 SkASSERT(imgPtr);
3737 SkASSERT(samplerPtr);
3738
3739 SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3740 Precision::kDefault, *imgPtr, out);
3741 SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3742 Precision::kDefault,
3743 *samplerPtr,
3744 out);
3745 SpvId result = this->nextId(nullptr);
3746 this->writeInstruction(SpvOpSampledImage,
3747 this->getType(variable->type()),
3748 result,
3749 img,
3750 sampler,
3751 out);
3752 return result;
3753 }
3754 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3755 }
3756 #ifdef SKSL_EXT
3757 const Variable* var = ref.as<VariableReference>().variable();
3758 if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
3759 return fVariableMap[var];
3760 }
3761 #endif
3762 return this->getLValue(ref, out)->load(out);
3763 }
3764 }
3765 }
3766
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3767 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3768 if (expr.base()->type().isVector()) {
3769 SpvId base = this->writeExpression(*expr.base(), out);
3770 SpvId index = this->writeExpression(*expr.index(), out);
3771 SpvId result = this->nextId(nullptr);
3772 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3773 index, out);
3774 return result;
3775 }
3776 return getLValue(expr, out)->load(out);
3777 }
3778
writeFieldAccess(const FieldAccess & f,OutputStream & out)3779 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3780 return getLValue(f, out)->load(out);
3781 }
3782
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3783 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3784 const ComponentArray& components,
3785 OutputStream& out) {
3786 size_t count = components.size();
3787 const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3788 SpvId base = this->writeExpression(baseExpr, out);
3789 if (count == 1) {
3790 return this->writeOpCompositeExtract(type, base, components[0], out);
3791 }
3792
3793 SpvId result = this->nextId(&type);
3794 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3795 this->writeWord(this->getType(type), out);
3796 this->writeWord(result, out);
3797 this->writeWord(base, out);
3798 this->writeWord(base, out);
3799 for (int component : components) {
3800 this->writeWord(component, out);
3801 }
3802 return result;
3803 }
3804
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3805 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3806 return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3807 }
3808
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3809 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3810 SpvId lhs, SpvId rhs,
3811 bool writeComponentwiseIfMatrix,
3812 SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3813 SpvOp_ ifBool, OutputStream& out) {
3814 SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3815 if (op == SpvOpUndef) {
3816 fContext.fErrors->error(operandType.fPosition,
3817 "unsupported operand for binary expression: " + operandType.description());
3818 return NA;
3819 }
3820 if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3821 return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3822 }
3823 SpvId result = this->nextId(&resultType);
3824 this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3825 return result;
3826 }
3827
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3828 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3829 const Type& operandType,
3830 SpvId lhs, SpvId rhs,
3831 SpvOp_ ifFloat, SpvOp_ ifInt,
3832 SpvOp_ ifUInt, SpvOp_ ifBool,
3833 OutputStream& out) {
3834 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3835 /*writeComponentwiseIfMatrix=*/true,
3836 ifFloat, ifInt, ifUInt, ifBool, out);
3837 }
3838
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3839 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3840 SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3841 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3842 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3843 /*writeComponentwiseIfMatrix=*/false,
3844 ifFloat, ifInt, ifUInt, ifBool, out);
3845 }
3846
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3847 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3848 OutputStream& out) {
3849 if (operandType.isVector()) {
3850 SpvId result = this->nextId(nullptr);
3851 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3852 return result;
3853 }
3854 return id;
3855 }
3856
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3857 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3858 SpvOp_ floatOperator, SpvOp_ intOperator,
3859 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3860 OutputStream& out) {
3861 SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3862 SkASSERT(operandType.isMatrix());
3863 const Type& columnType = operandType.componentType().toCompound(fContext,
3864 operandType.rows(),
3865 1);
3866 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3867 operandType.rows(),
3868 1));
3869 SpvId boolType = this->getType(*fContext.fTypes.fBool);
3870 SpvId result = 0;
3871 for (int i = 0; i < operandType.columns(); i++) {
3872 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3873 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3874 SpvId compare = this->nextId(&operandType);
3875 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3876 SpvId merge = this->nextId(nullptr);
3877 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3878 if (result != 0) {
3879 SpvId next = this->nextId(nullptr);
3880 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3881 result = next;
3882 } else {
3883 result = merge;
3884 }
3885 }
3886 return result;
3887 }
3888
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3889 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3890 SpvId operand,
3891 SpvOp_ op,
3892 OutputStream& out) {
3893 SkASSERT(operandType.isMatrix());
3894 const Type& columnType = operandType.columnType(fContext);
3895 SpvId columnTypeId = this->getType(columnType);
3896
3897 STArray<4, SpvId> columns;
3898 for (int i = 0; i < operandType.columns(); i++) {
3899 SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3900 SpvId dstColumn = this->nextId(&operandType);
3901 this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3902 columns.push_back(dstColumn);
3903 }
3904
3905 return this->writeOpCompositeConstruct(operandType, columns, out);
3906 }
3907
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3908 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3909 SpvId rhs, SpvOp_ op, OutputStream& out) {
3910 SkASSERT(operandType.isMatrix());
3911 const Type& columnType = operandType.columnType(fContext);
3912 SpvId columnTypeId = this->getType(columnType);
3913
3914 STArray<4, SpvId> columns;
3915 for (int i = 0; i < operandType.columns(); i++) {
3916 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3917 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3918 columns.push_back(this->nextId(&operandType));
3919 this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3920 }
3921 return this->writeOpCompositeConstruct(operandType, columns, out);
3922 }
3923
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3924 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3925 SkASSERT(type.isFloat());
3926 SpvId one = this->writeLiteral(1.0, type);
3927 SpvId reciprocal = this->nextId(&type);
3928 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3929 return reciprocal;
3930 }
3931
splat(const Type & type,SpvId id,OutputStream & out)3932 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3933 if (type.isScalar()) {
3934 // Scalars require no additional work; we can return the passed-in ID as is.
3935 } else {
3936 SkASSERT(type.isVector() || type.isMatrix());
3937 bool isMatrix = type.isMatrix();
3938
3939 // Splat the input scalar across a vector.
3940 int vectorSize = (isMatrix ? type.rows() : type.columns());
3941 const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3942
3943 STArray<4, SpvId> values;
3944 values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3945 id = this->writeOpCompositeConstruct(vectorType, values, out);
3946
3947 if (isMatrix) {
3948 // Splat the newly-synthesized vector into a matrix.
3949 STArray<4, SpvId> matArguments;
3950 matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3951 id = this->writeOpCompositeConstruct(type, matArguments, out);
3952 }
3953 }
3954
3955 return id;
3956 }
3957
types_match(const Type & a,const Type & b)3958 static bool types_match(const Type& a, const Type& b) {
3959 if (a.matches(b)) {
3960 return true;
3961 }
3962 return (a.typeKind() == b.typeKind()) &&
3963 (a.isScalar() || a.isVector() || a.isMatrix()) &&
3964 (a.columns() == b.columns() && a.rows() == b.rows()) &&
3965 a.componentType().numberKind() == b.componentType().numberKind();
3966 }
3967
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3968 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3969 SpvId lhs,
3970 const Type& rightType,
3971 SpvId rhs,
3972 const Type& resultType,
3973 OutputStream& out) {
3974 SpvId sum = NA;
3975 const Type& columnType = leftType.columnType(fContext);
3976 const Type& scalarType = rightType.componentType();
3977
3978 for (int n = 0; n < leftType.rows(); ++n) {
3979 // Extract mat[N] from the matrix.
3980 SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3981
3982 // Extract vec[N] from the vector.
3983 SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3984
3985 // Multiply them together.
3986 SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3987 scalarType, vecN,
3988 columnType, out);
3989
3990 // Sum all the components together.
3991 if (sum == NA) {
3992 sum = product;
3993 } else {
3994 sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3995 columnType, product,
3996 columnType, out);
3997 }
3998 }
3999
4000 return sum;
4001 }
4002
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)4003 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
4004 const Type& rightType, SpvId rhs,
4005 const Type& resultType, OutputStream& out) {
4006 // The comma operator ignores the type of the left-hand side entirely.
4007 if (op.kind() == Operator::Kind::COMMA) {
4008 return rhs;
4009 }
4010 // overall type we are operating on: float2, int, uint4...
4011 const Type* operandType;
4012 if (types_match(leftType, rightType)) {
4013 operandType = &leftType;
4014 } else {
4015 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
4016 // handling in SPIR-V
4017 if (leftType.isVector() && rightType.isNumber()) {
4018 if (resultType.componentType().isFloat()) {
4019 switch (op.kind()) {
4020 case Operator::Kind::SLASH: {
4021 rhs = this->writeReciprocal(rightType, rhs, out);
4022 [[fallthrough]];
4023 }
4024 case Operator::Kind::STAR: {
4025 SpvId result = this->nextId(&resultType);
4026 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
4027 result, lhs, rhs, out);
4028 return result;
4029 }
4030 default:
4031 break;
4032 }
4033 }
4034 // Vectorize the right-hand side.
4035 STArray<4, SpvId> arguments;
4036 arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
4037 rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
4038 operandType = &leftType;
4039 } else if (rightType.isVector() && leftType.isNumber()) {
4040 if (resultType.componentType().isFloat()) {
4041 if (op.kind() == Operator::Kind::STAR) {
4042 SpvId result = this->nextId(&resultType);
4043 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
4044 result, rhs, lhs, out);
4045 return result;
4046 }
4047 }
4048 // Vectorize the left-hand side.
4049 STArray<4, SpvId> arguments;
4050 arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
4051 lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
4052 operandType = &rightType;
4053 } else if (leftType.isMatrix()) {
4054 if (op.kind() == Operator::Kind::STAR) {
4055 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
4056 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
4057 if (fCaps.fRewriteMatrixVectorMultiply &&
4058 rightType.isVector() &&
4059 !resultType.highPrecision()) {
4060 return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
4061 resultType, out);
4062 }
4063
4064 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
4065 SpvOp_ spvop;
4066 if (rightType.isMatrix()) {
4067 spvop = SpvOpMatrixTimesMatrix;
4068 } else if (rightType.isVector()) {
4069 spvop = SpvOpMatrixTimesVector;
4070 } else {
4071 SkASSERT(rightType.isScalar());
4072 spvop = SpvOpMatrixTimesScalar;
4073 }
4074 SpvId result = this->nextId(&resultType);
4075 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
4076 return result;
4077 } else {
4078 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
4079 // expect to have a scalar here.
4080 SkASSERT(rightType.isScalar());
4081
4082 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
4083 SpvId rhsMatrix = this->splat(leftType, rhs, out);
4084
4085 // Perform this operation as matrix-op-matrix.
4086 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
4087 resultType, out);
4088 }
4089 } else if (rightType.isMatrix()) {
4090 if (op.kind() == Operator::Kind::STAR) {
4091 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
4092 SpvId result = this->nextId(&resultType);
4093 if (leftType.isVector()) {
4094 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
4095 result, lhs, rhs, out);
4096 } else {
4097 SkASSERT(leftType.isScalar());
4098 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
4099 result, rhs, lhs, out);
4100 }
4101 return result;
4102 } else {
4103 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
4104 // expect to have a scalar here.
4105 SkASSERT(leftType.isScalar());
4106
4107 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
4108 SpvId lhsMatrix = this->splat(rightType, lhs, out);
4109
4110 // Perform this operation as matrix-op-matrix.
4111 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
4112 resultType, out);
4113 }
4114 } else {
4115 fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
4116 return NA;
4117 }
4118 }
4119
4120 switch (op.kind()) {
4121 case Operator::Kind::EQEQ: {
4122 if (operandType->isMatrix()) {
4123 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
4124 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
4125 }
4126 if (operandType->isStruct()) {
4127 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
4128 }
4129 if (operandType->isArray()) {
4130 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
4131 }
4132 SkASSERT(resultType.isBoolean());
4133 const Type* tmpType;
4134 if (operandType->isVector()) {
4135 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
4136 operandType->columns(),
4137 operandType->rows());
4138 } else {
4139 tmpType = &resultType;
4140 }
4141 if (lhs == rhs) {
4142 // This ignores the effects of NaN.
4143 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
4144 }
4145 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4146 SpvOpFOrdEqual, SpvOpIEqual,
4147 SpvOpIEqual, SpvOpLogicalEqual, out),
4148 *operandType, SpvOpAll, out);
4149 }
4150 case Operator::Kind::NEQ:
4151 if (operandType->isMatrix()) {
4152 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
4153 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
4154 }
4155 if (operandType->isStruct()) {
4156 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
4157 }
4158 if (operandType->isArray()) {
4159 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
4160 }
4161 [[fallthrough]];
4162 case Operator::Kind::LOGICALXOR:
4163 SkASSERT(resultType.isBoolean());
4164 const Type* tmpType;
4165 if (operandType->isVector()) {
4166 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
4167 operandType->columns(),
4168 operandType->rows());
4169 } else {
4170 tmpType = &resultType;
4171 }
4172 if (lhs == rhs) {
4173 // This ignores the effects of NaN.
4174 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
4175 }
4176 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4177 SpvOpFUnordNotEqual, SpvOpINotEqual,
4178 SpvOpINotEqual, SpvOpLogicalNotEqual,
4179 out),
4180 *operandType, SpvOpAny, out);
4181 case Operator::Kind::GT:
4182 SkASSERT(resultType.isBoolean());
4183 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4184 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
4185 SpvOpUGreaterThan, SpvOpUndef, out);
4186 case Operator::Kind::LT:
4187 SkASSERT(resultType.isBoolean());
4188 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
4189 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
4190 case Operator::Kind::GTEQ:
4191 SkASSERT(resultType.isBoolean());
4192 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4193 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
4194 SpvOpUGreaterThanEqual, SpvOpUndef, out);
4195 case Operator::Kind::LTEQ:
4196 SkASSERT(resultType.isBoolean());
4197 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4198 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
4199 SpvOpULessThanEqual, SpvOpUndef, out);
4200 case Operator::Kind::PLUS:
4201 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4202 lhs, rhs, SpvOpFAdd, SpvOpIAdd,
4203 SpvOpIAdd, SpvOpUndef, out);
4204 case Operator::Kind::MINUS:
4205 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4206 lhs, rhs, SpvOpFSub, SpvOpISub,
4207 SpvOpISub, SpvOpUndef, out);
4208 case Operator::Kind::STAR:
4209 if (leftType.isMatrix() && rightType.isMatrix()) {
4210 // matrix multiply
4211 SpvId result = this->nextId(&resultType);
4212 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
4213 lhs, rhs, out);
4214 return result;
4215 }
4216 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
4217 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
4218 case Operator::Kind::SLASH:
4219 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4220 lhs, rhs, SpvOpFDiv, SpvOpSDiv,
4221 SpvOpUDiv, SpvOpUndef, out);
4222 case Operator::Kind::PERCENT:
4223 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
4224 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
4225 case Operator::Kind::SHL:
4226 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4227 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
4228 SpvOpUndef, out);
4229 case Operator::Kind::SHR:
4230 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4231 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
4232 SpvOpUndef, out);
4233 case Operator::Kind::BITWISEAND:
4234 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4235 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
4236 case Operator::Kind::BITWISEOR:
4237 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4238 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
4239 case Operator::Kind::BITWISEXOR:
4240 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4241 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
4242 default:
4243 fContext.fErrors->error(Position(), "unsupported token");
4244 return NA;
4245 }
4246 }
4247
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4248 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
4249 SpvId rhs, OutputStream& out) {
4250 // The inputs must be arrays, and the op must be == or !=.
4251 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4252 SkASSERT(arrayType.isArray());
4253 const Type& componentType = arrayType.componentType();
4254 const int arraySize = arrayType.columns();
4255 SkASSERT(arraySize > 0);
4256
4257 // Synthesize equality checks for each item in the array.
4258 const Type& boolType = *fContext.fTypes.fBool;
4259 SpvId allComparisons = NA;
4260 for (int index = 0; index < arraySize; ++index) {
4261 // Get the left and right item in the array.
4262 SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
4263 SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
4264 // Use `writeBinaryExpression` with the requested == or != operator on these items.
4265 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
4266 componentType, itemR, boolType, out);
4267 // Merge this comparison result with all the other comparisons we've done.
4268 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4269 }
4270 return allComparisons;
4271 }
4272
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4273 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4274 SpvId rhs, OutputStream& out) {
4275 // The inputs must be structs containing fields, and the op must be == or !=.
4276 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4277 SkASSERT(structType.isStruct());
4278 SkSpan<const Field> fields = structType.fields();
4279 SkASSERT(!fields.empty());
4280
4281 // Synthesize equality checks for each field in the struct.
4282 const Type& boolType = *fContext.fTypes.fBool;
4283 SpvId allComparisons = NA;
4284 for (int index = 0; index < (int)fields.size(); ++index) {
4285 // Get the left and right versions of this field.
4286 const Type& fieldType = *fields[index].fType;
4287
4288 SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4289 SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4290 // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4291 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4292 boolType, out);
4293 // Merge this comparison result with all the other comparisons we've done.
4294 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4295 }
4296 return allComparisons;
4297 }
4298
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4299 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4300 OutputStream& out) {
4301 // If this is the first entry, we don't need to merge comparison results with anything.
4302 if (allComparisons == NA) {
4303 return comparison;
4304 }
4305 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4306 const Type& boolType = *fContext.fTypes.fBool;
4307 SpvId boolTypeId = this->getType(boolType);
4308 SpvId logicalOp = this->nextId(&boolType);
4309 switch (op.kind()) {
4310 case Operator::Kind::EQEQ:
4311 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4312 comparison, allComparisons, out);
4313 break;
4314 case Operator::Kind::NEQ:
4315 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4316 comparison, allComparisons, out);
4317 break;
4318 default:
4319 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4320 return NA;
4321 }
4322 return logicalOp;
4323 }
4324
4325 #ifdef SKSL_EXT
writeSpecConstBinaryExpression(const BinaryExpression & b,const Operator & op,SpvId lhs,SpvId rhs)4326 SpvId SPIRVCodeGenerator::writeSpecConstBinaryExpression(const BinaryExpression& b, const Operator& op,
4327 SpvId lhs, SpvId rhs) {
4328 SpvId result = this->nextId(&(b.type()));
4329 switch (op.removeAssignment().kind()) {
4330 case Operator::Kind::EQEQ:
4331 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4332 SpvOpIEqual, lhs, rhs, fConstantBuffer);
4333 break;
4334 case Operator::Kind::NEQ:
4335 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4336 SpvOpINotEqual, lhs, rhs, fConstantBuffer);
4337 break;
4338 case Operator::Kind::LT:
4339 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4340 SpvOpULessThan, lhs, rhs, fConstantBuffer);
4341 break;
4342 case Operator::Kind::LTEQ:
4343 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4344 SpvOpULessThanEqual, lhs, rhs, fConstantBuffer);
4345 break;
4346 case Operator::Kind::GT:
4347 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4348 SpvOpUGreaterThan, lhs, rhs, fConstantBuffer);
4349 break;
4350 case Operator::Kind::GTEQ:
4351 this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4352 SpvOpUGreaterThanEqual, lhs, rhs, fConstantBuffer);
4353 break;
4354 default:
4355 fContext.fErrors->error(b.fPosition, "spec constant does not support operator: " +
4356 std::string(op.operatorName()));
4357 return -1;
4358 }
4359 return result;
4360 }
4361 #endif
4362
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4363 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4364 const Expression* left = b.left().get();
4365 const Expression* right = b.right().get();
4366 Operator op = b.getOperator();
4367
4368 switch (op.kind()) {
4369 case Operator::Kind::EQ: {
4370 // Handles assignment.
4371 SpvId rhs = this->writeExpression(*right, out);
4372 this->getLValue(*left, out)->store(rhs, out);
4373 return rhs;
4374 }
4375 case Operator::Kind::LOGICALAND:
4376 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4377 return this->writeLogicalAnd(*b.left(), *b.right(), out);
4378
4379 case Operator::Kind::LOGICALOR:
4380 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4381 return this->writeLogicalOr(*b.left(), *b.right(), out);
4382
4383 default:
4384 break;
4385 }
4386
4387 std::unique_ptr<LValue> lvalue;
4388 SpvId lhs;
4389 if (op.isAssignment()) {
4390 lvalue = this->getLValue(*left, out);
4391 lhs = lvalue->load(out);
4392 } else {
4393 lvalue = nullptr;
4394 lhs = this->writeExpression(*left, out);
4395 }
4396
4397 SpvId rhs = this->writeExpression(*right, out);
4398 #ifdef SKSL_EXT
4399 if (left->kind() == Expression::Kind::kVariableReference) {
4400 VariableReference* rightRef = (VariableReference*) right;
4401 const Expression* expr = ConstantFolder::GetConstantValueForVariable(*rightRef);
4402 if (expr != rightRef) {
4403 VariableReference* ref = (VariableReference*) left;
4404 const Variable* var = ref->variable();
4405 if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
4406 return writeSpecConstBinaryExpression(b, op, lhs, rhs);
4407 }
4408 }
4409 }
4410 if (right->kind() == Expression::Kind::kVariableReference) {
4411 VariableReference* leftRef = (VariableReference*) left;
4412 const Expression* expr = ConstantFolder::GetConstantValueForVariable(*leftRef);
4413 if (expr != leftRef) {
4414 VariableReference* ref = (VariableReference*) right;
4415 const Variable* var = ref->variable();
4416 if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
4417 return writeSpecConstBinaryExpression(b, op, lhs, rhs);
4418 }
4419 }
4420 }
4421 #endif
4422 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4423 right->type(), rhs, b.type(), out);
4424 if (lvalue) {
4425 lvalue->store(result, out);
4426 }
4427 return result;
4428 }
4429
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4430 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4431 OutputStream& out) {
4432 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4433 SpvId lhs = this->writeExpression(left, out);
4434
4435 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4436
4437 SpvId rhsLabel = this->nextId(nullptr);
4438 SpvId end = this->nextId(nullptr);
4439 SpvId lhsBlock = fCurrentBlock;
4440 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4441 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4442 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4443 SpvId rhs = this->writeExpression(right, out);
4444 SpvId rhsBlock = fCurrentBlock;
4445 this->writeInstruction(SpvOpBranch, end, out);
4446 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4447 SpvId result = this->nextId(nullptr);
4448 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4449 lhsBlock, rhs, rhsBlock, out);
4450
4451 return result;
4452 }
4453
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4454 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4455 OutputStream& out) {
4456 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4457 SpvId lhs = this->writeExpression(left, out);
4458
4459 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4460
4461 SpvId rhsLabel = this->nextId(nullptr);
4462 SpvId end = this->nextId(nullptr);
4463 SpvId lhsBlock = fCurrentBlock;
4464 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4465 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4466 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4467 SpvId rhs = this->writeExpression(right, out);
4468 SpvId rhsBlock = fCurrentBlock;
4469 this->writeInstruction(SpvOpBranch, end, out);
4470 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4471 SpvId result = this->nextId(nullptr);
4472 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4473 lhsBlock, rhs, rhsBlock, out);
4474
4475 return result;
4476 }
4477
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4478 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4479 const Type& type = t.type();
4480 SpvId test = this->writeExpression(*t.test(), out);
4481 if (t.ifTrue()->type().columns() == 1 &&
4482 Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4483 Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4484 // both true and false are constants, can just use OpSelect
4485 SpvId result = this->nextId(nullptr);
4486 SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4487 SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4488 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4489 out);
4490 return result;
4491 }
4492
4493 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4494
4495 // was originally using OpPhi to choose the result, but for some reason that is crashing on
4496 // Adreno. Switched to storing the result in a temp variable as glslang does.
4497 SpvId var = this->nextId(nullptr);
4498 this->writeInstruction(SpvOpVariable, this->getPointerType(type, StorageClass::kFunction),
4499 var, SpvStorageClassFunction, fVariableBuffer);
4500 SpvId trueLabel = this->nextId(nullptr);
4501 SpvId falseLabel = this->nextId(nullptr);
4502 SpvId end = this->nextId(nullptr);
4503 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4504 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4505 this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4506 this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4507 this->writeInstruction(SpvOpBranch, end, out);
4508 this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4509 this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4510 this->writeInstruction(SpvOpBranch, end, out);
4511 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4512 SpvId result = this->nextId(&type);
4513 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4514
4515 return result;
4516 }
4517
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4518 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4519 const Type& type = p.type();
4520 if (p.getOperator().kind() == Operator::Kind::MINUS) {
4521 SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4522 SkASSERT(negateOp != SpvOpUndef);
4523 SpvId expr = this->writeExpression(*p.operand(), out);
4524 if (type.isMatrix()) {
4525 return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4526 }
4527 SpvId result = this->nextId(&type);
4528 SpvId typeId = this->getType(type);
4529 this->writeInstruction(negateOp, typeId, result, expr, out);
4530 return result;
4531 }
4532 switch (p.getOperator().kind()) {
4533 case Operator::Kind::PLUS:
4534 return this->writeExpression(*p.operand(), out);
4535
4536 case Operator::Kind::PLUSPLUS: {
4537 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4538 SpvId one = this->writeLiteral(1.0, type.componentType());
4539 one = this->splat(type, one, out);
4540 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4541 lv->load(out), one,
4542 SpvOpFAdd, SpvOpIAdd,
4543 SpvOpIAdd, SpvOpUndef,
4544 out);
4545 lv->store(result, out);
4546 return result;
4547 }
4548 case Operator::Kind::MINUSMINUS: {
4549 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4550 SpvId one = this->writeLiteral(1.0, type.componentType());
4551 one = this->splat(type, one, out);
4552 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4553 lv->load(out), one,
4554 SpvOpFSub, SpvOpISub,
4555 SpvOpISub, SpvOpUndef,
4556 out);
4557 lv->store(result, out);
4558 return result;
4559 }
4560 case Operator::Kind::LOGICALNOT: {
4561 SkASSERT(p.operand()->type().isBoolean());
4562 SpvId result = this->nextId(nullptr);
4563 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4564 this->writeExpression(*p.operand(), out), out);
4565 return result;
4566 }
4567 case Operator::Kind::BITWISENOT: {
4568 SpvId result = this->nextId(nullptr);
4569 this->writeInstruction(SpvOpNot, this->getType(type), result,
4570 this->writeExpression(*p.operand(), out), out);
4571 return result;
4572 }
4573 default:
4574 SkDEBUGFAILF("unsupported prefix expression: %s",
4575 p.description(OperatorPrecedence::kExpression).c_str());
4576 return NA;
4577 }
4578 }
4579
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4580 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4581 const Type& type = p.type();
4582 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4583 SpvId result = lv->load(out);
4584 SpvId one = this->writeLiteral(1.0, type.componentType());
4585 one = this->splat(type, one, out);
4586 switch (p.getOperator().kind()) {
4587 case Operator::Kind::PLUSPLUS: {
4588 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4589 SpvOpFAdd, SpvOpIAdd,
4590 SpvOpIAdd, SpvOpUndef,
4591 out);
4592 lv->store(temp, out);
4593 return result;
4594 }
4595 case Operator::Kind::MINUSMINUS: {
4596 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4597 SpvOpFSub, SpvOpISub,
4598 SpvOpISub, SpvOpUndef,
4599 out);
4600 lv->store(temp, out);
4601 return result;
4602 }
4603 default:
4604 SkDEBUGFAILF("unsupported postfix expression %s",
4605 p.description(OperatorPrecedence::kExpression).c_str());
4606 return NA;
4607 }
4608 }
4609
writeLiteral(const Literal & l)4610 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4611 return this->writeLiteral(l.value(), l.type());
4612 }
4613
writeLiteral(double value,const Type & type)4614 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4615 switch (type.numberKind()) {
4616 case Type::NumberKind::kFloat: {
4617 float floatVal = value;
4618 int32_t valueBits;
4619 memcpy(&valueBits, &floatVal, sizeof(valueBits));
4620 return this->writeOpConstant(type, valueBits);
4621 }
4622 case Type::NumberKind::kBoolean: {
4623 return value ? this->writeOpConstantTrue(type)
4624 : this->writeOpConstantFalse(type);
4625 }
4626 default: {
4627 return this->writeOpConstant(type, (SKSL_INT)value);
4628 }
4629 }
4630 }
4631
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4632 void SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4633 SpvId result = fFunctionMap[{&f, fActiveSpecializationIndex}];
4634 SpvId returnTypeId = this->getType(f.returnType());
4635 SpvId functionTypeId = this->getFunctionType(f);
4636 this->writeInstruction(SpvOpFunction, returnTypeId, result,
4637 SpvFunctionControlMaskNone, functionTypeId, out);
4638 std::string mangledName = f.mangledName();
4639
4640 // For specialized functions, tack on `_param1_param2` to the function name.
4641 Analysis::GetParameterMappingsForFunction(
4642 f, fSpecializationInfo, fActiveSpecializationIndex,
4643 [&](int, const Variable*, const Expression* expr) {
4644 std::string name = expr->description();
4645 std::replace_if(name.begin(), name.end(), [](char c) { return !isalnum(c); }, '_');
4646
4647 mangledName += "_" + name;
4648 });
4649
4650 this->writeInstruction(SpvOpName,
4651 result,
4652 std::string_view(mangledName.c_str(), mangledName.size()),
4653 fNameBuffer);
4654 for (const Variable* parameter : f.parameters()) {
4655 const Variable* specializedVar = nullptr;
4656 if (fActiveSpecialization) {
4657 if (const Expression** specializedExpr = fActiveSpecialization->find(parameter)) {
4658 if ((*specializedExpr)->is<FieldAccess>()) {
4659 continue;
4660 }
4661 SkASSERT((*specializedExpr)->is<VariableReference>());
4662 specializedVar = (*specializedExpr)->as<VariableReference>().variable();
4663 }
4664 }
4665
4666 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4667 auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4668
4669 SpvId textureId = this->nextId(nullptr);
4670 fVariableMap.set(texture, textureId);
4671
4672 SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4673 this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4674
4675 if (specializedVar) {
4676 const auto* p = fSynthesizedSamplerMap.find(specializedVar);
4677 SkASSERT(p);
4678 const SpvId* uniformId = fVariableMap.find((*p)->fSampler.get());
4679 SkASSERT(uniformId);
4680 fVariableMap.set(sampler, *uniformId);
4681 } else {
4682 SpvId samplerId = this->nextId(nullptr);
4683 fVariableMap.set(sampler, samplerId);
4684
4685 SpvId samplerType =
4686 this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4687 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4688 }
4689 } else {
4690 if (specializedVar) {
4691 const SpvId* uniformId = fVariableMap.find(specializedVar);
4692 SkASSERT(uniformId);
4693 fVariableMap.set(parameter, *uniformId);
4694 } else {
4695 SpvId id = this->nextId(nullptr);
4696 fVariableMap.set(parameter, id);
4697
4698 SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4699 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4700 }
4701 }
4702 }
4703 }
4704
writeFunction(const FunctionDefinition & f,OutputStream & out)4705 void SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4706 if (const Analysis::Specializations* specializations =
4707 fSpecializationInfo.fSpecializationMap.find(&f.declaration())) {
4708 for (int i = 0; i < specializations->size(); i++) {
4709 this->writeFunctionInstantiation(f, i, &specializations->at(i), out);
4710 }
4711 } else {
4712 this->writeFunctionInstantiation(f,
4713 Analysis::kUnspecialized,
4714 /*specializedParams=*/nullptr,
4715 out);
4716 }
4717 }
4718
writeFunctionInstantiation(const FunctionDefinition & f,Analysis::SpecializationIndex specializationIndex,const Analysis::SpecializedParameters * specializedParams,OutputStream & out)4719 void SPIRVCodeGenerator::writeFunctionInstantiation(
4720 const FunctionDefinition& f,
4721 Analysis::SpecializationIndex specializationIndex,
4722 const Analysis::SpecializedParameters* specializedParams,
4723 OutputStream& out) {
4724 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4725
4726 fVariableBuffer.reset();
4727 fActiveSpecialization = specializedParams;
4728 fActiveSpecializationIndex = specializationIndex;
4729 this->writeFunctionStart(f.declaration(), out);
4730 fCurrentBlock = 0;
4731 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4732 StringStream bodyBuffer;
4733 this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4734 fActiveSpecialization = nullptr;
4735 fActiveSpecializationIndex = Analysis::kUnspecialized;
4736 write_stringstream(fVariableBuffer, out);
4737 if (f.declaration().isMain()) {
4738 write_stringstream(fGlobalInitializersBuffer, out);
4739 }
4740 write_stringstream(bodyBuffer, out);
4741 if (fCurrentBlock) {
4742 if (f.declaration().returnType().isVoid()) {
4743 this->writeInstruction(SpvOpReturn, out);
4744 } else {
4745 this->writeInstruction(SpvOpUnreachable, out);
4746 }
4747 }
4748 this->writeInstruction(SpvOpFunctionEnd, out);
4749 this->pruneConditionalOps(conditionalOps);
4750 }
4751
writeLayout(const Layout & layout,SpvId target,Position pos)4752 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4753 bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4754 if (layout.fLocation >= 0) {
4755 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4756 fDecorationBuffer);
4757 }
4758 if (layout.fBinding >= 0) {
4759 if (isPushConstant) {
4760 fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4761 } else {
4762 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4763 fDecorationBuffer);
4764 }
4765 }
4766 if (layout.fIndex >= 0) {
4767 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4768 fDecorationBuffer);
4769 }
4770 if (layout.fSet >= 0) {
4771 if (isPushConstant) {
4772 fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4773 } else {
4774 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4775 fDecorationBuffer);
4776 }
4777 }
4778 if (layout.fInputAttachmentIndex >= 0) {
4779 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4780 layout.fInputAttachmentIndex, fDecorationBuffer);
4781 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4782 }
4783 if (layout.fBuiltin >= 0 && (layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
4784 layout.fBuiltin != SK_SECONDARYFRAGCOLOR_BUILTIN)) {
4785 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4786 fDecorationBuffer);
4787 }
4788 }
4789
writeFieldLayout(const Layout & layout,SpvId target,int member)4790 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4791 // 'binding' and 'set' can not be applied to struct members
4792 SkASSERT(layout.fBinding == -1);
4793 SkASSERT(layout.fSet == -1);
4794 if (layout.fLocation >= 0) {
4795 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4796 layout.fLocation, fDecorationBuffer);
4797 }
4798 if (layout.fIndex >= 0) {
4799 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4800 layout.fIndex, fDecorationBuffer);
4801 }
4802 if (layout.fInputAttachmentIndex >= 0) {
4803 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4804 layout.fInputAttachmentIndex, fDecorationBuffer);
4805 }
4806 if (layout.fBuiltin >= 0) {
4807 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4808 layout.fBuiltin, fDecorationBuffer);
4809 }
4810 }
4811
memoryLayoutForStorageClass(StorageClass storageClass)4812 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(StorageClass storageClass) {
4813 return storageClass == StorageClass::kPushConstant ||
4814 storageClass == StorageClass::kStorageBuffer
4815 ? MemoryLayout(MemoryLayout::Standard::k430)
4816 : fDefaultMemoryLayout;
4817 }
4818
memoryLayoutForVariable(const Variable & v) const4819 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4820 bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4821 bool buffer = v.modifierFlags().isBuffer();
4822 return pushConstant || buffer ? MemoryLayout(MemoryLayout::Standard::k430)
4823 : fDefaultMemoryLayout;
4824 }
4825
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4826 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4827 MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4828 SpvId result = this->nextId(nullptr);
4829 const Variable& intfVar = *intf.var();
4830 const Type& type = intfVar.type();
4831 if (!memoryLayout.isSupported(type)) {
4832 fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4833 "' is not permitted here");
4834 return this->nextId(nullptr);
4835 }
4836 StorageClass storageClass =
4837 get_storage_class_for_global_variable(intfVar, StorageClass::kFunction);
4838 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4839 !fWroteRTFlip && type.isStruct()) {
4840 // We can only have one interface block (because we use push_constant and that is limited
4841 // to one per program), so we need to append rtflip to this one rather than synthesize an
4842 // entirely new block when the variable is referenced. And we can't modify the existing
4843 // block, so we instead create a modified copy of it and write that.
4844 SkSpan<const Field> fieldSpan = type.fields();
4845 TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4846 fields.emplace_back(Position(),
4847 Layout(LayoutFlag::kNone,
4848 /*location=*/-1,
4849 fProgram.fConfig->fSettings.fRTFlipOffset,
4850 /*binding=*/-1,
4851 /*index=*/-1,
4852 /*set=*/-1,
4853 /*builtin=*/-1,
4854 /*inputAttachmentIndex=*/-1),
4855 ModifierFlag::kNone,
4856 SKSL_RTFLIP_NAME,
4857 fContext.fTypes.fFloat2.get());
4858 {
4859 AutoAttachPoolToThread attach(fProgram.fPool.get());
4860 const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4861 Type::MakeStructType(fContext,
4862 type.fPosition,
4863 type.name(),
4864 std::move(fields),
4865 /*interfaceBlock=*/true));
4866 Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4867 Variable::Make(intfVar.fPosition,
4868 intfVar.modifiersPosition(),
4869 intfVar.layout(),
4870 intfVar.modifierFlags(),
4871 rtFlipStructType,
4872 intfVar.name(),
4873 /*mangledName=*/"",
4874 intfVar.isBuiltin(),
4875 intfVar.storage()));
4876 InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4877 result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4878 fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4879 Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4880 }
4881 fVariableMap.set(&intfVar, result);
4882 fWroteRTFlip = true;
4883 return result;
4884 }
4885 SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4886 if (intfVar.layout().fBuiltin == -1) {
4887 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4888 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4889 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4890 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4891 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4892 // would benefit from SPV_KHR_variable_pointers capabilities).
4893 #ifdef SKSL_EXT
4894 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
4895 #else
4896 bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4897 this->writeInstruction(SpvOpDecorate,
4898 typeId,
4899 isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4900 fDecorationBuffer);
4901 #endif
4902 }
4903 SpvId ptrType = this->nextId(nullptr);
4904 this->writeInstruction(SpvOpTypePointer, ptrType,
4905 get_storage_class_spv_id(storageClass), typeId, fConstantBuffer);
4906 this->writeInstruction(SpvOpVariable, ptrType, result,
4907 get_storage_class_spv_id(storageClass), fConstantBuffer);
4908 Layout layout = intfVar.layout();
4909 if ((storageClass == StorageClass::kUniform ||
4910 storageClass == StorageClass::kStorageBuffer) && layout.fSet < 0) {
4911 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4912 }
4913 this->writeLayout(layout, result, intfVar.fPosition);
4914 fVariableMap.set(&intfVar, result);
4915 return result;
4916 }
4917
4918 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4919 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4920 // always reference by value.
4921 //
4922 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4923 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4924 // OpVectorExtractDynamic (see writeIndexExpression).
4925 //
4926 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4927 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4928 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4929 return varDecl.var()->modifierFlags().isConst() &&
4930 #ifdef SKSL_EXT
4931 (varDecl.var()->layout().fFlags & LayoutFlag::kConstantId) == LayoutFlag::kNone &&
4932 #endif
4933 (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4934 (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4935 Analysis::IsCompileTimeConstant(*varDecl.value()));
4936 }
4937
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4938 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4939 const VarDeclaration& varDecl) {
4940 const Variable* var = varDecl.var();
4941 const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4942 const LayoutFlags kPermittedBackendFlags =
4943 LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4944 if (backendFlags & ~kPermittedBackendFlags) {
4945 fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4946 return false;
4947 }
4948
4949 // If this global variable is a compile-time constant then we'll emit OpConstant or
4950 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4951 if (is_vardecl_compile_time_constant(varDecl)) {
4952 return true;
4953 }
4954
4955 StorageClass storageClass =
4956 get_storage_class_for_global_variable(*var, StorageClass::kPrivate);
4957 if (storageClass == StorageClass::kUniform || storageClass == StorageClass::kStorageBuffer) {
4958 // Top-level uniforms are emitted in writeUniformBuffer.
4959 fTopLevelUniforms.push_back(&varDecl);
4960 return true;
4961 }
4962 #ifdef SKSL_EXT
4963 if (var->layout().fFlags & LayoutFlag::kConstantId) {
4964 Layout layout = var->layout();
4965 const Type& type = var->type();
4966 SpvId id = this->nextId(&type);
4967 fVariableMap[var] = id;
4968 SpvId typeId = this->getType(type);
4969 if (type.isInteger() && varDecl.value()) {
4970 int tmp = (*varDecl.value()).as<Literal>().intValue();
4971 this->writeInstruction(SpvOpSpecConstant, typeId, id, tmp, fConstantBuffer);
4972 } else {
4973 fContext.fErrors->error(var->fPosition,
4974 "spec const '" + std::string(var->name()) + "' must be an integer literal");
4975 return false;
4976 }
4977 this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4978 this->writeInstruction(SpvOpDecorate, id, SpvDecorationSpecId, layout.fConstantId,
4979 fDecorationBuffer);
4980 return true;
4981 }
4982 #endif
4983 if (fUseTextureSamplerPairs && var->type().isSampler()) {
4984 if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4985 fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4986 "and sampler indices");
4987 return false;
4988 }
4989 SkASSERT(storageClass == StorageClass::kUniformConstant);
4990
4991 auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4992 this->writeGlobalVar(kind, storageClass, *texture);
4993 this->writeGlobalVar(kind, storageClass, *sampler);
4994
4995 return true;
4996 }
4997 #ifdef SKSL_EXT
4998 if (storageClass == StorageClass::kFunction) {
4999 SkASSERT(varDecl.value());
5000 this->getPointerType(var->type(), storageClass);
5001 fEmittingGlobalConstConstructor = true;
5002 SpvId value = this->writeExpression(*varDecl.value(), fConstantBuffer);
5003 fEmittingGlobalConstConstructor = false;
5004 fGlobalConstVariableValueMap[var] = value;
5005 } else {
5006 #endif
5007 SpvId id = this->writeGlobalVar(kind, storageClass, *var);
5008 if (id != NA && varDecl.value()) {
5009 SkASSERT(!fCurrentBlock);
5010 fCurrentBlock = NA;
5011 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
5012 this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
5013 fCurrentBlock = 0;
5014 }
5015 #ifdef SKSL_EXT
5016 }
5017 #endif
5018 return true;
5019 }
5020
writeGlobalVar(ProgramKind kind,StorageClass storageClass,const Variable & var)5021 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
5022 StorageClass storageClass,
5023 const Variable& var) {
5024 Layout layout = var.layout();
5025 const ModifierFlags flags = var.modifierFlags();
5026 const Type* type = &var.type();
5027 switch (layout.fBuiltin) {
5028 case SK_FRAGCOLOR_BUILTIN:
5029 case SK_SECONDARYFRAGCOLOR_BUILTIN:
5030 if (!ProgramConfig::IsFragment(kind)) {
5031 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
5032 return NA;
5033 }
5034 break;
5035
5036 case SK_SAMPLEMASKIN_BUILTIN:
5037 case SK_SAMPLEMASK_BUILTIN:
5038 // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
5039 // decorated with SpvBuiltinSampleMask.
5040 type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
5041 layout.fBuiltin = SpvBuiltInSampleMask;
5042 break;
5043 }
5044
5045 // Add this global to the variable map.
5046 SpvId id = this->nextId(type);
5047 fVariableMap.set(&var, id);
5048
5049 if (layout.fSet < 0 && storageClass == StorageClass::kUniformConstant) {
5050 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
5051 }
5052
5053 SpvId typeId = this->getPointerType(*type,
5054 layout,
5055 this->memoryLayoutForStorageClass(storageClass),
5056 storageClass);
5057 this->writeInstruction(SpvOpVariable, typeId, id,
5058 get_storage_class_spv_id(storageClass), fConstantBuffer);
5059 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
5060 this->writeLayout(layout, id, var.fPosition);
5061 if (flags & ModifierFlag::kFlat) {
5062 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
5063 }
5064 if (flags & ModifierFlag::kNoPerspective) {
5065 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
5066 fDecorationBuffer);
5067 }
5068 if (flags.isWriteOnly()) {
5069 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
5070 } else if (flags.isReadOnly()) {
5071 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
5072 }
5073
5074 return id;
5075 }
5076
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)5077 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
5078 // If this variable is a compile-time constant then we'll emit OpConstant or
5079 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
5080 if (is_vardecl_compile_time_constant(varDecl)) {
5081 return;
5082 }
5083
5084 const Variable* var = varDecl.var();
5085 SpvId id = this->nextId(&var->type());
5086 fVariableMap.set(var, id);
5087 SpvId type = this->getPointerType(var->type(), StorageClass::kFunction);
5088 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
5089 this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
5090 if (varDecl.value()) {
5091 SpvId value = this->writeExpression(*varDecl.value(), out);
5092 this->writeOpStore(StorageClass::kFunction, id, value, out);
5093 }
5094 }
5095
writeStatement(const Statement & s,OutputStream & out)5096 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
5097 switch (s.kind()) {
5098 case Statement::Kind::kNop:
5099 break;
5100 case Statement::Kind::kBlock:
5101 this->writeBlock(s.as<Block>(), out);
5102 break;
5103 case Statement::Kind::kExpression:
5104 this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
5105 break;
5106 case Statement::Kind::kReturn:
5107 this->writeReturnStatement(s.as<ReturnStatement>(), out);
5108 break;
5109 case Statement::Kind::kVarDeclaration:
5110 this->writeVarDeclaration(s.as<VarDeclaration>(), out);
5111 break;
5112 case Statement::Kind::kIf:
5113 this->writeIfStatement(s.as<IfStatement>(), out);
5114 break;
5115 case Statement::Kind::kFor:
5116 this->writeForStatement(s.as<ForStatement>(), out);
5117 break;
5118 case Statement::Kind::kDo:
5119 this->writeDoStatement(s.as<DoStatement>(), out);
5120 break;
5121 case Statement::Kind::kSwitch:
5122 this->writeSwitchStatement(s.as<SwitchStatement>(), out);
5123 break;
5124 case Statement::Kind::kBreak:
5125 this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
5126 break;
5127 case Statement::Kind::kContinue:
5128 this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
5129 break;
5130 case Statement::Kind::kDiscard:
5131 this->writeInstruction(SpvOpKill, out);
5132 break;
5133 default:
5134 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
5135 break;
5136 }
5137 }
5138
writeBlock(const Block & b,OutputStream & out)5139 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
5140 for (const std::unique_ptr<Statement>& stmt : b.children()) {
5141 this->writeStatement(*stmt, out);
5142 }
5143 }
5144
getConditionalOpCounts()5145 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
5146 return {fReachableOps.size(), fStoreOps.size()};
5147 }
5148
pruneConditionalOps(ConditionalOpCounts ops)5149 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
5150 // Remove ops which are no longer reachable.
5151 while (fReachableOps.size() > ops.numReachableOps) {
5152 SpvId prunableSpvId = fReachableOps.back();
5153 const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
5154
5155 if (prunableOp) {
5156 fOpCache.remove(*prunableOp);
5157 fSpvIdCache.remove(prunableSpvId);
5158 } else {
5159 SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
5160 }
5161
5162 fReachableOps.pop_back();
5163 }
5164
5165 // Remove any cached stores that occurred during the conditional block.
5166 while (fStoreOps.size() > ops.numStoreOps) {
5167 if (fStoreCache.find(fStoreOps.back())) {
5168 fStoreCache.remove(fStoreOps.back());
5169 }
5170 fStoreOps.pop_back();
5171 }
5172 }
5173
writeIfStatement(const IfStatement & stmt,OutputStream & out)5174 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
5175 SpvId test = this->writeExpression(*stmt.test(), out);
5176 SpvId ifTrue = this->nextId(nullptr);
5177 SpvId ifFalse = this->nextId(nullptr);
5178
5179 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5180
5181 if (stmt.ifFalse()) {
5182 SpvId end = this->nextId(nullptr);
5183 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5184 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
5185 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
5186 this->writeStatement(*stmt.ifTrue(), out);
5187 if (fCurrentBlock) {
5188 this->writeInstruction(SpvOpBranch, end, out);
5189 }
5190 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
5191 this->writeStatement(*stmt.ifFalse(), out);
5192 if (fCurrentBlock) {
5193 this->writeInstruction(SpvOpBranch, end, out);
5194 }
5195 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5196 } else {
5197 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
5198 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
5199 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
5200 this->writeStatement(*stmt.ifTrue(), out);
5201 if (fCurrentBlock) {
5202 this->writeInstruction(SpvOpBranch, ifFalse, out);
5203 }
5204 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
5205 }
5206 }
5207
writeForStatement(const ForStatement & f,OutputStream & out)5208 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
5209 if (f.initializer()) {
5210 this->writeStatement(*f.initializer(), out);
5211 }
5212
5213 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5214
5215 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5216 // in the context of linear straight-line execution. If we wanted to be more clever, we could
5217 // only invalidate store cache entries for variables affected by the loop body, but for now we
5218 // simply clear the entire cache whenever branching occurs.
5219 SpvId header = this->nextId(nullptr);
5220 SpvId start = this->nextId(nullptr);
5221 SpvId body = this->nextId(nullptr);
5222 SpvId next = this->nextId(nullptr);
5223 fContinueTarget.push_back(next);
5224 SpvId end = this->nextId(nullptr);
5225 fBreakTarget.push_back(end);
5226 this->writeInstruction(SpvOpBranch, header, out);
5227 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
5228 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
5229 this->writeInstruction(SpvOpBranch, start, out);
5230 this->writeLabel(start, kBranchIsOnPreviousLine, out);
5231 if (f.test()) {
5232 SpvId test = this->writeExpression(*f.test(), out);
5233 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
5234 } else {
5235 this->writeInstruction(SpvOpBranch, body, out);
5236 }
5237 this->writeLabel(body, kBranchIsOnPreviousLine, out);
5238 this->writeStatement(*f.statement(), out);
5239 if (fCurrentBlock) {
5240 this->writeInstruction(SpvOpBranch, next, out);
5241 }
5242 this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
5243 if (f.next()) {
5244 this->writeExpression(*f.next(), out);
5245 }
5246 this->writeInstruction(SpvOpBranch, header, out);
5247 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5248 fBreakTarget.pop_back();
5249 fContinueTarget.pop_back();
5250 }
5251
writeDoStatement(const DoStatement & d,OutputStream & out)5252 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
5253 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5254
5255 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5256 // in the context of linear straight-line execution. If we wanted to be more clever, we could
5257 // only invalidate store cache entries for variables affected by the loop body, but for now we
5258 // simply clear the entire cache whenever branching occurs.
5259 SpvId header = this->nextId(nullptr);
5260 SpvId start = this->nextId(nullptr);
5261 SpvId next = this->nextId(nullptr);
5262 SpvId continueTarget = this->nextId(nullptr);
5263 fContinueTarget.push_back(continueTarget);
5264 SpvId end = this->nextId(nullptr);
5265 fBreakTarget.push_back(end);
5266 this->writeInstruction(SpvOpBranch, header, out);
5267 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
5268 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
5269 this->writeInstruction(SpvOpBranch, start, out);
5270 this->writeLabel(start, kBranchIsOnPreviousLine, out);
5271 this->writeStatement(*d.statement(), out);
5272 if (fCurrentBlock) {
5273 this->writeInstruction(SpvOpBranch, next, out);
5274 this->writeLabel(next, kBranchIsOnPreviousLine, out);
5275 this->writeInstruction(SpvOpBranch, continueTarget, out);
5276 }
5277 this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
5278 SpvId test = this->writeExpression(*d.test(), out);
5279 this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
5280 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5281 fBreakTarget.pop_back();
5282 fContinueTarget.pop_back();
5283 }
5284
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)5285 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
5286 SpvId value = this->writeExpression(*s.value(), out);
5287
5288 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5289
5290 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5291 // in the context of linear straight-line execution. If we wanted to be more clever, we could
5292 // only invalidate store cache entries for variables affected by the switch body, but for now we
5293 // simply clear the entire cache whenever branching occurs.
5294 TArray<SpvId> labels;
5295 SpvId end = this->nextId(nullptr);
5296 SpvId defaultLabel = end;
5297 fBreakTarget.push_back(end);
5298 int size = 3;
5299 const StatementArray& cases = s.cases();
5300 for (const std::unique_ptr<Statement>& stmt : cases) {
5301 const SwitchCase& c = stmt->as<SwitchCase>();
5302 SpvId label = this->nextId(nullptr);
5303 labels.push_back(label);
5304 if (!c.isDefault()) {
5305 size += 2;
5306 } else {
5307 defaultLabel = label;
5308 }
5309 }
5310
5311 // We should have exactly one label for each case.
5312 SkASSERT(labels.size() == cases.size());
5313
5314 // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
5315 // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
5316 // does support multiple switch-cases branching to the same label.
5317 SkBitSet caseIsCollapsed(cases.size());
5318 for (int index = cases.size() - 2; index >= 0; index--) {
5319 if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
5320 caseIsCollapsed.set(index);
5321 labels[index] = labels[index + 1];
5322 }
5323 }
5324
5325 labels.push_back(end);
5326
5327 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5328 this->writeOpCode(SpvOpSwitch, size, out);
5329 this->writeWord(value, out);
5330 this->writeWord(defaultLabel, out);
5331 for (int i = 0; i < cases.size(); ++i) {
5332 const SwitchCase& c = cases[i]->as<SwitchCase>();
5333 if (c.isDefault()) {
5334 continue;
5335 }
5336 this->writeWord(c.value(), out);
5337 this->writeWord(labels[i], out);
5338 }
5339 for (int i = 0; i < cases.size(); ++i) {
5340 if (caseIsCollapsed.test(i)) {
5341 continue;
5342 }
5343 const SwitchCase& c = cases[i]->as<SwitchCase>();
5344 if (i == 0) {
5345 this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
5346 } else {
5347 this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
5348 }
5349 this->writeStatement(*c.statement(), out);
5350 if (fCurrentBlock) {
5351 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
5352 }
5353 }
5354 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5355 fBreakTarget.pop_back();
5356 }
5357
writeReturnStatement(const ReturnStatement & r,OutputStream & out)5358 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
5359 if (r.expression()) {
5360 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
5361 out);
5362 } else {
5363 this->writeInstruction(SpvOpReturn, out);
5364 }
5365 }
5366
5367 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)5368 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
5369 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
5370 }
5371
writeEntrypointAdapter(const FunctionDeclaration & main)5372 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
5373 const FunctionDeclaration& main) {
5374 // Our goal is to synthesize a tiny helper function which looks like this:
5375 // void _entrypoint() { sk_FragColor = main(); }
5376
5377 // Fish a symbol table out of main().
5378 SymbolTable* symbolTable = get_top_level_symbol_table(main);
5379
5380 // Get `sk_FragColor` as a writable reference.
5381 const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
5382 SkASSERT(skFragColorSymbol);
5383 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
5384 auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
5385 VariableReference::RefKind::kWrite);
5386
5387 // TODO get secondary frag color as one as well?
5388
5389 // Synthesize a call to the `main()` function.
5390 if (!main.returnType().matches(skFragColorRef->type())) {
5391 fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
5392 main.returnType().description() + "' from main()");
5393 return {};
5394 }
5395 ExpressionArray args;
5396 if (main.parameters().size() == 1) {
5397 if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
5398 fContext.fErrors->error(main.fPosition,
5399 "SPIR-V does not support parameter of type '" +
5400 main.parameters()[0]->type().description() + "' to main()");
5401 return {};
5402 }
5403 double kZero[2] = {0.0, 0.0};
5404 args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
5405 *fContext.fTypes.fFloat2, kZero));
5406 }
5407 auto callMainFn = FunctionCall::Make(fContext, Position(), &main.returnType(),
5408 main, std::move(args));
5409
5410 // Synthesize `skFragColor = main()` as a BinaryExpression.
5411 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
5412 Position(),
5413 std::move(skFragColorRef),
5414 Operator::Kind::EQ,
5415 std::move(callMainFn),
5416 &main.returnType()));
5417
5418 // Function bodies are always wrapped in a Block.
5419 StatementArray entrypointStmts;
5420 entrypointStmts.push_back(std::move(assignmentStmt));
5421 auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
5422 Block::Kind::kBracedScope, /*symbols=*/nullptr);
5423 // Declare an entrypoint function.
5424 EntrypointAdapter adapter;
5425 adapter.entrypointDecl =
5426 std::make_unique<FunctionDeclaration>(fContext,
5427 Position(),
5428 ModifierFlag::kNone,
5429 "_entrypoint",
5430 /*parameters=*/TArray<Variable*>{},
5431 /*returnType=*/fContext.fTypes.fVoid.get(),
5432 kNotIntrinsic);
5433 // Define it.
5434 adapter.entrypointDef = FunctionDefinition::Convert(fContext,
5435 Position(),
5436 *adapter.entrypointDecl,
5437 std::move(entrypointBlock));
5438
5439 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
5440 return adapter;
5441 }
5442
writeUniformBuffer(SymbolTable * topLevelSymbolTable)5443 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5444 SkASSERT(!fTopLevelUniforms.empty());
5445 static constexpr char kUniformBufferName[] = "_UniformBuffer";
5446
5447 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5448 // a lookup table of variables to UniformBuffer field indices.
5449 TArray<Field> fields;
5450 fields.reserve_exact(fTopLevelUniforms.size());
5451 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5452 const Variable* var = topLevelUniform->var();
5453 fTopLevelUniformMap.set(var, (int)fields.size());
5454 ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5455 fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5456 }
5457 fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5458 Position(),
5459 kUniformBufferName,
5460 std::move(fields),
5461 /*interfaceBlock=*/true);
5462
5463 // Create a global variable to contain this struct.
5464 Layout layout;
5465 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5466 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
5467
5468 fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5469 /*modifiersPosition=*/Position(),
5470 layout,
5471 ModifierFlag::kUniform,
5472 fUniformBuffer.fStruct.get(),
5473 kUniformBufferName,
5474 /*mangledName=*/"",
5475 /*builtin=*/false,
5476 Variable::Storage::kGlobal);
5477
5478 // Create an interface block object for this global variable.
5479 fUniformBuffer.fInterfaceBlock =
5480 std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5481
5482 // Generate an interface block and hold onto its ID.
5483 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5484 }
5485
addRTFlipUniform(Position pos)5486 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5487 SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5488
5489 if (fWroteRTFlip) {
5490 return;
5491 }
5492 // Flip variable hasn't been written yet. This means we don't have an existing
5493 // interface block, so we're free to just synthesize one.
5494 fWroteRTFlip = true;
5495 TArray<Field> fields;
5496 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5497 fContext.fErrors->error(pos, "RTFlipOffset is negative");
5498 }
5499 fields.emplace_back(pos,
5500 Layout(LayoutFlag::kNone,
5501 /*location=*/-1,
5502 fProgram.fConfig->fSettings.fRTFlipOffset,
5503 /*binding=*/-1,
5504 /*index=*/-1,
5505 /*set=*/-1,
5506 /*builtin=*/-1,
5507 /*inputAttachmentIndex=*/-1),
5508 ModifierFlag::kNone,
5509 SKSL_RTFLIP_NAME,
5510 fContext.fTypes.fFloat2.get());
5511 std::string_view name = "sksl_synthetic_uniforms";
5512 const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5513 fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5514 bool usePushConstants = fProgram.fConfig->fSettings.fUseVulkanPushConstantsForGaneshRTAdjust;
5515 int binding = -1, set = -1;
5516 if (!usePushConstants) {
5517 binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5518 if (binding == -1) {
5519 fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5520 }
5521 set = fProgram.fConfig->fSettings.fRTFlipSet;
5522 if (set == -1) {
5523 fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5524 }
5525 }
5526 Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5527 /*location=*/-1,
5528 /*offset=*/-1,
5529 binding,
5530 /*index=*/-1,
5531 set,
5532 /*builtin=*/-1,
5533 /*inputAttachmentIndex=*/-1);
5534 Variable* intfVar =
5535 fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5536 /*modifiersPosition=*/Position(),
5537 layout,
5538 ModifierFlag::kUniform,
5539 intfStruct,
5540 name,
5541 /*mangledName=*/"",
5542 /*builtin=*/false,
5543 Variable::Storage::kGlobal));
5544 {
5545 AutoAttachPoolToThread attach(fProgram.fPool.get());
5546 fProgram.fSymbols->add(fContext,
5547 std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5548 }
5549 InterfaceBlock intf(Position(), intfVar);
5550 this->writeInterfaceBlock(intf, false);
5551 }
5552
synthesizeTextureAndSampler(const Variable & combinedSampler)5553 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5554 const Variable& combinedSampler) {
5555 SkASSERT(fUseTextureSamplerPairs);
5556 SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5557
5558 if (std::unique_ptr<SynthesizedTextureSamplerPair>* existingData =
5559 fSynthesizedSamplerMap.find(&combinedSampler)) {
5560 return {(*existingData)->fTexture.get(), (*existingData)->fSampler.get()};
5561 }
5562
5563 auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5564
5565 Layout texLayout = combinedSampler.layout();
5566 texLayout.fBinding = texLayout.fTexture;
5567 data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5568
5569 auto texture = Variable::Make(/*pos=*/Position(),
5570 /*modifiersPosition=*/Position(),
5571 texLayout,
5572 combinedSampler.modifierFlags(),
5573 &combinedSampler.type().textureType(),
5574 data->fTextureName,
5575 /*mangledName=*/"",
5576 /*builtin=*/false,
5577 Variable::Storage::kGlobal);
5578
5579 Layout samplerLayout = combinedSampler.layout();
5580 samplerLayout.fBinding = samplerLayout.fSampler;
5581 samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5582 data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5583
5584 auto sampler = Variable::Make(/*pos=*/Position(),
5585 /*modifiersPosition=*/Position(),
5586 samplerLayout,
5587 combinedSampler.modifierFlags(),
5588 fContext.fTypes.fSampler.get(),
5589 data->fSamplerName,
5590 /*mangledName=*/"",
5591 /*builtin=*/false,
5592 Variable::Storage::kGlobal);
5593
5594 const Variable* t = texture.get();
5595 const Variable* s = sampler.get();
5596 data->fTexture = std::move(texture);
5597 data->fSampler = std::move(sampler);
5598 fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5599
5600 return {t, s};
5601 }
5602
writeInstructions(const Program & program,OutputStream & out)5603 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5604 Analysis::FindFunctionsToSpecialize(program, &fSpecializationInfo, [](const Variable& param) {
5605 return param.type().isSampler() || param.type().isUnsizedArray();
5606 });
5607
5608 fGLSLExtendedInstructions = this->nextId(nullptr);
5609 StringStream body;
5610
5611 // Do an initial pass over the program elements to establish some baseline info.
5612 const FunctionDeclaration* main = nullptr;
5613 int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5614 Position combinedSamplerPos;
5615 Position separateSamplerPos;
5616 for (const ProgramElement* e : program.elements()) {
5617 switch (e->kind()) {
5618 case ProgramElement::Kind::kFunction: {
5619 // Assign SpvIds to functions.
5620 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5621 const FunctionDeclaration& funcDecl = funcDef.declaration();
5622 if (const Analysis::Specializations* specializations =
5623 fSpecializationInfo.fSpecializationMap.find(&funcDecl)) {
5624 for (int i = 0; i < specializations->size(); i++) {
5625 fFunctionMap.set({&funcDecl, i}, this->nextId(nullptr));
5626 }
5627 } else {
5628 fFunctionMap.set({&funcDecl, Analysis::kUnspecialized}, this->nextId(nullptr));
5629 }
5630 if (funcDecl.isMain()) {
5631 main = &funcDecl;
5632 }
5633 break;
5634 }
5635 case ProgramElement::Kind::kGlobalVar: {
5636 // Look for sampler variables and determine whether or not this program uses
5637 // combined samplers or separate samplers. The layout backend will be marked as
5638 // WebGPU for separate samplers, or Vulkan for combined samplers.
5639 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5640 const Variable& var = *decl.varDeclaration().var();
5641 if (var.type().isSampler()) {
5642 if (var.layout().fFlags & LayoutFlag::kVulkan) {
5643 combinedSamplerPos = decl.position();
5644 }
5645 if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5646 separateSamplerPos = decl.position();
5647 }
5648 }
5649 break;
5650 }
5651 case ProgramElement::Kind::kModifiers: {
5652 // If this is a compute program, collect the local-size values. Dimensions that are
5653 // not present will be assigned a value of 1.
5654 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5655 const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5656 if (modifiers.layout().fLocalSizeX >= 0) {
5657 localSizeX = modifiers.layout().fLocalSizeX;
5658 }
5659 if (modifiers.layout().fLocalSizeY >= 0) {
5660 localSizeY = modifiers.layout().fLocalSizeY;
5661 }
5662 if (modifiers.layout().fLocalSizeZ >= 0) {
5663 localSizeZ = modifiers.layout().fLocalSizeZ;
5664 }
5665 }
5666 break;
5667 }
5668 default:
5669 break;
5670 }
5671 }
5672
5673 // Make sure we have a main() function.
5674 if (!main) {
5675 fContext.fErrors->error(Position(), "program does not contain a main() function");
5676 return;
5677 }
5678 // Make sure our program's sampler usage is consistent.
5679 if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5680 fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5681 fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5682 fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5683 return;
5684 }
5685 fUseTextureSamplerPairs = separateSamplerPos.valid();
5686
5687 // Emit interface blocks.
5688 std::set<SpvId> interfaceVars;
5689 for (const ProgramElement* e : program.elements()) {
5690 if (e->is<InterfaceBlock>()) {
5691 const InterfaceBlock& intf = e->as<InterfaceBlock>();
5692 SpvId id = this->writeInterfaceBlock(intf);
5693
5694 if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5695 intf.var()->layout().fBuiltin == -1) {
5696 interfaceVars.insert(id);
5697 }
5698 }
5699 }
5700 // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5701 // explicitly declared in the output, whether or not the program explicitly references it.
5702 // However, if the program naturally declares it, we don't want to include it a second time;
5703 // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5704 const VarDeclaration* missingClockwiseDecl = nullptr;
5705 if (fCaps.fMustDeclareFragmentFrontFacing) {
5706 if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5707 missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5708 }
5709 }
5710 // Emit global variable declarations.
5711 for (const ProgramElement* e : program.elements()) {
5712 if (e->is<GlobalVarDeclaration>()) {
5713 const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5714 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5715 return;
5716 }
5717 if (missingClockwiseDecl == &decl) {
5718 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5719 missingClockwiseDecl = nullptr;
5720 }
5721 }
5722 }
5723 // All the global variables have been declared. If sk_Clockwise was not naturally included in
5724 // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5725 if (missingClockwiseDecl) {
5726 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5727 return;
5728 }
5729 missingClockwiseDecl = nullptr;
5730 }
5731 // Emit top-level uniforms into a dedicated uniform buffer.
5732 if (!fTopLevelUniforms.empty()) {
5733 this->writeUniformBuffer(get_top_level_symbol_table(*main));
5734 }
5735 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5736 // main() and stores the result into sk_FragColor.
5737 EntrypointAdapter adapter;
5738 if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5739 adapter = this->writeEntrypointAdapter(*main);
5740 if (adapter.entrypointDecl) {
5741 fFunctionMap.set({adapter.entrypointDecl.get(), Analysis::kUnspecialized},
5742 this->nextId(nullptr));
5743 this->writeFunction(*adapter.entrypointDef, body);
5744 main = adapter.entrypointDecl.get();
5745 }
5746 }
5747 // Emit all the functions.
5748 for (const ProgramElement* e : program.elements()) {
5749 if (e->is<FunctionDefinition>()) {
5750 this->writeFunction(e->as<FunctionDefinition>(), body);
5751 }
5752 }
5753 // Add global in/out variables to the list of interface variables.
5754 for (const auto& [var, spvId] : fVariableMap) {
5755 if (var->storage() == Variable::Storage::kGlobal &&
5756 #ifdef SKSL_EXT
5757 !(var->layout().fFlags & SkSL::LayoutFlag::kConstantId) &&
5758 ((var->modifierFlags() == ModifierFlag::kNone) ||
5759 (var->modifierFlags() & (
5760 ModifierFlag::kIn |
5761 ModifierFlag::kOut |
5762 ModifierFlag::kUniform |
5763 ModifierFlag::kBuffer)))) {
5764 #else
5765 (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5766 #endif
5767 interfaceVars.insert(spvId);
5768 }
5769 }
5770 this->writeCapabilities(out);
5771 #ifdef SKSL_EXT
5772 this->writeExtensions(out);
5773 #endif
5774 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5775 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5776 this->writeOpCode(SpvOpEntryPoint,
5777 (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5778 out);
5779 if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5780 this->writeWord(SpvExecutionModelVertex, out);
5781 } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5782 this->writeWord(SpvExecutionModelFragment, out);
5783 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5784 this->writeWord(SpvExecutionModelGLCompute, out);
5785 } else {
5786 SK_ABORT("cannot write this kind of program to SPIR-V\n");
5787 }
5788 const Analysis::SpecializedFunctionKey mainKey{main, Analysis::kUnspecialized};
5789 SpvId entryPoint = fFunctionMap[mainKey];
5790 this->writeWord(entryPoint, out);
5791 this->writeString(main->name(), out);
5792 for (int var : interfaceVars) {
5793 this->writeWord(var, out);
5794 }
5795 if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5796 this->writeInstruction(SpvOpExecutionMode,
5797 fFunctionMap[mainKey],
5798 SpvExecutionModeOriginUpperLeft,
5799 out);
5800 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5801 this->writeInstruction(SpvOpExecutionMode,
5802 fFunctionMap[mainKey],
5803 SpvExecutionModeLocalSize,
5804 localSizeX, localSizeY, localSizeZ,
5805 out);
5806 }
5807 for (const ProgramElement* e : program.elements()) {
5808 if (e->is<Extension>()) {
5809 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5810 }
5811 }
5812
5813 write_stringstream(fNameBuffer, out);
5814 write_stringstream(fDecorationBuffer, out);
5815 write_stringstream(fConstantBuffer, out);
5816 write_stringstream(body, out);
5817 }
5818
5819 bool SPIRVCodeGenerator::generateCode() {
5820 SkASSERT(!fContext.fErrors->errorCount());
5821 this->writeWord(SpvMagicNumber, *fOut);
5822 this->writeWord(SpvVersion, *fOut);
5823 this->writeWord(SKSL_MAGIC, *fOut);
5824 StringStream buffer;
5825 this->writeInstructions(fProgram, buffer);
5826 this->writeWord(fIdCount, *fOut);
5827 this->writeWord(0, *fOut); // reserved, always zero
5828 write_stringstream(buffer, *fOut);
5829 return fContext.fErrors->errorCount() == 0;
5830 }
5831
5832 bool ToSPIRV(Program& program,
5833 const ShaderCaps* caps,
5834 OutputStream& out,
5835 ValidateSPIRVProc validateSPIRV) {
5836 TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5837 SkASSERT(caps != nullptr);
5838
5839 program.fContext->fErrors->setSource(*program.fSource);
5840 bool result;
5841 if (validateSPIRV) {
5842 StringStream buffer;
5843 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5844 result = cg.generateCode();
5845
5846 if (result && program.fConfig->fSettings.fValidateSPIRV) {
5847 std::string_view binary = buffer.str();
5848 result = validateSPIRV(*program.fContext->fErrors, binary);
5849 out.write(binary.data(), binary.size());
5850 }
5851 } else {
5852 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5853 result = cg.generateCode();
5854 }
5855 program.fContext->fErrors->setSource(std::string_view());
5856
5857 return result;
5858 }
5859
5860 bool ToSPIRV(Program& program,
5861 const ShaderCaps* caps,
5862 std::string* out,
5863 ValidateSPIRVProc validateSPIRV) {
5864 StringStream buffer;
5865 if (!ToSPIRV(program, caps, buffer, validateSPIRV)) {
5866 return false;
5867 }
5868 *out = buffer.str();
5869 return true;
5870 }
5871
5872 } // namespace SkSL
5873